├── src ├── clip │ ├── __init__.py │ ├── __pycache__ │ │ ├── clip.cpython-37.pyc │ │ ├── clip.cpython-310.pyc │ │ ├── model.cpython-310.pyc │ │ ├── model.cpython-37.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── simple_tokenizer.cpython-310.pyc │ │ └── simple_tokenizer.cpython-37.pyc │ ├── simple_tokenizer.py │ └── clip.py ├── ds │ ├── __pycache__ │ │ ├── coco.cpython-37.pyc │ │ ├── cub.cpython-310.pyc │ │ ├── cub.cpython-37.pyc │ │ ├── flo.cpython-310.pyc │ │ ├── flo.cpython-37.pyc │ │ ├── coco.cpython-310.pyc │ │ ├── flickr.cpython-37.pyc │ │ ├── vocab.cpython-310.pyc │ │ ├── vocab.cpython-37.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── flickr.cpython-310.pyc │ │ ├── _dataloader.cpython-37.pyc │ │ ├── _transforms.cpython-37.pyc │ │ ├── fashion200k.cpython-37.pyc │ │ ├── _dataloader.cpython-310.pyc │ │ ├── _transforms.cpython-310.pyc │ │ ├── fashion200k.cpython-310.pyc │ │ ├── simple_tokenizer.cpython-37.pyc │ │ └── simple_tokenizer.cpython-310.pyc │ ├── annotations │ │ ├── flo │ │ │ ├── testclasses.txt │ │ │ └── trainvalclasses.txt │ │ └── cub │ │ │ ├── valclasses1.txt │ │ │ ├── valclasses2.txt │ │ │ ├── valclasses3.txt │ │ │ ├── testclasses.txt │ │ │ ├── trainclasses3.txt │ │ │ ├── trainclasses2.txt │ │ │ ├── trainclasses1.txt │ │ │ └── trainvalclasses.txt │ ├── __init__.py │ ├── vocabs │ │ └── make_vocab.py │ ├── vocab.py │ ├── simple_tokenizer.py │ ├── flickr.py │ ├── flo.py │ ├── cub.py │ ├── fashion200k.py │ ├── _transforms.py │ └── coco.py ├── ds_lavis │ ├── __pycache__ │ │ ├── coco.cpython-310.pyc │ │ ├── cub.cpython-310.pyc │ │ ├── flo.cpython-310.pyc │ │ ├── flickr.cpython-310.pyc │ │ ├── vocab.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── _dataloader.cpython-310.pyc │ │ ├── fashion200k.cpython-310.pyc │ │ └── simple_tokenizer.cpython-310.pyc │ ├── annotations │ │ ├── flo │ │ │ ├── testclasses.txt │ │ │ └── trainvalclasses.txt │ │ └── cub │ │ │ ├── valclasses1.txt │ │ │ ├── valclasses2.txt │ │ │ ├── valclasses3.txt │ │ │ ├── testclasses.txt │ │ │ ├── trainclasses3.txt │ │ │ ├── trainclasses2.txt │ │ │ ├── trainclasses1.txt │ │ │ └── trainvalclasses.txt │ ├── __init__.py │ ├── vocabs │ │ └── make_vocab.py │ ├── vocab.py │ ├── simple_tokenizer.py │ ├── flickr.py │ ├── flo.py │ ├── cub.py │ ├── fashion200k.py │ ├── _transforms.py │ └── coco.py ├── losses.py ├── train_ProbVLM_CLIP.ipynb ├── utils_lavis.py └── networks.py ├── .gitignore ├── figs └── probvlm.png ├── LICENSE ├── README.md └── requirements.txt /src/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **.ipynb_checkpoints/ 2 | **.pkl 3 | **.zip 4 | **.npy 5 | **.txt.gz 6 | -------------------------------------------------------------------------------- /figs/probvlm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/figs/probvlm.png -------------------------------------------------------------------------------- /src/ds/__pycache__/coco.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/coco.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/cub.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/cub.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/cub.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/cub.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/flo.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/flo.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/flo.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/flo.cpython-37.pyc -------------------------------------------------------------------------------- /src/clip/__pycache__/clip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/clip/__pycache__/clip.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/coco.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/coco.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/flickr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/flickr.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/vocab.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/vocab.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/vocab.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/vocab.cpython-37.pyc -------------------------------------------------------------------------------- /src/clip/__pycache__/clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/clip/__pycache__/clip.cpython-310.pyc -------------------------------------------------------------------------------- /src/clip/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/clip/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /src/clip/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/clip/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/flickr.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/flickr.cpython-310.pyc -------------------------------------------------------------------------------- /src/clip/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/clip/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/clip/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/clip/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/_dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/_dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/fashion200k.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/fashion200k.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds_lavis/__pycache__/coco.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds_lavis/__pycache__/coco.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds_lavis/__pycache__/cub.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds_lavis/__pycache__/cub.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds_lavis/__pycache__/flo.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds_lavis/__pycache__/flo.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/_dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/_dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/_transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/_transforms.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/fashion200k.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/fashion200k.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds_lavis/__pycache__/flickr.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds_lavis/__pycache__/flickr.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds_lavis/__pycache__/vocab.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds_lavis/__pycache__/vocab.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/simple_tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/simple_tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds_lavis/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds_lavis/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/clip/__pycache__/simple_tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/clip/__pycache__/simple_tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /src/clip/__pycache__/simple_tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/clip/__pycache__/simple_tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /src/ds/__pycache__/simple_tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds/__pycache__/simple_tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds_lavis/__pycache__/_dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds_lavis/__pycache__/_dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds_lavis/__pycache__/fashion200k.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds_lavis/__pycache__/fashion200k.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds_lavis/__pycache__/simple_tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExplainableML/ProbVLM/HEAD/src/ds_lavis/__pycache__/simple_tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /src/ds/annotations/flo/testclasses.txt: -------------------------------------------------------------------------------- 1 | class_00001 2 | class_00002 3 | class_00003 4 | class_00004 5 | class_00005 6 | class_00006 7 | class_00007 8 | class_00008 9 | class_00009 10 | class_00010 11 | class_00011 12 | class_00012 13 | class_00013 14 | class_00014 15 | class_00015 16 | class_00016 17 | class_00017 18 | class_00018 19 | class_00019 20 | class_00020 21 | -------------------------------------------------------------------------------- /src/ds_lavis/annotations/flo/testclasses.txt: -------------------------------------------------------------------------------- 1 | class_00001 2 | class_00002 3 | class_00003 4 | class_00004 5 | class_00005 6 | class_00006 7 | class_00007 8 | class_00008 9 | class_00009 10 | class_00010 11 | class_00011 12 | class_00012 13 | class_00013 14 | class_00014 15 | class_00015 16 | class_00016 17 | class_00017 18 | class_00018 19 | class_00019 20 | class_00020 21 | -------------------------------------------------------------------------------- /src/ds/__init__.py: -------------------------------------------------------------------------------- 1 | """Modules for multi-modal datasets 2 | 3 | PCME 4 | Copyright (c) 2021-present NAVER Corp. 5 | MIT license 6 | """ 7 | 8 | from ds._dataloader import prepare_cub_dataloaders 9 | from ds._dataloader import prepare_coco_dataloaders, prepare_flickr_dataloaders 10 | from ds._dataloader import prepare_fashion_dataloaders 11 | from ds._dataloader import prepare_flo_dataloaders 12 | from ds.vocab import Vocabulary 13 | 14 | 15 | __all__ = [ 16 | 'Vocabulary', 17 | 'prepare_coco_dataloaders', 18 | 'prepare_cub_dataloaders', 19 | # 'prepare_coco_dataset_with_bbox', 20 | 'prepare_flickr_dataloaders', 21 | # 'prepare_flickr_dataset_with_bbox', 22 | 'prepare_fashion_dataloaders', 23 | 'prepare_flo_dataloaders' 24 | ] 25 | -------------------------------------------------------------------------------- /src/ds_lavis/__init__.py: -------------------------------------------------------------------------------- 1 | """Modules for multi-modal datasets 2 | 3 | PCME 4 | Copyright (c) 2021-present NAVER Corp. 5 | MIT license 6 | """ 7 | 8 | from ds_lavis._dataloader import prepare_cub_dataloaders 9 | from ds_lavis._dataloader import prepare_coco_dataloaders, prepare_flickr_dataloaders 10 | from ds_lavis._dataloader import prepare_fashion_dataloaders 11 | from ds_lavis._dataloader import prepare_flo_dataloaders 12 | from ds_lavis.vocab import Vocabulary 13 | 14 | 15 | __all__ = [ 16 | 'Vocabulary', 17 | 'prepare_coco_dataloaders', 18 | 'prepare_cub_dataloaders', 19 | # 'prepare_coco_dataset_with_bbox', 20 | 'prepare_flickr_dataloaders', 21 | # 'prepare_flickr_dataset_with_bbox', 22 | 'prepare_fashion_dataloaders', 23 | 'prepare_flo_dataloaders' 24 | ] 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 EML Tübingen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/ds/annotations/cub/valclasses1.txt: -------------------------------------------------------------------------------- 1 | 076.Dark_eyed_Junco 2 | 117.Clay_colored_Sparrow 3 | 150.Sage_Thrasher 4 | 182.Yellow_Warbler 5 | 140.Summer_Tanager 6 | 069.Rufous_Hummingbird 7 | 048.European_Goldfinch 8 | 114.Black_throated_Sparrow 9 | 109.American_Redstart 10 | 005.Crested_Auklet 11 | 051.Horned_Grebe 12 | 144.Common_Tern 13 | 177.Prothonotary_Warbler 14 | 153.Philadelphia_Vireo 15 | 189.Red_bellied_Woodpecker 16 | 151.Black_capped_Vireo 17 | 162.Canada_Warbler 18 | 089.Hooded_Merganser 19 | 155.Warbling_Vireo 20 | 059.California_Gull 21 | 066.Western_Gull 22 | 184.Louisiana_Waterthrush 23 | 198.Rock_Wren 24 | 063.Ivory_Gull 25 | 194.Cactus_Wren 26 | 195.Carolina_Wren 27 | 083.White_breasted_Kingfisher 28 | 080.Green_Kingfisher 29 | 036.Northern_Flicker 30 | 018.Spotted_Catbird 31 | 056.Pine_Grosbeak 32 | 192.Downy_Woodpecker 33 | 128.Seaside_Sparrow 34 | 073.Blue_Jay 35 | 121.Grasshopper_Sparrow 36 | 034.Gray_crowned_Rosy_Finch 37 | 064.Ring_billed_Gull 38 | 174.Palm_Warbler 39 | 118.House_Sparrow 40 | 061.Heermann_Gull 41 | 116.Chipping_Sparrow 42 | 029.American_Crow 43 | 009.Brewer_Blackbird 44 | 158.Bay_breasted_Warbler 45 | 086.Pacific_Loon 46 | 179.Tennessee_Warbler 47 | 074.Florida_Jay 48 | 007.Parakeet_Auklet 49 | 100.Brown_Pelican 50 | 141.Artic_Tern 51 | -------------------------------------------------------------------------------- /src/ds/annotations/cub/valclasses2.txt: -------------------------------------------------------------------------------- 1 | 061.Heermann_Gull 2 | 140.Summer_Tanager 3 | 019.Gray_Catbird 4 | 011.Rusty_Blackbird 5 | 066.Western_Gull 6 | 121.Grasshopper_Sparrow 7 | 144.Common_Tern 8 | 014.Indigo_Bunting 9 | 102.Western_Wood_Pewee 10 | 135.Bank_Swallow 11 | 134.Cape_Glossy_Starling 12 | 039.Least_Flycatcher 13 | 115.Brewer_Sparrow 14 | 133.White_throated_Sparrow 15 | 017.Cardinal 16 | 175.Pine_Warbler 17 | 172.Nashville_Warbler 18 | 063.Ivory_Gull 19 | 093.Clark_Nutcracker 20 | 106.Horned_Puffin 21 | 083.White_breasted_Kingfisher 22 | 056.Pine_Grosbeak 23 | 179.Tennessee_Warbler 24 | 009.Brewer_Blackbird 25 | 178.Swainson_Warbler 26 | 158.Bay_breasted_Warbler 27 | 069.Rufous_Hummingbird 28 | 150.Sage_Thrasher 29 | 197.Marsh_Wren 30 | 034.Gray_crowned_Rosy_Finch 31 | 181.Worm_eating_Warbler 32 | 015.Lazuli_Bunting 33 | 128.Seaside_Sparrow 34 | 058.Pigeon_Guillemot 35 | 086.Pacific_Loon 36 | 060.Glaucous_winged_Gull 37 | 035.Purple_Finch 38 | 018.Spotted_Catbird 39 | 187.American_Three_toed_Woodpecker 40 | 078.Gray_Kingbird 41 | 007.Parakeet_Auklet 42 | 196.House_Wren 43 | 174.Palm_Warbler 44 | 073.Blue_Jay 45 | 044.Frigatebird 46 | 038.Great_Crested_Flycatcher 47 | 162.Canada_Warbler 48 | 096.Hooded_Oriole 49 | 002.Laysan_Albatross 50 | 001.Black_footed_Albatross 51 | -------------------------------------------------------------------------------- /src/ds_lavis/annotations/cub/valclasses1.txt: -------------------------------------------------------------------------------- 1 | 076.Dark_eyed_Junco 2 | 117.Clay_colored_Sparrow 3 | 150.Sage_Thrasher 4 | 182.Yellow_Warbler 5 | 140.Summer_Tanager 6 | 069.Rufous_Hummingbird 7 | 048.European_Goldfinch 8 | 114.Black_throated_Sparrow 9 | 109.American_Redstart 10 | 005.Crested_Auklet 11 | 051.Horned_Grebe 12 | 144.Common_Tern 13 | 177.Prothonotary_Warbler 14 | 153.Philadelphia_Vireo 15 | 189.Red_bellied_Woodpecker 16 | 151.Black_capped_Vireo 17 | 162.Canada_Warbler 18 | 089.Hooded_Merganser 19 | 155.Warbling_Vireo 20 | 059.California_Gull 21 | 066.Western_Gull 22 | 184.Louisiana_Waterthrush 23 | 198.Rock_Wren 24 | 063.Ivory_Gull 25 | 194.Cactus_Wren 26 | 195.Carolina_Wren 27 | 083.White_breasted_Kingfisher 28 | 080.Green_Kingfisher 29 | 036.Northern_Flicker 30 | 018.Spotted_Catbird 31 | 056.Pine_Grosbeak 32 | 192.Downy_Woodpecker 33 | 128.Seaside_Sparrow 34 | 073.Blue_Jay 35 | 121.Grasshopper_Sparrow 36 | 034.Gray_crowned_Rosy_Finch 37 | 064.Ring_billed_Gull 38 | 174.Palm_Warbler 39 | 118.House_Sparrow 40 | 061.Heermann_Gull 41 | 116.Chipping_Sparrow 42 | 029.American_Crow 43 | 009.Brewer_Blackbird 44 | 158.Bay_breasted_Warbler 45 | 086.Pacific_Loon 46 | 179.Tennessee_Warbler 47 | 074.Florida_Jay 48 | 007.Parakeet_Auklet 49 | 100.Brown_Pelican 50 | 141.Artic_Tern 51 | -------------------------------------------------------------------------------- /src/ds_lavis/annotations/cub/valclasses2.txt: -------------------------------------------------------------------------------- 1 | 061.Heermann_Gull 2 | 140.Summer_Tanager 3 | 019.Gray_Catbird 4 | 011.Rusty_Blackbird 5 | 066.Western_Gull 6 | 121.Grasshopper_Sparrow 7 | 144.Common_Tern 8 | 014.Indigo_Bunting 9 | 102.Western_Wood_Pewee 10 | 135.Bank_Swallow 11 | 134.Cape_Glossy_Starling 12 | 039.Least_Flycatcher 13 | 115.Brewer_Sparrow 14 | 133.White_throated_Sparrow 15 | 017.Cardinal 16 | 175.Pine_Warbler 17 | 172.Nashville_Warbler 18 | 063.Ivory_Gull 19 | 093.Clark_Nutcracker 20 | 106.Horned_Puffin 21 | 083.White_breasted_Kingfisher 22 | 056.Pine_Grosbeak 23 | 179.Tennessee_Warbler 24 | 009.Brewer_Blackbird 25 | 178.Swainson_Warbler 26 | 158.Bay_breasted_Warbler 27 | 069.Rufous_Hummingbird 28 | 150.Sage_Thrasher 29 | 197.Marsh_Wren 30 | 034.Gray_crowned_Rosy_Finch 31 | 181.Worm_eating_Warbler 32 | 015.Lazuli_Bunting 33 | 128.Seaside_Sparrow 34 | 058.Pigeon_Guillemot 35 | 086.Pacific_Loon 36 | 060.Glaucous_winged_Gull 37 | 035.Purple_Finch 38 | 018.Spotted_Catbird 39 | 187.American_Three_toed_Woodpecker 40 | 078.Gray_Kingbird 41 | 007.Parakeet_Auklet 42 | 196.House_Wren 43 | 174.Palm_Warbler 44 | 073.Blue_Jay 45 | 044.Frigatebird 46 | 038.Great_Crested_Flycatcher 47 | 162.Canada_Warbler 48 | 096.Hooded_Oriole 49 | 002.Laysan_Albatross 50 | 001.Black_footed_Albatross 51 | -------------------------------------------------------------------------------- /src/ds/annotations/cub/valclasses3.txt: -------------------------------------------------------------------------------- 1 | 187.American_Three_toed_Woodpecker 2 | 018.Spotted_Catbird 3 | 064.Ring_billed_Gull 4 | 102.Western_Wood_Pewee 5 | 171.Myrtle_Warbler 6 | 181.Worm_eating_Warbler 7 | 112.Great_Grey_Shrike 8 | 030.Fish_Crow 9 | 125.Lincoln_Sparrow 10 | 114.Black_throated_Sparrow 11 | 041.Scissor_tailed_Flycatcher 12 | 142.Black_Tern 13 | 115.Brewer_Sparrow 14 | 108.White_necked_Raven 15 | 155.Warbling_Vireo 16 | 008.Rhinoceros_Auklet 17 | 080.Green_Kingfisher 18 | 009.Brewer_Blackbird 19 | 105.Whip_poor_Will 20 | 050.Eared_Grebe 21 | 039.Least_Flycatcher 22 | 075.Green_Jay 23 | 175.Pine_Warbler 24 | 005.Crested_Auklet 25 | 160.Black_throated_Blue_Warbler 26 | 086.Pacific_Loon 27 | 079.Belted_Kingfisher 28 | 088.Western_Meadowlark 29 | 013.Bobolink 30 | 038.Great_Crested_Flycatcher 31 | 144.Common_Tern 32 | 020.Yellow_breasted_Chat 33 | 060.Glaucous_winged_Gull 34 | 037.Acadian_Flycatcher 35 | 007.Parakeet_Auklet 36 | 153.Philadelphia_Vireo 37 | 056.Pine_Grosbeak 38 | 126.Nelson_Sharp_tailed_Sparrow 39 | 015.Lazuli_Bunting 40 | 032.Mangrove_Cuckoo 41 | 184.Louisiana_Waterthrush 42 | 006.Least_Auklet 43 | 131.Vesper_Sparrow 44 | 090.Red_breasted_Merganser 45 | 166.Golden_winged_Warbler 46 | 147.Least_Tern 47 | 101.White_Pelican 48 | 179.Tennessee_Warbler 49 | 129.Song_Sparrow 50 | 110.Geococcyx 51 | -------------------------------------------------------------------------------- /src/ds_lavis/annotations/cub/valclasses3.txt: -------------------------------------------------------------------------------- 1 | 187.American_Three_toed_Woodpecker 2 | 018.Spotted_Catbird 3 | 064.Ring_billed_Gull 4 | 102.Western_Wood_Pewee 5 | 171.Myrtle_Warbler 6 | 181.Worm_eating_Warbler 7 | 112.Great_Grey_Shrike 8 | 030.Fish_Crow 9 | 125.Lincoln_Sparrow 10 | 114.Black_throated_Sparrow 11 | 041.Scissor_tailed_Flycatcher 12 | 142.Black_Tern 13 | 115.Brewer_Sparrow 14 | 108.White_necked_Raven 15 | 155.Warbling_Vireo 16 | 008.Rhinoceros_Auklet 17 | 080.Green_Kingfisher 18 | 009.Brewer_Blackbird 19 | 105.Whip_poor_Will 20 | 050.Eared_Grebe 21 | 039.Least_Flycatcher 22 | 075.Green_Jay 23 | 175.Pine_Warbler 24 | 005.Crested_Auklet 25 | 160.Black_throated_Blue_Warbler 26 | 086.Pacific_Loon 27 | 079.Belted_Kingfisher 28 | 088.Western_Meadowlark 29 | 013.Bobolink 30 | 038.Great_Crested_Flycatcher 31 | 144.Common_Tern 32 | 020.Yellow_breasted_Chat 33 | 060.Glaucous_winged_Gull 34 | 037.Acadian_Flycatcher 35 | 007.Parakeet_Auklet 36 | 153.Philadelphia_Vireo 37 | 056.Pine_Grosbeak 38 | 126.Nelson_Sharp_tailed_Sparrow 39 | 015.Lazuli_Bunting 40 | 032.Mangrove_Cuckoo 41 | 184.Louisiana_Waterthrush 42 | 006.Least_Auklet 43 | 131.Vesper_Sparrow 44 | 090.Red_breasted_Merganser 45 | 166.Golden_winged_Warbler 46 | 147.Least_Tern 47 | 101.White_Pelican 48 | 179.Tennessee_Warbler 49 | 129.Song_Sparrow 50 | 110.Geococcyx 51 | -------------------------------------------------------------------------------- /src/ds/annotations/cub/testclasses.txt: -------------------------------------------------------------------------------- 1 | 043.Yellow_bellied_Flycatcher 2 | 111.Loggerhead_Shrike 3 | 023.Brandt_Cormorant 4 | 098.Scott_Oriole 5 | 055.Evening_Grosbeak 6 | 130.Tree_Sparrow 7 | 139.Scarlet_Tanager 8 | 123.Henslow_Sparrow 9 | 156.White_eyed_Vireo 10 | 124.Le_Conte_Sparrow 11 | 200.Common_Yellowthroat 12 | 072.Pomarine_Jaeger 13 | 173.Orange_crowned_Warbler 14 | 028.Brown_Creeper 15 | 119.Field_Sparrow 16 | 165.Chestnut_sided_Warbler 17 | 103.Sayornis 18 | 180.Wilson_Warbler 19 | 077.Tropical_Kingbird 20 | 012.Yellow_headed_Blackbird 21 | 045.Northern_Fulmar 22 | 190.Red_cockaded_Woodpecker 23 | 191.Red_headed_Woodpecker 24 | 138.Tree_Swallow 25 | 157.Yellow_throated_Vireo 26 | 052.Pied_billed_Grebe 27 | 033.Yellow_billed_Cuckoo 28 | 164.Cerulean_Warbler 29 | 031.Black_billed_Cuckoo 30 | 143.Caspian_Tern 31 | 094.White_breasted_Nuthatch 32 | 070.Green_Violetear 33 | 097.Orchard_Oriole 34 | 091.Mockingbird 35 | 104.American_Pipit 36 | 127.Savannah_Sparrow 37 | 161.Blue_winged_Warbler 38 | 049.Boat_tailed_Grackle 39 | 169.Magnolia_Warbler 40 | 148.Green_tailed_Towhee 41 | 113.Baird_Sparrow 42 | 087.Mallard 43 | 163.Cape_May_Warbler 44 | 136.Barn_Swallow 45 | 188.Pileated_Woodpecker 46 | 084.Red_legged_Kittiwake 47 | 026.Bronzed_Cowbird 48 | 004.Groove_billed_Ani 49 | 132.White_crowned_Sparrow 50 | 168.Kentucky_Warbler 51 | -------------------------------------------------------------------------------- /src/ds_lavis/annotations/cub/testclasses.txt: -------------------------------------------------------------------------------- 1 | 043.Yellow_bellied_Flycatcher 2 | 111.Loggerhead_Shrike 3 | 023.Brandt_Cormorant 4 | 098.Scott_Oriole 5 | 055.Evening_Grosbeak 6 | 130.Tree_Sparrow 7 | 139.Scarlet_Tanager 8 | 123.Henslow_Sparrow 9 | 156.White_eyed_Vireo 10 | 124.Le_Conte_Sparrow 11 | 200.Common_Yellowthroat 12 | 072.Pomarine_Jaeger 13 | 173.Orange_crowned_Warbler 14 | 028.Brown_Creeper 15 | 119.Field_Sparrow 16 | 165.Chestnut_sided_Warbler 17 | 103.Sayornis 18 | 180.Wilson_Warbler 19 | 077.Tropical_Kingbird 20 | 012.Yellow_headed_Blackbird 21 | 045.Northern_Fulmar 22 | 190.Red_cockaded_Woodpecker 23 | 191.Red_headed_Woodpecker 24 | 138.Tree_Swallow 25 | 157.Yellow_throated_Vireo 26 | 052.Pied_billed_Grebe 27 | 033.Yellow_billed_Cuckoo 28 | 164.Cerulean_Warbler 29 | 031.Black_billed_Cuckoo 30 | 143.Caspian_Tern 31 | 094.White_breasted_Nuthatch 32 | 070.Green_Violetear 33 | 097.Orchard_Oriole 34 | 091.Mockingbird 35 | 104.American_Pipit 36 | 127.Savannah_Sparrow 37 | 161.Blue_winged_Warbler 38 | 049.Boat_tailed_Grackle 39 | 169.Magnolia_Warbler 40 | 148.Green_tailed_Towhee 41 | 113.Baird_Sparrow 42 | 087.Mallard 43 | 163.Cape_May_Warbler 44 | 136.Barn_Swallow 45 | 188.Pileated_Woodpecker 46 | 084.Red_legged_Kittiwake 47 | 026.Bronzed_Cowbird 48 | 004.Groove_billed_Ani 49 | 132.White_crowned_Sparrow 50 | 168.Kentucky_Warbler 51 | -------------------------------------------------------------------------------- /src/ds/annotations/flo/trainvalclasses.txt: -------------------------------------------------------------------------------- 1 | class_00021 2 | class_00022 3 | class_00023 4 | class_00024 5 | class_00025 6 | class_00026 7 | class_00027 8 | class_00028 9 | class_00029 10 | class_00030 11 | class_00031 12 | class_00032 13 | class_00033 14 | class_00034 15 | class_00035 16 | class_00036 17 | class_00037 18 | class_00038 19 | class_00039 20 | class_00040 21 | class_00041 22 | class_00042 23 | class_00043 24 | class_00044 25 | class_00045 26 | class_00046 27 | class_00047 28 | class_00048 29 | class_00049 30 | class_00050 31 | class_00051 32 | class_00052 33 | class_00053 34 | class_00054 35 | class_00055 36 | class_00056 37 | class_00057 38 | class_00058 39 | class_00059 40 | class_00060 41 | class_00061 42 | class_00062 43 | class_00063 44 | class_00064 45 | class_00065 46 | class_00066 47 | class_00067 48 | class_00068 49 | class_00069 50 | class_00070 51 | class_00071 52 | class_00072 53 | class_00073 54 | class_00074 55 | class_00075 56 | class_00076 57 | class_00077 58 | class_00078 59 | class_00079 60 | class_00080 61 | class_00081 62 | class_00082 63 | class_00083 64 | class_00084 65 | class_00085 66 | class_00086 67 | class_00087 68 | class_00088 69 | class_00089 70 | class_00090 71 | class_00091 72 | class_00092 73 | class_00093 74 | class_00094 75 | class_00095 76 | class_00096 77 | class_00097 78 | class_00098 79 | class_00099 80 | class_00100 81 | class_00101 82 | class_00102 83 | -------------------------------------------------------------------------------- /src/ds_lavis/annotations/flo/trainvalclasses.txt: -------------------------------------------------------------------------------- 1 | class_00021 2 | class_00022 3 | class_00023 4 | class_00024 5 | class_00025 6 | class_00026 7 | class_00027 8 | class_00028 9 | class_00029 10 | class_00030 11 | class_00031 12 | class_00032 13 | class_00033 14 | class_00034 15 | class_00035 16 | class_00036 17 | class_00037 18 | class_00038 19 | class_00039 20 | class_00040 21 | class_00041 22 | class_00042 23 | class_00043 24 | class_00044 25 | class_00045 26 | class_00046 27 | class_00047 28 | class_00048 29 | class_00049 30 | class_00050 31 | class_00051 32 | class_00052 33 | class_00053 34 | class_00054 35 | class_00055 36 | class_00056 37 | class_00057 38 | class_00058 39 | class_00059 40 | class_00060 41 | class_00061 42 | class_00062 43 | class_00063 44 | class_00064 45 | class_00065 46 | class_00066 47 | class_00067 48 | class_00068 49 | class_00069 50 | class_00070 51 | class_00071 52 | class_00072 53 | class_00073 54 | class_00074 55 | class_00075 56 | class_00076 57 | class_00077 58 | class_00078 59 | class_00079 60 | class_00080 61 | class_00081 62 | class_00082 63 | class_00083 64 | class_00084 65 | class_00085 66 | class_00086 67 | class_00087 68 | class_00088 69 | class_00089 70 | class_00090 71 | class_00091 72 | class_00092 73 | class_00093 74 | class_00094 75 | class_00095 76 | class_00096 77 | class_00097 78 | class_00098 79 | class_00099 80 | class_00100 81 | class_00101 82 | class_00102 83 | -------------------------------------------------------------------------------- /src/ds/annotations/cub/trainclasses3.txt: -------------------------------------------------------------------------------- 1 | 011.Rusty_Blackbird 2 | 069.Rufous_Hummingbird 3 | 071.Long_tailed_Jaeger 4 | 107.Common_Raven 5 | 017.Cardinal 6 | 019.Gray_Catbird 7 | 140.Summer_Tanager 8 | 054.Blue_Grosbeak 9 | 159.Black_and_white_Warbler 10 | 192.Downy_Woodpecker 11 | 092.Nighthawk 12 | 089.Hooded_Merganser 13 | 186.Cedar_Waxwing 14 | 152.Blue_headed_Vireo 15 | 170.Mourning_Warbler 16 | 046.Gadwall 17 | 118.House_Sparrow 18 | 027.Shiny_Cowbird 19 | 003.Sooty_Albatross 20 | 085.Horned_Lark 21 | 001.Black_footed_Albatross 22 | 174.Palm_Warbler 23 | 162.Canada_Warbler 24 | 121.Grasshopper_Sparrow 25 | 117.Clay_colored_Sparrow 26 | 109.American_Redstart 27 | 074.Florida_Jay 28 | 063.Ivory_Gull 29 | 048.European_Goldfinch 30 | 100.Brown_Pelican 31 | 051.Horned_Grebe 32 | 076.Dark_eyed_Junco 33 | 099.Ovenbird 34 | 120.Fox_Sparrow 35 | 029.American_Crow 36 | 065.Slaty_backed_Gull 37 | 042.Vermilion_Flycatcher 38 | 195.Carolina_Wren 39 | 150.Sage_Thrasher 40 | 040.Olive_sided_Flycatcher 41 | 151.Black_capped_Vireo 42 | 193.Bewick_Wren 43 | 053.Western_Grebe 44 | 057.Rose_breasted_Grosbeak 45 | 093.Clark_Nutcracker 46 | 177.Prothonotary_Warbler 47 | 024.Red_faced_Cormorant 48 | 128.Seaside_Sparrow 49 | 106.Horned_Puffin 50 | 059.California_Gull 51 | 022.Chuck_will_Widow 52 | 158.Bay_breasted_Warbler 53 | 146.Forsters_Tern 54 | 141.Artic_Tern 55 | 014.Indigo_Bunting 56 | 062.Herring_Gull 57 | 122.Harris_Sparrow 58 | 034.Gray_crowned_Rosy_Finch 59 | 182.Yellow_Warbler 60 | 073.Blue_Jay 61 | 185.Bohemian_Waxwing 62 | 154.Red_eyed_Vireo 63 | 172.Nashville_Warbler 64 | 133.White_throated_Sparrow 65 | 035.Purple_Finch 66 | 058.Pigeon_Guillemot 67 | 083.White_breasted_Kingfisher 68 | 194.Cactus_Wren 69 | 095.Baltimore_Oriole 70 | 044.Frigatebird 71 | 021.Eastern_Towhee 72 | 047.American_Goldfinch 73 | 078.Gray_Kingbird 74 | 081.Pied_Kingfisher 75 | 096.Hooded_Oriole 76 | 145.Elegant_Tern 77 | 167.Hooded_Warbler 78 | 068.Ruby_throated_Hummingbird 79 | 189.Red_bellied_Woodpecker 80 | 197.Marsh_Wren 81 | 134.Cape_Glossy_Starling 82 | 010.Red_winged_Blackbird 83 | 067.Anna_Hummingbird 84 | 196.House_Wren 85 | 066.Western_Gull 86 | 199.Winter_Wren 87 | 016.Painted_Bunting 88 | 116.Chipping_Sparrow 89 | 176.Prairie_Warbler 90 | 183.Northern_Waterthrush 91 | 137.Cliff_Swallow 92 | 036.Northern_Flicker 93 | 149.Brown_Thrasher 94 | 178.Swainson_Warbler 95 | 198.Rock_Wren 96 | 082.Ringed_Kingfisher 97 | 002.Laysan_Albatross 98 | 135.Bank_Swallow 99 | 061.Heermann_Gull 100 | 025.Pelagic_Cormorant 101 | -------------------------------------------------------------------------------- /src/ds_lavis/annotations/cub/trainclasses3.txt: -------------------------------------------------------------------------------- 1 | 011.Rusty_Blackbird 2 | 069.Rufous_Hummingbird 3 | 071.Long_tailed_Jaeger 4 | 107.Common_Raven 5 | 017.Cardinal 6 | 019.Gray_Catbird 7 | 140.Summer_Tanager 8 | 054.Blue_Grosbeak 9 | 159.Black_and_white_Warbler 10 | 192.Downy_Woodpecker 11 | 092.Nighthawk 12 | 089.Hooded_Merganser 13 | 186.Cedar_Waxwing 14 | 152.Blue_headed_Vireo 15 | 170.Mourning_Warbler 16 | 046.Gadwall 17 | 118.House_Sparrow 18 | 027.Shiny_Cowbird 19 | 003.Sooty_Albatross 20 | 085.Horned_Lark 21 | 001.Black_footed_Albatross 22 | 174.Palm_Warbler 23 | 162.Canada_Warbler 24 | 121.Grasshopper_Sparrow 25 | 117.Clay_colored_Sparrow 26 | 109.American_Redstart 27 | 074.Florida_Jay 28 | 063.Ivory_Gull 29 | 048.European_Goldfinch 30 | 100.Brown_Pelican 31 | 051.Horned_Grebe 32 | 076.Dark_eyed_Junco 33 | 099.Ovenbird 34 | 120.Fox_Sparrow 35 | 029.American_Crow 36 | 065.Slaty_backed_Gull 37 | 042.Vermilion_Flycatcher 38 | 195.Carolina_Wren 39 | 150.Sage_Thrasher 40 | 040.Olive_sided_Flycatcher 41 | 151.Black_capped_Vireo 42 | 193.Bewick_Wren 43 | 053.Western_Grebe 44 | 057.Rose_breasted_Grosbeak 45 | 093.Clark_Nutcracker 46 | 177.Prothonotary_Warbler 47 | 024.Red_faced_Cormorant 48 | 128.Seaside_Sparrow 49 | 106.Horned_Puffin 50 | 059.California_Gull 51 | 022.Chuck_will_Widow 52 | 158.Bay_breasted_Warbler 53 | 146.Forsters_Tern 54 | 141.Artic_Tern 55 | 014.Indigo_Bunting 56 | 062.Herring_Gull 57 | 122.Harris_Sparrow 58 | 034.Gray_crowned_Rosy_Finch 59 | 182.Yellow_Warbler 60 | 073.Blue_Jay 61 | 185.Bohemian_Waxwing 62 | 154.Red_eyed_Vireo 63 | 172.Nashville_Warbler 64 | 133.White_throated_Sparrow 65 | 035.Purple_Finch 66 | 058.Pigeon_Guillemot 67 | 083.White_breasted_Kingfisher 68 | 194.Cactus_Wren 69 | 095.Baltimore_Oriole 70 | 044.Frigatebird 71 | 021.Eastern_Towhee 72 | 047.American_Goldfinch 73 | 078.Gray_Kingbird 74 | 081.Pied_Kingfisher 75 | 096.Hooded_Oriole 76 | 145.Elegant_Tern 77 | 167.Hooded_Warbler 78 | 068.Ruby_throated_Hummingbird 79 | 189.Red_bellied_Woodpecker 80 | 197.Marsh_Wren 81 | 134.Cape_Glossy_Starling 82 | 010.Red_winged_Blackbird 83 | 067.Anna_Hummingbird 84 | 196.House_Wren 85 | 066.Western_Gull 86 | 199.Winter_Wren 87 | 016.Painted_Bunting 88 | 116.Chipping_Sparrow 89 | 176.Prairie_Warbler 90 | 183.Northern_Waterthrush 91 | 137.Cliff_Swallow 92 | 036.Northern_Flicker 93 | 149.Brown_Thrasher 94 | 178.Swainson_Warbler 95 | 198.Rock_Wren 96 | 082.Ringed_Kingfisher 97 | 002.Laysan_Albatross 98 | 135.Bank_Swallow 99 | 061.Heermann_Gull 100 | 025.Pelagic_Cormorant 101 | -------------------------------------------------------------------------------- /src/ds/annotations/cub/trainclasses2.txt: -------------------------------------------------------------------------------- 1 | 108.White_necked_Raven 2 | 099.Ovenbird 3 | 185.Bohemian_Waxwing 4 | 192.Downy_Woodpecker 5 | 036.Northern_Flicker 6 | 131.Vesper_Sparrow 7 | 005.Crested_Auklet 8 | 147.Least_Tern 9 | 189.Red_bellied_Woodpecker 10 | 071.Long_tailed_Jaeger 11 | 167.Hooded_Warbler 12 | 116.Chipping_Sparrow 13 | 003.Sooty_Albatross 14 | 064.Ring_billed_Gull 15 | 171.Myrtle_Warbler 16 | 053.Western_Grebe 17 | 050.Eared_Grebe 18 | 122.Harris_Sparrow 19 | 184.Louisiana_Waterthrush 20 | 183.Northern_Waterthrush 21 | 195.Carolina_Wren 22 | 040.Olive_sided_Flycatcher 23 | 142.Black_Tern 24 | 166.Golden_winged_Warbler 25 | 117.Clay_colored_Sparrow 26 | 090.Red_breasted_Merganser 27 | 065.Slaty_backed_Gull 28 | 029.American_Crow 29 | 024.Red_faced_Cormorant 30 | 129.Song_Sparrow 31 | 041.Scissor_tailed_Flycatcher 32 | 032.Mangrove_Cuckoo 33 | 193.Bewick_Wren 34 | 022.Chuck_will_Widow 35 | 100.Brown_Pelican 36 | 198.Rock_Wren 37 | 020.Yellow_breasted_Chat 38 | 149.Brown_Thrasher 39 | 027.Shiny_Cowbird 40 | 081.Pied_Kingfisher 41 | 177.Prothonotary_Warbler 42 | 107.Common_Raven 43 | 125.Lincoln_Sparrow 44 | 010.Red_winged_Blackbird 45 | 105.Whip_poor_Will 46 | 057.Rose_breasted_Grosbeak 47 | 154.Red_eyed_Vireo 48 | 199.Winter_Wren 49 | 160.Black_throated_Blue_Warbler 50 | 194.Cactus_Wren 51 | 037.Acadian_Flycatcher 52 | 054.Blue_Grosbeak 53 | 016.Painted_Bunting 54 | 062.Herring_Gull 55 | 088.Western_Meadowlark 56 | 155.Warbling_Vireo 57 | 076.Dark_eyed_Junco 58 | 074.Florida_Jay 59 | 025.Pelagic_Cormorant 60 | 176.Prairie_Warbler 61 | 182.Yellow_Warbler 62 | 042.Vermilion_Flycatcher 63 | 048.European_Goldfinch 64 | 141.Artic_Tern 65 | 114.Black_throated_Sparrow 66 | 030.Fish_Crow 67 | 109.American_Redstart 68 | 159.Black_and_white_Warbler 69 | 120.Fox_Sparrow 70 | 006.Least_Auklet 71 | 110.Geococcyx 72 | 170.Mourning_Warbler 73 | 146.Forsters_Tern 74 | 112.Great_Grey_Shrike 75 | 089.Hooded_Merganser 76 | 047.American_Goldfinch 77 | 067.Anna_Hummingbird 78 | 092.Nighthawk 79 | 152.Blue_headed_Vireo 80 | 021.Eastern_Towhee 81 | 118.House_Sparrow 82 | 059.California_Gull 83 | 145.Elegant_Tern 84 | 085.Horned_Lark 85 | 080.Green_Kingfisher 86 | 068.Ruby_throated_Hummingbird 87 | 151.Black_capped_Vireo 88 | 079.Belted_Kingfisher 89 | 153.Philadelphia_Vireo 90 | 186.Cedar_Waxwing 91 | 082.Ringed_Kingfisher 92 | 095.Baltimore_Oriole 93 | 046.Gadwall 94 | 101.White_Pelican 95 | 137.Cliff_Swallow 96 | 126.Nelson_Sharp_tailed_Sparrow 97 | 075.Green_Jay 98 | 051.Horned_Grebe 99 | 013.Bobolink 100 | 008.Rhinoceros_Auklet 101 | -------------------------------------------------------------------------------- /src/ds/annotations/cub/trainclasses1.txt: -------------------------------------------------------------------------------- 1 | 108.White_necked_Raven 2 | 167.Hooded_Warbler 3 | 142.Black_Tern 4 | 039.Least_Flycatcher 5 | 002.Laysan_Albatross 6 | 187.American_Three_toed_Woodpecker 7 | 106.Horned_Puffin 8 | 181.Worm_eating_Warbler 9 | 060.Glaucous_winged_Gull 10 | 015.Lazuli_Bunting 11 | 067.Anna_Hummingbird 12 | 107.Common_Raven 13 | 013.Bobolink 14 | 105.Whip_poor_Will 15 | 088.Western_Meadowlark 16 | 147.Least_Tern 17 | 006.Least_Auklet 18 | 160.Black_throated_Blue_Warbler 19 | 110.Geococcyx 20 | 183.Northern_Waterthrush 21 | 024.Red_faced_Cormorant 22 | 152.Blue_headed_Vireo 23 | 022.Chuck_will_Widow 24 | 008.Rhinoceros_Auklet 25 | 019.Gray_Catbird 26 | 154.Red_eyed_Vireo 27 | 185.Bohemian_Waxwing 28 | 068.Ruby_throated_Hummingbird 29 | 196.House_Wren 30 | 122.Harris_Sparrow 31 | 014.Indigo_Bunting 32 | 020.Yellow_breasted_Chat 33 | 054.Blue_Grosbeak 34 | 038.Great_Crested_Flycatcher 35 | 115.Brewer_Sparrow 36 | 079.Belted_Kingfisher 37 | 101.White_Pelican 38 | 027.Shiny_Cowbird 39 | 186.Cedar_Waxwing 40 | 053.Western_Grebe 41 | 099.Ovenbird 42 | 003.Sooty_Albatross 43 | 030.Fish_Crow 44 | 112.Great_Grey_Shrike 45 | 092.Nighthawk 46 | 166.Golden_winged_Warbler 47 | 071.Long_tailed_Jaeger 48 | 078.Gray_Kingbird 49 | 172.Nashville_Warbler 50 | 159.Black_and_white_Warbler 51 | 131.Vesper_Sparrow 52 | 197.Marsh_Wren 53 | 017.Cardinal 54 | 042.Vermilion_Flycatcher 55 | 133.White_throated_Sparrow 56 | 085.Horned_Lark 57 | 176.Prairie_Warbler 58 | 016.Painted_Bunting 59 | 129.Song_Sparrow 60 | 171.Myrtle_Warbler 61 | 090.Red_breasted_Merganser 62 | 146.Forsters_Tern 63 | 044.Frigatebird 64 | 035.Purple_Finch 65 | 065.Slaty_backed_Gull 66 | 041.Scissor_tailed_Flycatcher 67 | 050.Eared_Grebe 68 | 081.Pied_Kingfisher 69 | 062.Herring_Gull 70 | 082.Ringed_Kingfisher 71 | 125.Lincoln_Sparrow 72 | 170.Mourning_Warbler 73 | 021.Eastern_Towhee 74 | 193.Bewick_Wren 75 | 096.Hooded_Oriole 76 | 095.Baltimore_Oriole 77 | 040.Olive_sided_Flycatcher 78 | 037.Acadian_Flycatcher 79 | 075.Green_Jay 80 | 058.Pigeon_Guillemot 81 | 145.Elegant_Tern 82 | 102.Western_Wood_Pewee 83 | 025.Pelagic_Cormorant 84 | 001.Black_footed_Albatross 85 | 093.Clark_Nutcracker 86 | 137.Cliff_Swallow 87 | 149.Brown_Thrasher 88 | 175.Pine_Warbler 89 | 047.American_Goldfinch 90 | 199.Winter_Wren 91 | 178.Swainson_Warbler 92 | 126.Nelson_Sharp_tailed_Sparrow 93 | 046.Gadwall 94 | 011.Rusty_Blackbird 95 | 135.Bank_Swallow 96 | 032.Mangrove_Cuckoo 97 | 120.Fox_Sparrow 98 | 010.Red_winged_Blackbird 99 | 057.Rose_breasted_Grosbeak 100 | 134.Cape_Glossy_Starling 101 | -------------------------------------------------------------------------------- /src/ds_lavis/annotations/cub/trainclasses2.txt: -------------------------------------------------------------------------------- 1 | 108.White_necked_Raven 2 | 099.Ovenbird 3 | 185.Bohemian_Waxwing 4 | 192.Downy_Woodpecker 5 | 036.Northern_Flicker 6 | 131.Vesper_Sparrow 7 | 005.Crested_Auklet 8 | 147.Least_Tern 9 | 189.Red_bellied_Woodpecker 10 | 071.Long_tailed_Jaeger 11 | 167.Hooded_Warbler 12 | 116.Chipping_Sparrow 13 | 003.Sooty_Albatross 14 | 064.Ring_billed_Gull 15 | 171.Myrtle_Warbler 16 | 053.Western_Grebe 17 | 050.Eared_Grebe 18 | 122.Harris_Sparrow 19 | 184.Louisiana_Waterthrush 20 | 183.Northern_Waterthrush 21 | 195.Carolina_Wren 22 | 040.Olive_sided_Flycatcher 23 | 142.Black_Tern 24 | 166.Golden_winged_Warbler 25 | 117.Clay_colored_Sparrow 26 | 090.Red_breasted_Merganser 27 | 065.Slaty_backed_Gull 28 | 029.American_Crow 29 | 024.Red_faced_Cormorant 30 | 129.Song_Sparrow 31 | 041.Scissor_tailed_Flycatcher 32 | 032.Mangrove_Cuckoo 33 | 193.Bewick_Wren 34 | 022.Chuck_will_Widow 35 | 100.Brown_Pelican 36 | 198.Rock_Wren 37 | 020.Yellow_breasted_Chat 38 | 149.Brown_Thrasher 39 | 027.Shiny_Cowbird 40 | 081.Pied_Kingfisher 41 | 177.Prothonotary_Warbler 42 | 107.Common_Raven 43 | 125.Lincoln_Sparrow 44 | 010.Red_winged_Blackbird 45 | 105.Whip_poor_Will 46 | 057.Rose_breasted_Grosbeak 47 | 154.Red_eyed_Vireo 48 | 199.Winter_Wren 49 | 160.Black_throated_Blue_Warbler 50 | 194.Cactus_Wren 51 | 037.Acadian_Flycatcher 52 | 054.Blue_Grosbeak 53 | 016.Painted_Bunting 54 | 062.Herring_Gull 55 | 088.Western_Meadowlark 56 | 155.Warbling_Vireo 57 | 076.Dark_eyed_Junco 58 | 074.Florida_Jay 59 | 025.Pelagic_Cormorant 60 | 176.Prairie_Warbler 61 | 182.Yellow_Warbler 62 | 042.Vermilion_Flycatcher 63 | 048.European_Goldfinch 64 | 141.Artic_Tern 65 | 114.Black_throated_Sparrow 66 | 030.Fish_Crow 67 | 109.American_Redstart 68 | 159.Black_and_white_Warbler 69 | 120.Fox_Sparrow 70 | 006.Least_Auklet 71 | 110.Geococcyx 72 | 170.Mourning_Warbler 73 | 146.Forsters_Tern 74 | 112.Great_Grey_Shrike 75 | 089.Hooded_Merganser 76 | 047.American_Goldfinch 77 | 067.Anna_Hummingbird 78 | 092.Nighthawk 79 | 152.Blue_headed_Vireo 80 | 021.Eastern_Towhee 81 | 118.House_Sparrow 82 | 059.California_Gull 83 | 145.Elegant_Tern 84 | 085.Horned_Lark 85 | 080.Green_Kingfisher 86 | 068.Ruby_throated_Hummingbird 87 | 151.Black_capped_Vireo 88 | 079.Belted_Kingfisher 89 | 153.Philadelphia_Vireo 90 | 186.Cedar_Waxwing 91 | 082.Ringed_Kingfisher 92 | 095.Baltimore_Oriole 93 | 046.Gadwall 94 | 101.White_Pelican 95 | 137.Cliff_Swallow 96 | 126.Nelson_Sharp_tailed_Sparrow 97 | 075.Green_Jay 98 | 051.Horned_Grebe 99 | 013.Bobolink 100 | 008.Rhinoceros_Auklet 101 | -------------------------------------------------------------------------------- /src/ds_lavis/annotations/cub/trainclasses1.txt: -------------------------------------------------------------------------------- 1 | 108.White_necked_Raven 2 | 167.Hooded_Warbler 3 | 142.Black_Tern 4 | 039.Least_Flycatcher 5 | 002.Laysan_Albatross 6 | 187.American_Three_toed_Woodpecker 7 | 106.Horned_Puffin 8 | 181.Worm_eating_Warbler 9 | 060.Glaucous_winged_Gull 10 | 015.Lazuli_Bunting 11 | 067.Anna_Hummingbird 12 | 107.Common_Raven 13 | 013.Bobolink 14 | 105.Whip_poor_Will 15 | 088.Western_Meadowlark 16 | 147.Least_Tern 17 | 006.Least_Auklet 18 | 160.Black_throated_Blue_Warbler 19 | 110.Geococcyx 20 | 183.Northern_Waterthrush 21 | 024.Red_faced_Cormorant 22 | 152.Blue_headed_Vireo 23 | 022.Chuck_will_Widow 24 | 008.Rhinoceros_Auklet 25 | 019.Gray_Catbird 26 | 154.Red_eyed_Vireo 27 | 185.Bohemian_Waxwing 28 | 068.Ruby_throated_Hummingbird 29 | 196.House_Wren 30 | 122.Harris_Sparrow 31 | 014.Indigo_Bunting 32 | 020.Yellow_breasted_Chat 33 | 054.Blue_Grosbeak 34 | 038.Great_Crested_Flycatcher 35 | 115.Brewer_Sparrow 36 | 079.Belted_Kingfisher 37 | 101.White_Pelican 38 | 027.Shiny_Cowbird 39 | 186.Cedar_Waxwing 40 | 053.Western_Grebe 41 | 099.Ovenbird 42 | 003.Sooty_Albatross 43 | 030.Fish_Crow 44 | 112.Great_Grey_Shrike 45 | 092.Nighthawk 46 | 166.Golden_winged_Warbler 47 | 071.Long_tailed_Jaeger 48 | 078.Gray_Kingbird 49 | 172.Nashville_Warbler 50 | 159.Black_and_white_Warbler 51 | 131.Vesper_Sparrow 52 | 197.Marsh_Wren 53 | 017.Cardinal 54 | 042.Vermilion_Flycatcher 55 | 133.White_throated_Sparrow 56 | 085.Horned_Lark 57 | 176.Prairie_Warbler 58 | 016.Painted_Bunting 59 | 129.Song_Sparrow 60 | 171.Myrtle_Warbler 61 | 090.Red_breasted_Merganser 62 | 146.Forsters_Tern 63 | 044.Frigatebird 64 | 035.Purple_Finch 65 | 065.Slaty_backed_Gull 66 | 041.Scissor_tailed_Flycatcher 67 | 050.Eared_Grebe 68 | 081.Pied_Kingfisher 69 | 062.Herring_Gull 70 | 082.Ringed_Kingfisher 71 | 125.Lincoln_Sparrow 72 | 170.Mourning_Warbler 73 | 021.Eastern_Towhee 74 | 193.Bewick_Wren 75 | 096.Hooded_Oriole 76 | 095.Baltimore_Oriole 77 | 040.Olive_sided_Flycatcher 78 | 037.Acadian_Flycatcher 79 | 075.Green_Jay 80 | 058.Pigeon_Guillemot 81 | 145.Elegant_Tern 82 | 102.Western_Wood_Pewee 83 | 025.Pelagic_Cormorant 84 | 001.Black_footed_Albatross 85 | 093.Clark_Nutcracker 86 | 137.Cliff_Swallow 87 | 149.Brown_Thrasher 88 | 175.Pine_Warbler 89 | 047.American_Goldfinch 90 | 199.Winter_Wren 91 | 178.Swainson_Warbler 92 | 126.Nelson_Sharp_tailed_Sparrow 93 | 046.Gadwall 94 | 011.Rusty_Blackbird 95 | 135.Bank_Swallow 96 | 032.Mangrove_Cuckoo 97 | 120.Fox_Sparrow 98 | 010.Red_winged_Blackbird 99 | 057.Rose_breasted_Grosbeak 100 | 134.Cape_Glossy_Starling 101 | -------------------------------------------------------------------------------- /src/ds/vocabs/make_vocab.py: -------------------------------------------------------------------------------- 1 | """ a script for making vocabulary pickle. 2 | 3 | Original code: 4 | https://github.com/yalesong/pvse/blob/master/vocab.py 5 | """ 6 | 7 | import nltk 8 | import pickle 9 | from collections import Counter 10 | import fire 11 | import os 12 | from tqdm import tqdm 13 | 14 | 15 | class Vocabulary(object): 16 | """Simple vocabulary wrapper.""" 17 | 18 | def __init__(self): 19 | self.word2idx = {} 20 | self.idx2word = {} 21 | self.idx = 0 22 | 23 | def add_word(self, word): 24 | if word not in self.word2idx: 25 | self.word2idx[word] = self.idx 26 | self.idx2word[self.idx] = word 27 | self.idx += 1 28 | 29 | def __call__(self, word): 30 | if word not in self.word2idx: 31 | return self.word2idx[''] 32 | return self.word2idx[word] 33 | 34 | def __len__(self): 35 | return len(self.word2idx) 36 | 37 | 38 | def from_txt(txt): 39 | captions = [] 40 | with open(txt, 'rb') as f: 41 | for line in f: 42 | captions.append(line.strip()) 43 | return captions 44 | 45 | 46 | def build_vocab(data_path, threshold): 47 | """Build a simple vocabulary wrapper.""" 48 | counter = Counter() 49 | captions = [] 50 | for cname in os.listdir(data_path): 51 | for fname in os.listdir(os.path.join(data_path, cname)): 52 | full_path = os.path.join(data_path, cname, fname) 53 | captions.extend(from_txt(full_path)) 54 | 55 | for i, caption in tqdm(enumerate(captions), total=len(captions)): 56 | tokens = nltk.tokenize.word_tokenize( 57 | caption.lower().decode('utf-8')) 58 | counter.update(tokens) 59 | 60 | # Discard if the occurrence of the word is less than min_word_cnt. 61 | words = [word for word, cnt in counter.items() if cnt >= threshold] 62 | 63 | # Create a vocab wrapper and add some special tokens. 64 | vocab = Vocabulary() 65 | vocab.add_word('') 66 | vocab.add_word('') 67 | vocab.add_word('') 68 | vocab.add_word('') 69 | 70 | # Add words to the vocabulary. 71 | for i, word in enumerate(words): 72 | vocab.add_word(word) 73 | dict_vocab = { 74 | 'idx': vocab.idx, 75 | 'idx2word': vocab.idx2word, 76 | 'word2idx': vocab.word2idx, 77 | } 78 | return dict_vocab 79 | 80 | 81 | def main(data_path): 82 | vocab = build_vocab(data_path, threshold=4) 83 | with open('./vocab_local.pkl', 'wb') as f: 84 | pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) 85 | 86 | 87 | if __name__ == '__main__': 88 | fire.Fire(main) 89 | -------------------------------------------------------------------------------- /src/ds_lavis/vocabs/make_vocab.py: -------------------------------------------------------------------------------- 1 | """ a script for making vocabulary pickle. 2 | 3 | Original code: 4 | https://github.com/yalesong/pvse/blob/master/vocab.py 5 | """ 6 | 7 | import nltk 8 | import pickle 9 | from collections import Counter 10 | import fire 11 | import os 12 | from tqdm import tqdm 13 | 14 | 15 | class Vocabulary(object): 16 | """Simple vocabulary wrapper.""" 17 | 18 | def __init__(self): 19 | self.word2idx = {} 20 | self.idx2word = {} 21 | self.idx = 0 22 | 23 | def add_word(self, word): 24 | if word not in self.word2idx: 25 | self.word2idx[word] = self.idx 26 | self.idx2word[self.idx] = word 27 | self.idx += 1 28 | 29 | def __call__(self, word): 30 | if word not in self.word2idx: 31 | return self.word2idx[''] 32 | return self.word2idx[word] 33 | 34 | def __len__(self): 35 | return len(self.word2idx) 36 | 37 | 38 | def from_txt(txt): 39 | captions = [] 40 | with open(txt, 'rb') as f: 41 | for line in f: 42 | captions.append(line.strip()) 43 | return captions 44 | 45 | 46 | def build_vocab(data_path, threshold): 47 | """Build a simple vocabulary wrapper.""" 48 | counter = Counter() 49 | captions = [] 50 | for cname in os.listdir(data_path): 51 | for fname in os.listdir(os.path.join(data_path, cname)): 52 | full_path = os.path.join(data_path, cname, fname) 53 | captions.extend(from_txt(full_path)) 54 | 55 | for i, caption in tqdm(enumerate(captions), total=len(captions)): 56 | tokens = nltk.tokenize.word_tokenize( 57 | caption.lower().decode('utf-8')) 58 | counter.update(tokens) 59 | 60 | # Discard if the occurrence of the word is less than min_word_cnt. 61 | words = [word for word, cnt in counter.items() if cnt >= threshold] 62 | 63 | # Create a vocab wrapper and add some special tokens. 64 | vocab = Vocabulary() 65 | vocab.add_word('') 66 | vocab.add_word('') 67 | vocab.add_word('') 68 | vocab.add_word('') 69 | 70 | # Add words to the vocabulary. 71 | for i, word in enumerate(words): 72 | vocab.add_word(word) 73 | dict_vocab = { 74 | 'idx': vocab.idx, 75 | 'idx2word': vocab.idx2word, 76 | 'word2idx': vocab.word2idx, 77 | } 78 | return dict_vocab 79 | 80 | 81 | def main(data_path): 82 | vocab = build_vocab(data_path, threshold=4) 83 | with open('./vocab_local.pkl', 'wb') as f: 84 | pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) 85 | 86 | 87 | if __name__ == '__main__': 88 | fire.Fire(main) 89 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | from torch import Tensor 6 | 7 | class GenGaussLoss(nn.Module): 8 | def __init__( 9 | self, reduction='mean', 10 | alpha_eps = 1e-4, beta_eps=1e-4, 11 | resi_min = 1e-4, resi_max=1e3 12 | ) -> None: 13 | super(GenGaussLoss, self).__init__() 14 | self.reduction = reduction 15 | self.alpha_eps = alpha_eps 16 | self.beta_eps = beta_eps 17 | self.resi_min = resi_min 18 | self.resi_max = resi_max 19 | 20 | def forward( 21 | self, 22 | mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor 23 | ): 24 | one_over_alpha1 = one_over_alpha + self.alpha_eps 25 | beta1 = beta + self.beta_eps 26 | 27 | resi = torch.abs(mean - target) 28 | # resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max) 29 | resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max) 30 | ## check if resi has nans 31 | if torch.sum(resi != resi) > 0: 32 | print('resi has nans!!') 33 | return None 34 | 35 | log_one_over_alpha = torch.log(one_over_alpha1) 36 | log_beta = torch.log(beta1) 37 | lgamma_beta = torch.lgamma(torch.pow(beta1, -1)) 38 | 39 | if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0: 40 | print('log_one_over_alpha has nan') 41 | if torch.sum(lgamma_beta != lgamma_beta) > 0: 42 | print('lgamma_beta has nan') 43 | if torch.sum(log_beta != log_beta) > 0: 44 | print('log_beta has nan') 45 | 46 | l = resi - log_one_over_alpha + lgamma_beta - log_beta 47 | 48 | if self.reduction == 'mean': 49 | return l.mean() 50 | elif self.reduction == 'sum': 51 | return l.sum() 52 | else: 53 | print('Reduction not supported') 54 | return None 55 | 56 | class TempCombLoss(nn.Module): 57 | def __init__( 58 | self, reduction='mean', 59 | alpha_eps = 1e-4, beta_eps=1e-4, 60 | resi_min = 1e-4, resi_max=1e3 61 | ) -> None: 62 | super(TempCombLoss, self).__init__() 63 | self.reduction = reduction 64 | self.alpha_eps = alpha_eps 65 | self.beta_eps = beta_eps 66 | self.resi_min = resi_min 67 | self.resi_max = resi_max 68 | 69 | self.L_GenGauss = GenGaussLoss( 70 | reduction=self.reduction, 71 | alpha_eps=self.alpha_eps, beta_eps=self.beta_eps, 72 | resi_min=self.resi_min, resi_max=self.resi_max 73 | ) 74 | self.L_l1 = nn.L1Loss(reduction=self.reduction) 75 | 76 | def forward( 77 | self, 78 | mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor, 79 | T1: float, T2: float 80 | ): 81 | l1 = self.L_l1(mean, target) 82 | l2 = self.L_GenGauss(mean, one_over_alpha, beta, target) 83 | l = T1*l1 + T2*l2 84 | 85 | return l 86 | 87 | 88 | # x1 = torch.randn(4,3,32,32) 89 | # x2 = torch.rand(4,3,32,32) 90 | # x3 = torch.rand(4,3,32,32) 91 | # x4 = torch.randn(4,3,32,32) 92 | 93 | # L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3) 94 | # L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3) 95 | # print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2)) -------------------------------------------------------------------------------- /src/train_ProbVLM_CLIP.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "2986807e-7850-404f-a957-eaeb16371d24", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "\n", 12 | "from os.path import join as ospj\n", 13 | "from os.path import expanduser\n", 14 | "from munch import Munch as mch\n", 15 | "import numpy as np\n", 16 | "\n", 17 | "from ds import prepare_coco_dataloaders, prepare_flickr_dataloaders, prepare_cub_dataloaders, prepare_flo_dataloaders\n", 18 | "\n", 19 | "from utils import *\n", 20 | "from networks import *\n", 21 | "from train_ProbVLM import *\n", 22 | "\n", 23 | "import matplotlib.pyplot as plt" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "id": "fbdb7f39-293e-48fe-bd3b-da616157f4a9", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "dataset = 'CUB' # coco or flickr\n", 34 | "data_dir = ospj('/mnt/Datasets/', dataset) # e.g. ospj(expanduser('~'), 'Documents', 'jm', 'data', dataset)\n", 35 | "dataloader_config = mch({\n", 36 | " 'batch_size': 64,\n", 37 | " 'random_erasing_prob': 0.,\n", 38 | " 'traindata_shuffle': True\n", 39 | "})\n", 40 | "loaders,vocab = load_data_loader(dataset, data_dir, dataloader_config)\n", 41 | "cub_train_loader, cub_valid_loader, cub_test_loader = loaders['train'], loaders['val'], loaders['test']" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "7da29997-5483-4fa0-b927-74b38dc36cdf", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# clip_net = load_model('cuda')\n", 52 | "CLIP_Net = load_model(device='cuda', model_path=None)\n", 53 | "ProbVLM_Net = BayesCap_for_CLIP(\n", 54 | " inp_dim=512,\n", 55 | " out_dim=512,\n", 56 | " hid_dim=256,\n", 57 | " num_layers=3,\n", 58 | " p_drop=0.05,\n", 59 | ")" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "e1513f40-28d1-4d4d-8d9b-e113c8cd4183", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "train_ProbVLM(\n", 70 | " CLIP_Net,\n", 71 | " ProbVLM_Net,\n", 72 | " cub_train_loader,\n", 73 | " cub_valid_loader,\n", 74 | " Cri = TempCombLoss(),\n", 75 | " device='cuda',\n", 76 | " dtype=torch.cuda.FloatTensor,\n", 77 | " init_lr=8e-5,\n", 78 | " num_epochs=500,\n", 79 | " eval_every=5,\n", 80 | " ckpt_path='../ckpt/ProbVLM_Net',\n", 81 | " T1=1e0,\n", 82 | " T2=1e-4\n", 83 | ")" 84 | ] 85 | } 86 | ], 87 | "metadata": { 88 | "kernelspec": { 89 | "display_name": "Python 3 (ipykernel)", 90 | "language": "python", 91 | "name": "python3" 92 | }, 93 | "language_info": { 94 | "codemirror_mode": { 95 | "name": "ipython", 96 | "version": 3 97 | }, 98 | "file_extension": ".py", 99 | "mimetype": "text/x-python", 100 | "name": "python", 101 | "nbconvert_exporter": "python", 102 | "pygments_lexer": "ipython3", 103 | "version": "3.10.5" 104 | } 105 | }, 106 | "nbformat": 4, 107 | "nbformat_minor": 5 108 | } 109 | -------------------------------------------------------------------------------- /src/ds/annotations/cub/trainvalclasses.txt: -------------------------------------------------------------------------------- 1 | 002.Laysan_Albatross 2 | 003.Sooty_Albatross 3 | 015.Lazuli_Bunting 4 | 016.Painted_Bunting 5 | 020.Yellow_breasted_Chat 6 | 022.Chuck_will_Widow 7 | 047.American_Goldfinch 8 | 048.European_Goldfinch 9 | 067.Anna_Hummingbird 10 | 068.Ruby_throated_Hummingbird 11 | 069.Rufous_Hummingbird 12 | 073.Blue_Jay 13 | 074.Florida_Jay 14 | 075.Green_Jay 15 | 076.Dark_eyed_Junco 16 | 089.Hooded_Merganser 17 | 090.Red_breasted_Merganser 18 | 100.Brown_Pelican 19 | 149.Brown_Thrasher 20 | 150.Sage_Thrasher 21 | 001.Black_footed_Albatross 22 | 014.Indigo_Bunting 23 | 034.Gray_crowned_Rosy_Finch 24 | 035.Purple_Finch 25 | 101.White_Pelican 26 | 120.Fox_Sparrow 27 | 110.Geococcyx 28 | 085.Horned_Lark 29 | 008.Rhinoceros_Auklet 30 | 054.Blue_Grosbeak 31 | 171.Myrtle_Warbler 32 | 041.Scissor_tailed_Flycatcher 33 | 056.Pine_Grosbeak 34 | 109.American_Redstart 35 | 187.American_Three_toed_Woodpecker 36 | 175.Pine_Warbler 37 | 030.Fish_Crow 38 | 195.Carolina_Wren 39 | 051.Horned_Grebe 40 | 107.Common_Raven 41 | 117.Clay_colored_Sparrow 42 | 135.Bank_Swallow 43 | 134.Cape_Glossy_Starling 44 | 046.Gadwall 45 | 147.Least_Tern 46 | 037.Acadian_Flycatcher 47 | 160.Black_throated_Blue_Warbler 48 | 126.Nelson_Sharp_tailed_Sparrow 49 | 137.Cliff_Swallow 50 | 125.Lincoln_Sparrow 51 | 027.Shiny_Cowbird 52 | 189.Red_bellied_Woodpecker 53 | 197.Marsh_Wren 54 | 186.Cedar_Waxwing 55 | 158.Bay_breasted_Warbler 56 | 064.Ring_billed_Gull 57 | 044.Frigatebird 58 | 007.Parakeet_Auklet 59 | 183.Northern_Waterthrush 60 | 142.Black_Tern 61 | 086.Pacific_Loon 62 | 159.Black_and_white_Warbler 63 | 081.Pied_Kingfisher 64 | 128.Seaside_Sparrow 65 | 011.Rusty_Blackbird 66 | 145.Elegant_Tern 67 | 029.American_Crow 68 | 166.Golden_winged_Warbler 69 | 059.California_Gull 70 | 095.Baltimore_Oriole 71 | 155.Warbling_Vireo 72 | 010.Red_winged_Blackbird 73 | 141.Artic_Tern 74 | 102.Western_Wood_Pewee 75 | 079.Belted_Kingfisher 76 | 106.Horned_Puffin 77 | 192.Downy_Woodpecker 78 | 005.Crested_Auklet 79 | 083.White_breasted_Kingfisher 80 | 178.Swainson_Warbler 81 | 162.Canada_Warbler 82 | 194.Cactus_Wren 83 | 196.House_Wren 84 | 050.Eared_Grebe 85 | 039.Least_Flycatcher 86 | 105.Whip_poor_Will 87 | 036.Northern_Flicker 88 | 032.Mangrove_Cuckoo 89 | 146.Forsters_Tern 90 | 082.Ringed_Kingfisher 91 | 060.Glaucous_winged_Gull 92 | 144.Common_Tern 93 | 199.Winter_Wren 94 | 093.Clark_Nutcracker 95 | 198.Rock_Wren 96 | 066.Western_Gull 97 | 099.Ovenbird 98 | 053.Western_Grebe 99 | 151.Black_capped_Vireo 100 | 018.Spotted_Catbird 101 | 152.Blue_headed_Vireo 102 | 116.Chipping_Sparrow 103 | 061.Heermann_Gull 104 | 025.Pelagic_Cormorant 105 | 024.Red_faced_Cormorant 106 | 078.Gray_Kingbird 107 | 017.Cardinal 108 | 176.Prairie_Warbler 109 | 058.Pigeon_Guillemot 110 | 021.Eastern_Towhee 111 | 193.Bewick_Wren 112 | 057.Rose_breasted_Grosbeak 113 | 040.Olive_sided_Flycatcher 114 | 153.Philadelphia_Vireo 115 | 088.Western_Meadowlark 116 | 013.Bobolink 117 | 118.House_Sparrow 118 | 121.Grasshopper_Sparrow 119 | 179.Tennessee_Warbler 120 | 062.Herring_Gull 121 | 154.Red_eyed_Vireo 122 | 092.Nighthawk 123 | 038.Great_Crested_Flycatcher 124 | 140.Summer_Tanager 125 | 182.Yellow_Warbler 126 | 096.Hooded_Oriole 127 | 172.Nashville_Warbler 128 | 071.Long_tailed_Jaeger 129 | 042.Vermilion_Flycatcher 130 | 185.Bohemian_Waxwing 131 | 177.Prothonotary_Warbler 132 | 019.Gray_Catbird 133 | 065.Slaty_backed_Gull 134 | 009.Brewer_Blackbird 135 | 112.Great_Grey_Shrike 136 | 063.Ivory_Gull 137 | 006.Least_Auklet 138 | 080.Green_Kingfisher 139 | 181.Worm_eating_Warbler 140 | 108.White_necked_Raven 141 | 122.Harris_Sparrow 142 | 115.Brewer_Sparrow 143 | 184.Louisiana_Waterthrush 144 | 167.Hooded_Warbler 145 | 129.Song_Sparrow 146 | 133.White_throated_Sparrow 147 | 114.Black_throated_Sparrow 148 | 170.Mourning_Warbler 149 | 131.Vesper_Sparrow 150 | 174.Palm_Warbler 151 | -------------------------------------------------------------------------------- /src/ds_lavis/annotations/cub/trainvalclasses.txt: -------------------------------------------------------------------------------- 1 | 002.Laysan_Albatross 2 | 003.Sooty_Albatross 3 | 015.Lazuli_Bunting 4 | 016.Painted_Bunting 5 | 020.Yellow_breasted_Chat 6 | 022.Chuck_will_Widow 7 | 047.American_Goldfinch 8 | 048.European_Goldfinch 9 | 067.Anna_Hummingbird 10 | 068.Ruby_throated_Hummingbird 11 | 069.Rufous_Hummingbird 12 | 073.Blue_Jay 13 | 074.Florida_Jay 14 | 075.Green_Jay 15 | 076.Dark_eyed_Junco 16 | 089.Hooded_Merganser 17 | 090.Red_breasted_Merganser 18 | 100.Brown_Pelican 19 | 149.Brown_Thrasher 20 | 150.Sage_Thrasher 21 | 001.Black_footed_Albatross 22 | 014.Indigo_Bunting 23 | 034.Gray_crowned_Rosy_Finch 24 | 035.Purple_Finch 25 | 101.White_Pelican 26 | 120.Fox_Sparrow 27 | 110.Geococcyx 28 | 085.Horned_Lark 29 | 008.Rhinoceros_Auklet 30 | 054.Blue_Grosbeak 31 | 171.Myrtle_Warbler 32 | 041.Scissor_tailed_Flycatcher 33 | 056.Pine_Grosbeak 34 | 109.American_Redstart 35 | 187.American_Three_toed_Woodpecker 36 | 175.Pine_Warbler 37 | 030.Fish_Crow 38 | 195.Carolina_Wren 39 | 051.Horned_Grebe 40 | 107.Common_Raven 41 | 117.Clay_colored_Sparrow 42 | 135.Bank_Swallow 43 | 134.Cape_Glossy_Starling 44 | 046.Gadwall 45 | 147.Least_Tern 46 | 037.Acadian_Flycatcher 47 | 160.Black_throated_Blue_Warbler 48 | 126.Nelson_Sharp_tailed_Sparrow 49 | 137.Cliff_Swallow 50 | 125.Lincoln_Sparrow 51 | 027.Shiny_Cowbird 52 | 189.Red_bellied_Woodpecker 53 | 197.Marsh_Wren 54 | 186.Cedar_Waxwing 55 | 158.Bay_breasted_Warbler 56 | 064.Ring_billed_Gull 57 | 044.Frigatebird 58 | 007.Parakeet_Auklet 59 | 183.Northern_Waterthrush 60 | 142.Black_Tern 61 | 086.Pacific_Loon 62 | 159.Black_and_white_Warbler 63 | 081.Pied_Kingfisher 64 | 128.Seaside_Sparrow 65 | 011.Rusty_Blackbird 66 | 145.Elegant_Tern 67 | 029.American_Crow 68 | 166.Golden_winged_Warbler 69 | 059.California_Gull 70 | 095.Baltimore_Oriole 71 | 155.Warbling_Vireo 72 | 010.Red_winged_Blackbird 73 | 141.Artic_Tern 74 | 102.Western_Wood_Pewee 75 | 079.Belted_Kingfisher 76 | 106.Horned_Puffin 77 | 192.Downy_Woodpecker 78 | 005.Crested_Auklet 79 | 083.White_breasted_Kingfisher 80 | 178.Swainson_Warbler 81 | 162.Canada_Warbler 82 | 194.Cactus_Wren 83 | 196.House_Wren 84 | 050.Eared_Grebe 85 | 039.Least_Flycatcher 86 | 105.Whip_poor_Will 87 | 036.Northern_Flicker 88 | 032.Mangrove_Cuckoo 89 | 146.Forsters_Tern 90 | 082.Ringed_Kingfisher 91 | 060.Glaucous_winged_Gull 92 | 144.Common_Tern 93 | 199.Winter_Wren 94 | 093.Clark_Nutcracker 95 | 198.Rock_Wren 96 | 066.Western_Gull 97 | 099.Ovenbird 98 | 053.Western_Grebe 99 | 151.Black_capped_Vireo 100 | 018.Spotted_Catbird 101 | 152.Blue_headed_Vireo 102 | 116.Chipping_Sparrow 103 | 061.Heermann_Gull 104 | 025.Pelagic_Cormorant 105 | 024.Red_faced_Cormorant 106 | 078.Gray_Kingbird 107 | 017.Cardinal 108 | 176.Prairie_Warbler 109 | 058.Pigeon_Guillemot 110 | 021.Eastern_Towhee 111 | 193.Bewick_Wren 112 | 057.Rose_breasted_Grosbeak 113 | 040.Olive_sided_Flycatcher 114 | 153.Philadelphia_Vireo 115 | 088.Western_Meadowlark 116 | 013.Bobolink 117 | 118.House_Sparrow 118 | 121.Grasshopper_Sparrow 119 | 179.Tennessee_Warbler 120 | 062.Herring_Gull 121 | 154.Red_eyed_Vireo 122 | 092.Nighthawk 123 | 038.Great_Crested_Flycatcher 124 | 140.Summer_Tanager 125 | 182.Yellow_Warbler 126 | 096.Hooded_Oriole 127 | 172.Nashville_Warbler 128 | 071.Long_tailed_Jaeger 129 | 042.Vermilion_Flycatcher 130 | 185.Bohemian_Waxwing 131 | 177.Prothonotary_Warbler 132 | 019.Gray_Catbird 133 | 065.Slaty_backed_Gull 134 | 009.Brewer_Blackbird 135 | 112.Great_Grey_Shrike 136 | 063.Ivory_Gull 137 | 006.Least_Auklet 138 | 080.Green_Kingfisher 139 | 181.Worm_eating_Warbler 140 | 108.White_necked_Raven 141 | 122.Harris_Sparrow 142 | 115.Brewer_Sparrow 143 | 184.Louisiana_Waterthrush 144 | 167.Hooded_Warbler 145 | 129.Song_Sparrow 146 | 133.White_throated_Sparrow 147 | 114.Black_throated_Sparrow 148 | 170.Mourning_Warbler 149 | 131.Vesper_Sparrow 150 | 174.Palm_Warbler 151 | -------------------------------------------------------------------------------- /src/ds/vocab.py: -------------------------------------------------------------------------------- 1 | """ Create a vocabulary wrapper. 2 | 3 | Original code: 4 | https://github.com/yalesong/pvse/blob/master/vocab.py 5 | """ 6 | 7 | from collections import Counter 8 | import json 9 | import os 10 | import pickle 11 | 12 | import fire 13 | from nltk.tokenize import word_tokenize 14 | from pycocotools.coco import COCO 15 | 16 | ANNOTATIONS = { 17 | 'mrw': ['mrw-v1.0.json'], 18 | 'tgif': ['tgif-v1.0.tsv'], 19 | 'coco': ['annotations/captions_train2014.json', 20 | 'annotations/captions_val2014.json'], 21 | } 22 | 23 | 24 | class Vocabulary(object): 25 | """Simple vocabulary wrapper.""" 26 | 27 | def __init__(self): 28 | self.idx = 0 29 | self.word2idx = {} 30 | self.idx2word = {} 31 | 32 | def add_word(self, word): 33 | if word not in self.word2idx: 34 | self.word2idx[word] = self.idx 35 | self.idx2word[self.idx] = word 36 | self.idx += 1 37 | 38 | def load_from_pickle(self, data_path): 39 | with open(data_path, 'rb') as fin: 40 | data = pickle.load(fin) 41 | self.idx = data['idx'] 42 | self.word2idx = data['word2idx'] 43 | self.idx2word = data['idx2word'] 44 | 45 | def __call__(self, word): 46 | if word not in self.word2idx: 47 | return self.word2idx[''] 48 | return self.word2idx[word] 49 | 50 | def __len__(self): 51 | return len(self.word2idx) 52 | 53 | 54 | def from_tgif_tsv(path): 55 | captions = [line.strip().split('\t')[1] 56 | for line in open(path, 'r').readlines()] 57 | return captions 58 | 59 | 60 | def from_mrw_json(path): 61 | dataset = json.load(open(path, 'r')) 62 | captions = [] 63 | for datum in dataset: 64 | cap = datum['sentence'] 65 | cap = cap.replace('/r/', '') 66 | cap = cap.replace('r/', '') 67 | cap = cap.replace('/u/', '') 68 | cap = cap.replace('u/', '') 69 | cap = cap.replace('..', '') 70 | cap = cap.replace('/', ' ') 71 | cap = cap.replace('-', ' ') 72 | captions += [cap] 73 | return captions 74 | 75 | 76 | def from_coco_json(path): 77 | coco = COCO(path) 78 | ids = coco.anns.keys() 79 | captions = [] 80 | for idx in ids: 81 | captions.append(str(coco.anns[idx]['caption'])) 82 | 83 | return captions 84 | 85 | 86 | def from_txt(txt): 87 | captions = [] 88 | with open(txt, 'rb') as f: 89 | for line in f: 90 | captions.append(line.strip()) 91 | return captions 92 | 93 | 94 | def build_vocab(data_path, data_name, jsons, threshold): 95 | """Build a simple vocabulary wrapper.""" 96 | counter = Counter() 97 | for path in jsons[data_name]: 98 | full_path = os.path.join(os.path.join(data_path, data_name), path) 99 | if data_name == 'tgif': 100 | captions = from_tgif_tsv(full_path) 101 | elif data_name == 'mrw': 102 | captions = from_mrw_json(full_path) 103 | elif data_name == 'coco': 104 | captions = from_coco_json(full_path) 105 | else: 106 | captions = from_txt(full_path) 107 | 108 | for caption in captions: 109 | tokens = word_tokenize(caption.lower()) 110 | counter.update(tokens) 111 | 112 | # Discard if the occurrence of the word is less than min_word_cnt. 113 | words = [word for word, cnt in counter.items() if cnt >= threshold] 114 | print('Vocabulary size: {}'.format(len(words))) 115 | 116 | # Create a vocab wrapper and add some special tokens. 117 | vocab = Vocabulary() 118 | vocab.add_word('') 119 | vocab.add_word('') 120 | vocab.add_word('') 121 | vocab.add_word('') 122 | 123 | # Add words to the vocabulary. 124 | for word in words: 125 | vocab.add_word(word) 126 | return vocab 127 | 128 | 129 | def main(data_path, data_name, threshold=0): 130 | vocab = build_vocab(data_path, data_name, jsons=ANNOTATIONS, threshold=threshold) 131 | if not os.path.isdir('./vocab'): 132 | os.makedirs('./vocab') 133 | with open('./vocab/%s_vocab.pkl' % data_name, 'wb') as f: 134 | pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) 135 | print("Saved vocabulary file to ", './vocab/%s_vocab.pkl' % data_name) 136 | 137 | 138 | if __name__ == '__main__': 139 | fire.Fire(main) 140 | -------------------------------------------------------------------------------- /src/ds_lavis/vocab.py: -------------------------------------------------------------------------------- 1 | """ Create a vocabulary wrapper. 2 | 3 | Original code: 4 | https://github.com/yalesong/pvse/blob/master/vocab.py 5 | """ 6 | 7 | from collections import Counter 8 | import json 9 | import os 10 | import pickle 11 | 12 | import fire 13 | from nltk.tokenize import word_tokenize 14 | from pycocotools.coco import COCO 15 | 16 | ANNOTATIONS = { 17 | 'mrw': ['mrw-v1.0.json'], 18 | 'tgif': ['tgif-v1.0.tsv'], 19 | 'coco': ['annotations/captions_train2014.json', 20 | 'annotations/captions_val2014.json'], 21 | } 22 | 23 | 24 | class Vocabulary(object): 25 | """Simple vocabulary wrapper.""" 26 | 27 | def __init__(self): 28 | self.idx = 0 29 | self.word2idx = {} 30 | self.idx2word = {} 31 | 32 | def add_word(self, word): 33 | if word not in self.word2idx: 34 | self.word2idx[word] = self.idx 35 | self.idx2word[self.idx] = word 36 | self.idx += 1 37 | 38 | def load_from_pickle(self, data_path): 39 | with open(data_path, 'rb') as fin: 40 | data = pickle.load(fin) 41 | self.idx = data['idx'] 42 | self.word2idx = data['word2idx'] 43 | self.idx2word = data['idx2word'] 44 | 45 | def __call__(self, word): 46 | if word not in self.word2idx: 47 | return self.word2idx[''] 48 | return self.word2idx[word] 49 | 50 | def __len__(self): 51 | return len(self.word2idx) 52 | 53 | 54 | def from_tgif_tsv(path): 55 | captions = [line.strip().split('\t')[1] 56 | for line in open(path, 'r').readlines()] 57 | return captions 58 | 59 | 60 | def from_mrw_json(path): 61 | dataset = json.load(open(path, 'r')) 62 | captions = [] 63 | for datum in dataset: 64 | cap = datum['sentence'] 65 | cap = cap.replace('/r/', '') 66 | cap = cap.replace('r/', '') 67 | cap = cap.replace('/u/', '') 68 | cap = cap.replace('u/', '') 69 | cap = cap.replace('..', '') 70 | cap = cap.replace('/', ' ') 71 | cap = cap.replace('-', ' ') 72 | captions += [cap] 73 | return captions 74 | 75 | 76 | def from_coco_json(path): 77 | coco = COCO(path) 78 | ids = coco.anns.keys() 79 | captions = [] 80 | for idx in ids: 81 | captions.append(str(coco.anns[idx]['caption'])) 82 | 83 | return captions 84 | 85 | 86 | def from_txt(txt): 87 | captions = [] 88 | with open(txt, 'rb') as f: 89 | for line in f: 90 | captions.append(line.strip()) 91 | return captions 92 | 93 | 94 | def build_vocab(data_path, data_name, jsons, threshold): 95 | """Build a simple vocabulary wrapper.""" 96 | counter = Counter() 97 | for path in jsons[data_name]: 98 | full_path = os.path.join(os.path.join(data_path, data_name), path) 99 | if data_name == 'tgif': 100 | captions = from_tgif_tsv(full_path) 101 | elif data_name == 'mrw': 102 | captions = from_mrw_json(full_path) 103 | elif data_name == 'coco': 104 | captions = from_coco_json(full_path) 105 | else: 106 | captions = from_txt(full_path) 107 | 108 | for caption in captions: 109 | tokens = word_tokenize(caption.lower()) 110 | counter.update(tokens) 111 | 112 | # Discard if the occurrence of the word is less than min_word_cnt. 113 | words = [word for word, cnt in counter.items() if cnt >= threshold] 114 | print('Vocabulary size: {}'.format(len(words))) 115 | 116 | # Create a vocab wrapper and add some special tokens. 117 | vocab = Vocabulary() 118 | vocab.add_word('') 119 | vocab.add_word('') 120 | vocab.add_word('') 121 | vocab.add_word('') 122 | 123 | # Add words to the vocabulary. 124 | for word in words: 125 | vocab.add_word(word) 126 | return vocab 127 | 128 | 129 | def main(data_path, data_name, threshold=0): 130 | vocab = build_vocab(data_path, data_name, jsons=ANNOTATIONS, threshold=threshold) 131 | if not os.path.isdir('./vocab'): 132 | os.makedirs('./vocab') 133 | with open('./vocab/%s_vocab.pkl' % data_name, 'wb') as f: 134 | pickle.dump(vocab, f, pickle.HIGHEST_PROTOCOL) 135 | print("Saved vocabulary file to ", './vocab/%s_vocab.pkl' % data_name) 136 | 137 | 138 | if __name__ == '__main__': 139 | fire.Fire(main) 140 | -------------------------------------------------------------------------------- /src/ds/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /src/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /src/ds_lavis/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProbVLM: Probabilistic Adapter for Frozen Vision-Language Models 2 | ## [Arxiv paper](https://arxiv.org/pdf/2307.00398.pdf) || [Blog](https://www.eml-unitue.de/publication/ProbVLM) 3 | 4 | ## Introduction 5 | ![probvlm teaser](./figs/probvlm.png) 6 | 7 | **Abstract.** Large-scale vision-language models (VLMs) like CLIP successfully find correspondences between images and text. Through the standard deterministic mapping process, an image or a text sample is mapped to a single vector in the embedding space. This is problematic: as multiple samples (images or text) can abstract the same concept in the physical world, deterministic embeddings do not reflect the inherent ambiguity in the embedding space. We propose **ProbVLM**, a probabilistic adapter that estimates probability distributions for the embeddings of pre-trained VLMs via inter/intra-modal alignment in a post-hoc manner without needing large-scale datasets or computing. On four challenging datasets, i.e., COCO, Flickr, CUB, and Oxford-flowers, we estimate the multi-modal embedding uncertainties for two VLMs, i.e., CLIP and BLIP, quantify the calibration of embedding uncertainties in retrieval tasks and show that **ProbVLM** outperforms other methods. Furthermore, we propose active learning and model selection as two real-world downstream tasks for VLMs and show that the estimated uncertainty aids both tasks. Lastly, we present a novel technique for visualizing the embedding distributions using a large-scale pre-trained latent diffusion model. 8 | 9 | ***TLDR:*** This is the official [PyTorch](https://pytorch.org/) implementation of ProbVLM (from *ICCV 2023*) that allows estimating calibrated uncertainty for pre-trained (frozen) vision-langugae models in fast and efficient manner. 10 | 11 | The structure of the repository is as follows: 12 | ``` 13 | ProbVLM 14 | |-src/ (has the relevant code to train ProbVLM for a pretrained vision-language model, e.g., CLIP) 15 | |-requirements.txt (the environment dependencies) 16 | |-figs/ (has some example images) 17 | ``` 18 | 19 | ## Getting Started 20 | 21 | The `src/` directory already provides all the code to load, train, and evaluate the CLIP checkpoints along with the ProbVLM checkpoint. In order to get started, first ensure that you have all the requirements as provided in `requirements.txt`, this can be setup by running 22 | ``` 23 | conda create --name --file requirements.txt 24 | ``` 25 | 26 | The notebook `src/train_ProbVLM_CLIP.ipynb` shows how to load the CLIP as the base model and how to train ProbVLM for the same using the function, 27 | ```python 28 | train_ProbVLM( 29 | CLIP_Net, 30 | ProbVLM_Net, 31 | cub_train_loader, 32 | cub_valid_loader, 33 | Cri = TempCombLoss(), 34 | device='cuda', 35 | dtype=torch.cuda.FloatTensor, 36 | init_lr=8e-5, 37 | num_epochs=500, 38 | eval_every=5, 39 | ckpt_path='../ckpt/ProbVLM_Net', 40 | T1=1e0, 41 | T2=1e-4 42 | ) 43 | ``` 44 | 45 | 46 | ## Dataset Setup 47 | To setup the COCO, CUB, Flickr and the Oxford-Flowers dataset, please follow the below instructions: 48 | 49 | COCO: Download the 2014 data [here](https://cocodataset.org/#home) and setup the directory in the following way 50 | ``` 51 | coco 52 | |-images/ 53 | |--train2014 54 | |--val2014 55 | |-captions_train2014.json 56 | |-captions_val2014.json 57 | 58 | ``` 59 | 60 | CUB: Download the CUB-200-2011 dataset [here](http://www.vision.caltech.edu/datasets/cub_200_2011/) and the captions from [https://github.com/reedscot/cvpr2016](https://github.com/reedscot/cvpr2016). 61 | 62 | Flowers: Download the images [here](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html) and the captions from [https://github.com/reedscot/cvpr2016](https://github.com/reedscot/cvpr2016). 63 | 64 | 65 | Flickr: Dwonload the images and captions from [here](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset). 66 | 67 | ## Resources 68 | 69 | We use the following resources to use as the base models (on top of which we train different ProbVLM models) 70 | 71 | - CLIP: [https://github.com/openai/CLIP](https://github.com/openai/CLIP) 72 | - BLIP: [https://github.com/salesforce/LAVIS](https://github.com/salesforce/LAVIS) 73 | - BayesCap: [https://github.com/ExplainableML/BayesCap](https://github.com/ExplainableML/BayesCap) 74 | 75 | ## BibTex 76 | 77 | Please cite the following works 78 | 79 | ``` 80 | @inproceedings{Upa_probvlm, 81 | title = {ProbVLM: Probabilistic Adapter for Frozen Vision-Language Models}, 82 | author = {Upadhyay, U. and Karthik, S. and Mancini, M. and Akata, Z.}, 83 | booktitle = {International Conference on Computer Vision (ICCV 2023)}, 84 | year = {2023} 85 | } 86 | ``` 87 | 88 | 89 | ``` 90 | @inproceedings{Upa_bayescap, 91 | title = {BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks}, 92 | author = {Upadhyay, U. and Karthik, S. and Chen, Y. and Mancini, M. and Akata, Z.}, 93 | booktitle = {European Conference on Computer Vision (ECCV 2022)}, 94 | year = {2022} 95 | } 96 | ``` 97 | 98 | ``` 99 | @inproceedings{upadhyay2021uncerguidedi2i, 100 | title={Uncertainty Guided Progressive GANs for Medical Image Translation}, 101 | author={Upadhyay, Uddeshya and Chen, Yanbei and Hebb, Tobias and Gatidis, Sergios and Akata, Zeynep}, 102 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI)}, 103 | year={2021}, 104 | organization={Springer} 105 | } 106 | ``` 107 | 108 | ``` 109 | @inproceedings{UpaCheAka21, 110 | title = {Robustness via Uncertainty-aware Cycle Consistency}, 111 | author = {Upadhyay, U. and Chen, Y. and Akata, Z.}, 112 | booktitle = {Advances in Neural Information Processing Systems 34 (NeurIPS 2021)}, 113 | year = {2021} 114 | } 115 | ``` 116 | 117 | -------------------------------------------------------------------------------- /src/utils_lavis.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from os.path import join as ospj 4 | from os.path import expanduser 5 | from munch import Munch as mch 6 | from tqdm import tqdm_notebook 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | import clip 12 | import ds_lavis 13 | from ds_lavis import prepare_coco_dataloaders, prepare_flickr_dataloaders, prepare_cub_dataloaders, prepare_flo_dataloaders 14 | from tqdm import tqdm 15 | from losses import * 16 | 17 | def get_pred_ranks(q_features, g_features, recall_ks=(1,5,10)): 18 | """ 19 | Args: 20 | q_features (torch.tensor, size=[#query, embedding dim]) 21 | g_features (torch.tensor, size=[#gallery, embedding dim]) 22 | recall_ks (list[:int] or tuple[:int]) 23 | Returns: 24 | pred_ranks_all (np.ndarray, size=[#query, max(recall_ks)]): 25 | data indices of similarity ranking in descending order 26 | """ 27 | max_k = max(recall_ks) 28 | n_q_features = len(q_features) 29 | 30 | pred_ranks_all = [] 31 | for idx in range(n_q_features): 32 | sims = (q_features[idx : idx + 1] @ g_features.t()) 33 | _, pred_ranks = torch.topk(sims, k=max_k, dim=-1) 34 | pred_ranks_all.append(pred_ranks) 35 | pred_ranks_all = torch.cat(pred_ranks_all, dim=0).cpu().numpy() 36 | 37 | return pred_ranks_all 38 | 39 | 40 | def get_recall(pred_ranks_all, recall_ks=(1,5,10), n_gallery_per_query=5): 41 | """ 42 | Args: 43 | pred_ranks_all (np.ndarray, size=[#query, max(recall_ks)]): 44 | data indices of similarity ranking in descending order 45 | recall_ks (list[:int] or tuple[:int]) 46 | n_gallery_per_query (float) 47 | Returns: 48 | recall_scores (list[:float]): list of recall@k 49 | """ 50 | existence = lambda arr1, arr2: any([i in arr2 for i in arr1]) 51 | def gt_idxs(query_idx): 52 | if n_gallery_per_query >= 1: 53 | return np.arange(query_idx * n_gallery_per_query, 54 | (query_idx + 1) * n_gallery_per_query) 55 | else: 56 | return np.array([int(query_idx * n_gallery_per_query)]) 57 | 58 | recall_scores = [] 59 | for recall_k in recall_ks: 60 | score = sum([existence(pred_ranks[:recall_k], gt_idxs(query_idx)) 61 | for query_idx, pred_ranks in enumerate(pred_ranks_all)]) / len(pred_ranks_all) 62 | recall_scores.append(score) 63 | 64 | return recall_scores 65 | 66 | 67 | def get_recall_COCOFLICKR(pred_ranks_all, recall_ks=(1,5,10), n_gallery_per_query=5, q_idx=None): 68 | """ 69 | Args: 70 | pred_ranks_all (np.ndarray, size=[#query, max(recall_ks)]): 71 | data indices of similarity ranking in descending order 72 | recall_ks (list[:int] or tuple[:int]) 73 | n_gallery_per_query (float) 74 | Returns: 75 | recall_scores (list[:float]): list of recall@k 76 | """ 77 | existence = lambda arr1, arr2: any([i in arr2 for i in arr1]) 78 | def gt_idxs(query_idx): 79 | if n_gallery_per_query >= 1: 80 | return np.arange(query_idx * n_gallery_per_query, 81 | (query_idx + 1) * n_gallery_per_query) 82 | else: 83 | return np.array([int(query_idx * n_gallery_per_query)]) 84 | 85 | recall_scores = [] 86 | for recall_k in recall_ks: 87 | score = sum([existence(pred_ranks[:recall_k], q_idx) 88 | for query_idx, pred_ranks in enumerate(pred_ranks_all)]) / len(pred_ranks_all) 89 | recall_scores.append(score) 90 | 91 | return recall_scores 92 | 93 | 94 | def new_recall(pred_ranks_all,recall_ks=(1,5,10),q_classes_all=None,g_classes_all=None): 95 | recall_scores = [] 96 | for recall_k in recall_ks: 97 | corr=0 98 | total = len(pred_ranks_all) 99 | for i in range(len(pred_ranks_all)): 100 | gt_class = q_classes_all[i] 101 | pred_classes = [g_classes_all[j] for j in pred_ranks_all[i][:recall_k]] 102 | if gt_class in pred_classes: 103 | corr+=1 104 | recall_scores.append(corr/total) 105 | 106 | return recall_scores 107 | 108 | def load_data_loader(dataset, data_dir, dataloader_config): 109 | print('lavis!!!') 110 | prepare_loaders = { 111 | 'coco': prepare_coco_dataloaders, 112 | 'flickr': prepare_flickr_dataloaders, 113 | 'CUB':prepare_cub_dataloaders, 114 | 'FLO':prepare_flo_dataloaders 115 | }[dataset] 116 | if dataset == 'CUB': 117 | loaders = prepare_loaders( 118 | dataloader_config, 119 | dataset_root=data_dir, 120 | caption_root=data_dir+'/text_c10', 121 | vocab_path='ds/vocabs/cub_vocab.pkl') 122 | elif dataset == 'FLO': 123 | loaders = prepare_loaders( 124 | dataloader_config, 125 | dataset_root=data_dir, 126 | caption_root=data_dir+'/text_c10',) 127 | else: 128 | loaders = prepare_loaders( 129 | dataloader_config, 130 | dataset_root=data_dir, 131 | vocab_path='ds/vocabs/coco_vocab.pkl') 132 | return loaders 133 | 134 | def load_model(device, model_path=None): 135 | # load zero-shot CLIP model 136 | model, _ = clip.load(name='ViT-B/32', 137 | device=device, 138 | loss_type='contrastive') 139 | if model_path is None: 140 | # Convert the dtype of parameters from float16 to float32 141 | for name, param in model.named_parameters(): 142 | param.data = param.data.type(torch.float32) 143 | else: 144 | ckpt = torch.load(model_path) 145 | model.load_state_dict(ckpt['state_dict']) 146 | for name, param in model.named_parameters(): 147 | param.data = param.data.type(torch.float32) 148 | if torch.cuda.device_count() > 1: 149 | model = nn.DataParallel(model) 150 | return model 151 | 152 | 153 | ### training and evaluation 154 | def emb_mae(x1, x2): 155 | m = torch.abs(x1-x2).mean() 156 | return m 157 | 158 | def emb_mse(x1, x2): 159 | m = torch.pow(torch.abs(x1-x2),2).mean() 160 | return m 161 | -------------------------------------------------------------------------------- /src/ds/flickr.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as ospj 3 | from os.path import expanduser 4 | import csv 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torchvision as tv 10 | from torchvision.models.detection import maskrcnn_resnet50_fpn as maskrcnn 11 | 12 | class UnNormalize(object): 13 | def __init__(self, 14 | # mean=[0.485, 0.456, 0.406], 15 | # std=[0.229, 0.224, 0.225]): 16 | mean=(0.48145466, 0.4578275, 0.40821073), 17 | std=(0.26862954, 0.26130258, 0.27577711)): 18 | self.mean = mean 19 | self.std = std 20 | 21 | def __call__(self, tensor): 22 | """ 23 | Args: 24 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 25 | Returns: 26 | Tensor: Normalized image. 27 | """ 28 | unnormed_tensor = torch.zeros_like(tensor) 29 | for i, (t, m, s) in enumerate(zip(tensor, self.mean, self.std)): 30 | unnormed_tensor[i] = t.mul(s).add(m) 31 | # The normalize code -> t.sub_(m).div_(s) 32 | return unnormed_tensor 33 | 34 | COCO_INSTANCE_CATEGORY_NAMES = [ 35 | '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 36 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 37 | 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 38 | 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 39 | 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 40 | 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 41 | 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 42 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 43 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 44 | 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 45 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 46 | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' 47 | ] 48 | 49 | class FlickrCap(Dataset): 50 | def __init__(self, data_root, image_ids_path=None, transform=None, target_transform=None): 51 | self.root = expanduser(data_root) 52 | self.transform = transform 53 | self.target_transform = target_transform 54 | 55 | # load ids 56 | # image_ids_path = './datasets/annotations/flickr/train.txt' 57 | with open(image_ids_path) as f: 58 | lines = f.readlines() 59 | image_files = [line.strip() + '.jpg' for line in lines] 60 | 61 | # load data 62 | self.datas = [] 63 | data_path = ospj(os.path.dirname(self.root), 'results.csv') 64 | reader = csv.reader(open(data_path)) 65 | for i, row in enumerate(reader): 66 | if i == 0: 67 | continue 68 | data = [val.strip() for val in row[0].split('|')] # ex: ['1001465944.jpg', '0', 'A woman is walking .'] 69 | if data[0] in image_files: 70 | self.datas.append(data) 71 | 72 | def __getitem__(self, index, get_caption=False): 73 | image_file, _, caption = self.datas[index] 74 | img = Image.open(ospj(self.root, image_file)).convert('RGB') 75 | 76 | if self.transform is not None: 77 | img = self.transform(img) 78 | is_img_masked = False 79 | img_masked = img 80 | if self.target_transform is not None: 81 | # target = self.target_transform(target) 82 | target = self.target_transform(caption) 83 | target = target.squeeze(0) 84 | if get_caption: 85 | return img, target, caption, img_masked, is_img_masked 86 | else: 87 | return img, target, img_masked, is_img_masked 88 | 89 | def __len__(self): 90 | return len(self.datas) 91 | 92 | 93 | class FlickrBboxes(FlickrCap): 94 | def __init__(self, data_root, device, image_ids_path=None, transform=None, target_transform=None): 95 | super().__init__(data_root, image_ids_path, transform, target_transform) 96 | self.device = device 97 | self.detector = maskrcnn(pretrained=True) 98 | self.detector = self.detector.to(self. device); self.detector.eval() 99 | self.unnorm = UnNormalize() 100 | self.norm = tv.transforms.Compose([tv.transforms.ToTensor(),]) 101 | 102 | def __getitem__(self, index, get_caption=False): 103 | image_file, _, caption = self.datas[index] 104 | img = Image.open(ospj(self.root, image_file)).convert('RGB') 105 | 106 | if self.transform is not None: 107 | img, img_masked, is_img_masked = self.transform(img) 108 | 109 | if self.target_transform is not None: 110 | target = self.target_transform(caption) 111 | target = target.squeeze(0) 112 | 113 | # bbox 114 | input_for_bbox = tv.transforms.ToPILImage()(self.unnorm(img)) 115 | input_for_bbox = self.norm(input_for_bbox) 116 | input_for_bbox = input_for_bbox.to(self.device) 117 | with torch.no_grad(): 118 | p = self.detector([input_for_bbox]) 119 | bboxes = p[0]['boxes'].cpu().numpy() 120 | bboxes[:,2] = bboxes[:,2] - bboxes[:,0] 121 | bboxes[:,3] = bboxes[:,3] - bboxes[:,1] 122 | cats = p[0]['labels'].cpu().numpy() 123 | scores = p[0]['scores'].cpu().numpy() 124 | bboxes = [bbox for i, bbox in enumerate(bboxes) if scores[i] >= 0.5] 125 | bboxes = torch.tensor(np.array(bboxes)) 126 | bbox_cats = [cat for i, cat in enumerate(cats) if scores[i] >= 0.5] 127 | bbox_cats = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in bbox_cats] 128 | if len(bboxes) == 0: 129 | bboxes = torch.tensor([[0., 0., 0., 0.]]) 130 | bbox_cats = ['none'] 131 | 132 | if get_caption: 133 | return img, target, caption, bboxes, bbox_cats 134 | else: 135 | return img, target, bboxes -------------------------------------------------------------------------------- /src/ds_lavis/flickr.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as ospj 3 | from os.path import expanduser 4 | import csv 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torchvision as tv 10 | from torchvision.models.detection import maskrcnn_resnet50_fpn as maskrcnn 11 | 12 | class UnNormalize(object): 13 | def __init__(self, 14 | # mean=[0.485, 0.456, 0.406], 15 | # std=[0.229, 0.224, 0.225]): 16 | mean=(0.48145466, 0.4578275, 0.40821073), 17 | std=(0.26862954, 0.26130258, 0.27577711)): 18 | self.mean = mean 19 | self.std = std 20 | 21 | def __call__(self, tensor): 22 | """ 23 | Args: 24 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 25 | Returns: 26 | Tensor: Normalized image. 27 | """ 28 | unnormed_tensor = torch.zeros_like(tensor) 29 | for i, (t, m, s) in enumerate(zip(tensor, self.mean, self.std)): 30 | unnormed_tensor[i] = t.mul(s).add(m) 31 | # The normalize code -> t.sub_(m).div_(s) 32 | return unnormed_tensor 33 | 34 | COCO_INSTANCE_CATEGORY_NAMES = [ 35 | '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 36 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 37 | 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 38 | 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 39 | 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 40 | 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 41 | 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 42 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 43 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 44 | 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 45 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 46 | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' 47 | ] 48 | 49 | class FlickrCap(Dataset): 50 | def __init__(self, data_root, image_ids_path=None, transform=None, target_transform=None): 51 | self.root = expanduser(data_root) 52 | self.transform = transform 53 | self.target_transform = target_transform 54 | 55 | # load ids 56 | # image_ids_path = './datasets/annotations/flickr/train.txt' 57 | with open(image_ids_path) as f: 58 | lines = f.readlines() 59 | image_files = [line.strip() + '.jpg' for line in lines] 60 | 61 | # load data 62 | self.datas = [] 63 | data_path = ospj(os.path.dirname(self.root), 'results.csv') 64 | reader = csv.reader(open(data_path)) 65 | for i, row in enumerate(reader): 66 | if i == 0: 67 | continue 68 | data = [val.strip() for val in row[0].split('|')] # ex: ['1001465944.jpg', '0', 'A woman is walking .'] 69 | if data[0] in image_files: 70 | self.datas.append(data) 71 | 72 | def __getitem__(self, index, get_caption=False): 73 | image_file, _, caption = self.datas[index] 74 | img = Image.open(ospj(self.root, image_file)).convert('RGB') 75 | 76 | if self.transform is not None: 77 | img = self.transform(img) 78 | is_img_masked = False 79 | img_masked = img 80 | if self.target_transform is not None: 81 | # target = self.target_transform(target) 82 | target = self.target_transform(caption) 83 | target = target.squeeze(0) 84 | if get_caption: 85 | return img, target, caption, img_masked, is_img_masked 86 | else: 87 | return img, target, img_masked, is_img_masked 88 | 89 | def __len__(self): 90 | return len(self.datas) 91 | 92 | 93 | class FlickrBboxes(FlickrCap): 94 | def __init__(self, data_root, device, image_ids_path=None, transform=None, target_transform=None): 95 | super().__init__(data_root, image_ids_path, transform, target_transform) 96 | self.device = device 97 | self.detector = maskrcnn(pretrained=True) 98 | self.detector = self.detector.to(self. device); self.detector.eval() 99 | self.unnorm = UnNormalize() 100 | self.norm = tv.transforms.Compose([tv.transforms.ToTensor(),]) 101 | 102 | def __getitem__(self, index, get_caption=False): 103 | image_file, _, caption = self.datas[index] 104 | img = Image.open(ospj(self.root, image_file)).convert('RGB') 105 | 106 | if self.transform is not None: 107 | img, img_masked, is_img_masked = self.transform(img) 108 | 109 | if self.target_transform is not None: 110 | target = self.target_transform(caption) 111 | target = target.squeeze(0) 112 | 113 | # bbox 114 | input_for_bbox = tv.transforms.ToPILImage()(self.unnorm(img)) 115 | input_for_bbox = self.norm(input_for_bbox) 116 | input_for_bbox = input_for_bbox.to(self.device) 117 | with torch.no_grad(): 118 | p = self.detector([input_for_bbox]) 119 | bboxes = p[0]['boxes'].cpu().numpy() 120 | bboxes[:,2] = bboxes[:,2] - bboxes[:,0] 121 | bboxes[:,3] = bboxes[:,3] - bboxes[:,1] 122 | cats = p[0]['labels'].cpu().numpy() 123 | scores = p[0]['scores'].cpu().numpy() 124 | bboxes = [bbox for i, bbox in enumerate(bboxes) if scores[i] >= 0.5] 125 | bboxes = torch.tensor(np.array(bboxes)) 126 | bbox_cats = [cat for i, cat in enumerate(cats) if scores[i] >= 0.5] 127 | bbox_cats = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in bbox_cats] 128 | if len(bboxes) == 0: 129 | bboxes = torch.tensor([[0., 0., 0., 0.]]) 130 | bbox_cats = ['none'] 131 | 132 | if get_caption: 133 | return img, target, caption, bboxes, bbox_cats 134 | else: 135 | return img, target, bboxes -------------------------------------------------------------------------------- /src/ds/flo.py: -------------------------------------------------------------------------------- 1 | """CUB Caption image-to-caption retrieval dataset code 2 | 3 | PCME 4 | Copyright (c) 2021-present NAVER Corp. 5 | MIT license 6 | """ 7 | 8 | import os 9 | from PIL import Image 10 | import numpy as np 11 | from torch.utils.data import Dataset 12 | from torch.utils.data.sampler import Sampler 13 | import scipy.io 14 | import glob 15 | 16 | def pad_text(num): 17 | if num<10: 18 | return '0000'+str(num) 19 | if num<100: 20 | return '000'+str(num) 21 | 22 | if num<1000: 23 | return '00'+str(num) 24 | 25 | 26 | class FLOCaption(Dataset): 27 | """CUB Captions Dataset. 28 | 29 | Args: 30 | image_root (string): Root directory where images are downloaded to. 31 | caption_root (string): Root directory where captions are downloaded to. 32 | target_classes (str or list): target class ids 33 | - if str, it is the name of the file with target classes (line by line) 34 | - if list, it is directly used to get classes 35 | transform (callable, optional): A function/transform that takes in an PIL image 36 | and returns a transformed version. E.g, ``transforms.ToTensor`` 37 | target_transform (callable, optional): A function/transform that takes in the 38 | target and transforms it. 39 | omit_ids (str, optional): Path of file with the list of image ids to omit, 40 | if not specified, use all images in the target classes. 41 | ids (str, optional): Path of file with the list of target image ids, 42 | if not specified, use all images in the target classes. 43 | """ 44 | def __init__(self, image_root, caption_root, 45 | target_classes, 46 | transform=None, target_transform=None, 47 | ): 48 | 49 | self.image_root = os.path.expanduser(image_root) 50 | self.caption_root = os.path.expanduser(caption_root) 51 | 52 | if isinstance(target_classes, str): 53 | with open(target_classes) as fin: 54 | _classes = [int(line.strip().split('_')[1]) - 1 for line in fin] 55 | target_classes = _classes 56 | 57 | target_classes = set(list(target_classes)) 58 | if (target_classes - set(range(102))): 59 | raise ValueError(f'target classes should be an integer array between 0-102, but {target_classes}') 60 | print(f'prepare flo dataset with {len(target_classes)} classes') 61 | 62 | targets = [] 63 | index_to_class = {} 64 | class_to_indices = {} 65 | class_to_img_indices = {} 66 | idx = 0 67 | n_images = 0 68 | label_path = image_root+'/imagelabels.mat' 69 | jpg_path = image_root+'/jpg/' 70 | class_labels = np.array(scipy.io.loadmat(label_path)['labels'])[0] 71 | images = glob.glob(jpg_path+'*') 72 | images.sort() 73 | n_images=0 74 | for i in range(len(images)): 75 | img_name = images[i] 76 | cls_num = class_labels[i] - 1 77 | if cls_num in target_classes: 78 | _target = [] 79 | 80 | class_txt = 'class_'+pad_text(cls_num+1) 81 | #print(caption_root,class_txt,img_name) 82 | caption_img = img_name.split('/')[-1] 83 | txt_fname = os.path.join(caption_root, class_txt, caption_img.replace('jpg', 'txt')) 84 | with open(txt_fname) as fin: 85 | captions = [line.strip() for line in fin] 86 | 87 | for caption in captions: 88 | _target.append( 89 | (os.path.join(img_name), caption) 90 | ) 91 | index_to_class[idx] = cls_num 92 | class_to_indices.setdefault(cls_num, []).append(idx) 93 | idx += 1 94 | targets.extend(_target) 95 | n_images+=1 96 | self.targets = targets 97 | self.target_classes = target_classes 98 | self.index_to_class = index_to_class 99 | self.class_to_indices = class_to_indices 100 | self.class_to_img_indices = class_to_img_indices 101 | 102 | self.n_images = n_images 103 | 104 | self.transform = transform 105 | self.target_transform = target_transform 106 | 107 | def __getitem__(self, index): 108 | img_path, target = self.targets[index] 109 | 110 | img = Image.open(img_path).convert('RGB') 111 | if self.transform is not None: 112 | img = self.transform(img) 113 | if self.target_transform is not None: 114 | target = self.target_transform(target) 115 | target = target.squeeze(0) 116 | 117 | return img, target, self.index_to_class[index], index 118 | 119 | def __len__(self): 120 | return len(self.targets) 121 | 122 | 123 | class FLOSampler(Sampler): 124 | """ Sampler for CUB Captions training. 125 | 126 | Args: 127 | dataset (CUBCaption object): dataset object to apply the sampler. 128 | batch_size (int): batch size. 129 | adjust_epoch (bool): if true, the iterations for one epoch is re-calculated. 130 | """ 131 | def __init__(self, dataset, batch_size, adjust_epoch=True): 132 | self.dataset = dataset 133 | self.batch_size = batch_size 134 | print("Batch:",self.batch_size) 135 | self.target_classes = dataset.target_classes 136 | if batch_size != len(self.target_classes): 137 | raise ValueError(f'{batch_size} != {len(self.target_classes)}') 138 | self.index_to_class = dataset.index_to_class 139 | self.class_to_indices = dataset.class_to_indices 140 | self.n_items = len(self.index_to_class) 141 | 142 | if adjust_epoch: 143 | self.n_iters = int(self.n_items / len(self.target_classes)) 144 | else: 145 | self.n_iters = self.n_items 146 | 147 | def __iter__(self): 148 | batch = [] 149 | indices = list(range(self.n_items)) 150 | 151 | np.random.shuffle(indices) 152 | for cur_iter, idx in enumerate(indices): 153 | batch = [idx] 154 | pos_cls = self.index_to_class[idx] 155 | for cls_num, _indices in self.class_to_indices.items(): 156 | if cls_num == pos_cls: 157 | continue 158 | else: 159 | batch.append(np.random.choice(_indices)) 160 | np.random.shuffle(batch) 161 | if cur_iter > self.n_iters: 162 | return 163 | yield batch 164 | 165 | 166 | def __len__(self): 167 | return self.n_iters 168 | -------------------------------------------------------------------------------- /src/ds_lavis/flo.py: -------------------------------------------------------------------------------- 1 | """CUB Caption image-to-caption retrieval dataset code 2 | 3 | PCME 4 | Copyright (c) 2021-present NAVER Corp. 5 | MIT license 6 | """ 7 | 8 | import os 9 | from PIL import Image 10 | import numpy as np 11 | from torch.utils.data import Dataset 12 | from torch.utils.data.sampler import Sampler 13 | import scipy.io 14 | import glob 15 | 16 | def pad_text(num): 17 | if num<10: 18 | return '0000'+str(num) 19 | if num<100: 20 | return '000'+str(num) 21 | 22 | if num<1000: 23 | return '00'+str(num) 24 | 25 | 26 | class FLOCaption(Dataset): 27 | """CUB Captions Dataset. 28 | 29 | Args: 30 | image_root (string): Root directory where images are downloaded to. 31 | caption_root (string): Root directory where captions are downloaded to. 32 | target_classes (str or list): target class ids 33 | - if str, it is the name of the file with target classes (line by line) 34 | - if list, it is directly used to get classes 35 | transform (callable, optional): A function/transform that takes in an PIL image 36 | and returns a transformed version. E.g, ``transforms.ToTensor`` 37 | target_transform (callable, optional): A function/transform that takes in the 38 | target and transforms it. 39 | omit_ids (str, optional): Path of file with the list of image ids to omit, 40 | if not specified, use all images in the target classes. 41 | ids (str, optional): Path of file with the list of target image ids, 42 | if not specified, use all images in the target classes. 43 | """ 44 | def __init__(self, image_root, caption_root, 45 | target_classes, 46 | transform=None, target_transform=None, 47 | ): 48 | 49 | self.image_root = os.path.expanduser(image_root) 50 | self.caption_root = os.path.expanduser(caption_root) 51 | 52 | if isinstance(target_classes, str): 53 | with open(target_classes) as fin: 54 | _classes = [int(line.strip().split('_')[1]) - 1 for line in fin] 55 | target_classes = _classes 56 | 57 | target_classes = set(list(target_classes)) 58 | if (target_classes - set(range(102))): 59 | raise ValueError(f'target classes should be an integer array between 0-102, but {target_classes}') 60 | print(f'prepare flo dataset with {len(target_classes)} classes') 61 | 62 | targets = [] 63 | index_to_class = {} 64 | class_to_indices = {} 65 | class_to_img_indices = {} 66 | idx = 0 67 | n_images = 0 68 | label_path = image_root+'/imagelabels.mat' 69 | jpg_path = image_root+'/jpg/' 70 | class_labels = np.array(scipy.io.loadmat(label_path)['labels'])[0] 71 | images = glob.glob(jpg_path+'*') 72 | images.sort() 73 | n_images=0 74 | for i in range(len(images)): 75 | img_name = images[i] 76 | cls_num = class_labels[i] - 1 77 | if cls_num in target_classes: 78 | _target = [] 79 | 80 | class_txt = 'class_'+pad_text(cls_num+1) 81 | #print(caption_root,class_txt,img_name) 82 | caption_img = img_name.split('/')[-1] 83 | txt_fname = os.path.join(caption_root, class_txt, caption_img.replace('jpg', 'txt')) 84 | with open(txt_fname) as fin: 85 | captions = [line.strip() for line in fin] 86 | 87 | for caption in captions: 88 | _target.append( 89 | (os.path.join(img_name), caption) 90 | ) 91 | index_to_class[idx] = cls_num 92 | class_to_indices.setdefault(cls_num, []).append(idx) 93 | idx += 1 94 | targets.extend(_target) 95 | n_images+=1 96 | self.targets = targets 97 | self.target_classes = target_classes 98 | self.index_to_class = index_to_class 99 | self.class_to_indices = class_to_indices 100 | self.class_to_img_indices = class_to_img_indices 101 | 102 | self.n_images = n_images 103 | 104 | self.transform = transform 105 | self.target_transform = target_transform 106 | 107 | def __getitem__(self, index): 108 | img_path, target = self.targets[index] 109 | 110 | img = Image.open(img_path).convert('RGB') 111 | if self.transform is not None: 112 | img = self.transform(img) 113 | if self.target_transform is not None: 114 | target = self.target_transform(target) 115 | target = target.squeeze(0) 116 | 117 | return img, target, self.index_to_class[index], index 118 | 119 | def __len__(self): 120 | return len(self.targets) 121 | 122 | 123 | class FLOSampler(Sampler): 124 | """ Sampler for CUB Captions training. 125 | 126 | Args: 127 | dataset (CUBCaption object): dataset object to apply the sampler. 128 | batch_size (int): batch size. 129 | adjust_epoch (bool): if true, the iterations for one epoch is re-calculated. 130 | """ 131 | def __init__(self, dataset, batch_size, adjust_epoch=True): 132 | self.dataset = dataset 133 | self.batch_size = batch_size 134 | print("Batch:",self.batch_size) 135 | self.target_classes = dataset.target_classes 136 | if batch_size != len(self.target_classes): 137 | raise ValueError(f'{batch_size} != {len(self.target_classes)}') 138 | self.index_to_class = dataset.index_to_class 139 | self.class_to_indices = dataset.class_to_indices 140 | self.n_items = len(self.index_to_class) 141 | 142 | if adjust_epoch: 143 | self.n_iters = int(self.n_items / len(self.target_classes)) 144 | else: 145 | self.n_iters = self.n_items 146 | 147 | def __iter__(self): 148 | batch = [] 149 | indices = list(range(self.n_items)) 150 | 151 | np.random.shuffle(indices) 152 | for cur_iter, idx in enumerate(indices): 153 | batch = [idx] 154 | pos_cls = self.index_to_class[idx] 155 | for cls_num, _indices in self.class_to_indices.items(): 156 | if cls_num == pos_cls: 157 | continue 158 | else: 159 | batch.append(np.random.choice(_indices)) 160 | np.random.shuffle(batch) 161 | if cur_iter > self.n_iters: 162 | return 163 | yield batch 164 | 165 | 166 | def __len__(self): 167 | return self.n_iters 168 | -------------------------------------------------------------------------------- /src/ds/cub.py: -------------------------------------------------------------------------------- 1 | """CUB Caption image-to-caption retrieval dataset code 2 | 3 | PCME 4 | Copyright (c) 2021-present NAVER Corp. 5 | MIT license 6 | """ 7 | 8 | import os 9 | from PIL import Image 10 | import numpy as np 11 | from torch.utils.data import Dataset 12 | from torch.utils.data.sampler import Sampler 13 | 14 | 15 | class CUBCaption(Dataset): 16 | """CUB Captions Dataset. 17 | 18 | Args: 19 | image_root (string): Root directory where images are downloaded to. 20 | caption_root (string): Root directory where captions are downloaded to. 21 | target_classes (str or list): target class ids 22 | - if str, it is the name of the file with target classes (line by line) 23 | - if list, it is directly used to get classes 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.ToTensor`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | omit_ids (str, optional): Path of file with the list of image ids to omit, 29 | if not specified, use all images in the target classes. 30 | ids (str, optional): Path of file with the list of target image ids, 31 | if not specified, use all images in the target classes. 32 | """ 33 | def __init__(self, image_root, caption_root, 34 | target_classes, 35 | transform=None, target_transform=None, 36 | omit_ids=None, ids=None): 37 | if omit_ids and ids: 38 | raise ValueError('omit ids and ids cannot be defined at the same time.') 39 | if omit_ids: 40 | with open(omit_ids) as fin: 41 | omit_ids = set([line.strip() for line in fin]) 42 | else: 43 | omit_ids = set() 44 | if ids: 45 | with open(ids) as fin: 46 | ids = set([line.strip() for line in fin]) 47 | 48 | self.image_root = os.path.expanduser(image_root) 49 | self.caption_root = os.path.expanduser(caption_root) 50 | 51 | if isinstance(target_classes, str): 52 | with open(target_classes) as fin: 53 | _classes = [int(line.strip().split('.')[0]) - 1 for line in fin] 54 | target_classes = _classes 55 | 56 | target_classes = set(list(target_classes)) 57 | if (target_classes - set(range(200))): 58 | raise ValueError(f'target classes should be an integer array between 0-199, but {target_classes}') 59 | print(f'prepare cub dataset with {len(target_classes)} classes') 60 | 61 | targets = [] 62 | index_to_class = {} 63 | class_to_indices = {} 64 | class_to_img_indices = {} 65 | idx = 0 66 | n_images = 0 67 | for bird_name in os.listdir(image_root): 68 | cls_num = int(bird_name.split('.')[0]) - 1 69 | if cls_num in target_classes: 70 | _target = [] 71 | for fname in os.listdir(os.path.join(image_root, bird_name)): 72 | if os.path.join(bird_name, fname) in omit_ids: 73 | continue 74 | 75 | if ids and os.path.join(bird_name, fname) not in ids: 76 | continue 77 | 78 | txt_fname = os.path.join(caption_root, bird_name, fname.replace('jpg', 'txt')) 79 | with open(txt_fname) as fin: 80 | captions = [line.strip() for line in fin] 81 | 82 | n_images += 1 83 | class_to_img_indices.setdefault(cls_num, []).append(n_images) 84 | for caption in captions: 85 | _target.append( 86 | (os.path.join(image_root, bird_name, fname), caption) 87 | ) 88 | index_to_class[idx] = cls_num 89 | class_to_indices.setdefault(cls_num, []).append(idx) 90 | idx += 1 91 | targets.extend(_target) 92 | self.targets = targets 93 | self.target_classes = target_classes 94 | self.index_to_class = index_to_class 95 | self.class_to_indices = class_to_indices 96 | self.class_to_img_indices = class_to_img_indices 97 | 98 | self.n_images = n_images 99 | 100 | self.transform = transform 101 | self.target_transform = target_transform 102 | 103 | def __getitem__(self, index): 104 | img_path, target = self.targets[index] 105 | 106 | img = Image.open(img_path).convert('RGB') 107 | if self.transform is not None: 108 | img = self.transform(img) 109 | if self.target_transform is not None: 110 | target = self.target_transform(target) 111 | target = target.squeeze(0) 112 | 113 | return img, target, self.index_to_class[index], index 114 | 115 | def __len__(self): 116 | return len(self.targets) 117 | 118 | 119 | class CUBSampler(Sampler): 120 | """ Sampler for CUB Captions training. 121 | 122 | Args: 123 | dataset (CUBCaption object): dataset object to apply the sampler. 124 | batch_size (int): batch size. 125 | adjust_epoch (bool): if true, the iterations for one epoch is re-calculated. 126 | """ 127 | def __init__(self, dataset, batch_size, adjust_epoch=True): 128 | self.dataset = dataset 129 | self.batch_size = batch_size 130 | self.target_classes = dataset.target_classes 131 | if batch_size != len(self.target_classes): 132 | raise ValueError(f'{batch_size} != {len(self.target_classes)}') 133 | self.index_to_class = dataset.index_to_class 134 | self.class_to_indices = dataset.class_to_indices 135 | self.n_items = len(self.index_to_class) 136 | 137 | if adjust_epoch: 138 | self.n_iters = int(self.n_items / len(self.target_classes)) 139 | else: 140 | self.n_iters = self.n_items 141 | 142 | def __iter__(self): 143 | batch = [] 144 | indices = list(range(self.n_items)) 145 | 146 | np.random.shuffle(indices) 147 | for cur_iter, idx in enumerate(indices): 148 | batch = [idx] 149 | pos_cls = self.index_to_class[idx] 150 | for cls_num, _indices in self.class_to_indices.items(): 151 | if cls_num == pos_cls: 152 | continue 153 | else: 154 | batch.append(np.random.choice(_indices)) 155 | np.random.shuffle(batch) 156 | if cur_iter > self.n_iters: 157 | return 158 | yield batch 159 | 160 | 161 | def __len__(self): 162 | return self.n_iters 163 | -------------------------------------------------------------------------------- /src/ds_lavis/cub.py: -------------------------------------------------------------------------------- 1 | """CUB Caption image-to-caption retrieval dataset code 2 | 3 | PCME 4 | Copyright (c) 2021-present NAVER Corp. 5 | MIT license 6 | """ 7 | 8 | import os 9 | from PIL import Image 10 | import numpy as np 11 | from torch.utils.data import Dataset 12 | from torch.utils.data.sampler import Sampler 13 | 14 | 15 | class CUBCaption(Dataset): 16 | """CUB Captions Dataset. 17 | 18 | Args: 19 | image_root (string): Root directory where images are downloaded to. 20 | caption_root (string): Root directory where captions are downloaded to. 21 | target_classes (str or list): target class ids 22 | - if str, it is the name of the file with target classes (line by line) 23 | - if list, it is directly used to get classes 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.ToTensor`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | omit_ids (str, optional): Path of file with the list of image ids to omit, 29 | if not specified, use all images in the target classes. 30 | ids (str, optional): Path of file with the list of target image ids, 31 | if not specified, use all images in the target classes. 32 | """ 33 | def __init__(self, image_root, caption_root, 34 | target_classes, 35 | transform=None, target_transform=None, 36 | omit_ids=None, ids=None): 37 | if omit_ids and ids: 38 | raise ValueError('omit ids and ids cannot be defined at the same time.') 39 | if omit_ids: 40 | with open(omit_ids) as fin: 41 | omit_ids = set([line.strip() for line in fin]) 42 | else: 43 | omit_ids = set() 44 | if ids: 45 | with open(ids) as fin: 46 | ids = set([line.strip() for line in fin]) 47 | 48 | self.image_root = os.path.expanduser(image_root) 49 | self.caption_root = os.path.expanduser(caption_root) 50 | 51 | if isinstance(target_classes, str): 52 | with open(target_classes) as fin: 53 | _classes = [int(line.strip().split('.')[0]) - 1 for line in fin] 54 | target_classes = _classes 55 | 56 | target_classes = set(list(target_classes)) 57 | if (target_classes - set(range(200))): 58 | raise ValueError(f'target classes should be an integer array between 0-199, but {target_classes}') 59 | print(f'prepare cub dataset with {len(target_classes)} classes') 60 | 61 | targets = [] 62 | index_to_class = {} 63 | class_to_indices = {} 64 | class_to_img_indices = {} 65 | idx = 0 66 | n_images = 0 67 | for bird_name in os.listdir(image_root): 68 | cls_num = int(bird_name.split('.')[0]) - 1 69 | if cls_num in target_classes: 70 | _target = [] 71 | for fname in os.listdir(os.path.join(image_root, bird_name)): 72 | if os.path.join(bird_name, fname) in omit_ids: 73 | continue 74 | 75 | if ids and os.path.join(bird_name, fname) not in ids: 76 | continue 77 | 78 | txt_fname = os.path.join(caption_root, bird_name, fname.replace('jpg', 'txt')) 79 | with open(txt_fname) as fin: 80 | captions = [line.strip() for line in fin] 81 | 82 | n_images += 1 83 | class_to_img_indices.setdefault(cls_num, []).append(n_images) 84 | for caption in captions: 85 | _target.append( 86 | (os.path.join(image_root, bird_name, fname), caption) 87 | ) 88 | index_to_class[idx] = cls_num 89 | class_to_indices.setdefault(cls_num, []).append(idx) 90 | idx += 1 91 | targets.extend(_target) 92 | self.targets = targets 93 | self.target_classes = target_classes 94 | self.index_to_class = index_to_class 95 | self.class_to_indices = class_to_indices 96 | self.class_to_img_indices = class_to_img_indices 97 | 98 | self.n_images = n_images 99 | 100 | self.transform = transform 101 | self.target_transform = target_transform 102 | 103 | def __getitem__(self, index): 104 | img_path, target = self.targets[index] 105 | 106 | img = Image.open(img_path).convert('RGB') 107 | if self.transform is not None: 108 | img = self.transform(img) 109 | if self.target_transform is not None: 110 | target = self.target_transform(target) 111 | target = target.squeeze(0) 112 | 113 | return img, target, self.index_to_class[index], index 114 | 115 | def __len__(self): 116 | return len(self.targets) 117 | 118 | 119 | class CUBSampler(Sampler): 120 | """ Sampler for CUB Captions training. 121 | 122 | Args: 123 | dataset (CUBCaption object): dataset object to apply the sampler. 124 | batch_size (int): batch size. 125 | adjust_epoch (bool): if true, the iterations for one epoch is re-calculated. 126 | """ 127 | def __init__(self, dataset, batch_size, adjust_epoch=True): 128 | self.dataset = dataset 129 | self.batch_size = batch_size 130 | self.target_classes = dataset.target_classes 131 | if batch_size != len(self.target_classes): 132 | raise ValueError(f'{batch_size} != {len(self.target_classes)}') 133 | self.index_to_class = dataset.index_to_class 134 | self.class_to_indices = dataset.class_to_indices 135 | self.n_items = len(self.index_to_class) 136 | 137 | if adjust_epoch: 138 | self.n_iters = int(self.n_items / len(self.target_classes)) 139 | else: 140 | self.n_iters = self.n_items 141 | 142 | def __iter__(self): 143 | batch = [] 144 | indices = list(range(self.n_items)) 145 | 146 | np.random.shuffle(indices) 147 | for cur_iter, idx in enumerate(indices): 148 | batch = [idx] 149 | pos_cls = self.index_to_class[idx] 150 | for cls_num, _indices in self.class_to_indices.items(): 151 | if cls_num == pos_cls: 152 | continue 153 | else: 154 | batch.append(np.random.choice(_indices)) 155 | np.random.shuffle(batch) 156 | if cur_iter > self.n_iters: 157 | return 158 | yield batch 159 | 160 | 161 | def __len__(self): 162 | return self.n_iters 163 | -------------------------------------------------------------------------------- /src/networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as ospj 3 | from os.path import expanduser 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | import clip 10 | from tqdm import tqdm 11 | from utils_lavis import * 12 | 13 | 14 | class BayesCap_MLP(nn.Module): 15 | ''' 16 | Baseclass to create a simple MLP 17 | Inputs 18 | inp_dim: int, Input dimension 19 | out_dim: int, Output dimension 20 | hid_dim: int, hidden dimension 21 | num_layers: Number of hidden layers 22 | p_drop: dropout probability 23 | ''' 24 | def __init__( 25 | self, 26 | inp_dim, 27 | out_dim, 28 | hid_dim=512, 29 | num_layers=1, 30 | p_drop=0, 31 | ): 32 | super(BayesCap_MLP, self).__init__() 33 | mod = [] 34 | for layer in range(num_layers): 35 | if layer==0: 36 | incoming = inp_dim 37 | outgoing = hid_dim 38 | mod.append(nn.Linear(incoming, outgoing)) 39 | mod.append(nn.ReLU()) 40 | elif layer==num_layers//2: 41 | incoming = hid_dim 42 | outgoing = hid_dim 43 | mod.append(nn.Linear(incoming, outgoing)) 44 | mod.append(nn.ReLU()) 45 | mod.append(nn.Dropout(p=p_drop)) 46 | elif layer==num_layers-1: 47 | incoming = hid_dim 48 | outgoing = out_dim 49 | mod.append(nn.Linear(incoming, outgoing)) 50 | self.mod = nn.Sequential(*mod) 51 | 52 | self.block_mu = nn.Sequential( 53 | nn.Linear(out_dim, out_dim), 54 | nn.ReLU(), 55 | nn.Linear(out_dim, out_dim), 56 | ) 57 | 58 | self.block_alpha = nn.Sequential( 59 | nn.Linear(out_dim, out_dim), 60 | nn.ReLU(), 61 | # nn.Linear(out_dim, out_dim), 62 | # nn.ReLU(), 63 | nn.Linear(out_dim, out_dim), 64 | nn.ReLU(), 65 | ) 66 | 67 | self.block_beta = nn.Sequential( 68 | nn.Linear(out_dim, out_dim), 69 | nn.ReLU(), 70 | # nn.Linear(out_dim, out_dim), 71 | # nn.ReLU(), 72 | nn.Linear(out_dim, out_dim), 73 | nn.ReLU(), 74 | ) 75 | 76 | def forward(self, x): 77 | x_intr = self.mod(x) 78 | # print('dbg', x_intr.shape, x.shape) 79 | x_intr = x_intr + x 80 | x_mu = self.block_mu(x_intr) 81 | x_1alpha = self.block_alpha(x_intr) 82 | x_beta = self.block_beta(x_intr) 83 | return x_mu, x_1alpha, x_beta 84 | 85 | class BayesCap_HF_MLP(nn.Module): 86 | ''' 87 | Baseclass to create a simple MLP 88 | Inputs 89 | inp_dim: int, Input dimension 90 | out_dim: int, Output dimension 91 | hid_dim: int, hidden dimension 92 | num_layers: Number of hidden layers 93 | p_drop: dropout probability 94 | ''' 95 | def __init__( 96 | self, 97 | inp_dim, 98 | out_dim, 99 | hid_dim=512, 100 | num_layers=1, 101 | p_drop=0, 102 | ): 103 | super(BayesCap_MLP, self).__init__() 104 | mod = [] 105 | for layer in range(num_layers): 106 | if layer==0: 107 | incoming = inp_dim 108 | outgoing = hid_dim 109 | mod.append(nn.Linear(incoming, outgoing)) 110 | mod.append(nn.ReLU()) 111 | elif layer==num_layers//2: 112 | incoming = hid_dim 113 | outgoing = hid_dim 114 | mod.append(nn.Linear(incoming, outgoing)) 115 | mod.append(nn.ReLU()) 116 | mod.append(nn.Dropout(p=p_drop)) 117 | elif layer==num_layers-1: 118 | incoming = hid_dim 119 | outgoing = out_dim 120 | mod.append(nn.Linear(incoming, outgoing)) 121 | self.mod = nn.Sequential(*mod) 122 | 123 | self.block_mu = nn.Sequential( 124 | nn.Linear(out_dim, 128), 125 | nn.ReLU(), 126 | nn.Linear(128, out_dim), 127 | ) 128 | 129 | self.block_alpha = nn.Sequential( 130 | nn.Linear(out_dim, 128), 131 | nn.ReLU(), 132 | # nn.Linear(out_dim, out_dim), 133 | # nn.ReLU(), 134 | nn.Linear(128, out_dim), 135 | nn.ReLU(), 136 | ) 137 | 138 | self.block_beta = nn.Sequential( 139 | nn.Linear(out_dim, 128), 140 | nn.ReLU(), 141 | # nn.Linear(out_dim, out_dim), 142 | # nn.ReLU(), 143 | nn.Linear(128, out_dim), 144 | nn.ReLU(), 145 | ) 146 | 147 | def forward(self, x): 148 | x_intr = self.mod(x) 149 | # print('dbg', x_intr.shape, x.shape) 150 | x_intr = x_intr + x 151 | x_mu = self.block_mu(x_intr) 152 | x_1alpha = self.block_alpha(x_intr) 153 | x_beta = self.block_beta(x_intr) 154 | return x_mu, x_1alpha, x_beta 155 | 156 | 157 | class BayesCLIP(nn.Module): 158 | def __init__( 159 | self, 160 | model_path=None, 161 | device='cuda', 162 | ): 163 | super(BayesCLIP, self).__init__() 164 | self.clip_model = load_model(device, model_path) 165 | self.clip_model.eval() 166 | for param in self.clip_model.parameters(): 167 | param.requires_grad = False 168 | 169 | self.img_BayesCap = BayesCap_MLP(inp_dim=512, out_dim=512, hid_dim=512, num_layers=3, p_drop=0.3).to(device) 170 | self.txt_BayesCap = BayesCap_MLP(inp_dim=512, out_dim=512, hid_dim=512, num_layers=3, p_drop=0.3).to(device) 171 | 172 | def forward(self, i_inputs, t_inputs): 173 | i_features, t_features = self.clip_model(i_inputs, t_inputs) 174 | 175 | img_mu, img_1alpha, img_beta = self.img_BayesCap(i_features) 176 | txt_mu, txt_1alpha, txt_beta = self.txt_BayesCap(t_features) 177 | 178 | return (img_mu, img_1alpha, img_beta), (txt_mu, txt_1alpha, txt_beta), (i_features, t_features) 179 | 180 | 181 | class BayesCap_for_CLIP(nn.Module): 182 | def __init__( 183 | self, 184 | inp_dim=512, 185 | out_dim=512, 186 | hid_dim=256, 187 | num_layers=3, 188 | p_drop=0.1, 189 | ): 190 | super(BayesCap_for_CLIP, self).__init__() 191 | self.img_BayesCap = BayesCap_MLP(inp_dim=inp_dim, out_dim=out_dim, hid_dim=hid_dim, num_layers=num_layers, p_drop=p_drop) 192 | self.txt_BayesCap = BayesCap_MLP(inp_dim=inp_dim, out_dim=out_dim, hid_dim=hid_dim, num_layers=num_layers, p_drop=p_drop) 193 | 194 | def forward(self, i_features, t_features): 195 | 196 | # print('dbg', i_features.shape, t_features.shape) 197 | img_mu, img_1alpha, img_beta = self.img_BayesCap(i_features) 198 | txt_mu, txt_1alpha, txt_beta = self.txt_BayesCap(t_features) 199 | 200 | return (img_mu, img_1alpha, img_beta), (txt_mu, txt_1alpha, txt_beta) 201 | 202 | 203 | class BayesCap_for_HF_CLIP(nn.Module): 204 | def __init__( 205 | self, 206 | # inp_i_dim=512, 207 | # out_i_dim=512, 208 | # hid_i_dim=256, 209 | inp_t_dim=512, 210 | out_t_dim=512, 211 | hid_t_dim=256, 212 | num_layers=3, 213 | p_drop=0.1, 214 | ): 215 | super(BayesCap_for_HF_CLIP, self).__init__() 216 | # self.img_BayesCap = BayesCap_MLP(inp_dim=inp_i_dim, out_dim=out_i_dim, hid_dim=hid_i_dim, num_layers=num_layers, p_drop=p_drop) 217 | self.txt_BayesCap = BayesCap_MLP(inp_dim=inp_t_dim, out_dim=out_t_dim, hid_dim=hid_t_dim, num_layers=num_layers, p_drop=p_drop) 218 | 219 | def forward(self, i_features, t_features): 220 | # img_mu, img_1alpha, img_beta = self.img_BayesCap(i_features) 221 | txt_mu, txt_1alpha, txt_beta = self.txt_BayesCap(t_features) 222 | 223 | # return (img_mu, img_1alpha, img_beta), (txt_mu, txt_1alpha, txt_beta) 224 | return (None, None, None), (txt_mu, txt_1alpha, txt_beta) -------------------------------------------------------------------------------- /src/ds/fashion200k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import random 6 | import PIL 7 | 8 | class BaseDataset(Dataset): 9 | """Base class for a dataset.""" 10 | 11 | def __init__(self): 12 | super(BaseDataset, self).__init__() 13 | self.imgs = [] 14 | self.test_queries = [] 15 | 16 | def get_loader(self, 17 | batch_size, 18 | shuffle=False, 19 | drop_last=False, 20 | num_workers=0): 21 | return torch.utils.data.DataLoader( 22 | self, 23 | batch_size=batch_size, 24 | shuffle=shuffle, 25 | num_workers=num_workers, 26 | drop_last=drop_last, 27 | collate_fn=lambda i: i) 28 | 29 | def get_test_queries(self): 30 | return self.test_queries 31 | 32 | def get_all_texts(self): 33 | raise NotImplementedError 34 | 35 | def __getitem__(self, idx): 36 | return self.generate_random_query_target() 37 | 38 | def generate_random_query_target(self): 39 | raise NotImplementedError 40 | 41 | def get_img(self, idx, raw_img=False): 42 | raise NotImplementedError 43 | 44 | 45 | class Fashion200k(BaseDataset): 46 | """Fashion200k dataset.""" 47 | 48 | def __init__(self, path, split='train', transform=None,target_transform=None): 49 | super(Fashion200k, self).__init__() 50 | 51 | self.split = split 52 | self.transform = transform 53 | self.target_transform = target_transform 54 | self.img_path = path + '/' 55 | 56 | # get label files for the split 57 | label_path = path + '/labels/' 58 | from os import listdir 59 | from os.path import isfile 60 | from os.path import join 61 | label_files = [ 62 | f for f in listdir(label_path) if isfile(join(label_path, f)) 63 | ] 64 | label_files = [f for f in label_files if split in f] 65 | 66 | # read image info from label files 67 | self.imgs = [] 68 | 69 | def caption_post_process(s): 70 | return s.strip().replace('.', 71 | 'dotmark').replace('?', 'questionmark').replace( 72 | '&', 'andmark').replace('*', 'starmark') 73 | 74 | for filename in label_files: 75 | print('read ' + filename) 76 | with open(label_path + '/' + filename) as f: 77 | lines = f.readlines() 78 | for line in lines: 79 | line = line.split(' ') 80 | img = { 81 | 'file_path': line[0], 82 | 'detection_score': line[1], 83 | 'captions': [caption_post_process(line[2])], 84 | 'split': split, 85 | 'modifiable': False 86 | } 87 | self.imgs += [img] 88 | print('Fashion200k:', len(self.imgs), 'images') 89 | 90 | # generate query for training or testing 91 | if split == 'train': 92 | self.caption_index_init_() 93 | else: 94 | self.generate_test_queries_() 95 | 96 | def get_different_word(self, source_caption, target_caption): 97 | source_words = source_caption.split() 98 | target_words = target_caption.split() 99 | for source_word in source_words: 100 | if source_word not in target_words: 101 | break 102 | for target_word in target_words: 103 | if target_word not in source_words: 104 | break 105 | mod_str = 'replace ' + source_word + ' with ' + target_word 106 | return source_word, target_word, mod_str 107 | 108 | def generate_test_queries_(self): 109 | file2imgid = {} 110 | for i, img in enumerate(self.imgs): 111 | file2imgid[img['file_path']] = i 112 | with open(self.img_path + '/test_queries.txt') as f: 113 | lines = f.readlines() 114 | self.test_queries = [] 115 | for line in lines: 116 | source_file, target_file = line.split() 117 | idx = file2imgid[source_file] 118 | target_idx = file2imgid[target_file] 119 | source_caption = self.imgs[idx]['captions'][0] 120 | target_caption = self.imgs[target_idx]['captions'][0] 121 | source_word, target_word, mod_str = self.get_different_word( 122 | source_caption, target_caption) 123 | self.test_queries += [{ 124 | 'source_img_id': idx, 125 | 'source_caption': source_caption, 126 | 'target_caption': target_caption, 127 | 'mod': { 128 | 'str': mod_str 129 | } 130 | }] 131 | 132 | def caption_index_init_(self): 133 | """ index caption to generate training query-target example on the fly later""" 134 | 135 | # index caption 2 caption_id and caption 2 image_ids 136 | caption2id = {} 137 | id2caption = {} 138 | caption2imgids = {} 139 | for i, img in enumerate(self.imgs): 140 | for c in img['captions']: 141 | if c not in caption2id: 142 | id2caption[len(caption2id)] = c 143 | caption2id[c] = len(caption2id) 144 | caption2imgids[c] = [] 145 | caption2imgids[c].append(i) 146 | self.caption2imgids = caption2imgids 147 | print(len(caption2imgids), 'unique cations') 148 | 149 | # parent captions are 1-word shorter than their children 150 | parent2children_captions = {} 151 | for c in caption2id.keys(): 152 | for w in c.split(): 153 | p = c.replace(w, '') 154 | p = p.replace(' ', ' ').strip() 155 | if p not in parent2children_captions: 156 | parent2children_captions[p] = [] 157 | if c not in parent2children_captions[p]: 158 | parent2children_captions[p].append(c) 159 | self.parent2children_captions = parent2children_captions 160 | 161 | # identify parent captions for each image 162 | for img in self.imgs: 163 | img['modifiable'] = False 164 | img['parent_captions'] = [] 165 | for p in parent2children_captions: 166 | if len(parent2children_captions[p]) >= 2: 167 | for c in parent2children_captions[p]: 168 | for imgid in caption2imgids[c]: 169 | self.imgs[imgid]['modifiable'] = True 170 | self.imgs[imgid]['parent_captions'] += [p] 171 | num_modifiable_imgs = 0 172 | for img in self.imgs: 173 | if img['modifiable']: 174 | num_modifiable_imgs += 1 175 | print('Modifiable images', num_modifiable_imgs) 176 | 177 | def caption_index_sample_(self, idx): 178 | while not self.imgs[idx]['modifiable']: 179 | idx = np.random.randint(0, len(self.imgs)) 180 | 181 | # find random target image (same parent) 182 | img = self.imgs[idx] 183 | while True: 184 | p = random.choice(img['parent_captions']) 185 | c = random.choice(self.parent2children_captions[p]) 186 | if c not in img['captions']: 187 | break 188 | target_idx = random.choice(self.caption2imgids[c]) 189 | 190 | # find the word difference between query and target (not in parent caption) 191 | source_caption = self.imgs[idx]['captions'][0] 192 | target_caption = self.imgs[target_idx]['captions'][0] 193 | source_word, target_word, mod_str = self.get_different_word( 194 | source_caption, target_caption) 195 | return idx, target_idx, source_word, target_word, mod_str 196 | 197 | def get_all_texts(self): 198 | texts = [] 199 | for img in self.imgs: 200 | for c in img['captions']: 201 | texts.append(c) 202 | return texts 203 | 204 | def __len__(self): 205 | return len(self.imgs) 206 | 207 | def __getitem__(self, idx): 208 | idx, target_idx, source_word, target_word, mod_str = self.caption_index_sample_( 209 | idx) 210 | out = {} 211 | out['source_img_id'] = idx 212 | out['source_img_data'] = self.get_img(idx) 213 | out['source_caption'] = self.imgs[idx]['captions'][0] 214 | out['target_img_id'] = target_idx 215 | out['target_img_data'] = self.get_img(target_idx) 216 | out['target_caption'] = self.imgs[target_idx]['captions'][0] 217 | out['mod'] = {'str': mod_str} 218 | if self.target_transform: 219 | out['mod']['str'] = self.target_transform(mod_str) 220 | out['mod']['str'] = out['mod']['str'].squeeze(0) 221 | out['source_caption'] = self.target_transform(out['source_caption']) 222 | out['target_caption'] = self.target_transform(out['target_caption']) 223 | 224 | return out 225 | 226 | def get_img(self, idx, raw_img=False): 227 | img_path = self.img_path + self.imgs[idx]['file_path'] 228 | with open(img_path, 'rb') as f: 229 | img = PIL.Image.open(f) 230 | img = img.convert('RGB') 231 | 232 | if self.transform: 233 | img = self.transform(img) 234 | return img -------------------------------------------------------------------------------- /src/ds_lavis/fashion200k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | import random 6 | import PIL 7 | 8 | class BaseDataset(Dataset): 9 | """Base class for a dataset.""" 10 | 11 | def __init__(self): 12 | super(BaseDataset, self).__init__() 13 | self.imgs = [] 14 | self.test_queries = [] 15 | 16 | def get_loader(self, 17 | batch_size, 18 | shuffle=False, 19 | drop_last=False, 20 | num_workers=0): 21 | return torch.utils.data.DataLoader( 22 | self, 23 | batch_size=batch_size, 24 | shuffle=shuffle, 25 | num_workers=num_workers, 26 | drop_last=drop_last, 27 | collate_fn=lambda i: i) 28 | 29 | def get_test_queries(self): 30 | return self.test_queries 31 | 32 | def get_all_texts(self): 33 | raise NotImplementedError 34 | 35 | def __getitem__(self, idx): 36 | return self.generate_random_query_target() 37 | 38 | def generate_random_query_target(self): 39 | raise NotImplementedError 40 | 41 | def get_img(self, idx, raw_img=False): 42 | raise NotImplementedError 43 | 44 | 45 | class Fashion200k(BaseDataset): 46 | """Fashion200k dataset.""" 47 | 48 | def __init__(self, path, split='train', transform=None,target_transform=None): 49 | super(Fashion200k, self).__init__() 50 | 51 | self.split = split 52 | self.transform = transform 53 | self.target_transform = target_transform 54 | self.img_path = path + '/' 55 | 56 | # get label files for the split 57 | label_path = path + '/labels/' 58 | from os import listdir 59 | from os.path import isfile 60 | from os.path import join 61 | label_files = [ 62 | f for f in listdir(label_path) if isfile(join(label_path, f)) 63 | ] 64 | label_files = [f for f in label_files if split in f] 65 | 66 | # read image info from label files 67 | self.imgs = [] 68 | 69 | def caption_post_process(s): 70 | return s.strip().replace('.', 71 | 'dotmark').replace('?', 'questionmark').replace( 72 | '&', 'andmark').replace('*', 'starmark') 73 | 74 | for filename in label_files: 75 | print('read ' + filename) 76 | with open(label_path + '/' + filename) as f: 77 | lines = f.readlines() 78 | for line in lines: 79 | line = line.split(' ') 80 | img = { 81 | 'file_path': line[0], 82 | 'detection_score': line[1], 83 | 'captions': [caption_post_process(line[2])], 84 | 'split': split, 85 | 'modifiable': False 86 | } 87 | self.imgs += [img] 88 | print('Fashion200k:', len(self.imgs), 'images') 89 | 90 | # generate query for training or testing 91 | if split == 'train': 92 | self.caption_index_init_() 93 | else: 94 | self.generate_test_queries_() 95 | 96 | def get_different_word(self, source_caption, target_caption): 97 | source_words = source_caption.split() 98 | target_words = target_caption.split() 99 | for source_word in source_words: 100 | if source_word not in target_words: 101 | break 102 | for target_word in target_words: 103 | if target_word not in source_words: 104 | break 105 | mod_str = 'replace ' + source_word + ' with ' + target_word 106 | return source_word, target_word, mod_str 107 | 108 | def generate_test_queries_(self): 109 | file2imgid = {} 110 | for i, img in enumerate(self.imgs): 111 | file2imgid[img['file_path']] = i 112 | with open(self.img_path + '/test_queries.txt') as f: 113 | lines = f.readlines() 114 | self.test_queries = [] 115 | for line in lines: 116 | source_file, target_file = line.split() 117 | idx = file2imgid[source_file] 118 | target_idx = file2imgid[target_file] 119 | source_caption = self.imgs[idx]['captions'][0] 120 | target_caption = self.imgs[target_idx]['captions'][0] 121 | source_word, target_word, mod_str = self.get_different_word( 122 | source_caption, target_caption) 123 | self.test_queries += [{ 124 | 'source_img_id': idx, 125 | 'source_caption': source_caption, 126 | 'target_caption': target_caption, 127 | 'mod': { 128 | 'str': mod_str 129 | } 130 | }] 131 | 132 | def caption_index_init_(self): 133 | """ index caption to generate training query-target example on the fly later""" 134 | 135 | # index caption 2 caption_id and caption 2 image_ids 136 | caption2id = {} 137 | id2caption = {} 138 | caption2imgids = {} 139 | for i, img in enumerate(self.imgs): 140 | for c in img['captions']: 141 | if c not in caption2id: 142 | id2caption[len(caption2id)] = c 143 | caption2id[c] = len(caption2id) 144 | caption2imgids[c] = [] 145 | caption2imgids[c].append(i) 146 | self.caption2imgids = caption2imgids 147 | print(len(caption2imgids), 'unique cations') 148 | 149 | # parent captions are 1-word shorter than their children 150 | parent2children_captions = {} 151 | for c in caption2id.keys(): 152 | for w in c.split(): 153 | p = c.replace(w, '') 154 | p = p.replace(' ', ' ').strip() 155 | if p not in parent2children_captions: 156 | parent2children_captions[p] = [] 157 | if c not in parent2children_captions[p]: 158 | parent2children_captions[p].append(c) 159 | self.parent2children_captions = parent2children_captions 160 | 161 | # identify parent captions for each image 162 | for img in self.imgs: 163 | img['modifiable'] = False 164 | img['parent_captions'] = [] 165 | for p in parent2children_captions: 166 | if len(parent2children_captions[p]) >= 2: 167 | for c in parent2children_captions[p]: 168 | for imgid in caption2imgids[c]: 169 | self.imgs[imgid]['modifiable'] = True 170 | self.imgs[imgid]['parent_captions'] += [p] 171 | num_modifiable_imgs = 0 172 | for img in self.imgs: 173 | if img['modifiable']: 174 | num_modifiable_imgs += 1 175 | print('Modifiable images', num_modifiable_imgs) 176 | 177 | def caption_index_sample_(self, idx): 178 | while not self.imgs[idx]['modifiable']: 179 | idx = np.random.randint(0, len(self.imgs)) 180 | 181 | # find random target image (same parent) 182 | img = self.imgs[idx] 183 | while True: 184 | p = random.choice(img['parent_captions']) 185 | c = random.choice(self.parent2children_captions[p]) 186 | if c not in img['captions']: 187 | break 188 | target_idx = random.choice(self.caption2imgids[c]) 189 | 190 | # find the word difference between query and target (not in parent caption) 191 | source_caption = self.imgs[idx]['captions'][0] 192 | target_caption = self.imgs[target_idx]['captions'][0] 193 | source_word, target_word, mod_str = self.get_different_word( 194 | source_caption, target_caption) 195 | return idx, target_idx, source_word, target_word, mod_str 196 | 197 | def get_all_texts(self): 198 | texts = [] 199 | for img in self.imgs: 200 | for c in img['captions']: 201 | texts.append(c) 202 | return texts 203 | 204 | def __len__(self): 205 | return len(self.imgs) 206 | 207 | def __getitem__(self, idx): 208 | idx, target_idx, source_word, target_word, mod_str = self.caption_index_sample_( 209 | idx) 210 | out = {} 211 | out['source_img_id'] = idx 212 | out['source_img_data'] = self.get_img(idx) 213 | out['source_caption'] = self.imgs[idx]['captions'][0] 214 | out['target_img_id'] = target_idx 215 | out['target_img_data'] = self.get_img(target_idx) 216 | out['target_caption'] = self.imgs[target_idx]['captions'][0] 217 | out['mod'] = {'str': mod_str} 218 | if self.target_transform: 219 | out['mod']['str'] = self.target_transform(mod_str) 220 | out['mod']['str'] = out['mod']['str'].squeeze(0) 221 | out['source_caption'] = self.target_transform(out['source_caption']) 222 | out['target_caption'] = self.target_transform(out['target_caption']) 223 | 224 | return out 225 | 226 | def get_img(self, idx, raw_img=False): 227 | img_path = self.img_path + self.imgs[idx]['file_path'] 228 | with open(img_path, 'rb') as f: 229 | img = PIL.Image.open(f) 230 | img = img.convert('RGB') 231 | 232 | if self.transform: 233 | img = self.transform(img) 234 | return img -------------------------------------------------------------------------------- /src/ds/_transforms.py: -------------------------------------------------------------------------------- 1 | """Custom transform functions 2 | 3 | reference codes: 4 | https://github.com/yalesong/pvse/blob/master/data.py 5 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 6 | """ 7 | 8 | from functools import partial 9 | 10 | from nltk.tokenize import word_tokenize 11 | 12 | import random 13 | import math 14 | from copy import deepcopy 15 | import torch 16 | from torchvision import transforms 17 | 18 | 19 | def imagenet_normalize(): 20 | """Standard ImageNet normalize transform 21 | """ 22 | # return transforms.Normalize(mean=[0.485, 0.456, 0.406], 23 | # std=[0.229, 0.224, 0.225]) 24 | return transforms.Normalize( 25 | mean=(0.48145466, 0.4578275, 0.40821073), 26 | std=(0.26862954, 0.26130258, 0.27577711)) 27 | 28 | 29 | 30 | 31 | def imagenet_transform(resize_size=224, 32 | crop_size=224, 33 | random_resize_crop=False, 34 | random_erasing_prob=0.0, 35 | custom_transforms=None): 36 | """Standard ImageNet transform with resize/crop/normalize. 37 | 38 | Args: 39 | resize_size (int, Default: 256): resize for validation 40 | (only used when random_resize_crop is False). 41 | crop_size (int, Default: 224): final crop size. 42 | random_resize_crop (bool, Default: False): if True, use random transform (for training), 43 | if False, use center crop (for validation). 44 | custom_transforms (list of transform, Default: None): additional transforms. 45 | """ 46 | if custom_transforms is not None: 47 | if not isinstance(custom_transforms, list): 48 | raise TypeError(f'custom_transforms should be list, not {type(custom_transforms)}') 49 | transform = [] 50 | if random_resize_crop: 51 | transform.append(transforms.RandomResizedCrop(crop_size)) 52 | transform.append(transforms.RandomHorizontalFlip()) 53 | else: 54 | transform.append(transforms.Resize(resize_size)) 55 | transform.append(transforms.CenterCrop(crop_size)) 56 | transform.append(transforms.ToTensor()) 57 | transform.append(imagenet_normalize()) 58 | 59 | if custom_transforms: 60 | transform.extend(custom_transforms) 61 | 62 | # if random_erasing_prob > 0: 63 | # print(f'adding cutout {random_erasing_prob}') 64 | # transform.append(RandomErasing(random_erasing_prob, 65 | # mode='const', 66 | # max_count=1, num_splits=0, device='cpu')) 67 | transform.append(RandomErasing(random_erasing_prob, 68 | mode='const', 69 | max_count=1, num_splits=0, device='cpu')) 70 | 71 | transform = transforms.Compose(transform) 72 | print("Transform Called") 73 | return transform 74 | 75 | 76 | def tokenize(sentence, vocab, caption_drop_prob): 77 | """nltk word_tokenize for caption transform. 78 | """ 79 | tokens = word_tokenize(str(sentence).lower()) 80 | tokenized_sentence = [] 81 | tokenized_sentence.append(vocab('')) 82 | tokenized = [vocab(token) for token in tokens] 83 | if caption_drop_prob > 0: 84 | unk = vocab('') 85 | tokenized = [vocab(token) if random.random() > caption_drop_prob else unk for token in tokens] 86 | else: 87 | tokenized = [vocab(token) for token in tokens] 88 | if caption_drop_prob: 89 | N = int(len(tokenized) * caption_drop_prob) 90 | for _ in range(N): 91 | tokenized.pop(random.randrange(len(tokenized))) 92 | tokenized_sentence.extend(tokenized) 93 | tokenized_sentence.append(vocab('')) 94 | return torch.Tensor(tokenized_sentence) 95 | 96 | 97 | def caption_transform(vocab, caption_drop_prob=0): 98 | """Transform for captions. 99 | "caption drop augmentation" randomly alters the given input tokens as 100 | """ 101 | transform = [] 102 | if caption_drop_prob < 0 or caption_drop_prob is None: 103 | print('warning: wrong caption drop prob', caption_drop_prob, 'set to zero') 104 | caption_drop_prob = 0 105 | elif caption_drop_prob > 0: 106 | print('adding caption drop prob', caption_drop_prob) 107 | transform.append(partial(tokenize, vocab=vocab, caption_drop_prob=caption_drop_prob)) 108 | transform = transforms.Compose(transform) 109 | return transform 110 | 111 | 112 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): 113 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 114 | # paths, flip the order so normal is run on CPU if this becomes a problem 115 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 116 | if per_pixel: 117 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 118 | elif rand_color: 119 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() 120 | else: 121 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 122 | 123 | 124 | class RandomErasing: 125 | """ Randomly selects a rectangle region in an image and erases its pixels. 126 | 'Random Erasing Data Augmentation' by Zhong et al. 127 | See https://arxiv.org/pdf/1708.04896.pdf 128 | 129 | This variant of RandomErasing is intended to be applied to either a batch 130 | or single image tensor after it has been normalized by dataset mean and std. 131 | Args: 132 | probability: Probability that the Random Erasing operation will be performed. 133 | min_area: Minimum percentage of erased area wrt input image area. 134 | max_area: Maximum percentage of erased area wrt input image area. 135 | min_aspect: Minimum aspect ratio of erased area. 136 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 137 | 'const' - erase block is constant color of 0 for all channels 138 | 'rand' - erase block is same per-channel random (normal) color 139 | 'pixel' - erase block is per-pixel random (normal) color 140 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 141 | per-image count is randomly chosen between 1 and this value. 142 | """ 143 | 144 | def __init__( 145 | self, 146 | probability=0.5, min_area=0.02, max_area=1 / 3, min_aspect=0.3, max_aspect=None, 147 | # probability=0.5, min_area=0.1, max_area=1 / 2, min_aspect=0.3, max_aspect=None, 148 | mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): 149 | self.probability = probability 150 | self.min_area = min_area 151 | self.max_area = max_area 152 | max_aspect = max_aspect or 1 / min_aspect 153 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 154 | self.min_count = min_count 155 | self.max_count = max_count or min_count 156 | self.num_splits = num_splits 157 | mode = mode.lower() 158 | self.rand_color = False 159 | self.per_pixel = False 160 | if mode == 'rand': 161 | self.rand_color = True # per block random normal 162 | elif mode == 'pixel': 163 | self.per_pixel = True # per pixel random normal 164 | else: 165 | assert not mode or mode == 'const' 166 | self.device = device 167 | 168 | def _erase(self, input, chan, img_h, img_w, dtype): 169 | if random.random() > self.probability: 170 | return input, 0 171 | img = deepcopy(input) 172 | area = img_h * img_w 173 | count = self.min_count if self.min_count == self.max_count else \ 174 | random.randint(self.min_count, self.max_count) 175 | for _ in range(count): 176 | for attempt in range(10): 177 | target_area = random.uniform(self.min_area, self.max_area) * area / count 178 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 179 | h = int(round(math.sqrt(target_area * aspect_ratio))) 180 | w = int(round(math.sqrt(target_area / aspect_ratio))) 181 | if w < img_w and h < img_h: 182 | top = random.randint(0, img_h - h) 183 | left = random.randint(0, img_w - w) 184 | img[:, top:top + h, left:left + w] = _get_pixels( 185 | self.per_pixel, self.rand_color, (chan, h, w), 186 | dtype=dtype, device=self.device) 187 | break 188 | return img, 1 189 | 190 | def __call__(self, input): 191 | if len(input.size()) == 3: 192 | input_masked, is_masked = self._erase(input, *input.size(), input.dtype) 193 | else: 194 | raise ValueError("TODO") 195 | # batch_size, chan, img_h, img_w = input.size() 196 | # # skip first slice of batch if num_splits is set (for clean portion of samples) 197 | # batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 198 | # for i in range(batch_start, batch_size): 199 | # ?? = self._erase(input[i], chan, img_h, img_w, input.dtype) 200 | return input, input_masked, is_masked 201 | -------------------------------------------------------------------------------- /src/ds_lavis/_transforms.py: -------------------------------------------------------------------------------- 1 | """Custom transform functions 2 | 3 | reference codes: 4 | https://github.com/yalesong/pvse/blob/master/data.py 5 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 6 | """ 7 | 8 | from functools import partial 9 | 10 | from nltk.tokenize import word_tokenize 11 | 12 | import random 13 | import math 14 | from copy import deepcopy 15 | import torch 16 | from torchvision import transforms 17 | 18 | 19 | def imagenet_normalize(): 20 | """Standard ImageNet normalize transform 21 | """ 22 | # return transforms.Normalize(mean=[0.485, 0.456, 0.406], 23 | # std=[0.229, 0.224, 0.225]) 24 | return transforms.Normalize( 25 | mean=(0.48145466, 0.4578275, 0.40821073), 26 | std=(0.26862954, 0.26130258, 0.27577711)) 27 | 28 | 29 | 30 | 31 | def imagenet_transform(resize_size=224, 32 | crop_size=224, 33 | random_resize_crop=False, 34 | random_erasing_prob=0.0, 35 | custom_transforms=None): 36 | """Standard ImageNet transform with resize/crop/normalize. 37 | 38 | Args: 39 | resize_size (int, Default: 256): resize for validation 40 | (only used when random_resize_crop is False). 41 | crop_size (int, Default: 224): final crop size. 42 | random_resize_crop (bool, Default: False): if True, use random transform (for training), 43 | if False, use center crop (for validation). 44 | custom_transforms (list of transform, Default: None): additional transforms. 45 | """ 46 | if custom_transforms is not None: 47 | if not isinstance(custom_transforms, list): 48 | raise TypeError(f'custom_transforms should be list, not {type(custom_transforms)}') 49 | transform = [] 50 | if random_resize_crop: 51 | transform.append(transforms.RandomResizedCrop(crop_size)) 52 | transform.append(transforms.RandomHorizontalFlip()) 53 | else: 54 | transform.append(transforms.Resize(resize_size)) 55 | transform.append(transforms.CenterCrop(crop_size)) 56 | transform.append(transforms.ToTensor()) 57 | transform.append(imagenet_normalize()) 58 | 59 | if custom_transforms: 60 | transform.extend(custom_transforms) 61 | 62 | # if random_erasing_prob > 0: 63 | # print(f'adding cutout {random_erasing_prob}') 64 | # transform.append(RandomErasing(random_erasing_prob, 65 | # mode='const', 66 | # max_count=1, num_splits=0, device='cpu')) 67 | transform.append(RandomErasing(random_erasing_prob, 68 | mode='const', 69 | max_count=1, num_splits=0, device='cpu')) 70 | 71 | transform = transforms.Compose(transform) 72 | print("Transform Called") 73 | return transform 74 | 75 | 76 | def tokenize(sentence, vocab, caption_drop_prob): 77 | """nltk word_tokenize for caption transform. 78 | """ 79 | tokens = word_tokenize(str(sentence).lower()) 80 | tokenized_sentence = [] 81 | tokenized_sentence.append(vocab('')) 82 | tokenized = [vocab(token) for token in tokens] 83 | if caption_drop_prob > 0: 84 | unk = vocab('') 85 | tokenized = [vocab(token) if random.random() > caption_drop_prob else unk for token in tokens] 86 | else: 87 | tokenized = [vocab(token) for token in tokens] 88 | if caption_drop_prob: 89 | N = int(len(tokenized) * caption_drop_prob) 90 | for _ in range(N): 91 | tokenized.pop(random.randrange(len(tokenized))) 92 | tokenized_sentence.extend(tokenized) 93 | tokenized_sentence.append(vocab('')) 94 | return torch.Tensor(tokenized_sentence) 95 | 96 | 97 | def caption_transform(vocab, caption_drop_prob=0): 98 | """Transform for captions. 99 | "caption drop augmentation" randomly alters the given input tokens as 100 | """ 101 | transform = [] 102 | if caption_drop_prob < 0 or caption_drop_prob is None: 103 | print('warning: wrong caption drop prob', caption_drop_prob, 'set to zero') 104 | caption_drop_prob = 0 105 | elif caption_drop_prob > 0: 106 | print('adding caption drop prob', caption_drop_prob) 107 | transform.append(partial(tokenize, vocab=vocab, caption_drop_prob=caption_drop_prob)) 108 | transform = transforms.Compose(transform) 109 | return transform 110 | 111 | 112 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): 113 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 114 | # paths, flip the order so normal is run on CPU if this becomes a problem 115 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 116 | if per_pixel: 117 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 118 | elif rand_color: 119 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() 120 | else: 121 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 122 | 123 | 124 | class RandomErasing: 125 | """ Randomly selects a rectangle region in an image and erases its pixels. 126 | 'Random Erasing Data Augmentation' by Zhong et al. 127 | See https://arxiv.org/pdf/1708.04896.pdf 128 | 129 | This variant of RandomErasing is intended to be applied to either a batch 130 | or single image tensor after it has been normalized by dataset mean and std. 131 | Args: 132 | probability: Probability that the Random Erasing operation will be performed. 133 | min_area: Minimum percentage of erased area wrt input image area. 134 | max_area: Maximum percentage of erased area wrt input image area. 135 | min_aspect: Minimum aspect ratio of erased area. 136 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 137 | 'const' - erase block is constant color of 0 for all channels 138 | 'rand' - erase block is same per-channel random (normal) color 139 | 'pixel' - erase block is per-pixel random (normal) color 140 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 141 | per-image count is randomly chosen between 1 and this value. 142 | """ 143 | 144 | def __init__( 145 | self, 146 | probability=0.5, min_area=0.02, max_area=1 / 3, min_aspect=0.3, max_aspect=None, 147 | # probability=0.5, min_area=0.1, max_area=1 / 2, min_aspect=0.3, max_aspect=None, 148 | mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): 149 | self.probability = probability 150 | self.min_area = min_area 151 | self.max_area = max_area 152 | max_aspect = max_aspect or 1 / min_aspect 153 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 154 | self.min_count = min_count 155 | self.max_count = max_count or min_count 156 | self.num_splits = num_splits 157 | mode = mode.lower() 158 | self.rand_color = False 159 | self.per_pixel = False 160 | if mode == 'rand': 161 | self.rand_color = True # per block random normal 162 | elif mode == 'pixel': 163 | self.per_pixel = True # per pixel random normal 164 | else: 165 | assert not mode or mode == 'const' 166 | self.device = device 167 | 168 | def _erase(self, input, chan, img_h, img_w, dtype): 169 | if random.random() > self.probability: 170 | return input, 0 171 | img = deepcopy(input) 172 | area = img_h * img_w 173 | count = self.min_count if self.min_count == self.max_count else \ 174 | random.randint(self.min_count, self.max_count) 175 | for _ in range(count): 176 | for attempt in range(10): 177 | target_area = random.uniform(self.min_area, self.max_area) * area / count 178 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 179 | h = int(round(math.sqrt(target_area * aspect_ratio))) 180 | w = int(round(math.sqrt(target_area / aspect_ratio))) 181 | if w < img_w and h < img_h: 182 | top = random.randint(0, img_h - h) 183 | left = random.randint(0, img_w - w) 184 | img[:, top:top + h, left:left + w] = _get_pixels( 185 | self.per_pixel, self.rand_color, (chan, h, w), 186 | dtype=dtype, device=self.device) 187 | break 188 | return img, 1 189 | 190 | def __call__(self, input): 191 | if len(input.size()) == 3: 192 | input_masked, is_masked = self._erase(input, *input.size(), input.dtype) 193 | else: 194 | raise ValueError("TODO") 195 | # batch_size, chan, img_h, img_w = input.size() 196 | # # skip first slice of batch if num_splits is set (for clean portion of samples) 197 | # batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 198 | # for i in range(batch_start, batch_size): 199 | # ?? = self._erase(input[i], chan, img_h, img_w, input.dtype) 200 | return input, input_masked, is_masked 201 | -------------------------------------------------------------------------------- /src/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | } 40 | 41 | 42 | def _download(url: str, root: str): 43 | os.makedirs(root, exist_ok=True) 44 | filename = os.path.basename(url) 45 | 46 | expected_sha256 = url.split("/")[-2] 47 | download_target = os.path.join(root, filename) 48 | 49 | if os.path.exists(download_target) and not os.path.isfile(download_target): 50 | raise RuntimeError(f"{download_target} exists and is not a regular file") 51 | 52 | if os.path.isfile(download_target): 53 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 54 | return download_target 55 | else: 56 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 57 | 58 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 59 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 60 | while True: 61 | buffer = source.read(8192) 62 | if not buffer: 63 | break 64 | 65 | output.write(buffer) 66 | loop.update(len(buffer)) 67 | 68 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 69 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 70 | 71 | return download_target 72 | 73 | 74 | def _convert_image_to_rgb(image): 75 | return image.convert("RGB") 76 | 77 | 78 | def _transform(n_px): 79 | return Compose([ 80 | Resize(n_px, interpolation=BICUBIC), 81 | CenterCrop(n_px), 82 | _convert_image_to_rgb, 83 | ToTensor(), 84 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 85 | ]) 86 | 87 | 88 | def available_models() -> List[str]: 89 | """Returns the names of available CLIP models""" 90 | return list(_MODELS.keys()) 91 | 92 | 93 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, 94 | loss_type: str = 'contrastive'): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | try: 127 | # loading JIT archive 128 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 129 | state_dict = None 130 | except RuntimeError: 131 | # loading saved state dict 132 | if jit: 133 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 134 | jit = False 135 | state_dict = torch.load(model_path, map_location="cpu") 136 | 137 | if not jit: 138 | model = build_model(state_dict or model.state_dict(), loss_type).to(device) 139 | if str(device) == "cpu": 140 | model.float() 141 | return model, _transform(model.visual.input_resolution) 142 | 143 | # patch the device names 144 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 145 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 146 | 147 | def patch_device(module): 148 | try: 149 | graphs = [module.graph] if hasattr(module, "graph") else [] 150 | except RuntimeError: 151 | graphs = [] 152 | 153 | if hasattr(module, "forward1"): 154 | graphs.append(module.forward1.graph) 155 | 156 | for graph in graphs: 157 | for node in graph.findAllNodes("prim::Constant"): 158 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 159 | node.copyAttributes(device_node) 160 | 161 | model.apply(patch_device) 162 | patch_device(model.encode_image) 163 | patch_device(model.encode_text) 164 | 165 | # patch dtype to float32 on CPU 166 | if str(device) == "cpu": 167 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 168 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 169 | float_node = float_input.node() 170 | 171 | def patch_float(module): 172 | try: 173 | graphs = [module.graph] if hasattr(module, "graph") else [] 174 | except RuntimeError: 175 | graphs = [] 176 | 177 | if hasattr(module, "forward1"): 178 | graphs.append(module.forward1.graph) 179 | 180 | for graph in graphs: 181 | for node in graph.findAllNodes("aten::to"): 182 | inputs = list(node.inputs()) 183 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 184 | if inputs[i].node()["value"] == 5: 185 | inputs[i].node().copyAttributes(float_node) 186 | 187 | model.apply(patch_float) 188 | patch_float(model.encode_image) 189 | patch_float(model.encode_text) 190 | 191 | model.float() 192 | 193 | return model, _transform(model.input_resolution.item()) 194 | 195 | 196 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 197 | """ 198 | Returns the tokenized representation of given input string(s) 199 | 200 | Parameters 201 | ---------- 202 | texts : Union[str, List[str]] 203 | An input string or a list of input strings to tokenize 204 | 205 | context_length : int 206 | The context length to use; all CLIP models use 77 as the context length 207 | 208 | truncate: bool 209 | Whether to truncate the text in case its encoding is longer than the context length 210 | 211 | Returns 212 | ------- 213 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 214 | """ 215 | if isinstance(texts, str): 216 | texts = [texts] 217 | 218 | sot_token = _tokenizer.encoder["<|startoftext|>"] 219 | eot_token = _tokenizer.encoder["<|endoftext|>"] 220 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 221 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 222 | 223 | for i, tokens in enumerate(all_tokens): 224 | if len(tokens) > context_length: 225 | if truncate: 226 | tokens = tokens[:context_length] 227 | tokens[-1] = eot_token 228 | else: 229 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 230 | result[i, :len(tokens)] = torch.tensor(tokens) 231 | 232 | return result 233 | -------------------------------------------------------------------------------- /src/ds/coco.py: -------------------------------------------------------------------------------- 1 | """MS-COCO image-to-caption retrieval dataset code 2 | 3 | reference codes: 4 | https://github.com/pytorch/vision/blob/v0.2.2_branch/torchvision/datasets/coco.py 5 | https://github.com/yalesong/pvse/blob/master/data.py 6 | """ 7 | 8 | import os 9 | from os.path import join as ospj 10 | try: 11 | import ujson as json 12 | except ImportError: 13 | import json 14 | 15 | from PIL import Image 16 | from pycocotools.coco import COCO 17 | 18 | import torch 19 | from torch.utils.data import Dataset 20 | 21 | 22 | class CocoCaptionsCap(Dataset): 23 | """`MS Coco Captions `_ Dataset. 24 | Args: 25 | root (string): Root directory where images are downloaded to. 26 | annFile (string): Path to json annotation file. 27 | ids (list, optional): list of target caption ids 28 | extra_annFile (string, optional): Path to extra json annotation file (for training) 29 | extra_ids (list, optional): list of extra target caption ids (for training) 30 | transform (callable, optional): A function/transform that takes in an PIL image 31 | and returns a transformed version. E.g, ``transforms.ToTensor`` 32 | target_transform (callable, optional): A function/transform that takes in the 33 | target and transforms it. 34 | instance_annFile (str, optional): Path to instance annotation json (for PMRP computation) 35 | 36 | Example: 37 | .. code:: python 38 | import torchvision.datasets as dset 39 | import torchvision.transforms as transforms 40 | cap = dset.CocoCaptions(root='dir where images are', 41 | annFile='json annotation file', 42 | transform=transforms.ToTensor()) 43 | print('Number of samples: ', len(cap)) 44 | img, target = cap[3] # load 4th sample 45 | print("Image Size: ", img.size()) 46 | print(target) 47 | Output: :: 48 | Number of samples: 82783 49 | Image Size: (3L, 427L, 640L) 50 | [u'A plane emitting smoke stream flying over a mountain.', 51 | u'A plane darts across a bright blue sky behind a mountain covered in snow', 52 | u'A plane leaves a contrail above the snowy mountain top.', 53 | u'A mountain that has a plane flying overheard in the distance.', 54 | u'A mountain view with a plume of smoke in the background'] 55 | """ 56 | def __init__(self, root, annFile, ids=None, 57 | extra_annFile=None, extra_ids=None, 58 | transform=None, target_transform=None, 59 | instance_annFile=None): 60 | self.root = os.path.expanduser(root) 61 | if extra_annFile: 62 | self.coco = COCO() 63 | with open(annFile, 'r') as fin1, open(extra_annFile, 'r') as fin2: 64 | dataset = json.load(fin1) 65 | extra_dataset = json.load(fin2) 66 | if not isinstance(dataset, dict) or not isinstance(extra_dataset, dict): 67 | raise TypeError('invalid type {} {}'.format(type(dataset), 68 | type(extra_dataset))) 69 | if set(dataset.keys()) != set(extra_dataset.keys()): 70 | raise KeyError('key mismatch {} != {}'.format(list(dataset.keys()), 71 | list(extra_dataset.keys()))) 72 | for key in ['images', 'annotations']: 73 | dataset[key].extend(extra_dataset[key]) 74 | self.coco.dataset = dataset 75 | self.coco.createIndex() 76 | else: 77 | self.coco = COCO(annFile) 78 | self.ids = list(self.coco.anns.keys()) if ids is None else list(ids) 79 | if extra_ids is not None: 80 | self.ids += list(extra_ids) 81 | self.ids = [int(id_) for id_ in self.ids] 82 | self.transform = transform 83 | self.target_transform = target_transform 84 | 85 | self.all_image_ids = set([self.coco.loadAnns(annotation_id)[0]['image_id'] for annotation_id in self.ids]) 86 | 87 | iid_to_cls = {} 88 | if instance_annFile: 89 | with open(instance_annFile) as fin: 90 | instance_ann = json.load(fin) 91 | for ann in instance_ann['annotations']: 92 | image_id = int(ann['image_id']) 93 | code = iid_to_cls.get(image_id, [0] * 90) 94 | code[int(ann['category_id']) - 1] = 1 95 | iid_to_cls[image_id] = code 96 | 97 | seen_classes = {} 98 | new_iid_to_cls = {} 99 | idx = 0 100 | for k, v in iid_to_cls.items(): 101 | v = ''.join([str(s) for s in v]) 102 | if v in seen_classes: 103 | new_iid_to_cls[k] = seen_classes[v] 104 | else: 105 | new_iid_to_cls[k] = idx 106 | seen_classes[v] = idx 107 | idx += 1 108 | iid_to_cls = new_iid_to_cls 109 | 110 | if self.all_image_ids - set(iid_to_cls.keys()): 111 | print(f'Found mismatched! {self.all_image_ids - set(iid_to_cls.keys())}') 112 | 113 | self.iid_to_cls = iid_to_cls 114 | self.n_images = len(self.all_image_ids) 115 | 116 | def __getitem__(self, index, get_caption=False): 117 | """ 118 | Args: 119 | index (int): Index 120 | Returns: 121 | tuple: Tuple (image, target). target is a caption for the annotation. 122 | """ 123 | coco = self.coco 124 | annotation_id = self.ids[index] 125 | annotation = coco.loadAnns(annotation_id)[0] 126 | image_id = annotation['image_id'] 127 | # target = annotation['caption'] 128 | caption = annotation['caption'] 129 | 130 | path = coco.loadImgs(image_id)[0]['file_name'] 131 | 132 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 133 | if self.transform is not None: 134 | img = self.transform(img) 135 | 136 | if self.target_transform is not None: 137 | # target = self.target_transform(target) 138 | target = self.target_transform(caption) 139 | target = target.squeeze(0) 140 | img_masked = img 141 | is_img_masked = False 142 | if get_caption: 143 | return img, target, caption, img_masked, is_img_masked 144 | else: 145 | return img, target, img_masked, is_img_masked 146 | # if get_caption: 147 | # return img, target, caption, annotation_id, image_id 148 | # else: 149 | # return img, target, annotation_id, image_id 150 | 151 | def __len__(self): 152 | return len(self.ids) 153 | 154 | 155 | class CocoBboxes(CocoCaptionsCap): 156 | def __init__(self, root, annFile, ids, extra_ids=None, extra_annFile=None, transform=None, target_transform=None, instanceFile=None): 157 | super().__init__(root, annFile, ids, extra_ids=extra_ids, extra_annFile=extra_annFile, transform=transform, target_transform=target_transform) 158 | dirname = os.path.dirname(annFile) 159 | self.coco_for_instance = COCO(instanceFile) 160 | 161 | categories_info = self.coco_for_instance.loadCats(self.coco_for_instance.getCatIds()) 162 | self.category_id2name = {info['id']: info['name'] for info in categories_info} 163 | 164 | def __getitem__(self, index, get_caption=False): 165 | """ 166 | Returns: 167 | bboxes (torch.tensor, size=(#bboxes, 4)): (x_left, y_top, width, height) 168 | """ 169 | coco = self.coco 170 | annotation_id = self.ids[index] 171 | annotation = coco.loadAnns(annotation_id)[0] 172 | image_id = annotation['image_id'] 173 | caption = annotation['caption'] 174 | 175 | path = coco.loadImgs(image_id)[0]['file_name'] 176 | 177 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 178 | W, H = img.size 179 | if self.transform is not None: 180 | img, img_masked, is_img_masked = self.transform(img) 181 | 182 | if self.target_transform is not None: 183 | # target = self.target_transform(target) 184 | target = self.target_transform(caption) 185 | target = target.squeeze(0) 186 | 187 | # get bboxes 188 | bbox_ann_ids = self.coco_for_instance.getAnnIds(imgIds=[image_id]) 189 | bbox_anns = self.coco_for_instance.loadAnns(bbox_ann_ids) 190 | bboxes = torch.tensor([ann['bbox'] for ann in bbox_anns]) 191 | bbox_cats = [self.category_id2name[ann['category_id']] for ann in bbox_anns] 192 | if len(bboxes) == 0: 193 | bboxes = torch.tensor([[0., 0., 0., 0.]]) 194 | bbox_cats = ['none'] 195 | else: 196 | # bbox transform 197 | length_ratio = 224 / H if W > H else 224 / W 198 | bboxes *= length_ratio 199 | if W > H: 200 | bboxes[:, 0] -= ((W * length_ratio) - 224) / 2 201 | else: 202 | bboxes[:, 1] -= ((H * length_ratio) - 224) / 2 203 | x_right = torch.clamp(bboxes[:, 0] + bboxes[:,2], 0, 224) 204 | y_bottom = torch.clamp(bboxes[:, 1] + bboxes[:,3], 0, 224) 205 | bboxes[:, 0] = torch.clamp(bboxes[:, 0], 0, 224) 206 | bboxes[:, 1] = torch.clamp(bboxes[:, 1], 0, 224) 207 | bboxes[:, 2] = x_right - bboxes[:, 0] 208 | bboxes[:, 3] = y_bottom - bboxes[:, 1] 209 | is_object = (bboxes[:,2] > 0).logical_and(bboxes[:,3] > 0) 210 | bboxes = bboxes[is_object] 211 | bbox_cats = [cat for i, cat in enumerate(bbox_cats) if is_object[i].item()] 212 | 213 | if get_caption: 214 | return img, target, caption, bboxes, bbox_cats 215 | else: 216 | return img, target, bboxes 217 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=conda_forge 5 | _openmp_mutex=4.5=2_kmp_llvm 6 | aiohttp=3.8.1=pypi_0 7 | aiosignal=1.2.0=pypi_0 8 | albumentations=1.2.0=pyhd8ed1ab_0 9 | alsa-lib=1.2.6.1=h7f98852_0 10 | analytics-python=1.4.0=pypi_0 11 | anyio=3.6.1=pypi_0 12 | aom=3.3.0=h27087fc_1 13 | argon2-cffi=21.3.0=pypi_0 14 | argon2-cffi-bindings=21.2.0=pypi_0 15 | asttokens=2.0.5=pypi_0 16 | async-timeout=4.0.2=pypi_0 17 | attr=2.5.1=h166bdaf_0 18 | attrs=21.4.0=pypi_0 19 | babel=2.10.1=pypi_0 20 | backcall=0.2.0=pypi_0 21 | backoff=1.10.0=pypi_0 22 | bcrypt=3.2.2=pypi_0 23 | beautifulsoup4=4.11.1=pypi_0 24 | blas=1.0=mkl 25 | bleach=5.0.0=pypi_0 26 | blosc=1.21.1=h83bc5f7_3 27 | brotli=1.0.9=h166bdaf_7 28 | brotli-bin=1.0.9=h166bdaf_7 29 | brotlipy=0.7.0=py310h7f8727e_1002 30 | brunsli=0.1=h9c3ff4c_0 31 | bzip2=1.0.8=h7b6447c_0 32 | c-ares=1.18.1=h7f98852_0 33 | c-blosc2=2.2.0=h7a311fb_0 34 | ca-certificates=2022.6.15=ha878542_0 35 | cairo=1.16.0=ha61ee94_1011 36 | certifi=2022.6.15=py310hff52083_0 37 | cffi=1.15.0=py310hd667e15_1 38 | cfitsio=4.1.0=hd9d235c_0 39 | charls=2.3.4=h9c3ff4c_0 40 | charset-normalizer=2.0.4=pyhd3eb1b0_0 41 | click=8.1.3=pypi_0 42 | cloudpickle=2.1.0=pyhd8ed1ab_0 43 | cryptography=37.0.1=py310h9ce1e76_0 44 | cudatoolkit=10.2.89=hfd86e86_1 45 | cycler=0.11.0=pypi_0 46 | cytoolz=0.11.2=py310h5764c6d_2 47 | dask-core=2022.7.0=pyhd8ed1ab_0 48 | dbus=1.13.6=h5008d03_3 49 | debugpy=1.6.0=pypi_0 50 | decorator=5.1.1=pypi_0 51 | defusedxml=0.7.1=pypi_0 52 | entrypoints=0.4=pypi_0 53 | executing=0.8.3=pypi_0 54 | expat=2.4.8=h27087fc_0 55 | fastapi=0.78.0=pypi_0 56 | fastjsonschema=2.15.3=pypi_0 57 | ffmpeg=4.4.2=habc3f16_0 58 | ffmpy=0.3.0=pypi_0 59 | fftw=3.3.10=nompi_h77c792f_102 60 | fire=0.4.0=pypi_0 61 | font-ttf-dejavu-sans-mono=2.37=hab24e00_0 62 | font-ttf-inconsolata=3.000=h77eed37_0 63 | font-ttf-source-code-pro=2.038=h77eed37_0 64 | font-ttf-ubuntu=0.83=hab24e00_0 65 | fontconfig=2.14.0=h8e229c2_0 66 | fonts-conda-ecosystem=1=0 67 | fonts-conda-forge=1=0 68 | fonttools=4.33.3=pypi_0 69 | freeglut=3.2.2=h9c3ff4c_1 70 | freetype=2.11.0=h70c0345_0 71 | frozenlist=1.3.0=pypi_0 72 | fsspec=2022.5.0=pyhd8ed1ab_0 73 | ftfy=6.1.1=pypi_0 74 | gettext=0.19.8.1=h73d1719_1008 75 | giflib=5.2.1=h7b6447c_0 76 | glib=2.70.2=h780b84a_4 77 | glib-tools=2.70.2=h780b84a_4 78 | gmp=6.2.1=h295c915_3 79 | gnutls=3.7.6=hb5d6004_1 80 | gradio=3.0.24=pypi_0 81 | graphite2=1.3.13=h58526e2_1001 82 | gst-plugins-base=1.20.3=hf6a322e_0 83 | gstreamer=1.20.3=hd4edc92_0 84 | h11=0.12.0=pypi_0 85 | harfbuzz=4.4.1=hf9f4e7c_0 86 | hdf5=1.12.1=nompi_h2386368_104 87 | httpcore=0.15.0=pypi_0 88 | httpx=0.23.0=pypi_0 89 | icu=70.1=h27087fc_0 90 | idna=3.3=pyhd3eb1b0_0 91 | imagecodecs=2022.2.22=py310h3ac3b6e_6 92 | imageio=2.19.3=pyhcf75d05_0 93 | intel-openmp=2021.4.0=h06a4308_3561 94 | ipykernel=6.13.0=pypi_0 95 | ipython=8.4.0=pypi_0 96 | ipython-genutils=0.2.0=pypi_0 97 | jack=1.9.18=h8c3723f_1002 98 | jasper=2.0.33=ha77e612_0 99 | jedi=0.18.1=pypi_0 100 | jinja2=3.1.2=pypi_0 101 | joblib=1.1.0=pyhd8ed1ab_0 102 | jpeg=9e=h7f8727e_0 103 | json5=0.9.8=pypi_0 104 | jsonschema=4.6.0=pypi_0 105 | jupyter-client=7.3.1=pypi_0 106 | jupyter-core=4.10.0=pypi_0 107 | jupyter-server=1.17.0=pypi_0 108 | jupyterlab=3.4.2=pypi_0 109 | jupyterlab-pygments=0.2.2=pypi_0 110 | jupyterlab-server=2.14.0=pypi_0 111 | jxrlib=1.1=h7f98852_2 112 | keyutils=1.6.1=h166bdaf_0 113 | kiwisolver=1.4.2=pypi_0 114 | kornia=0.6.5=pypi_0 115 | krb5=1.19.3=h3790be6_0 116 | lame=3.100=h7b6447c_0 117 | lcms2=2.12=h3be6417_0 118 | ld_impl_linux-64=2.38=h1181459_1 119 | lerc=3.0=h9c3ff4c_0 120 | libaec=1.0.6=h9c3ff4c_0 121 | libavif=0.10.1=h166bdaf_0 122 | libblas=3.9.0=12_linux64_mkl 123 | libbrotlicommon=1.0.9=h166bdaf_7 124 | libbrotlidec=1.0.9=h166bdaf_7 125 | libbrotlienc=1.0.9=h166bdaf_7 126 | libcap=2.64=ha37c62d_0 127 | libcblas=3.9.0=12_linux64_mkl 128 | libclang=14.0.6=default_h2e3cab8_0 129 | libclang13=14.0.6=default_h3a83d3e_0 130 | libcups=2.3.3=hf5a7f15_1 131 | libcurl=7.83.1=h7bff187_0 132 | libdb=6.2.32=h9c3ff4c_0 133 | libdeflate=1.12=h166bdaf_0 134 | libdrm=2.4.112=h166bdaf_0 135 | libedit=3.1.20191231=he28a2e2_2 136 | libev=4.33=h516909a_1 137 | libevent=2.1.10=h9b69904_4 138 | libffi=3.4.2=h7f98852_5 139 | libflac=1.3.4=h27087fc_0 140 | libgcc-ng=12.1.0=h8d9b700_16 141 | libgfortran-ng=12.1.0=h69a702a_16 142 | libgfortran5=12.1.0=hdcd56e2_16 143 | libglib=2.70.2=h174f98d_4 144 | libglu=9.0.0=he1b5a44_1001 145 | libiconv=1.16=h7f8727e_2 146 | libidn2=2.3.2=h7f8727e_0 147 | liblapack=3.9.0=12_linux64_mkl 148 | liblapacke=3.9.0=12_linux64_mkl 149 | libllvm14=14.0.6=he0ac6c6_0 150 | libnghttp2=1.47.0=h727a467_0 151 | libnsl=2.0.0=h7f98852_0 152 | libogg=1.3.4=h7f98852_1 153 | libopencv=4.5.5=py310hcb97b83_13 154 | libopus=1.3.1=h7f98852_1 155 | libpciaccess=0.16=h516909a_0 156 | libpng=1.6.37=hbc83047_0 157 | libpq=14.4=hd77ab85_0 158 | libprotobuf=3.20.1=h6239696_0 159 | libsndfile=1.0.31=h9c3ff4c_1 160 | libssh2=1.10.0=ha56f1ee_2 161 | libstdcxx-ng=12.1.0=ha89aaad_16 162 | libtasn1=4.16.0=h27cfd23_0 163 | libtiff=4.4.0=hc85c160_1 164 | libtool=2.4.6=h9c3ff4c_1008 165 | libudev1=249=h166bdaf_4 166 | libunistring=0.9.10=h27cfd23_0 167 | libuuid=2.32.1=h7f98852_1000 168 | libuv=1.40.0=h7b6447c_0 169 | libva=2.15.0=h166bdaf_0 170 | libvorbis=1.3.7=h9c3ff4c_0 171 | libvpx=1.11.0=h9c3ff4c_3 172 | libwebp=1.2.2=h55f646e_0 173 | libwebp-base=1.2.2=h7f8727e_0 174 | libxcb=1.13=h7f98852_1004 175 | libxkbcommon=1.0.3=he3ba5ed_0 176 | libxml2=2.9.14=h22db469_3 177 | libzlib=1.2.12=h166bdaf_1 178 | libzopfli=1.0.3=h9c3ff4c_0 179 | linkify-it-py=1.0.3=pypi_0 180 | llvm-openmp=14.0.4=he0ac6c6_0 181 | locket=1.0.0=pyhd8ed1ab_0 182 | lz4-c=1.9.3=h295c915_1 183 | markdown-it-py=2.1.0=pypi_0 184 | markupsafe=2.1.1=pypi_0 185 | matplotlib=3.5.2=pypi_0 186 | matplotlib-inline=0.1.3=pypi_0 187 | mdit-py-plugins=0.3.0=pypi_0 188 | mdurl=0.1.1=pypi_0 189 | mistune=0.8.4=pypi_0 190 | mkl=2021.4.0=h06a4308_640 191 | mkl-service=2.4.0=py310h7f8727e_0 192 | mkl_fft=1.3.1=py310hd6ae3a3_0 193 | mkl_random=1.2.2=py310h00e6091_0 194 | mltk=0.0.5=pypi_0 195 | monotonic=1.6=pypi_0 196 | multidict=6.0.2=pypi_0 197 | munch=2.5.0=pypi_0 198 | mysql-common=8.0.29=haf5c9bc_1 199 | mysql-libs=8.0.29=h28c427c_1 200 | nbclassic=0.3.7=pypi_0 201 | nbclient=0.6.4=pypi_0 202 | nbconvert=6.5.0=pypi_0 203 | nbformat=5.4.0=pypi_0 204 | ncurses=6.3=h7f8727e_2 205 | nest-asyncio=1.5.5=pypi_0 206 | nettle=3.7.3=hbbd107a_1 207 | networkx=2.8.4=pyhd8ed1ab_0 208 | nltk=3.7=pypi_0 209 | notebook=6.4.11=pypi_0 210 | notebook-shim=0.1.0=pypi_0 211 | nspr=4.32=h9c3ff4c_1 212 | nss=3.78=h2350873_0 213 | ntk=1.1.3=pypi_0 214 | numpy=1.22.3=py310hfa59a62_0 215 | numpy-base=1.22.3=py310h9585f30_0 216 | opencv=4.5.5=py310hff52083_13 217 | opencv-python=4.6.0.66=pypi_0 218 | openh264=2.1.1=h4ff587b_0 219 | openjpeg=2.4.0=hb52868f_1 220 | openssl=1.1.1q=h166bdaf_0 221 | orjson=3.7.7=pypi_0 222 | packaging=21.3=pyhd8ed1ab_0 223 | pandas=1.4.2=pypi_0 224 | pandocfilters=1.5.0=pypi_0 225 | paramiko=2.11.0=pypi_0 226 | parso=0.8.3=pypi_0 227 | partd=1.2.0=pyhd8ed1ab_0 228 | pcre=8.45=h9c3ff4c_0 229 | pexpect=4.8.0=pypi_0 230 | pickleshare=0.7.5=pypi_0 231 | pillow=9.0.1=py310h22f2fdc_0 232 | pip=21.2.4=py310h06a4308_0 233 | pixman=0.40.0=h36c2ea0_0 234 | portaudio=19.6.0=h57a0ea0_5 235 | prometheus-client=0.14.1=pypi_0 236 | prompt-toolkit=3.0.29=pypi_0 237 | psutil=5.9.1=pypi_0 238 | pthread-stubs=0.4=h36c2ea0_1001 239 | ptyprocess=0.7.0=pypi_0 240 | pulseaudio=14.0=h7f54b18_8 241 | pure-eval=0.2.2=pypi_0 242 | py-opencv=4.5.5=py310hfdc917e_13 243 | pycocotools=2.0.4=pypi_0 244 | pycparser=2.21=pyhd3eb1b0_0 245 | pycryptodome=3.15.0=pypi_0 246 | pydantic=1.9.1=pypi_0 247 | pydub=0.25.1=pypi_0 248 | pygments=2.12.0=pypi_0 249 | pynacl=1.5.0=pypi_0 250 | pyopenssl=22.0.0=pyhd3eb1b0_0 251 | pyparsing=3.0.9=pyhd8ed1ab_0 252 | pyrsistent=0.18.1=pypi_0 253 | pysocks=1.7.1=py310h06a4308_0 254 | python=3.10.5=h582c2e5_0_cpython 255 | python-dateutil=2.8.2=pypi_0 256 | python-multipart=0.0.5=pypi_0 257 | python_abi=3.10=2_cp310 258 | pytorch=1.11.0=py3.10_cuda10.2_cudnn7.6.5_0 259 | pytorch-mutex=1.0=cuda 260 | pytz=2022.1=pypi_0 261 | pywavelets=1.3.0=py310hde88566_1 262 | pyyaml=6.0=py310h5764c6d_4 263 | pyzmq=23.1.0=pypi_0 264 | qt-main=5.15.4=ha5833f6_2 265 | qudida=0.0.4=pyhd8ed1ab_0 266 | readline=8.1.2=h7f8727e_1 267 | regex=2022.6.2=pypi_0 268 | requests=2.27.1=pyhd3eb1b0_0 269 | rfc3986=1.5.0=pypi_0 270 | scikit-image=0.19.3=py310h769672d_0 271 | scikit-learn=1.1.1=py310hffb9edd_0 272 | scipy=1.8.1=py310h7612f91_0 273 | seaborn=0.11.2=pypi_0 274 | send2trash=1.8.0=pypi_0 275 | setuptools=61.2.0=py310h06a4308_0 276 | six=1.16.0=pyhd3eb1b0_1 277 | snappy=1.1.9=hbd366e4_1 278 | sniffio=1.2.0=pypi_0 279 | soupsieve=2.3.2.post1=pypi_0 280 | sqlite=3.39.0=h4ff8645_0 281 | stack-data=0.2.0=pypi_0 282 | starlette=0.19.1=pypi_0 283 | svt-av1=1.1.0=h27087fc_1 284 | termcolor=1.1.0=pypi_0 285 | terminado=0.15.0=pypi_0 286 | threadpoolctl=3.1.0=pyh8a188c0_0 287 | tifffile=2022.5.4=pyhd8ed1ab_0 288 | tinycss2=1.1.1=pypi_0 289 | tk=8.6.12=h1ccaba5_0 290 | toolz=0.11.2=pyhd8ed1ab_0 291 | torchaudio=0.11.0=py310_cu102 292 | torchvision=0.12.0=py310_cu102 293 | tornado=6.1=pypi_0 294 | tqdm=4.64.0=pypi_0 295 | traitlets=5.2.2.post1=pypi_0 296 | typing-extensions=4.1.1=hd3eb1b0_0 297 | typing_extensions=4.1.1=pyh06a4308_0 298 | tzdata=2022a=hda174b7_0 299 | uc-micro-py=1.0.1=pypi_0 300 | urllib3=1.26.9=py310h06a4308_0 301 | uvicorn=0.18.2=pypi_0 302 | wcwidth=0.2.5=pypi_0 303 | webencodings=0.5.1=pypi_0 304 | websocket-client=1.3.2=pypi_0 305 | wheel=0.37.1=pyhd3eb1b0_0 306 | x264=1!161.3030=h7f98852_1 307 | x265=3.5=h924138e_3 308 | xcb-util=0.4.0=h166bdaf_0 309 | xcb-util-image=0.4.0=h166bdaf_0 310 | xcb-util-keysyms=0.4.0=h166bdaf_0 311 | xcb-util-renderutil=0.3.9=h166bdaf_0 312 | xcb-util-wm=0.4.1=h166bdaf_0 313 | xorg-fixesproto=5.0=h7f98852_1002 314 | xorg-inputproto=2.3.2=h7f98852_1002 315 | xorg-kbproto=1.0.7=h7f98852_1002 316 | xorg-libice=1.0.10=h7f98852_0 317 | xorg-libsm=1.2.3=hd9c2040_1000 318 | xorg-libx11=1.7.2=h7f98852_0 319 | xorg-libxau=1.0.9=h7f98852_0 320 | xorg-libxdmcp=1.1.3=h7f98852_0 321 | xorg-libxext=1.3.4=h7f98852_1 322 | xorg-libxfixes=5.0.3=h7f98852_1004 323 | xorg-libxi=1.7.10=h7f98852_0 324 | xorg-libxrender=0.9.10=h7f98852_1003 325 | xorg-renderproto=0.11.1=h7f98852_1002 326 | xorg-xextproto=7.3.0=h7f98852_1002 327 | xorg-xproto=7.0.31=h7f98852_1007 328 | xz=5.2.5=h7f8727e_1 329 | yaml=0.2.5=h7f98852_2 330 | yarl=1.7.2=pypi_0 331 | zfp=0.5.5=h9c3ff4c_8 332 | zlib=1.2.12=h166bdaf_1 333 | zlib-ng=2.0.6=h166bdaf_0 334 | zstd=1.5.2=ha4553b6_0 335 | -------------------------------------------------------------------------------- /src/ds_lavis/coco.py: -------------------------------------------------------------------------------- 1 | """MS-COCO image-to-caption retrieval dataset code 2 | 3 | reference codes: 4 | https://github.com/pytorch/vision/blob/v0.2.2_branch/torchvision/datasets/coco.py 5 | https://github.com/yalesong/pvse/blob/master/data.py 6 | """ 7 | 8 | import os 9 | from os.path import join as ospj 10 | try: 11 | import ujson as json 12 | except ImportError: 13 | import json 14 | 15 | from PIL import Image 16 | from pycocotools.coco import COCO 17 | 18 | import torch 19 | from torch.utils.data import Dataset 20 | 21 | from lavis.models import load_model_and_preprocess 22 | device='cuda' 23 | model, vis_processors, txt_processors = load_model_and_preprocess(name="blip_feature_extractor", model_type="base", is_eval=True, device=device) 24 | 25 | 26 | class CocoCaptionsCap(Dataset): 27 | """`MS Coco Captions `_ Dataset. 28 | Args: 29 | root (string): Root directory where images are downloaded to. 30 | annFile (string): Path to json annotation file. 31 | ids (list, optional): list of target caption ids 32 | extra_annFile (string, optional): Path to extra json annotation file (for training) 33 | extra_ids (list, optional): list of extra target caption ids (for training) 34 | transform (callable, optional): A function/transform that takes in an PIL image 35 | and returns a transformed version. E.g, ``transforms.ToTensor`` 36 | target_transform (callable, optional): A function/transform that takes in the 37 | target and transforms it. 38 | instance_annFile (str, optional): Path to instance annotation json (for PMRP computation) 39 | 40 | Example: 41 | .. code:: python 42 | import torchvision.datasets as dset 43 | import torchvision.transforms as transforms 44 | cap = dset.CocoCaptions(root='dir where images are', 45 | annFile='json annotation file', 46 | transform=transforms.ToTensor()) 47 | print('Number of samples: ', len(cap)) 48 | img, target = cap[3] # load 4th sample 49 | print("Image Size: ", img.size()) 50 | print(target) 51 | Output: :: 52 | Number of samples: 82783 53 | Image Size: (3L, 427L, 640L) 54 | [u'A plane emitting smoke stream flying over a mountain.', 55 | u'A plane darts across a bright blue sky behind a mountain covered in snow', 56 | u'A plane leaves a contrail above the snowy mountain top.', 57 | u'A mountain that has a plane flying overheard in the distance.', 58 | u'A mountain view with a plume of smoke in the background'] 59 | """ 60 | def __init__(self, root, annFile, ids=None, 61 | extra_annFile=None, extra_ids=None, 62 | transform=None, target_transform=None, 63 | instance_annFile=None): 64 | self.root = os.path.expanduser(root) 65 | if extra_annFile: 66 | self.coco = COCO() 67 | with open(annFile, 'r') as fin1, open(extra_annFile, 'r') as fin2: 68 | dataset = json.load(fin1) 69 | extra_dataset = json.load(fin2) 70 | if not isinstance(dataset, dict) or not isinstance(extra_dataset, dict): 71 | raise TypeError('invalid type {} {}'.format(type(dataset), 72 | type(extra_dataset))) 73 | if set(dataset.keys()) != set(extra_dataset.keys()): 74 | raise KeyError('key mismatch {} != {}'.format(list(dataset.keys()), 75 | list(extra_dataset.keys()))) 76 | for key in ['images', 'annotations']: 77 | dataset[key].extend(extra_dataset[key]) 78 | self.coco.dataset = dataset 79 | self.coco.createIndex() 80 | else: 81 | self.coco = COCO(annFile) 82 | self.ids = list(self.coco.anns.keys()) if ids is None else list(ids) 83 | if extra_ids is not None: 84 | self.ids += list(extra_ids) 85 | self.ids = [int(id_) for id_ in self.ids] 86 | self.transform = transform 87 | self.target_transform = target_transform 88 | 89 | self.all_image_ids = set([self.coco.loadAnns(annotation_id)[0]['image_id'] for annotation_id in self.ids]) 90 | 91 | iid_to_cls = {} 92 | if instance_annFile: 93 | with open(instance_annFile) as fin: 94 | instance_ann = json.load(fin) 95 | for ann in instance_ann['annotations']: 96 | image_id = int(ann['image_id']) 97 | code = iid_to_cls.get(image_id, [0] * 90) 98 | code[int(ann['category_id']) - 1] = 1 99 | iid_to_cls[image_id] = code 100 | 101 | seen_classes = {} 102 | new_iid_to_cls = {} 103 | idx = 0 104 | for k, v in iid_to_cls.items(): 105 | v = ''.join([str(s) for s in v]) 106 | if v in seen_classes: 107 | new_iid_to_cls[k] = seen_classes[v] 108 | else: 109 | new_iid_to_cls[k] = idx 110 | seen_classes[v] = idx 111 | idx += 1 112 | iid_to_cls = new_iid_to_cls 113 | 114 | if self.all_image_ids - set(iid_to_cls.keys()): 115 | print(f'Found mismatched! {self.all_image_ids - set(iid_to_cls.keys())}') 116 | 117 | self.iid_to_cls = iid_to_cls 118 | self.n_images = len(self.all_image_ids) 119 | 120 | def __getitem__(self, index, get_caption=False): 121 | """ 122 | Args: 123 | index (int): Index 124 | Returns: 125 | tuple: Tuple (image, target). target is a caption for the annotation. 126 | """ 127 | coco = self.coco 128 | annotation_id = self.ids[index] 129 | annotation = coco.loadAnns(annotation_id)[0] 130 | image_id = annotation['image_id'] 131 | # target = annotation['caption'] 132 | caption = annotation['caption'] 133 | target = caption 134 | 135 | path = coco.loadImgs(image_id)[0]['file_name'] 136 | 137 | # print('dbg', caption) 138 | 139 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 140 | # print('dbg', img, target) 141 | # if self.transform is not None: 142 | # img = self.transform(img) 143 | 144 | img = vis_processors["eval"](img) 145 | # print('dbg', img, target) 146 | 147 | if self.target_transform is not None: 148 | # target = self.target_transform(target) 149 | target = self.target_transform(caption) 150 | target = target.squeeze(0) 151 | # print('dbg', target) 152 | target = txt_processors["eval"](target) 153 | img_masked = img 154 | is_img_masked = False 155 | 156 | if get_caption: 157 | return img, target, caption, img_masked, is_img_masked 158 | else: 159 | return img, target, img_masked, is_img_masked 160 | # if get_caption: 161 | # return img, target, caption, annotation_id, image_id 162 | # else: 163 | # return img, target, annotation_id, image_id 164 | 165 | def __len__(self): 166 | return len(self.ids) 167 | 168 | 169 | class CocoBboxes(CocoCaptionsCap): 170 | def __init__(self, root, annFile, ids, extra_ids=None, extra_annFile=None, transform=None, target_transform=None, instanceFile=None): 171 | super().__init__(root, annFile, ids, extra_ids=extra_ids, extra_annFile=extra_annFile, transform=transform, target_transform=target_transform) 172 | dirname = os.path.dirname(annFile) 173 | self.coco_for_instance = COCO(instanceFile) 174 | 175 | categories_info = self.coco_for_instance.loadCats(self.coco_for_instance.getCatIds()) 176 | self.category_id2name = {info['id']: info['name'] for info in categories_info} 177 | 178 | def __getitem__(self, index, get_caption=False): 179 | """ 180 | Returns: 181 | bboxes (torch.tensor, size=(#bboxes, 4)): (x_left, y_top, width, height) 182 | """ 183 | coco = self.coco 184 | annotation_id = self.ids[index] 185 | annotation = coco.loadAnns(annotation_id)[0] 186 | image_id = annotation['image_id'] 187 | caption = annotation['caption'] 188 | 189 | path = coco.loadImgs(image_id)[0]['file_name'] 190 | 191 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 192 | W, H = img.size 193 | if self.transform is not None: 194 | img, img_masked, is_img_masked = self.transform(img) 195 | 196 | if self.target_transform is not None: 197 | # target = self.target_transform(target) 198 | target = self.target_transform(caption) 199 | target = target.squeeze(0) 200 | 201 | # get bboxes 202 | bbox_ann_ids = self.coco_for_instance.getAnnIds(imgIds=[image_id]) 203 | bbox_anns = self.coco_for_instance.loadAnns(bbox_ann_ids) 204 | bboxes = torch.tensor([ann['bbox'] for ann in bbox_anns]) 205 | bbox_cats = [self.category_id2name[ann['category_id']] for ann in bbox_anns] 206 | if len(bboxes) == 0: 207 | bboxes = torch.tensor([[0., 0., 0., 0.]]) 208 | bbox_cats = ['none'] 209 | else: 210 | # bbox transform 211 | length_ratio = 224 / H if W > H else 224 / W 212 | bboxes *= length_ratio 213 | if W > H: 214 | bboxes[:, 0] -= ((W * length_ratio) - 224) / 2 215 | else: 216 | bboxes[:, 1] -= ((H * length_ratio) - 224) / 2 217 | x_right = torch.clamp(bboxes[:, 0] + bboxes[:,2], 0, 224) 218 | y_bottom = torch.clamp(bboxes[:, 1] + bboxes[:,3], 0, 224) 219 | bboxes[:, 0] = torch.clamp(bboxes[:, 0], 0, 224) 220 | bboxes[:, 1] = torch.clamp(bboxes[:, 1], 0, 224) 221 | bboxes[:, 2] = x_right - bboxes[:, 0] 222 | bboxes[:, 3] = y_bottom - bboxes[:, 1] 223 | is_object = (bboxes[:,2] > 0).logical_and(bboxes[:,3] > 0) 224 | bboxes = bboxes[is_object] 225 | bbox_cats = [cat for i, cat in enumerate(bbox_cats) if is_object[i].item()] 226 | 227 | if get_caption: 228 | return img, target, caption, bboxes, bbox_cats 229 | else: 230 | return img, target, bboxes 231 | --------------------------------------------------------------------------------