├── .gitignore ├── README.md ├── datasets ├── awa1 │ ├── classes.txt │ └── testclasses.txt ├── awa2 │ ├── classes.txt │ └── testclasses.txt ├── cub │ ├── classes.txt │ └── testclasses.txt └── sun │ ├── classes.txt │ └── testclasses.txt ├── datautils.py ├── main.py ├── models.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | WARNING: There are some mistakes in the code available here, please do not use it as a benchmark or component. I will try to fix the project as soon as possibe. Please see the issues page to know the error. 2 | 3 | ## Feature Generating Networks for ZSL in Pytorch 4 | 5 | PyTorch implementation of paper: [Feature Generating Networks for Zero-Shot Learning](https://arxiv.org/abs/1712.00981) 6 | 7 | 4 datasets are currently supported: SUN, CUB, AWA1 & AWA2. All datasets can be downloaded [here](http://datasets.d2.mpi-inf.mpg.de/xian/xlsa17.zip). 8 | 9 | #### IMPORTANT: 10 | The downloaded zip will have many files for each dataset, but we only require 2 files ``res101.mat`` & ``att_splits.mat``. Move these 2 files per dataset to the appropriate folder in this repo before starting to train/test. 11 | 12 | * For training the model, use: 13 | ```python3 main.py --n_epochs 20 --use_cls_loss``` 14 | 15 | All trainable parameters are saved in a folder named ``saved_models`` at the end of every epoch. 16 | -------------------------------------------------------------------------------- /datasets/awa1/classes.txt: -------------------------------------------------------------------------------- 1 | 1 antelope 2 | 2 grizzly+bear 3 | 3 killer+whale 4 | 4 beaver 5 | 5 dalmatian 6 | 6 persian+cat 7 | 7 horse 8 | 8 german+shepherd 9 | 9 blue+whale 10 | 10 siamese+cat 11 | 11 skunk 12 | 12 mole 13 | 13 tiger 14 | 14 hippopotamus 15 | 15 leopard 16 | 16 moose 17 | 17 spider+monkey 18 | 18 humpback+whale 19 | 19 elephant 20 | 20 gorilla 21 | 21 ox 22 | 22 fox 23 | 23 sheep 24 | 24 seal 25 | 25 chimpanzee 26 | 26 hamster 27 | 27 squirrel 28 | 28 rhinoceros 29 | 29 rabbit 30 | 30 bat 31 | 31 giraffe 32 | 32 wolf 33 | 33 chihuahua 34 | 34 rat 35 | 35 weasel 36 | 36 otter 37 | 37 buffalo 38 | 38 zebra 39 | 39 giant+panda 40 | 40 deer 41 | 41 bobcat 42 | 42 pig 43 | 43 lion 44 | 44 mouse 45 | 45 polar+bear 46 | 46 collie 47 | 47 walrus 48 | 48 raccoon 49 | 49 cow 50 | 50 dolphin 51 | -------------------------------------------------------------------------------- /datasets/awa1/testclasses.txt: -------------------------------------------------------------------------------- 1 | sheep 2 | dolphin 3 | bat 4 | seal 5 | blue+whale 6 | rat 7 | horse 8 | walrus 9 | giraffe 10 | bobcat -------------------------------------------------------------------------------- /datasets/awa2/classes.txt: -------------------------------------------------------------------------------- 1 | 1 antelope 2 | 2 grizzly+bear 3 | 3 killer+whale 4 | 4 beaver 5 | 5 dalmatian 6 | 6 persian+cat 7 | 7 horse 8 | 8 german+shepherd 9 | 9 blue+whale 10 | 10 siamese+cat 11 | 11 skunk 12 | 12 mole 13 | 13 tiger 14 | 14 hippopotamus 15 | 15 leopard 16 | 16 moose 17 | 17 spider+monkey 18 | 18 humpback+whale 19 | 19 elephant 20 | 20 gorilla 21 | 21 ox 22 | 22 fox 23 | 23 sheep 24 | 24 seal 25 | 25 chimpanzee 26 | 26 hamster 27 | 27 squirrel 28 | 28 rhinoceros 29 | 29 rabbit 30 | 30 bat 31 | 31 giraffe 32 | 32 wolf 33 | 33 chihuahua 34 | 34 rat 35 | 35 weasel 36 | 36 otter 37 | 37 buffalo 38 | 38 zebra 39 | 39 giant+panda 40 | 40 deer 41 | 41 bobcat 42 | 42 pig 43 | 43 lion 44 | 44 mouse 45 | 45 polar+bear 46 | 46 collie 47 | 47 walrus 48 | 48 raccoon 49 | 49 cow 50 | 50 dolphin 51 | -------------------------------------------------------------------------------- /datasets/awa2/testclasses.txt: -------------------------------------------------------------------------------- 1 | sheep 2 | dolphin 3 | bat 4 | seal 5 | blue+whale 6 | rat 7 | horse 8 | walrus 9 | giraffe 10 | bobcat 11 | -------------------------------------------------------------------------------- /datasets/cub/classes.txt: -------------------------------------------------------------------------------- 1 | 1 Laysan_Albatross 2 | 2 Sooty_Albatross 3 | 3 Crested_Auklet 4 | 4 Parakeet_Auklet 5 | 5 Red_winged_Blackbird 6 | 6 Rusty_Blackbird 7 | 7 Yellow_headed_Blackbird 8 | 8 Bobolink 9 | 9 Lazuli_Bunting 10 | 10 Painted_Bunting 11 | 11 Cardinal 12 | 12 Spotted_Catbird 13 | 13 Gray_Catbird 14 | 14 Yellow_breasted_Chat 15 | 15 Eastern_Towhee 16 | 16 Chuck_will_Widow 17 | 17 Red_faced_Cormorant 18 | 18 Pelagic_Cormorant 19 | 19 Bronzed_Cowbird 20 | 20 Shiny_Cowbird 21 | 21 Brown_Creeper 22 | 22 Fish_Crow 23 | 23 Mangrove_Cuckoo 24 | 24 Least_Flycatcher 25 | 25 Olive_sided_Flycatcher 26 | 26 Scissor_tailed_Flycatcher 27 | 27 Vermilion_Flycatcher 28 | 28 Frigatebird 29 | 29 Northern_Fulmar 30 | 30 Gadwall 31 | 31 American_Goldfinch 32 | 32 European_Goldfinch 33 | 33 Eared_Grebe 34 | 34 Pied_billed_Grebe 35 | 35 Blue_Grosbeak 36 | 36 Evening_Grosbeak 37 | 37 Pine_Grosbeak 38 | 38 Rose_breasted_Grosbeak 39 | 39 Pigeon_Guillemot 40 | 40 California_Gull 41 | 41 Glaucous_winged_Gull 42 | 42 Heermann_Gull 43 | 43 Herring_Gull 44 | 44 Ivory_Gull 45 | 45 Ring_billed_Gull 46 | 46 Slaty_backed_Gull 47 | 47 Anna_Hummingbird 48 | 48 Ruby_throated_Hummingbird 49 | 49 Rufous_Hummingbird 50 | 50 Green_Violetear 51 | 51 Long_tailed_Jaeger 52 | 52 Blue_Jay 53 | 53 Florida_Jay 54 | 54 Green_Jay 55 | 55 Dark_eyed_Junco 56 | 56 Tropical_Kingbird 57 | 57 Gray_Kingbird 58 | 58 Green_Kingfisher 59 | 59 Pied_Kingfisher 60 | 60 Ringed_Kingfisher 61 | 61 Horned_Lark 62 | 62 Mallard 63 | 63 Western_Meadowlark 64 | 64 Hooded_Merganser 65 | 65 Red_breasted_Merganser 66 | 66 Nighthawk 67 | 67 Clark_Nutcracker 68 | 68 White_breasted_Nuthatch 69 | 69 Orchard_Oriole 70 | 70 Ovenbird 71 | 71 Brown_Pelican 72 | 72 American_Pipit 73 | 73 Whip_poor_Will 74 | 74 Horned_Puffin 75 | 75 Common_Raven 76 | 76 White_necked_Raven 77 | 77 American_Redstart 78 | 78 Geococcyx 79 | 79 Loggerhead_Shrike 80 | 80 Baird_Sparrow 81 | 81 Brewer_Sparrow 82 | 82 Chipping_Sparrow 83 | 83 Clay_colored_Sparrow 84 | 84 House_Sparrow 85 | 85 Fox_Sparrow 86 | 86 Harris_Sparrow 87 | 87 Henslow_Sparrow 88 | 88 Le_Conte_Sparrow 89 | 89 Lincoln_Sparrow 90 | 90 Nelson_Sharp_tailed_Sparrow 91 | 91 Savannah_Sparrow 92 | 92 Seaside_Sparrow 93 | 93 Song_Sparrow 94 | 94 Vesper_Sparrow 95 | 95 White_crowned_Sparrow 96 | 96 White_throated_Sparrow 97 | 97 Cape_Glossy_Starling 98 | 98 Barn_Swallow 99 | 99 Cliff_Swallow 100 | 100 Scarlet_Tanager 101 | 101 Summer_Tanager 102 | 102 Artic_Tern 103 | 103 Black_Tern 104 | 104 Caspian_Tern 105 | 105 Common_Tern 106 | 106 Elegant_Tern 107 | 107 Forsters_Tern 108 | 108 Green_tailed_Towhee 109 | 109 Brown_Thrasher 110 | 110 Sage_Thrasher 111 | 111 Black_capped_Vireo 112 | 112 Blue_headed_Vireo 113 | 113 Philadelphia_Vireo 114 | 114 Red_eyed_Vireo 115 | 115 Warbling_Vireo 116 | 116 Yellow_throated_Vireo 117 | 117 Bay_breasted_Warbler 118 | 118 Black_and_white_Warbler 119 | 119 Black_throated_Blue_Warbler 120 | 120 Blue_winged_Warbler 121 | 121 Canada_Warbler 122 | 122 Cerulean_Warbler 123 | 123 Hooded_Warbler 124 | 124 Kentucky_Warbler 125 | 125 Magnolia_Warbler 126 | 126 Mourning_Warbler 127 | 127 Myrtle_Warbler 128 | 128 Nashville_Warbler 129 | 129 Orange_crowned_Warbler 130 | 130 Palm_Warbler 131 | 131 Pine_Warbler 132 | 132 Prairie_Warbler 133 | 133 Prothonotary_Warbler 134 | 134 Swainson_Warbler 135 | 135 Tennessee_Warbler 136 | 136 Worm_eating_Warbler 137 | 137 Yellow_Warbler 138 | 138 Louisiana_Waterthrush 139 | 139 Pileated_Woodpecker 140 | 140 Red_bellied_Woodpecker 141 | 141 Red_cockaded_Woodpecker 142 | 142 Red_headed_Woodpecker 143 | 143 Downy_Woodpecker 144 | 144 Bewick_Wren 145 | 145 Cactus_Wren 146 | 146 Carolina_Wren 147 | 147 House_Wren 148 | 148 Rock_Wren 149 | 149 Winter_Wren 150 | 150 Common_Yellowthroat 151 | 151 Black_footed_Albatross 152 | 152 Groove_billed_Ani 153 | 153 Least_Auklet 154 | 154 Rhinoceros_Auklet 155 | 155 Brewer_Blackbird 156 | 156 Indigo_Bunting 157 | 157 Brandt_Cormorant 158 | 158 American_Crow 159 | 159 Black_billed_Cuckoo 160 | 160 Yellow_billed_Cuckoo 161 | 161 Gray_crowned_Rosy_Finch 162 | 162 Purple_Finch 163 | 163 Northern_Flicker 164 | 164 Acadian_Flycatcher 165 | 165 Great_Crested_Flycatcher 166 | 166 Yellow_bellied_Flycatcher 167 | 167 Boat_tailed_Grackle 168 | 168 Horned_Grebe 169 | 169 Western_Grebe 170 | 170 Western_Gull 171 | 171 Pomarine_Jaeger 172 | 172 Belted_Kingfisher 173 | 173 White_breasted_Kingfisher 174 | 174 Red_legged_Kittiwake 175 | 175 Pacific_Loon 176 | 176 Mockingbird 177 | 177 Baltimore_Oriole 178 | 178 Hooded_Oriole 179 | 179 Scott_Oriole 180 | 180 White_Pelican 181 | 181 Western_Wood_Pewee 182 | 182 Sayornis 183 | 183 Great_Grey_Shrike 184 | 184 Black_throated_Sparrow 185 | 185 Field_Sparrow 186 | 186 Grasshopper_Sparrow 187 | 187 Tree_Sparrow 188 | 188 Bank_Swallow 189 | 189 Tree_Swallow 190 | 190 Least_Tern 191 | 191 White_eyed_Vireo 192 | 192 Cape_May_Warbler 193 | 193 Chestnut_sided_Warbler 194 | 194 Golden_winged_Warbler 195 | 195 Wilson_Warbler 196 | 196 Northern_Waterthrush 197 | 197 Bohemian_Waxwing 198 | 198 Cedar_Waxwing 199 | 199 American_Three_toed_Woodpecker 200 | 200 Marsh_Wren 201 | -------------------------------------------------------------------------------- /datasets/cub/testclasses.txt: -------------------------------------------------------------------------------- 1 | Yellow_bellied_Flycatcher 2 | Loggerhead_Shrike 3 | Brandt_Cormorant 4 | Scott_Oriole 5 | Evening_Grosbeak 6 | Tree_Sparrow 7 | Scarlet_Tanager 8 | Henslow_Sparrow 9 | White_eyed_Vireo 10 | Le_Conte_Sparrow 11 | Common_Yellowthroat 12 | Pomarine_Jaeger 13 | Orange_crowned_Warbler 14 | Brown_Creeper 15 | Field_Sparrow 16 | Chestnut_sided_Warbler 17 | Sayornis 18 | Wilson_Warbler 19 | Tropical_Kingbird 20 | Yellow_headed_Blackbird 21 | Northern_Fulmar 22 | Red_cockaded_Woodpecker 23 | Red_headed_Woodpecker 24 | Tree_Swallow 25 | Yellow_throated_Vireo 26 | Pied_billed_Grebe 27 | Yellow_billed_Cuckoo 28 | Cerulean_Warbler 29 | Black_billed_Cuckoo 30 | Caspian_Tern 31 | White_breasted_Nuthatch 32 | Green_Violetear 33 | Orchard_Oriole 34 | Mockingbird 35 | American_Pipit 36 | Savannah_Sparrow 37 | Blue_winged_Warbler 38 | Boat_tailed_Grackle 39 | Magnolia_Warbler 40 | Green_tailed_Towhee 41 | Baird_Sparrow 42 | Mallard 43 | Cape_May_Warbler 44 | Barn_Swallow 45 | Pileated_Woodpecker 46 | Red_legged_Kittiwake 47 | Bronzed_Cowbird 48 | Groove_billed_Ani 49 | White_crowned_Sparrow 50 | Kentucky_Warbler 51 | -------------------------------------------------------------------------------- /datasets/sun/classes.txt: -------------------------------------------------------------------------------- 1 | 1 abbey 2 | 2 access_road 3 | 3 airfield 4 | 4 airlock 5 | 5 airplane_cabin 6 | 6 airport_airport 7 | 7 airport_entrance 8 | 8 airport_terminal 9 | 9 airport_ticket_counter 10 | 10 alcove 11 | 11 alley 12 | 12 amphitheater 13 | 13 amusement_arcade 14 | 14 amusement_park 15 | 15 anechoic_chamber 16 | 16 apartment_building_outdoor 17 | 17 apse_indoor 18 | 18 apse_outdoor 19 | 19 aquarium 20 | 20 aquatic_theater 21 | 21 aqueduct 22 | 22 arch 23 | 23 archaelogical_excavation 24 | 24 archive 25 | 25 arena_basketball 26 | 26 arena_hockey 27 | 27 arena_performance 28 | 28 armory 29 | 29 arrival_gate_outdoor 30 | 30 art_gallery 31 | 31 art_school 32 | 32 art_studio 33 | 33 artists_loft 34 | 34 assembly_line 35 | 35 athletic_field_outdoor 36 | 36 atrium_home 37 | 37 atrium_public 38 | 38 attic 39 | 39 auditorium 40 | 40 auto_factory 41 | 41 auto_mechanics_indoor 42 | 42 auto_racing_paddock 43 | 43 auto_showroom 44 | 44 backstage 45 | 45 badlands 46 | 46 badminton_court_indoor 47 | 47 badminton_court_outdoor 48 | 48 baggage_claim 49 | 49 bakery_kitchen 50 | 50 bakery_shop 51 | 51 balcony_exterior 52 | 52 balcony_interior 53 | 53 ball_pit 54 | 54 ballroom 55 | 55 bamboo_forest 56 | 56 bank_indoor 57 | 57 bank_outdoor 58 | 58 bank_vault 59 | 59 banquet_hall 60 | 60 baptistry_indoor 61 | 61 baptistry_outdoor 62 | 62 bar 63 | 63 barn 64 | 64 barndoor 65 | 65 baseball_field 66 | 66 basement 67 | 67 basilica 68 | 68 basketball_court_indoor 69 | 69 basketball_court_outdoor 70 | 70 bathroom 71 | 71 batters_box 72 | 72 batting_cage_indoor 73 | 73 batting_cage_outdoor 74 | 74 bayou 75 | 75 bazaar_indoor 76 | 76 bazaar_outdoor 77 | 77 beach 78 | 78 beach_house 79 | 79 beauty_salon 80 | 80 bedchamber 81 | 81 bedroom 82 | 82 beer_garden 83 | 83 beer_hall 84 | 84 bell_foundry 85 | 85 berth 86 | 86 betting_shop 87 | 87 bicycle_racks 88 | 88 bindery 89 | 89 biology_laboratory 90 | 90 bistro_indoor 91 | 91 bistro_outdoor 92 | 92 bleachers_outdoor 93 | 93 boardwalk 94 | 94 boat_deck 95 | 95 boathouse 96 | 96 bog 97 | 97 bookstore 98 | 98 booth_indoor 99 | 99 botanical_garden 100 | 100 bow_window_indoor 101 | 101 bow_window_outdoor 102 | 102 bowling_alley 103 | 103 boxing_ring 104 | 104 brewery_indoor 105 | 105 brewery_outdoor 106 | 106 brickyard_outdoor 107 | 107 bridge 108 | 108 building_complex 109 | 109 building_facade 110 | 110 bullpen 111 | 111 bullring 112 | 112 burial_chamber 113 | 113 bus_depot_outdoor 114 | 114 bus_interior 115 | 115 bus_shelter 116 | 116 bus_station_outdoor 117 | 117 butchers_shop 118 | 118 butte 119 | 119 cabana 120 | 120 cabin_outdoor 121 | 121 cafeteria 122 | 122 call_center 123 | 123 campsite 124 | 124 campus 125 | 125 canal_natural 126 | 126 canal_urban 127 | 127 candy_store 128 | 128 canteen 129 | 129 canyon 130 | 130 car_interior_backseat 131 | 131 car_interior_frontseat 132 | 132 caravansary 133 | 133 cardroom 134 | 134 cargo_deck_airplane 135 | 135 carport_freestanding 136 | 136 carport_outdoor 137 | 137 carrousel 138 | 138 casino_indoor 139 | 139 casino_outdoor 140 | 140 castle 141 | 141 catacomb 142 | 142 cathedral_indoor 143 | 143 cathedral_outdoor 144 | 144 catwalk 145 | 145 cavern_indoor 146 | 146 cemetery 147 | 147 chalet 148 | 148 chaparral 149 | 149 chapel 150 | 150 checkout_counter 151 | 151 cheese_factory 152 | 152 chemical_plant 153 | 153 chemistry_lab 154 | 154 chicken_coop_indoor 155 | 155 chicken_coop_outdoor 156 | 156 chicken_farm_indoor 157 | 157 chicken_farm_outdoor 158 | 158 childs_room 159 | 159 church_indoor 160 | 160 church_outdoor 161 | 161 circus_tent_indoor 162 | 162 circus_tent_outdoor 163 | 163 city 164 | 164 classroom 165 | 165 clean_room 166 | 166 cliff 167 | 167 cloister_indoor 168 | 168 cloister_outdoor 169 | 169 closet 170 | 170 clothing_store 171 | 171 coast 172 | 172 cockpit 173 | 173 coffee_shop 174 | 174 computer_room 175 | 175 conference_center 176 | 176 conference_hall 177 | 177 conference_room 178 | 178 confessional 179 | 179 construction_site 180 | 180 control_room 181 | 181 control_tower_indoor 182 | 182 control_tower_outdoor 183 | 183 convenience_store_indoor 184 | 184 corn_field 185 | 185 corral 186 | 186 corridor 187 | 187 cottage 188 | 188 cottage_garden 189 | 189 courthouse 190 | 190 courtroom 191 | 191 courtyard 192 | 192 covered_bridge_exterior 193 | 193 crawl_space 194 | 194 creek 195 | 195 crevasse 196 | 196 crosswalk 197 | 197 cubicle_office 198 | 198 cybercafe 199 | 199 dacha 200 | 200 dairy_indoor 201 | 201 dam 202 | 202 darkroom 203 | 203 day_care_center 204 | 204 delicatessen 205 | 205 dentists_office 206 | 206 departure_lounge 207 | 207 desert_road 208 | 208 desert_sand 209 | 209 desert_vegetation 210 | 210 diner_indoor 211 | 211 diner_outdoor 212 | 212 dinette_home 213 | 213 dinette_vehicle 214 | 214 dining_car 215 | 215 dining_hall 216 | 216 dining_room 217 | 217 dirt_track 218 | 218 discotheque 219 | 219 dock 220 | 220 dolmen 221 | 221 donjon 222 | 222 doorway_indoor 223 | 223 doorway_outdoor 224 | 224 dorm_room 225 | 225 downtown 226 | 226 drainage_ditch 227 | 227 drill_rig 228 | 228 driveway 229 | 229 driving_range_outdoor 230 | 230 drugstore 231 | 231 dry_dock 232 | 232 dugout 233 | 233 earth_fissure 234 | 234 editing_room 235 | 235 electrical_substation 236 | 236 elevator_door 237 | 237 elevator_freight_elevator 238 | 238 elevator_interior 239 | 239 elevator_lobby 240 | 240 elevator_shaft 241 | 241 embassy 242 | 242 engine_room 243 | 243 escalator_indoor 244 | 244 escalator_outdoor 245 | 245 estuary 246 | 246 excavation 247 | 247 exhibition_hall 248 | 248 factory_indoor 249 | 249 factory_outdoor 250 | 250 fairway 251 | 251 farm 252 | 252 fastfood_restaurant 253 | 253 fence 254 | 254 ferryboat_outdoor 255 | 255 field_cultivated 256 | 256 field_road 257 | 257 field_wild 258 | 258 fire_escape 259 | 259 fire_station 260 | 260 firing_range_indoor 261 | 261 firing_range_outdoor 262 | 262 fish_farm 263 | 263 fishpond 264 | 264 fjord 265 | 265 flea_market_indoor 266 | 266 flea_market_outdoor 267 | 267 flight_of_stairs_natural 268 | 268 flight_of_stairs_urban 269 | 269 flood 270 | 270 florist_shop_indoor 271 | 271 fly_bridge 272 | 272 food_court 273 | 273 football_field 274 | 274 forest_broadleaf 275 | 275 forest_needleleaf 276 | 276 forest_path 277 | 277 forest_road 278 | 278 formal_garden 279 | 279 fort 280 | 280 fortress 281 | 281 foundry_indoor 282 | 282 foundry_outdoor 283 | 283 fountain 284 | 284 freeway 285 | 285 funeral_chapel 286 | 286 furnace_room 287 | 287 galley 288 | 288 game_room 289 | 289 gangplank 290 | 290 garage_indoor 291 | 291 garage_outdoor 292 | 292 garbage_dump 293 | 293 gas_station 294 | 294 gasworks 295 | 295 gatehouse 296 | 296 gazebo_exterior 297 | 297 general_store_indoor 298 | 298 general_store_outdoor 299 | 299 geodesic_dome_indoor 300 | 300 geodesic_dome_outdoor 301 | 301 ghost_town 302 | 302 gift_shop 303 | 303 glacier 304 | 304 golf_course 305 | 305 gorge 306 | 306 great_hall 307 | 307 greenhouse_indoor 308 | 308 greenhouse_outdoor 309 | 309 grotto 310 | 310 guardhouse 311 | 311 gulch 312 | 312 gun_deck_indoor 313 | 313 gymnasium_indoor 314 | 314 hacienda 315 | 315 hallway 316 | 316 hangar_indoor 317 | 317 hangar_outdoor 318 | 318 harbor 319 | 319 hayfield 320 | 320 heath 321 | 321 hedge_maze 322 | 322 hedgerow 323 | 323 heliport 324 | 324 herb_garden 325 | 325 highway 326 | 326 hill 327 | 327 home_office 328 | 328 home_theater 329 | 329 hoodoo 330 | 330 hospital 331 | 331 hospital_room 332 | 332 hot_spring 333 | 333 hot_tub_indoor 334 | 334 hot_tub_outdoor 335 | 335 hotel_breakfast_area 336 | 336 hotel_outdoor 337 | 337 hotel_room 338 | 338 house 339 | 339 hunting_lodge_indoor 340 | 340 hunting_lodge_outdoor 341 | 341 hut 342 | 342 ice_floe 343 | 343 ice_shelf 344 | 344 ice_skating_rink_indoor 345 | 345 ice_skating_rink_outdoor 346 | 346 iceberg 347 | 347 igloo 348 | 348 industrial_area 349 | 349 industrial_park 350 | 350 inn_indoor 351 | 351 inn_outdoor 352 | 352 irrigation_ditch 353 | 353 islet 354 | 354 jacuzzi_indoor 355 | 355 jacuzzi_outdoor 356 | 356 jail_cell 357 | 357 jail_indoor 358 | 358 jail_outdoor 359 | 359 japanese_garden 360 | 360 jewelry_shop 361 | 361 joss_house 362 | 362 junk_pile 363 | 363 junkyard 364 | 364 jury_box 365 | 365 kasbah 366 | 366 kennel_indoor 367 | 367 kennel_outdoor 368 | 368 kindergarden_classroom 369 | 369 kiosk_indoor 370 | 370 kiosk_outdoor 371 | 371 kitchen 372 | 372 kitchenette 373 | 373 lab_classroom 374 | 374 labyrinth_indoor 375 | 375 labyrinth_outdoor 376 | 376 lagoon 377 | 377 lake_artificial 378 | 378 lake_natural 379 | 379 landfill 380 | 380 landing_deck 381 | 381 laundromat 382 | 382 lawn 383 | 383 lean-to 384 | 384 lecture_room 385 | 385 levee 386 | 386 library_indoor 387 | 387 library_outdoor 388 | 388 lido_deck_outdoor 389 | 389 lift_bridge 390 | 390 lighthouse 391 | 391 limousine_interior 392 | 392 liquor_store_indoor 393 | 393 liquor_store_outdoor 394 | 394 living_room 395 | 395 loading_dock 396 | 396 lobby 397 | 397 lock_chamber 398 | 398 locker_room 399 | 399 lookout_station_outdoor 400 | 400 machine_shop 401 | 401 manhole 402 | 402 mansion 403 | 403 manufactured_home 404 | 404 market_indoor 405 | 405 market_outdoor 406 | 406 marsh 407 | 407 martial_arts_gym 408 | 408 mastaba 409 | 409 mausoleum 410 | 410 medina 411 | 411 mesa 412 | 412 military_hospital 413 | 413 military_hut 414 | 414 mine 415 | 415 mineshaft 416 | 416 mini_golf_course_outdoor 417 | 417 mission 418 | 418 moat_dry 419 | 419 moat_water 420 | 420 mobile_home 421 | 421 monastery_outdoor 422 | 422 moor 423 | 423 morgue 424 | 424 mosque_indoor 425 | 425 mosque_outdoor 426 | 426 motel 427 | 427 mountain 428 | 428 mountain_path 429 | 429 mountain_road 430 | 430 mountain_snowy 431 | 431 movie_theater_indoor 432 | 432 movie_theater_outdoor 433 | 433 museum_indoor 434 | 434 museum_outdoor 435 | 435 music_store 436 | 436 music_studio 437 | 437 natural_history_museum 438 | 438 naval_base 439 | 439 newsroom 440 | 440 newsstand_outdoor 441 | 441 nightclub 442 | 442 nuclear_power_plant_indoor 443 | 443 nuclear_power_plant_outdoor 444 | 444 nursery 445 | 445 nursing_home 446 | 446 oasis 447 | 447 oast_house 448 | 448 observatory_indoor 449 | 449 observatory_outdoor 450 | 450 ocean 451 | 451 office 452 | 452 office_building 453 | 453 office_cubicles 454 | 454 oil_refinery_outdoor 455 | 455 oilrig 456 | 456 operating_room 457 | 457 optician 458 | 458 orchard 459 | 459 organ_loft_exterior 460 | 460 ossuary 461 | 461 outcropping 462 | 462 outhouse_outdoor 463 | 463 overpass 464 | 464 packaging_plant 465 | 465 pagoda 466 | 466 palace 467 | 467 pantry 468 | 468 parade_ground 469 | 469 park 470 | 470 parking_garage_indoor 471 | 471 parking_garage_outdoor 472 | 472 parking_lot 473 | 473 parlor 474 | 474 particle_accelerator 475 | 475 pasture 476 | 476 patio 477 | 477 pavilion 478 | 478 pedestrian_overpass_outdoor 479 | 479 pet_shop 480 | 480 pharmacy 481 | 481 phone_booth 482 | 482 physics_laboratory 483 | 483 piano_store 484 | 484 picnic_area 485 | 485 pier 486 | 486 pig_farm 487 | 487 pilothouse_indoor 488 | 488 pilothouse_outdoor 489 | 489 pitchers_mound 490 | 490 pizzeria 491 | 491 planetarium_indoor 492 | 492 planetarium_outdoor 493 | 493 plantation_house 494 | 494 playground 495 | 495 playroom 496 | 496 plaza 497 | 497 podium_indoor 498 | 498 podium_outdoor 499 | 499 pond 500 | 500 poolroom_establishment 501 | 501 poolroom_home 502 | 502 porch 503 | 503 portico 504 | 504 power_plant_indoor 505 | 505 power_plant_outdoor 506 | 506 print_shop 507 | 507 priory 508 | 508 promenade 509 | 509 promenade_deck 510 | 510 pub_indoor 511 | 511 pub_outdoor 512 | 512 pulpit 513 | 513 pump_room 514 | 514 putting_green 515 | 515 quadrangle 516 | 516 quay 517 | 517 quonset_hut_outdoor 518 | 518 racecourse 519 | 519 raceway 520 | 520 raft 521 | 521 railroad_track 522 | 522 railway_yard 523 | 523 rainforest 524 | 524 ramp 525 | 525 ranch 526 | 526 ranch_house 527 | 527 reading_room 528 | 528 reception 529 | 529 recreation_room 530 | 530 rectory 531 | 531 recycling_plant_indoor 532 | 532 recycling_plant_outdoor 533 | 533 repair_shop 534 | 534 residential_neighborhood 535 | 535 resort 536 | 536 restaurant 537 | 537 restaurant_kitchen 538 | 538 restaurant_patio 539 | 539 restroom_indoor 540 | 540 restroom_outdoor 541 | 541 revolving_door 542 | 542 rice_paddy 543 | 543 riding_arena 544 | 544 river 545 | 545 road_cut 546 | 546 rock_arch 547 | 547 rolling_mill 548 | 548 roof 549 | 549 root_cellar 550 | 550 rope_bridge 551 | 551 roundabout 552 | 552 rubble 553 | 553 ruin 554 | 554 runway 555 | 555 sacristy 556 | 556 salt_plain 557 | 557 sand_trap 558 | 558 sandbar 559 | 559 sandbox 560 | 560 sauna 561 | 561 savanna 562 | 562 sawmill 563 | 563 schoolhouse 564 | 564 schoolyard 565 | 565 science_museum 566 | 566 sea_cliff 567 | 567 seawall 568 | 568 server_room 569 | 569 sewing_room 570 | 570 shed 571 | 571 shipping_room 572 | 572 shipyard_outdoor 573 | 573 shoe_shop 574 | 574 shopfront 575 | 575 shopping_mall_indoor 576 | 576 shower 577 | 577 signal_box 578 | 578 skatepark 579 | 579 ski_jump 580 | 580 ski_lodge 581 | 581 ski_resort 582 | 582 ski_slope 583 | 583 sky 584 | 584 skyscraper 585 | 585 slum 586 | 586 snowfield 587 | 587 soccer_field 588 | 588 spillway 589 | 589 squash_court 590 | 590 stable 591 | 591 stadium_baseball 592 | 592 stadium_football 593 | 593 stadium_outdoor 594 | 594 stadium_soccer 595 | 595 stage_indoor 596 | 596 stage_outdoor 597 | 597 staircase 598 | 598 steel_mill_indoor 599 | 599 steel_mill_outdoor 600 | 600 stilt_house_water 601 | 601 stone_circle 602 | 602 street 603 | 603 strip_mall 604 | 604 strip_mine 605 | 605 submarine_interior 606 | 606 subway_interior 607 | 607 subway_station_corridor 608 | 608 subway_station_platform 609 | 609 sun_deck 610 | 610 supermarket 611 | 611 sushi_bar 612 | 612 swamp 613 | 613 swimming_hole 614 | 614 swimming_pool_indoor 615 | 615 swimming_pool_outdoor 616 | 616 synagogue_indoor 617 | 617 synagogue_outdoor 618 | 618 tea_garden 619 | 619 tearoom 620 | 620 teashop 621 | 621 television_studio 622 | 622 temple_east_asia 623 | 623 temple_south_asia 624 | 624 temple_western 625 | 625 tennis_court_indoor 626 | 626 tennis_court_outdoor 627 | 627 tent_indoor 628 | 628 tent_outdoor 629 | 629 terrace_farm 630 | 630 theater_indoor_procenium 631 | 631 theater_indoor_round 632 | 632 theater_indoor_seats 633 | 633 theater_outdoor 634 | 634 thriftshop 635 | 635 throne_room 636 | 636 ticket_booth 637 | 637 ticket_window_outdoor 638 | 638 toll_plaza 639 | 639 tollbooth 640 | 640 topiary_garden 641 | 641 tower 642 | 642 town_house 643 | 643 toyshop 644 | 644 track_indoor 645 | 645 track_outdoor 646 | 646 trading_floor 647 | 647 trailer_park 648 | 648 train_depot 649 | 649 train_railway 650 | 650 train_station_outdoor 651 | 651 train_station_platform 652 | 652 train_station_station 653 | 653 tree_farm 654 | 654 tree_house 655 | 655 trench 656 | 656 trestle_bridge 657 | 657 tundra 658 | 658 tunnel_rail_outdoor 659 | 659 tunnel_road_outdoor 660 | 660 underwater_coral_reef 661 | 661 underwater_ice 662 | 662 underwater_kelp_forest 663 | 663 underwater_ocean_deep 664 | 664 underwater_ocean_shallow 665 | 665 underwater_pool 666 | 666 underwater_wreck 667 | 667 utility_room 668 | 668 valley 669 | 669 van_interior 670 | 670 vegetable_garden 671 | 671 velodrome_indoor 672 | 672 velodrome_outdoor 673 | 673 ventilation_shaft 674 | 674 veranda 675 | 675 vestry 676 | 676 veterinarians_office 677 | 677 viaduct 678 | 678 videostore 679 | 679 village 680 | 680 vineyard 681 | 681 volcano 682 | 682 volleyball_court_outdoor 683 | 683 voting_booth 684 | 684 waiting_room 685 | 685 warehouse_indoor 686 | 686 watchtower 687 | 687 water_mill 688 | 688 water_tower 689 | 689 water_treatment_plant_indoor 690 | 690 water_treatment_plant_outdoor 691 | 691 waterfall_block 692 | 692 waterfall_cascade 693 | 693 waterfall_cataract 694 | 694 waterfall_fan 695 | 695 waterfall_plunge 696 | 696 watering_hole 697 | 697 wave 698 | 698 weighbridge 699 | 699 wet_bar 700 | 700 wharf 701 | 701 wheat_field 702 | 702 wind_farm 703 | 703 windmill 704 | 704 window_seat 705 | 705 wine_cellar_barrel_storage 706 | 706 wine_cellar_bottle_storage 707 | 707 winery 708 | 708 witness_stand 709 | 709 woodland 710 | 710 workroom 711 | 711 workshop 712 | 712 wrestling_ring_indoor 713 | 713 yard 714 | 714 youth_hostel 715 | 715 zen_garden 716 | 716 ziggurat 717 | 717 zoo 718 | -------------------------------------------------------------------------------- /datasets/sun/testclasses.txt: -------------------------------------------------------------------------------- 1 | airlock 2 | temple_south_asia 3 | monastery_outdoor 4 | piano_store 5 | watering_hole 6 | geodesic_dome_indoor 7 | chemistry_lab 8 | vineyard 9 | brewery_indoor 10 | pub_indoor 11 | vestry 12 | bog 13 | wrestling_ring_indoor 14 | trading_floor 15 | volleyball_court_outdoor 16 | racecourse 17 | cemetery 18 | betting_shop 19 | casino_outdoor 20 | theater_indoor_seats 21 | dirt_track 22 | motel 23 | ballroom 24 | sandbox 25 | mosque_indoor 26 | batting_cage_outdoor 27 | bow_window_indoor 28 | car_interior_frontseat 29 | alley 30 | nightclub 31 | arena_basketball 32 | hotel_room 33 | jacuzzi_indoor 34 | tunnel_road_outdoor 35 | bank_vault 36 | excavation 37 | lawn 38 | galley 39 | bazaar_outdoor 40 | canal_natural 41 | train_station_platform 42 | cubicle_office 43 | fishpond 44 | ice_shelf 45 | bazaar_indoor 46 | field_cultivated 47 | exhibition_hall 48 | observatory_outdoor 49 | rectory 50 | tundra 51 | firing_range_indoor 52 | doorway_indoor 53 | ticket_booth 54 | bus_depot_outdoor 55 | auditorium 56 | ski_resort 57 | ziggurat 58 | elevator_interior 59 | artists_loft 60 | archive 61 | playground 62 | hoodoo 63 | parking_lot 64 | yard 65 | corral 66 | japanese_garden 67 | promenade_deck 68 | landing_deck 69 | workshop 70 | hangar_indoor 71 | church_indoor 72 | savanna 73 | -------------------------------------------------------------------------------- /datautils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataset import Dataset 3 | from torch.autograd import Variable 4 | 5 | import numpy as np 6 | import random 7 | import scipy.io as scio 8 | from sklearn.preprocessing import MinMaxScaler 9 | 10 | class ZSLDataset(Dataset): 11 | def __init__(self, dset, n_train, n_test, gzsl=False, train=True, synthetic=False, syn_dataset=None): 12 | ''' 13 | Base class for all datasets 14 | Args: 15 | dset : Name of dataset - 1 among [sun, cub, awa1, awa2] 16 | n_train : Number of train classes 17 | n_test : Number of test classes 18 | gzsl : Boolean for Generalized ZSL 19 | train : Boolean indicating whether train/test 20 | synthetic : Boolean indicating whether dataset is for synthetic examples 21 | syn_dataset : A list consisting of 3-tuple (z, _, y) used for sampling 22 | only when synthetic flag is True 23 | ''' 24 | super(ZSLDataset, self).__init__() 25 | self.dset = dset 26 | self.n_train = n_train 27 | self.n_test = n_test 28 | self.train = train 29 | self.gzsl = gzsl 30 | self.synthetic = synthetic 31 | 32 | res101_data = scio.loadmat('./datasets/%s/res101.mat' % dset) 33 | self.features = self.normalize(res101_data['features'].T) 34 | self.labels = res101_data['labels'].reshape(-1) 35 | 36 | self.attribute_dict = scio.loadmat('./datasets/%s/att_splits.mat' % dset) 37 | self.attributes = self.attribute_dict['att'].T 38 | 39 | # file with all class names for deciding train/test split 40 | self.class_names_file = './datasets/%s/classes.txt' % dset 41 | 42 | # test class names 43 | with open('./datasets/%s/testclasses.txt' % dset) as fp: 44 | self.test_class_names = [i.strip() for i in fp.readlines() if i != ''] 45 | 46 | assert len(self.test_class_names) == self.n_test 47 | 48 | if self.synthetic: 49 | assert syn_dataset is not None 50 | self.syn_dataset = syn_dataset 51 | else: 52 | self.dataset = self.create_orig_dataset() 53 | if self.train: 54 | self.gzsl_dataset = self.create_gzsl_dataset() 55 | 56 | def normalize(self, matrix): 57 | scaler = MinMaxScaler() 58 | return scaler.fit_transform(matrix) 59 | 60 | def get_classmap(self): 61 | ''' 62 | Creates a mapping between serial number of a class 63 | in provided dataset and the indices used for classification. 64 | Returns: 65 | 2 dicts, 1 each for train and test classes 66 | ''' 67 | with open(self.class_names_file) as fp: 68 | all_classes = fp.readlines() 69 | 70 | test_count = 0 71 | train_count = 0 72 | 73 | train_classmap = dict() 74 | test_classmap = dict() 75 | for line in all_classes: 76 | idx, name = [i.strip() for i in line.split(' ')] 77 | if name in self.test_class_names: 78 | if self.gzsl: 79 | # train classes are also included in test time 80 | test_classmap[int(idx)] = self.n_train + test_count 81 | else: 82 | test_classmap[int(idx)] = test_count 83 | test_count += 1 84 | else: 85 | train_classmap[int(idx)] = train_count 86 | train_count += 1 87 | return train_classmap, test_classmap 88 | 89 | def create_gzsl_dataset(self, n_samples=200): 90 | ''' 91 | Create an auxillary dataset to be used during training final 92 | classifier on seen classes 93 | ''' 94 | dataset = [] 95 | for key in self.gzsl_map.keys(): 96 | features = self.gzsl_map[key]['feat'] 97 | if len(features) < n_samples: 98 | aug_features = [random.choice(features) for _ in range(n_samples)] 99 | else: 100 | aug_features = random.sample(features, n_samples) 101 | label = self.gzsl_map[key]['label'] 102 | dataset.extend([(torch.FloatTensor(f), label, key) for f in aug_features]) 103 | return dataset 104 | 105 | def create_orig_dataset(self): 106 | ''' 107 | Returns list of 3-tuple: (feature, label_in_dataset, label_for_classification) 108 | ''' 109 | self.train_classmap, self.test_classmap = self.get_classmap() 110 | 111 | if self.train: 112 | labels = self.attribute_dict['trainval_loc'].reshape(-1) 113 | classmap = self.train_classmap 114 | self.gzsl_map = dict() 115 | else: 116 | labels = self.attribute_dict['test_unseen_loc'].reshape(-1) 117 | if self.gzsl: 118 | labels = np.concatenate((labels, self.attribute_dict['test_seen_loc'].reshape(-1))) 119 | classmap = {**self.train_classmap, **self.test_classmap} 120 | else: 121 | classmap = self.test_classmap 122 | 123 | dataset = [] 124 | for l in labels: 125 | idx = self.labels[l - 1] 126 | dataset.append((self.features[l - 1], idx, classmap[idx])) 127 | if self.train: 128 | # create a map bw class label and features 129 | if self.gzsl_map.get(classmap[idx], None): 130 | try: 131 | self.gzsl_map[classmap[idx]]['feat'].append(self.features[l - 1]) 132 | except Exception: 133 | self.gzsl_map[classmap[idx]]['feat'] = [self.features[l - 1]] 134 | else: 135 | self.gzsl_map[classmap[idx]] = {} 136 | 137 | # Add the label to map 138 | self.gzsl_map[classmap[idx]]['label'] = idx 139 | return dataset 140 | 141 | def __getitem__(self, index): 142 | if self.synthetic: 143 | # choose an example from synthetic dataset 144 | img_feature, orig_label, label_idx = self.syn_dataset[index] 145 | else: 146 | # choose an example from original dataset 147 | img_feature, orig_label, label_idx = self.dataset[index] 148 | 149 | label_attr = self.attributes[orig_label - 1] 150 | return img_feature, label_attr, label_idx 151 | 152 | def __len__(self): 153 | if self.synthetic: 154 | return len(self.syn_dataset) 155 | else: 156 | return len(self.dataset) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import argparse 4 | 5 | from datautils import ZSLDataset 6 | from trainer import Trainer 7 | 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument('--dataset', type=str, default='awa2') 11 | parser.add_argument('--gzsl', action='store_true', default=False) 12 | parser.add_argument('--latent_dim', type=int, default=128) 13 | parser.add_argument('--n_critic', type=int, default=5) 14 | parser.add_argument('--lmbda', type=float, default=10.0) 15 | parser.add_argument('--beta', type=float, default=0.01) 16 | parser.add_argument('--batch_size', type=int, default=128) 17 | parser.add_argument('--n_epochs', type=int, default=10) 18 | parser.add_argument('--use_cls_loss', action='store_true', default=False) 19 | parser.add_argument('--visualize', action='store_true', default=False) 20 | 21 | args = parser.parse_args() 22 | 23 | if args.dataset == 'awa2' or args.dataset == 'awa1': 24 | x_dim = 2048 25 | attr_dim = 85 26 | n_train = 40 27 | n_test = 10 28 | elif args.dataset == 'cub': 29 | x_dim = 2048 30 | attr_dim = 312 31 | n_train = 150 32 | n_test = 50 33 | elif args.dataset == 'sun': 34 | x_dim = 2048 35 | attr_dim = 102 36 | n_train = 645 37 | n_test = 72 38 | else: 39 | raise NotImplementedError 40 | 41 | n_epochs = args.n_epochs 42 | 43 | # trainer object for mini batch training 44 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 45 | train_agent = Trainer( 46 | device, x_dim, args.latent_dim, attr_dim, 47 | n_train=n_train, n_test=n_test, gzsl=args.gzsl, 48 | n_critic=args.n_critic, lmbda=args.lmbda, beta=args.beta, 49 | batch_size=args.batch_size 50 | ) 51 | 52 | params = { 53 | 'batch_size': args.batch_size, 54 | 'shuffle': True, 55 | 'num_workers': 0, 56 | 'drop_last': True 57 | } 58 | 59 | train_dataset = ZSLDataset(args.dataset, n_train, n_test, args.gzsl) 60 | train_generator = DataLoader(train_dataset, **params) 61 | 62 | # ============================================================= 63 | # PRETRAIN THE SOFTMAX CLASSIFIER 64 | # ============================================================= 65 | model_name = "%s_disc_classifier" % args.dataset 66 | success = train_agent.load_model(model=model_name) 67 | if success: 68 | print("Discriminative classifier parameters loaded...") 69 | else: 70 | print("Training the discriminative classifier...") 71 | for ep in range(1, n_epochs + 1): 72 | loss = 0 73 | for idx, (img_features, label_attr, label_idx) in enumerate(train_generator): 74 | l = train_agent.fit_classifier(img_features, label_attr, label_idx) 75 | loss += l 76 | 77 | print("Loss for epoch: %3d - %.4f" %(ep, loss)) 78 | 79 | train_agent.save_model(model=model_name) 80 | 81 | # ============================================================= 82 | # TRAIN THE GANs 83 | # ============================================================= 84 | model_name = "%s_generator" % args.dataset 85 | success = train_agent.load_model(model=model_name) 86 | if success: 87 | print("\nGAN parameters loaded....") 88 | else: 89 | print("\nTraining the GANS...") 90 | for ep in range(1, n_epochs + 1): 91 | loss_dis = 0 92 | loss_gan = 0 93 | for idx, (img_features, label_attr, label_idx) in enumerate(train_generator): 94 | l_d, l_g = train_agent.fit_GAN(img_features, label_attr, label_idx, args.use_cls_loss) 95 | loss_dis += l_d 96 | loss_gan += l_g 97 | 98 | print("Loss for epoch: %3d - D: %.4f | G: %.4f"\ 99 | %(ep, loss_dis, loss_gan)) 100 | 101 | train_agent.save_model(model=model_name) 102 | 103 | # ============================================================= 104 | # TRAIN FINAL CLASSIFIER ON SYNTHETIC DATASET 105 | # ============================================================= 106 | 107 | # create new synthetic dataset using trained Generator 108 | seen_dataset = None 109 | if args.gzsl: 110 | seen_dataset = train_dataset.gzsl_dataset 111 | 112 | syn_dataset = train_agent.create_syn_dataset( 113 | train_dataset.test_classmap, train_dataset.attributes, seen_dataset) 114 | final_dataset = ZSLDataset(args.dataset, n_train, n_test, 115 | gzsl=args.gzsl, train=True, synthetic=True, syn_dataset=syn_dataset) 116 | final_train_generator = DataLoader(final_dataset, **params) 117 | 118 | model_name = "%s_final_classifier" % args.dataset 119 | success = train_agent.load_model(model=model_name) 120 | if success: 121 | print("\nFinal classifier parameters loaded....") 122 | else: 123 | print("\nTraining the final classifier on the synthetic dataset...") 124 | for ep in range(1, n_epochs + 1): 125 | syn_loss = 0 126 | for idx, (img, label_attr, label_idx) in enumerate(final_train_generator): 127 | l = train_agent.fit_final_classifier(img, label_attr, label_idx) 128 | syn_loss += l 129 | 130 | # print losses on real and synthetic datasets 131 | print("Loss for epoch: %3d - %.4f" %(ep, syn_loss)) 132 | 133 | train_agent.save_model(model=model_name) 134 | 135 | # ============================================================= 136 | # TESTING PHASE 137 | # ============================================================= 138 | test_dataset = ZSLDataset(args.dataset, n_train, n_test, gzsl=args.gzsl, train=False) 139 | test_generator = DataLoader(test_dataset, **params) 140 | 141 | print("\nFinal Accuracy on ZSL Task: %.3f" % train_agent.test(test_generator)) -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | class Generator(nn.Module): 6 | def __init__(self, z_dim, attr_dim): 7 | super(Generator, self).__init__() 8 | self.model = nn.Sequential( 9 | nn.Linear(z_dim + attr_dim, 4096), 10 | nn.LeakyReLU(), 11 | nn.Linear(4096, 2048), 12 | nn.ReLU(), 13 | ) 14 | 15 | def forward(self, z): 16 | return self.model(z) 17 | 18 | class Discriminator(nn.Module): 19 | def __init__(self, x_dim, attr_dim): 20 | super(Discriminator, self).__init__() 21 | self.model = nn.Sequential( 22 | nn.Linear(x_dim + attr_dim, 4096), 23 | nn.LeakyReLU(), 24 | nn.Linear(4096, 1), 25 | nn.Sigmoid(), 26 | ) 27 | 28 | def forward(self, x): 29 | return self.model(x) 30 | 31 | class MLPClassifier(nn.Module): 32 | def __init__(self, x_dim, attr_dim, out_dim): 33 | super(MLPClassifier, self).__init__() 34 | self.model = nn.Sequential( 35 | nn.Linear(x_dim + attr_dim, 2000), 36 | nn.LeakyReLU(), 37 | nn.Dropout(0.2), 38 | nn.Linear(2000, 1200), 39 | nn.LeakyReLU(), 40 | nn.Dropout(0.2), 41 | nn.Linear(1200, 1200), 42 | nn.LeakyReLU(), 43 | nn.Dropout(0.2), 44 | nn.Linear(1200, out_dim), 45 | ) 46 | 47 | def forward(self, x): 48 | return self.model(x) 49 | 50 | class Resnet101(nn.Module): 51 | def __init__(self, finetune=False): 52 | super(Resnet101, self).__init__() 53 | resnet101 = models.resnet101(pretrained=True) 54 | modules = list(resnet101.children())[:-1] 55 | 56 | self.model = nn.Sequential(*modules) 57 | if not finetune: 58 | for p in self.model.parameters(): 59 | p.requires_grad = False 60 | 61 | def forward(self, x): 62 | return self.model(x).squeeze() 63 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.distributions import uniform, normal 7 | 8 | import os 9 | import numpy as np 10 | from sklearn.metrics import accuracy_score 11 | 12 | from models import Generator, Discriminator, MLPClassifier, Resnet101 13 | 14 | class Trainer: 15 | def __init__( 16 | self, device, x_dim, z_dim, attr_dim, **kwargs): 17 | ''' 18 | Trainer class. 19 | Args: 20 | device: CPU/GPU 21 | x_dim: Dimension of image feature vector 22 | z_dim: Dimension of noise vector 23 | attr_dim: Dimension of attribute vector 24 | kwargs 25 | ''' 26 | self.device = device 27 | 28 | self.x_dim = x_dim 29 | self.z_dim = z_dim 30 | self.attr_dim = attr_dim 31 | 32 | self.n_critic = kwargs.get('n_critic', 5) 33 | self.lmbda = kwargs.get('lmbda', 10.0) 34 | self.beta = kwargs.get('beta', 0.01) 35 | self.bs = kwargs.get('batch_size', 32) 36 | 37 | self.gzsl = kwargs.get('gzsl', False) 38 | self.n_train = kwargs.get('n_train') 39 | self.n_test = kwargs.get('n_test') 40 | if self.gzsl: 41 | self.n_test = self.n_train + self.n_test 42 | 43 | self.eps_dist = uniform.Uniform(0, 1) 44 | self.Z_dist = normal.Normal(0, 1) 45 | 46 | self.eps_shape = torch.Size([self.bs, 1]) 47 | self.z_shape = torch.Size([self.bs, self.z_dim]) 48 | 49 | self.net_G = Generator(self.z_dim, self.attr_dim).to(self.device) 50 | self.optim_G = optim.Adam(self.net_G.parameters(), lr=1e-4) 51 | 52 | self.net_D = Discriminator(self.x_dim, self.attr_dim).to(self.device) 53 | self.optim_D = optim.Adam(self.net_D.parameters(), lr=1e-4) 54 | 55 | # classifier for judging the output of generator 56 | self.classifier = MLPClassifier( 57 | self.x_dim, self.attr_dim, self.n_train 58 | ).to(self.device) 59 | self.optim_cls = optim.Adam(self.classifier.parameters(), lr=1e-4) 60 | 61 | # Final classifier trained on augmented data for GZSL 62 | self.final_classifier = MLPClassifier( 63 | self.x_dim, self.attr_dim, self.n_test 64 | ).to(self.device) 65 | self.optim_final_cls = optim.Adam(self.final_classifier.parameters(), lr=1e-4) 66 | 67 | self.criterion_cls = nn.CrossEntropyLoss() 68 | 69 | self.model_save_dir = "saved_models" 70 | if not os.path.exists(self.model_save_dir): 71 | os.mkdir(self.model_save_dir) 72 | 73 | def get_conditional_input(self, X, C_Y): 74 | new_X = torch.cat([X, C_Y], dim=1).float() 75 | return autograd.Variable(new_X).to(self.device) 76 | 77 | def fit_classifier(self, img_features, label_attr, label_idx): 78 | ''' 79 | Train the classifier in supervised manner on a single 80 | minibatch of available data 81 | Args: 82 | img : bs X 2048 83 | label_attr : bs X 85 84 | label_idx : bs 85 | Returns: 86 | loss for the minibatch 87 | ''' 88 | img_features = autograd.Variable(img_features).to(self.device) 89 | label_attr = autograd.Variable(label_attr).to(self.device) 90 | label_idx = autograd.Variable(label_idx).to(self.device) 91 | 92 | X_inp = self.get_conditional_input(img_features, label_attr) 93 | Y_pred = self.classifier(X_inp) 94 | 95 | self.optim_cls.zero_grad() 96 | loss = self.criterion_cls(Y_pred, label_idx) 97 | loss.backward() 98 | self.optim_cls.step() 99 | 100 | return loss.item() 101 | 102 | def get_gradient_penalty(self, X_real, X_gen): 103 | eps = self.eps_dist.sample(self.eps_shape).to(self.device) 104 | X_penalty = eps * X_real + (1 - eps) * X_gen 105 | 106 | X_penalty = autograd.Variable(X_penalty, requires_grad=True).to(self.device) 107 | critic_pred = self.net_D(X_penalty) 108 | grad_outputs = torch.ones(critic_pred.size()).to(self.device) 109 | gradients = autograd.grad( 110 | outputs=critic_pred, inputs=X_penalty, 111 | grad_outputs=grad_outputs, 112 | create_graph=True, retain_graph=True, only_inputs=True 113 | )[0] 114 | 115 | grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 116 | return grad_penalty 117 | 118 | def fit_GAN(self, img_features, label_attr, label_idx, use_cls_loss=True): 119 | L_gen = 0 120 | L_disc = 0 121 | total_L_disc = 0 122 | 123 | img_features = autograd.Variable(img_features.float()).to(self.device) 124 | label_attr = autograd.Variable(label_attr.float()).to(self.device) 125 | label_idx = label_idx.to(self.device) 126 | 127 | # ============================================================= 128 | # optimize discriminator 129 | # ============================================================= 130 | X_real = self.get_conditional_input(img_features, label_attr) 131 | for _ in range(self.n_critic): 132 | Z = self.Z_dist.sample(self.z_shape).to(self.device) 133 | Z = self.get_conditional_input(Z, label_attr) 134 | 135 | X_gen = self.net_G(Z) 136 | X_gen = self.get_conditional_input(X_gen, label_attr) 137 | 138 | # calculate normal GAN loss 139 | L_disc = (self.net_D(X_gen) - self.net_D(X_real)).mean() 140 | 141 | # calculate gradient penalty 142 | grad_penalty = self.get_gradient_penalty(X_real, X_gen) 143 | L_disc += self.lmbda * grad_penalty 144 | 145 | # update critic params 146 | self.optim_D.zero_grad() 147 | L_disc.backward() 148 | self.optim_D.step() 149 | 150 | total_L_disc += L_disc.item() 151 | 152 | # ============================================================= 153 | # optimize generator 154 | # ============================================================= 155 | Z = self.Z_dist.sample(self.z_shape).to(self.device) 156 | Z = self.get_conditional_input(Z, label_attr) 157 | 158 | X_gen = self.net_G(Z) 159 | X = torch.cat([X_gen, label_attr], dim=1).float() 160 | L_gen = -1 * torch.mean(self.net_D(X)) 161 | 162 | if use_cls_loss: 163 | self.classifier.eval() 164 | Y_pred = F.softmax(self.classifier(X), dim=0) 165 | log_prob = torch.log(torch.gather(Y_pred, 1, label_idx.unsqueeze(1))) 166 | L_cls = -1 * torch.mean(log_prob) 167 | L_gen += self.beta * L_cls 168 | 169 | self.optim_G.zero_grad() 170 | L_gen.backward() 171 | self.optim_G.step() 172 | 173 | return total_L_disc, L_gen.item() 174 | 175 | def fit_final_classifier(self, img_features, label_attr, label_idx): 176 | img_features = autograd.Variable(img_features.float()).to(self.device) 177 | label_attr = autograd.Variable(label_attr.float()).to(self.device) 178 | label_idx = label_idx.to(self.device) 179 | 180 | X_inp = self.get_conditional_input(img_features, label_attr) 181 | Y_pred = self.final_classifier(X_inp) 182 | 183 | self.optim_final_cls.zero_grad() 184 | loss = self.criterion_cls(Y_pred, label_idx) 185 | loss.backward() 186 | self.optim_final_cls.step() 187 | 188 | return loss.item() 189 | 190 | def create_syn_dataset(self, test_labels, attributes, seen_dataset, n_examples=400): 191 | ''' 192 | Creates a synthetic dataset based on attribute vectors of unseen class 193 | Args: 194 | test_labels: A dict with key as original serial number in provided 195 | dataset and value as the index which is predicted during 196 | classification by network 197 | attributes: A np array containing class attributes for each class 198 | of dataset 199 | seen_dataset: A list of 3-tuple (x, orig_label, y) where x belongs to one of the 200 | seen classes and y is classification label. Used for generating 201 | latent representations of seen classes in GZSL 202 | n_samples: Number of samples of each unseen class to be generated(Default: 400) 203 | Returns: 204 | A list of 3-tuple (z, _, y) where z is latent representations and y is 205 | ''' 206 | syn_dataset = [] 207 | for test_cls, idx in test_labels.items(): 208 | attr = attributes[test_cls - 1] 209 | z = self.Z_dist.sample(torch.Size([n_examples, self.z_dim])) 210 | c_y = torch.stack([torch.FloatTensor(attr) for _ in range(n_examples)]) 211 | 212 | z_inp = self.get_conditional_input(z, c_y) 213 | X_gen = self.net_G(z_inp) 214 | 215 | syn_dataset.extend([(X_gen[i], test_cls, idx) for i in range(n_examples)]) 216 | 217 | if seen_dataset is not None: 218 | syn_dataset.extend(seen_dataset) 219 | 220 | return syn_dataset 221 | 222 | def test(self, data_generator, pretrained=False): 223 | if pretrained: 224 | model = self.classifier 225 | else: 226 | model = self.final_classifier 227 | 228 | # eval mode 229 | model.eval() 230 | batch_accuracies = [] 231 | for idx, (img_features, label_attr, label_idx) in enumerate(data_generator): 232 | img_features = img_features.to(self.device) 233 | label_attr = label_attr.to(self.device) 234 | 235 | X_inp = self.get_conditional_input(img_features, label_attr) 236 | with torch.no_grad(): 237 | Y_probs = model(X_inp) 238 | _, Y_pred = torch.max(Y_probs, dim=1) 239 | 240 | Y_pred = Y_pred.cpu().numpy() 241 | Y_real = label_idx.cpu().numpy() 242 | 243 | acc = accuracy_score(Y_pred, Y_real) 244 | batch_accuracies.append(acc) 245 | return np.mean(batch_accuracies) 246 | 247 | def save_model(self, model=None): 248 | if "disc_classifier" in model: 249 | ckpt_path = os.path.join(self.model_save_dir, model + ".pth") 250 | torch.save(self.classifier.state_dict(), ckpt_path) 251 | 252 | elif "gan" in model: 253 | dset_name = model.split('_')[0] 254 | g_ckpt_path = os.path.join(self.model_save_dir, "%s_generator.pth" % dset_name) 255 | torch.save(self.net_G.state_dict(), g_ckpt_path) 256 | 257 | d_ckpt_path = os.path.join(self.model_save_dir, "%s_discriminator.pth" % dset_name) 258 | torch.save(self.net_D.state_dict(), d_ckpt_path) 259 | 260 | elif "final_classifier" in model: 261 | ckpt_path = os.path.join(self.model_save_dir, model + ".pth") 262 | torch.save(self.final_classifier.state_dict(), ckpt_path) 263 | 264 | else: 265 | raise Exception("Trying to save unknown model: %s" % model) 266 | 267 | def load_model(self, model=None): 268 | if "disc_classifier" in model: 269 | ckpt_path = os.path.join(self.model_save_dir, model + ".pth") 270 | if os.path.exists(ckpt_path): 271 | self.classifier.load_state_dict(torch.load(ckpt_path)) 272 | return True 273 | 274 | elif "gan" in model: 275 | f1, f2 = False, False 276 | dset_name = model.split('_')[0] 277 | g_ckpt_path = os.path.join(self.model_save_dir, "%s_generator.pth" % dset_name) 278 | if os.path.exists(g_ckpt_path): 279 | self.net_G.load_state_dict(torch.load(g_ckpt_path)) 280 | f1 = True 281 | 282 | d_ckpt_path = os.path.join(self.model_save_dir, "%s_discriminator.pth" % dset_name) 283 | if os.path.exists(d_ckpt_path): 284 | self.net_D.load_state_dict(torch.load(d_ckpt_path)) 285 | f2 = True 286 | 287 | return f1 and f2 288 | 289 | elif "final_classifier" in model: 290 | ckpt_path = os.path.join(self.model_save_dir, model + ".pth") 291 | if os.path.exists(ckpt_path): 292 | self.final_classifier.load_state_dict(torch.load(ckpt_path)) 293 | return True 294 | 295 | else: 296 | raise Exception("Trying to load unknown model: %s" % model) 297 | 298 | return False 299 | --------------------------------------------------------------------------------