├── LICENSE ├── README.md ├── attribute ├── AWA2 │ ├── new_des.csv │ └── predicates.txt ├── CUB │ ├── attributes.txt │ └── new_des.csv └── SUN │ ├── attributes.mat │ └── new_des.csv ├── core ├── AWA2DataLoader.py ├── CUBDataLoader.py ├── CUBDataLoader_standard_split.py ├── DAZLE.py ├── DeepFashionDataLoader.py ├── SUNDataLoader.py ├── __init__.py └── helper_func.py ├── data ├── AWA2 │ └── instruction_AWA2.txt ├── CUB │ └── instruction_CUB.txt ├── DeepFashion │ ├── Anno │ │ └── instruction_Anno.txt │ ├── Eval │ │ └── intruction_Eval.txt │ ├── annotation.pkl │ └── instruction_DeepFashion.txt ├── SUN │ └── instruction_SUN.txt ├── standard_split │ └── instruction_standard_split.txt └── xlsa17 │ └── instruction_xlsa17.txt ├── extract_feature ├── extract_annotation_DeepFashion.py ├── extract_attribute_w2v_AWA2.py ├── extract_attribute_w2v_CUB.py ├── extract_attribute_w2v_DeepFashion.py ├── extract_attribute_w2v_SUN.py ├── extract_feature_map_ResNet_101_AWA2.py ├── extract_feature_map_ResNet_101_CUB.py ├── extract_feature_map_ResNet_101_DeepFashion.py └── extract_feature_map_ResNet_101_SUN.py ├── fig └── high_level_schematic.png ├── global_setting.py ├── notebook ├── .ipynb_checkpoints │ ├── DAZLE_AWA2-checkpoint.ipynb │ ├── DAZLE_CUB-checkpoint.ipynb │ ├── DAZLE_CUB_SS-checkpoint.ipynb │ ├── DAZLE_DeepFashion-checkpoint.ipynb │ └── DAZLE_SUN-checkpoint.ipynb ├── DAZLE_AWA2.ipynb ├── DAZLE_CUB.ipynb ├── DAZLE_CUB_SS.ipynb ├── DAZLE_DeepFashion.ipynb └── DAZLE_SUN.ipynb ├── requirements.txt └── w2v ├── AWA2_attribute.pkl ├── CUB_attribute.pkl ├── DeepFashion_attribute.pkl └── SUN_attribute.pkl /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Dat Huynh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fine-Grained Generalized Zero-Shot Learning via Dense Attribute-Based Attention 2 | 3 | ## Overview 4 | This repository contains the implementation of [Fine-Grained Generalized Zero-Shot Learning via Dense Attribute-Based Attention](http://khoury.neu.edu/home/eelhami/publications/FineGrainedZSL-CVPR20.pdf). 5 | > In this work, we develop a zero-shot fine-grained recognition with the ability to localize attributes using a dense attribute-based attention and embedding mechanism. 6 | 7 | ![Image](https://github.com/hbdat/cvpr20_DAZLE/raw/master/fig/high_level_schematic.png) 8 | 9 | --- 10 | ## Prerequisites 11 | To install all the dependency packages, please run: 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | --- 17 | ## Data Preparation 18 | 1) Please download and extract information into the `./data folder`. We include details about download links as well as what are they used for in each folder within `./data folder`. 19 | 20 | 2) **[Optional]** For DeepFashion dataset, we partition seen/unseen classes and training/testing split via: 21 | ``` 22 | python ./extract_feature/extract_annotation_DeepFashion.py #create ./data/DeepFashion/annotation.pkl 23 | ``` 24 | We have included the result file by default in the repository. Similarly, we have also included the attribute semantics from GloVe model for all datasets which are computed by: 25 | ``` 26 | python ./extract_feature/extract_attribute_w2v_DeepFashion.py #create ./w2v/DeepFashion_attribute.pkl 27 | python ./extract_feature/extract_attribute_w2v_AWA2.py #create ./w2v/AWA2_attribute.pkl 28 | python ./extract_feature/extract_attribute_w2v_CUB.py #create ./w2v/CUB_attribute.pkl 29 | python ./extract_feature/extract_attribute_w2v_SUN.py #create ./w2v/SUN_attribute.pkl 30 | ``` 31 | 32 | 3) Please run feature extraction scripts in `./extract_feature` folder to extract features from the last convolution layers of ResNet as region features for attention mechanism: 33 | ``` 34 | python ./extract_feature/extract_feature_map_ResNet_101_DeepFashion.py #create ./data/DeepFashion/feature_map_ResNet_101_DeepFashion_sep_seen_samples.hdf5 35 | python ./extract_feature/extract_feature_map_ResNet_101_AWA2.py #create ./data/AWA2/feature_map_ResNet_101_AWA2.hdf5 36 | python ./extract_feature/extract_feature_map_ResNet_101_CUB.py #create ./data/CUB/feature_map_ResNet_101_CUB.hdf5 37 | python ./extract_feature/extract_feature_map_ResNet_101_SUN.py #create ./data/SUN/feature_map_ResNet_101_SUN.hdf5 38 | ``` 39 | These scripts create hdf5 files which contain image features and data splits for training and evaluation. 40 | 41 | --- 42 | ## Training and Evaluation 43 | 1) We provide separate jupyter notebooks for training and evaluation on all four datasets in `./notebook` folder: 44 | ``` 45 | ./notebook/DAZLE_DeepFashion.ipynb 46 | ./notebook/DAZLE_AWA2.ipynb 47 | ./notebook/DAZLE_CUB.ipynb 48 | ./notebook/DAZLE_SUN.ipynb 49 | ``` 50 | 51 | ## Pretrained Models 52 | Since the training process is not resource-intensive, most experiments can be produced within 30mins. 53 | 54 | + If you need the pretrained models, please reach out to me via huynh.dat@northeastern.edu 55 | 56 | --- 57 | ## Citation 58 | If this code is helpful for your research, we would appreciate if you cite the work: 59 | ``` 60 | @article{Huynh-DAZLE:CVPR20, 61 | author = {D.~Huynh and E.~Elhamifar}, 62 | title = {Fine-Grained Generalized Zero-Shot Learning via Dense Attribute-Based Attention}, 63 | journal = {{IEEE} Conference on Computer Vision and Pattern Recognition}, 64 | year = {2020}} 65 | ``` 66 | 67 | --- 68 | ## References 69 | We adapt our dataloader classes from the following project: 70 | https://github.com/edgarschnfld/CADA-VAE-PyTorch 71 | -------------------------------------------------------------------------------- /attribute/AWA2/new_des.csv: -------------------------------------------------------------------------------- 1 | ,idx,des,new_des 2 | 0,1,black,black 3 | 1,2,white,white 4 | 2,3,blue,blue 5 | 3,4,brown,brown 6 | 4,5,gray,gray 7 | 5,6,orange,orange 8 | 6,7,red,red 9 | 7,8,yellow,yellow 10 | 8,9,patches,patches 11 | 9,10,spots,spots 12 | 10,11,stripes,stripes 13 | 11,12,furry,furry 14 | 12,13,hairless,hairless 15 | 13,14,toughskin,toughskin 16 | 14,15,big,big 17 | 15,16,small,small 18 | 16,17,bulbous,bulbous 19 | 17,18,lean,lean 20 | 18,19,flippers,flippers 21 | 19,20,hands,hands 22 | 20,21,hooves,hooves 23 | 21,22,pads,pads 24 | 22,23,paws,paws 25 | 23,24,longleg,longleg 26 | 24,25,longneck,longneck 27 | 25,26,tail,tail 28 | 26,27,chewteeth,chewteeth 29 | 27,28,meatteeth,meatteeth 30 | 28,29,buckteeth,buckteeth 31 | 29,30,strainteeth,strainteeth 32 | 30,31,horns,horns 33 | 31,32,claws,claws 34 | 32,33,tusks,tusks 35 | 33,34,smelly,smelly 36 | 34,35,flys,flys 37 | 35,36,hops,hops 38 | 36,37,swims,swims 39 | 37,38,tunnels,tunnels 40 | 38,39,walks,walks 41 | 39,40,fast,fast 42 | 40,41,slow,slow 43 | 41,42,strong,strong 44 | 42,43,weak,weak 45 | 43,44,muscle,muscle 46 | 44,45,bipedal,bipedal 47 | 45,46,quadrapedal,quadrapedal 48 | 46,47,active,active 49 | 47,48,inactive,inactive 50 | 48,49,nocturnal,nocturnal 51 | 49,50,hibernate,hibernate 52 | 50,51,agility,agility 53 | 51,52,fish,fish 54 | 52,53,meat,meat 55 | 53,54,plankton,plankton 56 | 54,55,vegetation,vegetation 57 | 55,56,insects,insects 58 | 56,57,forager,forager 59 | 57,58,grazer,grazer 60 | 58,59,hunter,hunter 61 | 59,60,scavenger,scavenger 62 | 60,61,skimmer,skimmer 63 | 61,62,stalker,stalker 64 | 62,63,newworld,newworld 65 | 63,64,oldworld,oldworld 66 | 64,65,arctic,arctic 67 | 65,66,coastal,coastal 68 | 66,67,desert,desert 69 | 67,68,bush,bush 70 | 68,69,plains,plains 71 | 69,70,forest,forest 72 | 70,71,fields,fields 73 | 71,72,jungle,jungle 74 | 72,73,mountains,mountains 75 | 73,74,ocean,ocean 76 | 74,75,ground,ground 77 | 75,76,water,water 78 | 76,77,tree,tree 79 | 77,78,cave,cave 80 | 78,79,fierce,fierce 81 | 79,80,timid,timid 82 | 80,81,smart,smart 83 | 81,82,group,group 84 | 82,83,solitary,solitary 85 | 83,84,nestspot,nestspot 86 | 84,85,domestic,domestic 87 | -------------------------------------------------------------------------------- /attribute/AWA2/predicates.txt: -------------------------------------------------------------------------------- 1 | 1 black 2 | 2 white 3 | 3 blue 4 | 4 brown 5 | 5 gray 6 | 6 orange 7 | 7 red 8 | 8 yellow 9 | 9 patches 10 | 10 spots 11 | 11 stripes 12 | 12 furry 13 | 13 hairless 14 | 14 toughskin 15 | 15 big 16 | 16 small 17 | 17 bulbous 18 | 18 lean 19 | 19 flippers 20 | 20 hands 21 | 21 hooves 22 | 22 pads 23 | 23 paws 24 | 24 longleg 25 | 25 longneck 26 | 26 tail 27 | 27 chewteeth 28 | 28 meatteeth 29 | 29 buckteeth 30 | 30 strainteeth 31 | 31 horns 32 | 32 claws 33 | 33 tusks 34 | 34 smelly 35 | 35 flys 36 | 36 hops 37 | 37 swims 38 | 38 tunnels 39 | 39 walks 40 | 40 fast 41 | 41 slow 42 | 42 strong 43 | 43 weak 44 | 44 muscle 45 | 45 bipedal 46 | 46 quadrapedal 47 | 47 active 48 | 48 inactive 49 | 49 nocturnal 50 | 50 hibernate 51 | 51 agility 52 | 52 fish 53 | 53 meat 54 | 54 plankton 55 | 55 vegetation 56 | 56 insects 57 | 57 forager 58 | 58 grazer 59 | 59 hunter 60 | 60 scavenger 61 | 61 skimmer 62 | 62 stalker 63 | 63 newworld 64 | 64 oldworld 65 | 65 arctic 66 | 66 coastal 67 | 67 desert 68 | 68 bush 69 | 69 plains 70 | 70 forest 71 | 71 fields 72 | 72 jungle 73 | 73 mountains 74 | 74 ocean 75 | 75 ground 76 | 76 water 77 | 77 tree 78 | 78 cave 79 | 79 fierce 80 | 80 timid 81 | 81 smart 82 | 82 group 83 | 83 solitary 84 | 84 nestspot 85 | 85 domestic 86 | -------------------------------------------------------------------------------- /attribute/CUB/attributes.txt: -------------------------------------------------------------------------------- 1 | 1 has_bill_shape::curved_(up_or_down) 2 | 2 has_bill_shape::dagger 3 | 3 has_bill_shape::hooked 4 | 4 has_bill_shape::needle 5 | 5 has_bill_shape::hooked_seabird 6 | 6 has_bill_shape::spatulate 7 | 7 has_bill_shape::all-purpose 8 | 8 has_bill_shape::cone 9 | 9 has_bill_shape::specialized 10 | 10 has_wing_color::blue 11 | 11 has_wing_color::brown 12 | 12 has_wing_color::iridescent 13 | 13 has_wing_color::purple 14 | 14 has_wing_color::rufous 15 | 15 has_wing_color::grey 16 | 16 has_wing_color::yellow 17 | 17 has_wing_color::olive 18 | 18 has_wing_color::green 19 | 19 has_wing_color::pink 20 | 20 has_wing_color::orange 21 | 21 has_wing_color::black 22 | 22 has_wing_color::white 23 | 23 has_wing_color::red 24 | 24 has_wing_color::buff 25 | 25 has_upperparts_color::blue 26 | 26 has_upperparts_color::brown 27 | 27 has_upperparts_color::iridescent 28 | 28 has_upperparts_color::purple 29 | 29 has_upperparts_color::rufous 30 | 30 has_upperparts_color::grey 31 | 31 has_upperparts_color::yellow 32 | 32 has_upperparts_color::olive 33 | 33 has_upperparts_color::green 34 | 34 has_upperparts_color::pink 35 | 35 has_upperparts_color::orange 36 | 36 has_upperparts_color::black 37 | 37 has_upperparts_color::white 38 | 38 has_upperparts_color::red 39 | 39 has_upperparts_color::buff 40 | 40 has_underparts_color::blue 41 | 41 has_underparts_color::brown 42 | 42 has_underparts_color::iridescent 43 | 43 has_underparts_color::purple 44 | 44 has_underparts_color::rufous 45 | 45 has_underparts_color::grey 46 | 46 has_underparts_color::yellow 47 | 47 has_underparts_color::olive 48 | 48 has_underparts_color::green 49 | 49 has_underparts_color::pink 50 | 50 has_underparts_color::orange 51 | 51 has_underparts_color::black 52 | 52 has_underparts_color::white 53 | 53 has_underparts_color::red 54 | 54 has_underparts_color::buff 55 | 55 has_breast_pattern::solid 56 | 56 has_breast_pattern::spotted 57 | 57 has_breast_pattern::striped 58 | 58 has_breast_pattern::multi-colored 59 | 59 has_back_color::blue 60 | 60 has_back_color::brown 61 | 61 has_back_color::iridescent 62 | 62 has_back_color::purple 63 | 63 has_back_color::rufous 64 | 64 has_back_color::grey 65 | 65 has_back_color::yellow 66 | 66 has_back_color::olive 67 | 67 has_back_color::green 68 | 68 has_back_color::pink 69 | 69 has_back_color::orange 70 | 70 has_back_color::black 71 | 71 has_back_color::white 72 | 72 has_back_color::red 73 | 73 has_back_color::buff 74 | 74 has_tail_shape::forked_tail 75 | 75 has_tail_shape::rounded_tail 76 | 76 has_tail_shape::notched_tail 77 | 77 has_tail_shape::fan-shaped_tail 78 | 78 has_tail_shape::pointed_tail 79 | 79 has_tail_shape::squared_tail 80 | 80 has_upper_tail_color::blue 81 | 81 has_upper_tail_color::brown 82 | 82 has_upper_tail_color::iridescent 83 | 83 has_upper_tail_color::purple 84 | 84 has_upper_tail_color::rufous 85 | 85 has_upper_tail_color::grey 86 | 86 has_upper_tail_color::yellow 87 | 87 has_upper_tail_color::olive 88 | 88 has_upper_tail_color::green 89 | 89 has_upper_tail_color::pink 90 | 90 has_upper_tail_color::orange 91 | 91 has_upper_tail_color::black 92 | 92 has_upper_tail_color::white 93 | 93 has_upper_tail_color::red 94 | 94 has_upper_tail_color::buff 95 | 95 has_head_pattern::spotted 96 | 96 has_head_pattern::malar 97 | 97 has_head_pattern::crested 98 | 98 has_head_pattern::masked 99 | 99 has_head_pattern::unique_pattern 100 | 100 has_head_pattern::eyebrow 101 | 101 has_head_pattern::eyering 102 | 102 has_head_pattern::plain 103 | 103 has_head_pattern::eyeline 104 | 104 has_head_pattern::striped 105 | 105 has_head_pattern::capped 106 | 106 has_breast_color::blue 107 | 107 has_breast_color::brown 108 | 108 has_breast_color::iridescent 109 | 109 has_breast_color::purple 110 | 110 has_breast_color::rufous 111 | 111 has_breast_color::grey 112 | 112 has_breast_color::yellow 113 | 113 has_breast_color::olive 114 | 114 has_breast_color::green 115 | 115 has_breast_color::pink 116 | 116 has_breast_color::orange 117 | 117 has_breast_color::black 118 | 118 has_breast_color::white 119 | 119 has_breast_color::red 120 | 120 has_breast_color::buff 121 | 121 has_throat_color::blue 122 | 122 has_throat_color::brown 123 | 123 has_throat_color::iridescent 124 | 124 has_throat_color::purple 125 | 125 has_throat_color::rufous 126 | 126 has_throat_color::grey 127 | 127 has_throat_color::yellow 128 | 128 has_throat_color::olive 129 | 129 has_throat_color::green 130 | 130 has_throat_color::pink 131 | 131 has_throat_color::orange 132 | 132 has_throat_color::black 133 | 133 has_throat_color::white 134 | 134 has_throat_color::red 135 | 135 has_throat_color::buff 136 | 136 has_eye_color::blue 137 | 137 has_eye_color::brown 138 | 138 has_eye_color::purple 139 | 139 has_eye_color::rufous 140 | 140 has_eye_color::grey 141 | 141 has_eye_color::yellow 142 | 142 has_eye_color::olive 143 | 143 has_eye_color::green 144 | 144 has_eye_color::pink 145 | 145 has_eye_color::orange 146 | 146 has_eye_color::black 147 | 147 has_eye_color::white 148 | 148 has_eye_color::red 149 | 149 has_eye_color::buff 150 | 150 has_bill_length::about_the_same_as_head 151 | 151 has_bill_length::longer_than_head 152 | 152 has_bill_length::shorter_than_head 153 | 153 has_forehead_color::blue 154 | 154 has_forehead_color::brown 155 | 155 has_forehead_color::iridescent 156 | 156 has_forehead_color::purple 157 | 157 has_forehead_color::rufous 158 | 158 has_forehead_color::grey 159 | 159 has_forehead_color::yellow 160 | 160 has_forehead_color::olive 161 | 161 has_forehead_color::green 162 | 162 has_forehead_color::pink 163 | 163 has_forehead_color::orange 164 | 164 has_forehead_color::black 165 | 165 has_forehead_color::white 166 | 166 has_forehead_color::red 167 | 167 has_forehead_color::buff 168 | 168 has_under_tail_color::blue 169 | 169 has_under_tail_color::brown 170 | 170 has_under_tail_color::iridescent 171 | 171 has_under_tail_color::purple 172 | 172 has_under_tail_color::rufous 173 | 173 has_under_tail_color::grey 174 | 174 has_under_tail_color::yellow 175 | 175 has_under_tail_color::olive 176 | 176 has_under_tail_color::green 177 | 177 has_under_tail_color::pink 178 | 178 has_under_tail_color::orange 179 | 179 has_under_tail_color::black 180 | 180 has_under_tail_color::white 181 | 181 has_under_tail_color::red 182 | 182 has_under_tail_color::buff 183 | 183 has_nape_color::blue 184 | 184 has_nape_color::brown 185 | 185 has_nape_color::iridescent 186 | 186 has_nape_color::purple 187 | 187 has_nape_color::rufous 188 | 188 has_nape_color::grey 189 | 189 has_nape_color::yellow 190 | 190 has_nape_color::olive 191 | 191 has_nape_color::green 192 | 192 has_nape_color::pink 193 | 193 has_nape_color::orange 194 | 194 has_nape_color::black 195 | 195 has_nape_color::white 196 | 196 has_nape_color::red 197 | 197 has_nape_color::buff 198 | 198 has_belly_color::blue 199 | 199 has_belly_color::brown 200 | 200 has_belly_color::iridescent 201 | 201 has_belly_color::purple 202 | 202 has_belly_color::rufous 203 | 203 has_belly_color::grey 204 | 204 has_belly_color::yellow 205 | 205 has_belly_color::olive 206 | 206 has_belly_color::green 207 | 207 has_belly_color::pink 208 | 208 has_belly_color::orange 209 | 209 has_belly_color::black 210 | 210 has_belly_color::white 211 | 211 has_belly_color::red 212 | 212 has_belly_color::buff 213 | 213 has_wing_shape::rounded-wings 214 | 214 has_wing_shape::pointed-wings 215 | 215 has_wing_shape::broad-wings 216 | 216 has_wing_shape::tapered-wings 217 | 217 has_wing_shape::long-wings 218 | 218 has_size::large_(16_-_32_in) 219 | 219 has_size::small_(5_-_9_in) 220 | 220 has_size::very_large_(32_-_72_in) 221 | 221 has_size::medium_(9_-_16_in) 222 | 222 has_size::very_small_(3_-_5_in) 223 | 223 has_shape::upright-perching_water-like 224 | 224 has_shape::chicken-like-marsh 225 | 225 has_shape::long-legged-like 226 | 226 has_shape::duck-like 227 | 227 has_shape::owl-like 228 | 228 has_shape::gull-like 229 | 229 has_shape::hummingbird-like 230 | 230 has_shape::pigeon-like 231 | 231 has_shape::tree-clinging-like 232 | 232 has_shape::hawk-like 233 | 233 has_shape::sandpiper-like 234 | 234 has_shape::upland-ground-like 235 | 235 has_shape::swallow-like 236 | 236 has_shape::perching-like 237 | 237 has_back_pattern::solid 238 | 238 has_back_pattern::spotted 239 | 239 has_back_pattern::striped 240 | 240 has_back_pattern::multi-colored 241 | 241 has_tail_pattern::solid 242 | 242 has_tail_pattern::spotted 243 | 243 has_tail_pattern::striped 244 | 244 has_tail_pattern::multi-colored 245 | 245 has_belly_pattern::solid 246 | 246 has_belly_pattern::spotted 247 | 247 has_belly_pattern::striped 248 | 248 has_belly_pattern::multi-colored 249 | 249 has_primary_color::blue 250 | 250 has_primary_color::brown 251 | 251 has_primary_color::iridescent 252 | 252 has_primary_color::purple 253 | 253 has_primary_color::rufous 254 | 254 has_primary_color::grey 255 | 255 has_primary_color::yellow 256 | 256 has_primary_color::olive 257 | 257 has_primary_color::green 258 | 258 has_primary_color::pink 259 | 259 has_primary_color::orange 260 | 260 has_primary_color::black 261 | 261 has_primary_color::white 262 | 262 has_primary_color::red 263 | 263 has_primary_color::buff 264 | 264 has_leg_color::blue 265 | 265 has_leg_color::brown 266 | 266 has_leg_color::iridescent 267 | 267 has_leg_color::purple 268 | 268 has_leg_color::rufous 269 | 269 has_leg_color::grey 270 | 270 has_leg_color::yellow 271 | 271 has_leg_color::olive 272 | 272 has_leg_color::green 273 | 273 has_leg_color::pink 274 | 274 has_leg_color::orange 275 | 275 has_leg_color::black 276 | 276 has_leg_color::white 277 | 277 has_leg_color::red 278 | 278 has_leg_color::buff 279 | 279 has_bill_color::blue 280 | 280 has_bill_color::brown 281 | 281 has_bill_color::iridescent 282 | 282 has_bill_color::purple 283 | 283 has_bill_color::rufous 284 | 284 has_bill_color::grey 285 | 285 has_bill_color::yellow 286 | 286 has_bill_color::olive 287 | 287 has_bill_color::green 288 | 288 has_bill_color::pink 289 | 289 has_bill_color::orange 290 | 290 has_bill_color::black 291 | 291 has_bill_color::white 292 | 292 has_bill_color::red 293 | 293 has_bill_color::buff 294 | 294 has_crown_color::blue 295 | 295 has_crown_color::brown 296 | 296 has_crown_color::iridescent 297 | 297 has_crown_color::purple 298 | 298 has_crown_color::rufous 299 | 299 has_crown_color::grey 300 | 300 has_crown_color::yellow 301 | 301 has_crown_color::olive 302 | 302 has_crown_color::green 303 | 303 has_crown_color::pink 304 | 304 has_crown_color::orange 305 | 305 has_crown_color::black 306 | 306 has_crown_color::white 307 | 307 has_crown_color::red 308 | 308 has_crown_color::buff 309 | 309 has_wing_pattern::solid 310 | 310 has_wing_pattern::spotted 311 | 311 has_wing_pattern::striped 312 | 312 has_wing_pattern::multi-colored 313 | -------------------------------------------------------------------------------- /attribute/SUN/attributes.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbdat/cvpr20_DAZLE/e5a04601d8368903008fd96dee3d95dde398aa51/attribute/SUN/attributes.mat -------------------------------------------------------------------------------- /attribute/SUN/new_des.csv: -------------------------------------------------------------------------------- 1 | ,new_des 2 | 0,sailing boating 3 | 1,driving 4 | 2,biking 5 | 3,transporting things or people 6 | 4,sunbathing 7 | 5,vacationing touring 8 | 6,hiking 9 | 7,climbing 10 | 8,camping 11 | 9,reading 12 | 10,studying learning 13 | 11,teaching training 14 | 12,research 15 | 13,diving 16 | 14,swimming 17 | 15,bathing 18 | 16,eating 19 | 17,cleaning 20 | 18,socializing 21 | 19,congregating 22 | 20,waiting in line queuing 23 | 21,competing 24 | 22,sports 25 | 23,exercise 26 | 24,playing 27 | 25,gaming 28 | 26,spectating being in an audience 29 | 27,farming 30 | 28,constructing building 31 | 29,shopping 32 | 30,medical activity 33 | 31,working 34 | 32,using tools 35 | 33,digging 36 | 34,conducting business 37 | 35,praying 38 | 36,fencing 39 | 37,railing 40 | 38,wire 41 | 39,railroad 42 | 40,trees 43 | 41,grass 44 | 42,vegetation 45 | 43,shrubbery 46 | 44,foliage 47 | 45,leaves 48 | 46,flowers 49 | 47,asphalt 50 | 48,pavement 51 | 49,shingles 52 | 50,carpet 53 | 51,brick 54 | 52,tiles 55 | 53,concrete 56 | 54,metal 57 | 55,paper 58 | 56,wood (not part of a tree) 59 | 57,vinyl linoleum 60 | 58,rubber plastic 61 | 59,cloth 62 | 60,sand 63 | 61,rock stone 64 | 62,dirt soil 65 | 63,marble 66 | 64,glass 67 | 65,waves surf 68 | 66,ocean 69 | 67,running water 70 | 68,still water 71 | 69,ice 72 | 70,snow 73 | 71,clouds 74 | 72,smoke 75 | 73,fire 76 | 74,natural light 77 | 75,direct sun sunny 78 | 76,electric indoor lighting 79 | 77,aged worn 80 | 78,glossy 81 | 79,matte 82 | 80,sterile 83 | 81,moist damp 84 | 82,dry 85 | 83,dirty 86 | 84,rusty 87 | 85,warm 88 | 86,cold 89 | 87,natural 90 | 88,man-made 91 | 89,open area 92 | 90,semi enclosed area 93 | 91,enclosed area 94 | 92,faraway horizon 95 | 93,no horizon 96 | 94,rugged scene 97 | 95,mostly vertical components 98 | 96,mostly horizontal components 99 | 97,symmetrical 100 | 98,cluttered space 101 | 99,scary 102 | 100,soothing 103 | 101,stressful 104 | -------------------------------------------------------------------------------- /core/AWA2DataLoader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jul 20 21:23:18 2019 4 | 5 | @author: badat 6 | """ 7 | 8 | import os,sys 9 | #import scipy.io as sio 10 | import torch 11 | import numpy as np 12 | import h5py 13 | import time 14 | import pickle 15 | from sklearn import preprocessing 16 | from global_setting import NFS_path 17 | #%% 18 | import scipy.io as sio 19 | import pandas as pd 20 | #%% 21 | import pdb 22 | #%% 23 | dataset = 'AWA2' 24 | img_dir = os.path.join(NFS_path,'data/{}/'.format(dataset)) 25 | mat_path = os.path.join(NFS_path,'data/xlsa17/data/{}/res101.mat'.format(dataset)) 26 | attr_path = './attribute/{}/new_des.csv'.format(dataset) 27 | 28 | 29 | class AWA2DataLoader(): 30 | def __init__(self, data_path, device, is_scale = False,is_balance =True): 31 | 32 | print(data_path) 33 | sys.path.append(data_path) 34 | 35 | self.data_path = data_path 36 | self.device = device 37 | self.dataset = 'AWA2' 38 | print('$'*30) 39 | print(self.dataset) 40 | print('$'*30) 41 | self.datadir = self.data_path + 'data/{}/'.format(self.dataset) 42 | self.index_in_epoch = 0 43 | self.epochs_completed = 0 44 | self.is_scale = is_scale 45 | self.is_balance = is_balance 46 | if self.is_balance: 47 | print('Balance dataloader') 48 | self.read_matdataset() 49 | self.get_idx_classes() 50 | 51 | 52 | def augment_img_path(self,mat_path=mat_path,img_dir=img_dir): 53 | self.matcontent = sio.loadmat(mat_path) 54 | self.image_files = np.squeeze(self.matcontent['image_files']) 55 | 56 | def convert_path(image_files,img_dir): 57 | new_image_files = [] 58 | for idx in range(len(image_files)): 59 | image_file = image_files[idx][0] 60 | image_file = os.path.join(img_dir,'/'.join(image_file.split('/')[5:])) 61 | new_image_files.append(image_file) 62 | return np.array(new_image_files) 63 | 64 | self.image_files = convert_path(self.image_files,img_dir) 65 | 66 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 67 | hf = h5py.File(path, 'r') 68 | 69 | trainval_loc = np.array(hf.get('trainval_loc')) 70 | test_seen_loc = np.array(hf.get('test_seen_loc')) 71 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 72 | 73 | self.data['train_seen']['img_path'] = self.image_files[trainval_loc] 74 | self.data['test_seen']['img_path'] = self.image_files[test_seen_loc] 75 | self.data['test_unseen']['img_path'] = self.image_files[test_unseen_loc] 76 | 77 | self.attr_name = pd.read_csv(attr_path)['new_des'] 78 | 79 | 80 | def next_batch_img(self, batch_size,class_id,is_trainset = False): 81 | features = None 82 | labels = None 83 | img_files = None 84 | if class_id in self.seenclasses: 85 | if is_trainset: 86 | features = self.data['train_seen']['resnet_features'] 87 | labels = self.data['train_seen']['labels'] 88 | img_files = self.data['train_seen']['img_path'] 89 | else: 90 | features = self.data['test_seen']['resnet_features'] 91 | labels = self.data['test_seen']['labels'] 92 | img_files = self.data['test_seen']['img_path'] 93 | elif class_id in self.unseenclasses: 94 | features = self.data['test_unseen']['resnet_features'] 95 | labels = self.data['test_unseen']['labels'] 96 | img_files = self.data['test_unseen']['img_path'] 97 | else: 98 | raise Exception("Cannot find this class {}".format(class_id)) 99 | 100 | #note that img_files is numpy type !!!!! 101 | 102 | idx_c = torch.squeeze(torch.nonzero(labels == class_id)) 103 | 104 | features = features[idx_c] 105 | labels = labels[idx_c] 106 | img_files = img_files[idx_c.cpu().numpy()] 107 | 108 | batch_label = labels[:batch_size].to(self.device) 109 | batch_feature = features[:batch_size].to(self.device) 110 | batch_files = img_files[:batch_size] 111 | batch_att = self.att[batch_label].to(self.device) 112 | 113 | return batch_label, batch_feature,batch_files, batch_att 114 | 115 | def next_batch(self, batch_size): 116 | if self.is_balance: 117 | idx = [] 118 | n_samples_class = max(batch_size //self.ntrain_class,1) 119 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 120 | for i_c in sampled_idx_c: 121 | idxs = self.idxs_list[i_c] 122 | idx.append(np.random.choice(idxs,n_samples_class)) 123 | idx = np.concatenate(idx) 124 | idx = torch.from_numpy(idx) 125 | else: 126 | idx = torch.randperm(self.ntrain)[0:batch_size] 127 | 128 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 129 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 130 | batch_att = self.att[batch_label].to(self.device) 131 | return batch_label, batch_feature, batch_att 132 | 133 | def get_idx_classes(self): 134 | n_classes = self.seenclasses.size(0) 135 | self.idxs_list = [] 136 | train_label = self.data['train_seen']['labels'] 137 | for i in range(n_classes): 138 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 139 | idx_c = np.squeeze(idx_c) 140 | self.idxs_list.append(idx_c) 141 | return self.idxs_list 142 | 143 | def read_matdataset(self): 144 | 145 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 146 | print('_____') 147 | print(path) 148 | tic = time.clock() 149 | hf = h5py.File(path, 'r') 150 | features = np.array(hf.get('feature_map')) 151 | # shape = features.shape 152 | # features = features.reshape(shape[0],shape[1],shape[2]*shape[3]) 153 | labels = np.array(hf.get('labels')) 154 | trainval_loc = np.array(hf.get('trainval_loc')) 155 | # train_loc = np.array(hf.get('train_loc')) #--> train_feature = TRAIN SEEN 156 | # val_unseen_loc = np.array(hf.get('val_unseen_loc')) #--> test_unseen_feature = TEST UNSEEN 157 | test_seen_loc = np.array(hf.get('test_seen_loc')) 158 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 159 | 160 | 161 | print('Expert Attr') 162 | att = np.array(hf.get('att')) 163 | 164 | print("threshold at zero attribute with negative value") 165 | att[att<0]=0 166 | 167 | self.att = torch.from_numpy(att).float().to(self.device) 168 | 169 | original_att = np.array(hf.get('original_att')) 170 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 171 | 172 | w2v_att = np.array(hf.get('w2v_att')) 173 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 174 | 175 | self.normalize_att = self.original_att/100 176 | 177 | print('Finish loading data in ',time.clock()-tic) 178 | 179 | train_feature = features[trainval_loc] 180 | test_seen_feature = features[test_seen_loc] 181 | test_unseen_feature = features[test_unseen_loc] 182 | if self.is_scale: 183 | scaler = preprocessing.MinMaxScaler() 184 | 185 | train_feature = scaler.fit_transform(train_feature) 186 | test_seen_feature = scaler.fit_transform(test_seen_feature) 187 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 188 | 189 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 190 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 191 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 192 | 193 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 194 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 195 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 196 | 197 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 198 | 199 | 200 | 201 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 202 | self.ntrain = train_feature.size()[0] 203 | self.ntrain_class = self.seenclasses.size(0) 204 | self.ntest_class = self.unseenclasses.size(0) 205 | self.train_class = self.seenclasses.clone() 206 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 207 | 208 | # self.train_mapped_label = map_label(train_label, self.seenclasses) 209 | 210 | self.data = {} 211 | self.data['train_seen'] = {} 212 | self.data['train_seen']['resnet_features'] = train_feature 213 | self.data['train_seen']['labels']= train_label 214 | 215 | 216 | self.data['train_unseen'] = {} 217 | self.data['train_unseen']['resnet_features'] = None 218 | self.data['train_unseen']['labels'] = None 219 | 220 | self.data['test_seen'] = {} 221 | self.data['test_seen']['resnet_features'] = test_seen_feature 222 | self.data['test_seen']['labels'] = test_seen_label 223 | 224 | self.data['test_unseen'] = {} 225 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 226 | self.data['test_unseen']['labels'] = test_unseen_label 227 | -------------------------------------------------------------------------------- /core/CUBDataLoader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jul 4 11:53:09 2019 4 | 5 | @author: badat 6 | """ 7 | import os,sys 8 | #import scipy.io as sio 9 | import torch 10 | import numpy as np 11 | import h5py 12 | import time 13 | import pickle 14 | import pdb 15 | from sklearn import preprocessing 16 | from global_setting import NFS_path 17 | #%% 18 | import scipy.io as sio 19 | import pandas as pd 20 | #%% 21 | #import pdb 22 | #%% 23 | 24 | img_dir = os.path.join(NFS_path,'data/CUB/') 25 | mat_path = os.path.join(NFS_path,'data/xlsa17/data/CUB/res101.mat') 26 | attr_path = './attribute/CUB/new_des.csv' 27 | 28 | class CUBDataLoader(): 29 | def __init__(self, data_path, device, is_scale = False, is_balance=True): 30 | 31 | print(data_path) 32 | sys.path.append(data_path) 33 | 34 | self.data_path = data_path 35 | self.device = device 36 | self.dataset = 'CUB' 37 | print('$'*30) 38 | print(self.dataset) 39 | print('$'*30) 40 | self.datadir = self.data_path + 'data/{}/'.format(self.dataset) 41 | self.index_in_epoch = 0 42 | self.epochs_completed = 0 43 | self.is_scale = is_scale 44 | self.is_balance = is_balance 45 | if self.is_balance: 46 | print('Balance dataloader') 47 | self.read_matdataset() 48 | self.get_idx_classes() 49 | 50 | def augment_img_path(self,mat_path=mat_path,img_dir=img_dir): 51 | self.matcontent = sio.loadmat(mat_path) 52 | self.image_files = np.squeeze(self.matcontent['image_files']) 53 | 54 | def convert_path(image_files,img_dir): 55 | new_image_files = [] 56 | for idx in range(len(image_files)): 57 | image_file = image_files[idx][0] 58 | image_file = os.path.join(img_dir,'/'.join(image_file.split('/')[5:])) 59 | new_image_files.append(image_file) 60 | return np.array(new_image_files) 61 | 62 | self.image_files = convert_path(self.image_files,img_dir) 63 | 64 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 65 | hf = h5py.File(path, 'r') 66 | 67 | trainval_loc = np.array(hf.get('trainval_loc')) 68 | test_seen_loc = np.array(hf.get('test_seen_loc')) 69 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 70 | 71 | self.data['train_seen']['img_path'] = self.image_files[trainval_loc] 72 | self.data['test_seen']['img_path'] = self.image_files[test_seen_loc] 73 | self.data['test_unseen']['img_path'] = self.image_files[test_unseen_loc] 74 | 75 | self.attr_name = pd.read_csv(attr_path)['new_des'] 76 | 77 | 78 | 79 | def next_batch_img(self, batch_size,class_id,is_trainset = False): 80 | features = None 81 | labels = None 82 | img_files = None 83 | if class_id in self.seenclasses: 84 | if is_trainset: 85 | features = self.data['train_seen']['resnet_features'] 86 | labels = self.data['train_seen']['labels'] 87 | img_files = self.data['train_seen']['img_path'] 88 | else: 89 | features = self.data['test_seen']['resnet_features'] 90 | labels = self.data['test_seen']['labels'] 91 | img_files = self.data['test_seen']['img_path'] 92 | elif class_id in self.unseenclasses: 93 | features = self.data['test_unseen']['resnet_features'] 94 | labels = self.data['test_unseen']['labels'] 95 | img_files = self.data['test_unseen']['img_path'] 96 | else: 97 | raise Exception("Cannot find this class {}".format(class_id)) 98 | 99 | #note that img_files is numpy type !!!!! 100 | 101 | idx_c = torch.squeeze(torch.nonzero(labels == class_id)) 102 | 103 | features = features[idx_c] 104 | labels = labels[idx_c] 105 | img_files = img_files[idx_c.cpu().numpy()] 106 | 107 | batch_label = labels[:batch_size].to(self.device) 108 | batch_feature = features[:batch_size].to(self.device) 109 | batch_files = img_files[:batch_size] 110 | batch_att = self.att[batch_label].to(self.device) 111 | 112 | return batch_label, batch_feature,batch_files, batch_att 113 | 114 | 115 | def next_batch(self, batch_size): 116 | if self.is_balance: 117 | idx = [] 118 | n_samples_class = max(batch_size //self.ntrain_class,1) 119 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 120 | for i_c in sampled_idx_c: 121 | idxs = self.idxs_list[i_c] 122 | idx.append(np.random.choice(idxs,n_samples_class)) 123 | idx = np.concatenate(idx) 124 | idx = torch.from_numpy(idx) 125 | else: 126 | idx = torch.randperm(self.ntrain)[0:batch_size] 127 | 128 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 129 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 130 | batch_att = self.att[batch_label].to(self.device) 131 | return batch_label, batch_feature, batch_att 132 | 133 | def get_idx_classes(self): 134 | n_classes = self.seenclasses.size(0) 135 | self.idxs_list = [] 136 | train_label = self.data['train_seen']['labels'] 137 | for i in range(n_classes): 138 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 139 | idx_c = np.squeeze(idx_c) 140 | self.idxs_list.append(idx_c) 141 | return self.idxs_list 142 | 143 | 144 | def read_matdataset(self): 145 | 146 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 147 | print('_____') 148 | print(path) 149 | tic = time.clock() 150 | hf = h5py.File(path, 'r') 151 | features = np.array(hf.get('feature_map')) 152 | # shape = features.shape 153 | # features = features.reshape(shape[0],shape[1],shape[2]*shape[3]) 154 | # pdb.set_trace() 155 | labels = np.array(hf.get('labels')) 156 | trainval_loc = np.array(hf.get('trainval_loc')) 157 | # train_loc = np.array(hf.get('train_loc')) #--> train_feature = TRAIN SEEN 158 | # val_unseen_loc = np.array(hf.get('val_unseen_loc')) #--> test_unseen_feature = TEST UNSEEN 159 | test_seen_loc = np.array(hf.get('test_seen_loc')) 160 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 161 | 162 | 163 | print('Expert Attr') 164 | att = np.array(hf.get('att')) 165 | self.att = torch.from_numpy(att).float().to(self.device) 166 | 167 | original_att = np.array(hf.get('original_att')) 168 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 169 | 170 | w2v_att = np.array(hf.get('w2v_att')) 171 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 172 | 173 | self.normalize_att = self.original_att/100 174 | 175 | print('Finish loading data in ',time.clock()-tic) 176 | 177 | train_feature = features[trainval_loc] 178 | test_seen_feature = features[test_seen_loc] 179 | test_unseen_feature = features[test_unseen_loc] 180 | if self.is_scale: 181 | scaler = preprocessing.MinMaxScaler() 182 | 183 | train_feature = scaler.fit_transform(train_feature) 184 | test_seen_feature = scaler.fit_transform(test_seen_feature) 185 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 186 | 187 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 188 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 189 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 190 | 191 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 192 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 193 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 194 | 195 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 196 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 197 | self.ntrain = train_feature.size()[0] 198 | self.ntrain_class = self.seenclasses.size(0) 199 | self.ntest_class = self.unseenclasses.size(0) 200 | self.train_class = self.seenclasses.clone() 201 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 202 | 203 | # self.train_mapped_label = map_label(train_label, self.seenclasses) 204 | 205 | self.data = {} 206 | self.data['train_seen'] = {} 207 | self.data['train_seen']['resnet_features'] = train_feature 208 | self.data['train_seen']['labels']= train_label 209 | 210 | 211 | self.data['train_unseen'] = {} 212 | self.data['train_unseen']['resnet_features'] = None 213 | self.data['train_unseen']['labels'] = None 214 | 215 | self.data['test_seen'] = {} 216 | self.data['test_seen']['resnet_features'] = test_seen_feature 217 | self.data['test_seen']['labels'] = test_seen_label 218 | 219 | self.data['test_unseen'] = {} 220 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 221 | self.data['test_unseen']['labels'] = test_unseen_label 222 | -------------------------------------------------------------------------------- /core/CUBDataLoader_standard_split.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Oct 23 12:31:01 2019 4 | 5 | @author: Warmachine 6 | """ 7 | 8 | import os,sys 9 | #import scipy.io as sio 10 | import torch 11 | import numpy as np 12 | import h5py 13 | import time 14 | import pickle 15 | import pdb 16 | from sklearn import preprocessing 17 | from global_setting import NFS_path 18 | #%% 19 | import scipy.io as sio 20 | import pandas as pd 21 | #%% 22 | #import pdb 23 | #%% 24 | 25 | img_dir = os.path.join(NFS_path,'data/CUB/') 26 | mat_path = os.path.join(NFS_path,'data/xlsa17/data/CUB/res101.mat') 27 | attr_path = './attribute/CUB/new_des.csv' 28 | 29 | class CUBDataLoader(): 30 | def __init__(self, data_path, device, is_scale = False, is_balance=True): 31 | print("!!!!!!!!!! Standard Split !!!!!!!!!!") 32 | print(data_path) 33 | sys.path.append(data_path) 34 | 35 | self.data_path = data_path 36 | self.device = device 37 | self.dataset = 'CUB' 38 | print('$'*30) 39 | print(self.dataset) 40 | print('$'*30) 41 | self.datadir = self.data_path + 'data/{}/'.format(self.dataset) 42 | 43 | self.standard_loc = self.data_path + 'data/standard_split/{}/att_splits.mat'.format(self.dataset) 44 | 45 | self.index_in_epoch = 0 46 | self.epochs_completed = 0 47 | self.is_scale = is_scale 48 | self.is_balance = is_balance 49 | if self.is_balance: 50 | print('Balance dataloader') 51 | self.read_matdataset() 52 | self.get_idx_classes() 53 | 54 | def augment_img_path(self,mat_path=mat_path,img_dir=img_dir): 55 | self.matcontent = sio.loadmat(mat_path) 56 | self.image_files = np.squeeze(self.matcontent['image_files']) 57 | 58 | def convert_path(image_files,img_dir): 59 | new_image_files = [] 60 | for idx in range(len(image_files)): 61 | image_file = image_files[idx][0] 62 | image_file = os.path.join(img_dir,'/'.join(image_file.split('/')[5:])) 63 | new_image_files.append(image_file) 64 | return np.array(new_image_files) 65 | 66 | self.image_files = convert_path(self.image_files,img_dir) 67 | 68 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 69 | hf = h5py.File(path, 'r') 70 | 71 | trainval_loc = np.array(hf.get('trainval_loc')) 72 | test_seen_loc = np.array(hf.get('test_seen_loc')) 73 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 74 | 75 | self.data['train_seen']['img_path'] = self.image_files[trainval_loc] 76 | self.data['test_seen']['img_path'] = self.image_files[test_seen_loc] 77 | self.data['test_unseen']['img_path'] = self.image_files[test_unseen_loc] 78 | 79 | self.attr_name = pd.read_csv(attr_path)['new_des'] 80 | 81 | 82 | 83 | def next_batch_img(self, batch_size,class_id,is_trainset = False): 84 | features = None 85 | labels = None 86 | img_files = None 87 | if class_id in self.seenclasses: 88 | if is_trainset: 89 | features = self.data['train_seen']['resnet_features'] 90 | labels = self.data['train_seen']['labels'] 91 | img_files = self.data['train_seen']['img_path'] 92 | else: 93 | features = self.data['test_seen']['resnet_features'] 94 | labels = self.data['test_seen']['labels'] 95 | img_files = self.data['test_seen']['img_path'] 96 | elif class_id in self.unseenclasses: 97 | features = self.data['test_unseen']['resnet_features'] 98 | labels = self.data['test_unseen']['labels'] 99 | img_files = self.data['test_unseen']['img_path'] 100 | else: 101 | raise Exception("Cannot find this class {}".format(class_id)) 102 | 103 | #note that img_files is numpy type !!!!! 104 | 105 | idx_c = torch.squeeze(torch.nonzero(labels == class_id)) 106 | 107 | features = features[idx_c] 108 | labels = labels[idx_c] 109 | img_files = img_files[idx_c.cpu().numpy()] 110 | 111 | batch_label = labels[:batch_size].to(self.device) 112 | batch_feature = features[:batch_size].to(self.device) 113 | batch_files = img_files[:batch_size] 114 | batch_att = self.att[batch_label].to(self.device) 115 | 116 | return batch_label, batch_feature,batch_files, batch_att 117 | 118 | 119 | def next_batch(self, batch_size): 120 | if self.is_balance: 121 | idx = [] 122 | n_samples_class = max(batch_size //self.ntrain_class,1) 123 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 124 | for i_c in sampled_idx_c: 125 | idxs = self.idxs_list[i_c] 126 | idx.append(np.random.choice(idxs,n_samples_class)) 127 | idx = np.concatenate(idx) 128 | idx = torch.from_numpy(idx) 129 | else: 130 | idx = torch.randperm(self.ntrain)[0:batch_size] 131 | 132 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 133 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 134 | batch_att = self.att[batch_label].to(self.device) 135 | return batch_label, batch_feature, batch_att 136 | 137 | def get_idx_classes(self): 138 | n_classes = self.seenclasses.size(0) 139 | self.idxs_list = [] 140 | train_label = self.data['train_seen']['labels'] 141 | for i in range(n_classes): 142 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 143 | idx_c = np.squeeze(idx_c) 144 | self.idxs_list.append(idx_c) 145 | return self.idxs_list 146 | 147 | 148 | def read_matdataset(self): 149 | 150 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 151 | print('_____') 152 | print(path) 153 | tic = time.clock() 154 | hf = h5py.File(path, 'r') 155 | features = np.array(hf.get('feature_map')) 156 | 157 | labels = np.array(hf.get('labels')) 158 | # trainval_loc = np.array(hf.get('trainval_loc')) 159 | # test_seen_loc = np.array(hf.get('test_seen_loc')) 160 | # test_unseen_loc = np.array(hf.get('test_unseen_loc')) 161 | mat_stand_split = sio.loadmat(self.standard_loc) 162 | trainval_loc = np.squeeze(mat_stand_split['trainval_loc'])-1 163 | test_seen_loc = np.array([],dtype=np.uint16) 164 | test_unseen_loc = np.squeeze(mat_stand_split['test_unseen_loc'])-1 165 | 166 | 167 | print('Expert Attr') 168 | att = np.array(hf.get('att')) 169 | self.att = torch.from_numpy(att).float().to(self.device) 170 | 171 | original_att = np.array(hf.get('original_att')) 172 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 173 | 174 | w2v_att = np.array(hf.get('w2v_att')) 175 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 176 | 177 | self.normalize_att = self.original_att/100 178 | 179 | print('Finish loading data in ',time.clock()-tic) 180 | 181 | train_feature = features[trainval_loc] 182 | test_seen_feature = features[test_seen_loc] 183 | test_unseen_feature = features[test_unseen_loc] 184 | if self.is_scale: 185 | scaler = preprocessing.MinMaxScaler() 186 | 187 | train_feature = scaler.fit_transform(train_feature) 188 | test_seen_feature = scaler.fit_transform(test_seen_feature) 189 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 190 | 191 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 192 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 193 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 194 | 195 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 196 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 197 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 198 | 199 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 200 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 201 | self.ntrain = train_feature.size()[0] 202 | self.ntrain_class = self.seenclasses.size(0) 203 | self.ntest_class = self.unseenclasses.size(0) 204 | self.train_class = self.seenclasses.clone() 205 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 206 | 207 | # self.train_mapped_label = map_label(train_label, self.seenclasses) 208 | 209 | self.data = {} 210 | self.data['train_seen'] = {} 211 | self.data['train_seen']['resnet_features'] = train_feature 212 | self.data['train_seen']['labels']= train_label 213 | 214 | 215 | self.data['train_unseen'] = {} 216 | self.data['train_unseen']['resnet_features'] = None 217 | self.data['train_unseen']['labels'] = None 218 | 219 | self.data['test_seen'] = {} 220 | self.data['test_seen']['resnet_features'] = test_seen_feature 221 | self.data['test_seen']['labels'] = test_seen_label 222 | 223 | self.data['test_unseen'] = {} 224 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 225 | self.data['test_unseen']['labels'] = test_unseen_label 226 | -------------------------------------------------------------------------------- /core/DAZLE.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jul 4 17:39:45 2019 4 | 5 | @author: badat 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import numpy as np 12 | #%% 13 | import pdb 14 | #%% 15 | 16 | class DAZLE(nn.Module): 17 | ##### 18 | # einstein sum notation 19 | # b: Batch size \ f: dim feature \ v: dim w2v \ r: number of region \ k: number of classes 20 | # i: number of attribute \ h : hidden attention dim 21 | ##### 22 | def __init__(self,dim_f,dim_v, 23 | init_w2v_att,att,normalize_att, 24 | seenclass,unseenclass, 25 | lambda_, 26 | trainable_w2v = False, normalize_V = False, normalize_F = False, is_conservative = False, 27 | prob_prune=0,desired_mass = -1,uniform_att_1 = False,uniform_att_2 = False, is_conv = False, 28 | is_bias = False,bias = 1,non_linear_act=False, 29 | loss_type = 'CE',non_linear_emb = False, 30 | is_sigmoid = False): 31 | super(DAZLE, self).__init__() 32 | self.dim_f = dim_f 33 | self.dim_v = dim_v 34 | self.dim_att = att.shape[1] 35 | self.nclass = att.shape[0] 36 | self.hidden = self.dim_att//2 37 | self.init_w2v_att = init_w2v_att 38 | self.non_linear_act = non_linear_act 39 | self.loss_type = loss_type 40 | if is_conv: 41 | r_dim = dim_f//2 42 | self.conv1 = nn.Conv2d(dim_f, r_dim, 2) #[2x2] kernel with same input and output dims 43 | print('***Reduce dim {} -> {}***'.format(self.dim_f,r_dim)) 44 | self.dim_f = r_dim 45 | self.conv1_bn = nn.BatchNorm2d(self.dim_f) 46 | 47 | 48 | if init_w2v_att is None: 49 | self.V = nn.Parameter(nn.init.normal_(torch.empty(self.dim_att,self.dim_v)),requires_grad = True) 50 | else: 51 | self.init_w2v_att = F.normalize(torch.tensor(init_w2v_att)) 52 | self.V = nn.Parameter(self.init_w2v_att.clone(),requires_grad = trainable_w2v) 53 | 54 | self.att = nn.Parameter(F.normalize(torch.tensor(att)),requires_grad = False) 55 | 56 | self.W_1 = nn.Parameter(nn.init.normal_(torch.empty(self.dim_v,self.dim_f)),requires_grad = True) #nn.utils.weight_norm(nn.Linear(self.dim_v,self.dim_f,bias=False))# 57 | self.W_2 = nn.Parameter(nn.init.zeros_(torch.empty(self.dim_v,self.dim_f)),requires_grad = True) #nn.utils.weight_norm(nn.Linear(self.dim_v,self.dim_f,bias=False))# 58 | ## second layer attenion conditioned on image features 59 | self.W_3 = nn.Parameter(nn.init.zeros_(torch.empty(self.dim_v,self.dim_f)),requires_grad = True) 60 | 61 | ## Compute the similarity between classes 62 | self.P = torch.mm(self.att,torch.transpose(self.att,1,0)) 63 | assert self.P.size(1)==self.P.size(0) and self.P.size(0)==self.nclass 64 | self.weight_ce = nn.Parameter(torch.eye(self.nclass).float(),requires_grad = False)#nn.Parameter(torch.tensor(weight_ce).float(),requires_grad = False) 65 | 66 | self.normalize_V = normalize_V 67 | self.normalize_F = normalize_F 68 | self.is_conservative = is_conservative 69 | self.is_conv = is_conv 70 | self.is_bias = is_bias 71 | 72 | self.seenclass = seenclass 73 | self.unseenclass = unseenclass 74 | self.normalize_att = normalize_att 75 | 76 | if is_bias: 77 | self.bias = nn.Parameter(torch.tensor(bias),requires_grad = False) 78 | mask_bias = np.ones((1,self.nclass)) 79 | mask_bias[:,self.seenclass.cpu().numpy()] *= -1 80 | self.mask_bias = nn.Parameter(torch.tensor(mask_bias).float(),requires_grad = False) 81 | 82 | if desired_mass == -1: 83 | self.desired_mass = self.unseenclass.size(0)/self.nclass#nn.Parameter(torch.tensor(self.unseenclass.size(0)/self.nclass),requires_grad = False)#nn.Parameter(torch.tensor(0.1),requires_grad = False)# 84 | else: 85 | self.desired_mass = desired_mass#nn.Parameter(torch.tensor(desired_mass),requires_grad = False)#nn.Parameter(torch.tensor(self.unseenclass.size(0)/self.nclass),requires_grad = False)# 86 | self.prob_prune = nn.Parameter(torch.tensor(prob_prune),requires_grad = False) 87 | 88 | self.lambda_ = lambda_ 89 | self.loss_att_func = nn.BCEWithLogitsLoss() 90 | self.log_softmax_func = nn.LogSoftmax(dim=1) 91 | self.uniform_att_1 = uniform_att_1 92 | self.uniform_att_2 = uniform_att_2 93 | 94 | self.non_linear_emb = non_linear_emb 95 | 96 | 97 | print('-'*30) 98 | print('Configuration') 99 | 100 | print('loss_type {}'.format(loss_type)) 101 | 102 | if self.is_conv: 103 | print('Learn CONV layer correct') 104 | 105 | if self.normalize_V: 106 | print('normalize V') 107 | else: 108 | print('no constraint V') 109 | 110 | if self.normalize_F: 111 | print('normalize F') 112 | else: 113 | print('no constraint F') 114 | 115 | if self.is_conservative: 116 | print('training to exclude unseen class [seen upperbound]') 117 | if init_w2v_att is None: 118 | print('Learning word2vec from scratch with dim {}'.format(self.V.size())) 119 | else: 120 | print('Init word2vec') 121 | 122 | if self.non_linear_act: 123 | print('Non-linear relu model') 124 | else: 125 | print('Linear model') 126 | 127 | print('loss_att {}'.format(self.loss_att_func)) 128 | print('Bilinear attention module') 129 | print('*'*30) 130 | print('Measure w2v deviation') 131 | if self.uniform_att_1: 132 | print('WARNING: UNIFORM ATTENTION LEVEL 1') 133 | if self.uniform_att_2: 134 | print('WARNING: UNIFORM ATTENTION LEVEL 2') 135 | print('Compute Pruning loss {}'.format(self.prob_prune)) 136 | if self.is_bias: 137 | print('Add one smoothing') 138 | print('Second layer attenion conditioned on image features') 139 | print('-'*30) 140 | 141 | if self.non_linear_emb: 142 | print('non_linear embedding') 143 | self.emb_func = torch.nn.Sequential( 144 | torch.nn.Linear(self.dim_att, self.dim_att//2), 145 | torch.nn.ReLU(), 146 | torch.nn.Linear(self.dim_att//2, 1), 147 | ) 148 | 149 | self.is_sigmoid = is_sigmoid 150 | if self.is_sigmoid: 151 | print("Sigmoid on attr score!!!") 152 | else: 153 | print("No sigmoid on attr score") 154 | 155 | 156 | def compute_loss_rank(self,in_package): 157 | # this is pairwise ranking loss 158 | batch_label = in_package['batch_label'] 159 | S_pp = in_package['S_pp'] 160 | 161 | batch_label_idx = torch.argmax(batch_label,dim = 1) 162 | 163 | s_c = torch.gather(S_pp,1,batch_label_idx.view(-1,1)) 164 | if self.is_conservative: 165 | S_seen = S_pp 166 | else: 167 | S_seen = S_pp[:,self.seenclass] 168 | assert S_seen.size(1) == len(self.seenclass) 169 | 170 | margin = 1-(s_c-S_seen) 171 | loss_rank = torch.max(margin,torch.zeros_like(margin)) 172 | loss_rank = torch.mean(loss_rank) 173 | return loss_rank 174 | 175 | def compute_loss_Self_Calibrate(self,in_package): 176 | S_pp = in_package['S_pp'] 177 | Prob_all = F.softmax(S_pp,dim=-1) 178 | Prob_unseen = Prob_all[:,self.unseenclass] 179 | assert Prob_unseen.size(1) == len(self.unseenclass) 180 | mass_unseen = torch.sum(Prob_unseen,dim=1) 181 | 182 | loss_pmp = -torch.log(torch.mean(mass_unseen)) 183 | return loss_pmp 184 | 185 | def compute_V(self): 186 | if self.normalize_V: 187 | V_n = F.normalize(self.V) 188 | else: 189 | V_n = self.V 190 | return V_n 191 | 192 | def compute_aug_cross_entropy(self,in_package): 193 | batch_label = in_package['batch_label'] 194 | S_pp = in_package['S_pp'] 195 | 196 | Labels = batch_label 197 | 198 | if self.is_bias: 199 | S_pp = S_pp - self.vec_bias # remove the margin +1/-1 from prediction scores 200 | 201 | if not self.is_conservative: 202 | S_pp = S_pp[:,self.seenclass] 203 | Labels = Labels[:,self.seenclass] 204 | assert S_pp.size(1) == len(self.seenclass) 205 | 206 | Prob = self.log_softmax_func(S_pp) 207 | 208 | loss = -torch.einsum('bk,bk->b',Prob,Labels) 209 | loss = torch.mean(loss) 210 | return loss 211 | 212 | def compute_loss(self,in_package): 213 | 214 | if len(in_package['batch_label'].size()) == 1: 215 | in_package['batch_label'] = self.weight_ce[in_package['batch_label']] 216 | 217 | ## loss rank 218 | if self.loss_type == 'CE': 219 | loss_CE = self.compute_aug_cross_entropy(in_package) 220 | elif self.loss_type == 'rank': 221 | loss_CE = self.compute_loss_rank(in_package) 222 | else: 223 | raise Exception('Unknown loss type') 224 | 225 | ## loss self-calibration 226 | loss_cal = self.compute_loss_Self_Calibrate(in_package) 227 | 228 | ## total loss 229 | loss = loss_CE + self.lambda_*loss_cal 230 | 231 | out_package = {'loss':loss,'loss_CE':loss_CE, 232 | 'loss_cal':loss_cal} 233 | 234 | return out_package 235 | 236 | def forward(self,Fs): 237 | 238 | if self.is_conv: 239 | Fs = self.conv1(Fs) 240 | Fs = self.conv1_bn(Fs) 241 | Fs = F.relu(Fs) 242 | 243 | shape = Fs.shape 244 | Fs = Fs.reshape(shape[0],shape[1],shape[2]*shape[3]) 245 | 246 | R = Fs.size(2) 247 | B = Fs.size(0) 248 | V_n = self.compute_V() 249 | 250 | if self.normalize_F and not self.is_conv: 251 | Fs = F.normalize(Fs,dim = 1) 252 | 253 | ## Compute attribute score on each image region 254 | S = torch.einsum('iv,vf,bfr->bir',V_n,self.W_1,Fs) 255 | 256 | if self.is_sigmoid: 257 | S=torch.sigmoid(S) 258 | 259 | ## Ablation setting 260 | A_b = Fs.new_full((B,self.dim_att,R),1/R) 261 | A_b_p = self.att.new_full((B,self.dim_att),fill_value = 1) 262 | S_b_p = torch.einsum('bir,bir->bi',A_b,S) 263 | S_b_pp = torch.einsum('ki,bi,bi->bk',self.att,A_b_p,S_b_p) 264 | ## 265 | 266 | ## compute Dense Attention 267 | A = torch.einsum('iv,vf,bfr->bir',V_n,self.W_2,Fs) 268 | A = F.softmax(A,dim = -1) # compute an attention map for each attribute 269 | F_p = torch.einsum('bir,bfr->bif',A,Fs) # compute attribute-based features 270 | 271 | if self.uniform_att_1: 272 | S_p = torch.einsum('bir,bir->bi',A_b,S) # ablation: compute attribute score using average image region features 273 | else: 274 | S_p = torch.einsum('bir,bir->bi',A,S) # compute attribute scores from attribute attention maps 275 | 276 | if self.non_linear_act: 277 | S_p = F.relu(S_p) 278 | ## 279 | 280 | ## compute Attention over Attribute 281 | A_p = torch.einsum('iv,vf,bif->bi',V_n,self.W_3,F_p) 282 | A_p = torch.sigmoid(A_p) 283 | ## 284 | 285 | if self.uniform_att_2: 286 | S_pp = torch.einsum('ki,bi,bi->bik',self.att,A_b_p,S_p) # ablation: setting attention over attribute to 1 287 | else: 288 | S_pp = torch.einsum('ki,bi,bi->bik',self.att,A_p,S_p) # compute the final prediction as the product of semantic scores, attribute scores, and attention over attribute scores 289 | 290 | S_attr = torch.einsum('bi,bi->bi',A_b_p,S_p) 291 | 292 | if self.non_linear_emb: 293 | S_pp = torch.transpose(S_pp,2,1) #[bki] <== [bik] 294 | S_pp = self.emb_func(S_pp) #[bk1] <== [bki] 295 | S_pp = S_pp[:,:,0] #[bk] <== [bk1] 296 | else: 297 | S_pp = torch.sum(S_pp,axis=1) #[bk] <== [bik] 298 | 299 | # augment prediction scores by adding a margin of 1 to unseen classes and -1 to seen classes 300 | if self.is_bias: 301 | self.vec_bias = self.mask_bias*self.bias 302 | S_pp = S_pp + self.vec_bias 303 | 304 | ## spatial attention supervision 305 | Pred_att = torch.einsum('iv,vf,bif->bi',V_n,self.W_1,F_p) 306 | 307 | package = {'S_pp':S_pp,'Pred_att':Pred_att,'S_b_pp':S_b_pp,'A_p':A_p,'A':A,'S_attr':S_attr} 308 | 309 | return package 310 | -------------------------------------------------------------------------------- /core/DeepFashionDataLoader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Aug 24 10:31:33 2019 4 | 5 | @author: Warmachine 6 | """ 7 | 8 | import os,sys 9 | #import scipy.io as sio 10 | import torch 11 | import torch.nn.functional as F 12 | import numpy as np 13 | import h5py 14 | import time 15 | import math 16 | from sklearn import preprocessing 17 | from global_setting import NFS_path 18 | #%% 19 | import scipy.io as sio 20 | import pandas as pd 21 | import pickle 22 | #%% 23 | #import pdb 24 | #%% 25 | img_dir = os.path.join(NFS_path,'data/DeepFashion/') 26 | anno_path = os.path.join(NFS_path,'data/DeepFashion/annotation.pkl') 27 | 28 | class DeepFashionDataLoader(): 29 | def __init__(self, data_path, device, is_balance =True, is_scale = False,verbose = False): 30 | 31 | print(data_path) 32 | sys.path.append(data_path) 33 | 34 | self.data_path = data_path 35 | self.device = device 36 | self.dataset = 'DeepFashion' 37 | print('$'*30) 38 | print(self.dataset) 39 | print('$'*30) 40 | self.datadir = self.data_path + 'data/{}/'.format(self.dataset) 41 | self.index_in_epoch = 0 42 | self.epochs_completed = 0 43 | self.is_scale = is_scale 44 | self.is_balance = is_balance 45 | self.verbose = verbose 46 | self.n_reuse = 2 47 | self.read_matdataset() 48 | 49 | self.seeker = np.zeros(self.ntrain_class) 50 | ### setup balance training ### 51 | if self.is_balance: 52 | print('Balance dataloader') 53 | else: 54 | print('No balance adjustment') 55 | self.cur_classes_idx = 0 56 | 57 | self.idx_part = 0 58 | self.part_size = 10000 59 | self.idx_b = 0 60 | self.part_features = None 61 | self.part_labels = None 62 | 63 | print('Partition size {}'.format(self.part_size)) 64 | 65 | self.convert_new_classes() 66 | print('Excluding non-sample classes') 67 | 68 | print('-'*30) 69 | print('DeepFashion') 70 | print('-'*30) 71 | 72 | def reset_seeker(self): 73 | self.seeker[:] = 0 74 | 75 | def get_class(self): 76 | self.cur_classes_idx = (self.cur_classes_idx+1)%self.ntrain_class 77 | 78 | return self.train_class[self.cur_classes_idx].cpu() 79 | 80 | def augment_img_path(self,anno_path=anno_path,img_dir=img_dir): 81 | self.package = pickle.load(open(anno_path,'rb')) 82 | self.image_files = self.package['image_names'] 83 | self.cat_names = self.package['cat_names'] 84 | 85 | def convert_path(image_files,img_dir): 86 | new_image_files = [] 87 | for idx in range(len(image_files)): 88 | image_file = image_files[idx] 89 | image_file = os.path.join(img_dir,image_file) 90 | new_image_files.append(image_file) 91 | return np.array(new_image_files) 92 | self.image_files = convert_path(self.image_files,img_dir) 93 | 94 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 95 | hf = h5py.File(path, 'r') 96 | 97 | trainval_loc = np.array(hf.get('trainval_loc')) 98 | test_seen_loc = np.array(hf.get('test_seen_loc')) 99 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 100 | 101 | self.data['train_seen']['img_path'] = None#self.image_files[trainval_loc] 102 | self.data['test_seen']['img_path'] = self.image_files[test_seen_loc] 103 | self.data['test_unseen']['img_path'] = self.image_files[test_unseen_loc] 104 | 105 | self.attr_name = self.package['att_names'] 106 | 107 | def next_batch_img(self, batch_size,class_id,is_trainset = False): 108 | features = None 109 | labels = None 110 | img_files = None 111 | if class_id in self.seenclasses: 112 | if is_trainset: 113 | features = self.data['train_seen']['resnet_features'] 114 | labels = self.data['train_seen']['labels'] 115 | img_files = self.data['train_seen']['img_path'] 116 | else: 117 | features = self.data['test_seen']['resnet_features'] 118 | labels = self.data['test_seen']['labels'] 119 | img_files = self.data['test_seen']['img_path'] 120 | elif class_id in self.unseenclasses: 121 | features = self.data['test_unseen']['resnet_features'] 122 | labels = self.data['test_unseen']['labels'] 123 | img_files = self.data['test_unseen']['img_path'] 124 | else: 125 | raise Exception("Cannot find this class {}".format(class_id)) 126 | 127 | #note that img_files is numpy type !!!!! 128 | 129 | idx_c = torch.squeeze(torch.nonzero(labels == class_id)) 130 | 131 | features = features[idx_c] 132 | labels = labels[idx_c] 133 | img_files = img_files[idx_c.cpu().numpy()] 134 | 135 | batch_label = labels[:batch_size].to(self.device) 136 | batch_feature = features[:batch_size].to(self.device) 137 | batch_files = img_files[:batch_size] 138 | batch_att = self.att[batch_label].to(self.device) 139 | 140 | return batch_label, batch_feature,batch_files, batch_att 141 | 142 | def sample_from_parition(self): 143 | print('load data from hdf') # this is slow because it loads non-consecutive memory block. Thus it may need to traverse the whole file to load stuff. 144 | 145 | self.part_features = [] 146 | self.part_labels = [] 147 | target_class_size = self.part_size//self.ntrain_class 148 | for idx_l,l in enumerate(self.data['train_seen']['labels']): 149 | 150 | features_c = self.data['train_seen']['resnet_features_hdf'][l] 151 | n_samples_c = len(features_c) 152 | if self.seeker[idx_l]*target_class_size >= n_samples_c: ## seek to different partition of data 153 | self.seeker[idx_l] = 0 154 | start_seek = self.seeker[idx_l]*target_class_size 155 | end_seek = min(n_samples_c,(self.seeker[idx_l]+1)*target_class_size) 156 | idx_samples_c = np.arange(start_seek,end_seek).tolist()#np.random.choice(n_samples_c,size = samples_class_size) 157 | 158 | n_select_c = len(idx_samples_c) 159 | 160 | print(l,end='..') 161 | 162 | 163 | self.part_features.append(features_c[idx_samples_c]) ## only work with list datatype 164 | part_labels = [int(l)]*n_select_c 165 | part_labels = self.map_old2new_classes[part_labels] 166 | self.part_labels.append(part_labels) 167 | self.seeker[idx_l] += 1 168 | print() 169 | 170 | def next_batch(self, batch_size): 171 | if self.is_balance: 172 | if self.idx_b == 0: 173 | tic = time.clock() 174 | self.sample_from_parition() 175 | print('Elapsed time {}'.format(time.clock()-tic)) 176 | 177 | batch_feature = [] 178 | batch_label = [] 179 | 180 | n_target_class = max(batch_size //self.ntrain_class,1) 181 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 182 | for i_c in sampled_idx_c: 183 | part_feature = self.part_features[i_c] 184 | part_label = self.part_labels[i_c] 185 | n_samples_c = part_feature.shape[0] 186 | idx_sample_c = np.random.choice(n_samples_c,n_target_class) 187 | batch_feature.append(part_feature[idx_sample_c]) 188 | batch_label.append(part_label[idx_sample_c]) 189 | 190 | 191 | 192 | # idx_samples_b = torch.randperm(self.part_labels.size(0))[0:batch_size] 193 | # batch_feature = self.part_features[idx_samples_b].to(self.device) 194 | # batch_label = self.part_labels[idx_samples_b].to(self.device) 195 | batch_feature = torch.from_numpy(np.concatenate(batch_feature)).to(self.device) 196 | batch_label= torch.from_numpy(np.concatenate(batch_label)).to(self.device) 197 | batch_att = self.att[batch_label].to(self.device) 198 | 199 | ##increment 200 | self.idx_b = (self.idx_b + 1)%(self.part_size//batch_size*self.n_reuse) 201 | 202 | else: 203 | raise Exception('Not Implemented') 204 | if self.idx_b == 0: 205 | # idx_samples = (np.arange(self.part_size)+self.idx_part*self.part_size).tolist() 206 | self.part_features = torch.tensor(self.data['train_seen']['resnet_features_hdf'][self.idx_part*self.part_size:(self.idx_part+1)*self.part_size]) 207 | self.part_labels = self.data['train_seen']['labels'][self.idx_part*self.part_size:(self.idx_part+1)*self.part_size] 208 | ## permute 209 | idx_permute = torch.randperm(self.part_features.size(0)) 210 | self.part_features = self.part_features[idx_permute] 211 | self.part_labels = self.part_labels[idx_permute] 212 | 213 | ## increment 214 | self.idx_part = (self.idx_part+1)%(self.ntrain//self.part_size) 215 | 216 | batch_feature = self.part_features[self.idx_b*batch_size:(self.idx_b+1)*batch_size].to(self.device) 217 | batch_label = self.part_labels[self.idx_b*batch_size:(self.idx_b+1)*batch_size].to(self.device) 218 | batch_att = self.att[batch_label].to(self.device) 219 | 220 | ##increment 221 | self.idx_b = (self.idx_b + 1)%(self.part_labels.size(0)//batch_size) 222 | 223 | if self.verbose: 224 | print('unique labels in batch {}'.format(torch.unique(batch_label))) 225 | 226 | return batch_label, batch_feature, batch_att 227 | 228 | def convert_new_classes(self): 229 | self.att = self.att[self.available_classes] 230 | self.att = F.normalize((self.att+1)/2) 231 | 232 | # self.data['train_seen']['labels'] = self.map_old2new_classes[self.data['train_seen']['labels']] 233 | self.data['test_seen']['labels'] = self.map_old2new_classes[self.data['test_seen']['labels']] 234 | self.data['test_unseen']['labels'] = self.map_old2new_classes[self.data['test_unseen']['labels']] 235 | 236 | self.seenclasses = self.map_old2new_classes[self.seenclasses].to(self.device)#torch.unique(self.data['train_seen']['labels']).to(self.device) 237 | self.unseenclasses = torch.unique(self.data['test_unseen']['labels']).to(self.device) 238 | # self.ntrain = self.data['train_seen']['labels'].size()[0] 239 | # self.ntrain_class = self.seenclasses.size(0) 240 | # self.ntest_class = self.unseenclasses.size(0) 241 | # self.train_class = self.seenclasses.clone() 242 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 243 | 244 | def read_matdataset(self): 245 | 246 | path= self.datadir + 'feature_map_ResNet_101_{}_sep_seen_samples.hdf5'.format(self.dataset) 247 | print('_____') 248 | print(path) 249 | tic = time.clock() 250 | hf = h5py.File(path, 'r') 251 | 252 | # labels = np.array(hf.get('labels')) 253 | 254 | att = np.array(hf.get('att')) 255 | 256 | ## remap classes this is because there is some classes that does not have training sample 257 | self.available_classes = np.where(np.sum(att,axis = 1)!=0)[0] 258 | self.map_old2new_classes = np.ones(att.shape[0])*-1 259 | self.map_old2new_classes[self.available_classes] = np.arange(self.available_classes.shape[0]) 260 | self.map_old2new_classes = torch.from_numpy(self.map_old2new_classes).long() 261 | ## 262 | 263 | self.att = torch.from_numpy(att).float().to(self.device) 264 | 265 | w2v_att = np.array(hf.get('w2v_att')) 266 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 267 | 268 | labels = hf['label_train'] 269 | seenclasses = [int(l) for l in labels] 270 | n_sample_classes = [len(labels[str(l)]) for l in seenclasses] 271 | 272 | 273 | test_unseen_label = torch.from_numpy(np.array(hf.get('label_test_unseen'),dtype=np.int32)).long()#.to(self.device) 274 | test_seen_label = torch.from_numpy(np.array(hf.get('label_test_seen'),dtype=np.int32)).long()#.to(self.device) 275 | 276 | self.seenclasses = torch.tensor(seenclasses) 277 | self.unseenclasses = torch.unique(test_unseen_label).to(self.device) 278 | self.ntrain = sum(n_sample_classes) 279 | self.ntrain_class = self.seenclasses.size(0) 280 | self.ntest_class = self.unseenclasses.size(0) 281 | self.train_class = self.seenclasses.clone() 282 | ## containing missing classes therefore cannot determine the set of all available label 283 | # self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 284 | 285 | self.data = {} 286 | self.data['train_seen'] = {} 287 | self.data['train_seen']['resnet_features_hdf'] = hf['feature_map_train'] 288 | self.data['train_seen']['labels']= labels 289 | 290 | # input('Debug version b') 291 | self.data['train_unseen'] = {} 292 | self.data['train_unseen']['resnet_features'] = None 293 | self.data['train_unseen']['labels'] = None 294 | 295 | self.data['test_seen'] = {} 296 | self.data['test_seen']['resnet_features'] = torch.from_numpy(np.array(hf.get('feature_map_test_seen'),dtype=np.float32)).float() 297 | self.data['test_seen']['labels'] = test_seen_label 298 | 299 | self.data['test_unseen'] = {} 300 | self.data['test_unseen']['resnet_features'] = torch.from_numpy(np.array(hf.get('feature_map_test_unseen'),dtype=np.float32)).float() 301 | self.data['test_unseen']['labels'] = test_unseen_label 302 | 303 | print('Finish loading data in ',time.clock()-tic) -------------------------------------------------------------------------------- /core/SUNDataLoader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Aug 1 12:11:40 2019 5 | 6 | @author: war-machince 7 | """ 8 | 9 | import os,sys 10 | #import scipy.io as sio 11 | import torch 12 | import numpy as np 13 | import h5py 14 | import time 15 | import pickle 16 | from sklearn import preprocessing 17 | from core.helper_func import mix_up 18 | #%% 19 | import pdb 20 | #%% 21 | 22 | class SUNDataLoader(): 23 | def __init__(self, data_path, device, is_scale = False, is_balance=True): 24 | 25 | print(data_path) 26 | sys.path.append(data_path) 27 | 28 | self.data_path = data_path 29 | self.device = device 30 | self.dataset = 'SUN' 31 | print('$'*30) 32 | print(self.dataset) 33 | print('$'*30) 34 | self.datadir = self.data_path + 'data/{}/'.format(self.dataset) 35 | self.index_in_epoch = 0 36 | self.epochs_completed = 0 37 | self.is_scale = is_scale 38 | self.is_balance = is_balance 39 | if self.is_balance: 40 | print('Balance dataloader') 41 | self.read_matdataset() 42 | self.get_idx_classes() 43 | self.I = torch.eye(self.allclasses.size(0)).to(device) 44 | 45 | def next_batch(self, batch_size): 46 | if self.is_balance: 47 | idx = [] 48 | n_samples_class = max(batch_size //self.ntrain_class,1) 49 | sampled_idx_c = np.random.choice(np.arange(self.ntrain_class),min(self.ntrain_class,batch_size),replace=False).tolist() 50 | for i_c in sampled_idx_c: 51 | idxs = self.idxs_list[i_c] 52 | idx.append(np.random.choice(idxs,n_samples_class)) 53 | idx = np.concatenate(idx) 54 | idx = torch.from_numpy(idx) 55 | else: 56 | idx = torch.randperm(self.ntrain)[0:batch_size] 57 | 58 | batch_feature = self.data['train_seen']['resnet_features'][idx].to(self.device) 59 | batch_label = self.data['train_seen']['labels'][idx].to(self.device) 60 | batch_att = self.att[batch_label].to(self.device) 61 | return batch_label, batch_feature, batch_att 62 | 63 | def get_idx_classes(self): 64 | n_classes = self.seenclasses.size(0) 65 | self.idxs_list = [] 66 | train_label = self.data['train_seen']['labels'] 67 | for i in range(n_classes): 68 | idx_c = torch.nonzero(train_label == self.seenclasses[i].cpu()).cpu().numpy() 69 | idx_c = np.squeeze(idx_c) 70 | self.idxs_list.append(idx_c) 71 | return self.idxs_list 72 | 73 | # def next_batch_mix_up(self,batch_size): 74 | # Y1,S1,_=self.next_batch(batch_size) 75 | # Y2,S2,_=self.next_batch(batch_size) 76 | # S,Y=mix_up(S1,S2,Y1,Y2) 77 | # return Y,S,None 78 | 79 | def read_matdataset(self): 80 | 81 | path= self.datadir + 'feature_map_ResNet_101_{}.hdf5'.format(self.dataset) 82 | print('_____') 83 | print(path) 84 | tic = time.clock() 85 | hf = h5py.File(path, 'r') 86 | features = np.array(hf.get('feature_map')) 87 | # shape = features.shape 88 | # features = features.reshape(shape[0],shape[1],shape[2]*shape[3]) 89 | labels = np.array(hf.get('labels')) 90 | trainval_loc = np.array(hf.get('trainval_loc')) 91 | # train_loc = np.array(hf.get('train_loc')) #--> train_feature = TRAIN SEEN 92 | # val_unseen_loc = np.array(hf.get('val_unseen_loc')) #--> test_unseen_feature = TEST UNSEEN 93 | test_seen_loc = np.array(hf.get('test_seen_loc')) 94 | test_unseen_loc = np.array(hf.get('test_unseen_loc')) 95 | 96 | 97 | print('Expert Attr') 98 | att = np.array(hf.get('att')) 99 | self.att = torch.from_numpy(att).float().to(self.device) 100 | 101 | original_att = np.array(hf.get('original_att')) 102 | self.original_att = torch.from_numpy(original_att).float().to(self.device) 103 | 104 | w2v_att = np.array(hf.get('w2v_att')) 105 | self.w2v_att = torch.from_numpy(w2v_att).float().to(self.device) 106 | 107 | self.normalize_att = self.original_att/100 108 | 109 | print('Finish loading data in ',time.clock()-tic) 110 | 111 | train_feature = features[trainval_loc] 112 | test_seen_feature = features[test_seen_loc] 113 | test_unseen_feature = features[test_unseen_loc] 114 | if self.is_scale: 115 | scaler = preprocessing.MinMaxScaler() 116 | 117 | train_feature = scaler.fit_transform(train_feature) 118 | test_seen_feature = scaler.fit_transform(test_seen_feature) 119 | test_unseen_feature = scaler.fit_transform(test_unseen_feature) 120 | 121 | train_feature = torch.from_numpy(train_feature).float() #.to(self.device) 122 | test_seen_feature = torch.from_numpy(test_seen_feature) #.float().to(self.device) 123 | test_unseen_feature = torch.from_numpy(test_unseen_feature) #.float().to(self.device) 124 | 125 | train_label = torch.from_numpy(labels[trainval_loc]).long() #.to(self.device) 126 | test_unseen_label = torch.from_numpy(labels[test_unseen_loc]) #.long().to(self.device) 127 | test_seen_label = torch.from_numpy(labels[test_seen_loc]) #.long().to(self.device) 128 | 129 | self.seenclasses = torch.from_numpy(np.unique(train_label.cpu().numpy())).to(self.device) 130 | self.unseenclasses = torch.from_numpy(np.unique(test_unseen_label.cpu().numpy())).to(self.device) 131 | self.ntrain = train_feature.size()[0] 132 | self.ntrain_class = self.seenclasses.size(0) 133 | self.ntest_class = self.unseenclasses.size(0) 134 | self.train_class = self.seenclasses.clone() 135 | self.allclasses = torch.arange(0, self.ntrain_class+self.ntest_class).long() 136 | 137 | # self.train_mapped_label = map_label(train_label, self.seenclasses) 138 | 139 | self.data = {} 140 | self.data['train_seen'] = {} 141 | self.data['train_seen']['resnet_features'] = train_feature 142 | self.data['train_seen']['labels']= train_label 143 | 144 | 145 | self.data['train_unseen'] = {} 146 | self.data['train_unseen']['resnet_features'] = None 147 | self.data['train_unseen']['labels'] = None 148 | 149 | self.data['test_seen'] = {} 150 | self.data['test_seen']['resnet_features'] = test_seen_feature 151 | self.data['test_seen']['labels'] = test_seen_label 152 | 153 | self.data['test_unseen'] = {} 154 | self.data['test_unseen']['resnet_features'] = test_unseen_feature 155 | self.data['test_unseen']['labels'] = test_unseen_label 156 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbdat/cvpr20_DAZLE/e5a04601d8368903008fd96dee3d95dde398aa51/core/__init__.py -------------------------------------------------------------------------------- /data/AWA2/instruction_AWA2.txt: -------------------------------------------------------------------------------- 1 | Please download and extract AwA2-data.zip (https://cvml.ist.ac.at/AwA2/AwA2-data.zip) into this folder in order to run ./extract_feature/extract_feature_map_ResNet_101_AWA2.py -------------------------------------------------------------------------------- /data/CUB/instruction_CUB.txt: -------------------------------------------------------------------------------- 1 | Please download and extract CUB_200_2011.tgz (http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz) into this folder in order to run ./extract_feature/extract_feature_map_ResNet_101_CUB.py -------------------------------------------------------------------------------- /data/DeepFashion/Anno/instruction_Anno.txt: -------------------------------------------------------------------------------- 1 | Please download and extract the whole Anno folder (https://drive.google.com/drive/folders/0B7EVK8r0v71pWWxJeGVqMjRkUVE) into this folder in order to run ./extract_feature/extract_annotation_DeepFashion.py -------------------------------------------------------------------------------- /data/DeepFashion/Eval/intruction_Eval.txt: -------------------------------------------------------------------------------- 1 | Please download and extract the whole Eval folder (https://drive.google.com/drive/folders/0B7EVK8r0v71pdDVIVGJpVFJOY0k) into this folder in order to run ./extract_feature/extract_annotation_DeepFashion.py -------------------------------------------------------------------------------- /data/DeepFashion/annotation.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbdat/cvpr20_DAZLE/e5a04601d8368903008fd96dee3d95dde398aa51/data/DeepFashion/annotation.pkl -------------------------------------------------------------------------------- /data/DeepFashion/instruction_DeepFashion.txt: -------------------------------------------------------------------------------- 1 | Please download and extract img.zip (https://drive.google.com/drive/folders/0B7EVK8r0v71pekpRNUlMS3Z5cUk) into this folder in order to run ./extract_feature/extract_feature_map_ResNet_101_DeepFashion.py -------------------------------------------------------------------------------- /data/SUN/instruction_SUN.txt: -------------------------------------------------------------------------------- 1 | Please download and extract SUNAttributeDB_Images.tar.gz (http://cs.brown.edu/~gmpatter/Attributes/SUNAttributeDB_Images.tar.gz) into this folder in order to run ./extract_feature/extract_feature_map_ResNet_101_SUN.py -------------------------------------------------------------------------------- /data/standard_split/instruction_standard_split.txt: -------------------------------------------------------------------------------- 1 | Please download and extract standard_split.zip (http://datasets.d2.mpi-inf.mpg.de/xian/standard_split.zip) into this folder. This file contains the standard splits for CUB needed to run DAZLE_CUB_SS.ipynb -------------------------------------------------------------------------------- /data/xlsa17/instruction_xlsa17.txt: -------------------------------------------------------------------------------- 1 | Please download and extract xlsa17.zip (http://datasets.d2.mpi-inf.mpg.de/xian/xlsa17.zip) into this folder. This file contains the proposed splits for AWA2, CUB, SUN needed for all files in ./extract_feature/ -------------------------------------------------------------------------------- /extract_feature/extract_annotation_DeepFashion.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Aug 22 10:28:17 2019 4 | 5 | @author: Warmachine 6 | """ 7 | 8 | import pandas as pd 9 | from io import StringIO 10 | import numpy as np 11 | from global_setting_Local import NFS_path 12 | import pickle 13 | #%% 14 | import pdb 15 | #%% 16 | anno_dir = NFS_path+'data/DeepFashion/'+'Anno/' 17 | eval_dir = NFS_path+'data/DeepFashion/'+'Eval/' 18 | #%% 19 | unseen_classes = [12,15,2,19,29,30,26,44,42,48] 20 | #%% 21 | def convert_string_np(s,idx_e): 22 | if s is None: 23 | print(idx_e) 24 | list_int = [int(e) for e in s.split(' ') if e] 25 | return np.array(list_int).reshape((1,-1)) 26 | 27 | def read_text(file_path,sep): #Need to hardcode the logic that row is always tuple 28 | with open(file_path,'r') as f: 29 | rows = [] 30 | for idx_l,line in enumerate(f): 31 | line = line.strip() 32 | if idx_l == 0: 33 | n = int(line) 34 | elif idx_l == 1: 35 | columns = [e for e in line.split(sep) if e] 36 | else: 37 | es = [e for e in line.split(sep) if e] 38 | row = [] 39 | row.append(es[0]) #element 1 40 | row.append(' '.join(es[1:])) #element 2 41 | # row[-1] = int(row[-1]) 42 | rows.append(row) 43 | assert len(rows)==n 44 | df = pd.DataFrame(rows,columns =columns) 45 | return df 46 | #%% 47 | df_cat_name = read_text(anno_dir+'list_category_cloth.txt',sep = ' ') 48 | df_cat_name['category_type'] = df_cat_name['category_type'].astype(int) 49 | 50 | df_att_name = read_text(anno_dir+'list_attr_cloth.txt',sep = ' ') 51 | df_att_name['attribute_type'] = df_att_name['attribute_type'].astype(int) 52 | #%% 53 | df_labels = read_text(anno_dir+'list_category_img.txt',sep = ' ') 54 | df_labels['category_label'] = df_labels['category_label'].astype(int) 55 | df_attr_anno = read_text(anno_dir+'list_attr_img.txt',sep = ' ') 56 | df_split = read_text(eval_dir+'list_eval_partition.txt',sep = ' ') 57 | 58 | df_join = df_attr_anno.join(df_labels.set_index('image_name'),on='image_name') 59 | df_join = df_join.join(df_split.set_index('image_name'),on='image_name') 60 | #%% 61 | image_names = df_join['image_name'].values 62 | cat_names = df_cat_name['category_name'].values 63 | att_names = df_att_name['attribute_name'].values 64 | 65 | 66 | labels = df_join['category_label'].values 67 | labels = [convert_string_np(attr,idx_e) for idx_e,attr in enumerate(labels)] 68 | labels = np.concatenate(labels,0) 69 | 70 | attr_annos = df_join['attribute_labels'].values 71 | attr_annos = [convert_string_np(attr,idx_e) for idx_e,attr in enumerate(attr_annos)] 72 | attr_annos = np.concatenate(attr_annos,0) 73 | 74 | split = df_join['evaluation_status'].values 75 | #%% 76 | classes = np.unique(labels) 77 | att = np.zeros((len(df_cat_name),attr_annos.shape[1])) 78 | for c in classes: 79 | mask = np.squeeze(labels == c) 80 | att[c,:] = np.mean(attr_annos[mask],0) 81 | #%% 82 | freq = df_labels['category_label'].value_counts() 83 | df_cat_name = df_cat_name.join(freq) 84 | #%% 85 | df_cat_name.join(freq.to_frame()) 86 | #%% 87 | df_unseen_cat = df_cat_name.iloc[unseen_classes] 88 | #%% 89 | test_unseen_loc = np.array([i for i in range(labels.shape[0]) if labels[i] in unseen_classes]) 90 | test_seen_loc = np.array([i for i in range(labels.shape[0]) if (labels[i] not in unseen_classes) and (split[i]=='test')]) 91 | train_loc = np.array([i for i in range(labels.shape[0]) if (labels[i] not in unseen_classes) and (split[i]!='test')]) 92 | assert len(train_loc)+len(test_seen_loc)+len(test_unseen_loc) == len(split) 93 | #%% 94 | package = {'image_names':image_names,'cat_names':cat_names,'att_names':att_names,'train_loc':train_loc,'test_seen_loc':test_seen_loc,'test_unseen_loc':test_unseen_loc,'att':att,'labels':labels} 95 | pickle.dump(package, open( NFS_path+"data/DeepFashion/annotation.pkl", "wb" ) ) -------------------------------------------------------------------------------- /extract_feature/extract_attribute_w2v_AWA2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Jul 19 15:02:44 2019 5 | 6 | @author: war-machince 7 | """ 8 | 9 | import os,sys 10 | pwd = os.getcwd() 11 | sys.path.insert(0,pwd) 12 | #%% 13 | print('-'*30) 14 | print(os.getcwd()) 15 | print('-'*30) 16 | #%% 17 | import pdb 18 | import pandas as pd 19 | import numpy as np 20 | import gensim.downloader as api 21 | import scipy.io as sio 22 | import pickle 23 | #%% 24 | print('Loading pretrain w2v model') 25 | model_name = 'word2vec-google-news-300'#best model 26 | model = api.load(model_name) 27 | dim_w2v = 300 28 | print('Done loading model') 29 | #%% 30 | replace_word = [('newworld','new world'),('oldworld','old world'),('nestspot','nest spot'),('toughskin','tough skin'), 31 | ('longleg','long leg'),('chewteeth','chew teeth'),('meatteeth','meat teeth'),('strainteeth','strain teeth'), 32 | ('quadrapedal','quadrupedal')] 33 | dataset = 'AWA2' 34 | #%% 35 | path = './attribute/{}/predicates.txt'.format('AWA2') 36 | df=pd.read_csv(path,sep='\t',header = None, names = ['idx','des']) 37 | des = df['des'].values 38 | #%% filter 39 | #new_des = [' '.join(i.split('_')) for i in des] 40 | #new_des = [' '.join(i.split('-')) for i in new_des] 41 | #new_des = [' '.join(i.split('::')) for i in new_des] 42 | #new_des = [i.split('(')[0] for i in new_des] 43 | #new_des = [i[4:] for i in new_des] 44 | #%% replace out of dictionary words 45 | for pair in replace_word: 46 | for idx,s in enumerate(des): 47 | des[idx]=s.replace(pair[0],pair[1]) 48 | print('Done replace OOD words') 49 | #%% 50 | df['new_des']=des 51 | df.to_csv('./attribute/{}/new_des.csv'.format(dataset)) 52 | #print('Done preprocessing attribute des') 53 | #%% 54 | counter_err = 0 55 | all_w2v = [] 56 | for s in des: 57 | print(s) 58 | words = s.split(' ') 59 | if words[-1] == '': #remove empty element 60 | words = words[:-1] 61 | w2v = np.zeros(dim_w2v) 62 | for w in words: 63 | try: 64 | w2v += model[w] 65 | except Exception as e: 66 | print(e) 67 | counter_err += 1 68 | all_w2v.append(w2v[np.newaxis,:]) 69 | print('counter_err ',counter_err) 70 | #%% 71 | all_w2v=np.concatenate(all_w2v,axis=0) 72 | pdb.set_trace() 73 | #%% 74 | with open('./w2v/{}_attribute.pkl'.format(dataset),'wb') as f: 75 | pickle.dump(all_w2v,f) 76 | -------------------------------------------------------------------------------- /extract_feature/extract_attribute_w2v_CUB.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jul 4 13:43:05 2019 4 | 5 | @author: badat 6 | """ 7 | import os,sys 8 | pwd = os.getcwd() 9 | sys.path.insert(0,pwd) 10 | #%% 11 | print('-'*30) 12 | print(os.getcwd()) 13 | print('-'*30) 14 | #%% 15 | import pdb 16 | import pandas as pd 17 | import numpy as np 18 | import gensim.downloader as api 19 | import pickle 20 | #%% 21 | print('Loading pretrain w2v model') 22 | model_name = 'word2vec-google-news-300'#best model 23 | model = api.load(model_name) 24 | dim_w2v = 300 25 | print('Done loading model') 26 | #%% 27 | replace_word = [('spatulate','broad'),('upperparts','upper parts'),('grey','gray')] 28 | #%% 29 | path = './attribute/CUB/attributes.txt' 30 | df=pd.read_csv(path,sep=' ',header = None, names = ['idx','des']) 31 | des = df['des'].values 32 | #%% filter 33 | new_des = [' '.join(i.split('_')) for i in des] 34 | new_des = [' '.join(i.split('-')) for i in new_des] 35 | new_des = [' '.join(i.split('::')) for i in new_des] 36 | new_des = [i.split('(')[0] for i in new_des] 37 | new_des = [i[4:] for i in new_des] 38 | #%% replace out of dictionary words 39 | for pair in replace_word: 40 | for idx,s in enumerate(new_des): 41 | new_des[idx]=s.replace(pair[0],pair[1]) 42 | print('Done replace OOD words') 43 | #%% 44 | df['new_des']=new_des 45 | df.to_csv('./attribute/CUB/new_des.csv') 46 | print('Done preprocessing attribute des') 47 | #%% 48 | all_w2v = [] 49 | for s in new_des: 50 | print(s) 51 | words = s.split(' ') 52 | if words[-1] == '': #remove empty element 53 | words = words[:-1] 54 | w2v = np.zeros(dim_w2v) 55 | for w in words: 56 | try: 57 | w2v += model[w] 58 | except Exception as e: 59 | print(e) 60 | all_w2v.append(w2v[np.newaxis,:]) 61 | #%% 62 | all_w2v=np.concatenate(all_w2v,axis=0) 63 | pdb.set_trace() 64 | #%% 65 | with open('./w2v/CUB_attribute.pkl','wb') as f: 66 | pickle.dump(all_w2v,f) -------------------------------------------------------------------------------- /extract_feature/extract_attribute_w2v_DeepFashion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Aug 22 17:25:09 2019 5 | 6 | @author: war-machince 7 | """ 8 | 9 | import os,sys 10 | pwd = os.getcwd() 11 | sys.path.insert(0,pwd) 12 | #%% 13 | print('-'*30) 14 | print(os.getcwd()) 15 | print('-'*30) 16 | #%% 17 | import pdb 18 | import pandas as pd 19 | import numpy as np 20 | import gensim.downloader as api 21 | import scipy.io as sio 22 | import pickle 23 | from global_setting_Pegasus import NFS_path 24 | #%% 25 | print('Loading pretrain w2v model') 26 | model_name = 'word2vec-google-news-300'#best model 27 | model = api.load(model_name) 28 | dim_w2v = 300 29 | print('Done loading model') 30 | #%% 31 | replace_word = [('-',' '),('eiffel','Eiffel')] 32 | dataset = 'DeepFashion' 33 | #%% 34 | path = os.path.join(NFS_path,'data/{}/annotation.pkl'.format(dataset)) 35 | package=pickle.load(open(path,'rb')) 36 | des = package['att_names'] 37 | #%% filter 38 | #new_des = [' '.join(i.split('_')) for i in des] 39 | #new_des = [' '.join(i.split('-')) for i in new_des] 40 | #new_des = [' '.join(i.split('::')) for i in new_des] 41 | #new_des = [i.split('(')[0] for i in new_des] 42 | #new_des = [i[4:] for i in new_des] 43 | #%% replace out of dictionary words 44 | for pair in replace_word: 45 | for idx,s in enumerate(des): 46 | des[idx]=s.replace(pair[0],pair[1]) 47 | print('Done replace OOD words') 48 | #%% 49 | #df['new_des']=new_des 50 | #df.to_csv('./attribute/CUB/new_des.csv') 51 | #print('Done preprocessing attribute des') 52 | #%% 53 | counter_err = 0 54 | all_w2v = [] 55 | for s in des: 56 | # print(s) 57 | words = s.split(' ') 58 | if words[-1] == '': #remove empty element 59 | words = words[:-1] 60 | w2v = np.zeros(dim_w2v) 61 | for w in words: 62 | try: 63 | w2v += model[w] 64 | except Exception as e: 65 | print(e) 66 | counter_err += 1 67 | all_w2v.append(w2v[np.newaxis,:]) 68 | print('counter_err ',counter_err) 69 | #%% 70 | all_w2v=np.concatenate(all_w2v,axis=0) 71 | pdb.set_trace() 72 | #%% 73 | with open('./w2v/{}_attribute.pkl'.format(dataset),'wb') as f: 74 | pickle.dump(all_w2v,f) 75 | -------------------------------------------------------------------------------- /extract_feature/extract_attribute_w2v_SUN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jul 30 12:27:57 2019 5 | 6 | @author: war-machince 7 | """ 8 | 9 | import os,sys 10 | pwd = os.getcwd() 11 | sys.path.insert(0,pwd) 12 | #%% 13 | print('-'*30) 14 | print(os.getcwd()) 15 | print('-'*30) 16 | #%% 17 | import pdb 18 | import pandas as pd 19 | import numpy as np 20 | import gensim.downloader as api 21 | import scipy.io as sio 22 | import pickle 23 | #%% 24 | dataset = 'SUN' 25 | #%% 26 | print('Loading pretrain w2v model') 27 | model_name = 'word2vec-google-news-300'#best model 28 | model = api.load(model_name) 29 | dim_w2v = 300 30 | print('Done loading model') 31 | #%% 32 | replace_word = [('rockstone','rock stone'),('dirtsoil','dirt soil'),('man-made','man-made'),('sunsunny','sun sunny'), 33 | ('electricindoor','electric indoor'),('semi-enclosed','semi enclosed'),('far-away','faraway')] 34 | #%% 35 | file_path = './attribute/{}/attributes.mat'.format(dataset) 36 | matcontent = sio.loadmat(file_path) 37 | des = matcontent['attributes'].flatten() 38 | #%% 39 | df = pd.DataFrame() 40 | #%% filter 41 | new_des = [''.join(i.item().split('/')) for i in des] 42 | #%% replace out of dictionary words 43 | for pair in replace_word: 44 | for idx,s in enumerate(new_des): 45 | new_des[idx]=s.replace(pair[0],pair[1]) 46 | print('Done replace OOD words') 47 | #%% 48 | df['new_des']=new_des 49 | df.to_csv('./attribute/{}/new_des.csv'.format(dataset)) 50 | print('Done preprocessing attribute des') 51 | #%% 52 | all_w2v = [] 53 | for s in new_des: 54 | print(s) 55 | words = s.split(' ') 56 | if words[-1] == '': #remove empty element 57 | words = words[:-1] 58 | w2v = np.zeros(dim_w2v) 59 | for w in words: 60 | try: 61 | w2v += model[w] 62 | except Exception as e: 63 | print(e) 64 | all_w2v.append(w2v[np.newaxis,:]) 65 | #%% 66 | all_w2v=np.concatenate(all_w2v,axis=0) 67 | pdb.set_trace() 68 | #%% 69 | with open('./w2v/{}_attribute.pkl'.format(dataset),'wb') as f: 70 | pickle.dump(all_w2v,f) -------------------------------------------------------------------------------- /extract_feature/extract_feature_map_ResNet_101_AWA2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Jul 19 14:45:03 2019 5 | 6 | @author: war-machince 7 | """ 8 | 9 | import os,sys 10 | pwd = os.getcwd() 11 | sys.path.insert(0,pwd) 12 | #%% 13 | print('-'*30) 14 | print(os.getcwd()) 15 | print('-'*30) 16 | #%% 17 | import torch 18 | import torchvision 19 | import torch.nn as nn 20 | from torchvision import datasets, transforms 21 | from torch.utils.data import Dataset, DataLoader 22 | import torchvision.models.resnet as models 23 | from PIL import Image 24 | import h5py 25 | import numpy as np 26 | import scipy.io as sio 27 | import pickle 28 | from global_setting_Pegasus import NFS_path 29 | 30 | #%% 31 | import pdb 32 | #%% 33 | idx_GPU = 6 34 | is_save = False 35 | dataset = 'AWA2' 36 | input('is_save {}'.format(is_save)) 37 | #%% 38 | print("PyTorch Version: ",torch.__version__) 39 | print("Torchvision Version: ",torchvision.__version__) 40 | #%% 41 | img_dir = os.path.join(NFS_path,'data/{}/'.format(dataset)) 42 | file_paths = os.path.join(NFS_path,'data/xlsa17/data/{}/res101.mat'.format(dataset)) 43 | save_path = os.path.join(NFS_path,'data/{}/feature_map_ResNet_101_{}.hdf5'.format(dataset,dataset)) 44 | attribute_path = './w2v/{}_attribute.pkl'.format(dataset) 45 | #pdb.set_trace() 46 | # Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception] 47 | model_name = "resnet" 48 | 49 | # Batch size for training (change depending on how much memory you have) 50 | batch_size = 32 51 | 52 | device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu") 53 | #%% 54 | 55 | model_ref = models.resnet101(pretrained=True) 56 | model_ref.eval() 57 | 58 | model_f = nn.Sequential(*list(model_ref.children())[:-2]) 59 | model_f.to(device) 60 | model_f.eval() 61 | 62 | for param in model_f.parameters(): 63 | param.requires_grad = False 64 | #%% 65 | class CustomedDataset(Dataset): 66 | """Face Landmarks dataset.""" 67 | 68 | def __init__(self, img_dir , file_paths, transform=None): 69 | self.matcontent = sio.loadmat(file_paths) 70 | self.image_files = np.squeeze(self.matcontent['image_files']) 71 | self.img_dir = img_dir 72 | self.transform = transform 73 | 74 | def __len__(self): 75 | return len(self.image_files) 76 | 77 | def __getitem__(self, idx): 78 | image_file = self.image_files[idx][0] 79 | image_file = os.path.join(self.img_dir, 80 | '/'.join(image_file.split('/')[5:])) 81 | image = Image.open(image_file) 82 | if image.mode == 'L': 83 | image=image.convert('RGB') 84 | if self.transform: 85 | image = self.transform(image) 86 | return image 87 | 88 | #%% 89 | input_size = 224 90 | data_transforms = transforms.Compose([ 91 | transforms.Resize(input_size), 92 | transforms.CenterCrop(input_size), 93 | transforms.ToTensor(), 94 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 95 | ]) 96 | 97 | AWA2Dataset = CustomedDataset(img_dir , file_paths, data_transforms) 98 | dataset_loader = torch.utils.data.DataLoader(AWA2Dataset, 99 | batch_size=batch_size, shuffle=False, 100 | num_workers=4) 101 | #%% 102 | #with torch.no_grad(): 103 | all_features = [] 104 | for i_batch, imgs in enumerate(dataset_loader): 105 | print(i_batch) 106 | pdb.set_trace() 107 | imgs=imgs.to(device) 108 | features = model_f(imgs) 109 | all_features.append(features.cpu().numpy()) 110 | all_features = np.concatenate(all_features,axis=0) 111 | #%% get remaining metadata 112 | matcontent = AWA2Dataset.matcontent 113 | labels = matcontent['labels'].astype(int).squeeze() - 1 114 | 115 | split_path = os.path.join(NFS_path,'data/xlsa17/data/{}/att_splits.mat'.format(dataset)) 116 | matcontent = sio.loadmat(split_path) 117 | trainval_loc = matcontent['trainval_loc'].squeeze() - 1 118 | #train_loc = matcontent['train_loc'].squeeze() - 1 #--> train_feature = TRAIN SEEN 119 | #val_unseen_loc = matcontent['val_loc'].squeeze() - 1 #--> test_unseen_feature = TEST UNSEEN 120 | test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1 121 | test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1 122 | att = matcontent['att'].T 123 | original_att = matcontent['original_att'].T 124 | #%% construct attribute w2v 125 | with open(attribute_path,'rb') as f: 126 | w2v_att = pickle.load(f) 127 | assert w2v_att.shape == (85,300) 128 | print('save w2v_att') 129 | #%% 130 | if is_save: 131 | f = h5py.File(save_path, "w") 132 | f.create_dataset('feature_map', data=all_features,compression="gzip") 133 | f.create_dataset('labels', data=labels,compression="gzip") 134 | f.create_dataset('trainval_loc', data=trainval_loc,compression="gzip") 135 | # f.create_dataset('train_loc', data=train_loc,compression="gzip") 136 | # f.create_dataset('val_unseen_loc', data=val_unseen_loc,compression="gzip") 137 | f.create_dataset('test_seen_loc', data=test_seen_loc,compression="gzip") 138 | f.create_dataset('test_unseen_loc', data=test_unseen_loc,compression="gzip") 139 | f.create_dataset('att', data=att,compression="gzip") 140 | f.create_dataset('original_att', data=original_att,compression="gzip") 141 | f.create_dataset('w2v_att', data=w2v_att,compression="gzip") 142 | f.close() -------------------------------------------------------------------------------- /extract_feature/extract_feature_map_ResNet_101_CUB.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Jul 3 12:03:15 2019 5 | 6 | @author: war-machince 7 | """ 8 | 9 | import os,sys 10 | pwd = os.getcwd() 11 | sys.path.insert(0,pwd) 12 | #%% 13 | print('-'*30) 14 | print(os.getcwd()) 15 | print('-'*30) 16 | #%% 17 | import torch 18 | import torchvision 19 | import torch.nn as nn 20 | from torchvision import datasets, transforms 21 | from torch.utils.data import Dataset, DataLoader 22 | import torchvision.models.resnet as models 23 | from PIL import Image 24 | import h5py 25 | import numpy as np 26 | import scipy.io as sio 27 | import pickle 28 | from global_setting_Pegasus import NFS_path 29 | 30 | #%% 31 | #import pdb 32 | #%% 33 | idx_GPU = 6 34 | is_save = True 35 | #%% 36 | print("PyTorch Version: ",torch.__version__) 37 | print("Torchvision Version: ",torchvision.__version__) 38 | #%% 39 | img_dir = os.path.join(NFS_path,'data/CUB/') 40 | file_paths = os.path.join(NFS_path,'data/xlsa17/data/CUB/res101.mat') 41 | save_path = os.path.join(NFS_path,'data/CUB/feature_map_ResNet_101_CUB.hdf5') 42 | attribute_path = './w2v/CUB_attribute.pkl' 43 | #pdb.set_trace() 44 | # Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception] 45 | model_name = "resnet" 46 | 47 | # Batch size for training (change depending on how much memory you have) 48 | batch_size = 32 49 | 50 | device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu") 51 | #%% 52 | 53 | model_ref = models.resnet101(pretrained=True) 54 | model_ref.eval() 55 | 56 | model_f = nn.Sequential(*list(model_ref.children())[:-2]) 57 | model_f.to(device) 58 | model_f.eval() 59 | 60 | for param in model_f.parameters(): 61 | param.requires_grad = False 62 | #%% 63 | class CustomedDataset(Dataset): 64 | """Face Landmarks dataset.""" 65 | 66 | def __init__(self, img_dir , file_paths, transform=None): 67 | self.matcontent = sio.loadmat(file_paths) 68 | self.image_files = np.squeeze(self.matcontent['image_files']) 69 | self.img_dir = img_dir 70 | self.transform = transform 71 | 72 | def __len__(self): 73 | return len(self.image_files) 74 | 75 | def __getitem__(self, idx): 76 | image_file = self.image_files[idx][0] 77 | image_file = os.path.join(self.img_dir, 78 | '/'.join(image_file.split('/')[5:])) 79 | image = Image.open(image_file) 80 | if image.mode == 'L': 81 | image=image.convert('RGB') 82 | if self.transform: 83 | image = self.transform(image) 84 | return image 85 | 86 | #%% 87 | input_size = 224 88 | data_transforms = transforms.Compose([ 89 | transforms.Resize(input_size), 90 | transforms.CenterCrop(input_size), 91 | transforms.ToTensor(), 92 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 93 | ]) 94 | 95 | CUBDataset = CustomedDataset(img_dir , file_paths, data_transforms) 96 | dataset_loader = torch.utils.data.DataLoader(CUBDataset, 97 | batch_size=batch_size, shuffle=False, 98 | num_workers=4) 99 | #%% 100 | #with torch.no_grad(): 101 | all_features = [] 102 | for i_batch, imgs in enumerate(dataset_loader): 103 | print(i_batch) 104 | imgs=imgs.to(device) 105 | features = model_f(imgs) 106 | all_features.append(features.cpu().numpy()) 107 | all_features = np.concatenate(all_features,axis=0) 108 | #%% get remaining metadata 109 | matcontent = CUBDataset.matcontent 110 | labels = matcontent['labels'].astype(int).squeeze() - 1 111 | 112 | split_path = os.path.join(NFS_path,'data/xlsa17/data/CUB/att_splits.mat') 113 | matcontent = sio.loadmat(split_path) 114 | trainval_loc = matcontent['trainval_loc'].squeeze() - 1 115 | #train_loc = matcontent['train_loc'].squeeze() - 1 #--> train_feature = TRAIN SEEN 116 | #val_unseen_loc = matcontent['val_loc'].squeeze() - 1 #--> test_unseen_feature = TEST UNSEEN 117 | test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1 118 | test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1 119 | att = matcontent['att'].T 120 | original_att = matcontent['original_att'].T 121 | #%% construct attribute w2v 122 | with open(attribute_path,'rb') as f: 123 | w2v_att = pickle.load(f) 124 | assert w2v_att.shape == (312,300) 125 | print('save w2v_att') 126 | #%% 127 | if is_save: 128 | f = h5py.File(save_path, "w") 129 | f.create_dataset('feature_map', data=all_features,compression="gzip") 130 | f.create_dataset('labels', data=labels,compression="gzip") 131 | f.create_dataset('trainval_loc', data=trainval_loc,compression="gzip") 132 | # f.create_dataset('train_loc', data=train_loc,compression="gzip") 133 | # f.create_dataset('val_unseen_loc', data=val_unseen_loc,compression="gzip") 134 | f.create_dataset('test_seen_loc', data=test_seen_loc,compression="gzip") 135 | f.create_dataset('test_unseen_loc', data=test_unseen_loc,compression="gzip") 136 | f.create_dataset('att', data=att,compression="gzip") 137 | f.create_dataset('original_att', data=original_att,compression="gzip") 138 | f.create_dataset('w2v_att', data=w2v_att,compression="gzip") 139 | f.close() -------------------------------------------------------------------------------- /extract_feature/extract_feature_map_ResNet_101_DeepFashion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Aug 22 17:14:16 2019 5 | 6 | @author: war-machince 7 | """ 8 | 9 | import os,sys 10 | pwd = os.getcwd() 11 | sys.path.insert(0,pwd) 12 | #%% 13 | print('-'*30) 14 | print(os.getcwd()) 15 | print('-'*30) 16 | #%% 17 | import torch 18 | import torchvision 19 | import torch.nn as nn 20 | from torchvision import datasets, transforms 21 | from torch.utils.data import Dataset, DataLoader 22 | import torchvision.models.resnet as models 23 | from PIL import Image 24 | import h5py 25 | import numpy as np 26 | import scipy.io as sio 27 | import pickle 28 | from global_setting_Pegasus import NFS_path 29 | 30 | #%% 31 | import pdb 32 | #%% 33 | idx_GPU = 4 34 | is_save = True 35 | dataset = 'DeepFashion' 36 | #%% 37 | print("PyTorch Version: ",torch.__version__) 38 | print("Torchvision Version: ",torchvision.__version__) 39 | #%% 40 | img_dir = os.path.join(NFS_path,'data/{}/'.format(dataset)) 41 | file_paths = os.path.join(NFS_path,'data/{}/annotation.pkl'.format(dataset)) 42 | save_path = os.path.join(NFS_path,'data/{}/feature_map_ResNet_101_{}_sep_seen_samples.hdf5'.format(dataset,dataset)) 43 | attribute_path = './w2v/{}_attribute.pkl'.format(dataset) 44 | #pdb.set_trace() 45 | # Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception] 46 | model_name = "resnet" 47 | 48 | # Batch size for training (change depending on how much memory you have) 49 | batch_size = 32 50 | 51 | device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu") 52 | #%% 53 | 54 | model_ref = models.resnet101(pretrained=True) 55 | model_ref.eval() 56 | 57 | model_f = nn.Sequential(*list(model_ref.children())[:-2]) 58 | model_f.to(device) 59 | model_f.eval() 60 | 61 | for param in model_f.parameters(): 62 | param.requires_grad = False 63 | #%% 64 | class CustomedDataset(Dataset): 65 | """Face Landmarks dataset.""" 66 | 67 | def __init__(self, img_dir , file_paths, transform=None): 68 | self.package = pickle.load(open(file_paths,'rb')) 69 | self.image_files = self.package['image_names'] 70 | self.att = self.package['att'] 71 | self.img_dir = img_dir 72 | self.transform = transform 73 | 74 | def __len__(self): 75 | return len(self.image_files) 76 | 77 | def __getitem__(self, idx): 78 | image_file = self.image_files[idx] 79 | image_file = os.path.join(self.img_dir,image_file) 80 | image = Image.open(image_file) 81 | if image.mode == 'L': 82 | image=image.convert('RGB') 83 | if self.transform: 84 | image = self.transform(image) 85 | return image 86 | 87 | #%% 88 | input_size = 224 89 | data_transforms = transforms.Compose([ 90 | transforms.Resize(input_size), 91 | transforms.CenterCrop(input_size), 92 | transforms.ToTensor(), 93 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 94 | ]) 95 | 96 | dataset = CustomedDataset(img_dir , file_paths, data_transforms) 97 | dataset_loader = torch.utils.data.DataLoader(dataset, 98 | batch_size=batch_size, shuffle=False, 99 | num_workers=4) 100 | #%% 101 | with torch.no_grad(): 102 | all_features = [] 103 | for i_batch, imgs in enumerate(dataset_loader): 104 | print(i_batch) 105 | imgs=imgs.to(device) 106 | features = model_f(imgs) 107 | all_features.append(features.cpu().numpy()) 108 | all_features = np.concatenate(all_features,axis=0) 109 | #%% get remaining metadata 110 | labels = dataset.package['labels'].squeeze() 111 | 112 | train_loc = dataset.package['train_loc'].squeeze()#--> train_feature = TRAIN SEEN 113 | test_seen_loc = dataset.package['test_seen_loc'].squeeze() 114 | test_unseen_loc = dataset.package['test_unseen_loc'].squeeze() 115 | att = dataset.att 116 | #%% 117 | features_train = all_features[train_loc] 118 | label_train = labels[train_loc] 119 | img_train = dataset.image_files[train_loc].astype('S') 120 | 121 | features_test_seen = all_features[test_seen_loc] 122 | label_test_seen = labels[test_seen_loc] 123 | img_test_seen = dataset.image_files[test_seen_loc].astype('S') 124 | 125 | features_test_unseen = all_features[test_unseen_loc] 126 | label_test_unseen = labels[test_unseen_loc] 127 | img_test_unseen = dataset.image_files[test_unseen_loc].astype('S') 128 | #%% construct attribute w2v 129 | with open(attribute_path,'rb') as f: 130 | w2v_att = pickle.load(f) 131 | assert w2v_att.shape == (1000,300) 132 | print('load w2v_att') 133 | #%% 134 | string_dt = h5py.special_dtype(vlen=str) 135 | if is_save: 136 | f = h5py.File(save_path, "w") 137 | 138 | # f.create_dataset('feature_map_train', data=features_train,compression="gzip") 139 | # f.create_dataset('label_train', data=label_train,compression="gzip") 140 | # f.create_dataset('img_train', data=img_train,dtype = string_dt,compression="gzip") 141 | pdb.set_trace() 142 | unique_label_train = np.squeeze(np.unique(label_train)) 143 | for l in unique_label_train: 144 | mask_l = label_train == l 145 | print(l,np.sum(mask_l)) 146 | f.create_dataset('feature_map_train/{}'.format(l), data=features_train[mask_l],compression="gzip") 147 | f.create_dataset('label_train/{}'.format(l), data=label_train[mask_l],compression="gzip") 148 | f.create_dataset('img_train/{}'.format(l), data=img_train[mask_l],dtype = string_dt,compression="gzip") 149 | 150 | f.create_dataset('feature_map_test_seen', data=features_test_seen,compression="gzip") 151 | f.create_dataset('label_test_seen', data=label_test_seen,compression="gzip") 152 | f.create_dataset('img_test_seen', data=img_test_seen,dtype = string_dt,compression="gzip") 153 | 154 | f.create_dataset('feature_map_test_unseen', data=features_test_unseen,compression="gzip") 155 | f.create_dataset('label_test_unseen', data=label_test_unseen,compression="gzip") 156 | f.create_dataset('img_test_unseen', data=img_test_unseen,dtype = string_dt,compression="gzip") 157 | 158 | f.create_dataset('labels', data=labels,compression="gzip") 159 | f.create_dataset('imgs', data=dataset.image_files,dtype = string_dt,compression="gzip") 160 | f.create_dataset('train_loc', data=train_loc,compression="gzip") 161 | f.create_dataset('test_seen_loc', data=test_seen_loc,compression="gzip") 162 | f.create_dataset('test_unseen_loc', data=test_unseen_loc,compression="gzip") 163 | f.create_dataset('att', data=att,compression="gzip") 164 | f.create_dataset('w2v_att', data=w2v_att,compression="gzip") 165 | f.close() 166 | pdb.set_trace() -------------------------------------------------------------------------------- /extract_feature/extract_feature_map_ResNet_101_SUN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jul 30 12:11:07 2019 5 | 6 | @author: war-machince 7 | """ 8 | 9 | import os,sys 10 | pwd = os.getcwd() 11 | sys.path.insert(0,pwd) 12 | #%% 13 | print('-'*30) 14 | print(os.getcwd()) 15 | print('-'*30) 16 | #%% 17 | import torch 18 | import torchvision 19 | import torch.nn as nn 20 | from torchvision import datasets, transforms 21 | from torch.utils.data import Dataset, DataLoader 22 | import torchvision.models.resnet as models 23 | from PIL import Image 24 | import h5py 25 | import numpy as np 26 | import scipy.io as sio 27 | import pickle 28 | from global_setting_Pegasus import NFS_path 29 | 30 | #%% 31 | import pdb 32 | #%% 33 | idx_GPU = 6 34 | is_save = True 35 | dataset = 'SUN' 36 | #%% 37 | print("PyTorch Version: ",torch.__version__) 38 | print("Torchvision Version: ",torchvision.__version__) 39 | #%% 40 | img_dir = os.path.join(NFS_path,'data/{}/'.format(dataset)) 41 | file_paths = os.path.join(NFS_path,'data/xlsa17/data/{}/res101.mat'.format(dataset)) 42 | save_path = os.path.join(NFS_path,'data/{}/feature_map_ResNet_101_{}.hdf5'.format(dataset,dataset)) 43 | attribute_path = './w2v/{}_attribute.pkl'.format(dataset) 44 | #pdb.set_trace() 45 | # Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception] 46 | model_name = "resnet" 47 | 48 | # Batch size for training (change depending on how much memory you have) 49 | batch_size = 1 50 | 51 | device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu") 52 | #%% 53 | 54 | model_ref = models.resnet101(pretrained=True) 55 | model_ref.eval() 56 | 57 | model_f = nn.Sequential(*list(model_ref.children())[:-2]) 58 | model_f.to(device) 59 | model_f.eval() 60 | 61 | for param in model_f.parameters(): 62 | param.requires_grad = False 63 | #%% 64 | class CustomedDataset(Dataset): 65 | """Face Landmarks dataset.""" 66 | 67 | def __init__(self, img_dir , file_paths, transform=None): 68 | self.matcontent = sio.loadmat(file_paths) 69 | self.image_files = np.squeeze(self.matcontent['image_files']) 70 | self.img_dir = img_dir 71 | self.transform = transform 72 | 73 | def __len__(self): 74 | return len(self.image_files) 75 | 76 | def __getitem__(self, idx): 77 | image_file = self.image_files[idx][0] 78 | image_file = os.path.join(self.img_dir, 79 | '/'.join(image_file.split('/')[7:])) 80 | 81 | image = Image.open(image_file) 82 | if image.mode == 'L': 83 | image=image.convert('RGB') 84 | if self.transform: 85 | image = self.transform(image) 86 | return image,image_file 87 | 88 | #%% 89 | input_size = 224 90 | data_transforms = transforms.Compose([ 91 | transforms.Resize(input_size), 92 | transforms.CenterCrop(input_size), 93 | transforms.ToTensor(), 94 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 95 | ]) 96 | 97 | AWA2Dataset = CustomedDataset(img_dir , file_paths, data_transforms) 98 | dataset_loader = torch.utils.data.DataLoader(AWA2Dataset, 99 | batch_size=batch_size, shuffle=False, 100 | num_workers=4) 101 | #%% 102 | #with torch.no_grad(): 103 | all_features = [] 104 | error_files = [] 105 | for i_batch, package in enumerate(dataset_loader): 106 | print(i_batch) 107 | imgs,image_files=package 108 | imgs=imgs.to(device) 109 | 110 | print(imgs.size(1)) 111 | if imgs.size(1) != 3: 112 | print('Error') 113 | features = torch.zeros((1,2048,7,7)) 114 | error_files.append(image_files) 115 | else: 116 | features = model_f(imgs) 117 | 118 | all_features.append(features.cpu().numpy()) 119 | 120 | print('err_counter {}'.format(error_files)) 121 | all_features = np.concatenate(all_features,axis=0) 122 | #%% get remaining metadata 123 | matcontent = AWA2Dataset.matcontent 124 | labels = matcontent['labels'].astype(int).squeeze() - 1 125 | 126 | split_path = os.path.join(NFS_path,'data/xlsa17/data/{}/att_splits.mat'.format(dataset)) 127 | matcontent = sio.loadmat(split_path) 128 | trainval_loc = matcontent['trainval_loc'].squeeze() - 1 129 | #train_loc = matcontent['train_loc'].squeeze() - 1 #--> train_feature = TRAIN SEEN 130 | #val_unseen_loc = matcontent['val_loc'].squeeze() - 1 #--> test_unseen_feature = TEST UNSEEN 131 | test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1 132 | test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1 133 | att = matcontent['att'].T 134 | original_att = matcontent['original_att'].T 135 | #%% construct attribute w2v 136 | with open(attribute_path,'rb') as f: 137 | w2v_att = pickle.load(f) 138 | assert w2v_att.shape == (102,300) 139 | print('save w2v_att') 140 | #%% 141 | if is_save: 142 | f = h5py.File(save_path, "w") 143 | f.create_dataset('feature_map', data=all_features,compression="gzip") 144 | f.create_dataset('labels', data=labels,compression="gzip") 145 | f.create_dataset('trainval_loc', data=trainval_loc,compression="gzip") 146 | # f.create_dataset('train_loc', data=train_loc,compression="gzip") 147 | # f.create_dataset('val_unseen_loc', data=val_unseen_loc,compression="gzip") 148 | f.create_dataset('test_seen_loc', data=test_seen_loc,compression="gzip") 149 | f.create_dataset('test_unseen_loc', data=test_unseen_loc,compression="gzip") 150 | f.create_dataset('att', data=att,compression="gzip") 151 | f.create_dataset('original_att', data=original_att,compression="gzip") 152 | f.create_dataset('w2v_att', data=w2v_att,compression="gzip") 153 | f.close() -------------------------------------------------------------------------------- /fig/high_level_schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbdat/cvpr20_DAZLE/e5a04601d8368903008fd96dee3d95dde398aa51/fig/high_level_schematic.png -------------------------------------------------------------------------------- /global_setting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Jul 3 18:59:49 2019 5 | 6 | @author: war-machince 7 | """ 8 | 9 | NFS_path = './' -------------------------------------------------------------------------------- /notebook/.ipynb_checkpoints/DAZLE_AWA2-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "------------------------------\n", 13 | "/home/project_amadeus/home/hbdat/[RELEASE]_DenseAttentionZSL\n", 14 | "------------------------------\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import os,sys\n", 20 | "pwd = os.getcwd()\n", 21 | "parent = '/'.join(pwd.split('/')[:-1])\n", 22 | "sys.path.insert(0,parent)\n", 23 | "os.chdir(parent)\n", 24 | "#%%\n", 25 | "print('-'*30)\n", 26 | "print(os.getcwd())\n", 27 | "print('-'*30)\n", 28 | "#%%\n", 29 | "import torch\n", 30 | "import torch.optim as optim\n", 31 | "import torch.nn as nn\n", 32 | "import pandas as pd\n", 33 | "from core.DAZLE import DAZLE\n", 34 | "from core.AWA2DataLoader import AWA2DataLoader\n", 35 | "from core.helper_func import eval_zs_gzsl,visualize_attention#,get_attribute_attention_stats\n", 36 | "from global_setting import NFS_path\n", 37 | "import importlib\n", 38 | "import numpy as np\n", 39 | "import matplotlib.pyplot as plt" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "idx_GPU = 0\n", 49 | "device = torch.device(\"cuda:{}\".format(idx_GPU) if torch.cuda.is_available() else \"cpu\")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "torch.backends.cudnn.benchmark = True" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/\n", 71 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 72 | "AWA2\n", 73 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 74 | "Balance dataloader\n", 75 | "_____\n", 76 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/data/AWA2/feature_map_ResNet_101_AWA2.hdf5\n", 77 | "Expert Attr\n", 78 | "threshold at zero attribute with negative value\n", 79 | "Finish loading data in 283.97918799999997\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "dataloader = AWA2DataLoader(NFS_path,device)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "def get_lr(optimizer):\n", 94 | " lr = []\n", 95 | " for param_group in optimizer.param_groups:\n", 96 | " lr.append(param_group['lr'])\n", 97 | " return lr" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "seed = 214#214\n", 107 | "torch.manual_seed(seed)\n", 108 | "torch.cuda.manual_seed_all(seed)\n", 109 | "np.random.seed(seed)\n", 110 | "\n", 111 | "batch_size = 50\n", 112 | "nepoches = 20\n", 113 | "niters = dataloader.ntrain * nepoches//batch_size\n", 114 | "dim_f = 2048\n", 115 | "dim_v = 300\n", 116 | "init_w2v_att = dataloader.w2v_att\n", 117 | "att = dataloader.att#dataloader.normalize_att#\n", 118 | "att[att<0] = 0\n", 119 | "normalize_att = dataloader.normalize_att\n", 120 | "#assert (att.min().item() == 0 and att.max().item() == 1)\n", 121 | "\n", 122 | "trainable_w2v = True\n", 123 | "lambda_ = 0.1#0.1\n", 124 | "bias = 0\n", 125 | "prob_prune = 0\n", 126 | "uniform_att_1 = False\n", 127 | "uniform_att_2 = False\n", 128 | "\n", 129 | "seenclass = dataloader.seenclasses\n", 130 | "unseenclass = dataloader.unseenclasses\n", 131 | "desired_mass = 1#unseenclass.size(0)/(seenclass.size(0)+unseenclass.size(0))\n", 132 | "report_interval = niters//nepoches#10000//batch_size#\n", 133 | "\n", 134 | "model = DAZLE(dim_f,dim_v,init_w2v_att,att,normalize_att,\n", 135 | " seenclass,unseenclass,\n", 136 | " lambda_,\n", 137 | " trainable_w2v,normalize_V=True,normalize_F=True,is_conservative=True,\n", 138 | " uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,\n", 139 | " prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,\n", 140 | " is_bias=True)\n", 141 | "model.to(device)\n", 142 | "\n", 143 | "setup = {'pmp':{'init_lambda':0.1,'final_lambda':0.1,'phase':0.8},\n", 144 | " 'desired_mass':{'init_lambda':-1,'final_lambda':-1,'phase':0.8}}\n", 145 | "print(setup)\n", 146 | "\n", 147 | "params_to_update = []\n", 148 | "params_names = []\n", 149 | "for name,param in model.named_parameters():\n", 150 | " if param.requires_grad == True:\n", 151 | " params_to_update.append(param)\n", 152 | " params_names.append(name)\n", 153 | " print(\"\\t\",name)\n", 154 | "#%%\n", 155 | "lr = 0.0001\n", 156 | "weight_decay = 0.0001#0.000#0.#\n", 157 | "momentum = 0.#0.#\n", 158 | "#%%\n", 159 | "lr_seperator = 1\n", 160 | "lr_factor = 1\n", 161 | "print('default lr {} {}x lr {}'.format(params_names[:lr_seperator],lr_factor,params_names[lr_seperator:]))\n", 162 | "optimizer = optim.RMSprop( params_to_update ,lr=lr,weight_decay=weight_decay, momentum=momentum)\n", 163 | "print('-'*30)\n", 164 | "print('learing rate {}'.format(lr))\n", 165 | "print('trainable V {}'.format(trainable_w2v))\n", 166 | "print('lambda_ {}'.format(lambda_))\n", 167 | "print('optimized seen only')\n", 168 | "print('optimizer: RMSProp with momentum = {} and weight_decay = {}'.format(momentum,weight_decay))\n", 169 | "print('-'*30)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 7, 175 | "metadata": { 176 | "scrolled": true 177 | }, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "------------------------------\n", 184 | "bias_seen 0 bias_unseen 0\n", 185 | "{'iter': 0, 'loss': 3.955843210220337, 'loss_CE': 3.9093644618988037, 'loss_cal': 0.46478742361068726, 'acc_seen': 0, 'acc_novel': 0, 'H': 0, 'acc_zs': 0}\n", 186 | "------------------------------\n", 187 | "bias_seen 0 bias_unseen 0\n", 188 | "{'iter': 470, 'loss': 1.3514924049377441, 'loss_CE': 1.2345727682113647, 'loss_cal': 1.169196605682373, 'acc_seen': 0.5430148839950562, 'acc_novel': 0.6056535840034485, 'H': 0.5726263405347534, 'acc_zs': 0.646105170249939}\n", 189 | "------------------------------\n", 190 | "bias_seen 0 bias_unseen 0\n", 191 | "{'iter': 940, 'loss': 0.9182254672050476, 'loss_CE': 0.7962093949317932, 'loss_cal': 1.2201608419418335, 'acc_seen': 0.7086815237998962, 'acc_novel': 0.6011627912521362, 'H': 0.6505093132987168, 'acc_zs': 0.6683744192123413}\n", 192 | "------------------------------\n", 193 | "bias_seen 0 bias_unseen 0\n", 194 | "{'iter': 1410, 'loss': 0.7537283301353455, 'loss_CE': 0.6223986744880676, 'loss_cal': 1.3132964372634888, 'acc_seen': 0.7395804524421692, 'acc_novel': 0.5977693200111389, 'H': 0.6611561361974528, 'acc_zs': 0.6678923964500427}\n", 195 | "------------------------------\n", 196 | "bias_seen 0 bias_unseen 0\n", 197 | "{'iter': 1880, 'loss': 0.662609338760376, 'loss_CE': 0.5255433917045593, 'loss_cal': 1.370659351348877, 'acc_seen': 0.7518362998962402, 'acc_novel': 0.6027500033378601, 'H': 0.6690889036601549, 'acc_zs': 0.6755213737487793}\n", 198 | "------------------------------\n", 199 | "bias_seen 0 bias_unseen 0\n", 200 | "{'iter': 2350, 'loss': 0.6536160707473755, 'loss_CE': 0.5199357271194458, 'loss_cal': 1.3368035554885864, 'acc_seen': 0.7530007362365723, 'acc_novel': 0.6061833500862122, 'H': 0.6716625268092912, 'acc_zs': 0.6787406206130981}\n", 201 | "------------------------------\n", 202 | "bias_seen 0 bias_unseen 0\n", 203 | "{'iter': 2820, 'loss': 0.6108418107032776, 'loss_CE': 0.48195621371269226, 'loss_cal': 1.2888559103012085, 'acc_seen': 0.7530007362365723, 'acc_novel': 0.6061833500862122, 'H': 0.6716625268092912, 'acc_zs': 0.6787406206130981}\n", 204 | "------------------------------\n", 205 | "bias_seen 0 bias_unseen 0\n", 206 | "{'iter': 3290, 'loss': 0.5895015001296997, 'loss_CE': 0.45069620013237, 'loss_cal': 1.388053059577942, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 207 | "------------------------------\n", 208 | "bias_seen 0 bias_unseen 0\n", 209 | "{'iter': 3760, 'loss': 0.644405722618103, 'loss_CE': 0.511443555355072, 'loss_cal': 1.3296215534210205, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 210 | "------------------------------\n", 211 | "bias_seen 0 bias_unseen 0\n", 212 | "{'iter': 4230, 'loss': 0.5973643064498901, 'loss_CE': 0.462933212518692, 'loss_cal': 1.3443106412887573, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 213 | "------------------------------\n", 214 | "bias_seen 0 bias_unseen 0\n", 215 | "{'iter': 4700, 'loss': 0.6409440636634827, 'loss_CE': 0.5101036429405212, 'loss_cal': 1.3084039688110352, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 216 | "------------------------------\n", 217 | "bias_seen 0 bias_unseen 0\n", 218 | "{'iter': 5170, 'loss': 0.6012772917747498, 'loss_CE': 0.4705732464790344, 'loss_cal': 1.3070402145385742, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 219 | "------------------------------\n", 220 | "bias_seen 0 bias_unseen 0\n", 221 | "{'iter': 5640, 'loss': 0.6804268956184387, 'loss_CE': 0.5569705963134766, 'loss_cal': 1.2345629930496216, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 222 | "------------------------------\n", 223 | "bias_seen 0 bias_unseen 0\n", 224 | "{'iter': 6110, 'loss': 0.5830560922622681, 'loss_CE': 0.45000630617141724, 'loss_cal': 1.3304975032806396, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 225 | "------------------------------\n", 226 | "bias_seen 0 bias_unseen 0\n", 227 | "{'iter': 6580, 'loss': 0.680651843547821, 'loss_CE': 0.566943883895874, 'loss_cal': 1.1370794773101807, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 228 | "------------------------------\n", 229 | "bias_seen 0 bias_unseen 0\n", 230 | "{'iter': 7050, 'loss': 0.5572313666343689, 'loss_CE': 0.42874494194984436, 'loss_cal': 1.2848644256591797, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 231 | "------------------------------\n", 232 | "bias_seen 0 bias_unseen 0\n", 233 | "{'iter': 7520, 'loss': 0.5773841142654419, 'loss_CE': 0.4520866870880127, 'loss_cal': 1.2529743909835815, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 234 | "------------------------------\n", 235 | "bias_seen 0 bias_unseen 0\n", 236 | "{'iter': 7990, 'loss': 0.6745968461036682, 'loss_CE': 0.5561791658401489, 'loss_cal': 1.1841765642166138, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 237 | "------------------------------\n", 238 | "bias_seen 0 bias_unseen 0\n", 239 | "{'iter': 8460, 'loss': 0.549967885017395, 'loss_CE': 0.4177410304546356, 'loss_cal': 1.3222687244415283, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 240 | "------------------------------\n", 241 | "bias_seen 0 bias_unseen 0\n", 242 | "{'iter': 8930, 'loss': 0.6991280317306519, 'loss_CE': 0.5688791275024414, 'loss_cal': 1.302489161491394, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 243 | "------------------------------\n", 244 | "bias_seen 0 bias_unseen 0\n", 245 | "{'iter': 9400, 'loss': 0.6614841818809509, 'loss_CE': 0.5430386662483215, 'loss_cal': 1.184455394744873, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "best_performance = [0,0,0,0]\n", 251 | "for i in range(0,niters):\n", 252 | " model.train()\n", 253 | " optimizer.zero_grad()\n", 254 | " \n", 255 | " batch_label, batch_feature, batch_att = dataloader.next_batch(batch_size)\n", 256 | " out_package = model(batch_feature)\n", 257 | " \n", 258 | " in_package = out_package\n", 259 | " in_package['batch_label'] = batch_label\n", 260 | " \n", 261 | " out_package=model.compute_loss(in_package)\n", 262 | " loss,loss_CE,loss_cal = out_package['loss'],out_package['loss_CE'],out_package['loss_cal']\n", 263 | " \n", 264 | " loss.backward()\n", 265 | " optimizer.step()\n", 266 | " if i%report_interval==0:\n", 267 | " print('-'*30)\n", 268 | " acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(dataloader,model,device,bias_seen=-bias,bias_unseen=bias)\n", 269 | " \n", 270 | " if H > best_performance[2]:\n", 271 | " best_performance = [acc_seen, acc_novel, H, acc_zs]\n", 272 | " stats_package = {'iter':i, 'loss':loss.item(), 'loss_CE':loss_CE.item(),\n", 273 | " 'loss_cal': loss_cal.item(),\n", 274 | " 'acc_seen':best_performance[0], 'acc_novel':best_performance[1], 'H':best_performance[2], 'acc_zs':best_performance[3]}\n", 275 | " \n", 276 | " print(stats_package)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [] 285 | } 286 | ], 287 | "metadata": { 288 | "kernelspec": { 289 | "display_name": "Python 3", 290 | "language": "python", 291 | "name": "python3" 292 | }, 293 | "language_info": { 294 | "codemirror_mode": { 295 | "name": "ipython", 296 | "version": 3 297 | }, 298 | "file_extension": ".py", 299 | "mimetype": "text/x-python", 300 | "name": "python", 301 | "nbconvert_exporter": "python", 302 | "pygments_lexer": "ipython3", 303 | "version": "3.6.8" 304 | } 305 | }, 306 | "nbformat": 4, 307 | "nbformat_minor": 2 308 | } 309 | -------------------------------------------------------------------------------- /notebook/.ipynb_checkpoints/DAZLE_CUB_SS-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "------------------------------\n", 13 | "/home/project_amadeus/home/hbdat/[RELEASE]_DenseAttentionZSL\n", 14 | "------------------------------\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "import os,sys\n", 21 | "pwd = os.getcwd()\n", 22 | "parent = '/'.join(pwd.split('/')[:-1])\n", 23 | "sys.path.insert(0,parent)\n", 24 | "os.chdir(parent)\n", 25 | "#%%\n", 26 | "print('-'*30)\n", 27 | "print(os.getcwd())\n", 28 | "print('-'*30)\n", 29 | "#%%\n", 30 | "import torch\n", 31 | "import torch.optim as optim\n", 32 | "import torch.nn as nn\n", 33 | "import pandas as pd\n", 34 | "from core.DAZLE import DAZLE\n", 35 | "from core.CUBDataLoader_standard_split import CUBDataLoader\n", 36 | "from core.helper_func import eval_zs_gzsl,visualize_attention,eval_zs_gzsl#,get_attribute_attention_stats\n", 37 | "from global_setting import NFS_path\n", 38 | "#from core.Scheduler import Scheduler\n", 39 | "import importlib\n", 40 | "import pdb\n", 41 | "import numpy as np" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "idx_GPU = 0\n", 51 | "device = torch.device(\"cuda:{}\".format(idx_GPU) if torch.cuda.is_available() else \"cpu\")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "torch.backends.cudnn.benchmark = True" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "!!!!!!!!!! Standard Split !!!!!!!!!!\n", 73 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/\n", 74 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 75 | "CUB\n", 76 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 77 | "_____\n", 78 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/data/CUB/feature_map_ResNet_101_CUB.hdf5\n", 79 | "Expert Attr\n", 80 | "Finish loading data in 61.513001\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "dataloader = CUBDataLoader(NFS_path,device,is_unsupervised_attr=False,is_balance=False)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "def get_attr_entropy(att): #the lower the more discriminative it is\n", 95 | " eps = 1e-8\n", 96 | " mass=np.sum(att,axis = 0,keepdims=True)\n", 97 | " att_n = np.divide(att,mass+eps)\n", 98 | " entropy = np.sum(-att_n*np.log(att_n+eps),axis=0)\n", 99 | " assert len(entropy.shape)==1\n", 100 | " return entropy" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 6, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "batch_size = 50\n", 110 | "nepoches = 1\n", 111 | "niters = dataloader.ntrain * nepoches//batch_size\n", 112 | "dim_f = 2048\n", 113 | "dim_v = 300\n", 114 | "init_w2v_att = dataloader.w2v_att\n", 115 | "att = dataloader.att#dataloader.normalize_att#\n", 116 | "normalize_att = dataloader.att\n", 117 | "#%% attribute selection\n", 118 | "attr_entropy = get_attr_entropy(att.cpu().numpy())\n", 119 | "idx_attr_dis = np.argsort(attr_entropy)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 7, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "torch.Size([312, 300])\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "print(init_w2v_att.shape)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 8, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "def get_lr(optimizer):\n", 146 | " lr = []\n", 147 | " for param_group in optimizer.param_groups:\n", 148 | " lr.append(param_group['lr'])\n", 149 | " return lr" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "seed = 214#215#\n", 159 | "torch.manual_seed(seed)\n", 160 | "torch.cuda.manual_seed_all(seed)\n", 161 | "np.random.seed(seed)\n", 162 | "\n", 163 | "batch_size = 50\n", 164 | "nepoches = 20\n", 165 | "niters = dataloader.ntrain * nepoches//batch_size\n", 166 | "dim_f = 2048\n", 167 | "dim_v = 300\n", 168 | "init_w2v_att = dataloader.w2v_att\n", 169 | "att = dataloader.att#dataloader.normalize_att#\n", 170 | "normalize_att = dataloader.normalize_att\n", 171 | "#assert (att.min().item() == 0 and att.max().item() == 1)\n", 172 | "\n", 173 | "trainable_w2v = True\n", 174 | "lambda_ = 0.1\n", 175 | "bias = 0\n", 176 | "prob_prune = 0\n", 177 | "uniform_att_1 = False\n", 178 | "uniform_att_2 = False\n", 179 | "\n", 180 | "seenclass = dataloader.seenclasses\n", 181 | "unseenclass = dataloader.unseenclasses\n", 182 | "desired_mass = 1#unseenclass.size(0)/(seenclass.size(0)+unseenclass.size(0))\n", 183 | "report_interval = niters//nepoches#10000//batch_size#\n", 184 | "\n", 185 | "model = DAZLE(dim_f,dim_v,init_w2v_att,att,normalize_att,\n", 186 | " seenclass,unseenclass,\n", 187 | " lambda_,\n", 188 | " trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True,\n", 189 | " uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,\n", 190 | " prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,\n", 191 | " is_bias=True)\n", 192 | "model.to(device)\n", 193 | "\n", 194 | "setup = {'pmp':{'init_lambda':0.1,'final_lambda':0.1,'phase':0.8},\n", 195 | " 'desired_mass':{'init_lambda':-1,'final_lambda':-1,'phase':0.8}}\n", 196 | "print(setup)\n", 197 | "#scheduler = Scheduler(model,niters,batch_size,report_interval,setup)\n", 198 | "\n", 199 | "params_to_update = []\n", 200 | "params_names = []\n", 201 | "for name,param in model.named_parameters():\n", 202 | " if param.requires_grad == True:\n", 203 | " params_to_update.append(param)\n", 204 | " params_names.append(name)\n", 205 | " print(\"\\t\",name)\n", 206 | "#%%\n", 207 | "lr = 0.0001\n", 208 | "weight_decay = 0.00005#0.000#0.#\n", 209 | "momentum = 0.9#0.#\n", 210 | "#%%\n", 211 | "lr_seperator = 1\n", 212 | "lr_factor = 1\n", 213 | "print('default lr {} {}x lr {}'.format(params_names[:lr_seperator],lr_factor,params_names[lr_seperator:]))\n", 214 | "optimizer = optim.RMSprop( params_to_update ,lr=lr,weight_decay=weight_decay, momentum=momentum)\n", 215 | "print('-'*30)\n", 216 | "print('learing rate {}'.format(lr))\n", 217 | "print('trainable V {}'.format(trainable_w2v))\n", 218 | "print('lambda_ {}'.format(lambda_))\n", 219 | "print('optimized seen only')\n", 220 | "print('optimizer: RMSProp with momentum = {} and weight_decay = {}'.format(momentum,weight_decay))\n", 221 | "print('-'*30)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 10, 227 | "metadata": { 228 | "scrolled": false 229 | }, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "------------------------------\n", 236 | "bias_seen 0 bias_unseen 0\n", 237 | "{'iter': 0, 'loss': 5.307318687438965, 'loss_CE': 5.271770000457764, 'loss_cal': 0.35548919439315796, 'acc_seen': nan, 'acc_novel': 0.040221136063337326, 'H': 0, 'acc_zs': 0.04022114351391792}\n", 238 | "------------------------------\n", 239 | "bias_seen 0 bias_unseen 0\n", 240 | "{'iter': 177, 'loss': 1.6891649961471558, 'loss_CE': 1.5842840671539307, 'loss_cal': 1.0488090515136719, 'acc_seen': nan, 'acc_novel': 0.4655452370643616, 'H': 0, 'acc_zs': 0.5405555367469788}\n", 241 | "------------------------------\n", 242 | "bias_seen 0 bias_unseen 0\n", 243 | "{'iter': 354, 'loss': 1.5217909812927246, 'loss_CE': 1.4060150384902954, 'loss_cal': 1.1577597856521606, 'acc_seen': nan, 'acc_novel': 0.510532796382904, 'H': 0, 'acc_zs': 0.6040189266204834}\n", 244 | "------------------------------\n", 245 | "bias_seen 0 bias_unseen 0\n", 246 | "{'iter': 531, 'loss': 1.156381607055664, 'loss_CE': 1.0383371114730835, 'loss_cal': 1.1804451942443848, 'acc_seen': nan, 'acc_novel': 0.5140025019645691, 'H': 0, 'acc_zs': 0.6206724643707275}\n", 247 | "------------------------------\n", 248 | "bias_seen 0 bias_unseen 0\n", 249 | "{'iter': 708, 'loss': 1.1116386651992798, 'loss_CE': 0.9800852537155151, 'loss_cal': 1.3155337572097778, 'acc_seen': nan, 'acc_novel': 0.5589466094970703, 'H': 0, 'acc_zs': 0.6678099632263184}\n", 250 | "------------------------------\n", 251 | "bias_seen 0 bias_unseen 0\n", 252 | "{'iter': 885, 'loss': 1.3317115306854248, 'loss_CE': 1.1996527910232544, 'loss_cal': 1.320586919784546, 'acc_seen': nan, 'acc_novel': 0.5589466094970703, 'H': 0, 'acc_zs': 0.6678099632263184}\n", 253 | "------------------------------\n", 254 | "bias_seen 0 bias_unseen 0\n", 255 | "{'iter': 1062, 'loss': 1.118662714958191, 'loss_CE': 0.9750721454620361, 'loss_cal': 1.4359060525894165, 'acc_seen': nan, 'acc_novel': 0.5580828785896301, 'H': 0, 'acc_zs': 0.6697779297828674}\n", 256 | "------------------------------\n", 257 | "bias_seen 0 bias_unseen 0\n", 258 | "{'iter': 1239, 'loss': 0.8553117513656616, 'loss_CE': 0.6984010338783264, 'loss_cal': 1.5691068172454834, 'acc_seen': nan, 'acc_novel': 0.5738272070884705, 'H': 0, 'acc_zs': 0.6717338562011719}\n", 259 | "------------------------------\n", 260 | "bias_seen 0 bias_unseen 0\n", 261 | "{'iter': 1416, 'loss': 0.6992800831794739, 'loss_CE': 0.5499823689460754, 'loss_cal': 1.4929770231246948, 'acc_seen': nan, 'acc_novel': 0.5738272070884705, 'H': 0, 'acc_zs': 0.6717338562011719}\n", 262 | "------------------------------\n", 263 | "bias_seen 0 bias_unseen 0\n", 264 | "{'iter': 1593, 'loss': 1.0170717239379883, 'loss_CE': 0.88196861743927, 'loss_cal': 1.3510308265686035, 'acc_seen': nan, 'acc_novel': 0.5738272070884705, 'H': 0, 'acc_zs': 0.6717338562011719}\n", 265 | "------------------------------\n", 266 | "bias_seen 0 bias_unseen 0\n", 267 | "{'iter': 1770, 'loss': 0.9129882454872131, 'loss_CE': 0.7532180547714233, 'loss_cal': 1.5977017879486084, 'acc_seen': nan, 'acc_novel': 0.5701281428337097, 'H': 0, 'acc_zs': 0.6731362342834473}\n", 268 | "------------------------------\n", 269 | "bias_seen 0 bias_unseen 0\n", 270 | "{'iter': 1947, 'loss': 0.6349995136260986, 'loss_CE': 0.48872146010398865, 'loss_cal': 1.4627807140350342, 'acc_seen': nan, 'acc_novel': 0.5772320032119751, 'H': 0, 'acc_zs': 0.6761006116867065}\n", 271 | "------------------------------\n", 272 | "bias_seen 0 bias_unseen 0\n", 273 | "{'iter': 2124, 'loss': 0.7970255613327026, 'loss_CE': 0.6616990566253662, 'loss_cal': 1.3532648086547852, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 274 | "------------------------------\n", 275 | "bias_seen 0 bias_unseen 0\n", 276 | "{'iter': 2301, 'loss': 0.7347122430801392, 'loss_CE': 0.5974063873291016, 'loss_cal': 1.373058557510376, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 277 | "------------------------------\n", 278 | "bias_seen 0 bias_unseen 0\n", 279 | "{'iter': 2478, 'loss': 0.5548276901245117, 'loss_CE': 0.3850424587726593, 'loss_cal': 1.6978520154953003, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 280 | "------------------------------\n", 281 | "bias_seen 0 bias_unseen 0\n", 282 | "{'iter': 2655, 'loss': 0.6615085601806641, 'loss_CE': 0.5202628374099731, 'loss_cal': 1.4124568700790405, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 283 | "------------------------------\n", 284 | "bias_seen 0 bias_unseen 0\n", 285 | "{'iter': 2832, 'loss': 0.6729946136474609, 'loss_CE': 0.501806914806366, 'loss_cal': 1.711876630783081, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 286 | "------------------------------\n", 287 | "bias_seen 0 bias_unseen 0\n", 288 | "{'iter': 3009, 'loss': 0.6648485064506531, 'loss_CE': 0.5202063322067261, 'loss_cal': 1.4464218616485596, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 289 | "------------------------------\n", 290 | "bias_seen 0 bias_unseen 0\n", 291 | "{'iter': 3186, 'loss': 0.5674977898597717, 'loss_CE': 0.42501401901245117, 'loss_cal': 1.424837589263916, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 292 | "------------------------------\n", 293 | "bias_seen 0 bias_unseen 0\n", 294 | "{'iter': 3363, 'loss': 0.5623935461044312, 'loss_CE': 0.4077602028846741, 'loss_cal': 1.5463331937789917, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 295 | "------------------------------\n", 296 | "bias_seen 0 bias_unseen 0\n", 297 | "{'iter': 3540, 'loss': 0.6784080862998962, 'loss_CE': 0.5070964694023132, 'loss_cal': 1.713115930557251, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n" 298 | ] 299 | } 300 | ], 301 | "source": [ 302 | "best_performance = [0,0,0,0]\n", 303 | "for i in range(0,niters):\n", 304 | " model.train()\n", 305 | " optimizer.zero_grad()\n", 306 | " \n", 307 | " batch_label, batch_feature, batch_att = dataloader.next_batch(batch_size)\n", 308 | " out_package = model(batch_feature)\n", 309 | " \n", 310 | " in_package = out_package\n", 311 | " in_package['batch_label'] = batch_label\n", 312 | " \n", 313 | " out_package=model.compute_loss(in_package)\n", 314 | " loss,loss_CE,loss_cal = out_package['loss'],out_package['loss_CE'],out_package['loss_cal']\n", 315 | " \n", 316 | " loss.backward()\n", 317 | " optimizer.step()\n", 318 | " if i%report_interval==0:\n", 319 | " print('-'*30)\n", 320 | " acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(dataloader,model,device,bias_seen=-bias,bias_unseen=bias)\n", 321 | " \n", 322 | " if acc_zs > best_performance[3]:\n", 323 | " best_performance = [acc_seen, acc_novel, H, acc_zs]\n", 324 | " stats_package = {'iter':i, 'loss':loss.item(), 'loss_CE':loss_CE.item(),\n", 325 | " 'loss_cal': loss_cal.item(),\n", 326 | " 'acc_seen':best_performance[0], 'acc_novel':best_performance[1], 'H':best_performance[2], 'acc_zs':best_performance[3]}\n", 327 | " \n", 328 | " print(stats_package)" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [] 337 | } 338 | ], 339 | "metadata": { 340 | "kernelspec": { 341 | "display_name": "Python 3", 342 | "language": "python", 343 | "name": "python3" 344 | }, 345 | "language_info": { 346 | "codemirror_mode": { 347 | "name": "ipython", 348 | "version": 3 349 | }, 350 | "file_extension": ".py", 351 | "mimetype": "text/x-python", 352 | "name": "python", 353 | "nbconvert_exporter": "python", 354 | "pygments_lexer": "ipython3", 355 | "version": "3.6.8" 356 | } 357 | }, 358 | "nbformat": 4, 359 | "nbformat_minor": 2 360 | } 361 | -------------------------------------------------------------------------------- /notebook/.ipynb_checkpoints/DAZLE_DeepFashion-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "------------------------------\n", 13 | "/home/project_amadeus/home/hbdat/[RELEASE]_DenseAttentionZSL/notebook\n", 14 | "------------------------------\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import os,sys\n", 20 | "pwd = os.getcwd()\n", 21 | "parent = '/'.join(pwd.split('/')[:-1])\n", 22 | "sys.path.insert(0,parent)\n", 23 | "#%%\n", 24 | "print('-'*30)\n", 25 | "print(os.getcwd())\n", 26 | "print('-'*30)\n", 27 | "#%%\n", 28 | "import torch\n", 29 | "import torch.optim as optim\n", 30 | "import torch.nn as nn\n", 31 | "import pandas as pd\n", 32 | "import numpy as np\n", 33 | "import time\n", 34 | "import h5py\n", 35 | "from core.DAZLE import DAZLE\n", 36 | "from core.DeepFashionDataLoader import DeepFashionDataLoader\n", 37 | "from core.helper_func import eval_zs_gzsl,get_lr,get_attr_entropy#get_attribute_attention_stats\n", 38 | "from global_setting import NFS_path" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "idx_GPU = 7\n", 48 | "device = torch.device(\"cuda:{}\".format(idx_GPU) if torch.cuda.is_available() else \"cpu\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "torch.backends.cudnn.benchmark = True" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/\n", 70 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 71 | "DeepFashion\n", 72 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 73 | "_____\n", 74 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/data/DeepFashion/feature_map_ResNet_101_DeepFashion_sep_seen_samples.hdf5\n", 75 | "Finish loading data in 551.489919\n", 76 | "Balance dataloader\n", 77 | "Partition size 10000\n", 78 | "Excluding non-sample classes\n", 79 | "------------------------------\n", 80 | "DeepFashion\n", 81 | "------------------------------\n" 82 | ] 83 | } 84 | ], 85 | "source": [ 86 | "dataloader = DeepFashionDataLoader(NFS_path,device,is_balance = True)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 5, 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "Randomize seed 214\n", 99 | "seeker [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", 100 | " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n" 101 | ] 102 | }, 103 | { 104 | "name": "stderr", 105 | "output_type": "stream", 106 | "text": [ 107 | "/home/project_amadeus/home/hbdat/[RELEASE]_DenseAttentionZSL/core/DAZLE.py:51: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", 108 | " self.init_w2v_att = F.normalize(torch.tensor(init_w2v_att))\n", 109 | "/home/project_amadeus/home/hbdat/[RELEASE]_DenseAttentionZSL/core/DAZLE.py:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", 110 | " self.att = nn.Parameter(F.normalize(torch.tensor(att)),requires_grad = False)\n" 111 | ] 112 | }, 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "------------------------------\n", 118 | "Configuration\n", 119 | "loss_type CE\n", 120 | "no constraint V\n", 121 | "normalize F\n", 122 | "training to exclude unseen class [seen upperbound]\n", 123 | "Init word2vec\n", 124 | "Linear model\n", 125 | "loss_att BCEWithLogitsLoss()\n", 126 | "Bilinear attention module\n", 127 | "******************************\n", 128 | "Measure w2v deviation\n", 129 | "new Laplacian smoothing with desire mass 1 4\n", 130 | "Compute Pruning loss 0\n", 131 | "Add one smoothing\n", 132 | "Second layer attenion conditioned on image features\n", 133 | "------------------------------\n", 134 | "No sigmoid on attr score\n", 135 | "\t V\n", 136 | "\t W_1\n", 137 | "\t W_2\n", 138 | "\t W_3\n", 139 | "------------------------------\n", 140 | "learing rate 0.0001\n", 141 | "trainable V True\n", 142 | "lambda_ 0.1\n", 143 | "optimized seen only\n", 144 | "optimizer: RMSProp with momentum = 0.9 and weight_decay = 0.0001\n", 145 | "------------------------------\n" 146 | ] 147 | } 148 | ], 149 | "source": [ 150 | "#%%\n", 151 | "seed = 214\n", 152 | "torch.manual_seed(seed)\n", 153 | "torch.cuda.manual_seed_all(seed)\n", 154 | "np.random.seed(seed)\n", 155 | "dataloader.idx_b = 0\n", 156 | "print('Randomize seed {}'.format(seed))\n", 157 | "#%%\n", 158 | "batch_size = 50\n", 159 | "nepoches = 1\n", 160 | "niters = dataloader.ntrain * nepoches//batch_size\n", 161 | "dim_f = 2048\n", 162 | "dim_v = 300\n", 163 | "init_w2v_att = dataloader.w2v_att\n", 164 | "att = dataloader.att#dataloader.normalize_att#\n", 165 | "normalize_att = dataloader.att\n", 166 | "#assert (att.min().item() == 0 and att.max().item() == 1)\n", 167 | "\n", 168 | "trainable_w2v = True\n", 169 | "lambda_ = 0.1\n", 170 | "bias = 0.\n", 171 | "prob_prune = 0\n", 172 | "uniform_att_1 = False\n", 173 | "uniform_att_2 = False\n", 174 | "\n", 175 | "dataloader.seeker[:] = 0\n", 176 | "print('seeker ',dataloader.seeker)\n", 177 | "\n", 178 | "seenclass = dataloader.seenclasses\n", 179 | "unseenclass = dataloader.unseenclasses\n", 180 | "desired_mass = 1#unseenclass.size(0)/(seenclass.size(0)+unseenclass.size(0))\n", 181 | "report_interval = 200#niters//nepoches\n", 182 | "#%%\n", 183 | "model = DAZLE(dim_f,dim_v,init_w2v_att,att,normalize_att,\n", 184 | " seenclass,unseenclass,\n", 185 | " lambda_,\n", 186 | " trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True,\n", 187 | " uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,\n", 188 | " prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,\n", 189 | " is_bias=True,non_linear_act=False)\n", 190 | "model.to(device)\n", 191 | "#%%\n", 192 | "params_to_update = []\n", 193 | "for name,param in model.named_parameters():\n", 194 | " if param.requires_grad == True:\n", 195 | " params_to_update.append(param)\n", 196 | " print(\"\\t\",name)\n", 197 | "#%%\n", 198 | "lr = 0.0001\n", 199 | "weight_decay = 0.0001#0.000#0.#\n", 200 | "momentum = 0.9#0.#\n", 201 | "optimizer = optim.RMSprop( params_to_update ,lr=lr,weight_decay=weight_decay, momentum=momentum)\n", 202 | "#%%\n", 203 | "print('-'*30)\n", 204 | "print('learing rate {}'.format(lr))\n", 205 | "print('trainable V {}'.format(trainable_w2v))\n", 206 | "print('lambda_ {}'.format(lambda_))\n", 207 | "print('optimized seen only')\n", 208 | "print('optimizer: RMSProp with momentum = {} and weight_decay = {}'.format(momentum,weight_decay))\n", 209 | "print('-'*30)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": { 216 | "scrolled": true 217 | }, 218 | "outputs": [ 219 | { 220 | "name": "stdout", 221 | "output_type": "stream", 222 | "text": [ 223 | "load data from hdf\n", 224 | "1..10..11..13..14..16..17..18..20..21..22..23..24..25..27..28..3..31..32..33..34..35..36..37..39..4..40..41..43..46..47..5..6..7..8..9..\n", 225 | "Elapsed time 33.30335600000001\n", 226 | "------------------------------\n", 227 | "bias_seen -0.0 bias_unseen 0.0\n", 228 | "{'iter': 0, 'loss': 4.189703941345215, 'loss_CE': 4.158728122711182, 'loss_cal': 0.3097580075263977, 'acc_seen': 0, 'acc_novel': 0, 'H': 0, 'acc_zs': 0}\n", 229 | "------------------------------\n", 230 | "bias_seen -0.0 bias_unseen 0.0\n", 231 | "{'iter': 200, 'loss': 1.8229509592056274, 'loss_CE': 1.7100118398666382, 'loss_cal': 1.1293911933898926, 'acc_seen': 0.3145444989204407, 'acc_novel': 0.19587568938732147, 'H': 0.24141529657475722, 'acc_zs': 0.3387182056903839}\n", 232 | "load data from hdf\n", 233 | "1..10..11..13..14..16..17..18..20..21..22..23..24..25..27..28..3..31..32..33..34..35..36..37..39..4..40..41..43..46..47..5..6..7..8..9..\n", 234 | "Elapsed time 35.98181899999997\n", 235 | "------------------------------\n", 236 | "bias_seen -0.0 bias_unseen 0.0\n", 237 | "{'iter': 400, 'loss': 1.8814144134521484, 'loss_CE': 1.7609858512878418, 'loss_cal': 1.2042860984802246, 'acc_seen': 0.37800484895706177, 'acc_novel': 0.1890132874250412, 'H': 0.25201288840550584, 'acc_zs': 0.36421096324920654}\n", 238 | "------------------------------\n", 239 | "bias_seen -0.0 bias_unseen 0.0\n", 240 | "{'iter': 600, 'loss': 1.2690142393112183, 'loss_CE': 1.1435215473175049, 'loss_cal': 1.2549272775650024, 'acc_seen': 0.36895838379859924, 'acc_novel': 0.20752708613872528, 'H': 0.2656402018406512, 'acc_zs': 0.36523857712745667}\n", 241 | "load data from hdf\n", 242 | "1..10..11..13..14..16..17..18..20..21..22..23..24..25..27..28..3..31..32..33..34..35..36..37..39..4..40..41..43..46..47..5..6..7..8..9..\n", 243 | "Elapsed time 36.09313999999995\n", 244 | "------------------------------\n", 245 | "bias_seen -0.0 bias_unseen 0.0\n", 246 | "{'iter': 800, 'loss': 1.343553066253662, 'loss_CE': 1.224740982055664, 'loss_cal': 1.1881206035614014, 'acc_seen': 0.36895838379859924, 'acc_novel': 0.20752708613872528, 'H': 0.2656402018406512, 'acc_zs': 0.36523857712745667}\n", 247 | "------------------------------\n", 248 | "bias_seen -0.0 bias_unseen 0.0\n", 249 | "{'iter': 1000, 'loss': 1.1187831163406372, 'loss_CE': 0.9923740029335022, 'loss_cal': 1.2640912532806396, 'acc_seen': 0.36895838379859924, 'acc_novel': 0.20752708613872528, 'H': 0.2656402018406512, 'acc_zs': 0.36523857712745667}\n", 250 | "load data from hdf\n", 251 | "1..10..11..13..14..16..17..18..20..21..22..23..24..25..27..28..3..31..32..33..34..35..36..37..39..4..40..41..43..46..47..5..6..7..8..9..\n", 252 | "Elapsed time 35.591203000000064\n", 253 | "------------------------------\n", 254 | "bias_seen -0.0 bias_unseen 0.0\n", 255 | "{'iter': 1200, 'loss': 1.322019100189209, 'loss_CE': 1.2000318765640259, 'loss_cal': 1.2198721170425415, 'acc_seen': 0.36895838379859924, 'acc_novel': 0.20752708613872528, 'H': 0.2656402018406512, 'acc_zs': 0.36523857712745667}\n", 256 | "------------------------------\n", 257 | "bias_seen -0.0 bias_unseen 0.0\n", 258 | "{'iter': 1400, 'loss': 1.013647198677063, 'loss_CE': 0.8876954913139343, 'loss_cal': 1.2595171928405762, 'acc_seen': 0.38100937008857727, 'acc_novel': 0.21498239040374756, 'H': 0.27487059579550444, 'acc_zs': 0.3899243175983429}\n", 259 | "load data from hdf\n", 260 | "1..10..11..13..14..16..17..18..20..21..22..23..24..25..27..28..3..31..32..33..34..35..36..37..39..4..40..41..43..46..47..5..6..7..8..9..\n", 261 | "Elapsed time 37.208722999999964\n", 262 | "------------------------------\n", 263 | "bias_seen -0.0 bias_unseen 0.0\n", 264 | "{'iter': 1600, 'loss': 1.5902624130249023, 'loss_CE': 1.4395182132720947, 'loss_cal': 1.5074422359466553, 'acc_seen': 0.38100937008857727, 'acc_novel': 0.21498239040374756, 'H': 0.27487059579550444, 'acc_zs': 0.3899243175983429}\n", 265 | "------------------------------\n", 266 | "bias_seen -0.0 bias_unseen 0.0\n", 267 | "{'iter': 1800, 'loss': 1.0743675231933594, 'loss_CE': 0.9522173404693604, 'loss_cal': 1.2215015888214111, 'acc_seen': 0.38100937008857727, 'acc_novel': 0.21498239040374756, 'H': 0.27487059579550444, 'acc_zs': 0.3899243175983429}\n", 268 | "load data from hdf\n", 269 | "1..10..11..13..14..16..17..18..20..21..22..23..24..25..27..28..3..31..32..33..34..35..36..37..39..4..40..41..43..46..47..5..6..7..8..9..\n", 270 | "Elapsed time 31.795719000000076\n", 271 | "------------------------------\n", 272 | "bias_seen -0.0 bias_unseen 0.0\n", 273 | "{'iter': 2000, 'loss': 1.172136902809143, 'loss_CE': 1.0547785758972168, 'loss_cal': 1.1735827922821045, 'acc_seen': 0.38100937008857727, 'acc_novel': 0.21498239040374756, 'H': 0.27487059579550444, 'acc_zs': 0.3899243175983429}\n", 274 | "------------------------------\n", 275 | "bias_seen -0.0 bias_unseen 0.0\n" 276 | ] 277 | } 278 | ], 279 | "source": [ 280 | "best_performance = [0,0,0,0]\n", 281 | "for i in range(0,niters):\n", 282 | " model.train()\n", 283 | " optimizer.zero_grad()\n", 284 | " \n", 285 | " batch_label, batch_feature, batch_att = dataloader.next_batch(batch_size)\n", 286 | " out_package = model(batch_feature)\n", 287 | " \n", 288 | " in_package = out_package\n", 289 | " in_package['batch_label'] = batch_label\n", 290 | " \n", 291 | " out_package=model.compute_loss(in_package)\n", 292 | " loss,loss_CE,loss_cal = out_package['loss'],out_package['loss_CE'],out_package['loss_cal']\n", 293 | " \n", 294 | " loss.backward()\n", 295 | " optimizer.step()\n", 296 | " if i%report_interval==0:\n", 297 | " print('-'*30)\n", 298 | " acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(dataloader,model,device,bias_seen=-bias,bias_unseen=bias)\n", 299 | " \n", 300 | " if H > best_performance[2]:\n", 301 | " best_performance = [acc_seen, acc_novel, H, acc_zs]\n", 302 | " stats_package = {'iter':i, 'loss':loss.item(), 'loss_CE':loss_CE.item(),\n", 303 | " 'loss_cal': loss_cal.item(),\n", 304 | " 'acc_seen':best_performance[0], 'acc_novel':best_performance[1], 'H':best_performance[2], 'acc_zs':best_performance[3]}\n", 305 | " \n", 306 | " print(stats_package)" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [] 315 | } 316 | ], 317 | "metadata": { 318 | "kernelspec": { 319 | "display_name": "Python 3", 320 | "language": "python", 321 | "name": "python3" 322 | }, 323 | "language_info": { 324 | "codemirror_mode": { 325 | "name": "ipython", 326 | "version": 3 327 | }, 328 | "file_extension": ".py", 329 | "mimetype": "text/x-python", 330 | "name": "python", 331 | "nbconvert_exporter": "python", 332 | "pygments_lexer": "ipython3", 333 | "version": "3.6.8" 334 | } 335 | }, 336 | "nbformat": 4, 337 | "nbformat_minor": 2 338 | } 339 | -------------------------------------------------------------------------------- /notebook/DAZLE_AWA2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "------------------------------\n", 13 | "/home/project_amadeus/home/hbdat/[RELEASE]_DenseAttentionZSL\n", 14 | "------------------------------\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import os,sys\n", 20 | "pwd = os.getcwd()\n", 21 | "parent = '/'.join(pwd.split('/')[:-1])\n", 22 | "sys.path.insert(0,parent)\n", 23 | "os.chdir(parent)\n", 24 | "#%%\n", 25 | "print('-'*30)\n", 26 | "print(os.getcwd())\n", 27 | "print('-'*30)\n", 28 | "#%%\n", 29 | "import torch\n", 30 | "import torch.optim as optim\n", 31 | "import torch.nn as nn\n", 32 | "import pandas as pd\n", 33 | "from core.DAZLE import DAZLE\n", 34 | "from core.AWA2DataLoader import AWA2DataLoader\n", 35 | "from core.helper_func import eval_zs_gzsl,visualize_attention#,get_attribute_attention_stats\n", 36 | "from global_setting import NFS_path\n", 37 | "import importlib\n", 38 | "import numpy as np\n", 39 | "import matplotlib.pyplot as plt" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "idx_GPU = 0\n", 49 | "device = torch.device(\"cuda:{}\".format(idx_GPU) if torch.cuda.is_available() else \"cpu\")" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "torch.backends.cudnn.benchmark = True" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/\n", 71 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 72 | "AWA2\n", 73 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 74 | "Balance dataloader\n", 75 | "_____\n", 76 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/data/AWA2/feature_map_ResNet_101_AWA2.hdf5\n", 77 | "Expert Attr\n", 78 | "threshold at zero attribute with negative value\n", 79 | "Finish loading data in 283.97918799999997\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "dataloader = AWA2DataLoader(NFS_path,device)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "def get_lr(optimizer):\n", 94 | " lr = []\n", 95 | " for param_group in optimizer.param_groups:\n", 96 | " lr.append(param_group['lr'])\n", 97 | " return lr" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "seed = 214#214\n", 107 | "torch.manual_seed(seed)\n", 108 | "torch.cuda.manual_seed_all(seed)\n", 109 | "np.random.seed(seed)\n", 110 | "\n", 111 | "batch_size = 50\n", 112 | "nepoches = 20\n", 113 | "niters = dataloader.ntrain * nepoches//batch_size\n", 114 | "dim_f = 2048\n", 115 | "dim_v = 300\n", 116 | "init_w2v_att = dataloader.w2v_att\n", 117 | "att = dataloader.att#dataloader.normalize_att#\n", 118 | "att[att<0] = 0\n", 119 | "normalize_att = dataloader.normalize_att\n", 120 | "#assert (att.min().item() == 0 and att.max().item() == 1)\n", 121 | "\n", 122 | "trainable_w2v = True\n", 123 | "lambda_ = 0.1#0.1\n", 124 | "bias = 0\n", 125 | "prob_prune = 0\n", 126 | "uniform_att_1 = False\n", 127 | "uniform_att_2 = False\n", 128 | "\n", 129 | "seenclass = dataloader.seenclasses\n", 130 | "unseenclass = dataloader.unseenclasses\n", 131 | "desired_mass = 1#unseenclass.size(0)/(seenclass.size(0)+unseenclass.size(0))\n", 132 | "report_interval = niters//nepoches#10000//batch_size#\n", 133 | "\n", 134 | "model = DAZLE(dim_f,dim_v,init_w2v_att,att,normalize_att,\n", 135 | " seenclass,unseenclass,\n", 136 | " lambda_,\n", 137 | " trainable_w2v,normalize_V=True,normalize_F=True,is_conservative=True,\n", 138 | " uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,\n", 139 | " prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,\n", 140 | " is_bias=True)\n", 141 | "model.to(device)\n", 142 | "\n", 143 | "setup = {'pmp':{'init_lambda':0.1,'final_lambda':0.1,'phase':0.8},\n", 144 | " 'desired_mass':{'init_lambda':-1,'final_lambda':-1,'phase':0.8}}\n", 145 | "print(setup)\n", 146 | "\n", 147 | "params_to_update = []\n", 148 | "params_names = []\n", 149 | "for name,param in model.named_parameters():\n", 150 | " if param.requires_grad == True:\n", 151 | " params_to_update.append(param)\n", 152 | " params_names.append(name)\n", 153 | " print(\"\\t\",name)\n", 154 | "#%%\n", 155 | "lr = 0.0001\n", 156 | "weight_decay = 0.0001#0.000#0.#\n", 157 | "momentum = 0.#0.#\n", 158 | "#%%\n", 159 | "lr_seperator = 1\n", 160 | "lr_factor = 1\n", 161 | "print('default lr {} {}x lr {}'.format(params_names[:lr_seperator],lr_factor,params_names[lr_seperator:]))\n", 162 | "optimizer = optim.RMSprop( params_to_update ,lr=lr,weight_decay=weight_decay, momentum=momentum)\n", 163 | "print('-'*30)\n", 164 | "print('learing rate {}'.format(lr))\n", 165 | "print('trainable V {}'.format(trainable_w2v))\n", 166 | "print('lambda_ {}'.format(lambda_))\n", 167 | "print('optimized seen only')\n", 168 | "print('optimizer: RMSProp with momentum = {} and weight_decay = {}'.format(momentum,weight_decay))\n", 169 | "print('-'*30)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 7, 175 | "metadata": { 176 | "scrolled": true 177 | }, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "------------------------------\n", 184 | "bias_seen 0 bias_unseen 0\n", 185 | "{'iter': 0, 'loss': 3.955843210220337, 'loss_CE': 3.9093644618988037, 'loss_cal': 0.46478742361068726, 'acc_seen': 0, 'acc_novel': 0, 'H': 0, 'acc_zs': 0}\n", 186 | "------------------------------\n", 187 | "bias_seen 0 bias_unseen 0\n", 188 | "{'iter': 470, 'loss': 1.3514924049377441, 'loss_CE': 1.2345727682113647, 'loss_cal': 1.169196605682373, 'acc_seen': 0.5430148839950562, 'acc_novel': 0.6056535840034485, 'H': 0.5726263405347534, 'acc_zs': 0.646105170249939}\n", 189 | "------------------------------\n", 190 | "bias_seen 0 bias_unseen 0\n", 191 | "{'iter': 940, 'loss': 0.9182254672050476, 'loss_CE': 0.7962093949317932, 'loss_cal': 1.2201608419418335, 'acc_seen': 0.7086815237998962, 'acc_novel': 0.6011627912521362, 'H': 0.6505093132987168, 'acc_zs': 0.6683744192123413}\n", 192 | "------------------------------\n", 193 | "bias_seen 0 bias_unseen 0\n", 194 | "{'iter': 1410, 'loss': 0.7537283301353455, 'loss_CE': 0.6223986744880676, 'loss_cal': 1.3132964372634888, 'acc_seen': 0.7395804524421692, 'acc_novel': 0.5977693200111389, 'H': 0.6611561361974528, 'acc_zs': 0.6678923964500427}\n", 195 | "------------------------------\n", 196 | "bias_seen 0 bias_unseen 0\n", 197 | "{'iter': 1880, 'loss': 0.662609338760376, 'loss_CE': 0.5255433917045593, 'loss_cal': 1.370659351348877, 'acc_seen': 0.7518362998962402, 'acc_novel': 0.6027500033378601, 'H': 0.6690889036601549, 'acc_zs': 0.6755213737487793}\n", 198 | "------------------------------\n", 199 | "bias_seen 0 bias_unseen 0\n", 200 | "{'iter': 2350, 'loss': 0.6536160707473755, 'loss_CE': 0.5199357271194458, 'loss_cal': 1.3368035554885864, 'acc_seen': 0.7530007362365723, 'acc_novel': 0.6061833500862122, 'H': 0.6716625268092912, 'acc_zs': 0.6787406206130981}\n", 201 | "------------------------------\n", 202 | "bias_seen 0 bias_unseen 0\n", 203 | "{'iter': 2820, 'loss': 0.6108418107032776, 'loss_CE': 0.48195621371269226, 'loss_cal': 1.2888559103012085, 'acc_seen': 0.7530007362365723, 'acc_novel': 0.6061833500862122, 'H': 0.6716625268092912, 'acc_zs': 0.6787406206130981}\n", 204 | "------------------------------\n", 205 | "bias_seen 0 bias_unseen 0\n", 206 | "{'iter': 3290, 'loss': 0.5895015001296997, 'loss_CE': 0.45069620013237, 'loss_cal': 1.388053059577942, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 207 | "------------------------------\n", 208 | "bias_seen 0 bias_unseen 0\n", 209 | "{'iter': 3760, 'loss': 0.644405722618103, 'loss_CE': 0.511443555355072, 'loss_cal': 1.3296215534210205, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 210 | "------------------------------\n", 211 | "bias_seen 0 bias_unseen 0\n", 212 | "{'iter': 4230, 'loss': 0.5973643064498901, 'loss_CE': 0.462933212518692, 'loss_cal': 1.3443106412887573, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 213 | "------------------------------\n", 214 | "bias_seen 0 bias_unseen 0\n", 215 | "{'iter': 4700, 'loss': 0.6409440636634827, 'loss_CE': 0.5101036429405212, 'loss_cal': 1.3084039688110352, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 216 | "------------------------------\n", 217 | "bias_seen 0 bias_unseen 0\n", 218 | "{'iter': 5170, 'loss': 0.6012772917747498, 'loss_CE': 0.4705732464790344, 'loss_cal': 1.3070402145385742, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 219 | "------------------------------\n", 220 | "bias_seen 0 bias_unseen 0\n", 221 | "{'iter': 5640, 'loss': 0.6804268956184387, 'loss_CE': 0.5569705963134766, 'loss_cal': 1.2345629930496216, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 222 | "------------------------------\n", 223 | "bias_seen 0 bias_unseen 0\n", 224 | "{'iter': 6110, 'loss': 0.5830560922622681, 'loss_CE': 0.45000630617141724, 'loss_cal': 1.3304975032806396, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 225 | "------------------------------\n", 226 | "bias_seen 0 bias_unseen 0\n", 227 | "{'iter': 6580, 'loss': 0.680651843547821, 'loss_CE': 0.566943883895874, 'loss_cal': 1.1370794773101807, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 228 | "------------------------------\n", 229 | "bias_seen 0 bias_unseen 0\n", 230 | "{'iter': 7050, 'loss': 0.5572313666343689, 'loss_CE': 0.42874494194984436, 'loss_cal': 1.2848644256591797, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 231 | "------------------------------\n", 232 | "bias_seen 0 bias_unseen 0\n", 233 | "{'iter': 7520, 'loss': 0.5773841142654419, 'loss_CE': 0.4520866870880127, 'loss_cal': 1.2529743909835815, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 234 | "------------------------------\n", 235 | "bias_seen 0 bias_unseen 0\n", 236 | "{'iter': 7990, 'loss': 0.6745968461036682, 'loss_CE': 0.5561791658401489, 'loss_cal': 1.1841765642166138, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 237 | "------------------------------\n", 238 | "bias_seen 0 bias_unseen 0\n", 239 | "{'iter': 8460, 'loss': 0.549967885017395, 'loss_CE': 0.4177410304546356, 'loss_cal': 1.3222687244415283, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 240 | "------------------------------\n", 241 | "bias_seen 0 bias_unseen 0\n", 242 | "{'iter': 8930, 'loss': 0.6991280317306519, 'loss_CE': 0.5688791275024414, 'loss_cal': 1.302489161491394, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n", 243 | "------------------------------\n", 244 | "bias_seen 0 bias_unseen 0\n", 245 | "{'iter': 9400, 'loss': 0.6614841818809509, 'loss_CE': 0.5430386662483215, 'loss_cal': 1.184455394744873, 'acc_seen': 0.7628469467163086, 'acc_novel': 0.6038219332695007, 'H': 0.6740823983052283, 'acc_zs': 0.6789496541023254}\n" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "best_performance = [0,0,0,0]\n", 251 | "for i in range(0,niters):\n", 252 | " model.train()\n", 253 | " optimizer.zero_grad()\n", 254 | " \n", 255 | " batch_label, batch_feature, batch_att = dataloader.next_batch(batch_size)\n", 256 | " out_package = model(batch_feature)\n", 257 | " \n", 258 | " in_package = out_package\n", 259 | " in_package['batch_label'] = batch_label\n", 260 | " \n", 261 | " out_package=model.compute_loss(in_package)\n", 262 | " loss,loss_CE,loss_cal = out_package['loss'],out_package['loss_CE'],out_package['loss_cal']\n", 263 | " \n", 264 | " loss.backward()\n", 265 | " optimizer.step()\n", 266 | " if i%report_interval==0:\n", 267 | " print('-'*30)\n", 268 | " acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(dataloader,model,device,bias_seen=-bias,bias_unseen=bias)\n", 269 | " \n", 270 | " if H > best_performance[2]:\n", 271 | " best_performance = [acc_seen, acc_novel, H, acc_zs]\n", 272 | " stats_package = {'iter':i, 'loss':loss.item(), 'loss_CE':loss_CE.item(),\n", 273 | " 'loss_cal': loss_cal.item(),\n", 274 | " 'acc_seen':best_performance[0], 'acc_novel':best_performance[1], 'H':best_performance[2], 'acc_zs':best_performance[3]}\n", 275 | " \n", 276 | " print(stats_package)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [] 285 | } 286 | ], 287 | "metadata": { 288 | "kernelspec": { 289 | "display_name": "Python 3", 290 | "language": "python", 291 | "name": "python3" 292 | }, 293 | "language_info": { 294 | "codemirror_mode": { 295 | "name": "ipython", 296 | "version": 3 297 | }, 298 | "file_extension": ".py", 299 | "mimetype": "text/x-python", 300 | "name": "python", 301 | "nbconvert_exporter": "python", 302 | "pygments_lexer": "ipython3", 303 | "version": "3.6.8" 304 | } 305 | }, 306 | "nbformat": 4, 307 | "nbformat_minor": 2 308 | } 309 | -------------------------------------------------------------------------------- /notebook/DAZLE_CUB.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "------------------------------\n", 13 | "/home/project_amadeus/home/hbdat/[RELEASE]_DenseAttentionZSL\n", 14 | "------------------------------\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "import os,sys\n", 21 | "pwd = os.getcwd()\n", 22 | "parent = '/'.join(pwd.split('/')[:-1])\n", 23 | "sys.path.insert(0,parent)\n", 24 | "os.chdir(parent)\n", 25 | "#%%\n", 26 | "print('-'*30)\n", 27 | "print(os.getcwd())\n", 28 | "print('-'*30)\n", 29 | "#%%\n", 30 | "import torch\n", 31 | "import torch.optim as optim\n", 32 | "import torch.nn as nn\n", 33 | "import pandas as pd\n", 34 | "from core.DAZLE import DAZLE\n", 35 | "from core.CUBDataLoader import CUBDataLoader\n", 36 | "from core.helper_func import eval_zs_gzsl,visualize_attention,eval_zs_gzsl#,get_attribute_attention_stats\n", 37 | "from global_setting import NFS_path\n", 38 | "import importlib\n", 39 | "import pdb\n", 40 | "import numpy as np" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 2, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "idx_GPU = 4\n", 50 | "device = torch.device(\"cuda:{}\".format(idx_GPU) if torch.cuda.is_available() else \"cpu\")" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "torch.backends.cudnn.benchmark = True" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/\n", 72 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 73 | "CUB\n", 74 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 75 | "_____\n", 76 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/data/CUB/feature_map_ResNet_101_CUB.hdf5\n", 77 | "Expert Attr\n", 78 | "Finish loading data in 61.433818\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "dataloader = CUBDataLoader(NFS_path,device,is_balance=False)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "dataloader.augment_img_path()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 6, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "def get_lr(optimizer):\n", 102 | " lr = []\n", 103 | " for param_group in optimizer.param_groups:\n", 104 | " lr.append(param_group['lr'])\n", 105 | " return lr" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "seed = 214#215#\n", 115 | "torch.manual_seed(seed)\n", 116 | "torch.cuda.manual_seed_all(seed)\n", 117 | "np.random.seed(seed)\n", 118 | "\n", 119 | "batch_size = 50\n", 120 | "nepoches = 20\n", 121 | "niters = dataloader.ntrain * nepoches//batch_size\n", 122 | "dim_f = 2048\n", 123 | "dim_v = 300\n", 124 | "init_w2v_att = dataloader.w2v_att\n", 125 | "att = dataloader.att\n", 126 | "normalize_att = dataloader.normalize_att\n", 127 | "\n", 128 | "trainable_w2v = True\n", 129 | "lambda_ = 0.1#0.1\n", 130 | "bias = 0\n", 131 | "prob_prune = 0\n", 132 | "uniform_att_1 = False\n", 133 | "uniform_att_2 = False\n", 134 | "\n", 135 | "seenclass = dataloader.seenclasses\n", 136 | "unseenclass = dataloader.unseenclasses\n", 137 | "desired_mass = 1\n", 138 | "report_interval = niters//nepoches\n", 139 | "\n", 140 | "model = DAZLE(dim_f,dim_v,init_w2v_att,att,normalize_att,\n", 141 | " seenclass,unseenclass,\n", 142 | " lambda_,\n", 143 | " trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True,\n", 144 | " uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,\n", 145 | " prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,\n", 146 | " is_bias=True)\n", 147 | "model.to(device)\n", 148 | "\n", 149 | "setup = {'pmp':{'init_lambda':0.1,'final_lambda':0.1,'phase':0.8},\n", 150 | " 'desired_mass':{'init_lambda':-1,'final_lambda':-1,'phase':0.8}}\n", 151 | "print(setup)\n", 152 | "#scheduler = Scheduler(model,niters,batch_size,report_interval,setup)\n", 153 | "\n", 154 | "params_to_update = []\n", 155 | "params_names = []\n", 156 | "for name,param in model.named_parameters():\n", 157 | " if param.requires_grad == True:\n", 158 | " params_to_update.append(param)\n", 159 | " params_names.append(name)\n", 160 | " print(\"\\t\",name)\n", 161 | "#%%\n", 162 | "lr = 0.0001\n", 163 | "weight_decay = 0.0001#0.000#0.#\n", 164 | "momentum = 0.9#0.#\n", 165 | "#%%\n", 166 | "lr_seperator = 1\n", 167 | "lr_factor = 1\n", 168 | "print('default lr {} {}x lr {}'.format(params_names[:lr_seperator],lr_factor,params_names[lr_seperator:]))\n", 169 | "optimizer = optim.RMSprop( params_to_update ,lr=lr,weight_decay=weight_decay, momentum=momentum)\n", 170 | "\n", 171 | "print('-'*30)\n", 172 | "print('learing rate {}'.format(lr))\n", 173 | "print('trainable V {}'.format(trainable_w2v))\n", 174 | "print('lambda_ {}'.format(lambda_))\n", 175 | "print('optimized seen only')\n", 176 | "print('optimizer: RMSProp with momentum = {} and weight_decay = {}'.format(momentum,weight_decay))\n", 177 | "print('-'*30)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 8, 183 | "metadata": { 184 | "scrolled": true 185 | }, 186 | "outputs": [ 187 | { 188 | "name": "stdout", 189 | "output_type": "stream", 190 | "text": [ 191 | "------------------------------\n", 192 | "bias_seen 0 bias_unseen 0\n", 193 | "{'iter': 0, 'loss': 5.339369297027588, 'loss_CE': 5.305324554443359, 'loss_cal': 0.34044915437698364, 'acc_seen': 0, 'acc_novel': 0, 'H': 0, 'acc_zs': 0}\n", 194 | "------------------------------\n", 195 | "bias_seen 0 bias_unseen 0\n", 196 | "{'iter': 141, 'loss': 1.6474179029464722, 'loss_CE': 1.5508091449737549, 'loss_cal': 0.9660871028900146, 'acc_seen': 0.32329174876213074, 'acc_novel': 0.4783506989479065, 'H': 0.3858249633018277, 'acc_zs': 0.5206593871116638}\n", 197 | "------------------------------\n", 198 | "bias_seen 0 bias_unseen 0\n", 199 | "{'iter': 282, 'loss': 1.5724619626998901, 'loss_CE': 1.4667860269546509, 'loss_cal': 1.0567588806152344, 'acc_seen': 0.44570451974868774, 'acc_novel': 0.5117195248603821, 'H': 0.4764361337239027, 'acc_zs': 0.579412579536438}\n", 200 | "------------------------------\n", 201 | "bias_seen 0 bias_unseen 0\n", 202 | "{'iter': 423, 'loss': 1.3222812414169312, 'loss_CE': 1.2042889595031738, 'loss_cal': 1.179922342300415, 'acc_seen': 0.5000437498092651, 'acc_novel': 0.537502646446228, 'H': 0.5180970023728759, 'acc_zs': 0.610663890838623}\n", 203 | "------------------------------\n", 204 | "bias_seen 0 bias_unseen 0\n", 205 | "{'iter': 564, 'loss': 1.115435004234314, 'loss_CE': 1.0016543865203857, 'loss_cal': 1.137805700302124, 'acc_seen': 0.5307056903839111, 'acc_novel': 0.5414375066757202, 'H': 0.5360178874764615, 'acc_zs': 0.6320300102233887}\n", 206 | "------------------------------\n", 207 | "bias_seen 0 bias_unseen 0\n", 208 | "{'iter': 705, 'loss': 1.1483986377716064, 'loss_CE': 1.0214109420776367, 'loss_cal': 1.2698768377304077, 'acc_seen': 0.5512126684188843, 'acc_novel': 0.5535522103309631, 'H': 0.5523799621707404, 'acc_zs': 0.645402193069458}\n", 209 | "------------------------------\n", 210 | "bias_seen 0 bias_unseen 0\n", 211 | "{'iter': 846, 'loss': 1.0018386840820312, 'loss_CE': 0.853941798210144, 'loss_cal': 1.4789692163467407, 'acc_seen': 0.5494822263717651, 'acc_novel': 0.5576741695404053, 'H': 0.5535478915182929, 'acc_zs': 0.6444504261016846}\n", 212 | "------------------------------\n", 213 | "bias_seen 0 bias_unseen 0\n", 214 | "{'iter': 987, 'loss': 0.9130017161369324, 'loss_CE': 0.7828669548034668, 'loss_cal': 1.3013474941253662, 'acc_seen': 0.5765684247016907, 'acc_novel': 0.5499624013900757, 'H': 0.5629512270244877, 'acc_zs': 0.6390618681907654}\n", 215 | "------------------------------\n", 216 | "bias_seen 0 bias_unseen 0\n", 217 | "{'iter': 1128, 'loss': 1.13380765914917, 'loss_CE': 0.9990037083625793, 'loss_cal': 1.3480396270751953, 'acc_seen': 0.5809220671653748, 'acc_novel': 0.5495595335960388, 'H': 0.5648057607873046, 'acc_zs': 0.648659348487854}\n", 218 | "------------------------------\n", 219 | "bias_seen 0 bias_unseen 0\n", 220 | "{'iter': 1269, 'loss': 0.8725115060806274, 'loss_CE': 0.7155832648277283, 'loss_cal': 1.5692821741104126, 'acc_seen': 0.5809220671653748, 'acc_novel': 0.5495595335960388, 'H': 0.5648057607873046, 'acc_zs': 0.648659348487854}\n", 221 | "------------------------------\n", 222 | "bias_seen 0 bias_unseen 0\n", 223 | "{'iter': 1410, 'loss': 0.7587224841117859, 'loss_CE': 0.6123248338699341, 'loss_cal': 1.463976263999939, 'acc_seen': 0.5878117680549622, 'acc_novel': 0.5457720160484314, 'H': 0.5660123551645453, 'acc_zs': 0.6505881547927856}\n", 224 | "------------------------------\n", 225 | "bias_seen 0 bias_unseen 0\n", 226 | "{'iter': 1551, 'loss': 0.9522080421447754, 'loss_CE': 0.8048012852668762, 'loss_cal': 1.474067211151123, 'acc_seen': 0.5906508564949036, 'acc_novel': 0.5587936639785767, 'H': 0.5742807945126683, 'acc_zs': 0.6692891120910645}\n", 227 | "------------------------------\n", 228 | "bias_seen 0 bias_unseen 0\n", 229 | "{'iter': 1692, 'loss': 1.0499943494796753, 'loss_CE': 0.9086294174194336, 'loss_cal': 1.413649559020996, 'acc_seen': 0.5906508564949036, 'acc_novel': 0.5587936639785767, 'H': 0.5742807945126683, 'acc_zs': 0.6692891120910645}\n", 230 | "------------------------------\n", 231 | "bias_seen 0 bias_unseen 0\n", 232 | "{'iter': 1833, 'loss': 0.762641429901123, 'loss_CE': 0.6273956298828125, 'loss_cal': 1.352458119392395, 'acc_seen': 0.591661274433136, 'acc_novel': 0.5631834864616394, 'H': 0.5770712577531473, 'acc_zs': 0.6591776013374329}\n", 233 | "------------------------------\n", 234 | "bias_seen 0 bias_unseen 0\n", 235 | "{'iter': 1974, 'loss': 0.8310383558273315, 'loss_CE': 0.6961447596549988, 'loss_cal': 1.3489360809326172, 'acc_seen': 0.591661274433136, 'acc_novel': 0.5631834864616394, 'H': 0.5770712577531473, 'acc_zs': 0.6591776013374329}\n", 236 | "------------------------------\n", 237 | "bias_seen 0 bias_unseen 0\n", 238 | "{'iter': 2115, 'loss': 0.6843529939651489, 'loss_CE': 0.5164155960083008, 'loss_cal': 1.6793742179870605, 'acc_seen': 0.5928998589515686, 'acc_novel': 0.564140260219574, 'H': 0.5781626326884738, 'acc_zs': 0.6699299812316895}\n", 239 | "------------------------------\n", 240 | "bias_seen 0 bias_unseen 0\n", 241 | "{'iter': 2256, 'loss': 0.7171906232833862, 'loss_CE': 0.5732181668281555, 'loss_cal': 1.4397245645523071, 'acc_seen': 0.5928998589515686, 'acc_novel': 0.564140260219574, 'H': 0.5781626326884738, 'acc_zs': 0.6699299812316895}\n", 242 | "------------------------------\n", 243 | "bias_seen 0 bias_unseen 0\n", 244 | "{'iter': 2397, 'loss': 0.7362838387489319, 'loss_CE': 0.6180073618888855, 'loss_cal': 1.1827645301818848, 'acc_seen': 0.5928998589515686, 'acc_novel': 0.564140260219574, 'H': 0.5781626326884738, 'acc_zs': 0.6699299812316895}\n", 245 | "------------------------------\n", 246 | "bias_seen 0 bias_unseen 0\n", 247 | "{'iter': 2538, 'loss': 0.7740883827209473, 'loss_CE': 0.6076768040657043, 'loss_cal': 1.66411554813385, 'acc_seen': 0.5928998589515686, 'acc_novel': 0.564140260219574, 'H': 0.5781626326884738, 'acc_zs': 0.6699299812316895}\n", 248 | "------------------------------\n", 249 | "bias_seen 0 bias_unseen 0\n", 250 | "{'iter': 2679, 'loss': 1.0031527280807495, 'loss_CE': 0.8608006834983826, 'loss_cal': 1.423520803451538, 'acc_seen': 0.5928998589515686, 'acc_novel': 0.564140260219574, 'H': 0.5781626326884738, 'acc_zs': 0.6699299812316895}\n", 251 | "------------------------------\n", 252 | "bias_seen 0 bias_unseen 0\n", 253 | "{'iter': 2820, 'loss': 0.8257772922515869, 'loss_CE': 0.7006416320800781, 'loss_cal': 1.251356840133667, 'acc_seen': 0.5979835391044617, 'acc_novel': 0.5665391087532043, 'H': 0.5818367928121845, 'acc_zs': 0.6588420867919922}\n" 254 | ] 255 | } 256 | ], 257 | "source": [ 258 | "best_performance = [0,0,0,0]\n", 259 | "for i in range(0,niters):\n", 260 | " model.train()\n", 261 | " optimizer.zero_grad()\n", 262 | " \n", 263 | " batch_label, batch_feature, batch_att = dataloader.next_batch(batch_size)\n", 264 | " out_package = model(batch_feature)\n", 265 | " \n", 266 | " in_package = out_package\n", 267 | " in_package['batch_label'] = batch_label\n", 268 | " \n", 269 | " out_package=model.compute_loss(in_package)\n", 270 | " loss,loss_CE,loss_cal = out_package['loss'],out_package['loss_CE'],out_package['loss_cal']\n", 271 | " \n", 272 | " loss.backward()\n", 273 | " optimizer.step()\n", 274 | " if i%report_interval==0:\n", 275 | " print('-'*30)\n", 276 | " acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(dataloader,model,device,bias_seen=-bias,bias_unseen=bias)\n", 277 | " \n", 278 | " if H > best_performance[2]:\n", 279 | " best_performance = [acc_seen, acc_novel, H, acc_zs]\n", 280 | " stats_package = {'iter':i, 'loss':loss.item(), 'loss_CE':loss_CE.item(),\n", 281 | " 'loss_cal': loss_cal.item(),\n", 282 | " 'acc_seen':best_performance[0], 'acc_novel':best_performance[1], 'H':best_performance[2], 'acc_zs':best_performance[3]}\n", 283 | " \n", 284 | " print(stats_package)" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [] 293 | } 294 | ], 295 | "metadata": { 296 | "kernelspec": { 297 | "display_name": "Python 3", 298 | "language": "python", 299 | "name": "python3" 300 | }, 301 | "language_info": { 302 | "codemirror_mode": { 303 | "name": "ipython", 304 | "version": 3 305 | }, 306 | "file_extension": ".py", 307 | "mimetype": "text/x-python", 308 | "name": "python", 309 | "nbconvert_exporter": "python", 310 | "pygments_lexer": "ipython3", 311 | "version": "3.6.8" 312 | } 313 | }, 314 | "nbformat": 4, 315 | "nbformat_minor": 2 316 | } 317 | -------------------------------------------------------------------------------- /notebook/DAZLE_CUB_SS.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "------------------------------\n", 13 | "/home/project_amadeus/home/hbdat/[RELEASE]_DenseAttentionZSL\n", 14 | "------------------------------\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "import os,sys\n", 21 | "pwd = os.getcwd()\n", 22 | "parent = '/'.join(pwd.split('/')[:-1])\n", 23 | "sys.path.insert(0,parent)\n", 24 | "os.chdir(parent)\n", 25 | "#%%\n", 26 | "print('-'*30)\n", 27 | "print(os.getcwd())\n", 28 | "print('-'*30)\n", 29 | "#%%\n", 30 | "import torch\n", 31 | "import torch.optim as optim\n", 32 | "import torch.nn as nn\n", 33 | "import pandas as pd\n", 34 | "from core.DAZLE import DAZLE\n", 35 | "from core.CUBDataLoader_standard_split import CUBDataLoader\n", 36 | "from core.helper_func import eval_zs_gzsl,visualize_attention,eval_zs_gzsl#,get_attribute_attention_stats\n", 37 | "from global_setting import NFS_path\n", 38 | "#from core.Scheduler import Scheduler\n", 39 | "import importlib\n", 40 | "import pdb\n", 41 | "import numpy as np" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 2, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "idx_GPU = 0\n", 51 | "device = torch.device(\"cuda:{}\".format(idx_GPU) if torch.cuda.is_available() else \"cpu\")" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 3, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "torch.backends.cudnn.benchmark = True" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "!!!!!!!!!! Standard Split !!!!!!!!!!\n", 73 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/\n", 74 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 75 | "CUB\n", 76 | "$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n", 77 | "_____\n", 78 | "/home/project_amadeus/mnt/raptor/hbdat/Attention_over_attention/data/CUB/feature_map_ResNet_101_CUB.hdf5\n", 79 | "Expert Attr\n", 80 | "Finish loading data in 61.513001\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "dataloader = CUBDataLoader(NFS_path,device,is_balance=False)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "def get_attr_entropy(att): #the lower the more discriminative it is\n", 95 | " eps = 1e-8\n", 96 | " mass=np.sum(att,axis = 0,keepdims=True)\n", 97 | " att_n = np.divide(att,mass+eps)\n", 98 | " entropy = np.sum(-att_n*np.log(att_n+eps),axis=0)\n", 99 | " assert len(entropy.shape)==1\n", 100 | " return entropy" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 6, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "batch_size = 50\n", 110 | "nepoches = 1\n", 111 | "niters = dataloader.ntrain * nepoches//batch_size\n", 112 | "dim_f = 2048\n", 113 | "dim_v = 300\n", 114 | "init_w2v_att = dataloader.w2v_att\n", 115 | "att = dataloader.att#dataloader.normalize_att#\n", 116 | "normalize_att = dataloader.att\n", 117 | "#%% attribute selection\n", 118 | "attr_entropy = get_attr_entropy(att.cpu().numpy())\n", 119 | "idx_attr_dis = np.argsort(attr_entropy)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 7, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "torch.Size([312, 300])\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "print(init_w2v_att.shape)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 8, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "def get_lr(optimizer):\n", 146 | " lr = []\n", 147 | " for param_group in optimizer.param_groups:\n", 148 | " lr.append(param_group['lr'])\n", 149 | " return lr" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "seed = 214#215#\n", 159 | "torch.manual_seed(seed)\n", 160 | "torch.cuda.manual_seed_all(seed)\n", 161 | "np.random.seed(seed)\n", 162 | "\n", 163 | "batch_size = 50\n", 164 | "nepoches = 20\n", 165 | "niters = dataloader.ntrain * nepoches//batch_size\n", 166 | "dim_f = 2048\n", 167 | "dim_v = 300\n", 168 | "init_w2v_att = dataloader.w2v_att\n", 169 | "att = dataloader.att#dataloader.normalize_att#\n", 170 | "normalize_att = dataloader.normalize_att\n", 171 | "#assert (att.min().item() == 0 and att.max().item() == 1)\n", 172 | "\n", 173 | "trainable_w2v = True\n", 174 | "lambda_ = 0.1\n", 175 | "bias = 0\n", 176 | "prob_prune = 0\n", 177 | "uniform_att_1 = False\n", 178 | "uniform_att_2 = False\n", 179 | "\n", 180 | "seenclass = dataloader.seenclasses\n", 181 | "unseenclass = dataloader.unseenclasses\n", 182 | "desired_mass = 1#unseenclass.size(0)/(seenclass.size(0)+unseenclass.size(0))\n", 183 | "report_interval = niters//nepoches#10000//batch_size#\n", 184 | "\n", 185 | "model = DAZLE(dim_f,dim_v,init_w2v_att,att,normalize_att,\n", 186 | " seenclass,unseenclass,\n", 187 | " lambda_,\n", 188 | " trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True,\n", 189 | " uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,\n", 190 | " prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,\n", 191 | " is_bias=True)\n", 192 | "model.to(device)\n", 193 | "\n", 194 | "setup = {'pmp':{'init_lambda':0.1,'final_lambda':0.1,'phase':0.8},\n", 195 | " 'desired_mass':{'init_lambda':-1,'final_lambda':-1,'phase':0.8}}\n", 196 | "print(setup)\n", 197 | "#scheduler = Scheduler(model,niters,batch_size,report_interval,setup)\n", 198 | "\n", 199 | "params_to_update = []\n", 200 | "params_names = []\n", 201 | "for name,param in model.named_parameters():\n", 202 | " if param.requires_grad == True:\n", 203 | " params_to_update.append(param)\n", 204 | " params_names.append(name)\n", 205 | " print(\"\\t\",name)\n", 206 | "#%%\n", 207 | "lr = 0.0001\n", 208 | "weight_decay = 0.00005#0.000#0.#\n", 209 | "momentum = 0.9#0.#\n", 210 | "#%%\n", 211 | "lr_seperator = 1\n", 212 | "lr_factor = 1\n", 213 | "print('default lr {} {}x lr {}'.format(params_names[:lr_seperator],lr_factor,params_names[lr_seperator:]))\n", 214 | "optimizer = optim.RMSprop( params_to_update ,lr=lr,weight_decay=weight_decay, momentum=momentum)\n", 215 | "print('-'*30)\n", 216 | "print('learing rate {}'.format(lr))\n", 217 | "print('trainable V {}'.format(trainable_w2v))\n", 218 | "print('lambda_ {}'.format(lambda_))\n", 219 | "print('optimized seen only')\n", 220 | "print('optimizer: RMSProp with momentum = {} and weight_decay = {}'.format(momentum,weight_decay))\n", 221 | "print('-'*30)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 10, 227 | "metadata": { 228 | "scrolled": false 229 | }, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "------------------------------\n", 236 | "bias_seen 0 bias_unseen 0\n", 237 | "{'iter': 0, 'loss': 5.307318687438965, 'loss_CE': 5.271770000457764, 'loss_cal': 0.35548919439315796, 'acc_seen': nan, 'acc_novel': 0.040221136063337326, 'H': 0, 'acc_zs': 0.04022114351391792}\n", 238 | "------------------------------\n", 239 | "bias_seen 0 bias_unseen 0\n", 240 | "{'iter': 177, 'loss': 1.6891649961471558, 'loss_CE': 1.5842840671539307, 'loss_cal': 1.0488090515136719, 'acc_seen': nan, 'acc_novel': 0.4655452370643616, 'H': 0, 'acc_zs': 0.5405555367469788}\n", 241 | "------------------------------\n", 242 | "bias_seen 0 bias_unseen 0\n", 243 | "{'iter': 354, 'loss': 1.5217909812927246, 'loss_CE': 1.4060150384902954, 'loss_cal': 1.1577597856521606, 'acc_seen': nan, 'acc_novel': 0.510532796382904, 'H': 0, 'acc_zs': 0.6040189266204834}\n", 244 | "------------------------------\n", 245 | "bias_seen 0 bias_unseen 0\n", 246 | "{'iter': 531, 'loss': 1.156381607055664, 'loss_CE': 1.0383371114730835, 'loss_cal': 1.1804451942443848, 'acc_seen': nan, 'acc_novel': 0.5140025019645691, 'H': 0, 'acc_zs': 0.6206724643707275}\n", 247 | "------------------------------\n", 248 | "bias_seen 0 bias_unseen 0\n", 249 | "{'iter': 708, 'loss': 1.1116386651992798, 'loss_CE': 0.9800852537155151, 'loss_cal': 1.3155337572097778, 'acc_seen': nan, 'acc_novel': 0.5589466094970703, 'H': 0, 'acc_zs': 0.6678099632263184}\n", 250 | "------------------------------\n", 251 | "bias_seen 0 bias_unseen 0\n", 252 | "{'iter': 885, 'loss': 1.3317115306854248, 'loss_CE': 1.1996527910232544, 'loss_cal': 1.320586919784546, 'acc_seen': nan, 'acc_novel': 0.5589466094970703, 'H': 0, 'acc_zs': 0.6678099632263184}\n", 253 | "------------------------------\n", 254 | "bias_seen 0 bias_unseen 0\n", 255 | "{'iter': 1062, 'loss': 1.118662714958191, 'loss_CE': 0.9750721454620361, 'loss_cal': 1.4359060525894165, 'acc_seen': nan, 'acc_novel': 0.5580828785896301, 'H': 0, 'acc_zs': 0.6697779297828674}\n", 256 | "------------------------------\n", 257 | "bias_seen 0 bias_unseen 0\n", 258 | "{'iter': 1239, 'loss': 0.8553117513656616, 'loss_CE': 0.6984010338783264, 'loss_cal': 1.5691068172454834, 'acc_seen': nan, 'acc_novel': 0.5738272070884705, 'H': 0, 'acc_zs': 0.6717338562011719}\n", 259 | "------------------------------\n", 260 | "bias_seen 0 bias_unseen 0\n", 261 | "{'iter': 1416, 'loss': 0.6992800831794739, 'loss_CE': 0.5499823689460754, 'loss_cal': 1.4929770231246948, 'acc_seen': nan, 'acc_novel': 0.5738272070884705, 'H': 0, 'acc_zs': 0.6717338562011719}\n", 262 | "------------------------------\n", 263 | "bias_seen 0 bias_unseen 0\n", 264 | "{'iter': 1593, 'loss': 1.0170717239379883, 'loss_CE': 0.88196861743927, 'loss_cal': 1.3510308265686035, 'acc_seen': nan, 'acc_novel': 0.5738272070884705, 'H': 0, 'acc_zs': 0.6717338562011719}\n", 265 | "------------------------------\n", 266 | "bias_seen 0 bias_unseen 0\n", 267 | "{'iter': 1770, 'loss': 0.9129882454872131, 'loss_CE': 0.7532180547714233, 'loss_cal': 1.5977017879486084, 'acc_seen': nan, 'acc_novel': 0.5701281428337097, 'H': 0, 'acc_zs': 0.6731362342834473}\n", 268 | "------------------------------\n", 269 | "bias_seen 0 bias_unseen 0\n", 270 | "{'iter': 1947, 'loss': 0.6349995136260986, 'loss_CE': 0.48872146010398865, 'loss_cal': 1.4627807140350342, 'acc_seen': nan, 'acc_novel': 0.5772320032119751, 'H': 0, 'acc_zs': 0.6761006116867065}\n", 271 | "------------------------------\n", 272 | "bias_seen 0 bias_unseen 0\n", 273 | "{'iter': 2124, 'loss': 0.7970255613327026, 'loss_CE': 0.6616990566253662, 'loss_cal': 1.3532648086547852, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 274 | "------------------------------\n", 275 | "bias_seen 0 bias_unseen 0\n", 276 | "{'iter': 2301, 'loss': 0.7347122430801392, 'loss_CE': 0.5974063873291016, 'loss_cal': 1.373058557510376, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 277 | "------------------------------\n", 278 | "bias_seen 0 bias_unseen 0\n", 279 | "{'iter': 2478, 'loss': 0.5548276901245117, 'loss_CE': 0.3850424587726593, 'loss_cal': 1.6978520154953003, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 280 | "------------------------------\n", 281 | "bias_seen 0 bias_unseen 0\n", 282 | "{'iter': 2655, 'loss': 0.6615085601806641, 'loss_CE': 0.5202628374099731, 'loss_cal': 1.4124568700790405, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 283 | "------------------------------\n", 284 | "bias_seen 0 bias_unseen 0\n", 285 | "{'iter': 2832, 'loss': 0.6729946136474609, 'loss_CE': 0.501806914806366, 'loss_cal': 1.711876630783081, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 286 | "------------------------------\n", 287 | "bias_seen 0 bias_unseen 0\n", 288 | "{'iter': 3009, 'loss': 0.6648485064506531, 'loss_CE': 0.5202063322067261, 'loss_cal': 1.4464218616485596, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 289 | "------------------------------\n", 290 | "bias_seen 0 bias_unseen 0\n", 291 | "{'iter': 3186, 'loss': 0.5674977898597717, 'loss_CE': 0.42501401901245117, 'loss_cal': 1.424837589263916, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 292 | "------------------------------\n", 293 | "bias_seen 0 bias_unseen 0\n", 294 | "{'iter': 3363, 'loss': 0.5623935461044312, 'loss_CE': 0.4077602028846741, 'loss_cal': 1.5463331937789917, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n", 295 | "------------------------------\n", 296 | "bias_seen 0 bias_unseen 0\n", 297 | "{'iter': 3540, 'loss': 0.6784080862998962, 'loss_CE': 0.5070964694023132, 'loss_cal': 1.713115930557251, 'acc_seen': nan, 'acc_novel': 0.5840730667114258, 'H': 0, 'acc_zs': 0.6778362989425659}\n" 298 | ] 299 | } 300 | ], 301 | "source": [ 302 | "best_performance = [0,0,0,0]\n", 303 | "for i in range(0,niters):\n", 304 | " model.train()\n", 305 | " optimizer.zero_grad()\n", 306 | " \n", 307 | " batch_label, batch_feature, batch_att = dataloader.next_batch(batch_size)\n", 308 | " out_package = model(batch_feature)\n", 309 | " \n", 310 | " in_package = out_package\n", 311 | " in_package['batch_label'] = batch_label\n", 312 | " \n", 313 | " out_package=model.compute_loss(in_package)\n", 314 | " loss,loss_CE,loss_cal = out_package['loss'],out_package['loss_CE'],out_package['loss_cal']\n", 315 | " \n", 316 | " loss.backward()\n", 317 | " optimizer.step()\n", 318 | " if i%report_interval==0:\n", 319 | " print('-'*30)\n", 320 | " acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(dataloader,model,device,bias_seen=-bias,bias_unseen=bias)\n", 321 | " \n", 322 | " if acc_zs > best_performance[3]:\n", 323 | " best_performance = [acc_seen, acc_novel, H, acc_zs]\n", 324 | " stats_package = {'iter':i, 'loss':loss.item(), 'loss_CE':loss_CE.item(),\n", 325 | " 'loss_cal': loss_cal.item(),\n", 326 | " 'acc_seen':best_performance[0], 'acc_novel':best_performance[1], 'H':best_performance[2], 'acc_zs':best_performance[3]}\n", 327 | " \n", 328 | " print(stats_package)" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [] 337 | } 338 | ], 339 | "metadata": { 340 | "kernelspec": { 341 | "display_name": "Python 3", 342 | "language": "python", 343 | "name": "python3" 344 | }, 345 | "language_info": { 346 | "codemirror_mode": { 347 | "name": "ipython", 348 | "version": 3 349 | }, 350 | "file_extension": ".py", 351 | "mimetype": "text/x-python", 352 | "name": "python", 353 | "nbconvert_exporter": "python", 354 | "pygments_lexer": "ipython3", 355 | "version": "3.6.8" 356 | } 357 | }, 358 | "nbformat": 4, 359 | "nbformat_minor": 2 360 | } 361 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.9.0 2 | jupyter-client==5.2.4 3 | jupyter-core==4.4.0 4 | jupyterlab==0.35.6 5 | jupyterlab-server==0.2.0 6 | matplotlib==3.1.0 7 | notebook==5.7.8 8 | numpy==1.16.4 9 | opencv-python==3.4.1.15 10 | pandas==0.24.2 11 | Pillow==7.1.2 12 | Pillow-SIMD==5.3.0.post1 13 | scikit-image==0.17.2 14 | scikit-learn==0.21.2 15 | scipy==1.2.1 16 | torch==1.4.0 17 | torchtext==0.4.0 18 | torchvision==0.5.0 -------------------------------------------------------------------------------- /w2v/AWA2_attribute.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbdat/cvpr20_DAZLE/e5a04601d8368903008fd96dee3d95dde398aa51/w2v/AWA2_attribute.pkl -------------------------------------------------------------------------------- /w2v/CUB_attribute.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbdat/cvpr20_DAZLE/e5a04601d8368903008fd96dee3d95dde398aa51/w2v/CUB_attribute.pkl -------------------------------------------------------------------------------- /w2v/DeepFashion_attribute.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbdat/cvpr20_DAZLE/e5a04601d8368903008fd96dee3d95dde398aa51/w2v/DeepFashion_attribute.pkl -------------------------------------------------------------------------------- /w2v/SUN_attribute.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hbdat/cvpr20_DAZLE/e5a04601d8368903008fd96dee3d95dde398aa51/w2v/SUN_attribute.pkl --------------------------------------------------------------------------------