├── .github
└── FUNDING.yml
├── IFL.md
├── LICENSE
├── README.md
├── _COCOGeneration
├── 1.Visualization.ipynb
├── 1.Visualization.py
├── 2.SplitGeneration.ipynb
├── 2.SplitGeneration.py
├── 3.DataGeneration.ipynb
├── 3.DataGeneration.py
├── 4.Analysis.ipynb
├── 4.Analysis.py
├── README.md
├── cocottributes_py3.jbl
├── mscoco_glt_crop.py
└── mscoco_glt_generation.py
├── _ImageNetGeneration
├── 1.MoCoFeature.ipynb
├── 1.SupervisedFeature.ipynb
├── 1.UnsupervisedFeature.ipynb
├── 2.SplitGeneration.ipynb
├── 3.Analysis.ipynb
├── 4.ChangePath.ipynb
├── 5.visualize_tsne.ipynb
├── README.md
├── imagenet_glt_generation.py
└── long-tail-distribution.pytorch
├── __init__.py
├── config
├── COCO_BL.yaml
├── COCO_LT.yaml
├── ColorMNIST_BL.yaml
├── ColorMNIST_LT.yaml
├── ImageNet_BL.yaml
├── ImageNet_LT.yaml
├── __init__.py
└── algorithms_config.yaml
├── data
├── DT_COCO_LT.py
├── DT_ColorMNIST.py
├── DT_ImageNet_LT.py
├── Sampler_ClassAware.py
├── Sampler_MultiEnv.py
├── __init__.py
└── dataloader.py
├── deprecated
├── _ColorMNISTGeneration
│ └── 1.DataGeneration.ipynb
└── __init__.py
├── figure
├── generalized-long-tail.jpg
├── glt_formulation.jpg
├── ifl.jpg
├── ifl_code.jpg
├── imagenet-glt-statistics.jpg
├── imagenet-glt-visualization.jpg
├── imagenet-glt.jpg
├── mscoco-glt-statistics.jpg
├── mscoco-glt-testgeneration.jpg
├── mscoco-glt-visualization.jpg
└── mscoco-glt.jpg
├── main.py
├── models
├── ClassifierCOS.py
├── ClassifierFC.py
├── ClassifierLA.py
├── ClassifierLDAM.py
├── ClassifierLWS.py
├── ClassifierMultiHead.py
├── ClassifierRIDE.py
├── ClassifierTDE.py
├── ResNet.py
├── ResNet_BBN.py
├── ResNet_RIDE.py
└── __init__.py
├── test_baseline.py
├── test_la.py
├── test_tde.py
├── train_baseline.py
├── train_bbn.py
├── train_center_dual.py
├── train_center_dual_mixup.py
├── train_center_ldam_dual.py
├── train_center_ride.py
├── train_center_ride_mixup.py
├── train_center_single.py
├── train_center_tade.py
├── train_center_triple.py
├── train_irm_dual.py
├── train_la.py
├── train_ldam.py
├── train_lff.py
├── train_mixup.py
├── train_ride.py
├── train_stage1.py
├── train_stage2.py
├── train_stage2_ride.py
├── train_tade.py
├── train_tde.py
├── utils
├── __init__.py
├── checkpoint_utils.py
├── general_utils.py
├── logger_utils.py
├── test_loader.py
├── train_loader.py
└── training_utils.py
└── visualization
├── figure1.ipynb
├── figure2.ipynb
└── figure3.ipynb
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: [KaihuaTang]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
12 | polar: # Replace with a single Polar username
13 | buy_me_a_coffee: tkhchipaomg
14 | thanks_dev: # Replace with a single thanks.dev username
15 | custom: ['https://kaihuatang.github.io/donate']
16 |
--------------------------------------------------------------------------------
/IFL.md:
--------------------------------------------------------------------------------
1 | # Invariant Feature Learning
2 |
3 | ## The framework of the proposed IFL
4 |
5 |

6 | Figure 1. Invariant Feature Learning.
7 |
8 |
9 | ## The pseudo code of IFL
10 |
11 | 
12 | Figure 2. Pseudo Code.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | NTUItive Dual License
2 |
3 | Copyright (c) 2022 Kaihua Tang
4 |
5 | NANYANG TECHNOLOGICAL UNIVERSITY - NTUITIVE PTE LTD (NTUITIVE) Dual License Agreement
6 | Non-Commercial Use Only
7 | This NTUITIVE License Agreement, including all exhibits ("NTUITIVE-LA") is a legal agreement between you and NTUITIVE (or “we”) located at 71 Nanyang Drive, NTU Innovation Centre, #01-109, Singapore 637722, a wholly owned subsidiary of Nanyang Technological University (“NTU”) for the software or data identified above, which may include source code, and any associated materials, text or speech files, associated media and "online" or electronic documentation and any updates we provide in our discretion (together, the "Software").
8 |
9 | By installing, copying, or otherwise using this Software, found at *************** (researcher to provide link to software), you agree to be bound by the terms of this NTUITIVE-LA. If you do not agree, do not install copy or use the Software. The Software is protected by copyright and other intellectual property laws and is licensed, not sold. If you wish to obtain a commercial royalty bearing license to this software please contact us at XXX (provide email).
10 |
11 | SCOPE OF RIGHTS:
12 | You may use, copy, reproduce, and distribute this Software for any non-commercial purpose, subject to the restrictions in this NTUITIVE-LA. Some purposes which can be non-commercial are teaching, academic research, public demonstrations and personal experimentation. You may also distribute this Software with books or other teaching materials, or publish the Software on websites, that are intended to teach the use of the Software for academic or other non-commercial purposes.
13 | You may not use or distribute this Software or any derivative works in any form for commercial purposes. Examples of commercial purposes would be running business operations, licensing, leasing, or selling the Software, distributing the Software for use with commercial products, using the Software in the creation or use of commercial products or any other activity which purpose is to procure a commercial gain to you or others.
14 | If the Software includes source code or data, you may create derivative works of such portions of the Software and distribute the modified Software for non-commercial purposes, as provided herein.
15 | If you distribute the Software or any derivative works of the Software, you will distribute them under the same terms and conditions as in this license, and you will not grant other rights to the Software or derivative works that are different from those provided by this NTUITIVE-LA.
16 | If you have created derivative works of the Software, and distribute such derivative works, you will cause the modified files to carry prominent notices so that recipients know that they are not receiving the original Software. Such notices must state: (i) that you have changed the Software; and (ii) the date of any changes.
17 |
18 | You may not distribute this Software or any derivative works.
19 | In return, we simply require that you agree:
20 | 1. That you will not remove any copyright or other notices from the Software.
21 | 2. That if any of the Software is in binary format, you will not attempt to modify such portions of the Software, or to reverse engineer or decompile them, except and only to the extent authorized by applicable law.
22 | 3. That NTUITIVE is granted back, without any restrictions or limitations, a non-exclusive, perpetual, irrevocable, royalty-free, assignable and sub-licensable license, to reproduce, publicly perform or display, install, use, modify, post, distribute, make and have made, sell and transfer your modifications to and/or derivative works of the Software source code or data, for any purpose.
23 | 4. That any feedback about the Software provided by you to us is voluntarily given, and NTUITIVE shall be free to use the feedback as it sees fit without obligation or restriction of any kind, even if the feedback is designated by you as confidential.
24 | 5. THAT THE SOFTWARE COMES "AS IS", WITH NO WARRANTIES. THIS MEANS NO EXPRESS, IMPLIED OR STATUTORY WARRANTY, INCLUDING WITHOUT LIMITATION, WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE, ANY WARRANTY AGAINST INTERFERENCE WITH YOUR ENJOYMENT OF THE SOFTWARE OR ANY WARRANTY OF TITLE OR NON-INFRINGEMENT. THERE IS NO WARRANTY THAT THIS SOFTWARE WILL FULFILL ANY OF YOUR PARTICULAR PURPOSES OR NEEDS. ALSO, YOU MUST PASS THIS DISCLAIMER ON WHENEVER YOU DISTRIBUTE THE SOFTWARE OR DERIVATIVE WORKS.
25 | 6. THAT NEITHER NTUITIVE NOR NTU NOR ANY CONTRIBUTOR TO THE SOFTWARE WILL BE LIABLE FOR ANY DAMAGES RELATED TO THE SOFTWARE OR THIS NTUITIVE-LA, INCLUDING DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL OR INCIDENTAL DAMAGES, TO THE MAXIMUM EXTENT THE LAW PERMITS, NO MATTER WHAT LEGAL THEORY IT IS BASED ON. ALSO, YOU MUST PASS THIS LIMITATION OF LIABILITY ON WHENEVER YOU DISTRIBUTE THE SOFTWARE OR DERIVATIVE WORKS.
26 | 7. That we have no duty of reasonable care or lack of negligence, and we are not obligated to (and will not) provide technical support for the Software.
27 | 8. That if you breach this NTUITIVE-LA or if you sue anyone over patents that you think may apply to or read on the Software or anyone's use of the Software, this NTUITIVE-LA (and your license and rights obtained herein) terminate automatically. Upon any such termination, you shall destroy all of your copies of the Software immediately. Sections 3, 4, 5, 6, 7, 8, 11 and 12 of this NTUITIVE-LA shall survive any termination of this NTUITIVE-LA.
28 | 9. That the patent rights, if any, granted to you in this NTUITIVE-LA only apply to the Software, not to any derivative works you make.
29 | 10. That the Software may be subject to U.S. export jurisdiction at the time it is licensed to you, and it may be subject to additional export or import laws in other places. You agree to comply with all such laws and regulations that may apply to the Software after delivery of the software to you.
30 | 11. That all rights not expressly granted to you in this NTUITIVE-LA are reserved.
31 | 12. That this NTUITIVE-LA shall be construed and controlled by the laws of the Republic of Singapore without regard to conflicts of law. If any provision of this NTUITIVE-LA shall be deemed unenforceable or contrary to law, the rest of this NTUITIVE-LA shall remain in full effect and interpreted in an enforceable manner that most nearly captures the intent of the original language.
32 |
33 |
34 | Do you accept all of the terms of the preceding NTUITIVE-LA license agreement? If you accept the terms, click “I Agree,” then “Next.” Otherwise click “Cancel.”
35 |
36 | Copyright (c) NTUITIVE. All rights reserved.
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/_COCOGeneration/1.Visualization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # In[2]:
5 |
6 |
7 | # Python 2
8 | #from sklearn.externals import joblib
9 | # Python 3
10 | import joblib
11 |
12 | from PIL import Image, ImageDraw
13 | from io import BytesIO
14 | import json
15 | import joblib
16 | import os
17 | import random
18 |
19 | import numpy as np
20 | import matplotlib.pyplot as plt
21 |
22 | random.seed(25)
23 |
24 | COCO_IMAGE_PATH = '/data4/coco2014/images/'
25 | COCO_ATTRIBUTE_PATH = './cocottributes_py3.jbl'
26 | COCO_ANNOTATION_PATH = '/data4/coco2014/annotations/'
27 |
28 |
29 | # In[3]:
30 |
31 |
32 | cocottributes = joblib.load(COCO_ATTRIBUTE_PATH)
33 |
34 |
35 | # In[4]:
36 |
37 |
38 | def load_coco(root_path=COCO_ANNOTATION_PATH):
39 | # Load COCO Annotations in val2014 & train2014
40 | coco_data = {'images':[], 'annotations':[]}
41 | with open(os.path.join(root_path, 'instances_train2014.json'), 'r') as f:
42 | train2014 = json.load(f)
43 | with open(os.path.join(root_path, 'instances_val2014.json'), 'r') as f:
44 | val2014 = json.load(f)
45 | coco_data['categories'] = train2014['categories']
46 | coco_data['images'] += train2014['images']
47 | coco_data['images'] += val2014['images']
48 | coco_data['annotations'] += train2014['annotations']
49 | coco_data['annotations'] += val2014['annotations']
50 | return coco_data
51 |
52 |
53 | # In[5]:
54 |
55 |
56 | def id_to_path(data_path=COCO_IMAGE_PATH):
57 | id2path = {}
58 | splits = ['train2014', 'val2014']
59 | for split in splits:
60 | for file in os.listdir(os.path.join(data_path, split)):
61 | if file.endswith(".jpg"):
62 | idx = int(file.split('.')[0].split('_')[-1])
63 | id2path[idx] = os.path.join(data_path, split, file)
64 | return id2path
65 |
66 |
67 | # In[6]:
68 |
69 |
70 | def print_coco_attributes_instance(cocottributes, coco_data, id2path, ex_ind, sname):
71 | # List of COCO Attributes
72 | attr_details = sorted(cocottributes['attributes'], key=lambda x:x['id'])
73 | attr_names = [item['name'] for item in attr_details]
74 |
75 | # COCO Attributes instance ID for this example
76 | coco_attr_id = list(cocottributes['ann_vecs'].keys())[ex_ind]
77 |
78 | # COCO Attribute annotation vector, attributes in order sorted by dataset ID
79 | instance_attrs = cocottributes['ann_vecs'][coco_attr_id]
80 |
81 | # Print the image and positive attributes for this instance, attribute considered postive if worker vote is > 0.5
82 | pos_attrs = [a for ind, a in enumerate(attr_names) if instance_attrs[ind] > 0.5]
83 | coco_dataset_ann_id = cocottributes['patch_id_to_ann_id'][coco_attr_id]
84 |
85 | coco_annotation = [ann for ann in coco_data['annotations'] if ann['id'] == coco_dataset_ann_id][0]
86 |
87 | img_path = id2path[coco_annotation['image_id']]
88 | img = Image.open(img_path)
89 | polygon = coco_annotation['segmentation'][0]
90 | bbox = coco_annotation['bbox']
91 | ImageDraw.Draw(img, 'RGBA').polygon(polygon, outline=(255,0,0), fill=(255,0,0,50))
92 | ImageDraw.Draw(img, 'RGBA').rectangle(((bbox[0], bbox[1]), (bbox[0]+bbox[2], bbox[1]+bbox[3])), outline="red")
93 | img = img.crop((bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]))
94 | img = np.array(img)
95 | category = [c['name'] for c in coco_data['categories'] if c['id'] == coco_annotation['category_id']][0]
96 |
97 | print_image_with_attributes(img, pos_attrs, category, sname)
98 |
99 |
100 | # In[7]:
101 |
102 |
103 | def print_image_with_attributes(img, attrs, category, sname):
104 |
105 | fig = plt.figure()
106 | plt.imshow(img)
107 | plt.axis('off') # clear x- and y-axes
108 | plt.title(category)
109 | for ind, a in enumerate(attrs):
110 | plt.text(min(img.shape[1]+10, 1000), (ind+1)*img.shape[1]*0.1, a, ha='left')
111 |
112 | #fig.savefig(sname, dpi = 300, bbox_inches='tight')
113 |
114 |
115 | # In[8]:
116 |
117 |
118 | coco_data = load_coco()
119 | id2path = id_to_path()
120 |
121 |
122 | # In[9]:
123 |
124 |
125 | ex_inds = [0,10,50,1000,2000,3000,4000,5000]
126 | sname = '/Users/tangkaihua/Desktop/example_cocottributes_annotation{}.jpg'
127 | for ex_ind in ex_inds:
128 | print_coco_attributes_instance(cocottributes, coco_data, id2path, ex_ind, sname.format(ex_ind))
129 |
130 |
131 | # In[ ]:
132 |
133 |
134 |
135 |
136 |
137 | # In[ ]:
138 |
139 |
140 |
141 |
142 |
143 | # In[ ]:
144 |
145 |
146 |
147 |
148 |
149 | # In[ ]:
150 |
151 |
152 |
153 |
154 |
--------------------------------------------------------------------------------
/_COCOGeneration/3.DataGeneration.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # In[1]:
5 |
6 |
7 | # Python 2
8 | #from sklearn.externals import joblib
9 | # Python 3
10 | import joblib
11 |
12 | from PIL import Image, ImageDraw
13 | from io import BytesIO
14 | import json
15 | import joblib
16 | import os
17 | import random
18 |
19 | import torch
20 | import numpy as np
21 | import matplotlib.pyplot as plt
22 |
23 | random.seed(25)
24 |
25 | # annotation path
26 | ANNOTATION_LT = './coco_intra_lt_inter_lt.jbl'
27 | ANNOTATION_BL = './coco_intra_lt_inter_bl.jbl'
28 | COCO_IMAGE_PATH = '/data4/coco2014/images/'
29 | ROOT_PATH = '/data4/'
30 |
31 |
32 | # In[2]:
33 |
34 |
35 | def id_to_path(data_path=COCO_IMAGE_PATH):
36 | id2path = {}
37 | subpath = ['val2014', 'train2014']
38 | for spath in subpath:
39 | for file in os.listdir(data_path + spath):
40 | if file.endswith(".jpg"):
41 | idx = int(file.split('.')[0].split('_')[-1])
42 | id2path[idx] = os.path.join(data_path, spath, file)
43 | return id2path
44 |
45 |
46 | # In[ ]:
47 |
48 |
49 |
50 |
51 |
52 | # In[3]:
53 |
54 |
55 | def generate_images_labels():
56 | annotations = {}
57 | cat2id = cocottributes_all['cat2id']
58 | id2cat = {i:cat for cat,i in cat2id.items()}
59 | annotations['cat2id'] = cat2id
60 | annotations['id2cat'] = id2cat
61 | annotations['key2path'] = {}
62 |
63 | train_count_array = get_att_count(cocottributes_all, 'train')
64 |
65 | for setname in ['train', 'val', 'test_lt', 'test_bl', 'test_bbl']:
66 | annotations[setname] = {'label':{}, 'frequency':{}, 'attribute':{}, 'path':{}, 'attribute_score':{}}
67 | # check validity
68 | all_keys = list(cocottributes_all[setname]['label'].keys())
69 | if len(all_keys) == 0:
70 | print('Skip {}'.format(setname))
71 | continue
72 |
73 | # attribute distribution
74 | annotations[setname]['attribute_dist'] = get_att_count(cocottributes_all, setname)
75 | for cat_id in annotations[setname]['attribute_dist'].keys():
76 | annotations[setname]['attribute_dist'][cat_id] = annotations[setname]['attribute_dist'][cat_id].tolist()
77 |
78 | # find attribute threshold
79 | for coco_attr_key in all_keys:
80 | cat_id = cocottributes_all[setname]['label'][coco_attr_key]
81 | att_array = cocottributes_all[setname]['attribute'][coco_attr_key]
82 | base_score = normalize_vector(train_count_array[cat_id])
83 | attr_score = normalize_vector((torch.from_numpy(att_array) > 0.5).float())
84 | annotations[setname]['attribute_score'][coco_attr_key] = (base_score * attr_score).sum().item()
85 | att_scores = list(annotations[setname]['attribute_score'].values())
86 | att_scores.sort(reverse=True)
87 | attribute_high_mid_thres = att_scores[len(att_scores) // 3]
88 | attribute_mid_low_thres = att_scores[len(att_scores) // 3 * 2]
89 |
90 | for i, coco_attr_key in enumerate(all_keys):
91 | if (i%1000 == 0):
92 | print('==== Processing : {}'.format(i/len(all_keys)))
93 | # generate image
94 | print_coco_attributes_instance(cocottributes_all, id2path, coco_attr_key, OUTPUT_PATH.format(coco_attr_key), setname)
95 | # generate label
96 | annotations[setname]['label'][coco_attr_key] = cocottributes_all[setname]['label'][coco_attr_key]
97 | annotations[setname]['path'][coco_attr_key] = OUTPUT_PATH.format(coco_attr_key)
98 | annotations[setname]['frequency'][coco_attr_key] = cocottributes_all[setname]['frequency'][coco_attr_key]
99 | if annotations[setname]['attribute_score'][coco_attr_key] > attribute_high_mid_thres:
100 | annotations[setname]['attribute'][coco_attr_key] = 0
101 | elif annotations[setname]['attribute_score'][coco_attr_key] > attribute_mid_low_thres:
102 | annotations[setname]['attribute'][coco_attr_key] = 1
103 | else:
104 | annotations[setname]['attribute'][coco_attr_key] = 2
105 |
106 | with open(OUTPUT_ANNO, 'w') as outfile:
107 | json.dump(annotations, outfile)
108 |
109 |
110 | def normalize_vector(vector):
111 | output = vector / (vector.sum() + 1e-9)
112 | return output
113 |
114 |
115 | def get_att_count(cocottributes, setname):
116 | split_data = cocottributes[setname]
117 | cat2id = cocottributes['cat2id']
118 |
119 | # update array count
120 | count_array = {}
121 | for item in set(cat2id.values()):
122 | count_array[item] = torch.FloatTensor([0 for i in range(len(cocottributes['attributes']))])
123 | for key in split_data['label'].keys():
124 | cat_id = split_data['label'][key]
125 | att_array = split_data['attribute'][key]
126 | count_array[cat_id] = count_array[cat_id] + (torch.from_numpy(att_array) > 0.5).float()
127 |
128 | return count_array
129 |
130 |
131 | def print_coco_attributes_instance(cocottributes, id2path, coco_attr_id, sname, setname):
132 | # List of COCO Attributes
133 | coco_annotation = cocottributes['annotations'][coco_attr_id]
134 | img_path = id2path[coco_annotation['image_id']]
135 | img = Image.open(img_path)
136 | bbox = coco_annotation['bbox']
137 |
138 | # crop the object bounding box
139 | if bbox[2] < 100:
140 | x1 = max(bbox[0]-50,0)
141 | x2 = min(bbox[0]+50+bbox[2],img.size[0])
142 | else:
143 | x1 = max(bbox[0]-bbox[2]*0.2,0)
144 | x2 = min(bbox[0]+1.2*bbox[2],img.size[0])
145 |
146 | if bbox[3] < 100:
147 | y1 = max(bbox[1]-50,0)
148 | y2 = min(bbox[1]+50+bbox[3],img.size[1])
149 | else:
150 | y1 = max(bbox[1]-bbox[3]*0.2,0)
151 | y2 = min(bbox[1]+1.2*bbox[3],img.size[1])
152 |
153 | img = img.crop((x1, y1, x2, y2))
154 |
155 | # padding the rectangular boxes to square boxes
156 | w, h = img.size
157 | pad_size = (max(w,h) - min(w,h))/2
158 | if w > h:
159 | img = img.crop((0, -pad_size, w, h+pad_size))
160 | else:
161 | img = img.crop((-pad_size, 0, w+pad_size, h))
162 |
163 | # save image
164 | img.save(sname)
165 |
166 |
167 | # In[ ]:
168 |
169 |
170 |
171 |
172 |
173 | # In[ ]:
174 |
175 |
176 |
177 |
178 |
179 | # In[4]:
180 |
181 |
182 | DATA_TYPE = 'coco_bl'
183 | # output path
184 | OUTPUT_PATH = ROOT_PATH + DATA_TYPE + '/images/{}.jpg'
185 | OUTPUT_ANNO = ROOT_PATH + DATA_TYPE + '/annotations/annotation.json'
186 |
187 | cocottributes_all = joblib.load(ANNOTATION_BL)
188 |
189 | id2path = id_to_path()
190 | generate_images_labels()
191 |
192 |
193 | # In[5]:
194 |
195 |
196 | DATA_TYPE = 'coco_lt' #'coco_lt' / 'coco_half_lt'
197 | # output path
198 | OUTPUT_PATH = ROOT_PATH + DATA_TYPE + '/images/{}.jpg'
199 | OUTPUT_ANNO = ROOT_PATH + DATA_TYPE + '/annotations/annotation.json'
200 |
201 | cocottributes_all = joblib.load(ANNOTATION_LT)
202 |
203 | id2path = id_to_path()
204 | generate_images_labels()
205 |
206 |
207 | # In[ ]:
208 |
209 |
210 |
211 |
212 |
213 | # In[ ]:
214 |
215 |
216 |
217 |
218 |
219 | # In[ ]:
220 |
221 |
222 |
223 |
224 |
225 | # In[ ]:
226 |
227 |
228 |
229 |
230 |
231 | # In[ ]:
232 |
233 |
234 |
235 |
236 |
237 | # In[ ]:
238 |
239 |
240 |
241 |
242 |
243 | # In[ ]:
244 |
245 |
246 |
247 |
248 |
--------------------------------------------------------------------------------
/_COCOGeneration/4.Analysis.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # In[1]:
5 |
6 |
7 | # Python 2
8 | #from sklearn.externals import joblib
9 | # Python 3
10 | import joblib
11 |
12 | from PIL import Image, ImageDraw
13 | from io import BytesIO
14 | import json
15 | import joblib
16 | import os
17 | import random
18 | import torch
19 |
20 | import numpy as np
21 | import matplotlib.pyplot as plt
22 |
23 | random.seed(25)
24 |
25 |
26 | ANNOTATION_LT = '/data4/coco_lt/annotations/annotation.json'
27 | ANNOTATION_BL = '/data4/coco_bl/annotations/annotation.json'
28 |
29 |
30 | # In[2]:
31 |
32 |
33 | lt_annotation = json.load(open(ANNOTATION_LT))
34 | bl_annotation = json.load(open(ANNOTATION_BL))
35 |
36 |
37 | # In[3]:
38 |
39 |
40 | def show_statistics(vals):
41 | # sort your values in descending order
42 | indSort = np.argsort(vals)[::-1]
43 | # rearrange your data
44 | att_values = np.array(vals)[indSort]
45 | indexes = np.arange(len(vals))
46 | bar_width = 0.35
47 | plt.bar(indexes, att_values)
48 | plt.show()
49 |
50 |
51 | # In[4]:
52 |
53 |
54 | lt_annotation.keys()
55 |
56 |
57 | # In[5]:
58 |
59 |
60 | def duplication_check(annotation):
61 | count = 0
62 | all_img = []
63 | all_set = ['train', 'val', 'test_lt', 'test_bl', 'test_bbl']
64 |
65 | for key in all_set:
66 | count += len(annotation[key]['label'])
67 | all_img = all_img + list(annotation[key]['label'].keys())
68 |
69 | all_img = set(all_img)
70 | print('Counting Result: {}'.format(count))
71 | print('Number of Images: {}'.format(len(all_img)))
72 |
73 |
74 | # In[6]:
75 |
76 |
77 | duplication_check(lt_annotation)
78 |
79 |
80 | # In[7]:
81 |
82 |
83 | duplication_check(bl_annotation)
84 |
85 |
86 | # In[8]:
87 |
88 |
89 | def count_label_dist(dataset, num_cls=29):
90 | cls_sizes = [0] * num_cls
91 | for key, val in dataset['label'].items():
92 | cls_sizes[int(val)] += 1
93 | print(len(dataset['label']))
94 | return cls_sizes
95 |
96 |
97 | # In[9]:
98 |
99 |
100 | def count_attribuate_dist(dataset, num_cls=204):
101 | if 'attribute_dist' in dataset:
102 | att_all = list(dataset['attribute_dist'].values())
103 | att_all = [torch.FloatTensor(item) for item in att_all]
104 | att_count = sum(att_all)
105 | print('std: ', (att_count / att_count.sum()).std().item())
106 | att_count = att_count.tolist()
107 | att_count.sort(reverse=True)
108 | return att_count
109 | else:
110 | return [0] * num_cls
111 |
112 |
113 | # In[10]:
114 |
115 |
116 | show_statistics(count_label_dist(lt_annotation['train']))
117 |
118 |
119 | # In[11]:
120 |
121 |
122 | show_statistics(count_label_dist(lt_annotation['val']))
123 |
124 |
125 | # In[12]:
126 |
127 |
128 | show_statistics(count_label_dist(lt_annotation['test_lt']))
129 |
130 |
131 | # In[13]:
132 |
133 |
134 | show_statistics(count_label_dist(lt_annotation['test_bl']))
135 |
136 |
137 | # In[14]:
138 |
139 |
140 | show_statistics(count_label_dist(lt_annotation['test_bbl']))
141 |
142 |
143 | # In[ ]:
144 |
145 |
146 |
147 |
148 |
149 | # In[15]:
150 |
151 |
152 | show_statistics(count_label_dist(bl_annotation['train']))
153 |
154 |
155 | # In[16]:
156 |
157 |
158 | show_statistics(count_label_dist(bl_annotation['val']))
159 |
160 |
161 | # In[17]:
162 |
163 |
164 | show_statistics(count_label_dist(bl_annotation['test_lt']))
165 |
166 |
167 | # In[18]:
168 |
169 |
170 | show_statistics(count_label_dist(bl_annotation['test_bl']))
171 |
172 |
173 | # In[19]:
174 |
175 |
176 | show_statistics(count_label_dist(bl_annotation['test_bbl']))
177 |
178 |
179 | # In[ ]:
180 |
181 |
182 |
183 |
184 |
185 | # In[20]:
186 |
187 |
188 | show_statistics(count_attribuate_dist(lt_annotation['train']))
189 |
190 |
191 | # In[21]:
192 |
193 |
194 | show_statistics(count_attribuate_dist(lt_annotation['val']))
195 |
196 |
197 | # In[22]:
198 |
199 |
200 | show_statistics(count_attribuate_dist(lt_annotation['test_lt']))
201 |
202 |
203 | # In[23]:
204 |
205 |
206 | show_statistics(count_attribuate_dist(lt_annotation['test_bl']))
207 |
208 |
209 | # In[24]:
210 |
211 |
212 | show_statistics(count_attribuate_dist(lt_annotation['test_bbl']))
213 |
214 |
215 | # In[ ]:
216 |
217 |
218 |
219 |
220 |
221 | # In[25]:
222 |
223 |
224 | show_statistics(count_attribuate_dist(bl_annotation['train']))
225 |
226 |
227 | # In[26]:
228 |
229 |
230 | show_statistics(count_attribuate_dist(bl_annotation['val']))
231 |
232 |
233 | # In[27]:
234 |
235 |
236 | show_statistics(count_attribuate_dist(bl_annotation['test_lt']))
237 |
238 |
239 | # In[28]:
240 |
241 |
242 | show_statistics(count_attribuate_dist(bl_annotation['test_bl']))
243 |
244 |
245 | # In[29]:
246 |
247 |
248 | show_statistics(count_attribuate_dist(bl_annotation['test_bbl']))
249 |
250 |
251 | # In[ ]:
252 |
253 |
254 |
255 |
256 |
257 | # In[ ]:
258 |
259 |
260 |
261 |
262 |
263 | # In[30]:
264 |
265 |
266 | JBL_LT = './coco_intra_lt_inter_lt.jbl'
267 | JBL_BL = './coco_intra_lt_inter_bl.jbl'
268 | cocottributes_lt = joblib.load(JBL_LT)
269 | cocottributes_bl = joblib.load(JBL_BL)
270 |
271 |
272 | # In[48]:
273 |
274 |
275 | def print_image_by_index(annotation, cocottributes, setname='train', index=1):
276 | key = list(annotation[setname]['label'].keys())[index]
277 | img_path = annotation[setname]['path'][key]
278 | img = Image.open(img_path)
279 |
280 | att_array = cocottributes[setname]['attribute'][key]
281 | # List of COCO Attributes
282 | attr_details = sorted(cocottributes['attributes'], key=lambda x:x['id'])
283 | attr_names = [item['name'] for item in attr_details]
284 | pos_attrs = [a for ind, a in enumerate(attr_names) if att_array[ind] > 0.5]
285 | #
286 | id2cat = {i:cat for cat,i in cocottributes['cat2id'].items()}
287 | category = id2cat[annotation[setname]['label'][key]]
288 | print_image_with_attributes(img, pos_attrs, category)
289 |
290 | def print_image_with_attributes(img, attrs, category):
291 |
292 | fig = plt.figure()
293 | plt.imshow(img)
294 | plt.axis('off') # clear x- and y-axes
295 | plt.title(category)
296 | for ind, a in enumerate(attrs):
297 | plt.text(min(img.size[1]+10, 1000), (ind+1)*img.size[1]*0.1, a, ha='left')
298 |
299 |
300 | # In[49]:
301 |
302 |
303 | ex_inds = [0,10,50,1000,2000,3000,4000,5000]
304 |
305 |
306 | # In[50]:
307 |
308 |
309 | for ex_ind in ex_inds:
310 | print_image_by_index(lt_annotation, cocottributes_lt, setname='train', index=ex_ind)
311 |
312 |
313 | # In[51]:
314 |
315 |
316 | for ex_ind in ex_inds:
317 | print_image_by_index(bl_annotation, cocottributes_bl, setname='train', index=ex_ind)
318 |
319 |
320 | # In[ ]:
321 |
322 |
323 | lt_annotation['']
324 |
325 |
--------------------------------------------------------------------------------
/_COCOGeneration/README.md:
--------------------------------------------------------------------------------
1 | # MSCOCO-GLT Dataset Generation
2 |
3 | ## Introduction
4 | 
5 | Figure 1. Balancing the Attribute Distribution of MSCOCO-Attribute Through Minimized STD.
6 |
7 | For MSCOCO-GLT, although we can directly obtain the attribute annotation from [MSCOCO-Attribute](https://github.com/genp/cocottributes), each object may have multiple attributes, making strictly balancing the attribute distribution prohibitive, as every time we sample an object with a rare attribute, it usually contain co-occurred frequent attributes as well. Therefore, as illustrated in the above Figure 1, we collect a subset of images with relatively more balanced attribute distributions (by minimizing STD of attributes) to serve as attribute-wise balanced test set. Other long-tailed split and class-wise splits are constructed the same as the ImageNet-GLT.
8 |
9 | Note that the size of Train-CBL split is significantly smaller than Train-GLT, because one single super head class “person” contains over 40% of the entire data, making class-wise data re-balancing extremely expensive. The worse results on Train-CBL further proves the importance of long-tailed classification in real-world applications, as large long-tailed datasets are better than small balanced counterparts.
10 |
11 | ## Steps for MSCOCO-GLT Generation
12 | 1. Download the [MSCOCO dataset](https://cocodataset.org/#download)
13 | 2. Download the [MSCOCO-Attribute Annotations](https://github.com/genp/cocottributes) (you can skip this step, as it has already included in this folder, i.e., cocoattributes_py3.jbl)
14 | 3. Construct training sets and testing sets for GLT, see [SplitGeneration.ipynb](2.SplitGeneration.ipynb). Note that we have three evaluation protocols: 1) Class-wise Long Tail with (Train-GLT, Test-CBL), 2) Attribute-wise Long Tail with (Train-CBL, Test-GBL), and 3) Generalized Long Tail with (Train-GLT, Test-GBL), where Class-wise Long Tail protocol and Generalized Long Tail protocol share the same training set Train-GLT, so there are two annotation files corresponding to two different training sets.
15 | 1. coco_intra_lt_inter_lt.json:
16 | 1. Train-GLT: classes and attribute clusters are both long-tailed
17 | 2. Test-GLT: same as above
18 | 3. Test-CBL: classes are balanced but the attribute clusters are still i.i.d. sampled within each classes, i.e., long-tailed intra-class distribution
19 | 4. Test-GBL: both classes and pretext attributes are balanced
20 | 2. coco_intra_lt_inter_bl.json:
21 | 1. Train-CBL: classes are balanced, but the attribute clusters are still i.i.d. sampled within each classes, i.e., long-tailed intra-class distribution
22 | 2. Test-CBL: same as above
23 | 3. Test-GBL: both classes and pretext attributes are balanced
24 | 4. We further cropped each object from the original MSCOCO images to generate the image data, see [DataGeneration.ipynb](3.DataGeneration.ipynb).
25 |
26 |
27 | ## Download Our Generated Annotation files
28 | You can directly download our generated json files from [the link](https://1drv.ms/u/s!AmRLLNf6bzcitKlpBoTerorv5yaeIw?e=0bw83U).
29 |
30 | Since OneDrive links might be broken in mainland China, we also provide the following alternate link using BaiduNetDisk:
31 |
32 | Link:https://pan.baidu.com/s/1VSpgMVfMFnhQ0RTMzKNIyw Extraction code:1234
33 |
34 | ## Commands
35 | Run the following command to generate the MSCOCO-GLT annotations: 1) coco_intra_lt_inter_lt.json and 2) coco_intra_lt_inter_bl.json
36 | ```
37 | python mscoco_glt_generation.py --data_path YOUR_COCO_IMAGE_FOLDER --anno_path YOUR_COCO_ANNOTATION_FOLDER --attribute_path ./cocottributes_py3.jbl
38 | ```
39 | Run the following command to crop objects from MSCOCO images
40 | ```
41 | python mscoco_glt_crop.py --data_path YOUR_COCO_IMAGE_FOLDER --output_path OUTOUT_FOLDER
42 | ```
43 |
44 | ## Visualization of Attributes and Dataset Statistics
45 |
46 | 
47 | Figure 2. Examples of object with attributes in MSCOCO-GLT.
48 |
49 | 
50 | Figure 3. The algorithm used to generate Test-GBL for MSCOCO-GLT.
51 |
52 |
53 | 
54 | Figure 4. Class distributions and attribute distributions for each split of MSCOCO-GLT benchmark, where the most frequent category is the “person”. Although we cannot strictly balance the attribute distribution in Test-GBL split due to the fact that both head attributes and tail attributes can co-occur in one object, the selected Test-GBL has lower standard deviation of attributes than other splits..
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/_COCOGeneration/cocottributes_py3.jbl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/_COCOGeneration/cocottributes_py3.jbl
--------------------------------------------------------------------------------
/_COCOGeneration/mscoco_glt_crop.py:
--------------------------------------------------------------------------------
1 | # Python 2
2 | #from sklearn.externals import joblib
3 | # Python 3
4 | import joblib
5 |
6 | import argparse
7 | from PIL import Image, ImageDraw
8 | from io import BytesIO
9 | import json
10 | import joblib
11 | import os
12 | import random
13 |
14 | import torch
15 | import numpy as np
16 | import matplotlib.pyplot as plt
17 |
18 | # annotation path
19 | ANNOTATION_LT = './coco_intra_lt_inter_lt.jbl'
20 | ANNOTATION_BL = './coco_intra_lt_inter_bl.jbl'
21 | ROOT_PATH = '/data4/'
22 |
23 |
24 | # ============================================================================
25 | # argument parser
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--data_path', default='/data4/coco2014/images/', type=str, help='indicate the path of images.')
28 | parser.add_argument('--output_path', default='/data4/coco_glt/', type=str, help='output path.')
29 | parser.add_argument('--seed', default=25, type=int, help='Fix the random seed for reproduction. Default is 25.')
30 | args = parser.parse_args()
31 |
32 |
33 | def id_to_path(data_path=args.data_path):
34 | id2path = {}
35 | subpath = ['val2014', 'train2014']
36 | for spath in subpath:
37 | for file in os.listdir(data_path + spath):
38 | if file.endswith(".jpg"):
39 | idx = int(file.split('.')[0].split('_')[-1])
40 | id2path[idx] = os.path.join(data_path, spath, file)
41 | return id2path
42 |
43 | def generate_images_labels(output_img_path, output_ano_path, cocottributes_all):
44 | annotations = {}
45 | cat2id = cocottributes_all['cat2id']
46 | id2cat = {i:cat for cat,i in cat2id.items()}
47 | annotations['cat2id'] = cat2id
48 | annotations['id2cat'] = id2cat
49 | annotations['key2path'] = {}
50 |
51 | train_count_array = get_att_count(cocottributes_all, 'train')
52 |
53 | for setname in ['train', 'val', 'test_lt', 'test_bl', 'test_bbl']:
54 | annotations[setname] = {'label':{}, 'frequency':{}, 'attribute':{}, 'path':{}, 'attribute_score':{}}
55 | # check validity
56 | all_keys = list(cocottributes_all[setname]['label'].keys())
57 | if len(all_keys) == 0:
58 | print('Skip {}'.format(setname))
59 | continue
60 |
61 | # attribute distribution
62 | annotations[setname]['attribute_dist'] = get_att_count(cocottributes_all, setname)
63 | for cat_id in annotations[setname]['attribute_dist'].keys():
64 | annotations[setname]['attribute_dist'][cat_id] = annotations[setname]['attribute_dist'][cat_id].tolist()
65 |
66 | # find attribute threshold
67 | for coco_attr_key in all_keys:
68 | cat_id = cocottributes_all[setname]['label'][coco_attr_key]
69 | att_array = cocottributes_all[setname]['attribute'][coco_attr_key]
70 | base_score = normalize_vector(train_count_array[cat_id])
71 | attr_score = normalize_vector((torch.from_numpy(att_array) > 0.5).float())
72 | annotations[setname]['attribute_score'][coco_attr_key] = (base_score * attr_score).sum().item()
73 | att_scores = list(annotations[setname]['attribute_score'].values())
74 | att_scores.sort(reverse=True)
75 | attribute_high_mid_thres = att_scores[len(att_scores) // 3]
76 | attribute_mid_low_thres = att_scores[len(att_scores) // 3 * 2]
77 |
78 | for i, coco_attr_key in enumerate(all_keys):
79 | if (i%1000 == 0):
80 | print('==== Processing : {}'.format(i/len(all_keys)))
81 | # generate image
82 | print_coco_attributes_instance(cocottributes_all, id2path, coco_attr_key, output_img_path.format(coco_attr_key), setname)
83 | # generate label
84 | annotations[setname]['label'][coco_attr_key] = cocottributes_all[setname]['label'][coco_attr_key]
85 | annotations[setname]['path'][coco_attr_key] = output_img_path.format(coco_attr_key)
86 | annotations[setname]['frequency'][coco_attr_key] = cocottributes_all[setname]['frequency'][coco_attr_key]
87 | if annotations[setname]['attribute_score'][coco_attr_key] > attribute_high_mid_thres:
88 | annotations[setname]['attribute'][coco_attr_key] = 0
89 | elif annotations[setname]['attribute_score'][coco_attr_key] > attribute_mid_low_thres:
90 | annotations[setname]['attribute'][coco_attr_key] = 1
91 | else:
92 | annotations[setname]['attribute'][coco_attr_key] = 2
93 |
94 | with open(output_ano_path, 'w') as outfile:
95 | json.dump(annotations, outfile)
96 |
97 |
98 | def normalize_vector(vector):
99 | output = vector / (vector.sum() + 1e-9)
100 | return output
101 |
102 |
103 | def get_att_count(cocottributes, setname):
104 | split_data = cocottributes[setname]
105 | cat2id = cocottributes['cat2id']
106 |
107 | # update array count
108 | count_array = {}
109 | for item in set(cat2id.values()):
110 | count_array[item] = torch.FloatTensor([0 for i in range(len(cocottributes['attributes']))])
111 | for key in split_data['label'].keys():
112 | cat_id = split_data['label'][key]
113 | att_array = split_data['attribute'][key]
114 | count_array[cat_id] = count_array[cat_id] + (torch.from_numpy(att_array) > 0.5).float()
115 |
116 | return count_array
117 |
118 |
119 | def print_coco_attributes_instance(cocottributes, id2path, coco_attr_id, sname, setname):
120 | # List of COCO Attributes
121 | coco_annotation = cocottributes['annotations'][coco_attr_id]
122 | img_path = id2path[coco_annotation['image_id']]
123 | img = Image.open(img_path)
124 | bbox = coco_annotation['bbox']
125 |
126 | # crop the object bounding box
127 | if bbox[2] < 100:
128 | x1 = max(bbox[0]-50,0)
129 | x2 = min(bbox[0]+50+bbox[2],img.size[0])
130 | else:
131 | x1 = max(bbox[0]-bbox[2]*0.2,0)
132 | x2 = min(bbox[0]+1.2*bbox[2],img.size[0])
133 |
134 | if bbox[3] < 100:
135 | y1 = max(bbox[1]-50,0)
136 | y2 = min(bbox[1]+50+bbox[3],img.size[1])
137 | else:
138 | y1 = max(bbox[1]-bbox[3]*0.2,0)
139 | y2 = min(bbox[1]+1.2*bbox[3],img.size[1])
140 |
141 | img = img.crop((x1, y1, x2, y2))
142 |
143 | # padding the rectangular boxes to square boxes
144 | w, h = img.size
145 | pad_size = (max(w,h) - min(w,h))/2
146 | if w > h:
147 | img = img.crop((0, -pad_size, w, h+pad_size))
148 | else:
149 | img = img.crop((-pad_size, 0, w+pad_size, h))
150 |
151 | # save image
152 | img.save(sname)
153 |
154 |
155 | DATA_TYPE = 'coco_bl'
156 | # output path
157 | if not os.path.exists(args.output_path):
158 | os.mkdir(args.output_path)
159 |
160 |
161 | # generate Train-GLT
162 | if not os.path.exists(os.path.join(args.output_path, 'coco_glt')):
163 | os.mkdir(os.path.join(args.output_path, 'coco_glt'))
164 |
165 | if not os.path.exists(os.path.join(args.output_path, 'coco_glt', 'images')):
166 | os.mkdir(os.path.join(args.output_path, 'coco_glt', 'images'))
167 |
168 | if not os.path.exists(args.output_path):
169 | os.mkdir(args.output_path)
170 | output_img_path = os.path.join(args.output_path, 'coco_glt', 'images') + '/{}.jpg'
171 | output_ano_path = os.path.join(args.output_path, 'coco_glt', 'annotation.json')
172 |
173 | cocottributes_all = joblib.load(ANNOTATION_LT)
174 |
175 | id2path = id_to_path()
176 | generate_images_labels(output_img_path, output_ano_path, cocottributes_all)
177 |
178 |
179 | # generate Train-CBL
180 | if not os.path.exists(os.path.join(args.output_path, 'coco_cbl')):
181 | os.mkdir(os.path.join(args.output_path, 'coco_cbl'))
182 |
183 | if not os.path.exists(os.path.join(args.output_path, 'coco_cbl', 'images')):
184 | os.mkdir(os.path.join(args.output_path, 'coco_cbl', 'images'))
185 |
186 | if not os.path.exists(args.output_path):
187 | os.mkdir(args.output_path)
188 | output_img_path = os.path.join(args.output_path, 'coco_cbl', 'images') + '/{}.jpg'
189 | output_ano_path = os.path.join(args.output_path, 'coco_cbl', 'annotation.json')
190 |
191 | cocottributes_all = joblib.load(ANNOTATION_BL)
192 |
193 | id2path = id_to_path()
194 | generate_images_labels(output_img_path, output_ano_path, cocottributes_all)
195 |
--------------------------------------------------------------------------------
/_ImageNetGeneration/4.ChangePath.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 8,
6 | "id": "60106164",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import json\n",
11 | "\n",
12 | "LT_PATH = './imagenet_sup_intra_lt_inter_lt.json'\n",
13 | "BL_PATH = './imagenet_sup_intra_lt_inter_bl.json'\n",
14 | "\n",
15 | "OLD_PATH = '/data4'\n",
16 | "NEW_PATH = '/home/kaihua.tkh/datasets'"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 9,
22 | "id": "05c1f024",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "anno_file = json.load(open(LT_PATH))\n",
27 | "\n",
28 | "for split in ('train', 'val', 'test_lt', 'test_bl', 'test_bbl'):\n",
29 | " split_new_data = {}\n",
30 | " split_data = anno_file[split]\n",
31 | " for k_type, val in split_data.items():\n",
32 | " split_new_data[k_type] = {}\n",
33 | " for path, label in val.items():\n",
34 | " assert path.startswith(OLD_PATH) \n",
35 | " new_path = NEW_PATH + path[len(OLD_PATH):]\n",
36 | " split_new_data[k_type][new_path] = label\n",
37 | " anno_file[split] = split_new_data\n",
38 | "\n",
39 | "with open(LT_PATH, 'w') as outfile:\n",
40 | " json.dump(anno_file, outfile)"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 10,
46 | "id": "b843650f",
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "anno_file = json.load(open(BL_PATH))\n",
51 | "\n",
52 | "for split in ('train', 'val', 'test_lt', 'test_bl', 'test_bbl'):\n",
53 | " split_new_data = {}\n",
54 | " split_data = anno_file[split]\n",
55 | " for k_type, val in split_data.items():\n",
56 | " split_new_data[k_type] = {}\n",
57 | " for path, label in val.items():\n",
58 | " assert path.startswith(OLD_PATH) \n",
59 | " new_path = NEW_PATH + path[len(OLD_PATH):]\n",
60 | " split_new_data[k_type][new_path] = label\n",
61 | " anno_file[split] = split_new_data\n",
62 | "\n",
63 | "with open(BL_PATH, 'w') as outfile:\n",
64 | " json.dump(anno_file, outfile)"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "id": "5a7bdaf8",
71 | "metadata": {},
72 | "outputs": [],
73 | "source": []
74 | }
75 | ],
76 | "metadata": {
77 | "kernelspec": {
78 | "display_name": "Python 3",
79 | "language": "python",
80 | "name": "python3"
81 | },
82 | "language_info": {
83 | "codemirror_mode": {
84 | "name": "ipython",
85 | "version": 3
86 | },
87 | "file_extension": ".py",
88 | "mimetype": "text/x-python",
89 | "name": "python",
90 | "nbconvert_exporter": "python",
91 | "pygments_lexer": "ipython3",
92 | "version": "3.6.10"
93 | }
94 | },
95 | "nbformat": 4,
96 | "nbformat_minor": 5
97 | }
98 |
--------------------------------------------------------------------------------
/_ImageNetGeneration/README.md:
--------------------------------------------------------------------------------
1 | # ImageNet-GLT Dataset Generation
2 |
3 | ## Introduction
4 |
5 | 
6 | Figure 1. Collecting an “Attribute-Wise Balanced” Test Set for ImageNet.
7 |
8 | For [ImageNet](https://ieeexplore.ieee.org/abstract/document/5206848) dataset, there is no ground-truth attribute labels for us to monitor the intra-class distribution and construct the GLT benchmark. However, attribute-wise long-tailed distribution, i.e., intra-class imbalance, is commonly existing in all kinds of data. Therefore, we can apply the clustering algorithms (e.g. K-Means in our project) to create a set of clusters within each class as ``pretext attributes''. In other words, each cluster represents a meta attribute layout for this class, e.g., one cluster for cat category can be ginger cat in house, another cluster for cat category is black cat on street, etc.
9 |
10 | Note that our clustered attributes here are purely constructed based on intra-class feature distribution, so it stands for all kinds of factors causing the intra-class variance, including object-level attributes like textures, or image-level attributes like contexts.
11 |
12 | ## Steps for ImageNet-GLT Generation
13 | 1. Download the [ImageNet dataset](https://image-net.org/download.php)
14 | 2. Use a pre-trained model to extract features (our default choice is ii: A torchvision Pre-trained ResNet):
15 | 1. [MoCo Pre-trained Model](https://github.com/facebookresearch/moco), see [MoCoFeature.ipynb](1.MoCoFeature.ipynb)
16 | 2. Torchvision Pre-trained ResNet, see [SupervisedFeature.ipynb](1.SupervisedFeature.ipynb)
17 | 3. Using MMD-VAE with reconstruction loss to train a model and extract features, see [UnsupervisedFeature.ipynb](1.UnsupervisedFeature.ipynb)
18 | 3. Apply the K-Means algorithm to cluster images within each class into K (6 by default) pretext attributes based on extracted features
19 | 4. Construct training sets and testing sets for GLT, see [SplitGeneration.ipynb](2.SplitGeneration.ipynb). Note that we have three evaluation protocols: 1) Class-wise Long Tail with (Train-GLT, Test-CBL), 2) Attribute-wise Long Tail with (Train-CBL, Test-GBL), and 3) Generalized Long Tail with (Train-GLT, Test-GBL), where Class-wise Long Tail protocol and Generalized Long Tail protocol share the same training set Train-GLT, so there are two annotation files corresponding to two different training sets.
20 | 1. imagenet_sup_intra_lt_inter_lt.json:
21 | 1. Train-GLT: classes and attribute clusters are both long-tailed
22 | 2. Test-GLT: same as above
23 | 3. Test-CBL: classes are balanced but the attribute clusters are still i.i.d. sampled within each classes, i.e., long-tailed intra-class distribution
24 | 4. Test-GBL: both classes and pretext attributes are balanced
25 | 2. imagenet_sup_intra_lt_inter_bl.json:
26 | 1. Train-CBL: classes are balanced, but the attribute clusters are still i.i.d. sampled within each classes, i.e., long-tailed intra-class distribution
27 | 2. Test-CBL: same as above
28 | 3. Test-GBL: both classes and pretext attributes are balanced
29 |
30 |
31 | ## Download Our Generated Annotation files
32 | You can directly download our generated json files from [the link](https://1drv.ms/u/s!AmRLLNf6bzcitKlpBoTerorv5yaeIw?e=0bw83U).
33 |
34 | Since OneDrive links might be broken in mainland China, we also provide the following alternate link using BaiduNetDisk:
35 |
36 | Link:https://pan.baidu.com/s/1VSpgMVfMFnhQ0RTMzKNIyw Extraction code:1234
37 |
38 | ## Commands
39 | Run the following command to generate the ImageNet-GLT annotations: 1) imagenet_sup_intra_lt_inter_lt.json and 2) imagenet_sup_intra_lt_inter_bl.json
40 | ```
41 | python imagenet_glt_generation.py --data_path YOUR_IMAGENET_FOLDER_FOR_TRAINSET
42 | ```
43 |
44 | ## Visualization of Clusters and Dataset Statistics
45 |
46 | 
47 | Figure 2. Examples of feature clusters using KMeans (through t-SNE) and the corresponding visualized images for each cluster in the original ImageNet.
48 |
49 |
50 | 
51 | Figure 3. Class distributions and cluster distributions for each split of ImageNet-GLT benchmark. Note that clusters may represent different attribute layouts in each classes, so ImageNet-GLT actually has 6 × 1000 pretext attributes rather than 6, i.e, each column in the cluster distribution stands for 1000 pretext attributes having the same frequency.
52 |
--------------------------------------------------------------------------------
/_ImageNetGeneration/long-tail-distribution.pytorch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/_ImageNetGeneration/long-tail-distribution.pytorch
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/__init__.py
--------------------------------------------------------------------------------
/config/COCO_BL.yaml:
--------------------------------------------------------------------------------
1 | output_dir: null
2 |
3 | dataset:
4 | name: 'MSCOCO-BL'
5 | data_path: ./data/coco_bl/images
6 | anno_path: ./data/coco_bl/annotations/annotation.json
7 | # lt : test distribution is long-tailed in both cateogies and atttributes
8 | # bl : the category is balanced
9 | # bbl : both category and attribute are balanced
10 | testset: 'test_lt' # test_lt / test_bl / test_bbl
11 | rgb_mean: [0.473, 0.429, 0.370]
12 | rgb_std: [0.277, 0.268, 0.274]
13 | rand_aug: False
14 |
15 | sampler: default
16 |
17 | networks:
18 | type: ResNext
19 | ResNext:
20 | def_file: ./models/ResNet.py
21 | params: {m_type: 'resnext50'}
22 | BBN:
23 | def_file: ./models/ResNet_BBN.py
24 | params: {m_type: 'bbn_resnet50'}
25 | RIDE:
26 | def_file: ./models/ResNet_RIDE.py
27 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True}
28 |
29 | classifiers:
30 | type: FC
31 | FC:
32 | def_file: ./models/ClassifierFC.py
33 | params: {feat_dim: 2048, num_classes: 29}
34 | RIDE:
35 | def_file: ./models/ClassifierRIDE.py
36 | params: {feat_dim: 1536, num_classes: 29, num_experts: 3, use_norm: True}
37 | BBN:
38 | def_file: ./models/ClassifierFC.py
39 | params: {feat_dim: 4096, num_classes: 29}
40 | COS:
41 | def_file: ./models/ClassifierCOS.py
42 | params: {feat_dim: 2048, num_classes: 29, num_head: 1, tau: 30.0}
43 | LDAM:
44 | def_file: ./models/ClassifierLDAM.py
45 | params: {feat_dim: 2048, num_classes: 29}
46 | TDE:
47 | def_file: ./models/ClassifierTDE.py
48 | params: {feat_dim: 2048, num_classes: 29, num_head: 2, tau: 16.0, alpha: 0.0, gamma: 0.03125}
49 | LA:
50 | def_file: ./models/ClassifierLA.py
51 | params: {feat_dim: 2048, num_classes: 29, posthoc: True, loss: False}
52 | LWS:
53 | def_file: ./models/ClassifierLWS.py
54 | params: {feat_dim: 2048, num_classes: 29}
55 | MultiHead:
56 | def_file: ./models/ClassifierMultiHead.py
57 | params: {feat_dim: 2048, num_classes: 29}
58 |
59 | training_opt:
60 | type: baseline # baseline / mixup / two_stage2
61 | num_epochs: 100
62 | batch_size: 256
63 | data_workers: 4
64 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM
65 | loss_params: {alpha: 1.0, gamma: 2.0}
66 | optimizer: 'SGD' # 'Adam' / 'SGD'
67 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005}
68 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep'
69 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]}
70 |
71 | testing_opt:
72 | type: baseline # baseline / TDE
73 |
74 | logger_opt:
75 | print_grad: false
76 | print_iter: 100
77 |
78 | checkpoint_opt:
79 | checkpoint_step: 10
80 | checkpoint_name: 'test_model.pth'
81 |
82 | saving_opt:
83 | save_all: false
--------------------------------------------------------------------------------
/config/COCO_LT.yaml:
--------------------------------------------------------------------------------
1 | output_dir: null
2 |
3 | dataset:
4 | name: 'MSCOCO-LT'
5 | data_path: ./data/coco_lt/images
6 | anno_path: ./data/coco_lt/annotations/annotation.json
7 | # lt : test distribution is long-tailed in both cateogies and atttributes
8 | # bl : the category is balanced
9 | # bbl : both category and attribute are balanced
10 | testset: 'test_lt' # test_lt / test_bl / test_bbl
11 | rgb_mean: [0.473, 0.429, 0.370]
12 | rgb_std: [0.277, 0.268, 0.274]
13 | rand_aug: False
14 |
15 | sampler: default
16 |
17 | networks:
18 | type: ResNext
19 | ResNext:
20 | def_file: ./models/ResNet.py
21 | params: {m_type: 'resnext50'}
22 | BBN:
23 | def_file: ./models/ResNet_BBN.py
24 | params: {m_type: 'bbn_resnet50'}
25 | RIDE:
26 | def_file: ./models/ResNet_RIDE.py
27 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True}
28 |
29 | classifiers:
30 | type: FC
31 | FC:
32 | def_file: ./models/ClassifierFC.py
33 | params: {feat_dim: 2048, num_classes: 29}
34 | RIDE:
35 | def_file: ./models/ClassifierRIDE.py
36 | params: {feat_dim: 1536, num_classes: 29, num_experts: 3, use_norm: True}
37 | BBN:
38 | def_file: ./models/ClassifierFC.py
39 | params: {feat_dim: 4096, num_classes: 29}
40 | COS:
41 | def_file: ./models/ClassifierCOS.py
42 | params: {feat_dim: 2048, num_classes: 29, num_head: 1, tau: 30.0}
43 | LDAM:
44 | def_file: ./models/ClassifierLDAM.py
45 | params: {feat_dim: 2048, num_classes: 29}
46 | TDE:
47 | def_file: ./models/ClassifierTDE.py
48 | params: {feat_dim: 2048, num_classes: 29, num_head: 2, tau: 16.0, alpha: 1.0, gamma: 0.03125}
49 | LA:
50 | def_file: ./models/ClassifierLA.py
51 | params: {feat_dim: 2048, num_classes: 29, posthoc: True, loss: False}
52 | LWS:
53 | def_file: ./models/ClassifierLWS.py
54 | params: {feat_dim: 2048, num_classes: 29}
55 | MultiHead:
56 | def_file: ./models/ClassifierMultiHead.py
57 | params: {feat_dim: 2048, num_classes: 29}
58 |
59 | training_opt:
60 | type: baseline # baseline / mixup / two_stage2
61 | num_epochs: 100
62 | batch_size: 256
63 | data_workers: 4
64 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM
65 | loss_params: {alpha: 1.0, gamma: 2.0}
66 | optimizer: 'SGD' # 'Adam' / 'SGD'
67 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005}
68 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep'
69 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]}
70 |
71 | testing_opt:
72 | type: baseline # baseline / TDE
73 |
74 | logger_opt:
75 | print_grad: false
76 | print_iter: 100
77 |
78 | checkpoint_opt:
79 | checkpoint_step: 10
80 | checkpoint_name: 'test_model.pth'
81 |
82 | saving_opt:
83 | save_all: false
--------------------------------------------------------------------------------
/config/ColorMNIST_BL.yaml:
--------------------------------------------------------------------------------
1 | output_dir: null
2 |
3 | dataset:
4 | name: 'ColorMNIST-BL'
5 | data_path: './_ColorMNISTGeneration/'
6 | # lt : test distribution is long-tailed in both cateogies and atttributes
7 | # bl : the category is balanced
8 | # bbl : both category and attribute are balanced
9 | testset: 'test_lt' # test_lt / test_bl / test_bbl
10 | cat_ratio: 1.0
11 | att_ratio: 0.1
12 | rand_aug: False
13 |
14 | sampler: default
15 |
16 | networks:
17 | type: ResNext
18 | ResNext:
19 | def_file: ./models/ResNet.py
20 | params: {m_type: 'resnext50'}
21 | BBN:
22 | def_file: ./models/ResNet_BBN.py
23 | params: {m_type: 'bbn_resnet50'}
24 | RIDE:
25 | def_file: ./models/ResNet_RIDE.py
26 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True}
27 |
28 | classifiers:
29 | type: FC
30 | FC:
31 | def_file: ./models/ClassifierFC.py
32 | params: {feat_dim: 2048, num_classes: 3}
33 | RIDE:
34 | def_file: ./models/ClassifierRIDE.py
35 | params: {feat_dim: 1536, num_classes: 3, num_experts: 3, use_norm: True}
36 | BBN:
37 | def_file: ./models/ClassifierFC.py
38 | params: {feat_dim: 4096, num_classes: 3}
39 | COS:
40 | def_file: ./models/ClassifierCOS.py
41 | params: {feat_dim: 2048, num_classes: 3, num_head: 1, tau: 30.0}
42 | LDAM:
43 | def_file: ./models/ClassifierLDAM.py
44 | params: {feat_dim: 2048, num_classes: 3}
45 | TDE:
46 | def_file: ./models/ClassifierTDE.py
47 | params: {feat_dim: 2048, num_classes: 3, num_head: 2, tau: 16.0, alpha: 0.0, gamma: 0.03125}
48 | LA:
49 | def_file: ./models/ClassifierLA.py
50 | params: {feat_dim: 2048, num_classes: 3, posthoc: True, loss: False}
51 | LWS:
52 | def_file: ./models/ClassifierLWS.py
53 | params: {feat_dim: 2048, num_classes: 3}
54 | MultiHead:
55 | def_file: ./models/ClassifierMultiHead.py
56 | params: {feat_dim: 2048, num_classes: 3}
57 |
58 | training_opt:
59 | type: baseline # baseline / mixup / two_stage2
60 | num_epochs: 100
61 | batch_size: 256
62 | data_workers: 4
63 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM
64 | loss_params: {alpha: 1.0, gamma: 2.0}
65 | optimizer: 'SGD' # 'Adam' / 'SGD'
66 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005}
67 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep'
68 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]}
69 |
70 | testing_opt:
71 | type: baseline # baseline / TDE
72 |
73 | logger_opt:
74 | print_grad: false
75 | print_iter: 100
76 |
77 | checkpoint_opt:
78 | checkpoint_step: 10
79 | checkpoint_name: 'test_model.pth'
80 |
81 | saving_opt:
82 | save_all: false
--------------------------------------------------------------------------------
/config/ColorMNIST_LT.yaml:
--------------------------------------------------------------------------------
1 | output_dir: null
2 |
3 | dataset:
4 | name: 'ColorMNIST-LT'
5 | data_path: './_ColorMNISTGeneration/'
6 | # lt : test distribution is long-tailed in both cateogies and atttributes
7 | # bl : the category is balanced
8 | # bbl : both category and attribute are balanced
9 | testset: 'test_lt' # test_lt / test_bl / test_bbl
10 | cat_ratio: 0.1
11 | att_ratio: 0.1
12 | rand_aug: False
13 |
14 | sampler: default
15 |
16 | networks:
17 | type: ResNext
18 | ResNext:
19 | def_file: ./models/ResNet.py
20 | params: {m_type: 'resnext50'}
21 | BBN:
22 | def_file: ./models/ResNet_BBN.py
23 | params: {m_type: 'bbn_resnet50'}
24 | RIDE:
25 | def_file: ./models/ResNet_RIDE.py
26 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True}
27 |
28 | classifiers:
29 | type: FC
30 | FC:
31 | def_file: ./models/ClassifierFC.py
32 | params: {feat_dim: 2048, num_classes: 3}
33 | RIDE:
34 | def_file: ./models/ClassifierRIDE.py
35 | params: {feat_dim: 1536, num_classes: 3, num_experts: 3, use_norm: True}
36 | BBN:
37 | def_file: ./models/ClassifierFC.py
38 | params: {feat_dim: 4096, num_classes: 3}
39 | COS:
40 | def_file: ./models/ClassifierCOS.py
41 | params: {feat_dim: 2048, num_classes: 3, num_head: 1, tau: 30.0}
42 | LDAM:
43 | def_file: ./models/ClassifierLDAM.py
44 | params: {feat_dim: 2048, num_classes: 3}
45 | TDE:
46 | def_file: ./models/ClassifierTDE.py
47 | params: {feat_dim: 2048, num_classes: 3, num_head: 2, tau: 16.0, alpha: 0.1, gamma: 0.03125}
48 | LA:
49 | def_file: ./models/ClassifierLA.py
50 | params: {feat_dim: 2048, num_classes: 3, posthoc: True, loss: False}
51 | LWS:
52 | def_file: ./models/ClassifierLWS.py
53 | params: {feat_dim: 2048, num_classes: 3}
54 | MultiHead:
55 | def_file: ./models/ClassifierMultiHead.py
56 | params: {feat_dim: 2048, num_classes: 3}
57 |
58 | training_opt:
59 | type: baseline # baseline / mixup / two_stage2
60 | num_epochs: 100
61 | batch_size: 256
62 | data_workers: 4
63 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM
64 | loss_params: {alpha: 1.0, gamma: 2.0}
65 | optimizer: 'SGD' # 'Adam' / 'SGD'
66 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005}
67 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep'
68 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]}
69 |
70 | testing_opt:
71 | type: baseline # baseline / TDE
72 |
73 | logger_opt:
74 | print_grad: false
75 | print_iter: 100
76 |
77 | checkpoint_opt:
78 | checkpoint_step: 10
79 | checkpoint_name: 'test_model.pth'
80 |
81 | saving_opt:
82 | save_all: false
--------------------------------------------------------------------------------
/config/ImageNet_BL.yaml:
--------------------------------------------------------------------------------
1 | output_dir: null
2 |
3 | dataset:
4 | name: 'ImageNet-BL'
5 | anno_path: ./_ImageNetGeneration/imagenet_sup_intra_lt_inter_bl.json
6 | # lt : test distribution is long-tailed in both cateogies and atttributes
7 | # bl : the category is balanced
8 | # bbl : both category and attribute are balanced
9 | testset: 'test_lt' # test_lt / test_bl / test_bbl
10 | rgb_mean: [0.485, 0.456, 0.406]
11 | rgb_std: [0.229, 0.224, 0.225]
12 | rand_aug: False
13 |
14 | sampler: default
15 |
16 | networks:
17 | type: ResNext
18 | ResNext:
19 | def_file: ./models/ResNet.py
20 | params: {m_type: 'resnext50'}
21 | BBN:
22 | def_file: ./models/ResNet_BBN.py
23 | params: {m_type: 'bbn_resnet50'}
24 | RIDE:
25 | def_file: ./models/ResNet_RIDE.py
26 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True}
27 |
28 | classifiers:
29 | type: FC
30 | FC:
31 | def_file: ./models/ClassifierFC.py
32 | params: {feat_dim: 2048, num_classes: 1000}
33 | RIDE:
34 | def_file: ./models/ClassifierRIDE.py
35 | params: {feat_dim: 1536, num_classes: 1000, num_experts: 3, use_norm: True}
36 | BBN:
37 | def_file: ./models/ClassifierFC.py
38 | params: {feat_dim: 4096, num_classes: 1000}
39 | COS:
40 | def_file: ./models/ClassifierCOS.py
41 | params: {feat_dim: 2048, num_classes: 1000, num_head: 1, tau: 30.0}
42 | LDAM:
43 | def_file: ./models/ClassifierLDAM.py
44 | params: {feat_dim: 2048, num_classes: 1000}
45 | TDE:
46 | def_file: ./models/ClassifierTDE.py
47 | params: {feat_dim: 2048, num_classes: 1000, num_head: 2, tau: 16.0, alpha: 0.0, gamma: 0.03125}
48 | LA:
49 | def_file: ./models/ClassifierLA.py
50 | params: {feat_dim: 2048, num_classes: 1000, posthoc: True, loss: False}
51 | LWS:
52 | def_file: ./models/ClassifierLWS.py
53 | params: {feat_dim: 2048, num_classes: 1000}
54 | MultiHead:
55 | def_file: ./models/ClassifierMultiHead.py
56 | params: {feat_dim: 2048, num_classes: 1000}
57 |
58 | training_opt:
59 | type: baseline # baseline / mixup / two_stage2
60 | num_epochs: 100
61 | batch_size: 256
62 | data_workers: 4
63 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM
64 | loss_params: {alpha: 1.0, gamma: 2.0}
65 | optimizer: 'SGD' # 'Adam' / 'SGD'
66 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005}
67 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep'
68 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]}
69 |
70 | testing_opt:
71 | type: baseline # baseline / TDE
72 |
73 | logger_opt:
74 | print_grad: false
75 | print_iter: 100
76 |
77 | checkpoint_opt:
78 | checkpoint_step: 10
79 | checkpoint_name: 'test_model.pth'
80 |
81 | saving_opt:
82 | save_all: false
--------------------------------------------------------------------------------
/config/ImageNet_LT.yaml:
--------------------------------------------------------------------------------
1 | output_dir: null
2 |
3 | dataset:
4 | name: 'ImageNet-LT'
5 | anno_path: ./_ImageNetGeneration/imagenet_sup_intra_lt_inter_lt.json
6 | # lt : test distribution is long-tailed in both cateogies and atttributes
7 | # bl : the category is balanced
8 | # bbl : both category and attribute are balanced
9 | testset: 'test_lt' # test_lt / test_bl / test_bbl
10 | rgb_mean: [0.485, 0.456, 0.406]
11 | rgb_std: [0.229, 0.224, 0.225]
12 | rand_aug: False
13 |
14 | sampler: default
15 |
16 | networks:
17 | type: ResNext
18 | ResNext:
19 | def_file: ./models/ResNet.py
20 | params: {m_type: 'resnext50'}
21 | BBN:
22 | def_file: ./models/ResNet_BBN.py
23 | params: {m_type: 'bbn_resnet50'}
24 | RIDE:
25 | def_file: ./models/ResNet_RIDE.py
26 | params: {m_type: 'resnext50', num_experts: 3, reduce_dimension: True}
27 |
28 | classifiers:
29 | type: FC
30 | FC:
31 | def_file: ./models/ClassifierFC.py
32 | params: {feat_dim: 2048, num_classes: 1000}
33 | RIDE:
34 | def_file: ./models/ClassifierRIDE.py
35 | params: {feat_dim: 1536, num_classes: 1000, num_experts: 3, use_norm: True}
36 | BBN:
37 | def_file: ./models/ClassifierFC.py
38 | params: {feat_dim: 4096, num_classes: 1000}
39 | COS:
40 | def_file: ./models/ClassifierCOS.py
41 | params: {feat_dim: 2048, num_classes: 1000, num_head: 1, tau: 30.0}
42 | LDAM:
43 | def_file: ./models/ClassifierLDAM.py
44 | params: {feat_dim: 2048, num_classes: 1000}
45 | TDE:
46 | def_file: ./models/ClassifierTDE.py
47 | params: {feat_dim: 2048, num_classes: 1000, num_head: 2, tau: 16.0, alpha: 2.0, gamma: 0.03125}
48 | LA:
49 | def_file: ./models/ClassifierLA.py
50 | params: {feat_dim: 2048, num_classes: 1000, posthoc: True, loss: False}
51 | LWS:
52 | def_file: ./models/ClassifierLWS.py
53 | params: {feat_dim: 2048, num_classes: 1000}
54 | MultiHead:
55 | def_file: ./models/ClassifierMultiHead.py
56 | params: {feat_dim: 2048, num_classes: 1000}
57 |
58 | training_opt:
59 | type: baseline # baseline / mixup / two_stage2
60 | num_epochs: 100
61 | batch_size: 256
62 | data_workers: 4
63 | loss: 'CrossEntropy' # CrossEntropy / Focal / BalancedSoftmax / LDAM
64 | loss_params: {alpha: 1.0, gamma: 2.0}
65 | optimizer: 'SGD' # 'Adam' / 'SGD'
66 | optim_params: {lr: 0.1, momentum: 0.9, weight_decay: 0.0005}
67 | scheduler: 'cosine' # 'cosine' / 'step' / 'multistep'
68 | scheduler_params: {endlr: 0.0, gamma: 0.1, step_size: 35, milestones: [120, 160]}
69 |
70 | testing_opt:
71 | type: baseline # baseline / TDE
72 |
73 | logger_opt:
74 | print_grad: false
75 | print_iter: 100
76 |
77 | checkpoint_opt:
78 | checkpoint_step: 10
79 | checkpoint_name: 'test_model.pth'
80 |
81 | saving_opt:
82 | save_all: false
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/config/__init__.py
--------------------------------------------------------------------------------
/data/DT_COCO_LT.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 |
6 | import os
7 | import json
8 |
9 | import torch
10 | import torch.utils.data as data
11 | import torchvision.transforms as transforms
12 | from PIL import Image
13 |
14 | from randaugment import RandAugment
15 |
16 | class COCO_LT(data.Dataset):
17 | def __init__(self, phase, data_path, anno_path, testset, rgb_mean, rgb_std, rand_aug, output_path, logger):
18 | super(COCO_LT, self).__init__()
19 | valid_phase = ['train', 'val', 'test']
20 | assert phase in valid_phase
21 | if phase == 'train':
22 | full_phase = 'train'
23 | elif phase == 'test':
24 | full_phase = testset
25 | else:
26 | full_phase = phase
27 | logger.info('====== The Current Split is : {}'.format(full_phase))
28 | self.logger = logger
29 |
30 | self.dataset_info = {}
31 | self.phase = phase
32 | self.rand_aug = rand_aug
33 | self.data_path = data_path
34 |
35 | self.annotations = json.load(open(anno_path))
36 | self.data = self.annotations[full_phase]
37 | self.transform = self.get_data_transform(phase, rgb_mean, rgb_std)
38 |
39 | # load dataset category info
40 | logger.info('=====> Load dataset category info')
41 | self.id2cat, self.cat2id = self.annotations['id2cat'], self.annotations['cat2id']
42 |
43 | # load all image info
44 | logger.info('=====> Load image info')
45 | self.img_paths, self.labels, self.attributes, self.frequencies = self.load_img_info()
46 |
47 | # save dataset info
48 | logger.info('=====> Save dataset info')
49 | self.dataset_info['cat2id'] = self.cat2id
50 | self.dataset_info['id2cat'] = self.id2cat
51 | self.save_dataset_info(output_path)
52 |
53 |
54 | def __len__(self):
55 | return len(self.labels)
56 |
57 |
58 | def __getitem__(self, index):
59 | path = self.img_paths[index]
60 | label = self.labels[index]
61 | rarity = self.frequencies[index]
62 |
63 | with open(path, 'rb') as f:
64 | sample = Image.open(f).convert('RGB')
65 |
66 | if self.transform is not None:
67 | sample = self.transform(sample)
68 |
69 | # intra-class attribute SHOULD NOT be used during training
70 | if self.phase != 'train':
71 | attribute = self.attributes[index]
72 | return sample, label, rarity, attribute, index
73 | else:
74 | return sample, label, rarity, index
75 |
76 |
77 | #######################################
78 | # Load image info
79 | #######################################
80 | def load_img_info(self):
81 | img_paths = []
82 | labels = []
83 | attributes = []
84 | frequencies = []
85 |
86 | for key, label in self.data['label'].items():
87 | img_paths.append(self.data['path'][key])
88 | labels.append(int(label))
89 | frequencies.append(int(self.data['frequency'][key]))
90 |
91 | # intra-class attribute SHOULD NOT be used in training
92 | if self.phase != 'train':
93 | att_label = int(self.data['attribute'][key])
94 | attributes.append(att_label)
95 |
96 | # save dataset info
97 | self.dataset_info['img_paths'] = img_paths
98 | self.dataset_info['labels'] = labels
99 | self.dataset_info['attributes'] = attributes
100 | self.dataset_info['frequencies'] = frequencies
101 |
102 | return img_paths, labels, attributes, frequencies
103 |
104 |
105 | #######################################
106 | # Save dataset info
107 | #######################################
108 | def save_dataset_info(self, output_path):
109 |
110 | with open(os.path.join(output_path, 'dataset_info_{}.json'.format(self.phase)), 'w') as f:
111 | json.dump(self.dataset_info, f)
112 |
113 | del self.dataset_info
114 |
115 |
116 | #######################################
117 | # transform
118 | #######################################
119 | def get_data_transform(self, phase, rgb_mean, rgb_std):
120 | transform_info = {
121 | 'rgb_mean': rgb_mean,
122 | 'rgb_std': rgb_std,
123 | }
124 |
125 | if phase == 'train':
126 | if self.rand_aug:
127 | self.logger.info('============= Using Rand Augmentation in Dataset ===========')
128 | trans = transforms.Compose([
129 | transforms.RandomResizedCrop(112, scale=(0.5, 1.0)),
130 | transforms.RandomHorizontalFlip(),
131 | RandAugment(),
132 | transforms.ToTensor(),
133 | transforms.Normalize(rgb_mean, rgb_std)
134 | ])
135 | transform_info['operations'] = ['RandomResizedCrop(112, scale=(0.5, 1.0)),', 'RandomHorizontalFlip()',
136 | 'RandAugment()', 'ToTensor()', 'Normalize(rgb_mean, rgb_std)']
137 | else:
138 | self.logger.info('============= Using normal transforms in Dataset ===========')
139 | trans = transforms.Compose([
140 | transforms.RandomResizedCrop(112, scale=(0.5, 1.0)),
141 | transforms.RandomHorizontalFlip(),
142 | transforms.ToTensor(),
143 | transforms.Normalize(rgb_mean, rgb_std)
144 | ])
145 | transform_info['operations'] = ['RandomResizedCrop(112, scale=(0.5, 1.0)),', 'RandomHorizontalFlip()',
146 | 'ToTensor()', 'Normalize(rgb_mean, rgb_std)']
147 | else:
148 | trans = transforms.Compose([
149 | transforms.Resize(128),
150 | transforms.CenterCrop(112),
151 | transforms.ToTensor(),
152 | transforms.Normalize(rgb_mean, rgb_std)
153 | ])
154 | transform_info['operations'] = ['Resize(128)', 'CenterCrop(112)', 'ToTensor()', 'Normalize(rgb_mean, rgb_std)']
155 |
156 | # save dataset info
157 | self.dataset_info['transform_info'] = transform_info
158 |
159 | return trans
--------------------------------------------------------------------------------
/data/DT_ColorMNIST.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 |
6 | from PIL import Image, ImageDraw
7 | from io import BytesIO
8 | import json
9 | import os
10 | import random
11 |
12 | import numpy as np
13 | import matplotlib.pyplot as plt
14 |
15 | import torch
16 | import torchvision
17 | import torch.nn as nn
18 | import torch.optim as optim
19 | import torch.utils.data as data
20 | import torchvision.transforms as transforms
21 |
22 | from randaugment import RandAugment
23 |
24 |
25 | class ColorMNIST_LT(torchvision.datasets.MNIST):
26 | def __init__(self, phase, testset, data_path, logger, cat_ratio=1.0, att_ratio=0.1, rand_aug=False):
27 | super(ColorMNIST_LT, self).__init__(root=data_path, train=(phase == 'train'), download=True)
28 | # mnist dataset contains self.data, self.targets
29 | self.dig2label = {0: 0, 1: 0, 2: 0, 3: 0, 4: 1, 5: 1, 6: 1, 7: 2, 8: 2, 9: 2}
30 | self.dig2attri = {}
31 | self.colors = {0:[1,0,0], 1:[0,1,0], 2:[0,0,1]}
32 | self.logger = logger
33 | self.phase = phase
34 |
35 | # ColorMNIST should not use rand augmentation, as its patterns are too simple
36 | # and its color confounder could be over-written by augmentation
37 | assert rand_aug == False
38 |
39 | valid_phase = ['train', 'val', 'test']
40 | assert phase in valid_phase
41 | if phase == 'train':
42 | full_phase = 'train'
43 | # generate long-tailed data
44 | self.cat_ratio = cat_ratio
45 | self.att_ratio = att_ratio
46 | self.generate_lt_label(cat_ratio)
47 | elif phase == 'test':
48 | full_phase = testset
49 | # generate long-tailed data
50 | if full_phase == 'test_iid':
51 | self.cat_ratio = cat_ratio
52 | self.att_ratio = att_ratio
53 | self.generate_lt_label(cat_ratio)
54 | elif full_phase == 'test_half_bl':
55 | self.cat_ratio = 1.0
56 | self.att_ratio = att_ratio
57 | self.generate_lt_label(1.0)
58 | elif full_phase == 'test_bl':
59 | self.cat_ratio = 1.0
60 | self.att_ratio = 1.0
61 | self.generate_lt_label(1.0)
62 | else:
63 | full_phase = phase
64 | self.cat_ratio = cat_ratio
65 | self.att_ratio = att_ratio
66 | self.generate_lt_label(cat_ratio)
67 | logger.info('====== The Current Split is : {}'.format(full_phase))
68 |
69 |
70 |
71 | def generate_lt_label(self, ratio=1.0):
72 | self.label2list = {i:[] for i in range(3)}
73 | for img, dig in zip(self.data, self.targets):
74 | label = self.dig2label[int(dig)]
75 | self.label2list[label].append(img)
76 | if ratio == 1.0:
77 | balance_size = min([len(val) for key, val in self.label2list.items()])
78 | for key, val in self.label2list.items():
79 | self.label2list[key] = val[:balance_size]
80 | elif ratio < 1.0:
81 | current_size = len(self.label2list[0])
82 | for key, val in self.label2list.items():
83 | max_size = len(val)
84 | self.label2list[key] = val[:min(max_size, current_size)]
85 | current_size = int(current_size * ratio)
86 | else:
87 | raise ValueError('Wrong Ratio in ColorMNIST')
88 |
89 | self.labels = []
90 | self.imgs = []
91 | for key, val in self.label2list.items():
92 | for item in val:
93 | self.labels.append(key)
94 | self.imgs.append(item)
95 | self.logger.info('Generate ColorMNIST: label {} has {} images.'.format(key, len(val)))
96 |
97 |
98 | def __len__(self):
99 | return len(self.labels)
100 |
101 | def __getitem__(self, index):
102 | img = self.imgs[index].unsqueeze(0).repeat(3,1,1)
103 | label = self.labels[index]
104 |
105 | # generate tail colors
106 | if random.random() < self.att_ratio:
107 | att_label = random.randint(0,2)
108 | else:
109 | att_label = label
110 | color = self.colors[att_label]
111 |
112 | # assign attribute
113 | img = self.to_color(img, color)
114 |
115 | if self.phase != 'train':
116 | # attribute
117 | attribute = 1 - int(att_label == label)
118 | return img, label, label, attribute, index
119 | else:
120 | return img, label, label, index
121 |
122 | def to_color(self, img, rgb=[1,0,0]):
123 | return (img * torch.FloatTensor(rgb).unsqueeze(-1).unsqueeze(-1)).float()
--------------------------------------------------------------------------------
/data/DT_ImageNet_LT.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 |
6 | import os
7 | import json
8 |
9 | import torch
10 | import torch.utils.data as data
11 | import torchvision.transforms as transforms
12 | from PIL import Image
13 |
14 | from randaugment import RandAugment
15 |
16 | class ImageNet_LT(data.Dataset):
17 | def __init__(self, phase, anno_path, testset, rgb_mean, rgb_std, rand_aug, output_path, logger):
18 | super(ImageNet_LT, self).__init__()
19 | valid_phase = ['train', 'val', 'test']
20 | assert phase in valid_phase
21 | if phase == 'train':
22 | full_phase = 'train'
23 | elif phase == 'test':
24 | full_phase = testset
25 | else:
26 | full_phase = phase
27 | logger.info('====== The Current Split is : {}'.format(full_phase))
28 | self.logger = logger
29 |
30 | self.dataset_info = {}
31 | self.phase = phase
32 | self.rand_aug = rand_aug
33 |
34 | # load annotation
35 | self.annotations = json.load(open(anno_path))
36 | self.data = self.annotations[full_phase]
37 |
38 | # get transform
39 | self.transform = self.get_data_transform(phase, rgb_mean, rgb_std)
40 |
41 | # load dataset category info
42 | logger.info('=====> Load dataset category info')
43 | self.id2cat, self.cat2id = self.annotations['id2cat'], self.annotations['cat2id']
44 |
45 | # load all image info
46 | logger.info('=====> Load image info')
47 | self.img_paths, self.labels, self.attributes, self.frequencies = self.load_img_info()
48 |
49 | # save dataset info
50 | logger.info('=====> Save dataset info')
51 | self.dataset_info['cat2id'] = self.cat2id
52 | self.dataset_info['id2cat'] = self.id2cat
53 | self.save_dataset_info(output_path)
54 |
55 |
56 | def __len__(self):
57 | return len(self.labels)
58 |
59 |
60 | def __getitem__(self, index):
61 | path = self.img_paths[index]
62 | label = self.labels[index]
63 | rarity = self.frequencies[index]
64 |
65 | with open(path, 'rb') as f:
66 | sample = Image.open(f).convert('RGB')
67 |
68 | if self.transform is not None:
69 | sample = self.transform(sample)
70 |
71 | # intra-class attribute SHOULD NOT be used during training
72 | if self.phase != 'train':
73 | attribute = self.attributes[index]
74 | return sample, label, rarity, attribute, index
75 | else:
76 | return sample, label, rarity, index
77 |
78 |
79 | #######################################
80 | # Load image info
81 | #######################################
82 | def load_img_info(self):
83 | img_paths = []
84 | labels = []
85 | attributes = []
86 | frequencies = []
87 |
88 |
89 | for path, label in self.data['label'].items():
90 | img_paths.append(path)
91 | labels.append(int(label))
92 | frequencies.append(int(self.data['frequency'][path]))
93 |
94 | # intra-class attribute SHOULD NOT be used in training
95 | if self.phase != 'train':
96 | att_label = int(self.data['attribute'][path])
97 | attributes.append(att_label)
98 |
99 | # save dataset info
100 | self.dataset_info['img_paths'] = img_paths
101 | self.dataset_info['labels'] = labels
102 | self.dataset_info['attributes'] = attributes
103 | self.dataset_info['frequencies'] = frequencies
104 |
105 | return img_paths, labels, attributes, frequencies
106 |
107 |
108 | #######################################
109 | # Save dataset info
110 | #######################################
111 | def save_dataset_info(self, output_path):
112 |
113 | with open(os.path.join(output_path, 'dataset_info_{}.json'.format(self.phase)), 'w') as f:
114 | json.dump(self.dataset_info, f)
115 |
116 | del self.dataset_info
117 |
118 |
119 | #######################################
120 | # transform
121 | #######################################
122 | def get_data_transform(self, phase, rgb_mean, rgb_std):
123 | transform_info = {
124 | 'rgb_mean': rgb_mean,
125 | 'rgb_std': rgb_std,
126 | }
127 |
128 | if phase == 'train':
129 | if self.rand_aug:
130 | self.logger.info('============= Using Rand Augmentation in Dataset ===========')
131 | trans = transforms.Compose([
132 | transforms.RandomResizedCrop(112),
133 | transforms.RandomHorizontalFlip(),
134 | RandAugment(),
135 | transforms.ToTensor(),
136 | transforms.Normalize(rgb_mean, rgb_std)
137 | ])
138 | transform_info['operations'] = ['RandomResizedCrop(112)', 'RandomHorizontalFlip()',
139 | 'RandAugment()', 'ToTensor()', 'Normalize(rgb_mean, rgb_std)']
140 | else:
141 | self.logger.info('============= Using normal transforms in Dataset ===========')
142 | trans = transforms.Compose([
143 | transforms.RandomResizedCrop(112),
144 | transforms.RandomHorizontalFlip(),
145 | transforms.ToTensor(),
146 | transforms.Normalize(rgb_mean, rgb_std)
147 | ])
148 | transform_info['operations'] = ['RandomResizedCrop(112)', 'RandomHorizontalFlip()',
149 | 'ToTensor()', 'Normalize(rgb_mean, rgb_std)']
150 | else:
151 | trans = transforms.Compose([
152 | transforms.Resize(128),
153 | transforms.CenterCrop(112),
154 | transforms.ToTensor(),
155 | transforms.Normalize(rgb_mean, rgb_std)
156 | ])
157 | transform_info['operations'] = ['Resize(128)', 'CenterCrop(112)', 'ToTensor()', 'Normalize(rgb_mean, rgb_std)']
158 |
159 | # save dataset info
160 | self.dataset_info['transform_info'] = transform_info
161 |
162 | return trans
--------------------------------------------------------------------------------
/data/Sampler_ClassAware.py:
--------------------------------------------------------------------------------
1 |
2 | import random
3 | import numpy as np
4 | from torch.utils.data.sampler import Sampler
5 |
6 |
7 |
8 | ##################################
9 | ## Class-aware sampling, partly implemented by frombeijingwithlove
10 | ## github: https://github.com/facebookresearch/classifier-balancing/blob/main/data/ClassAwareSampler.py
11 | ##################################
12 |
13 | class RandomCycleIter:
14 |
15 | def __init__ (self, data, test_mode=False):
16 | self.data_list = list(data)
17 | random.shuffle(self.data_list)
18 | self.length = len(self.data_list)
19 | self.i = self.length - 1
20 | self.test_mode = test_mode
21 |
22 | def __iter__ (self):
23 | return self
24 |
25 | def __next__ (self):
26 | self.i += 1
27 |
28 | if self.i == self.length:
29 | self.i = 0
30 | if not self.test_mode:
31 | random.shuffle(self.data_list)
32 |
33 | return self.data_list[self.i]
34 |
35 | def class_aware_sample_generator (cls_iter, data_iter_list, n, num_samples_cls=1):
36 |
37 | i = 0
38 | j = 0
39 | while i < n:
40 |
41 | # yield next(data_iter_list[next(cls_iter)])
42 |
43 | if j >= num_samples_cls:
44 | j = 0
45 |
46 | if j == 0:
47 | temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]]*num_samples_cls))
48 | yield temp_tuple[j]
49 | else:
50 | yield temp_tuple[j]
51 |
52 | i += 1
53 | j += 1
54 |
55 | class ClassAwareSampler (Sampler):
56 |
57 | def __init__(self, data_source, num_samples_cls=1, max_aug=4):
58 | num_images = len(data_source.labels)
59 | num_classes = len(np.unique(data_source.labels))
60 | self.class_iter = RandomCycleIter(range(num_classes))
61 | cls_data_list = [list() for _ in range(num_classes)]
62 | for i, label in enumerate(data_source.labels):
63 | cls_data_list[label].append(i)
64 | self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list]
65 | self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list)
66 | if self.num_samples > (num_images * max_aug):
67 | self.num_samples = num_images * max_aug
68 | self.num_samples_cls = num_samples_cls
69 |
70 | def __iter__ (self):
71 | return class_aware_sample_generator(self.class_iter, self.data_iter_list,
72 | self.num_samples, self.num_samples_cls)
73 |
74 | def __len__ (self):
75 | return self.num_samples
--------------------------------------------------------------------------------
/data/Sampler_MultiEnv.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import numpy as np
4 | from torch.utils.data.sampler import Sampler
5 |
6 |
7 | class WeightedSampler(Sampler):
8 | def __init__(self, dataset):
9 | self.num_samples = len(dataset)
10 | self.indexes = torch.arange(self.num_samples)
11 | self.weight = torch.zeros_like(self.indexes).fill_(1.0).float() # init weight
12 |
13 |
14 | def __iter__(self):
15 | selected_inds = []
16 | # MAKE SURE self.weight.sum() == self.num_samples
17 | while((self.weight >= 1.0).sum().item() > 0):
18 | inds = self.indexes[self.weight >= 1.0].tolist()
19 | selected_inds = selected_inds + inds
20 | self.weight = self.weight - 1.0
21 | selected_inds = torch.LongTensor(selected_inds)
22 | # shuffle
23 | current_size = selected_inds.shape[0]
24 | selected_inds = selected_inds[torch.randperm(current_size)]
25 | expand = torch.randperm(self.num_samples) % current_size
26 | indices = selected_inds[expand].tolist()
27 |
28 | assert len(indices) == self.num_samples
29 | return iter(indices)
30 |
31 | def __len__(self):
32 | return self.num_samples
33 |
34 | def set_parameter(self, weight):
35 | self.weight = weight.float()
36 |
37 |
38 | class DistributionSampler(Sampler):
39 | def __init__(self, dataset):
40 | self.num_samples = len(dataset)
41 | self.indexes = torch.arange(self.num_samples)
42 | self.weight = torch.zeros_like(self.indexes).fill_(1.0).float() # init weight
43 |
44 |
45 | def __iter__(self):
46 | self.prob = self.weight / self.weight.sum()
47 |
48 | indices = torch.multinomial(self.prob, self.num_samples, replacement=True).tolist()
49 | assert len(indices) == self.num_samples
50 | return iter(indices)
51 |
52 | def __len__(self):
53 | return self.num_samples
54 |
55 | def set_parameter(self, weight):
56 | self.weight = weight.float()
57 |
58 |
59 | class FixSeedSampler(Sampler):
60 | def __init__(self, dataset):
61 | self.dataset = dataset
62 | self.num_samples = len(dataset)
63 |
64 | def __iter__(self):
65 | # deterministically shuffle based on epoch
66 | g = torch.Generator()
67 | g.manual_seed(self.epoch)
68 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
69 | assert len(indices) == self.num_samples
70 | return iter(indices)
71 |
72 | def __len__(self):
73 | return self.num_samples
74 |
75 | def set_parameter(self, epoch):
76 | self.epoch = epoch
77 |
78 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/data/__init__.py
--------------------------------------------------------------------------------
/data/dataloader.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 | import math
5 | import random
6 | import numpy as np
7 | import torch
8 | import torch.utils.data as data
9 | import torch.distributed as dist
10 | from torch.utils.data.sampler import Sampler
11 |
12 |
13 | from .DT_COCO_LT import COCO_LT
14 | from .DT_ColorMNIST import ColorMNIST_LT
15 | from .DT_ImageNet_LT import ImageNet_LT
16 |
17 | from .Sampler_ClassAware import ClassAwareSampler
18 | from .Sampler_MultiEnv import WeightedSampler, DistributionSampler, FixSeedSampler
19 |
20 | ##################################
21 | # return a dataloader
22 | ##################################
23 | def get_loader(config, phase, testset, logger):
24 | if config['dataset']['name'] in ('MSCOCO-LT', 'MSCOCO-BL'):
25 | split = COCO_LT(phase=phase,
26 | data_path=config['dataset']['data_path'],
27 | anno_path=config['dataset']['anno_path'],
28 | testset=testset,
29 | rgb_mean=config['dataset']['rgb_mean'],
30 | rgb_std=config['dataset']['rgb_std'],
31 | rand_aug = config['dataset']['rand_aug'],
32 | output_path=config['output_dir'],
33 | logger=logger)
34 | elif config['dataset']['name'] in ('ColorMNIST-LT', 'ColorMNIST-BL'):
35 | split = ColorMNIST_LT(phase=phase,
36 | testset=testset,
37 | data_path=config['dataset']['data_path'],
38 | cat_ratio=config['dataset']['cat_ratio'],
39 | att_ratio=config['dataset']['att_ratio'],
40 | rand_aug = config['dataset']['rand_aug'],
41 | logger=logger)
42 | elif config['dataset']['name'] in ('ImageNet-LT', 'ImageNet-BL'):
43 | split = ImageNet_LT(phase=phase,
44 | anno_path=config['dataset']['anno_path'],
45 | testset=testset,
46 | rgb_mean=config['dataset']['rgb_mean'],
47 | rgb_std=config['dataset']['rgb_std'],
48 | rand_aug = config['dataset']['rand_aug'],
49 | output_path=config['output_dir'],
50 | logger=logger)
51 | else:
52 | logger.info('********** ERROR: unidentified dataset **********')
53 |
54 |
55 |
56 |
57 | # create data sampler
58 | sampler_type = config['sampler']
59 |
60 | # class aware sampling (re-balancing)
61 | if sampler_type == 'ClassAwareSampler' and phase == 'train':
62 | logger.info('======> Sampler Type {}'.format(sampler_type))
63 | sampler = ClassAwareSampler(split, num_samples_cls=4)
64 | loader = data.DataLoader(split, num_workers=config['training_opt']['data_workers'],
65 | batch_size=config['training_opt']['batch_size'],
66 | sampler=sampler,
67 | pin_memory=True,)
68 | # hard weighted sampling (don't sampling samples with weights smaller than 1.0)
69 | elif sampler_type == 'WeightedSampler' and phase == 'train':
70 | logger.info('======> Sampler Type {}, Sampler Number {}'.format(sampler_type, config['num_sampler']))
71 | loader = []
72 | num_sampler = config['num_sampler']
73 | batch_size = config['training_opt']['batch_size']
74 | if config['batch_split']:
75 | batch_size = batch_size // num_sampler
76 | for _ in range(num_sampler):
77 | loader.append(data.DataLoader(split, num_workers=config['training_opt']['data_workers'],
78 | batch_size=batch_size,
79 | sampler=WeightedSampler(split),
80 | pin_memory=True,))
81 | # soft weighted sampling (sampling samples by the provided weights)
82 | elif sampler_type == 'DistributionSampler' and phase == 'train':
83 | logger.info('======> Sampler Type {}, Sampler Number {}'.format(sampler_type, config['num_sampler']))
84 | loader = []
85 | num_sampler = config['num_sampler']
86 | batch_size = config['training_opt']['batch_size']
87 | if config['batch_split']:
88 | batch_size = batch_size // num_sampler
89 | for _ in range(num_sampler):
90 | loader.append(data.DataLoader(split, num_workers=config['training_opt']['data_workers'],
91 | batch_size=batch_size,
92 | sampler=DistributionSampler(split),
93 | pin_memory=True,))
94 | # Random Sampling with given seed
95 | elif sampler_type == 'FixSeedSampler' and phase == 'train':
96 | logger.info('======> Sampler Type {}, Sampler Number {}'.format(sampler_type, config['num_sampler']))
97 | loader = []
98 | num_sampler = config['num_sampler']
99 | batch_size = config['training_opt']['batch_size']
100 | if config['batch_split']:
101 | batch_size = batch_size // num_sampler
102 | for _ in range(num_sampler):
103 | loader.append(data.DataLoader(split, num_workers=config['training_opt']['data_workers'],
104 | batch_size=batch_size,
105 | sampler=FixSeedSampler(split),
106 | pin_memory=True,))
107 | else:
108 | logger.info('======> Sampler Type Naive Sampling with shuffle type: {}'.format(True if phase == 'train' else False))
109 | loader = data.DataLoader(split, num_workers=config['training_opt']['data_workers'],
110 | batch_size=config['training_opt']['batch_size'],
111 | shuffle=True if phase == 'train' else False,
112 | pin_memory=True,)
113 |
114 | return loader
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/deprecated/_ColorMNISTGeneration/1.DataGeneration.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "5a5db594",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "from PIL import Image, ImageDraw\n",
11 | "from io import BytesIO\n",
12 | "import json\n",
13 | "import joblib\n",
14 | "import os\n",
15 | "import requests\n",
16 | "import random\n",
17 | "\n",
18 | "import cv2\n",
19 | "import numpy as np\n",
20 | "import matplotlib.pyplot as plt\n",
21 | "from sklearn.cluster import KMeans\n",
22 | "from skimage import feature as skif\n",
23 | "\n",
24 | "import torch\n",
25 | "import torchvision\n",
26 | "import torch.nn as nn\n",
27 | "import torch.optim as optim\n",
28 | "import torch.utils.data as data\n",
29 | "import torchvision.transforms as transforms\n",
30 | "from torch.utils.data import Dataset, DataLoader, ConcatDataset\n",
31 | "\n",
32 | "random.seed(25)"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 132,
38 | "id": "74f34318",
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "class ColorMNIST_LT(torchvision.datasets.MNIST):\n",
43 | " def __init__(self, phase, test_type, output_path, logger, cat_ratio=1.0, att_ratio=0.1):\n",
44 | " super(ColorMNIST_LT, self).__init__(root='./', train=(phase == 'train'), download=True)\n",
45 | " # mnist dataset contains self.data, self.targets\n",
46 | " self.dig2label = {0: 0, 1: 0, 2: 0, 3: 0, 4: 1, 5: 1, 6: 1, 7: 2, 8: 2, 9: 2}\n",
47 | " self.dig2attri = {}\n",
48 | " self.colors = {0:[1,0,0], 1:[0,1,0], 2:[0,0,1]}\n",
49 | " \n",
50 | " self.cat_ratio = cat_ratio\n",
51 | " self.att_ratio = att_ratio\n",
52 | " # generate long-tailed data\n",
53 | " self.generate_lt_label(cat_ratio)\n",
54 | " \n",
55 | " \n",
56 | " def generate_lt_label(self, ratio=1.0):\n",
57 | " self.label2list = {i:[] for i in range(3)}\n",
58 | " for img, dig in zip(self.data, self.targets):\n",
59 | " label = self.dig2label[int(dig)]\n",
60 | " self.label2list[label].append(img)\n",
61 | " if ratio == 1.0:\n",
62 | " balance_size = min([len(val) for key, val in self.label2list.items()])\n",
63 | " for key, val in self.label2list.items():\n",
64 | " self.label2list[key] = val[:balance_size]\n",
65 | " elif ratio < 1.0:\n",
66 | " current_size = len(self.label2list[0])\n",
67 | " for key, val in self.label2list.items():\n",
68 | " max_size = len(val)\n",
69 | " self.label2list[key] = val[:min(max_size, current_size)]\n",
70 | " current_size = int(current_size * ratio)\n",
71 | " else:\n",
72 | " raise ValueError('Wrong Ratio in ColorMNIST')\n",
73 | " \n",
74 | " self.lt_labels = []\n",
75 | " self.lt_imgs = []\n",
76 | " for key, val in self.label2list.items():\n",
77 | " for item in val:\n",
78 | " self.lt_labels.append(key)\n",
79 | " self.lt_imgs.append(item)\n",
80 | " print('Generate ColorMNIST: label {} has {} images.'.format(key, len(val)))\n",
81 | " \n",
82 | " \n",
83 | " def __len__(self):\n",
84 | " return len(self.lt_labels)\n",
85 | " \n",
86 | " def __getitem__(self, index):\n",
87 | " img = self.lt_imgs[index].unsqueeze(-1).repeat(1,1,3)\n",
88 | " label = self.lt_labels[index]\n",
89 | " \n",
90 | " # generate tail colors\n",
91 | " if random.random() < self.att_ratio:\n",
92 | " att_label = random.randint(0,2)\n",
93 | " color = self.colors[att_label]\n",
94 | " else:\n",
95 | " color = self.colors[label]\n",
96 | " \n",
97 | " # assign attribute\n",
98 | " img = self.to_color(img, color)\n",
99 | " \n",
100 | " return img, label\n",
101 | " \n",
102 | " def to_color(self, img, rgb=[1,0,0]):\n",
103 | " return (img * torch.FloatTensor(rgb).unsqueeze(0).unsqueeze(0)).byte()"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 133,
109 | "id": "0bddd1e3",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "def visualization(input_tensor):\n",
114 | " return Image.fromarray(input_tensor.numpy()).resize((64,64))"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 135,
120 | "id": "77a12b35",
121 | "metadata": {},
122 | "outputs": [
123 | {
124 | "name": "stdout",
125 | "output_type": "stream",
126 | "text": [
127 | "Generate ColorMNIST: label 0 has 24754 images.\n",
128 | "Generate ColorMNIST: label 1 has 2475 images.\n",
129 | "Generate ColorMNIST: label 2 has 247 images.\n"
130 | ]
131 | }
132 | ],
133 | "source": [
134 | "dataset = ColorMNIST_LT('train', 'test_iid', '.', None, 0.1, 0.1)"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 165,
140 | "id": "72d73ef4",
141 | "metadata": {},
142 | "outputs": [
143 | {
144 | "data": {
145 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAEAAAABACAIAAAAlC+aJAAAGcUlEQVR4nO3a6W9T2RkG8J/tLI4hIdsEsk0yCYQtIQxQllLUUTuqqn7rp/5X/SP6D1QjVZVGoyJ1OtN2gGGAGZYQAgkhG1lIyOIsjt0PVwc5gypsx6FVm0f+YPn6nvM+55znfZ9z7mUf+9jHPvbx/4zYe+umkiQpaqgBadKssU6GbEktV5QzzH+PShrooY/j9IIRhhjiGYtskiu+5T0nEKOaVnoY4GPO0Q9+oJ16UowwyVbx7b8PAi38nGt0c4SOcKmLBO2c4Qv+VBKBvUUlzXzKH1ggwyaZvM8WGRb4Pa0ldbGHM1BFJx9zhUEaQAJhpCvDPxtooZ6Z4qW8VwQqaGSA33KZw+SIkQujvs0GiZCRUrQyzWu2i+yozIhU28FprnKWHoQMk2ObNZaYJ8tRmungGrWMMc0S6f8UgSo+4Cf8hrO0hd+jGYiR4TVPuckan/IJx6njAt9wm+8Zf88EEhyggSa6ucIVuomHP/yoZKZ5zizHWKGOOlrIsMzz90+ghkEGOUYP3RwOks0nkKOaJtpp4jWrTNFOiioa+YBUwf2WjcAhzvIr+ukKet3eWVwjDlVU0cwhKllghAo+JEsNtVS9ZwIVNNHDcdpCoNGKj/Q6xRqdQdDR1TSzbFNJDW1UESNbTCIqA4EKGviQLo5QGfSKdWYY4y6LXKEzpP9ZXjLBNJsc5hw1YeoK95i7JVCdV62OcjD0nWGNUW4xzCRx2hkiyxR3GeIlWbY4wwYHaaCNJlKsF1DXdkUgGvvz/I7zNOeN3Brj3OCPDNNIF+N8xQy3GWGaZXJBxxvEaOM4H9HMTAEWtUQCsVBrexngHB+CLDG2mGOE+yGjR+t7m+dMcof5vNa2WCEDkhyhlz4yzLxLDyUSiFPLUS5yitrwe5YcK0xwnxHWwCuGGCMRZP32iLxBDSeYY4P5vSBQQT1dnOE8vSTJssgrtljgAY+YCL4t/S5rEM/jUEknp3icVwfLSSDJCS5xgQE6SLHCd9xjnlfMMMViWBhFIR7UnCogHZVCoJZ+PuEUbSTBOP/gOi+CNKvZKNJaRog20NVUlJ3AG+32cJJWkizxkG/4J4+YZ5M4qVCMS0Bs56IqG4EK6miji45Q8J/xGV8znKe5XJBvCfv06K5sYfcWRyBFH2fpphqs8Zw73OdVXpe54kOP7fxe4Ay8U+U7UM9P+QVdYJ1JxvLWfcmI5YUbRR8vLLhCZyBOBa0MMEATWZYZZ5yFXR8o/Ih8ho3CMlihBA7QQh+dNFFBlnlGmWSzyHDfxmae3CMPu8hqAbNaKIGowkfRR3lzmxmeMFlSsn+DaPEk83Y8aaYYY6GMZi5BNcmgXawywUNGWS8p9AiH6eR4OJ6IssItvuRpAUNTdCF7I7XIQg7n+YUSkOQoVxkMdWOOIf7OlyyXl0BsZ67YZoW5UgUQDx77Etc4QQ1LDHOPEV4V1k4RBN7W09YuFk9kOS/ySy7RSI4JbnOb2YLbKWUJ5VhnlqXihz9BkgN8xGUuc5yDYf9wl295zGLBDZbihbKM8z0TxR9lHqCL05zlKr1UMcrXfM4Yk7wq+FhOCQRiYRO4WEw3ieACO+nnIoOcJMUoN/mCz0pakIUSeCPfLHEO00s9G4XdHhnYfgbpo4tGNhnjr/yFh6XKqbgZiOxxnFZO0kmK1XfdFeXKK/yMi7QSY4VhbvJnru+iFBZNIFr0MRo4xgVeMM8qGeIkqaWeJmpJcojTnKaPJuKhCD7gLqO7cyKFEsjlfSLUcJJfc48HTJEOR3QdHOMU7TTSSAuHwDJTzDLMDX4oJuHsikDsLY8ebb0vUkdTODxM0EIXPRyllcZw2pVjgWlGecwTHjDKcri65wQSxPMOnGM0k6KNAZZYI04zLdSR4kDwTmnmgkX7lq+YYJm1Up+uFk0gOipcI802ifBA4ABNdLHOOjHqg12NsMUSz3jGTKi134UN5+5RKIFVppngJd3U5q2lSLiV4Wi2Ou+udSZ4xN+4yyorvCxf9EUQSPOSMZ7QQjt1JPKSUiLsALfJkiHNJEPc4nPulC/oUghkSTPKdeboo5duGvP+E83Ja57zgnGeM8ZTRssbdR6KqAMZXvCaYfo5T3YngQhj3OAO9xgPxmbvHsEXZ6c3mA0BRXboKcmQRmJshPz4mKfM7U3Q+SjldZvIEtdwkNTOw49seIMmeonmv+7dh33sYx/72Mf/Gv4FIJ/2aCjskFoAAAAASUVORK5CYII=\n",
146 | "text/plain": [
147 | ""
148 | ]
149 | },
150 | "execution_count": 165,
151 | "metadata": {},
152 | "output_type": "execute_result"
153 | }
154 | ],
155 | "source": [
156 | "visualization(dataset[0][0])"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "id": "a9f1a89c",
163 | "metadata": {},
164 | "outputs": [],
165 | "source": []
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "id": "36f89360",
171 | "metadata": {},
172 | "outputs": [],
173 | "source": []
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "id": "617f415f",
179 | "metadata": {},
180 | "outputs": [],
181 | "source": []
182 | }
183 | ],
184 | "metadata": {
185 | "kernelspec": {
186 | "display_name": "Python 3",
187 | "language": "python",
188 | "name": "python3"
189 | },
190 | "language_info": {
191 | "codemirror_mode": {
192 | "name": "ipython",
193 | "version": 3
194 | },
195 | "file_extension": ".py",
196 | "mimetype": "text/x-python",
197 | "name": "python",
198 | "nbconvert_exporter": "python",
199 | "pygments_lexer": "ipython3",
200 | "version": "3.6.10"
201 | }
202 | },
203 | "nbformat": 4,
204 | "nbformat_minor": 5
205 | }
206 |
--------------------------------------------------------------------------------
/deprecated/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/deprecated/__init__.py
--------------------------------------------------------------------------------
/figure/generalized-long-tail.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/generalized-long-tail.jpg
--------------------------------------------------------------------------------
/figure/glt_formulation.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/glt_formulation.jpg
--------------------------------------------------------------------------------
/figure/ifl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/ifl.jpg
--------------------------------------------------------------------------------
/figure/ifl_code.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/ifl_code.jpg
--------------------------------------------------------------------------------
/figure/imagenet-glt-statistics.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/imagenet-glt-statistics.jpg
--------------------------------------------------------------------------------
/figure/imagenet-glt-visualization.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/imagenet-glt-visualization.jpg
--------------------------------------------------------------------------------
/figure/imagenet-glt.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/imagenet-glt.jpg
--------------------------------------------------------------------------------
/figure/mscoco-glt-statistics.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/mscoco-glt-statistics.jpg
--------------------------------------------------------------------------------
/figure/mscoco-glt-testgeneration.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/mscoco-glt-testgeneration.jpg
--------------------------------------------------------------------------------
/figure/mscoco-glt-visualization.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/mscoco-glt-visualization.jpg
--------------------------------------------------------------------------------
/figure/mscoco-glt.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/figure/mscoco-glt.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import json
6 | import yaml
7 | import os
8 | import argparse
9 | import torch
10 | import torch.nn as nn
11 | import random
12 | import utils.general_utils as utils
13 | from utils.logger_utils import custom_logger
14 | from data.dataloader import get_loader
15 | from utils.checkpoint_utils import Checkpoint
16 | from utils.training_utils import *
17 |
18 |
19 | from utils.train_loader import train_loader
20 | from utils.test_loader import test_loader
21 |
22 | # ============================================================================
23 | # argument parser
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--cfg', default=None, type=str, help='Indicate the config file used for the training.')
26 | parser.add_argument('--seed', default=25, type=int, help='Fix the random seed for reproduction. Default is 25.')
27 | parser.add_argument('--phase', default='train', type=str, help='Indicate train/val/test phase.')
28 | parser.add_argument('--load_dir', default=None, type=str, help='Load model from this directory for testing')
29 | parser.add_argument('--output_dir', default=None, type=str, help='Output directory that saves everything.')
30 | parser.add_argument('--require_eval', action='store_true', help='Require evaluation on val set during training.')
31 | parser.add_argument('--logger_name', default='logger_eval', type=str, help='Name of TXT output for the logger.')
32 | # update config settings
33 | parser.add_argument('--lr', default=None, type=float, help='Learning Rate')
34 | parser.add_argument('--testset', default=None, type=str, help='Reset the type of test set.')
35 | parser.add_argument('--loss_type', default=None, type=str, help='Reset the type of loss function.')
36 | parser.add_argument('--model_type', default=None, type=str, help='Reset the type of model.')
37 | parser.add_argument('--train_type', default=None, type=str, help='Reset the type of traning strategy.')
38 | parser.add_argument('--sample_type', default=None, type=str, help='Reset the type of sampling strategy.')
39 | parser.add_argument('--rand_aug', action='store_true', help='Apply Random Augmentation During Training.')
40 | parser.add_argument('--save_all', action='store_true', help='Save All Output Information During Testing.')
41 |
42 | args = parser.parse_args()
43 |
44 | # ============================================================================
45 | # init logger
46 | if args.output_dir is None:
47 | print('Please specify output directory')
48 | if not os.path.exists(args.output_dir):
49 | os.mkdir(args.output_dir)
50 | if args.phase != 'train':
51 | logger = custom_logger(args.output_dir, name='{}.txt'.format(args.logger_name))
52 | else:
53 | logger = custom_logger(args.output_dir)
54 | logger.info('========================= Start Main =========================')
55 |
56 |
57 | # ============================================================================
58 | # fix random seed
59 | logger.info('=====> Using fixed random seed: ' + str(args.seed))
60 | random.seed(args.seed)
61 | torch.manual_seed(args.seed)
62 | torch.cuda.manual_seed(args.seed)
63 | torch.cuda.manual_seed_all(args.seed)
64 |
65 | # ============================================================================
66 | # load config
67 | logger.info('=====> Load config from yaml: ' + str(args.cfg))
68 | with open(args.cfg) as f:
69 | config = yaml.load(f)
70 |
71 | # load detailed settings for each algorithms
72 | logger.info('=====> Load algorithm details from yaml: config/algorithms_config.yaml')
73 | with open('config/algorithms_config.yaml') as f:
74 | algo_config = yaml.load(f)
75 |
76 | # update config
77 | logger.info('=====> Merge arguments from command')
78 | config = utils.update(config, algo_config, args)
79 |
80 | # save config
81 | logger.info('=====> Save config as config.json')
82 | with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
83 | json.dump(config, f)
84 | utils.print_config(config, logger)
85 |
86 | # ============================================================================
87 | # training
88 | if args.phase == 'train':
89 | logger.info('========= The Current Training Strategy is {} ========='.format(config['training_opt']['type']))
90 | train_func = train_loader(config)
91 | training = train_func(args, config, logger, eval=args.require_eval)
92 | training.run()
93 |
94 |
95 | else:
96 | # ============================================================================
97 | # create model
98 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
99 | model_type = config['networks']['type']
100 | model_file = config['networks'][model_type]['def_file']
101 | model_args = config['networks'][model_type]['params']
102 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
103 | classifier_type = config['classifiers']['type']
104 | classifier_file = config['classifiers'][classifier_type]['def_file']
105 | classifier_args = config['classifiers'][classifier_type]['params']
106 | model = utils.source_import(model_file).create_model(**model_args)
107 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
108 |
109 | model = nn.DataParallel(model).cuda()
110 | classifier = nn.DataParallel(classifier).cuda()
111 |
112 | # ============================================================================
113 | # load checkpoint
114 | checkpoint = Checkpoint(config)
115 | ckpt = checkpoint.load(model, classifier, args.load_dir, logger)
116 |
117 | # ============================================================================
118 | # testing
119 | test_func = test_loader(config)
120 | if args.phase == 'val':
121 | # run validation set
122 | testing = test_func(config, logger, model, classifier, val=True, add_ckpt=ckpt)
123 | testing.run_val(epoch=-1)
124 | else:
125 | assert args.phase == 'test'
126 | # Run a specific test split
127 | if args.testset:
128 | testing = test_func(config, logger, model, classifier, val=False, add_ckpt=ckpt)
129 | testing.run_test()
130 | # Run all test splits
131 | else:
132 | if 'LT' in config['dataset']['name']:
133 | config['dataset']['testset'] = 'test_lt'
134 | testing = test_func(config, logger, model, classifier, val=False, add_ckpt=ckpt)
135 | testing.run_test()
136 |
137 | config['dataset']['testset'] = 'test_bl'
138 | testing = test_func(config, logger, model, classifier, val=False, add_ckpt=ckpt)
139 | testing.run_test()
140 |
141 | config['dataset']['testset'] = 'test_bbl'
142 | testing = test_func(config, logger, model, classifier, val=False, add_ckpt=ckpt)
143 | testing.run_test()
144 |
145 | logger.info('========================= Complete =========================')
146 |
--------------------------------------------------------------------------------
/models/ClassifierCOS.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.models as models
8 |
9 | import math
10 |
11 | class ClassifierCOS(nn.Module):
12 | def __init__(self, feat_dim, num_classes=1000, num_head=2, tau=16.0):
13 | super(ClassifierCOS, self).__init__()
14 |
15 | # classifier weights
16 | self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim).cuda(), requires_grad=True)
17 | self.reset_parameters(self.weight)
18 |
19 | # parameters
20 | self.scale = tau / num_head # 16.0 / num_head
21 | self.num_head = num_head
22 | self.head_dim = feat_dim // num_head
23 |
24 | def reset_parameters(self, weight):
25 | stdv = 1. / math.sqrt(weight.size(1))
26 | weight.data.uniform_(-stdv, stdv)
27 |
28 | def forward(self, x, add_inputs=None):
29 | normed_x = self.multi_head_call(self.l2_norm, x)
30 | normed_w = self.multi_head_call(self.l2_norm, self.weight)
31 | y = torch.mm(normed_x * self.scale, normed_w.t())
32 | return y
33 |
34 | def multi_head_call(self, func, x):
35 | assert len(x.shape) == 2
36 | x_list = torch.split(x, self.head_dim, dim=1)
37 | y_list = [func(item) for item in x_list]
38 | assert len(x_list) == self.num_head
39 | assert len(y_list) == self.num_head
40 | return torch.cat(y_list, dim=1)
41 |
42 | def l2_norm(self, x):
43 | normed_x = x / torch.norm(x, 2, 1, keepdim=True)
44 | return normed_x
45 |
46 | def create_model(feat_dim=2048, num_classes=1000, num_head=2, tau=16.0):
47 | model = ClassifierCOS(feat_dim=feat_dim, num_classes=num_classes, num_head=num_head, tau=tau)
48 | return model
--------------------------------------------------------------------------------
/models/ClassifierFC.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.models as models
8 |
9 | import math
10 |
11 | class ClassifierFC(nn.Module):
12 | def __init__(self, feat_dim, num_classes=1000):
13 | super(ClassifierFC, self).__init__()
14 |
15 | self.fc = nn.Linear(feat_dim, num_classes, bias=False)
16 |
17 | def forward(self, x, add_inputs=None):
18 | y = self.fc(x)
19 | return y
20 |
21 |
22 | def create_model(feat_dim=2048, num_classes=1000):
23 | model = ClassifierFC(feat_dim=feat_dim, num_classes=num_classes)
24 | return model
--------------------------------------------------------------------------------
/models/ClassifierLA.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.models as models
8 |
9 | import math
10 |
11 | class ClassifierLA(nn.Module):
12 | def __init__(self, feat_dim, num_classes=1000, posthoc=False, loss=False):
13 | super(ClassifierLA, self).__init__()
14 |
15 | self.posthoc = posthoc
16 | self.loss = loss
17 | assert (self.posthoc and self.loss) == False
18 | assert (self.posthoc or self.loss) == True
19 |
20 | self.fc = nn.Linear(feat_dim, num_classes, bias=False)
21 |
22 | def forward(self, x, add_inputs=None):
23 | y = self.fc(x)
24 | if self.training and self.loss:
25 | logit_adj = add_inputs['logit_adj']
26 | y = y + logit_adj
27 | if (not self.training) and self.posthoc:
28 | logit_adj = add_inputs['logit_adj']
29 | y = y - logit_adj
30 | return y
31 |
32 |
33 | def create_model(feat_dim=2048, num_classes=1000, posthoc=True, loss=False):
34 | model = ClassifierLA(feat_dim=feat_dim, num_classes=num_classes, posthoc=posthoc, loss=loss)
35 | return model
--------------------------------------------------------------------------------
/models/ClassifierLDAM.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.models as models
8 | import torch.nn.functional as F
9 |
10 | import math
11 |
12 | class ClassifierLDAM(nn.Module):
13 | def __init__(self, feat_dim, num_classes=1000):
14 | super(ClassifierLDAM, self).__init__()
15 |
16 | self.weight = nn.Parameter(torch.Tensor(feat_dim, num_classes).cuda(), requires_grad=True)
17 | self.weight.data.uniform_(-1, 1)
18 | self.weight.data.renorm_(2, 1, 1e-5)
19 | self.weight.data.mul_(1e5)
20 |
21 |
22 | def forward(self, x, add_inputs=None):
23 | y = torch.mm(F.normalize(x, dim=1), F.normalize(self.weight, dim=0))
24 | return y
25 |
26 |
27 | def create_model(feat_dim=2048, num_classes=1000):
28 | model = ClassifierLDAM(feat_dim=feat_dim, num_classes=num_classes)
29 | return model
--------------------------------------------------------------------------------
/models/ClassifierLWS.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.models as models
8 |
9 | import math
10 |
11 | class ClassifierLWS(nn.Module):
12 | def __init__(self, feat_dim, num_classes=1000):
13 | super(ClassifierLWS, self).__init__()
14 |
15 | self.fc = nn.Linear(feat_dim, num_classes, bias=False)
16 |
17 | self.scales = nn.Parameter(torch.ones(num_classes))
18 | for _, param in self.fc.named_parameters():
19 | param.requires_grad = False
20 |
21 | def forward(self, x, add_inputs=None):
22 | y = self.fc(x)
23 | y *= self.scales
24 | return y
25 |
26 |
27 | def create_model(feat_dim=2048, num_classes=1000):
28 | model = ClassifierLWS(feat_dim=feat_dim, num_classes=num_classes)
29 | return model
--------------------------------------------------------------------------------
/models/ClassifierMultiHead.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.models as models
8 |
9 | import math
10 |
11 | class ClassifierMultiHead(nn.Module):
12 | def __init__(self, feat_dim, num_classes=1000):
13 | super(ClassifierMultiHead, self).__init__()
14 |
15 | self.fc1 = nn.Linear(feat_dim, num_classes, bias=False)
16 | self.fc2 = nn.Linear(feat_dim, feat_dim, bias=False)
17 | self.fc3 = nn.Linear(feat_dim, feat_dim, bias=False)
18 |
19 | def forward(self, x, add_inputs=None):
20 | y1 = self.fc1(x.detach()) # prediction (re-train in stage2)
21 | y2 = self.fc2(x) # contrastive head
22 | y3 = self.fc3(x) # metric head
23 | return y1, y2, y3
24 |
25 |
26 | def create_model(feat_dim=2048, num_classes=1000):
27 | model = ClassifierMultiHead(feat_dim=feat_dim, num_classes=num_classes)
28 | return model
--------------------------------------------------------------------------------
/models/ClassifierRIDE.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 | import math
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchvision.models as models
9 |
10 |
11 |
12 | class NormedLinear(nn.Module):
13 | def __init__(self, in_features, out_features):
14 | super(NormedLinear, self).__init__()
15 | self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
16 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
17 |
18 | def forward(self, x):
19 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
20 | return out
21 |
22 | class ClassifierRIDE(nn.Module):
23 | def __init__(self, feat_dim, num_classes=1000, num_experts=3, use_norm=True):
24 | super(ClassifierRIDE, self).__init__()
25 | self.num_experts = num_experts
26 | if use_norm:
27 | self.linears = nn.ModuleList([NormedLinear(feat_dim, num_classes) for _ in range(num_experts)])
28 | s = 30
29 | else:
30 | self.linears = nn.ModuleList([nn.Linear(feat_dim, num_classes) for _ in range(num_experts)])
31 | s = 1
32 | self.s = s
33 | def forward(self, x, add_inputs=None, index=None):
34 | if index is None:
35 | logits = []
36 | for ind in range(self.num_experts):
37 | logit = (self.linears[ind])(x[:, ind, :])
38 | logits.append(logit * self.s)
39 | y = torch.stack(logits, dim=1).mean(dim=1)
40 | return y, logits
41 | else:
42 | logit = (self.linears[index])(x)
43 | return logit
44 |
45 |
46 | def create_model(feat_dim=2048, num_classes=1000, num_experts=3, use_norm=True):
47 | model = ClassifierRIDE(feat_dim=feat_dim, num_classes=num_classes,
48 | num_experts=num_experts, use_norm=use_norm)
49 | return model
--------------------------------------------------------------------------------
/models/ClassifierTDE.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.models as models
8 |
9 | import math
10 |
11 | class ClassifierTDE(nn.Module):
12 | def __init__(self, feat_dim, num_classes=1000, num_head=2, tau=16.0, alpha=3.0, gamma=0.03125):
13 | super(ClassifierTDE, self).__init__()
14 |
15 | # classifier weights
16 | self.weight = nn.Parameter(torch.Tensor(num_classes, feat_dim).cuda(), requires_grad=True)
17 | self.reset_parameters(self.weight)
18 |
19 | # parameters
20 | self.scale = tau / num_head # 16.0 / num_head
21 | self.norm_scale = gamma
22 | self.alpha = alpha
23 | self.num_head = num_head
24 | self.head_dim = feat_dim // num_head
25 |
26 | def reset_parameters(self, weight):
27 | stdv = 1. / math.sqrt(weight.size(1))
28 | weight.data.uniform_(-stdv, stdv)
29 |
30 | def forward(self, x, add_inputs=None):
31 | normed_x = self.multi_head_call(self.l2_norm, x)
32 | normed_w = self.multi_head_call(self.causal_norm, self.weight, weight=self.norm_scale)
33 | y = torch.mm(normed_x * self.scale, normed_w.t())
34 |
35 | # apply TDE during inference
36 | if (not self.training):
37 | self.embed = add_inputs['embed']
38 | normed_c = self.multi_head_call(self.l2_norm, self.embed)
39 | x_list = torch.split(normed_x, self.head_dim, dim=1)
40 | c_list = torch.split(normed_c, self.head_dim, dim=1)
41 | w_list = torch.split(normed_w, self.head_dim, dim=1)
42 | output = []
43 |
44 | for nx, nc, nw in zip(x_list, c_list, w_list):
45 | cos_val, sin_val = self.get_cos_sin(nx, nc)
46 | y0 = torch.mm((nx - cos_val * self.alpha * nc) * self.scale, nw.t())
47 | output.append(y0)
48 | y = sum(output)
49 | return y
50 |
51 | def get_cos_sin(self, x, y):
52 | cos_val = (x * y).sum(-1, keepdim=True) / torch.norm(x, 2, 1, keepdim=True) / torch.norm(y, 2, 1, keepdim=True)
53 | sin_val = (1 - cos_val * cos_val).sqrt()
54 | return cos_val, sin_val
55 |
56 | def multi_head_call(self, func, x, weight=None):
57 | assert len(x.shape) == 2
58 | x_list = torch.split(x, self.head_dim, dim=1)
59 | if weight:
60 | y_list = [func(item, weight) for item in x_list]
61 | else:
62 | y_list = [func(item) for item in x_list]
63 | assert len(x_list) == self.num_head
64 | assert len(y_list) == self.num_head
65 | return torch.cat(y_list, dim=1)
66 |
67 | def l2_norm(self, x):
68 | normed_x = x / torch.norm(x, 2, 1, keepdim=True)
69 | return normed_x
70 |
71 | def causal_norm(self, x, weight):
72 | norm= torch.norm(x, 2, 1, keepdim=True)
73 | normed_x = x / (norm + weight)
74 | return normed_x
75 |
76 | def create_model(feat_dim=2048, num_classes=1000, num_head=2, tau=16.0, alpha=3.0, gamma=0.03125):
77 | model = ClassifierTDE(feat_dim=feat_dim, num_classes=num_classes, num_head=num_head, tau=tau, alpha=alpha, gamma=gamma)
78 | return model
--------------------------------------------------------------------------------
/models/ResNet.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.models as models
8 |
9 |
10 | def create_model(m_type='resnet101'):
11 | # create various resnet models
12 | if m_type == 'resnet18':
13 | model = models.resnet18(pretrained=False)
14 | elif m_type == 'resnet50':
15 | model = models.resnet50(pretrained=False)
16 | elif m_type == 'resnet101':
17 | model = models.resnet101(pretrained=False)
18 | elif m_type == 'resnext50':
19 | model = models.resnext50_32x4d(pretrained=False)
20 | elif m_type == 'resnext101':
21 | model = models.resnext101_32x8d(pretrained=False)
22 | else:
23 | raise ValueError('Wrong Model Type')
24 | model.fc = nn.ReLU()
25 | return model
--------------------------------------------------------------------------------
/models/ResNet_BBN.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.models as models
8 | import torch.nn.functional as F
9 | import math
10 |
11 | class BottleNeck(nn.Module):
12 |
13 | expansion = 4
14 |
15 | def __init__(self, inplanes, planes, stride=1):
16 | super(BottleNeck, self).__init__()
17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.relu1 = nn.ReLU(True)
20 | self.conv2 = nn.Conv2d(
21 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
22 | )
23 | self.bn2 = nn.BatchNorm2d(planes)
24 | self.relu2 = nn.ReLU(True)
25 | self.conv3 = nn.Conv2d(
26 | planes, planes * self.expansion, kernel_size=1, bias=False
27 | )
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | if stride != 1 or self.expansion * planes != inplanes:
30 | self.downsample = nn.Sequential(
31 | nn.Conv2d(
32 | inplanes,
33 | self.expansion * planes,
34 | kernel_size=1,
35 | stride=stride,
36 | bias=False,
37 | ),
38 | nn.BatchNorm2d(self.expansion * planes),
39 | )
40 | else:
41 | self.downsample = None
42 | self.relu = nn.ReLU(True)
43 |
44 | def forward(self, x):
45 | out = self.relu1(self.bn1(self.conv1(x)))
46 |
47 | out = self.relu2(self.bn2(self.conv2(out)))
48 |
49 | out = self.bn3(self.conv3(out))
50 |
51 | if self.downsample != None:
52 | residual = self.downsample(x)
53 | else:
54 | residual = x
55 | out = out + residual
56 | out = self.relu(out)
57 | return out
58 |
59 |
60 |
61 | class BBN_ResNet(nn.Module):
62 | def __init__(
63 | self,
64 | block_type,
65 | num_blocks,
66 | last_layer_stride=2,
67 | ):
68 | super(BBN_ResNet, self).__init__()
69 | self.inplanes = 64
70 | self.block = block_type
71 |
72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
73 | self.bn1 = nn.BatchNorm2d(64)
74 | self.relu = nn.ReLU(True)
75 | self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
76 |
77 | self.layer1 = self._make_layer(num_blocks[0], 64)
78 | self.layer2 = self._make_layer(num_blocks[1], 128, stride=2)
79 | self.layer3 = self._make_layer(num_blocks[2], 256, stride=2)
80 | self.layer4 = self._make_layer(num_blocks[3] - 1, 512, stride=last_layer_stride)
81 |
82 | self.cb_block = self.block(self.inplanes, self.inplanes // 4, stride=1)
83 | self.rb_block = self.block(self.inplanes, self.inplanes // 4, stride=1)
84 |
85 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
86 |
87 | def load_model(self, pretrain):
88 | print("Loading Backbone pretrain model from {}......".format(pretrain))
89 | model_dict = self.state_dict()
90 | pretrain_dict = torch.load(pretrain)
91 | pretrain_dict = pretrain_dict["state_dict"] if "state_dict" in pretrain_dict else pretrain_dict
92 | from collections import OrderedDict
93 |
94 | new_dict = OrderedDict()
95 | for k, v in pretrain_dict.items():
96 | if k.startswith("module"):
97 | k = k[7:]
98 | if "fc" not in k and "classifier" not in k:
99 | k = k.replace("backbone.", "")
100 | new_dict[k] = v
101 |
102 | model_dict.update(new_dict)
103 | self.load_state_dict(model_dict)
104 | print("Backbone model has been loaded......")
105 |
106 | def _make_layer(self, num_block, planes, stride=1):
107 | strides = [stride] + [1] * (num_block - 1)
108 | layers = []
109 | for now_stride in strides:
110 | layers.append(self.block(self.inplanes, planes, stride=now_stride))
111 | self.inplanes = planes * self.block.expansion
112 | return nn.Sequential(*layers)
113 |
114 | def forward(self, x, **kwargs):
115 | out = self.conv1(x)
116 | out = self.bn1(out)
117 | out = self.relu(out)
118 | out = self.pool(out)
119 |
120 | out = self.layer1(out)
121 | out = self.layer2(out)
122 | out = self.layer3(out)
123 | out = self.layer4(out)
124 |
125 | if "feature_cb" in kwargs:
126 | out = self.cb_block(out)
127 | elif "feature_rb" in kwargs:
128 | out = self.rb_block(out)
129 | else:
130 | out1 = self.cb_block(out)
131 | out2 = self.rb_block(out)
132 | out = torch.cat((out1, out2), dim=1)
133 |
134 | out = self.avgpool(out)
135 | out = out.view(x.shape[0], -1)
136 | return out
137 |
138 | def create_model(m_type='bbn_resnet50'):
139 | # create various resnet models
140 | model = BBN_ResNet(BottleNeck, [3, 4, 6, 3], last_layer_stride=2,)
141 | return model
--------------------------------------------------------------------------------
/models/ResNet_RIDE.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import math
6 | from typing import ForwardRef
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torchvision.models as models
11 |
12 | def conv3x3(in_planes, out_planes, stride=1):
13 | """3x3 convolution with padding"""
14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15 | padding=1, bias=False)
16 |
17 |
18 | class BasicBlock(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(self, inplanes, planes, stride=1, downsample=None):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(inplanes, planes, stride)
24 | self.bn1 = nn.BatchNorm2d(planes)
25 | self.relu = nn.ReLU(inplace=True)
26 | self.conv2 = conv3x3(planes, planes)
27 | self.bn2 = nn.BatchNorm2d(planes)
28 | self.downsample = downsample
29 | self.stride = stride
30 |
31 | def forward(self, x):
32 | residual = x
33 |
34 | out = self.conv1(x)
35 | out = self.bn1(out)
36 | out = self.relu(out)
37 |
38 | out = self.conv2(out)
39 | out = self.bn2(out)
40 |
41 | if self.downsample is not None:
42 | residual = self.downsample(x)
43 |
44 | out += residual
45 | out = self.relu(out)
46 |
47 | return out
48 |
49 | class Bottleneck(nn.Module):
50 | expansion = 4
51 |
52 | def __init__(self, inplanes, planes, stride=1, downsample=None,
53 | groups=1, base_width=64, is_last=False):
54 | super(Bottleneck, self).__init__()
55 | width = int(planes * (base_width / 64.)) * groups
56 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False)
57 | self.bn1 = nn.BatchNorm2d(width)
58 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
59 | groups=groups, padding=1, bias=False)
60 | self.bn2 = nn.BatchNorm2d(width)
61 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
62 | self.bn3 = nn.BatchNorm2d(planes * 4)
63 | self.relu = nn.ReLU(inplace=True)
64 | self.downsample = downsample
65 | self.stride = stride
66 | self.is_last = is_last
67 |
68 | def forward(self, x):
69 | residual = x
70 |
71 | out = self.conv1(x)
72 | out = self.bn1(out)
73 | out = self.relu(out)
74 |
75 | out = self.conv2(out)
76 | out = self.bn2(out)
77 | out = self.relu(out)
78 |
79 | out = self.conv3(out)
80 | out = self.bn3(out)
81 |
82 | if self.downsample is not None:
83 | residual = self.downsample(x)
84 |
85 | out += residual
86 | out = self.relu(out)
87 |
88 | return out
89 |
90 | class ResNext(nn.Module):
91 | def __init__(self, block, layers, groups=1, width_per_group=64, num_experts=1, reduce_dimension=False):
92 | self.inplanes = 64
93 | self.num_experts = num_experts
94 | super(ResNext, self).__init__()
95 |
96 | self.groups = groups
97 | self.base_width = width_per_group
98 |
99 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
100 | bias=False)
101 | self.bn1 = nn.BatchNorm2d(64)
102 | self.relu = nn.ReLU(inplace=True)
103 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
104 | self.layer1 = self._make_layer(block, 64, layers[0])
105 | self.inplanes = self.next_inplanes
106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
107 | self.inplanes = self.next_inplanes
108 |
109 | if reduce_dimension:
110 | layer3_output_dim = 192
111 | else:
112 | layer3_output_dim = 256
113 |
114 | if reduce_dimension:
115 | layer4_output_dim = 384
116 | else:
117 | layer4_output_dim = 512
118 |
119 | self.layer3s = nn.ModuleList([self._make_layer(block, layer3_output_dim, layers[2], stride=2) for _ in range(num_experts)])
120 | self.inplanes = self.next_inplanes
121 | self.layer4s = nn.ModuleList([self._make_layer(block, layer4_output_dim, layers[3], stride=2) for _ in range(num_experts)])
122 | self.inplanes = self.next_inplanes
123 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
124 |
125 |
126 | for m in self.modules():
127 | if isinstance(m, nn.Conv2d):
128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
129 | m.weight.data.normal_(0, math.sqrt(2. / n))
130 | elif isinstance(m, nn.BatchNorm2d):
131 | m.weight.data.fill_(1)
132 | m.bias.data.zero_()
133 |
134 |
135 | def _make_layer(self, block, planes, blocks, stride=1, is_last=False):
136 | downsample = None
137 | if stride != 1 or self.inplanes != planes * block.expansion:
138 | downsample = nn.Sequential(
139 | nn.Conv2d(self.inplanes, planes * block.expansion,
140 | kernel_size=1, stride=stride, bias=False),
141 | nn.BatchNorm2d(planes * block.expansion),
142 | )
143 |
144 | layers = []
145 | layers.append(block(self.inplanes, planes, stride, downsample,
146 | groups=self.groups, base_width=self.base_width))
147 | self.next_inplanes = planes * block.expansion
148 | for i in range(1, blocks):
149 | layers.append(block(self.next_inplanes, planes,
150 | groups=self.groups, base_width=self.base_width,
151 | is_last=(is_last and i == blocks-1)))
152 |
153 | return nn.Sequential(*layers)
154 |
155 | def _separate_part(self, x, ind):
156 | x = (self.layer3s[ind])(x)
157 | x = (self.layer4s[ind])(x)
158 | x = self.avgpool(x)
159 | x = x.view(x.size(0), -1)
160 | return x
161 |
162 | def forward(self, x, index=None):
163 | x = self.conv1(x)
164 | x = self.bn1(x)
165 | x = self.relu(x)
166 | x = self.maxpool(x)
167 |
168 | x = self.layer1(x)
169 | x = self.layer2(x)
170 |
171 | if index is None:
172 | feats = []
173 | for ind in range(self.num_experts):
174 | feats.append(self._separate_part(x, ind))
175 | return torch.stack(feats, dim=1)
176 | else:
177 | return self._separate_part(x, index)
178 |
179 | class ResNeXt50Model(nn.Module):
180 | def __init__(self, reduce_dimension=False, num_experts=1):
181 | super(ResNeXt50Model, self).__init__()
182 | self.backbone = ResNext(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, reduce_dimension=reduce_dimension, num_experts=num_experts)
183 | def forward(self, x, index=None):
184 | return self.backbone(x, index)
185 |
186 |
187 | def create_model(m_type='resnext50', num_experts=3, reduce_dimension=True):
188 | # create various resnet models
189 | if m_type == 'resnext50':
190 | model = ResNeXt50Model(reduce_dimension=reduce_dimension, num_experts=num_experts)
191 | return model
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/models/__init__.py
--------------------------------------------------------------------------------
/train_baseline.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | class train_baseline():
18 | def __init__(self, args, config, logger, eval=False):
19 | # ============================================================================
20 | # create model
21 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
22 | model_type = config['networks']['type']
23 | model_file = config['networks'][model_type]['def_file']
24 | model_args = config['networks'][model_type]['params']
25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
26 | classifier_type = config['classifiers']['type']
27 | classifier_file = config['classifiers'][classifier_type]['def_file']
28 | classifier_args = config['classifiers'][classifier_type]['params']
29 | model = utils.source_import(model_file).create_model(**model_args)
30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
31 |
32 | model = nn.DataParallel(model).cuda()
33 | classifier = nn.DataParallel(classifier).cuda()
34 |
35 | # other initialization
36 | self.config = config
37 | self.logger = logger
38 | self.model = model
39 | self.classifier = classifier
40 | self.optimizer = create_optimizer(model, classifier, logger, config)
41 | self.scheduler = create_scheduler(self.optimizer, logger, config)
42 | self.eval = eval
43 | self.training_opt = config['training_opt']
44 |
45 | self.checkpoint = Checkpoint(config)
46 |
47 | # get dataloader
48 | self.logger.info('=====> Get train dataloader')
49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
50 |
51 | # get loss
52 | self.loss_fc = create_loss(logger, config, self.train_loader)
53 |
54 | # set eval
55 | if self.eval:
56 | test_func = test_loader(config)
57 | self.testing = test_func(config, logger, model, classifier, val=True)
58 |
59 |
60 | def run(self):
61 | # Start Training
62 | self.logger.info('=====> Start Baseline Training')
63 |
64 | # run epoch
65 | for epoch in range(self.training_opt['num_epochs']):
66 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
67 |
68 | # preprocess for each epoch
69 | total_batch = len(self.train_loader)
70 |
71 | for step, (inputs, labels, _, _) in enumerate(self.train_loader):
72 | self.optimizer.zero_grad()
73 |
74 | # additional inputs
75 | inputs, labels = inputs.cuda(), labels.cuda()
76 | add_inputs = {}
77 |
78 | features = self.model(inputs)
79 | predictions = self.classifier(features, add_inputs)
80 |
81 | # calculate loss
82 | loss = self.loss_fc(predictions, labels)
83 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),}
84 |
85 | loss.backward()
86 | self.optimizer.step()
87 |
88 | # calculate accuracy
89 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0]
90 |
91 | # log information
92 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
93 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
94 |
95 | first_batch = (epoch == 0) and (step == 0)
96 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
97 | utils.print_grad(self.classifier.named_parameters())
98 | utils.print_grad(self.model.named_parameters())
99 |
100 | # evaluation on validation set
101 | if self.eval:
102 | val_acc = self.testing.run_val(epoch)
103 | else:
104 | val_acc = 0.0
105 |
106 | # checkpoint
107 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc)
108 |
109 | # update scheduler
110 | self.scheduler.step()
111 |
112 | # save best model path
113 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/train_bbn.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | class train_bbn():
18 | def __init__(self, args, config, logger, eval=False):
19 | # ============================================================================
20 | # create model
21 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
22 | model_type = config['networks']['type']
23 | model_file = config['networks'][model_type]['def_file']
24 | model_args = config['networks'][model_type]['params']
25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
26 | classifier_type = config['classifiers']['type']
27 | classifier_file = config['classifiers'][classifier_type]['def_file']
28 | classifier_args = config['classifiers'][classifier_type]['params']
29 | model = utils.source_import(model_file).create_model(**model_args)
30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
31 |
32 | model = nn.DataParallel(model).cuda()
33 | classifier = nn.DataParallel(classifier).cuda()
34 |
35 | # other initialization
36 | self.config = config
37 | self.logger = logger
38 | self.model = model
39 | self.classifier = classifier
40 | self.optimizer = create_optimizer(model, classifier, logger, config)
41 | self.scheduler = create_scheduler(self.optimizer, logger, config)
42 | self.eval = eval
43 | self.training_opt = config['training_opt']
44 |
45 | self.checkpoint = Checkpoint(config)
46 |
47 | # get dataloader
48 | self.logger.info('=====> Get train dataloader')
49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
50 |
51 | # get loss
52 | self.loss_fc = create_loss(logger, config, self.train_loader)
53 |
54 | # set eval
55 | if self.eval:
56 | test_func = test_loader(config)
57 | self.testing = test_func(config, logger, model, classifier, val=True)
58 |
59 |
60 | def calculate_reverse_instance_weight(self, dataloader):
61 | # counting frequency
62 | label_freq = {}
63 | for key in dataloader.dataset.labels:
64 | label_freq[key] = label_freq.get(key, 0) + 1
65 | label_freq = dict(sorted(label_freq.items()))
66 | label_freq_array = torch.FloatTensor(list(label_freq.values()))
67 | reverse_class_weight = label_freq_array.max() / label_freq_array
68 | # generate reverse weight
69 | reverse_instance_weight = torch.zeros(len(dataloader.dataset)).fill_(1.0)
70 | for i, label in enumerate(dataloader.dataset.labels):
71 | reverse_instance_weight[i] = reverse_class_weight[label] / (label_freq_array[label] + 1e-9)
72 | return reverse_instance_weight
73 |
74 |
75 | def run(self):
76 | # Start Training
77 | self.logger.info('=====> Start BBN Training')
78 |
79 | # preprocess for each epoch
80 | env1_loader, env2_loader = self.train_loader
81 | assert len(env1_loader) == len(env2_loader)
82 | total_batch = len(env1_loader)
83 | total_image = len(env1_loader.dataset)
84 |
85 | # set dataloader distribution
86 | instance_normal_weight = torch.zeros(total_image).fill_(1.0)
87 | env1_loader.sampler.set_parameter(instance_normal_weight) # conventional distribution
88 | instance_reverse_weight = self.calculate_reverse_instance_weight(env1_loader)
89 | env2_loader.sampler.set_parameter(instance_reverse_weight) # reverse distribution
90 |
91 | # run epoch
92 | num_epoch = self.training_opt['num_epochs']
93 | for epoch in range(num_epoch):
94 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
95 |
96 | for step, ((inputs1, labels1, _, indexs1), (inputs2, labels2, _, indexs2)) in enumerate(zip(env1_loader, env2_loader)):
97 | iter_info_print = {}
98 |
99 | self.optimizer.zero_grad()
100 |
101 | # additional inputs
102 | inputs1, inputs2 = inputs1.cuda(), inputs2.cuda()
103 | labels1, labels2 = labels1.cuda(), labels2.cuda()
104 |
105 | feature1 = self.model(inputs1, feature_cb=True)
106 | feature2 = self.model(inputs2, feature_rb=True)
107 |
108 | l = 1 - ((epoch - 1) / num_epoch) ** 2 # parabolic decay
109 |
110 | mixed_feature = 2 * torch.cat((l * feature1, (1-l) * feature2), dim=1)
111 |
112 | predictions = self.classifier(mixed_feature)
113 |
114 | # calculate loss
115 | loss = l * self.loss_fc(predictions, labels1) + (1 - l) * self.loss_fc(predictions, labels2)
116 | iter_info_print = {'BBN mixup loss': loss.sum().item(),}
117 |
118 | loss.backward()
119 | self.optimizer.step()
120 |
121 | # calculate accuracy
122 | accuracy = l * (predictions.max(1)[1] == labels1).float() + (1 - l) * (predictions.max(1)[1] == labels2).float()
123 | accuracy = accuracy.sum() / accuracy.shape[0]
124 |
125 | # log information
126 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
127 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
128 |
129 | first_batch = (epoch == 0) and (step == 0)
130 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
131 | utils.print_grad(self.classifier.named_parameters())
132 | utils.print_grad(self.model.named_parameters())
133 |
134 | # evaluation on validation set
135 | if self.eval:
136 | val_acc = self.testing.run_val(epoch)
137 | else:
138 | val_acc = 0.0
139 |
140 | # checkpoint
141 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc)
142 |
143 | # update scheduler
144 | self.scheduler.step()
145 |
146 | # save best model path
147 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/train_la.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | class train_la():
18 | def __init__(self, args, config, logger, eval=False):
19 | # ============================================================================
20 | # create model
21 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
22 | model_type = config['networks']['type']
23 | model_file = config['networks'][model_type]['def_file']
24 | model_args = config['networks'][model_type]['params']
25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
26 | classifier_type = config['classifiers']['type']
27 | classifier_file = config['classifiers'][classifier_type]['def_file']
28 | classifier_args = config['classifiers'][classifier_type]['params']
29 | model = utils.source_import(model_file).create_model(**model_args)
30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
31 |
32 | model = nn.DataParallel(model).cuda()
33 | classifier = nn.DataParallel(classifier).cuda()
34 |
35 | # other initialization
36 | self.algorithm_opt = config['algorithm_opt']
37 | self.config = config
38 | self.logger = logger
39 | self.model = model
40 | self.classifier = classifier
41 | self.optimizer = create_optimizer(model, classifier, logger, config)
42 | self.scheduler = create_scheduler(self.optimizer, logger, config)
43 | self.eval = eval
44 | self.training_opt = config['training_opt']
45 |
46 | self.checkpoint = Checkpoint(config)
47 |
48 | # get dataloader
49 | self.logger.info('=====> Get train dataloader')
50 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
51 |
52 | # get loss
53 | self.loss_fc = create_loss(logger, config, self.train_loader)
54 |
55 | # set eval
56 | if self.eval:
57 | test_func = test_loader(config)
58 | self.testing = test_func(config, logger, model, classifier, val=True)
59 |
60 |
61 | def run(self):
62 | # Start Training
63 | self.logger.info('=====> Start Baseline Training')
64 |
65 | # logit adjustment
66 | logit_adj = utils.compute_adjustment(self.train_loader, self.algorithm_opt['tro'])
67 | logit_adj.requires_grad = False
68 |
69 | # run epoch
70 | for epoch in range(self.training_opt['num_epochs']):
71 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
72 |
73 | # preprocess for each epoch
74 | total_batch = len(self.train_loader)
75 |
76 | for step, (inputs, labels, _, _) in enumerate(self.train_loader):
77 | self.optimizer.zero_grad()
78 |
79 | # additional inputs
80 | inputs, labels = inputs.cuda(), labels.cuda()
81 | add_inputs = {}
82 | batch_size = inputs.shape[0]
83 | add_inputs['logit_adj'] = logit_adj.to(inputs.device).view(1, -1).repeat(batch_size, 1)
84 |
85 | features = self.model(inputs)
86 | predictions = self.classifier(features, add_inputs)
87 |
88 | # calculate loss
89 | loss = self.loss_fc(predictions, labels)
90 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),}
91 |
92 | loss.backward()
93 | self.optimizer.step()
94 |
95 | # calculate accuracy
96 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0]
97 |
98 | # log information
99 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
100 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
101 |
102 | first_batch = (epoch == 0) and (step == 0)
103 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
104 | utils.print_grad(self.classifier.named_parameters())
105 | utils.print_grad(self.model.named_parameters())
106 |
107 | # evaluation on validation set
108 | if self.eval:
109 | val_acc = self.testing.run_val(epoch)
110 | else:
111 | val_acc = 0.0
112 |
113 | # checkpoint
114 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc)
115 |
116 | # update scheduler
117 | self.scheduler.step()
118 |
119 | # save best model path
120 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/train_ldam.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | class train_ldam():
18 | def __init__(self, args, config, logger, eval=False):
19 | # ============================================================================
20 | # create model
21 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
22 | model_type = config['networks']['type']
23 | model_file = config['networks'][model_type]['def_file']
24 | model_args = config['networks'][model_type]['params']
25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
26 | classifier_type = config['classifiers']['type']
27 | classifier_file = config['classifiers'][classifier_type]['def_file']
28 | classifier_args = config['classifiers'][classifier_type]['params']
29 | model = utils.source_import(model_file).create_model(**model_args)
30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
31 |
32 | model = nn.DataParallel(model).cuda()
33 | classifier = nn.DataParallel(classifier).cuda()
34 |
35 | # other initialization
36 | self.config = config
37 | self.logger = logger
38 | self.model = model
39 | self.classifier = classifier
40 | self.optimizer = create_optimizer(model, classifier, logger, config)
41 | self.scheduler = create_scheduler(self.optimizer, logger, config)
42 | self.eval = eval
43 | self.training_opt = config['training_opt']
44 |
45 | self.checkpoint = Checkpoint(config)
46 |
47 | # get dataloader
48 | self.logger.info('=====> Get train dataloader')
49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
50 |
51 | # get loss
52 | self.loss_ldam = create_loss(logger, config, self.train_loader)
53 |
54 | # set eval
55 | if self.eval:
56 | test_func = test_loader(config)
57 | self.testing = test_func(config, logger, model, classifier, val=True)
58 |
59 |
60 | def run(self):
61 | # Start Training
62 | self.logger.info('=====> Start LDAM Training')
63 |
64 | # run epoch
65 | for epoch in range(self.training_opt['num_epochs']):
66 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
67 |
68 | # preprocess for each epoch
69 | total_batch = len(self.train_loader)
70 |
71 | # set LDAM weight
72 | self.loss_ldam.set_weight(epoch)
73 |
74 | for step, (inputs, labels, _, _) in enumerate(self.train_loader):
75 | self.optimizer.zero_grad()
76 |
77 | # additional inputs
78 | inputs, labels = inputs.cuda(), labels.cuda()
79 | add_inputs = {}
80 |
81 | features = self.model(inputs)
82 | predictions = self.classifier(features, add_inputs)
83 |
84 | # calculate loss
85 | loss = self.loss_ldam(predictions, labels)
86 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),}
87 |
88 | loss.backward()
89 | self.optimizer.step()
90 |
91 | # calculate accuracy
92 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0]
93 |
94 | # log information
95 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
96 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
97 |
98 | first_batch = (epoch == 0) and (step == 0)
99 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
100 | utils.print_grad(self.classifier.named_parameters())
101 | utils.print_grad(self.model.named_parameters())
102 |
103 | # evaluation on validation set
104 | if self.eval:
105 | val_acc = self.testing.run_val(epoch)
106 | else:
107 | val_acc = 0.0
108 |
109 | # checkpoint
110 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc)
111 |
112 | # update scheduler
113 | self.scheduler.step()
114 |
115 | # save best model path
116 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/train_lff.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | # hard example mining
18 |
19 | class GeneralizedCELoss(nn.Module):
20 | def __init__(self, q=0.7):
21 | super(GeneralizedCELoss, self).__init__()
22 | self.q = q
23 |
24 | def forward(self, logits, targets, requires_weight = False, weight_base = 0):
25 | p = F.softmax(logits, dim=1)
26 | if np.isnan(p.mean().item()):
27 | raise NameError('GCE_p')
28 | Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1))
29 | # modify gradient of cross entropy
30 | loss_weight = (Yg.squeeze().detach()**self.q)*self.q
31 | if np.isnan(Yg.mean().item()):
32 | raise NameError('GCE_Yg')
33 |
34 | loss = F.cross_entropy(logits, targets, reduction='none') * loss_weight + weight_base
35 | if requires_weight:
36 | return loss, loss_weight
37 | return loss
38 |
39 | class train_lff():
40 | def __init__(self, args, config, logger, eval=False):
41 | # ============================================================================
42 | # create model
43 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
44 | model_type = config['networks']['type']
45 | model_file = config['networks'][model_type]['def_file']
46 | model_args = config['networks'][model_type]['params']
47 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
48 | classifier_type = config['classifiers']['type']
49 | classifier_file = config['classifiers'][classifier_type]['def_file']
50 | classifier_args = config['classifiers'][classifier_type]['params']
51 | model_b = utils.source_import(model_file).create_model(**model_args)
52 | model_d = utils.source_import(model_file).create_model(**model_args)
53 | classifier_b = utils.source_import(classifier_file).create_model(**classifier_args)
54 | classifier_d = utils.source_import(classifier_file).create_model(**classifier_args)
55 |
56 | model_b = nn.DataParallel(model_b).cuda()
57 | model_d = nn.DataParallel(model_d).cuda()
58 | classifier_b = nn.DataParallel(classifier_b).cuda()
59 | classifier_d = nn.DataParallel(classifier_d).cuda()
60 |
61 | # other initialization
62 | self.algorithm_opt = config['algorithm_opt']
63 | self.config = config
64 | self.logger = logger
65 | self.model_b = model_b
66 | self.model_d = model_d
67 | self.classifier_b = classifier_b
68 | self.classifier_d = classifier_d
69 | self.optimizer_b = create_optimizer(model_b, classifier_b, logger, config)
70 | self.optimizer_d = create_optimizer(model_d, classifier_d, logger, config)
71 | self.scheduler_b = create_scheduler(self.optimizer_b, logger, config)
72 | self.scheduler_d = create_scheduler(self.optimizer_d, logger, config)
73 | self.eval = eval
74 | self.training_opt = config['training_opt']
75 |
76 | self.checkpoint = Checkpoint(config)
77 |
78 | # get dataloader
79 | self.logger.info('=====> Get train dataloader')
80 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
81 |
82 | # get loss
83 | self.loss_fc = nn.CrossEntropyLoss(reduction='none')
84 | # biased loss
85 | self.loss_bias = GeneralizedCELoss()
86 |
87 |
88 | # set eval
89 | if self.eval:
90 | test_func = test_loader(config)
91 | self.testing = test_func(config, logger, model_d, classifier_d, val=True)
92 |
93 |
94 | def run(self):
95 | # Start Training
96 | self.logger.info('=====> Start Baseline Training')
97 |
98 | # logit adjustment
99 | logit_adj = utils.compute_adjustment(self.train_loader, self.algorithm_opt['tro'])
100 | logit_adj.requires_grad = False
101 |
102 | # run epoch
103 | for epoch in range(self.training_opt['num_epochs']):
104 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
105 |
106 | # preprocess for each epoch
107 | total_batch = len(self.train_loader)
108 |
109 | for step, (inputs, labels, _, _) in enumerate(self.train_loader):
110 | self.optimizer_b.zero_grad()
111 | self.optimizer_d.zero_grad()
112 |
113 | # additional inputs
114 | inputs, labels = inputs.cuda(), labels.cuda()
115 | add_inputs = {}
116 | batch_size = inputs.shape[0]
117 | add_inputs['logit_adj'] = logit_adj.to(inputs.device).view(1, -1).repeat(batch_size, 1)
118 |
119 | # biased prediction
120 | predictions_b = self.classifier_b(self.model_b(inputs), add_inputs)
121 | # targeted prediction
122 | predictions_d = self.classifier_d(self.model_d(inputs), add_inputs)
123 |
124 | # calculate hard exampling mining weight
125 | loss_b = self.loss_fc(predictions_b, labels).detach()
126 | loss_d = self.loss_fc(predictions_d, labels).detach()
127 |
128 | loss_weight = loss_b / (loss_b + loss_d + 1e-8)
129 |
130 | # calculate loss
131 | # biased model
132 | loss_b_update = self.loss_bias(predictions_b, labels)
133 | loss_d_update = self.loss_fc(predictions_d, labels) * loss_weight.cuda().detach()
134 | loss = loss_b_update.mean() + loss_d_update.mean()
135 |
136 | iter_info_print = {'biased loss' : loss_b_update.mean().item(), 'target loss': loss_d_update.mean().item()}
137 |
138 | loss.backward()
139 | self.optimizer_b.step()
140 | self.optimizer_d.step()
141 |
142 | # calculate accuracy
143 | accuracy = (predictions_d.max(1)[1] == labels).sum().float() / predictions_d.shape[0]
144 |
145 | # log information
146 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer_d.param_groups[0]['lr'])})
147 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
148 |
149 | # evaluation on validation set
150 | if self.eval:
151 | val_acc = self.testing.run_val(epoch)
152 | else:
153 | val_acc = 0.0
154 |
155 | # checkpoint
156 | self.checkpoint.save(self.model_d, self.classifier_d, epoch, self.logger, acc=val_acc)
157 |
158 | # update scheduler
159 | self.scheduler_b.step()
160 | self.scheduler_d.step()
161 |
162 | # save best model path
163 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/train_mixup.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | class train_mixup():
18 | def __init__(self, args, config, logger, eval=False):
19 | # ============================================================================
20 | # create model
21 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
22 | model_type = config['networks']['type']
23 | model_file = config['networks'][model_type]['def_file']
24 | model_args = config['networks'][model_type]['params']
25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
26 | classifier_type = config['classifiers']['type']
27 | classifier_file = config['classifiers'][classifier_type]['def_file']
28 | classifier_args = config['classifiers'][classifier_type]['params']
29 | model = utils.source_import(model_file).create_model(**model_args)
30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
31 |
32 | model = nn.DataParallel(model).cuda()
33 | classifier = nn.DataParallel(classifier).cuda()
34 |
35 | # other initialization
36 | self.config = config
37 | self.logger = logger
38 | self.model = model
39 | self.classifier = classifier
40 | self.optimizer = create_optimizer(model, classifier, logger, config)
41 | self.scheduler = create_scheduler(self.optimizer, logger, config)
42 | self.eval = eval
43 | self.training_opt = config['training_opt']
44 |
45 | self.checkpoint = Checkpoint(config)
46 |
47 | # get dataloader
48 | self.logger.info('=====> Get train dataloader')
49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
50 |
51 | # get loss
52 | self.loss_fc = create_loss(logger, config, self.train_loader)
53 |
54 | # set eval
55 | if self.eval:
56 | test_func = test_loader(config)
57 | self.testing = test_func(config, logger, model, classifier, val=True)
58 |
59 |
60 | def mixup_data(self, x, y, alpha=1.0):
61 | lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
62 | batch_size = x.shape[0]
63 | index = torch.randperm(batch_size).to(x.device)
64 | mixed_x = lam * x + (1 - lam) * x[index]
65 | y_a, y_b = y, y[index]
66 | return mixed_x, y_a, y_b, lam
67 |
68 | def mixup_criterion(self, pred, y_a, y_b, lam):
69 | return lam * self.loss_fc(pred, y_a) + (1 - lam) * self.loss_fc(pred, y_b)
70 |
71 | def mixup_accuracy(self, pred, y_a, y_b, lam):
72 | correct = lam * (pred.max(1)[1] == y_a) + (1 - lam) * (pred.max(1)[1] == y_b)
73 | accuracy = correct.sum().float() / pred.shape[0]
74 | return accuracy
75 |
76 |
77 | def run(self):
78 | # Start Training
79 | self.logger.info('=====> Start Mixup Training')
80 |
81 | # run epoch
82 | for epoch in range(self.training_opt['num_epochs']):
83 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
84 |
85 | # preprocess for each epoch
86 | total_batch = len(self.train_loader)
87 |
88 | for step, (inputs, labels, _, _) in enumerate(self.train_loader):
89 | self.optimizer.zero_grad()
90 |
91 | # additional inputs
92 | inputs, labels = inputs.cuda(), labels.cuda()
93 | add_inputs = {}
94 |
95 | # mixup
96 | inputs, labels_a, labels_b, lam = self.mixup_data(inputs, labels)
97 |
98 | features = self.model(inputs)
99 | predictions = self.classifier(features, add_inputs)
100 |
101 | # calculate loss
102 | loss = self.mixup_criterion(predictions, labels_a, labels_b, lam)
103 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),}
104 |
105 | loss.backward()
106 | self.optimizer.step()
107 |
108 | # calculate accuracy
109 | accuracy = self.mixup_accuracy(predictions, labels_a, labels_b, lam)
110 |
111 | # log information
112 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
113 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
114 |
115 | first_batch = (epoch == 0) and (step == 0)
116 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
117 | utils.print_grad(self.classifier.named_parameters())
118 | utils.print_grad(self.model.named_parameters())
119 |
120 | # evaluation on validation set
121 | if self.eval:
122 | val_acc = self.testing.run_val(epoch)
123 | else:
124 | val_acc = 0.0
125 |
126 | # checkpoint
127 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc)
128 |
129 | # update scheduler
130 | self.scheduler.step()
131 |
132 | # save best model path
133 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/train_ride.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | class train_ride():
18 | def __init__(self, args, config, logger, eval=False):
19 | # ============================================================================
20 | # create model
21 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
22 | model_type = config['networks']['type']
23 | model_file = config['networks'][model_type]['def_file']
24 | model_args = config['networks'][model_type]['params']
25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
26 | classifier_type = config['classifiers']['type']
27 | classifier_file = config['classifiers'][classifier_type]['def_file']
28 | classifier_args = config['classifiers'][classifier_type]['params']
29 | model = utils.source_import(model_file).create_model(**model_args)
30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
31 |
32 | model = nn.DataParallel(model).cuda()
33 | classifier = nn.DataParallel(classifier).cuda()
34 |
35 | # other initialization
36 | self.config = config
37 | self.logger = logger
38 | self.model = model
39 | self.classifier = classifier
40 | self.optimizer = create_optimizer(model, classifier, logger, config)
41 | self.scheduler = create_scheduler(self.optimizer, logger, config)
42 | self.eval = eval
43 | self.training_opt = config['training_opt']
44 |
45 | self.checkpoint = Checkpoint(config)
46 |
47 | # get dataloader
48 | self.logger.info('=====> Get train dataloader')
49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
50 |
51 | # get loss
52 | self.loss_fc = create_loss(logger, config, self.train_loader)
53 |
54 | # set eval
55 | if self.eval:
56 | test_func = test_loader(config)
57 | self.testing = test_func(config, logger, model, classifier, val=True)
58 |
59 |
60 | def run(self):
61 | # Start Training
62 | self.logger.info('=====> Start RIDE Training')
63 |
64 | # run epoch
65 | for epoch in range(self.training_opt['num_epochs']):
66 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
67 | iter_info_print = {}
68 |
69 | # preprocess for each epoch
70 | total_batch = len(self.train_loader)
71 | if self.training_opt['loss'] == 'RIDE':
72 | self.loss_fc.set_epoch(epoch)
73 |
74 | for step, (inputs, labels, _, _) in enumerate(self.train_loader):
75 | self.optimizer.zero_grad()
76 |
77 | # additional inputs
78 | inputs, labels = inputs.cuda(), labels.cuda()
79 | add_inputs = {}
80 |
81 | features = self.model(inputs)
82 | predictions, all_logits = self.classifier(features, add_inputs)
83 |
84 | # calculate loss
85 | if self.training_opt['loss'] == 'RIDE':
86 | extra_info = {'logits': all_logits}
87 | loss = self.loss_fc(predictions, labels, extra_info)
88 | iter_info_print[self.training_opt['loss']] = loss.sum().item()
89 | else:
90 | all_loss = []
91 | for logit in all_logits:
92 | all_loss.append(self.loss_fc(logit, labels))
93 | loss = sum(all_loss)
94 | for i, branch_loss in enumerate(all_loss):
95 | iter_info_print[self.training_opt['loss'] + '_{}'.format(i)] = branch_loss.sum().item()
96 |
97 | loss.backward()
98 | self.optimizer.step()
99 |
100 | # calculate accuracy
101 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0]
102 |
103 | # log information
104 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
105 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
106 |
107 | first_batch = (epoch == 0) and (step == 0)
108 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
109 | utils.print_grad(self.classifier.named_parameters())
110 | utils.print_grad(self.model.named_parameters())
111 |
112 | # evaluation on validation set
113 | if self.eval:
114 | val_acc = self.testing.run_val(epoch)
115 | else:
116 | val_acc = 0.0
117 |
118 | # checkpoint
119 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc)
120 |
121 | # update scheduler
122 | self.scheduler.step()
123 |
124 | # save best model path
125 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/train_stage1.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | class train_stage1():
18 | def __init__(self, args, config, logger, eval=False):
19 | # ============================================================================
20 | # create model
21 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
22 | model_type = config['networks']['type']
23 | model_file = config['networks'][model_type]['def_file']
24 | model_args = config['networks'][model_type]['params']
25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
26 | classifier_type = config['classifiers']['type']
27 | classifier_file = config['classifiers'][classifier_type]['def_file']
28 | classifier_args = config['classifiers'][classifier_type]['params']
29 | model = utils.source_import(model_file).create_model(**model_args)
30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
31 |
32 | model = nn.DataParallel(model).cuda()
33 | classifier = nn.DataParallel(classifier).cuda()
34 |
35 | # other initialization
36 | self.config = config
37 | self.logger = logger
38 | self.model = model
39 | self.classifier = classifier
40 | self.optimizer = create_optimizer(model, classifier, logger, config)
41 | self.scheduler = create_scheduler(self.optimizer, logger, config)
42 | self.eval = eval
43 | self.training_opt = config['training_opt']
44 |
45 | self.checkpoint = Checkpoint(config)
46 |
47 | # get dataloader
48 | self.logger.info('=====> Get train dataloader')
49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
50 |
51 | # get loss
52 | self.loss_fc = create_loss(logger, config, self.train_loader)
53 |
54 | # set eval
55 | if self.eval:
56 | test_func = test_loader(config)
57 | self.testing = test_func(config, logger, model, classifier, val=True)
58 |
59 |
60 | def run(self):
61 | # Start Training
62 | self.logger.info('=====> Start Stage 1 Training')
63 |
64 | # run epoch
65 | for epoch in range(self.training_opt['num_epochs']):
66 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
67 |
68 | # preprocess for each epoch
69 | total_batch = len(self.train_loader)
70 |
71 | for step, (inputs, labels, _, _) in enumerate(self.train_loader):
72 | self.optimizer.zero_grad()
73 |
74 | # additional inputs
75 | inputs, labels = inputs.cuda(), labels.cuda()
76 | add_inputs = {}
77 |
78 | features = self.model(inputs)
79 | predictions = self.classifier(features, add_inputs)
80 |
81 | # calculate loss
82 | loss = self.loss_fc(predictions, labels)
83 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),}
84 |
85 | loss.backward()
86 | self.optimizer.step()
87 |
88 | # calculate accuracy
89 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0]
90 |
91 | # log information
92 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
93 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
94 |
95 | first_batch = (epoch == 0) and (step == 0)
96 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
97 | utils.print_grad(self.classifier.named_parameters())
98 | utils.print_grad(self.model.named_parameters())
99 |
100 | # evaluation on validation set
101 | if self.eval:
102 | val_acc = self.testing.run_val(epoch)
103 | else:
104 | val_acc = 0.0
105 |
106 | # checkpoint
107 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc)
108 |
109 | # update scheduler
110 | self.scheduler.step()
111 |
112 | # save best model path
113 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/train_stage2.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | class train_stage2():
18 | def __init__(self, args, config, logger, eval=False):
19 | # ============================================================================
20 | # create model
21 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
22 | model_type = config['networks']['type']
23 | model_file = config['networks'][model_type]['def_file']
24 | model_args = config['networks'][model_type]['params']
25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
26 | classifier_type = config['classifiers']['type']
27 | classifier_file = config['classifiers'][classifier_type]['def_file']
28 | classifier_args = config['classifiers'][classifier_type]['params']
29 | model = utils.source_import(model_file).create_model(**model_args)
30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
31 |
32 | model = nn.DataParallel(model).cuda()
33 | classifier = nn.DataParallel(classifier).cuda()
34 |
35 | # other initialization
36 | self.config = config
37 | self.logger = logger
38 | self.model = model
39 | self.classifier = classifier
40 | self.optimizer = create_optimizer_stage2(model, classifier, logger, config)
41 | self.scheduler = create_scheduler(self.optimizer, logger, config)
42 | self.eval = eval
43 | self.training_opt = config['training_opt']
44 |
45 | self.checkpoint = Checkpoint(config)
46 |
47 | if config['classifiers']['type'] == 'LWS':
48 | self.checkpoint.load(self.model, self.classifier, args.load_dir, logger)
49 | else:
50 | self.checkpoint.load_backbone(self.model, args.load_dir, logger)
51 |
52 | # get dataloader
53 | self.logger.info('=====> Get train dataloader')
54 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
55 |
56 | # get loss
57 | self.loss_fc = create_loss(logger, config, self.train_loader)
58 |
59 | # set eval
60 | if self.eval:
61 | test_func = test_loader(config)
62 | self.testing = test_func(config, logger, model, classifier, val=True)
63 |
64 |
65 | def run(self):
66 | # Start Training
67 | self.logger.info('=====> Start Stage 2 Training')
68 |
69 | # run epoch
70 | for epoch in range(self.training_opt['num_epochs']):
71 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
72 |
73 | # preprocess for each epoch
74 | total_batch = len(self.train_loader)
75 |
76 | for step, (inputs, labels, _, _) in enumerate(self.train_loader):
77 | self.optimizer.zero_grad()
78 |
79 | # additional inputs
80 | inputs, labels = inputs.cuda(), labels.cuda()
81 | add_inputs = {}
82 |
83 | features = self.model(inputs)
84 | predictions = self.classifier(features, add_inputs)
85 |
86 | # calculate loss
87 | loss = self.loss_fc(predictions, labels)
88 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),}
89 |
90 | loss.backward()
91 | self.optimizer.step()
92 |
93 | # calculate accuracy
94 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0]
95 |
96 | # log information
97 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
98 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
99 |
100 | first_batch = (epoch == 0) and (step == 0)
101 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
102 | utils.print_grad(self.classifier.named_parameters())
103 | utils.print_grad(self.model.named_parameters())
104 |
105 | # evaluation on validation set
106 | if self.eval:
107 | val_acc = self.testing.run_val(epoch)
108 | else:
109 | val_acc = 0.0
110 |
111 | # checkpoint
112 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc)
113 |
114 | # update scheduler
115 | self.scheduler.step()
116 |
117 | # save best model path
118 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/train_stage2_ride.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | class train_stage2_ride():
18 | def __init__(self, args, config, logger, eval=False):
19 | # ============================================================================
20 | # create model
21 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
22 | model_type = config['networks']['type']
23 | model_file = config['networks'][model_type]['def_file']
24 | model_args = config['networks'][model_type]['params']
25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
26 | classifier_type = config['classifiers']['type']
27 | classifier_file = config['classifiers'][classifier_type]['def_file']
28 | classifier_args = config['classifiers'][classifier_type]['params']
29 | model = utils.source_import(model_file).create_model(**model_args)
30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
31 |
32 | model = nn.DataParallel(model).cuda()
33 | classifier = nn.DataParallel(classifier).cuda()
34 |
35 | # other initialization
36 | self.config = config
37 | self.logger = logger
38 | self.model = model
39 | self.classifier = classifier
40 | self.optimizer = create_optimizer_stage2(model, classifier, logger, config)
41 | self.scheduler = create_scheduler(self.optimizer, logger, config)
42 | self.eval = eval
43 | self.training_opt = config['training_opt']
44 |
45 | self.checkpoint = Checkpoint(config)
46 |
47 | if config['classifiers']['type'] == 'LWS':
48 | self.checkpoint.load(self.model, self.classifier, args.load_dir, logger)
49 | else:
50 | self.checkpoint.load_backbone(self.model, args.load_dir, logger)
51 |
52 | # get dataloader
53 | self.logger.info('=====> Get train dataloader')
54 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
55 |
56 | # get loss
57 | self.loss_fc = create_loss(logger, config, self.train_loader)
58 |
59 | # set eval
60 | if self.eval:
61 | test_func = test_loader(config)
62 | self.testing = test_func(config, logger, model, classifier, val=True)
63 |
64 |
65 | def run(self):
66 | # Start Training
67 | self.logger.info('=====> Start RIDE Stage 2 Training')
68 |
69 | # run epoch
70 | for epoch in range(self.training_opt['num_epochs']):
71 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
72 |
73 | # preprocess for each epoch
74 | total_batch = len(self.train_loader)
75 |
76 | for step, (inputs, labels, _, _) in enumerate(self.train_loader):
77 | self.optimizer.zero_grad()
78 | iter_info_print = {}
79 | # additional inputs
80 | inputs, labels = inputs.cuda(), labels.cuda()
81 | add_inputs = {}
82 |
83 | features = self.model(inputs)
84 | predictions, all_logits = self.classifier(features, add_inputs)
85 |
86 | # calculate loss
87 | all_loss = []
88 | for logit in all_logits:
89 | all_loss.append(self.loss_fc(logit, labels))
90 | loss = sum(all_loss)
91 | for i, branch_loss in enumerate(all_loss):
92 | iter_info_print[self.training_opt['loss'] + '_{}'.format(i)] = branch_loss.sum().item()
93 |
94 | loss.backward()
95 | self.optimizer.step()
96 |
97 | # calculate accuracy
98 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0]
99 |
100 | # log information
101 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
102 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
103 |
104 | first_batch = (epoch == 0) and (step == 0)
105 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
106 | utils.print_grad(self.classifier.named_parameters())
107 | utils.print_grad(self.model.named_parameters())
108 |
109 | # evaluation on validation set
110 | if self.eval:
111 | val_acc = self.testing.run_val(epoch)
112 | else:
113 | val_acc = 0.0
114 |
115 | # checkpoint
116 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc)
117 |
118 | # update scheduler
119 | self.scheduler.step()
120 |
121 | # save best model path
122 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/train_tade.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | class train_tade():
18 | def __init__(self, args, config, logger, eval=False):
19 | # ============================================================================
20 | # create model
21 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
22 | model_type = config['networks']['type']
23 | model_file = config['networks'][model_type]['def_file']
24 | model_args = config['networks'][model_type]['params']
25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
26 | classifier_type = config['classifiers']['type']
27 | classifier_file = config['classifiers'][classifier_type]['def_file']
28 | classifier_args = config['classifiers'][classifier_type]['params']
29 | model = utils.source_import(model_file).create_model(**model_args)
30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
31 |
32 | model = nn.DataParallel(model).cuda()
33 | classifier = nn.DataParallel(classifier).cuda()
34 |
35 | # other initialization
36 | self.config = config
37 | self.logger = logger
38 | self.model = model
39 | self.classifier = classifier
40 | self.optimizer = create_optimizer(model, classifier, logger, config)
41 | self.scheduler = create_scheduler(self.optimizer, logger, config)
42 | self.eval = eval
43 | self.training_opt = config['training_opt']
44 |
45 | self.checkpoint = Checkpoint(config)
46 |
47 | # get dataloader
48 | self.logger.info('=====> Get train dataloader')
49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
50 |
51 | # get loss
52 | self.loss_fc = create_loss(logger, config, self.train_loader)
53 |
54 | # set eval
55 | if self.eval:
56 | test_func = test_loader(config)
57 | self.testing = test_func(config, logger, model, classifier, val=True)
58 |
59 |
60 | def run(self):
61 | # Start Training
62 | self.logger.info('=====> Start TADE Training')
63 |
64 | # run epoch
65 | for epoch in range(self.training_opt['num_epochs']):
66 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
67 | iter_info_print = {}
68 |
69 | # preprocess for each epoch
70 | total_batch = len(self.train_loader)
71 |
72 | for step, (inputs, labels, _, _) in enumerate(self.train_loader):
73 | self.optimizer.zero_grad()
74 |
75 | # additional inputs
76 | inputs, labels = inputs.cuda(), labels.cuda()
77 | add_inputs = {}
78 |
79 | features = self.model(inputs)
80 | predictions, all_logits = self.classifier(features, add_inputs)
81 |
82 | # calculate loss
83 | if self.training_opt['loss'] == 'TADE':
84 | extra_info = {'logits': all_logits}
85 | loss = self.loss_fc(predictions, labels, extra_info)
86 | iter_info_print[self.training_opt['loss']] = loss.sum().item()
87 | else:
88 | all_loss = []
89 | for logit in all_logits:
90 | all_loss.append(self.loss_fc(logit, labels))
91 | loss = sum(all_loss)
92 | for i, branch_loss in enumerate(all_loss):
93 | iter_info_print[self.training_opt['loss'] + '_{}'.format(i)] = branch_loss.sum().item()
94 |
95 | loss.backward()
96 | self.optimizer.step()
97 |
98 | # calculate accuracy
99 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0]
100 |
101 | # log information
102 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
103 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
104 |
105 | first_batch = (epoch == 0) and (step == 0)
106 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
107 | utils.print_grad(self.classifier.named_parameters())
108 | utils.print_grad(self.model.named_parameters())
109 |
110 | # evaluation on validation set
111 | if self.eval:
112 | val_acc = self.testing.run_val(epoch)
113 | else:
114 | val_acc = 0.0
115 |
116 | # checkpoint
117 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc)
118 |
119 | # update scheduler
120 | self.scheduler.step()
121 |
122 | # save best model path
123 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/train_tde.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | import torch.nn.functional as F
10 |
11 | import utils.general_utils as utils
12 | from data.dataloader import get_loader
13 | from utils.checkpoint_utils import Checkpoint
14 | from utils.training_utils import *
15 | from utils.test_loader import test_loader
16 |
17 | class train_tde():
18 | def __init__(self, args, config, logger, eval=False):
19 | # ============================================================================
20 | # create model
21 | logger.info('=====> Model construction from: ' + str(config['networks']['type']))
22 | model_type = config['networks']['type']
23 | model_file = config['networks'][model_type]['def_file']
24 | model_args = config['networks'][model_type]['params']
25 | logger.info('=====> Classifier construction from: ' + str(config['classifiers']['type']))
26 | classifier_type = config['classifiers']['type']
27 | classifier_file = config['classifiers'][classifier_type]['def_file']
28 | classifier_args = config['classifiers'][classifier_type]['params']
29 | model = utils.source_import(model_file).create_model(**model_args)
30 | classifier = utils.source_import(classifier_file).create_model(**classifier_args)
31 |
32 | model = nn.DataParallel(model).cuda()
33 | classifier = nn.DataParallel(classifier).cuda()
34 |
35 | # other initialization
36 | self.config = config
37 | self.logger = logger
38 | self.model = model
39 | self.classifier = classifier
40 | self.optimizer = create_optimizer(model, classifier, logger, config)
41 | self.scheduler = create_scheduler(self.optimizer, logger, config)
42 | self.eval = eval
43 | self.training_opt = config['training_opt']
44 |
45 | self.checkpoint = Checkpoint(config)
46 |
47 | # get dataloader
48 | self.logger.info('=====> Get train dataloader')
49 | self.train_loader = get_loader(config, 'train', config['dataset']['testset'], logger)
50 |
51 | # get loss
52 | self.loss_fc = create_loss(logger, config, self.train_loader)
53 |
54 | # set eval
55 | if self.eval:
56 | test_func = test_loader(config)
57 | self.testing = test_func(config, logger, model, classifier, val=True)
58 |
59 |
60 | def update_embed(self, embed, input):
61 | # embed is only updated during training
62 | assert len(input.shape) == 2
63 | with torch.no_grad():
64 | embed = embed * 0.995 + (1 - 0.995) * input.mean(0).view(-1)
65 | return embed
66 |
67 |
68 | def run(self):
69 | # Start Training
70 | self.logger.info('=====> Start TDE Training')
71 |
72 | # init embed
73 | self.embed = torch.zeros(2048).cuda()
74 |
75 | # run epoch
76 | for epoch in range(self.training_opt['num_epochs']):
77 | self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
78 |
79 | # preprocess for each epoch
80 | total_batch = len(self.train_loader)
81 |
82 | for step, (inputs, labels, _, _) in enumerate(self.train_loader):
83 | self.optimizer.zero_grad()
84 |
85 | # additional inputs
86 | inputs, labels = inputs.cuda(), labels.cuda()
87 | add_inputs = {}
88 |
89 | features = self.model(inputs)
90 | predictions = self.classifier(features, add_inputs)
91 |
92 | # update embed during training
93 | self.embed = self.update_embed(self.embed, features)
94 |
95 | # calculate loss
96 | loss = self.loss_fc(predictions, labels)
97 | iter_info_print = {self.training_opt['loss'] : loss.sum().item(),}
98 |
99 | loss.backward()
100 | self.optimizer.step()
101 |
102 | # calculate accuracy
103 | accuracy = (predictions.max(1)[1] == labels).sum().float() / predictions.shape[0]
104 |
105 | # log information
106 | iter_info_print.update({'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])})
107 | self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
108 |
109 | first_batch = (epoch == 0) and (step == 0)
110 | if first_batch or self.config['logger_opt']['print_grad'] and step % 1000 == 0:
111 | utils.print_grad(self.classifier.named_parameters())
112 | utils.print_grad(self.model.named_parameters())
113 |
114 | # evaluation on validation set
115 | if self.eval:
116 | val_acc = self.testing.run_val(epoch, self.embed)
117 | else:
118 | val_acc = 0.0
119 |
120 | # checkpoint
121 | add_dict = {}
122 | add_dict['embed'] = self.embed.cpu()
123 | self.logger.info('Embed Mean: {}'.format(self.embed.mean().item()))
124 | self.checkpoint.save(self.model, self.classifier, epoch, self.logger, acc=val_acc, add_dict=add_dict)
125 |
126 | # update scheduler
127 | self.scheduler.step()
128 |
129 | # save best model path
130 | self.checkpoint.save_best_model_path(self.logger)
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KaihuaTang/Generalized-Long-Tailed-Benchmarks.pytorch/6317d8feb0ba107e1a64822567ed59115d51c581/utils/__init__.py
--------------------------------------------------------------------------------
/utils/checkpoint_utils.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import os
6 | import json
7 | import torch
8 | import numpy as np
9 |
10 | class Checkpoint():
11 | def __init__(self, config):
12 | self.config = config
13 | self.best_epoch = -1
14 | self.best_performance = -1
15 | self.best_model_path = None
16 |
17 |
18 | def save(self, model, classifier, epoch, logger, add_dict={}, acc=None):
19 | # only save at certain steps or the last epoch
20 | if epoch % self.config['checkpoint_opt']['checkpoint_step'] != 0 and epoch < (self.config['training_opt']['num_epochs'] - 1):
21 | return
22 |
23 | output = {
24 | 'model': model.state_dict(),
25 | 'classifier': classifier.state_dict(),
26 | 'epoch': epoch,
27 | }
28 | output.update(add_dict)
29 |
30 | model_name = 'epoch_{}_'.format(epoch) + self.config['checkpoint_opt']['checkpoint_name']
31 | model_path = os.path.join(self.config['output_dir'], model_name)
32 |
33 | logger.info('Model at epoch {} is saved to {}'.format(epoch, model_path))
34 | torch.save(output, model_path)
35 |
36 | # update best model
37 | if acc is not None:
38 | if float(acc) > self.best_performance:
39 | self.best_epoch = epoch
40 | self.best_performance = float(acc)
41 | self.best_model_path = model_path
42 | else:
43 | # if acc is None, the newest is always the best
44 | self.best_epoch = epoch
45 | self.best_model_path = model_path
46 |
47 | logger.info('Best model is at epoch {} with accuracy {:9.3f}'.format(self.best_epoch, self.best_performance))
48 |
49 |
50 | def save_best_model_path(self, logger):
51 | logger.info('Best model is at epoch {} with accuracy {:9.3f} (Path: {})'.format(self.best_epoch, self.best_performance, self.best_model_path))
52 | with open(os.path.join(self.config['output_dir'], 'best_checkpoint'), 'w+') as f:
53 | f.write(self.best_model_path + ' ' + str(self.best_epoch) + ' ' + str(self.best_performance) + '\n')
54 |
55 | def load(self, model, classifier, path, logger):
56 | if path.split('.')[-1] != 'pth':
57 | with open(os.path.join(path, 'best_checkpoint')) as f:
58 | path = f[0].split(' ')[0]
59 |
60 | checkpoint = torch.load(path, map_location='cpu')
61 | logger.info('Loading checkpoint pretrained with epoch {}.'.format(checkpoint['epoch']))
62 |
63 | model_state = checkpoint['model']
64 | classifier_state = checkpoint['classifier']
65 |
66 | model = self.load_module(model, model_state, logger)
67 | classifier = self.load_module(classifier, classifier_state, logger)
68 | return checkpoint
69 |
70 |
71 | def load_backbone(self, model, path, logger):
72 | if path.split('.')[-1] != 'pth':
73 | with open(os.path.join(path, 'best_checkpoint')) as f:
74 | path = f[0].split(' ')[0]
75 |
76 | checkpoint = torch.load(path, map_location='cpu')
77 | logger.info('Loading checkpoint pretrained with epoch {}.'.format(checkpoint['epoch']))
78 |
79 | model_state = checkpoint['model']
80 |
81 | model = self.load_module(model, model_state, logger)
82 | return checkpoint
83 |
84 |
85 | def load_module(self, module, module_state, logger):
86 | x = module.state_dict()
87 | for key, _ in x.items():
88 | if key in module_state:
89 | x[key] = module_state[key]
90 | logger.info('Load {:>50} from checkpoint.'.format(key))
91 | elif 'module.' + key in module_state:
92 | x[key] = module_state['module.' + key]
93 | logger.info('Load {:>50} from checkpoint (rematch with module.).'.format(key))
94 | else:
95 | logger.info('WARNING: Key {} is missing in the checkpoint.'.format(key))
96 |
97 | module.load_state_dict(x)
98 | return module
--------------------------------------------------------------------------------
/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 | import torch
5 | import numpy as np
6 | import importlib
7 |
8 |
9 | def count_dataset(train_loader):
10 | label_freq = {}
11 | if isinstance(train_loader, list) or isinstance(train_loader, tuple):
12 | all_labels = train_loader[0].dataset.labels
13 | else:
14 | all_labels = train_loader.dataset.labels
15 | for label in all_labels:
16 | key = str(label)
17 | label_freq[key] = label_freq.get(key, 0) + 1
18 | label_freq = dict(sorted(label_freq.items()))
19 | label_freq_array = np.array(list(label_freq.values()))
20 | return label_freq_array
21 |
22 |
23 | def compute_adjustment(train_loader, tro=1.0):
24 | """compute the base probabilities"""
25 | label_freq = {}
26 | for key in train_loader.dataset.labels:
27 | label_freq[key] = label_freq.get(key, 0) + 1
28 | label_freq = dict(sorted(label_freq.items()))
29 | label_freq_array = np.array(list(label_freq.values()))
30 | label_freq_array = label_freq_array / label_freq_array.sum()
31 | adjustments = np.log(label_freq_array ** tro + 1e-12)
32 | adjustments = torch.from_numpy(adjustments)
33 | return adjustments
34 |
35 |
36 | def update(config, algo_config, args):
37 | if args.output_dir:
38 | config['output_dir'] = args.output_dir
39 | if args.load_dir:
40 | config['load_dir'] = args.load_dir
41 |
42 | # select the algorithm we use
43 | if args.train_type:
44 | config['training_opt']['type'] = args.train_type
45 | # update algorithm details based on training_type
46 | algo_info = algo_config[args.train_type]
47 | config['sampler'] = algo_info['sampler']
48 | config['num_sampler'] = algo_info['num_sampler'] if 'num_sampler' in algo_info else 1
49 | config['batch_split'] = algo_info['batch_split'] if 'batch_split' in algo_info else False
50 | config['testing_opt']['type'] = algo_info['test_type']
51 | config['training_opt']['loss'] = algo_info['loss_type']
52 | config['networks']['type'] = algo_info['backbone_type']
53 | config['classifiers']['type'] = algo_info['classifier_type']
54 | config['algorithm_opt'] = algo_info['algorithm_opt']
55 | config['dataset']['rand_aug'] = algo_info['rand_aug'] if 'rand_aug' in algo_info else False
56 |
57 | if 'num_epochs' in algo_info:
58 | config['training_opt']['num_epochs'] = algo_info['num_epochs']
59 | if 'batch_size' in algo_info:
60 | config['training_opt']['batch_size'] = algo_info['batch_size']
61 | if 'optim_params' in algo_info:
62 | config['training_opt']['optim_params'] = algo_info['optim_params']
63 | if 'scheduler' in algo_info:
64 | config['training_opt']['scheduler'] = algo_info['scheduler']
65 | if 'scheduler_params' in algo_info:
66 | config['training_opt']['scheduler_params'] = algo_info['scheduler_params']
67 |
68 |
69 | # other updates
70 | if args.lr:
71 | config['training_opt']['optim_params']['lr'] = args.lr
72 | if args.testset:
73 | config['dataset']['testset'] = args.testset
74 | if args.model_type:
75 | config['classifiers']['type'] = args.model_type
76 | if args.loss_type:
77 | config['training_opt']['loss'] = args.loss_type
78 | if args.sample_type:
79 | config['sampler'] = args.sample_type
80 | if args.rand_aug:
81 | config['dataset']['rand_aug'] = True
82 | if args.save_all:
83 | config['saving_opt']['save_all'] = True
84 | return config
85 |
86 |
87 | def source_import(file_path):
88 | """This function imports python module directly from source code using importlib"""
89 | spec = importlib.util.spec_from_file_location('', file_path)
90 | module = importlib.util.module_from_spec(spec)
91 | spec.loader.exec_module(module)
92 | return module
93 |
94 | def print_grad(named_parameters):
95 | """ show grads """
96 | total_norm = 0
97 | param_to_norm = {}
98 | param_to_shape = {}
99 | for n, p in named_parameters:
100 | if p.grad is not None:
101 | param_norm = p.grad.data.norm(2)
102 | total_norm += param_norm ** 2
103 | param_to_norm[n] = param_norm
104 | param_to_shape[n] = p.size()
105 | total_norm = total_norm ** (1. / 2)
106 | print('---Total norm {:.3f} -----------------'.format(total_norm))
107 | for name, norm in sorted(param_to_norm.items(), key=lambda x: -x[1]):
108 | print("{:<50s}: {:.3f}, ({})".format(name, norm, param_to_shape[name]))
109 | print('-------------------------------', flush=True)
110 | return total_norm
111 |
112 | def print_config(config, logger, head=''):
113 | for key, val in config.items():
114 | if isinstance(val, dict):
115 | logger.info(head + str(key))
116 | print_config(val, logger, head=head + ' ')
117 | else:
118 | logger.info(head + '{} : {}'.format(str(key), str(val)))
119 |
120 | class TriggerAction():
121 | def __init__(self, name):
122 | self.name = name
123 | self.action = {}
124 |
125 | def add_action(self, name, func):
126 | assert str(name) not in self.action
127 | self.action[str(name)] = func
128 |
129 | def remove_action(self, name):
130 | assert str(name) in self.action
131 | del self.action[str(name)]
132 | assert str(name) not in self.action
133 |
134 | def run_all(self, logger=None):
135 | for key, func in self.action.items():
136 | if logger:
137 | logger.info('trigger {}'.format(key))
138 | func()
139 |
140 |
141 | def calculate_recall(prediction, label, split_mask=None):
142 | recall = (prediction == label).float()
143 | if split_mask is not None:
144 | recall = recall[split_mask].mean().item()
145 | else:
146 | recall = recall.mean().item()
147 | return recall
148 |
149 |
150 | def calculate_precision(prediction, label, num_class, split_mask=None):
151 | pred_count = torch.zeros(num_class).to(label.device)
152 | for i in range(num_class):
153 | pred_count[i] = (prediction == i).float().sum()
154 |
155 | precision = (prediction == label).float() / pred_count[prediction].float()
156 |
157 | if split_mask is not None:
158 | available_class = len(set(label[split_mask].tolist()))
159 | precision = precision[split_mask].sum().item() / available_class
160 | else:
161 | available_class = len(set(label.tolist()))
162 | precision = precision.sum().item() / available_class
163 |
164 | return precision
165 |
166 |
167 | def calculate_f1(recall, precision):
168 | f1 = 2 * recall * precision / (recall + precision)
169 | return f1
--------------------------------------------------------------------------------
/utils/logger_utils.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 | import os
5 | import csv
6 | from datetime import datetime
7 |
8 | class custom_logger():
9 | def __init__(self, output_path, name='logger.txt'):
10 | now = datetime.now()
11 | logger_name = str(now.strftime("20%y_%h_%d_")) + name
12 | self.logger_path = os.path.join(output_path, logger_name)
13 | self.csv_path = os.path.join(output_path, 'results.csv')
14 | # init logger file
15 | f = open(self.logger_path, "w+")
16 | f.write(self.get_local_time() + 'Start Logging \n')
17 | f.close()
18 | # init csv
19 | with open(self.csv_path, 'w') as f:
20 | writer = csv.writer(f)
21 | writer.writerow([self.get_local_time(), ])
22 |
23 | def get_local_time(self):
24 | now = datetime.now()
25 | return str(now.strftime("%y_%h_%d %H:%M:%S : "))
26 |
27 | def info(self, log_str):
28 | print(str(log_str))
29 | with open(self.logger_path, "a") as f:
30 | f.write(self.get_local_time() + str(log_str) + '\n')
31 |
32 | def raise_error(self, error):
33 | prototype = '************* Error: {} *************'.format(str(error))
34 | self.info(prototype)
35 | raise ValueError(str(error))
36 |
37 | def info_iter(self, epoch, batch, total_batch, info_dict, print_iter):
38 | if batch % print_iter != 0:
39 | pass
40 | else:
41 | acc_log = 'Epoch {:5d}, Batch {:6d}/{},'.format(epoch, batch, total_batch)
42 | for key, val in info_dict.items():
43 | acc_log += ' {}: {:9.3f},'.format(str(key), float(val))
44 | self.info(acc_log)
45 |
46 | def write_results(self, result_list):
47 | with open(self.csv_path, 'a') as f:
48 | writer = csv.writer(f)
49 | writer.writerow(result_list)
--------------------------------------------------------------------------------
/utils/test_loader.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | from test_baseline import test_baseline
6 | from test_tde import test_tde
7 | from test_la import test_la
8 |
9 | def test_loader(config):
10 | if config['testing_opt']['type'] in ('baseline'):
11 | return test_baseline
12 | elif config['testing_opt']['type'] in ('TDE'):
13 | return test_tde
14 | elif config['testing_opt']['type'] in ('LA'):
15 | return test_la
16 | else:
17 | raise ValueError('Wrong Test Pipeline')
18 |
19 |
--------------------------------------------------------------------------------
/utils/train_loader.py:
--------------------------------------------------------------------------------
1 | ######################################
2 | # Kaihua Tang
3 | ######################################
4 |
5 | import imp
6 | from train_baseline import train_baseline
7 | from train_la import train_la
8 | from train_bbn import train_bbn
9 | from train_tde import train_tde
10 | from train_ride import train_ride
11 | from train_tade import train_tade
12 | from train_ldam import train_ldam
13 | from train_mixup import train_mixup
14 |
15 | from train_lff import train_lff
16 |
17 | from train_stage1 import train_stage1
18 | from train_stage2 import train_stage2
19 | from train_stage2_ride import train_stage2_ride
20 |
21 | from train_irm_dual import train_irm_dual
22 | from train_center_dual import train_center_dual
23 | from train_center_single import train_center_single
24 | from train_center_triple import train_center_triple
25 | from train_center_tade import train_center_tade
26 | from train_center_ride import train_center_ride
27 | from train_center_ride_mixup import train_center_ride_mixup
28 | from train_center_ldam_dual import train_center_ldam_dual
29 | from train_center_dual_mixup import train_center_dual_mixup
30 |
31 |
32 | def train_loader(config):
33 | if config['training_opt']['type'] in ('baseline', 'Focal'):
34 | return train_baseline
35 | elif config['training_opt']['type'] in ('LFF', 'LFFLA'):
36 | return train_lff
37 | elif config['training_opt']['type'] in ('LA', 'FocalLA'):
38 | return train_la
39 | elif config['training_opt']['type'] in ('BBN'):
40 | return train_bbn
41 | elif config['training_opt']['type'] in ('TDE'):
42 | return train_tde
43 | elif config['training_opt']['type'] in ('mixup'):
44 | return train_mixup
45 | elif config['training_opt']['type'] in ('LDAM'):
46 | return train_ldam
47 | elif config['training_opt']['type'] in ('RIDE'):
48 | return train_ride
49 | elif config['training_opt']['type'] in ('TADE'):
50 | return train_tade
51 | elif config['training_opt']['type'] in ('stage1'):
52 | return train_stage1
53 | elif config['training_opt']['type'] in ('crt_stage2', 'lws_stage2'):
54 | return train_stage2
55 | elif config['training_opt']['type'] in ('ride_stage2'):
56 | return train_stage2_ride
57 | elif config['training_opt']['type'] in ('center_dual', 'env_dual'):
58 | return train_center_dual
59 | elif config['training_opt']['type'] in ('center_single'):
60 | return train_center_single
61 | elif config['training_opt']['type'] in ('center_triple'):
62 | return train_center_triple
63 | elif config['training_opt']['type'] in ('center_LDAM_dual'):
64 | return train_center_ldam_dual
65 | elif config['training_opt']['type'] in ('center_dual_mixup'):
66 | return train_center_dual_mixup
67 | elif config['training_opt']['type'] in ('center_tade'):
68 | return train_center_tade
69 | elif config['training_opt']['type'] in ('center_ride'):
70 | return train_center_ride
71 | elif config['training_opt']['type'] in ('center_ride_mixup'):
72 | return train_center_ride_mixup
73 | elif config['training_opt']['type'] in ('irm_dual'):
74 | return train_irm_dual
75 | else:
76 | raise ValueError('Wrong Train Type')
--------------------------------------------------------------------------------