├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── GETTING_STARTED.md ├── INSTALL.md ├── LICENSE ├── README.md ├── configs ├── ovseg_R101c_vitB_bs32_120k.yaml ├── ovseg_swinB_vitL_bs32_120k.yaml └── ovseg_swinB_vitL_demo.yaml ├── datasets ├── DATASETS.md ├── prepare_ade20k_full_sem_seg.py ├── prepare_ade20k_sem_seg.py ├── prepare_coco_stuff_sem_seg.py ├── prepare_pascal_context.py └── prepare_voc_sem_seg.py ├── demo.py ├── open_clip_training ├── CITATION.cff ├── HISTORY.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── docs │ ├── CLIP.png │ ├── Interacting_with_open_clip.ipynb │ ├── clip_conceptual_captions.md │ ├── clip_loss.png │ ├── clip_recall.png │ ├── clip_val_loss.png │ ├── clip_zeroshot.png │ ├── effective_robustness.png │ ├── laion2b_clip_zeroshot_b32.png │ ├── laion_clip_zeroshot.png │ ├── laion_clip_zeroshot_b16.png │ ├── laion_clip_zeroshot_b16_plus_240.png │ ├── laion_clip_zeroshot_l14.png │ ├── laion_openai_compare_b32.jpg │ └── scaling.png ├── requirements-test.txt ├── requirements-training.txt ├── requirements.txt ├── setup.py ├── src │ ├── data │ │ └── gather_cc.py │ ├── open_clip │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── factory.cpython-38.pyc │ │ │ ├── loss.cpython-38.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── openai.cpython-38.pyc │ │ │ ├── pretrained.cpython-38.pyc │ │ │ ├── timm_model.cpython-38.pyc │ │ │ ├── tokenizer.cpython-38.pyc │ │ │ ├── transform.cpython-38.pyc │ │ │ └── utils.cpython-38.pyc │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── factory.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── model_configs │ │ │ ├── RN101-quickgelu.json │ │ │ ├── RN101.json │ │ │ ├── RN50-quickgelu.json │ │ │ ├── RN50.json │ │ │ ├── RN50x16.json │ │ │ ├── RN50x4.json │ │ │ ├── ViT-B-16-plus-240.json │ │ │ ├── ViT-B-16-plus.json │ │ │ ├── ViT-B-16.json │ │ │ ├── ViT-B-32-plus-256.json │ │ │ ├── ViT-B-32-quickgelu.json │ │ │ ├── ViT-B-32.json │ │ │ ├── ViT-H-14.json │ │ │ ├── ViT-H-16.json │ │ │ ├── ViT-L-14-280.json │ │ │ ├── ViT-L-14-336.json │ │ │ ├── ViT-L-14.json │ │ │ ├── ViT-L-16-320.json │ │ │ ├── ViT-L-16.json │ │ │ ├── ViT-g-14.json │ │ │ ├── timm-efficientnetv2_rw_s.json │ │ │ ├── timm-resnet50d.json │ │ │ ├── timm-resnetaa50d.json │ │ │ ├── timm-resnetblur50.json │ │ │ ├── timm-swin_base_patch4_window7_224.json │ │ │ ├── timm-vit_base_patch16_224.json │ │ │ ├── timm-vit_base_patch32_224.json │ │ │ └── timm-vit_small_patch16_224.json │ │ ├── openai.py │ │ ├── pretrained.py │ │ ├── timm_model.py │ │ ├── tokenizer.py │ │ ├── transform.py │ │ ├── utils.py │ │ └── version.py │ ├── scripts │ │ ├── coco_gt_171cls_finetune_VitL.sh │ │ ├── coco_proposal_1cap_finetune_VitB.sh │ │ ├── coco_proposal_1cap_finetune_VitL.sh │ │ ├── coco_proposal_1cap_mask_prompt_tuning_VitL.sh │ │ └── launch.sh │ └── training │ │ ├── .gitignore │ │ ├── __init__.py │ │ ├── ade150_zeroshot_data.py │ │ ├── data.py │ │ ├── distributed.py │ │ ├── imagenet_zeroshot_data.py │ │ ├── logger.py │ │ ├── main.py │ │ ├── main_mask_prompt_tuning.py │ │ ├── params.py │ │ ├── plot_transforms.ipynb │ │ ├── scheduler.py │ │ ├── train.py │ │ └── zero_shot.py └── tests │ └── test_simple.py ├── open_vocab_seg ├── __init__.py ├── config.py ├── data │ ├── __init__.py │ ├── augmentations.py │ ├── build.py │ ├── dataset_mappers │ │ ├── __init__.py │ │ └── mask_former_semantic_dataset_mapper.py │ └── datasets │ │ ├── __init__.py │ │ ├── csv_data.py │ │ ├── register_ade20k_full.py │ │ ├── register_cc3m.py │ │ ├── register_coco_stuff.py │ │ ├── register_pascal_context.py │ │ └── register_voc_seg.py ├── evaluation │ ├── __init__.py │ └── generalized_sem_seg_evaluation.py ├── mask_former_model.py ├── modeling │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── clip_resnet.py │ │ └── swin.py │ ├── clip_adapter │ │ ├── __init__.py │ │ ├── adapter.py │ │ ├── text_template.py │ │ └── utils.py │ ├── criterion.py │ ├── heads │ │ ├── __init__.py │ │ ├── mask_former_head.py │ │ ├── open_vocab_mask_former_head.py │ │ └── pixel_decoder.py │ ├── matcher.py │ └── transformer │ │ ├── __init__.py │ │ ├── open_vocab_transformer_predictor.py │ │ ├── position_encoding.py │ │ ├── transformer.py │ │ └── transformer_predictor.py ├── ovseg_model.py ├── test_time_augmentation.py └── utils │ ├── __init__.py │ ├── events.py │ ├── misc.py │ ├── post_process_utils.py │ └── predictor.py ├── requirements.txt ├── resources ├── demo_samples │ └── sample_03.jpeg ├── ovseg.gif ├── proposal.png └── pytorch-logo-dark.png ├── third_party └── CLIP │ ├── .gitignore │ ├── CLIP.png │ ├── LICENSE │ ├── MANIFEST.in │ ├── README.md │ ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py │ ├── model-card.md │ ├── requirements.txt │ ├── setup.py │ └── tests │ └── test_consistency.py ├── tools ├── convert-pretrained-clip-model-to-d2.py ├── convert-pretrained-swin-model-to-d2.py ├── convert-torchvision-to-d2.py ├── ovseg_replace_clip.py ├── sanity_check_ft_clip_weights.py ├── search_thr_ensemble_w.sh └── web_demo.py └── train_net.py /.gitignore: -------------------------------------------------------------------------------- 1 | open_clip_training/openclip_data/* 2 | open_clip_training/src/logs/* 3 | *__pycache__* -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to OVSeg 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | 30 | ## License 31 | By contributing to OVSeg, you agree that your contributions will be licensed 32 | under the LICENSE file in the root directory of this source tree. 33 | -------------------------------------------------------------------------------- /GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | ## Getting started with OVSeg 2 | 3 | 4 | ### Try demo 5 | 6 | We release our largest model (Swin-Base + CLIP-ViT-L/14) [ovseg_swinbase_vitL14_ft_mpt.pth](https://drive.google.com/file/d/1cn-ohxgXDrDfkzC1QdO-fi8IjbjXmgKy/view?usp=sharing) (md5: 526080). 7 | 8 | - Test on sample image 9 | ```bash 10 | python demo.py --config-file configs/ovseg_swinB_vitL_demo.yaml --class-names 'Oculus' 'Ukulele' --input ./resources/demo_samples/sample_03.jpeg --output ./pred --opts MODEL.WEIGHTS #PATH_of_ovseg_swinbase_vitL14_ft_mpt.pth 11 | ``` 12 | 13 | ### Evaluation with pre-trained weights 14 | 15 | We release our largest model (Swin-Base + CLIP-ViT-L/14) [ovseg_swinbase_vitL14_ft_mpt.pth](https://drive.google.com/file/d/1cn-ohxgXDrDfkzC1QdO-fi8IjbjXmgKy/view?usp=sharing) (md5: 526080). 16 | 17 | - Test on ADE20K-150 and ADE-847 18 | ```bash 19 | python train_net.py --num-gpu 8 --eval-only --config-file configs/ovseg_swinB_vitL_bs32_120k.yaml MODEL.WEIGHTS #PATH_of_ovseg_swinbase_vitL14_ft_mpt.pth DATASETS.TEST \(\"ade20k_sem_seg_val\",\"ade20k_full_sem_seg_val\"\) 20 | ``` 21 | 22 | - Test on PascalContext-59 and PascalContext-459 23 | ```bash 24 | python train_net.py --num-gpu 8 --eval-only --config-file configs/ovseg_swinB_vitL_bs32_120k.yaml MODEL.WEIGHTS #PATH_of_ovseg_swinbase_vitL14_ft_mpt.pth MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT 0.6 DATASETS.TEST \(\"pascal_context_59_sem_seg_val\",\"pascal_context_459_sem_seg_val\",\) 25 | ``` 26 | 27 | - Test on PascalVOC-20 28 | ```bash 29 | python train_net.py --num-gpu 8 --eval-only --config-file configs/ovseg_swinB_vitL_bs32_120k.yaml MODEL.WEIGHTS #PATH_of_ovseg_swinbase_vitL14_ft_mpt.pth MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT 0.45 DATASETS.TEST \(\"pascalvoc20_sem_seg_val\",\) 30 | ``` 31 | 32 | You may also want to try our small model (R101c + CLIP-ViT-B/16) [ovseg_R101c_vitB16_ft_mpt.pth](https://drive.google.com/file/d/1ZmeFMEkhuLqaWkhz4t4IlQ8YflrYja8q/view?usp=drive_link) (md5: c746f4). 33 | 34 | - Test on ADE20K-150 35 | ```bash 36 | python train_net.py --num-gpu 8 --eval-only --config-file configs/ovseg_R101c_vitB_bs32_120k.yaml MODEL.WEIGHTS #PATH_of_ovseg_R101c_vitB16_ft_mpt.pth DATASETS.TEST \(\"ade20k_sem_seg_val\",\) 37 | ``` 38 | 39 | #### Performance benchmark 40 | 41 | | method | backbone | training dataset | A-847 | PC-459 | A-150 | PC-59 | PAS-20 | 42 | |------------------------------------|----------|------------------|:-----:|:------:|:-----:|:-----:|:------:| 43 | | Open-vocabulary generalist models. | | | | | | | | 44 | | SPNet | R-101 | PASCAL-15 | - | - | - | 24.3 | 18.3 | 45 | | ZS3Net | R-101 | PASCAL-15 | - | - | - | 19.4 | 38.3 | 46 | | LSeg | R-101 | PASCAL-15 | - | - | - | - | 47.4 | 47 | | LSeg+ | R-101 | COCO Panoptic | 2.5 | 5.2 | 13.0 | 36.0 | 59.0 | 48 | | SimBaseline | R-101c | COCO-Stuff-156 | - | - | 15.3 | - | 74.5 | 49 | | ZegFormer | R-50 | COCO-Stuff-156 | - | - | 16.4 | - | 80.7 | 50 | | OpenSeg | R-101 | COCO Panoptic | 4.0 | 6.5 | 15.3 | 36.9 | 60.0 | 51 | | OVSeg (Ours) | R-101c | COCO-Stuff-171 | 7.1 | 11.0 | 24.8 | 53.3 | 92.6 | 52 | | LSeg+ | Eff-B7 | COCO Panoptic | 3.8 | 7.8 | 18.0 | 46.5 | - | 53 | | OpenSeg | Eff-B7 | COCO Panoptic | 6.3 | 9.0 | 21.1 | 42.1 | - | 54 | | OVSeg (Ours) | Swin-B | COCO-Stuff-171 | 9.0 | 12.4 | 29.6 | 55.7 | 94.5 | 55 | | Supervised specialist models. | | | | | | | | 56 | | FCN | FCN-8s | Same as test | - | - | 29.4 | 37.8 | - | 57 | | Deeplab | R-101 | Same as test | - | - | - | 45.7 | 77.7 | 58 | | SelfTrain | Eff-L2 | Same as test | - | - | - | - | 90.0 | 59 | 60 | #### Ablation study 61 | 62 | - Mask prompt tuning can bring significant improvement without changing CLIP weights (Table 3 in [paper](https://arxiv.org/pdf/2210.04150.pdf)) 63 | 64 | Download the checkpoint with mpt only [ovseg_swinbase_vitL14_mpt_only.pt](https://drive.google.com/file/d/1LJGWFjHw76OGDNy9r9KQIaACfIm9KMhQ/view?usp=sharing) (md5: 2dd495). 65 | 66 | ```bash 67 | python train_net.py --num-gpu 8 --eval-only --config-file configs/ovseg_swinB_vitL_bs32_120k.yaml MODEL.WEIGHTS #PATH_of_ovseg_swinbase_vitL14_mpt_only.pt DATASETS.TEST \(\"ade20k_sem_seg_val\",\"ade20k_full_sem_seg_val\"\) 68 | ``` 69 | 70 | - Mask prompt tuning can improve over fully finetuned model (Table 3 in [paper](https://arxiv.org/pdf/2210.04150.pdf)) 71 | 72 | With the same [ovseg_swinbase_vitL14_ft_mpt.pth](https://drive.google.com/file/d/1cn-ohxgXDrDfkzC1QdO-fi8IjbjXmgKy/view?usp=sharing) checkpoint, set `MASK_PROMPT_FWD` as `False` 73 | 74 | ```bash 75 | python train_net.py --num-gpu 8 --eval-only --config-file configs/ovseg_swinB_vitL_bs32_120k.yaml MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD False MODEL.WEIGHTS #PATH_of_ovseg_swinbase_vitL14_ft_mpt.pth DATASETS.TEST \(\"ade20k_sem_seg_val\",\"ade20k_full_sem_seg_val\"\) 76 | ``` 77 | 78 | - The effects of class prediction ensemble (Table 6 in [paper](https://arxiv.org/pdf/2210.04150.pdf)) 79 | 80 | With the same [ovseg_swinbase_vitL14_ft_mpt.pth](https://drive.google.com/file/d/1cn-ohxgXDrDfkzC1QdO-fi8IjbjXmgKy/view?usp=sharing) checkpoint, set `CLIP_ENSEMBLE` as `False`. 81 | 82 | ```bash 83 | python train_net.py --num-gpu 8 --eval-only --config-file configs/ovseg_swinB_vitL_bs32_120k.yaml MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE False MODEL.WEIGHTS #PATH_of_ovseg_swinbase_vitL14_ft_mpt.pth DATASETS.TEST \(\"ade20k_sem_seg_val\",\"ade20k_full_sem_seg_val\"\) 84 | ``` 85 | 86 | ### Training Segmentation model 87 | 88 | Our model is trained on COCO-Stuff 89 | 90 | - Training baseline w/ original CLIP 91 | ``` 92 | python train_net.py --num-gpu 8 --config-file configs/ovseg_swinB_vitL_bs32_120k.yaml MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD False 93 | ``` 94 | 95 | To reproduce our final results, you may want to use the our mask-adapted CLIP 96 | 97 | - Training ovseg w/ mask-adapted CLIP 98 | ``` 99 | python train_net.py --num-gpu 8 --config-file configs/ovseg_swinB_vitL_bs32_120k.yaml MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME #PATH_TO_MASKADAPTED_CLIP 100 | ``` 101 | 102 | CAUTION: The final results is sensitive to the ensemble (appendix A.5 in [paper](https://arxiv.org/pdf/2210.04150.pdf)). Thus, you may want to use the ```tools/search_thr_ensemble_w.sh``` to find the best ensemble hyper-parameters. 103 | 104 | ### Fine-tuning CLIP with collected mask-category pairs 105 | 106 | Please see [open clip training](./open_clip_training/README.md) -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ### Requirements 4 | - Linux with Python ≥ 3.6 5 | - PyTorch ≥ 1.8 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. 6 | Install them together at [pytorch.org](https://pytorch.org) to make sure of this. Note, please check 7 | PyTorch version matches that is required by Detectron2. 8 | - Detectron2: follow [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html). 9 | 10 | ### Usage 11 | 12 | Install required packages. 13 | 14 | ```bash 15 | conda create --name ovseg python=3.8 16 | conda activate ovseg 17 | conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | You need to download `detectron2==0.6` following [instructions](https://detectron2.readthedocs.io/en/latest/tutorials/install.html) 22 | 23 | ```bash 24 | python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html 25 | ``` 26 | 27 | 28 | FurtherMore, install the modified clip package. 29 | 30 | ```bash 31 | cd third_party/CLIP 32 | python -m pip install -Ue . 33 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [OVSeg] Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP 2 | 3 | 4 | 5 | This is the official PyTorch implementation of our paper:
6 | **Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP**
7 | [Feng Liang](https://jeff-liangf.github.io/), [Bichen Wu](https://www.linkedin.com/in/bichenwu), [Xiaoliang Dai](https://sites.google.com/view/xiaoliangdai/), [Kunpeng Li](https://kunpengli1994.github.io/), [Yinan Zhao](https://yinan-zhao.github.io/), [Hang Zhang](https://hangzhang.org/), [Peizhao Zhang](https://www.linkedin.com/in/peizhao-zhang-14846042/), [Peter Vajda](https://sites.google.com/site/vajdap), [Diana Marculescu](https://www.ece.utexas.edu/people/faculty/diana-marculescu)
8 | Computer Vision and Pattern Recognition Conference (CVPR), 2023 9 | 10 | [[arXiv](https://arxiv.org/abs/2210.04150)] [[Project](https://jeff-liangf.github.io/projects/ovseg/)] [[huggingface demo](https://huggingface.co/spaces/facebook/ov-seg)] 11 | 12 |

13 | 14 |

15 | 16 | 17 | ## Installation 18 | 19 | Please see [installation guide](./INSTALL.md). 20 | 21 | ## Data Preparation 22 | 23 | Please see [datasets preparation](./datasets/DATASETS.md). 24 | 25 | ## Getting started 26 | 27 | Please see [getting started instruction](./GETTING_STARTED.md). 28 | 29 | ## Finetuning CLIP 30 | 31 | Please see [open clip training](./open_clip_training/README.md). 32 | 33 | ## LICENSE 34 | 35 | Shield: [![CC BY-NC 4.0][cc-by-nc-shield]][cc-by-nc] 36 | 37 | The majority of OVSeg is licensed under a 38 | [Creative Commons Attribution-NonCommercial 4.0 International License](LICENSE). 39 | 40 | [![CC BY-NC 4.0][cc-by-nc-image]][cc-by-nc] 41 | 42 | [cc-by-nc]: http://creativecommons.org/licenses/by-nc/4.0/ 43 | [cc-by-nc-image]: https://licensebuttons.net/l/by-nc/4.0/88x31.png 44 | [cc-by-nc-shield]: https://img.shields.io/badge/License-CC%20BY--NC%204.0-lightgrey.svg 45 | 46 | However portions of the project are under separate license terms: CLIP and ZSSEG are licensed under the [MIT license](https://github.com/openai/CLIP/blob/main/LICENSE); MaskFormer is licensed under the [CC-BY-NC](https://github.com/facebookresearch/MaskFormer/blob/main/LICENSE); openclip is licensed under the license at [its repo](https://github.com/mlfoundations/open_clip/blob/main/LICENSE). 47 | 48 | 49 | ## Citing OVSeg :pray: 50 | 51 | If you use OVSeg in your research or wish to refer to the baseline results published in the paper, please use the following BibTeX entry. 52 | 53 | ```BibTeX 54 | @inproceedings{liang2023open, 55 | title={Open-vocabulary semantic segmentation with mask-adapted clip}, 56 | author={Liang, Feng and Wu, Bichen and Dai, Xiaoliang and Li, Kunpeng and Zhao, Yinan and Zhang, Hang and Zhang, Peizhao and Vajda, Peter and Marculescu, Diana}, 57 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 58 | pages={7061--7070}, 59 | year={2023} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /configs/ovseg_R101c_vitB_bs32_120k.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "OVSeg" 3 | BACKBONE: 4 | NAME: "build_resnet_deeplab_backbone" 5 | RESNETS: 6 | DEPTH: 101 7 | STEM_TYPE: "deeplab" 8 | STEM_OUT_CHANNELS: 128 9 | STRIDE_IN_1X1: False 10 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 11 | # NORM: "SyncBN" 12 | RES5_MULTI_GRID: [1, 2, 4] 13 | WEIGHTS: "detectron2://DeepLab/R-103.pkl" 14 | PIXEL_MEAN: [123.675, 116.280, 103.530] 15 | PIXEL_STD: [58.395, 57.120, 57.375] 16 | SEM_SEG_HEAD: 17 | NAME: "OpenVocabMaskFormerHead" 18 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 19 | IGNORE_VALUE: 255 20 | NUM_CLASSES: 171 # number of categories in training set 21 | EMBEDDING_DIM: 512 22 | EMBED_LAYERS: 2 23 | COMMON_STRIDE: 4 # not used, hard-coded 24 | LOSS_WEIGHT: 1.0 25 | CONVS_DIM: 256 26 | MASK_DIM: 256 27 | NORM: "GN" 28 | MASK_FORMER: 29 | TRANSFORMER_IN_FEATURE: "res5" 30 | DEEP_SUPERVISION: True 31 | NO_OBJECT_WEIGHT: 0.1 32 | DICE_WEIGHT: 1.0 33 | MASK_WEIGHT: 20.0 34 | HIDDEN_DIM: 256 35 | NUM_OBJECT_QUERIES: 100 36 | NHEADS: 8 37 | DROPOUT: 0.1 38 | DIM_FEEDFORWARD: 2048 39 | ENC_LAYERS: 0 40 | DEC_LAYERS: 6 41 | PRE_NORM: False 42 | CLIP_ADAPTER: 43 | TEXT_TEMPLATES: "vild" 44 | CLIP_MODEL_NAME: "ViT-B/16" 45 | MASK_FILL: "mean" 46 | MASK_EXPAND_RATIO: 1.0 47 | MASK_THR: 0.5 # choose the foreground objects 48 | MASK_MATTING: False # use soft background, default not used 49 | MASK_PROMPT_DEPTH: 3 50 | MASK_PROMPT_FWD: True # use mask prompt during forward 51 | REGION_RESIZED: True # resize to the input of clip, e.g., 224 52 | CLIP_ENSEMBLE: True # use ensemble of two classification branches 53 | CLIP_ENSEMBLE_WEIGHT: 0.7 54 | DATASETS: 55 | TRAIN: ("coco_2017_train_stuff_sem_seg",) 56 | TEST: ("ade20k_sem_seg_val",) 57 | SOLVER: 58 | IMS_PER_BATCH: 32 59 | BASE_LR: 0.0002 60 | MAX_ITER: 120000 61 | WARMUP_FACTOR: 1e-6 62 | WARMUP_ITERS: 1500 63 | LR_SCHEDULER_NAME: "WarmupPolyLR" 64 | WEIGHT_DECAY: 0.0001 65 | WEIGHT_DECAY_NORM: 0.0 66 | WEIGHT_DECAY_EMBED: 0.0 67 | BACKBONE_MULTIPLIER: 0.1 68 | TEST_IMS_PER_BATCH: 1 69 | CLIP_GRADIENTS: 70 | ENABLED: True 71 | CLIP_TYPE: "full_model" 72 | CLIP_VALUE: 0.01 73 | NORM_TYPE: 2.0 74 | INPUT: 75 | MIN_SIZE_TEST: 512 76 | MAX_SIZE_TEST: 2048 77 | CROP: 78 | ENABLED: True 79 | TYPE: "absolute" 80 | SIZE: (512, 512) 81 | SINGLE_CATEGORY_MAX_AREA: 1.0 82 | COLOR_AUG_SSD: True 83 | SIZE_DIVISIBILITY: 512 # used in dataset mapper 84 | FORMAT: "RGB" 85 | TEST: 86 | EVAL_PERIOD: 5000 87 | AUG: 88 | ENABLED: False 89 | MIN_SIZES: [256, 384, 512, 640, 768, 896] 90 | MAX_SIZE: 3584 91 | FLIP: True 92 | DATALOADER: 93 | FILTER_EMPTY_ANNOTATIONS: True 94 | NUM_WORKERS: 4 95 | VERSION: 2 -------------------------------------------------------------------------------- /configs/ovseg_swinB_vitL_bs32_120k.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "OVSeg" 3 | BACKBONE: 4 | FREEZE_AT: 0 5 | NAME: "D2SwinTransformer" 6 | SWIN: 7 | EMBED_DIM: 128 8 | DEPTHS: [2, 2, 18, 2] 9 | NUM_HEADS: [4, 8, 16, 32] 10 | WINDOW_SIZE: 12 11 | APE: False 12 | DROP_PATH_RATE: 0.3 13 | PATCH_NORM: True 14 | PRETRAIN_IMG_SIZE: 384 15 | WEIGHTS: "swin_base_patch4_window12_384_22k.pkl" 16 | PIXEL_MEAN: [123.675, 116.280, 103.530] 17 | PIXEL_STD: [58.395, 57.120, 57.375] 18 | SEM_SEG_HEAD: 19 | NAME: "OpenVocabMaskFormerHead" 20 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 21 | IGNORE_VALUE: 255 22 | NUM_CLASSES: 171 # number of categories in training set 23 | EMBEDDING_DIM: 768 24 | EMBED_LAYERS: 2 25 | COMMON_STRIDE: 4 # not used, hard-coded 26 | LOSS_WEIGHT: 1.0 27 | CONVS_DIM: 256 28 | MASK_DIM: 256 29 | NORM: "GN" 30 | MASK_FORMER: 31 | TRANSFORMER_IN_FEATURE: "res5" 32 | DEEP_SUPERVISION: True 33 | NO_OBJECT_WEIGHT: 0.1 34 | DICE_WEIGHT: 1.0 35 | MASK_WEIGHT: 20.0 36 | HIDDEN_DIM: 256 37 | NUM_OBJECT_QUERIES: 100 38 | NHEADS: 8 39 | DROPOUT: 0.1 40 | DIM_FEEDFORWARD: 2048 41 | ENC_LAYERS: 0 42 | DEC_LAYERS: 6 43 | PRE_NORM: False 44 | CLIP_ADAPTER: 45 | TEXT_TEMPLATES: "vild" 46 | CLIP_MODEL_NAME: "ViT-L/14" 47 | MASK_FILL: "mean" 48 | MASK_EXPAND_RATIO: 1.0 49 | MASK_THR: 0.4 # choose the foreground objects 50 | MASK_MATTING: False # use soft background, default not used 51 | MASK_PROMPT_DEPTH: 3 52 | MASK_PROMPT_FWD: True # use mask prompt during forward 53 | REGION_RESIZED: True # resize to the input of clip, e.g., 224 54 | CLIP_ENSEMBLE: True # use ensemble of two classification branches 55 | CLIP_ENSEMBLE_WEIGHT: 0.7 56 | DATASETS: 57 | TRAIN: ("coco_2017_train_stuff_sem_seg",) 58 | TEST: ("ade20k_sem_seg_val",) 59 | SOLVER: 60 | IMS_PER_BATCH: 32 61 | BASE_LR: 0.00006 62 | MAX_ITER: 120000 63 | WARMUP_FACTOR: 1e-6 64 | WARMUP_ITERS: 1500 65 | LR_SCHEDULER_NAME: "WarmupPolyLR" 66 | WEIGHT_DECAY: 0.01 67 | WEIGHT_DECAY_NORM: 0.0 68 | WEIGHT_DECAY_EMBED: 0.0 69 | BACKBONE_MULTIPLIER: 1.0 70 | TEST_IMS_PER_BATCH: 1 71 | CLIP_GRADIENTS: 72 | ENABLED: True 73 | CLIP_TYPE: "full_model" 74 | CLIP_VALUE: 0.01 75 | NORM_TYPE: 2.0 76 | INPUT: 77 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"] 78 | MIN_SIZE_TRAIN_SAMPLING: "choice" 79 | MIN_SIZE_TEST: 640 80 | MAX_SIZE_TRAIN: 2560 81 | MAX_SIZE_TEST: 2560 82 | CROP: 83 | ENABLED: True 84 | TYPE: "absolute" 85 | SIZE: (640, 640) 86 | SINGLE_CATEGORY_MAX_AREA: 1.0 87 | COLOR_AUG_SSD: True 88 | SIZE_DIVISIBILITY: 640 # used in dataset mapper 89 | FORMAT: "RGB" 90 | TEST: 91 | EVAL_PERIOD: 5000 92 | AUG: 93 | ENABLED: False 94 | MIN_SIZES: [256, 384, 512, 640, 768, 896] 95 | MAX_SIZE: 3584 96 | FLIP: True 97 | DATALOADER: 98 | FILTER_EMPTY_ANNOTATIONS: True 99 | NUM_WORKERS: 4 100 | VERSION: 2 -------------------------------------------------------------------------------- /configs/ovseg_swinB_vitL_demo.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "OVSegDEMO" 3 | BACKBONE: 4 | FREEZE_AT: 0 5 | NAME: "D2SwinTransformer" 6 | SWIN: 7 | EMBED_DIM: 128 8 | DEPTHS: [2, 2, 18, 2] 9 | NUM_HEADS: [4, 8, 16, 32] 10 | WINDOW_SIZE: 12 11 | APE: False 12 | DROP_PATH_RATE: 0.3 13 | PATCH_NORM: True 14 | PRETRAIN_IMG_SIZE: 384 15 | WEIGHTS: "swin_base_patch4_window12_384_22k.pkl" 16 | PIXEL_MEAN: [123.675, 116.280, 103.530] 17 | PIXEL_STD: [58.395, 57.120, 57.375] 18 | SEM_SEG_HEAD: 19 | NAME: "OpenVocabMaskFormerHead" 20 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 21 | IGNORE_VALUE: 255 22 | NUM_CLASSES: 171 # number of categories in training set 23 | EMBEDDING_DIM: 768 24 | EMBED_LAYERS: 2 25 | COMMON_STRIDE: 4 # not used, hard-coded 26 | LOSS_WEIGHT: 1.0 27 | CONVS_DIM: 256 28 | MASK_DIM: 256 29 | NORM: "GN" 30 | MASK_FORMER: 31 | TRANSFORMER_IN_FEATURE: "res5" 32 | DEEP_SUPERVISION: True 33 | NO_OBJECT_WEIGHT: 0.1 34 | DICE_WEIGHT: 1.0 35 | MASK_WEIGHT: 20.0 36 | HIDDEN_DIM: 256 37 | NUM_OBJECT_QUERIES: 100 38 | NHEADS: 8 39 | DROPOUT: 0.1 40 | DIM_FEEDFORWARD: 2048 41 | ENC_LAYERS: 0 42 | DEC_LAYERS: 6 43 | PRE_NORM: False 44 | CLIP_ADAPTER: 45 | TEXT_TEMPLATES: "vild" 46 | CLIP_MODEL_NAME: "ViT-L/14" 47 | MASK_FILL: "mean" 48 | MASK_EXPAND_RATIO: 1.0 49 | MASK_THR: 0.35 # choose the foreground objects 50 | MASK_MATTING: False # use soft background, default not used 51 | MASK_PROMPT_DEPTH: 3 52 | MASK_PROMPT_FWD: True # use mask prompt during forward 53 | REGION_RESIZED: True # resize to the input of clip, e.g., 224 54 | CLIP_ENSEMBLE: True # use ensemble of two classification branches 55 | CLIP_ENSEMBLE_WEIGHT: 0.0 56 | DATASETS: 57 | TRAIN: ("coco_2017_train_stuff_sem_seg",) 58 | TEST: ("ade20k_sem_seg_val",) 59 | SOLVER: 60 | IMS_PER_BATCH: 32 61 | BASE_LR: 0.00006 62 | MAX_ITER: 120000 63 | WARMUP_FACTOR: 1e-6 64 | WARMUP_ITERS: 1500 65 | WEIGHT_DECAY: 0.01 66 | WEIGHT_DECAY_NORM: 0.0 67 | WEIGHT_DECAY_EMBED: 0.0 68 | BACKBONE_MULTIPLIER: 1.0 69 | TEST_IMS_PER_BATCH: 1 70 | CLIP_GRADIENTS: 71 | ENABLED: True 72 | CLIP_TYPE: "full_model" 73 | CLIP_VALUE: 0.01 74 | NORM_TYPE: 2.0 75 | INPUT: 76 | MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"] 77 | MIN_SIZE_TRAIN_SAMPLING: "choice" 78 | MIN_SIZE_TEST: 640 79 | MAX_SIZE_TRAIN: 2560 80 | MAX_SIZE_TEST: 2560 81 | CROP: 82 | ENABLED: True 83 | TYPE: "absolute" 84 | SIZE: (640, 640) 85 | SINGLE_CATEGORY_MAX_AREA: 1.0 86 | COLOR_AUG_SSD: True 87 | SIZE_DIVISIBILITY: 640 # used in dataset mapper 88 | FORMAT: "RGB" 89 | TEST: 90 | EVAL_PERIOD: 5000 91 | AUG: 92 | ENABLED: False 93 | MIN_SIZES: [256, 384, 512, 640, 768, 896] 94 | MAX_SIZE: 3584 95 | FLIP: True 96 | DATALOADER: 97 | FILTER_EMPTY_ANNOTATIONS: True 98 | NUM_WORKERS: 4 99 | VERSION: 2 -------------------------------------------------------------------------------- /datasets/DATASETS.md: -------------------------------------------------------------------------------- 1 | ## Prepare Datasets for OVSeg 2 | 3 | This doc is a modification/extension of [MaskFormer](https://github.com/facebookresearch/MaskFormer/blob/main/datasets/README.md) following [Detectron2 fromat](https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html). 4 | 5 | A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog) 6 | for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc). 7 | This document explains how to setup the builtin datasets so they can be used by the above APIs. 8 | [Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`, 9 | and how to add new datasets to them. 10 | 11 | OVSeg has builtin support for a few datasets. 12 | The datasets are assumed to exist in a directory specified by the environment variable 13 | `DETECTRON2_DATASETS`. 14 | Under this directory, detectron2 will look for datasets in the structure described below, if needed. 15 | ``` 16 | $DETECTRON2_DATASETS/ 17 | coco/ # COCOStuff-171 18 | ADEChallengeData2016/ # ADE20K-150 19 | ADE20K_2021_17_01/ # ADE20K-847 20 | VOCdevkit/ 21 | VOC2012/ # PASCALVOC-20 22 | VOC2010/ # PASCALContext-59, PASCALContext-459 23 | ``` 24 | 25 | You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`. 26 | If left unset, the default is `./datasets` relative to your current working directory. 27 | 28 | Without specific notifications, our model is trained on COCOStuff-171 and evlauted on ADE20K-150, ADE20K-847, PASCALVOC-20, PASCALContext-59 and PASCALContext-459. 29 | 30 | | dataset | split | # images | # categories | 31 | |:--------------:|:---------:|:--------:|:------------:| 32 | | COCO Stuff | train2017 | 118K | 171 | 33 | | ADE20K | val | 2K | 150/847 | 34 | | Pascal VOC | val | 1.5K | 20 | 35 | | Pascal Context | val | 5K | 59/459 | 36 | 37 | 38 | ### Expected dataset structure for [COCO Stuff](https://github.com/nightrome/cocostuff): 39 | ``` 40 | coco/ 41 | train2017/ # http://images.cocodataset.org/zips/train2017.zip 42 | annotations/ # http://images.cocodataset.org/annotations/annotations_trainval2017.zip 43 | stuffthingmaps/ 44 | stuffthingmaps_trainval2017.zip # http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip 45 | train2017/ 46 | # below are generated 47 | stuffthingmaps_detectron2/ 48 | train2017/ 49 | ``` 50 | 51 | The directory `stuffthingmaps_detectron2` is generated by running `python datasets/prepare_coco_stuff_sem_seg.py`. 52 | 53 | 54 | 55 | ### Expected dataset structure for [ADE20k Scene Parsing (ADE20K-150)](http://sceneparsing.csail.mit.edu/): 56 | ``` 57 | ADEChallengeData2016/ 58 | annotations/ 59 | images/ 60 | objectInfo150.txt 61 | # below are generated 62 | annotations_detectron2/ 63 | ``` 64 | The directory `annotations_detectron2` is generated by running `python datasets/prepare_ade20k_sem_seg.py`. 65 | 66 | 67 | ### Expected dataset structure for [ADE20k-Full (ADE20K-847)](https://github.com/CSAILVision/ADE20K#download): 68 | ``` 69 | ADE20K_2021_17_01/ 70 | images/ 71 | index_ade20k.pkl 72 | objects.txt 73 | # below are generated 74 | images_detectron2/ 75 | annotations_detectron2/ 76 | ``` 77 | The directories `images_detectron2` and `annotations_detectron2` are generated by running `python datasets/prepare_ade20k_full_sem_seg.py`. 78 | 79 | ### Expected dataset structure for [Pascal VOC 2012 (PASCALVOC-20)](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/#devkit): 80 | ``` 81 | VOCdevkit/VOC2012/ 82 | Annotations/ 83 | ImageSets/ 84 | JPEGImages/ 85 | SegmentationClass/ 86 | SegmentationObject/ 87 | SegmentationClassAug/ # https://github.com/kazuto1011/deeplab-pytorch/blob/master/data/datasets/voc12/README.md 88 | # below are generated 89 | images_detectron2/ 90 | annotations_detectron2/ 91 | ``` 92 | 93 | It starts with a tar file `VOCtrainval_11-May-2012.tar`. 94 | 95 | We use SBD augmentated training data as `SegmentationClassAug` following [Deeplab](https://github.com/kazuto1011/deeplab-pytorch/blob/master/data/datasets/voc12/README.md) 96 | 97 | The directories `images_detectron2` and `annotations_detectron2` are generated by running `python datasets/prepare_voc_sem_seg.py`. 98 | 99 | 100 | ### Expected dataset structure for [Pascal Context](https://www.cs.stanford.edu/~roozbeh/pascal-context/): 101 | 102 | ``` 103 | VOCdevkit/VOC2010/ 104 | Annotations/ 105 | ImageSets/ 106 | JPEGImages/ 107 | SegmentationClass/ 108 | SegmentationObject/ 109 | # below are from https://www.cs.stanford.edu/~roozbeh/pascal-context/trainval.tar.gz 110 | trainval/ 111 | labels.txt 112 | 59_labels.txt # https://www.cs.stanford.edu/~roozbeh/pascal-context/59_labels.txt 113 | pascalcontext_val.txt # https://drive.google.com/file/d/1BCbiOKtLvozjVnlTJX51koIveUZHCcUh/view?usp=sharing 114 | # below are generated 115 | annotations_detectron2/ 116 | pc459_val 117 | pc59_val 118 | ``` 119 | It starts with a tar file `VOCtrainval_03-May-2010.tar`. You may want to download the 5K validation set [here](https://drive.google.com/file/d/1BCbiOKtLvozjVnlTJX51koIveUZHCcUh/view?usp=sharing). 120 | 121 | The directory `annotations_detectron2` is generated by running `python datasets/prepare_pascal_context.py`. 122 | 123 | -------------------------------------------------------------------------------- /datasets/prepare_ade20k_sem_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import os 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import tqdm 9 | from PIL import Image 10 | 11 | 12 | def convert(input, output, index=None): 13 | img = np.asarray(Image.open(input)) 14 | assert img.dtype == np.uint8 15 | img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1 16 | if index is not None: 17 | mapping = {i: k for k, i in enumerate(index)} 18 | img = np.vectorize(lambda x: mapping[x] if x in mapping else 255)( 19 | img.astype(np.float) 20 | ).astype(np.uint8) 21 | Image.fromarray(img).save(output) 22 | 23 | 24 | if __name__ == "__main__": 25 | dataset_dir = ( 26 | Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "ADEChallengeData2016" 27 | ) 28 | print('Caution: we only generate the validation set!') 29 | for name in ["validation"]: 30 | annotation_dir = dataset_dir / "annotations" / name 31 | output_dir = dataset_dir / "annotations_detectron2" / name 32 | output_dir.mkdir(parents=True, exist_ok=True) 33 | for file in tqdm.tqdm(list(annotation_dir.iterdir())): 34 | output_file = output_dir / file.name 35 | convert(file, output_file) 36 | -------------------------------------------------------------------------------- /datasets/prepare_coco_stuff_sem_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Modified by Feng Liang from 4 | # https://github.com/MendelXu/zsseg.baseline/blob/master/datasets/prepare_coco_stuff_164k_sem_seg.py 5 | 6 | import os 7 | import os.path as osp 8 | from pathlib import Path 9 | import tqdm 10 | from glob import glob 11 | 12 | import numpy as np 13 | from PIL import Image 14 | 15 | 16 | full_clsID_to_trID = { 17 | 0: 0, 18 | 1: 1, 19 | 2: 2, 20 | 3: 3, 21 | 4: 4, 22 | 5: 5, 23 | 6: 6, 24 | 7: 7, 25 | 8: 8, 26 | 9: 9, 27 | 10: 10, 28 | 12: 11, 29 | 13: 12, 30 | 14: 13, 31 | 15: 14, 32 | 16: 15, 33 | 17: 16, 34 | 18: 17, 35 | 19: 18, 36 | 20: 19, 37 | 21: 20, 38 | 22: 21, 39 | 23: 22, 40 | 24: 23, 41 | 26: 24, 42 | 27: 25, 43 | 30: 26, 44 | 31: 27, 45 | 32: 28, 46 | 33: 29, 47 | 34: 30, 48 | 35: 31, 49 | 36: 32, 50 | 37: 33, 51 | 38: 34, 52 | 39: 35, 53 | 40: 36, 54 | 41: 37, 55 | 42: 38, 56 | 43: 39, 57 | 45: 40, 58 | 46: 41, 59 | 47: 42, 60 | 48: 43, 61 | 49: 44, 62 | 50: 45, 63 | 51: 46, 64 | 52: 47, 65 | 53: 48, 66 | 54: 49, 67 | 55: 50, 68 | 56: 51, 69 | 57: 52, 70 | 58: 53, 71 | 59: 54, 72 | 60: 55, 73 | 61: 56, 74 | 62: 57, 75 | 63: 58, 76 | 64: 59, 77 | 66: 60, 78 | 69: 61, 79 | 71: 62, 80 | 72: 63, 81 | 73: 64, 82 | 74: 65, 83 | 75: 66, 84 | 76: 67, 85 | 77: 68, 86 | 78: 69, 87 | 79: 70, 88 | 80: 71, 89 | 81: 72, 90 | 83: 73, 91 | 84: 74, 92 | 85: 75, 93 | 86: 76, 94 | 87: 77, 95 | 88: 78, 96 | 89: 79, 97 | 91: 80, 98 | 92: 81, 99 | 93: 82, 100 | 94: 83, 101 | 95: 84, 102 | 96: 85, 103 | 97: 86, 104 | 98: 87, 105 | 99: 88, 106 | 100: 89, 107 | 101: 90, 108 | 102: 91, 109 | 103: 92, 110 | 104: 93, 111 | 105: 94, 112 | 106: 95, 113 | 107: 96, 114 | 108: 97, 115 | 109: 98, 116 | 110: 99, 117 | 111: 100, 118 | 112: 101, 119 | 113: 102, 120 | 114: 103, 121 | 115: 104, 122 | 116: 105, 123 | 117: 106, 124 | 118: 107, 125 | 119: 108, 126 | 120: 109, 127 | 121: 110, 128 | 122: 111, 129 | 123: 112, 130 | 124: 113, 131 | 125: 114, 132 | 126: 115, 133 | 127: 116, 134 | 128: 117, 135 | 129: 118, 136 | 130: 119, 137 | 131: 120, 138 | 132: 121, 139 | 133: 122, 140 | 134: 123, 141 | 135: 124, 142 | 136: 125, 143 | 137: 126, 144 | 138: 127, 145 | 139: 128, 146 | 140: 129, 147 | 141: 130, 148 | 142: 131, 149 | 143: 132, 150 | 144: 133, 151 | 145: 134, 152 | 146: 135, 153 | 147: 136, 154 | 148: 137, 155 | 149: 138, 156 | 150: 139, 157 | 151: 140, 158 | 152: 141, 159 | 153: 142, 160 | 154: 143, 161 | 155: 144, 162 | 156: 145, 163 | 157: 146, 164 | 158: 147, 165 | 159: 148, 166 | 160: 149, 167 | 161: 150, 168 | 162: 151, 169 | 163: 152, 170 | 164: 153, 171 | 165: 154, 172 | 166: 155, 173 | 167: 156, 174 | 168: 157, 175 | 169: 158, 176 | 170: 159, 177 | 171: 160, 178 | 172: 161, 179 | 173: 162, 180 | 174: 163, 181 | 175: 164, 182 | 176: 165, 183 | 177: 166, 184 | 178: 167, 185 | 179: 168, 186 | 180: 169, 187 | 181: 170, 188 | 255: 255, 189 | } 190 | 191 | def convert_to_trainID( 192 | maskpath, out_mask_dir, is_train, clsID_to_trID=full_clsID_to_trID, suffix="" 193 | ): 194 | mask = np.array(Image.open(maskpath)) 195 | mask_copy = np.ones_like(mask, dtype=np.uint8) * 255 196 | for clsID, trID in clsID_to_trID.items(): 197 | mask_copy[mask == clsID] = trID 198 | seg_filename = ( 199 | osp.join(out_mask_dir, "train2017" + suffix, osp.basename(maskpath)) 200 | if is_train 201 | else osp.join(out_mask_dir, "val2017" + suffix, osp.basename(maskpath)) 202 | ) 203 | if len(np.unique(mask_copy)) == 1 and np.unique(mask_copy)[0] == 255: 204 | return 205 | Image.fromarray(mask_copy).save(seg_filename, "PNG") 206 | 207 | 208 | 209 | if __name__ == "__main__": 210 | dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) 211 | print('Caution: we only generate the training set!') 212 | coco_path = dataset_dir / "coco" 213 | mask_dir = coco_path / "stuffthingmaps" 214 | out_mask_dir = coco_path / "stuffthingmaps_detectron2" 215 | for name in ["train2017"]: 216 | os.makedirs((out_mask_dir / name), exist_ok=True) 217 | train_list = glob(osp.join(mask_dir, "train2017", "*.png")) 218 | for file in tqdm.tqdm(train_list): 219 | convert_to_trainID(file, out_mask_dir, is_train=True) 220 | -------------------------------------------------------------------------------- /datasets/prepare_pascal_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import tqdm 5 | import os 6 | import os.path as osp 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | from PIL import Image 11 | import scipy.io 12 | 13 | def convert_pc59(mask_path, new_mask_path, pc59_dict): 14 | mat = scipy.io.loadmat(mask_path) 15 | mask = mat['LabelMap'] 16 | 17 | mask_copy = np.ones_like(mask, dtype=np.uint8) * 255 18 | for trID, clsID in pc59_dict.items(): 19 | mask_copy[mask == clsID] = trID 20 | 21 | min_value = np.amin(mask_copy) 22 | assert min_value >= 0, print(min_value) 23 | Image.fromarray(mask_copy).save(new_mask_path, "PNG") 24 | 25 | def convert_pc459(mask_path, new_mask_path): 26 | mat = scipy.io.loadmat(mask_path) 27 | mask = mat['LabelMap'] 28 | mask = mask - 1 29 | min_value = np.amin(mask) 30 | assert min_value >= 0, print(min_value) 31 | Image.fromarray(mask).save(new_mask_path, "TIFF") 32 | 33 | 34 | if __name__ == "__main__": 35 | dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) 36 | print('Caution: we only generate the validation set!') 37 | pc_path = dataset_dir / "VOCdevkit/VOC2010" 38 | 39 | val_list = open(pc_path / "pascalcontext_val.txt", "r") 40 | pc459_labels = open(pc_path / "labels.txt", "r") 41 | pc59_labels = open(pc_path / "59_labels.txt", "r") 42 | 43 | pc459_dict = {} 44 | for line in pc459_labels.readlines(): 45 | if ':' in line: 46 | idx, name = line.split(':') 47 | idx = int(idx.strip()) 48 | name = name.strip() 49 | pc459_dict[name] = idx 50 | 51 | pc59_dict = {} 52 | for i, line in enumerate(pc59_labels.readlines()): 53 | name = line.split(':')[-1].strip() 54 | if name is not '': 55 | pc59_dict[i] = pc459_dict[name] 56 | 57 | pc459_dir = pc_path / "annotations_detectron2" / "pc459_val" 58 | pc459_dir.mkdir(parents=True, exist_ok=True) 59 | pc59_dir = pc_path / "annotations_detectron2" / "pc59_val" 60 | pc59_dir.mkdir(parents=True, exist_ok=True) 61 | 62 | for line in tqdm.tqdm(val_list.readlines()): 63 | fileid = line.strip() 64 | ori_mask = f'{pc_path}/trainval/{fileid}.mat' 65 | pc459_dst = f'{pc459_dir}/{fileid}.tif' 66 | pc59_dst = f'{pc59_dir}/{fileid}.png' 67 | if osp.exists(ori_mask): 68 | convert_pc459(ori_mask, pc459_dst) 69 | convert_pc59(ori_mask, pc59_dst, pc59_dict) 70 | -------------------------------------------------------------------------------- /datasets/prepare_voc_sem_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Modified by Feng Liang from https://github.com/MendelXu/zsseg.baseline/blob/master/datasets/prepare_voc_sem_seg.py 4 | 5 | import os 6 | import os.path as osp 7 | from pathlib import Path 8 | import tqdm 9 | 10 | import numpy as np 11 | from PIL import Image 12 | 13 | 14 | clsID_to_trID = { 15 | 0: 255, 16 | 1: 0, 17 | 2: 1, 18 | 3: 2, 19 | 4: 3, 20 | 5: 4, 21 | 6: 5, 22 | 7: 6, 23 | 8: 7, 24 | 9: 8, 25 | 10: 9, 26 | 11: 10, 27 | 12: 11, 28 | 13: 12, 29 | 14: 13, 30 | 15: 14, 31 | 16: 15, 32 | 17: 16, 33 | 18: 17, 34 | 19: 18, 35 | 20: 19, 36 | 255: 255, 37 | } 38 | 39 | def convert_to_trainID( 40 | maskpath, out_mask_dir, is_train, clsID_to_trID=clsID_to_trID, suffix="" 41 | ): 42 | mask = np.array(Image.open(maskpath)) 43 | mask_copy = np.ones_like(mask, dtype=np.uint8) * 255 44 | for clsID, trID in clsID_to_trID.items(): 45 | mask_copy[mask == clsID] = trID 46 | seg_filename = ( 47 | osp.join(out_mask_dir, "train" + suffix, osp.basename(maskpath)) 48 | if is_train 49 | else osp.join(out_mask_dir, "val" + suffix, osp.basename(maskpath)) 50 | ) 51 | if len(np.unique(mask_copy)) == 1 and np.unique(mask_copy)[0] == 255: 52 | return 53 | Image.fromarray(mask_copy).save(seg_filename, "PNG") 54 | 55 | 56 | 57 | if __name__ == "__main__": 58 | dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) 59 | print('Caution: we only generate the validation set!') 60 | voc_path = dataset_dir / "VOCdevkit" / "VOC2012" 61 | out_mask_dir = voc_path / "annotations_detectron2" 62 | out_image_dir = voc_path / "images_detectron2" 63 | for name in ["val"]: 64 | os.makedirs((out_mask_dir / name), exist_ok=True) 65 | os.makedirs((out_image_dir / name), exist_ok=True) 66 | val_list = [ 67 | osp.join(voc_path, "SegmentationClassAug", f + ".png") 68 | for f in np.loadtxt(osp.join(voc_path, "ImageSets/Segmentation/val.txt"), dtype=np.str).tolist() 69 | ] 70 | for file in tqdm.tqdm(val_list): 71 | convert_to_trainID(file, out_mask_dir, is_train=False) 72 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import argparse 5 | import glob 6 | import multiprocessing as mp 7 | import os 8 | import time 9 | import cv2 10 | import tqdm 11 | 12 | from detectron2.config import get_cfg 13 | 14 | from detectron2.projects.deeplab import add_deeplab_config 15 | from detectron2.data.detection_utils import read_image 16 | from detectron2.utils.logger import setup_logger 17 | from open_vocab_seg import add_ovseg_config 18 | 19 | from open_vocab_seg.utils import VisualizationDemo 20 | 21 | # constants 22 | WINDOW_NAME = "Open vocabulary segmentation" 23 | 24 | 25 | def setup_cfg(args): 26 | # load config from file and command-line arguments 27 | cfg = get_cfg() 28 | # for poly lr schedule 29 | add_deeplab_config(cfg) 30 | add_ovseg_config(cfg) 31 | cfg.merge_from_file(args.config_file) 32 | cfg.merge_from_list(args.opts) 33 | cfg.freeze() 34 | return cfg 35 | 36 | 37 | def get_parser(): 38 | parser = argparse.ArgumentParser(description="Detectron2 demo for open vocabulary segmentation") 39 | parser.add_argument( 40 | "--config-file", 41 | default="configs/ovseg_swinB_vitL_demo.yaml", 42 | metavar="FILE", 43 | help="path to config file", 44 | ) 45 | parser.add_argument( 46 | "--input", 47 | nargs="+", 48 | help="A list of space separated input images; " 49 | "or a single glob pattern such as 'directory/*.jpg'", 50 | ) 51 | parser.add_argument( 52 | "--class-names", 53 | nargs="+", 54 | help="A list of user-defined class_names" 55 | ) 56 | parser.add_argument( 57 | "--output", 58 | help="A file or directory to save output visualizations. " 59 | "If not given, will show output in an OpenCV window.", 60 | ) 61 | parser.add_argument( 62 | "--opts", 63 | help="Modify config options using the command-line 'KEY VALUE' pairs", 64 | default=[], 65 | nargs=argparse.REMAINDER, 66 | ) 67 | return parser 68 | 69 | 70 | if __name__ == "__main__": 71 | mp.set_start_method("spawn", force=True) 72 | args = get_parser().parse_args() 73 | setup_logger(name="fvcore") 74 | logger = setup_logger() 75 | logger.info("Arguments: " + str(args)) 76 | 77 | cfg = setup_cfg(args) 78 | 79 | demo = VisualizationDemo(cfg) 80 | class_names = args.class_names 81 | if args.input: 82 | if len(args.input) == 1: 83 | args.input = glob.glob(os.path.expanduser(args.input[0])) 84 | assert args.input, "The input path(s) was not found" 85 | for path in tqdm.tqdm(args.input, disable=not args.output): 86 | # use PIL, to be consistent with evaluation 87 | img = read_image(path, format="BGR") 88 | start_time = time.time() 89 | predictions, visualized_output = demo.run_on_image(img, class_names) 90 | logger.info( 91 | "{}: {} in {:.2f}s".format( 92 | path, 93 | "detected {} instances".format(len(predictions["instances"])) 94 | if "instances" in predictions 95 | else "finished", 96 | time.time() - start_time, 97 | ) 98 | ) 99 | 100 | if args.output: 101 | if os.path.isdir(args.output): 102 | assert os.path.isdir(args.output), args.output 103 | out_filename = os.path.join(args.output, os.path.basename(path)) 104 | else: 105 | assert len(args.input) == 1, "Please specify a directory with args.output" 106 | out_filename = args.output 107 | visualized_output.save(out_filename) 108 | else: 109 | cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) 110 | cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1]) 111 | if cv2.waitKey(0) == 27: 112 | break # esc to quit 113 | else: 114 | raise NotImplementedError -------------------------------------------------------------------------------- /open_clip_training/CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | message: If you use this software, please cite it as below. 3 | authors: 4 | - family-names: Ilharco 5 | given-names: Gabriel 6 | - family-names: Wortsman 7 | given-names: Mitchell 8 | - family-names: Wightman 9 | given-names: Ross 10 | - family-names: Gordon 11 | given-names: Cade 12 | - family-names: Carlini 13 | given-names: Nicholas 14 | - family-names: Taori 15 | given-names: Rohan 16 | - family-names: Dave 17 | given-names: Achal 18 | - family-names: Shankar 19 | given-names: Vaishaal 20 | - family-names: Namkoong 21 | given-names: Hongseok 22 | - family-names: Miller 23 | given-names: John 24 | - family-names: Hajishirzi 25 | given-names: Hannaneh 26 | - family-names: Farhadi 27 | given-names: Ali 28 | - family-names: Schmidt 29 | given-names: Ludwig 30 | title: OpenCLIP 31 | version: v0.1 32 | doi: 10.5281/zenodo.5143773 33 | date-released: 2021-07-28 34 | -------------------------------------------------------------------------------- /open_clip_training/HISTORY.md: -------------------------------------------------------------------------------- 1 | ## 1.2.0 2 | 3 | * ViT-B/32 trained on Laion2B-en 4 | * add missing openai RN50x64 model 5 | 6 | ## 1.1.1 7 | 8 | * ViT-B/16+ 9 | * Add grad checkpointing support 10 | * more robust data loader 11 | -------------------------------------------------------------------------------- /open_clip_training/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2021 Gabriel Ilharco, Mitchell Wortsman, 2 | Nicholas Carlini, Rohan Taori, Achal Dave, Vaishaal Shankar, 3 | John Miller, Hongseok Namkoong, Hannaneh Hajishirzi, Ali Farhadi, 4 | Ludwig Schmidt 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining 7 | a copy of this software and associated documentation files (the 8 | "Software"), to deal in the Software without restriction, including 9 | without limitation the rights to use, copy, modify, merge, publish, 10 | distribute, sublicense, and/or sell copies of the Software, and to 11 | permit persons to whom the Software is furnished to do so, subject to 12 | the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be 15 | included in all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 18 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 20 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 21 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 22 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 23 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /open_clip_training/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include src/open_clip/bpe_simple_vocab_16e6.txt.gz 2 | include src/open_clip/model_configs/*.json 3 | 4 | -------------------------------------------------------------------------------- /open_clip_training/Makefile: -------------------------------------------------------------------------------- 1 | install: ## [Local development] Upgrade pip, install requirements, install package. 2 | python -m pip install -U pip 3 | python -m pip install -e . 4 | 5 | install-dev: ## [Local development] Install test requirements 6 | python -m pip install -r requirements-test.txt 7 | 8 | test: ## [Local development] Run unit tests 9 | python -m pytest -x -s -v tests 10 | -------------------------------------------------------------------------------- /open_clip_training/docs/CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/CLIP.png -------------------------------------------------------------------------------- /open_clip_training/docs/clip_conceptual_captions.md: -------------------------------------------------------------------------------- 1 | ## Additional training curves for CLIP on Conceptual Captions 2 | 3 | # Zero shot accuracy 4 | ![](/docs/clip_zeroshot.png) 5 | 6 | # Training loss curve 7 | ![](/docs/clip_loss.png) 8 | 9 | # Validation loss curve 10 | ![](/docs/clip_val_loss.png) 11 | 12 | # Validation recall 13 | ![](/docs/clip_recall.png) -------------------------------------------------------------------------------- /open_clip_training/docs/clip_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/clip_loss.png -------------------------------------------------------------------------------- /open_clip_training/docs/clip_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/clip_recall.png -------------------------------------------------------------------------------- /open_clip_training/docs/clip_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/clip_val_loss.png -------------------------------------------------------------------------------- /open_clip_training/docs/clip_zeroshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/clip_zeroshot.png -------------------------------------------------------------------------------- /open_clip_training/docs/effective_robustness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/effective_robustness.png -------------------------------------------------------------------------------- /open_clip_training/docs/laion2b_clip_zeroshot_b32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/laion2b_clip_zeroshot_b32.png -------------------------------------------------------------------------------- /open_clip_training/docs/laion_clip_zeroshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/laion_clip_zeroshot.png -------------------------------------------------------------------------------- /open_clip_training/docs/laion_clip_zeroshot_b16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/laion_clip_zeroshot_b16.png -------------------------------------------------------------------------------- /open_clip_training/docs/laion_clip_zeroshot_b16_plus_240.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/laion_clip_zeroshot_b16_plus_240.png -------------------------------------------------------------------------------- /open_clip_training/docs/laion_clip_zeroshot_l14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/laion_clip_zeroshot_l14.png -------------------------------------------------------------------------------- /open_clip_training/docs/laion_openai_compare_b32.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/laion_openai_compare_b32.jpg -------------------------------------------------------------------------------- /open_clip_training/docs/scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/docs/scaling.png -------------------------------------------------------------------------------- /open_clip_training/requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest-xdist==2.5.0 2 | pytest==7.0.1 -------------------------------------------------------------------------------- /open_clip_training/requirements-training.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | webdataset>=0.2.5 4 | regex 5 | ftfy 6 | tqdm 7 | pandas 8 | braceexpand 9 | -------------------------------------------------------------------------------- /open_clip_training/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision 3 | regex 4 | ftfy 5 | tqdm 6 | -------------------------------------------------------------------------------- /open_clip_training/setup.py: -------------------------------------------------------------------------------- 1 | """ Setup 2 | """ 3 | from setuptools import setup, find_packages 4 | from codecs import open 5 | from os import path 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | # Get the long description from the README file 10 | with open(path.join(here, 'README.md'), encoding='utf-8') as f: 11 | long_description = f.read() 12 | 13 | exec(open('src/open_clip/version.py').read()) 14 | setup( 15 | name='open_clip_torch', 16 | version=__version__, 17 | description='OpenCLIP', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/mlfoundations/open_clip', 21 | author='', 22 | author_email='', 23 | classifiers=[ 24 | # How mature is this project? Common values are 25 | # 3 - Alpha 26 | # 4 - Beta 27 | # 5 - Production/Stable 28 | 'Development Status :: 3 - Alpha', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3.7', 33 | 'Programming Language :: Python :: 3.8', 34 | 'Programming Language :: Python :: 3.9', 35 | 'Programming Language :: Python :: 3.10', 36 | 'Topic :: Scientific/Engineering', 37 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 38 | 'Topic :: Software Development', 39 | 'Topic :: Software Development :: Libraries', 40 | 'Topic :: Software Development :: Libraries :: Python Modules', 41 | ], 42 | 43 | # Note that this is a string of words separated by whitespace, not a list. 44 | keywords='CLIP pretrained', 45 | package_dir={'': 'src'}, 46 | packages=find_packages(where='src', exclude=['training']), 47 | include_package_data=True, 48 | install_requires=[ 49 | 'torch >= 1.9', 50 | 'torchvision', 51 | 'ftfy', 52 | 'regex', 53 | 'tqdm', 54 | ], 55 | python_requires='>=3.7', 56 | ) 57 | -------------------------------------------------------------------------------- /open_clip_training/src/data/gather_cc.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import os 3 | import multiprocessing as mp 4 | from io import BytesIO 5 | import numpy as np 6 | import PIL 7 | from PIL import Image 8 | import pickle 9 | import sys 10 | 11 | 12 | def grab(line): 13 | """ 14 | Download a single image from the TSV. 15 | """ 16 | uid, split, line = line 17 | try: 18 | caption, url = line.split("\t")[:2] 19 | except: 20 | print("Parse error") 21 | return 22 | 23 | if os.path.exists(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)): 24 | print("Finished", uid) 25 | return uid, caption, url 26 | 27 | # Let's not crash if anythign weird happens 28 | try: 29 | dat = requests.get(url, timeout=20) 30 | if dat.status_code != 200: 31 | print("404 file", url) 32 | return 33 | 34 | # Try to parse this as an Image file, we'll fail out if not 35 | im = Image.open(BytesIO(dat.content)) 36 | im.thumbnail((512, 512), PIL.Image.BICUBIC) 37 | if min(*im.size) < max(*im.size)/3: 38 | print("Too small", url) 39 | return 40 | 41 | im.save(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)) 42 | 43 | # Another try/catch just because sometimes saving and re-loading 44 | # the image is different than loading it once. 45 | try: 46 | o = Image.open(ROOT+"/%s/%d/%d.jpg"%(split,uid%1000,uid)) 47 | o = np.array(o) 48 | 49 | print("Success", o.shape, uid, url) 50 | return uid, caption, url 51 | except: 52 | print("Failed", uid, url) 53 | 54 | except Exception as e: 55 | print("Unknown error", e) 56 | pass 57 | 58 | if __name__ == "__main__": 59 | ROOT = "/home/jeffliang/data/cc3m" 60 | 61 | if not os.path.exists(ROOT): 62 | os.mkdir(ROOT) 63 | os.mkdir(os.path.join(ROOT,"train")) 64 | os.mkdir(os.path.join(ROOT,"val")) 65 | for i in range(1000): 66 | os.mkdir(os.path.join(ROOT,"train", str(i))) 67 | os.mkdir(os.path.join(ROOT,"val", str(i))) 68 | 69 | 70 | p = mp.Pool(300) 71 | 72 | for tsv in sys.argv[1:]: 73 | print("Processing file", tsv) 74 | assert 'val' in tsv.lower() or 'train' in tsv.lower() 75 | split = 'val' if 'val' in tsv.lower() else 'train' 76 | results = p.map(grab, 77 | [(i,split,x) for i,x in enumerate(open(tsv).read().split("\n"))]) 78 | 79 | out = open(tsv.replace(".tsv","_output.csv"),"w") 80 | out.write("title\tfilepath\n") 81 | 82 | for row in results: 83 | if row is None: continue 84 | id, caption, url = row 85 | fp = os.path.join(ROOT, split, str(id % 1000), str(id) + ".jpg") 86 | if os.path.exists(fp): 87 | out.write("%s\t%s\n"%(caption,fp)) 88 | else: 89 | print("Drop", id) 90 | out.close() 91 | 92 | p.close() 93 | 94 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config 2 | from .loss import ClipLoss 3 | from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model 4 | from .openai import load_openai_model, list_openai_models 5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ 6 | get_pretrained_url, download_pretrained 7 | from .tokenizer import SimpleTokenizer, tokenize 8 | from .transform import image_transform 9 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/open_clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/__pycache__/factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/open_clip/__pycache__/factory.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/open_clip/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/open_clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/__pycache__/openai.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/open_clip/__pycache__/openai.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/__pycache__/pretrained.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/open_clip/__pycache__/pretrained.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/__pycache__/timm_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/open_clip/__pycache__/timm_model.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/__pycache__/tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/open_clip/__pycache__/tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/__pycache__/transform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/open_clip/__pycache__/transform.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/open_clip/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | try: 6 | import torch.distributed.nn 7 | from torch import distributed as dist 8 | has_distributed = True 9 | except ImportError: 10 | has_distributed = False 11 | 12 | try: 13 | import horovod.torch as hvd 14 | except ImportError: 15 | hvd = None 16 | 17 | 18 | def gather_features( 19 | image_features, 20 | text_features, 21 | local_loss=False, 22 | gather_with_grad=False, 23 | rank=0, 24 | world_size=1, 25 | use_horovod=False 26 | ): 27 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 28 | if use_horovod: 29 | assert hvd is not None, 'Please install horovod' 30 | if gather_with_grad: 31 | all_image_features = hvd.allgather(image_features) 32 | all_text_features = hvd.allgather(text_features) 33 | else: 34 | with torch.no_grad(): 35 | all_image_features = hvd.allgather(image_features) 36 | all_text_features = hvd.allgather(text_features) 37 | if not local_loss: 38 | # ensure grads for local rank when all_* features don't have a gradient 39 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 40 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 41 | gathered_image_features[rank] = image_features 42 | gathered_text_features[rank] = text_features 43 | all_image_features = torch.cat(gathered_image_features, dim=0) 44 | all_text_features = torch.cat(gathered_text_features, dim=0) 45 | else: 46 | # We gather tensors from all gpus 47 | if gather_with_grad: 48 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 49 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 50 | else: 51 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 52 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 53 | dist.all_gather(gathered_image_features, image_features) 54 | dist.all_gather(gathered_text_features, text_features) 55 | if not local_loss: 56 | # ensure grads for local rank when all_* features don't have a gradient 57 | gathered_image_features[rank] = image_features 58 | gathered_text_features[rank] = text_features 59 | all_image_features = torch.cat(gathered_image_features, dim=0) 60 | all_text_features = torch.cat(gathered_text_features, dim=0) 61 | 62 | return all_image_features, all_text_features 63 | 64 | 65 | class ClipLoss(nn.Module): 66 | 67 | def __init__( 68 | self, 69 | local_loss=False, 70 | gather_with_grad=False, 71 | cache_labels=False, 72 | rank=0, 73 | world_size=1, 74 | use_horovod=False, 75 | ): 76 | super().__init__() 77 | self.local_loss = local_loss 78 | self.gather_with_grad = gather_with_grad 79 | self.cache_labels = cache_labels 80 | self.rank = rank 81 | self.world_size = world_size 82 | self.use_horovod = use_horovod 83 | 84 | # cache state 85 | self.prev_num_logits = 0 86 | self.labels = {} 87 | 88 | def forward(self, image_features, text_features, logit_scale): 89 | device = image_features.device 90 | if self.world_size > 1: 91 | all_image_features, all_text_features = gather_features( 92 | image_features, text_features, 93 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 94 | 95 | if self.local_loss: 96 | logits_per_image = logit_scale * image_features @ all_text_features.T 97 | logits_per_text = logit_scale * text_features @ all_image_features.T 98 | else: 99 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 100 | logits_per_text = logits_per_image.T 101 | else: 102 | logits_per_image = logit_scale * image_features @ text_features.T 103 | logits_per_text = logit_scale * text_features @ image_features.T 104 | 105 | # calculated ground-truth and cache if enabled 106 | num_logits = logits_per_image.shape[0] 107 | if self.prev_num_logits != num_logits or device not in self.labels: 108 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 109 | if self.world_size > 1 and self.local_loss: 110 | labels = labels + num_logits * self.rank 111 | if self.cache_labels: 112 | self.labels[device] = labels 113 | self.prev_num_logits = num_logits 114 | else: 115 | labels = self.labels[device] 116 | 117 | total_loss = ( 118 | F.cross_entropy(logits_per_image, labels) + 119 | F.cross_entropy(logits_per_text, labels) 120 | ) / 2 121 | return total_loss 122 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/timm-efficientnetv2_rw_s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "efficientnetv2_rw_s", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 288 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/timm-resnet50d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnet50d", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/timm-resnetaa50d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnetaa50d", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/timm-resnetblur50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnetblur50", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/timm-swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/timm-vit_base_patch16_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_base_patch16_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/timm-vit_base_patch32_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_base_patch32_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/model_configs/timm-vit_small_patch16_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_small_patch16_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_tag_models('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 26 | jit=True, 27 | ): 28 | """Load a CLIP model 29 | 30 | Parameters 31 | ---------- 32 | name : str 33 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 34 | device : Union[str, torch.device] 35 | The device to put the loaded model 36 | jit : bool 37 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 38 | 39 | Returns 40 | ------- 41 | model : torch.nn.Module 42 | The CLIP model 43 | preprocess : Callable[[PIL.Image], torch.Tensor] 44 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 45 | """ 46 | if get_pretrained_url(name, 'openai'): 47 | model_path = download_pretrained(get_pretrained_url(name, 'openai')) 48 | elif os.path.isfile(name): 49 | model_path = name 50 | else: 51 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 52 | 53 | try: 54 | # loading JIT archive 55 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 56 | state_dict = None 57 | except RuntimeError: 58 | # loading saved state dict 59 | if jit: 60 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 61 | jit = False 62 | state_dict = torch.load(model_path, map_location="cpu") 63 | 64 | if not jit: 65 | try: 66 | model = build_model_from_openai_state_dict(state_dict or model.state_dict()).to(device) 67 | except KeyError: 68 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 69 | model = build_model_from_openai_state_dict(sd).to(device) 70 | 71 | if str(device) == "cpu": 72 | model.float() 73 | return model 74 | 75 | # patch the device names 76 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 77 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 78 | 79 | def patch_device(module): 80 | try: 81 | graphs = [module.graph] if hasattr(module, "graph") else [] 82 | except RuntimeError: 83 | graphs = [] 84 | 85 | if hasattr(module, "forward1"): 86 | graphs.append(module.forward1.graph) 87 | 88 | for graph in graphs: 89 | for node in graph.findAllNodes("prim::Constant"): 90 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 91 | node.copyAttributes(device_node) 92 | 93 | model.apply(patch_device) 94 | patch_device(model.encode_image) 95 | patch_device(model.encode_text) 96 | 97 | # patch dtype to float32 on CPU 98 | if str(device) == "cpu": 99 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 100 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 101 | float_node = float_input.node() 102 | 103 | def patch_float(module): 104 | try: 105 | graphs = [module.graph] if hasattr(module, "graph") else [] 106 | except RuntimeError: 107 | graphs = [] 108 | 109 | if hasattr(module, "forward1"): 110 | graphs.append(module.forward1.graph) 111 | 112 | for graph in graphs: 113 | for node in graph.findAllNodes("aten::to"): 114 | inputs = list(node.inputs()) 115 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 116 | if inputs[i].node()["value"] == 5: 117 | inputs[i].node().copyAttributes(float_node) 118 | 119 | model.apply(patch_float) 120 | patch_float(model.encode_image) 121 | patch_float(model.encode_text) 122 | model.float() 123 | 124 | # ensure image_size attr available at consistent location for both jit and non-jit 125 | model.visual.image_size = model.input_resolution.item() 126 | return model 127 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 14 | except ImportError as e: 15 | timm = None 16 | 17 | from .utils import freeze_batch_norm_2d 18 | 19 | 20 | class TimmModel(nn.Module): 21 | """ timm model adapter 22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model_name, 28 | embed_dim, 29 | image_size=224, 30 | pool='avg', 31 | proj='linear', 32 | drop=0., 33 | pretrained=False): 34 | super().__init__() 35 | if timm is None: 36 | raise RuntimeError("Please `pip install timm` to use timm models.") 37 | 38 | self.image_size = to_2tuple(image_size) 39 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 40 | feat_size = self.trunk.default_cfg.get('pool_size', None) 41 | feature_ndim = 1 if not feat_size else 2 42 | if pool in ('abs_attn', 'rot_attn'): 43 | assert feature_ndim == 2 44 | # if attn pooling used, remove both classifier and default pool 45 | self.trunk.reset_classifier(0, global_pool='') 46 | else: 47 | # reset global pool if pool config set, otherwise leave as network default 48 | reset_kwargs = dict(global_pool=pool) if pool else {} 49 | self.trunk.reset_classifier(0, **reset_kwargs) 50 | prev_chs = self.trunk.num_features 51 | 52 | head_layers = OrderedDict() 53 | if pool == 'abs_attn': 54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 55 | prev_chs = embed_dim 56 | elif pool == 'rot_attn': 57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 58 | prev_chs = embed_dim 59 | else: 60 | assert proj, 'projection layer needed if non-attention pooling is used.' 61 | 62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 63 | if proj == 'linear': 64 | head_layers['drop'] = nn.Dropout(drop) 65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim) 66 | elif proj == 'mlp': 67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 68 | 69 | self.head = nn.Sequential(head_layers) 70 | 71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 72 | """ lock modules 73 | Args: 74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 75 | """ 76 | if not unlocked_groups: 77 | # lock full model 78 | for param in self.trunk.parameters(): 79 | param.requires_grad = False 80 | if freeze_bn_stats: 81 | freeze_batch_norm_2d(self.trunk) 82 | else: 83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 84 | try: 85 | # FIXME import here until API stable and in an official release 86 | from timm.models.helpers import group_parameters, group_modules 87 | except ImportError: 88 | raise RuntimeError( 89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 90 | matcher = self.trunk.group_matcher() 91 | gparams = group_parameters(self.trunk, matcher) 92 | max_layer_id = max(gparams.keys()) 93 | max_layer_id = max_layer_id - unlocked_groups 94 | for group_idx in range(max_layer_id + 1): 95 | group = gparams[group_idx] 96 | for param in group: 97 | self.trunk.get_parameter(param).requires_grad = False 98 | if freeze_bn_stats: 99 | gmodules = group_modules(self.trunk, matcher, reverse=True) 100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 101 | freeze_batch_norm_2d(self.trunk, gmodules) 102 | 103 | def forward(self, x): 104 | x = self.trunk(x) 105 | x = self.head(x) 106 | return x 107 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 19 | 20 | 21 | @lru_cache() 22 | def bytes_to_unicode(): 23 | """ 24 | Returns list of utf-8 byte and a corresponding list of unicode strings. 25 | The reversible bpe codes work on unicode strings. 26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 28 | This is a signficant percentage of your normal, say, 32K bpe vocab. 29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 30 | And avoids mapping to whitespace/control characters the bpe code barfs on. 31 | """ 32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 33 | cs = bs[:] 34 | n = 0 35 | for b in range(2**8): 36 | if b not in bs: 37 | bs.append(b) 38 | cs.append(2**8+n) 39 | n += 1 40 | cs = [chr(n) for n in cs] 41 | return dict(zip(bs, cs)) 42 | 43 | 44 | def get_pairs(word): 45 | """Return set of symbol pairs in a word. 46 | Word is represented as tuple of symbols (symbols being variable-length strings). 47 | """ 48 | pairs = set() 49 | prev_char = word[0] 50 | for char in word[1:]: 51 | pairs.add((prev_char, char)) 52 | prev_char = char 53 | return pairs 54 | 55 | 56 | def basic_clean(text): 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r'\s+', ' ', text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 73 | merges = merges[1:49152-256-2+1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v+'' for v in vocab] 77 | for merge in merges: 78 | vocab.append(''.join(merge)) 79 | if not special_tokens: 80 | special_tokens = ['', ''] 81 | else: 82 | special_tokens = ['', ''] + special_tokens 83 | vocab.extend(special_tokens) 84 | self.encoder = dict(zip(vocab, range(len(vocab)))) 85 | self.decoder = {v: k for k, v in self.encoder.items()} 86 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 87 | self.cache = {t:t for t in special_tokens} 88 | special = "|".join(special_tokens) 89 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 90 | 91 | self.vocab_size = len(self.encoder) 92 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 93 | 94 | def bpe(self, token): 95 | if token in self.cache: 96 | return self.cache[token] 97 | word = tuple(token[:-1]) + ( token[-1] + '',) 98 | pairs = get_pairs(word) 99 | 100 | if not pairs: 101 | return token+'' 102 | 103 | while True: 104 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 105 | if bigram not in self.bpe_ranks: 106 | break 107 | first, second = bigram 108 | new_word = [] 109 | i = 0 110 | while i < len(word): 111 | try: 112 | j = word.index(first, i) 113 | new_word.extend(word[i:j]) 114 | i = j 115 | except: 116 | new_word.extend(word[i:]) 117 | break 118 | 119 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 120 | new_word.append(first+second) 121 | i += 2 122 | else: 123 | new_word.append(word[i]) 124 | i += 1 125 | new_word = tuple(new_word) 126 | word = new_word 127 | if len(word) == 1: 128 | break 129 | else: 130 | pairs = get_pairs(word) 131 | word = ' '.join(word) 132 | self.cache[token] = word 133 | return word 134 | 135 | def encode(self, text): 136 | bpe_tokens = [] 137 | text = whitespace_clean(basic_clean(text)).lower() 138 | for token in re.findall(self.pat, text): 139 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 140 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 141 | return bpe_tokens 142 | 143 | def decode(self, tokens): 144 | text = ''.join([self.decoder[token] for token in tokens]) 145 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 146 | return text 147 | 148 | 149 | _tokenizer = SimpleTokenizer() 150 | 151 | 152 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 153 | """ 154 | Returns the tokenized representation of given input string(s) 155 | 156 | Parameters 157 | ---------- 158 | texts : Union[str, List[str]] 159 | An input string or a list of input strings to tokenize 160 | context_length : int 161 | The context length to use; all CLIP models use 77 as the context length 162 | 163 | Returns 164 | ------- 165 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 166 | """ 167 | if isinstance(texts, str): 168 | texts = [texts] 169 | 170 | sot_token = _tokenizer.encoder[""] 171 | eot_token = _tokenizer.encoder[""] 172 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 173 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 174 | 175 | for i, tokens in enumerate(all_tokens): 176 | if len(tokens) > context_length: 177 | tokens = tokens[:context_length] # Truncate 178 | tokens[-1] = eot_token 179 | result[i, :len(tokens)] = torch.tensor(tokens) 180 | 181 | return result 182 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional, Sequence, Tuple 3 | import numpy as np 4 | from PIL import Image 5 | from scipy.ndimage.morphology import binary_erosion 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.transforms.functional as F 10 | 11 | 12 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 13 | CenterCrop 14 | 15 | 16 | class ResizeMaxSize(nn.Module): 17 | 18 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 19 | super().__init__() 20 | if not isinstance(max_size, int): 21 | raise TypeError(f"Size should be int. Got {type(max_size)}") 22 | self.max_size = max_size 23 | self.interpolation = interpolation 24 | self.fn = min if fn == 'min' else min 25 | self.fill = fill 26 | 27 | def forward(self, img): 28 | if isinstance(img, torch.Tensor): 29 | height, width = img.shape[:2] 30 | else: 31 | width, height = img.size 32 | scale = self.max_size / float(max(height, width)) 33 | if scale != 1.0: 34 | new_size = tuple(round(dim * scale) for dim in (height, width)) 35 | img = F.resize(img, new_size, self.interpolation) 36 | pad_h = self.max_size - new_size[0] 37 | pad_w = self.max_size - new_size[1] 38 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 39 | return img 40 | 41 | class Erosion(nn.Module): 42 | 43 | def __init__(self, ksize_list=(3, 7, 11, 15, 17, 21), mean=(124, 116, 103)): 44 | super(Erosion, self).__init__() 45 | self.ksize_list = ksize_list 46 | self.mean = np.array(mean).astype(np.uint8) 47 | 48 | def forward(self, img): 49 | ksize = random.choice(self.ksize_list) 50 | imarray = np.array(img) 51 | mask = (imarray == self.mean).all(axis=2) 52 | mask = ~mask 53 | mask = mask.astype(int) 54 | erosion_mask = binary_erosion(mask, structure=np.ones((ksize,ksize))).astype(mask.dtype) 55 | imarray[erosion_mask == 0] = self.mean 56 | img = Image.fromarray(imarray) 57 | return img 58 | 59 | 60 | 61 | def _convert_to_rgb(image): 62 | return image.convert('RGB') 63 | 64 | def _convert_to_rgb_w_mask(inp): 65 | image, mask = inp 66 | return image.convert('RGB'), mask 67 | 68 | def _to_tensor_w_mask(inp): 69 | image, mask = inp 70 | return F.to_tensor(image), mask 71 | 72 | class Maskget(nn.Module): 73 | 74 | def __init__(self, mean=(124, 116, 103)): 75 | super(Maskget, self).__init__() 76 | self.mean = np.array(mean).astype(np.uint8) 77 | 78 | def forward(self, img): 79 | imarray = np.array(img) 80 | mask = (imarray == self.mean).all(axis=2) 81 | mask = ~mask 82 | mask = mask.astype(int) 83 | img = Image.fromarray(imarray) 84 | return img, mask 85 | 86 | class Normalize_w_mask(nn.Module): 87 | def __init__(self, mean, std, inplace=False): 88 | super().__init__() 89 | self.mean = mean 90 | self.std = std 91 | self.inplace = inplace 92 | 93 | def forward(self, inp): 94 | """ 95 | Args: 96 | tensor (Tensor): Tensor image to be normalized. 97 | 98 | Returns: 99 | Tensor: Normalized Tensor image. 100 | """ 101 | tensor, mask = inp 102 | return F.normalize(tensor, self.mean, self.std, self.inplace), mask 103 | 104 | def _normalize_w_mask(image, mask): 105 | return F.to_tensor(image), mask 106 | 107 | 108 | def image_transform( 109 | image_size: int, 110 | is_train: bool, 111 | mean: Optional[Tuple[float, ...]] = None, 112 | std: Optional[Tuple[float, ...]] = None, 113 | resize_longest_max: bool = False, 114 | fill_color: int = 0, 115 | scale: Optional[Tuple[float, ...]] = None, 116 | erosion: bool = False, 117 | with_mask: bool = False, 118 | ): 119 | mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean 120 | std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std 121 | scale = scale or (0.9, 1.0) 122 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 123 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 124 | image_size = image_size[0] 125 | 126 | normalize = Normalize(mean=mean, std=std) 127 | if is_train: 128 | default_transform = Compose([ 129 | RandomResizedCrop(image_size, scale=scale, interpolation=InterpolationMode.BICUBIC), 130 | _convert_to_rgb, 131 | ToTensor(), 132 | normalize,]) 133 | if with_mask: 134 | default_transform = Compose([ 135 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 136 | Maskget(), 137 | _convert_to_rgb_w_mask, 138 | _to_tensor_w_mask, 139 | Normalize_w_mask(mean=mean, std=std), 140 | ]) 141 | return default_transform 142 | else: 143 | if resize_longest_max: 144 | transforms = [ 145 | ResizeMaxSize(image_size, fill=fill_color) 146 | ] 147 | else: 148 | transforms = [ 149 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 150 | CenterCrop(image_size), 151 | ] 152 | if not with_mask: 153 | transforms.extend([ 154 | _convert_to_rgb, 155 | ToTensor(), 156 | normalize, 157 | ]) 158 | else: 159 | transforms.extend([ 160 | Maskget(), 161 | _convert_to_rgb_w_mask, 162 | _to_tensor_w_mask, 163 | Normalize_w_mask(mean=mean, std=std), 164 | ]) 165 | return Compose(transforms) 166 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | 8 | def freeze_batch_norm_2d(module, module_match={}, name=''): 9 | """ 10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 12 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 13 | 14 | Args: 15 | module (torch.nn.Module): Any PyTorch module. 16 | module_match (dict): Dictionary of full module names to freeze (all if empty) 17 | name (str): Full module name (prefix) 18 | 19 | Returns: 20 | torch.nn.Module: Resulting module 21 | 22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 23 | """ 24 | res = module 25 | is_match = True 26 | if module_match: 27 | is_match = name in module_match 28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 29 | res = FrozenBatchNorm2d(module.num_features) 30 | res.num_features = module.num_features 31 | res.affine = module.affine 32 | if module.affine: 33 | res.weight.data = module.weight.data.clone().detach() 34 | res.bias.data = module.bias.data.clone().detach() 35 | res.running_mean.data = module.running_mean.data 36 | res.running_var.data = module.running_var.data 37 | res.eps = module.eps 38 | else: 39 | for child_name, child in module.named_children(): 40 | full_child_name = '.'.join([name, child_name]) if name else child_name 41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 42 | if new_child is not child: 43 | res.add_module(child_name, new_child) 44 | return res 45 | 46 | 47 | # From PyTorch internals 48 | def _ntuple(n): 49 | def parse(x): 50 | if isinstance(x, collections.abc.Iterable): 51 | return x 52 | return tuple(repeat(x, n)) 53 | return parse 54 | 55 | 56 | to_1tuple = _ntuple(1) 57 | to_2tuple = _ntuple(2) 58 | to_3tuple = _ntuple(3) 59 | to_4tuple = _ntuple(4) 60 | to_ntuple = lambda n, x: _ntuple(n)(x) 61 | -------------------------------------------------------------------------------- /open_clip_training/src/open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.3.0' 2 | -------------------------------------------------------------------------------- /open_clip_training/src/scripts/coco_gt_171cls_finetune_VitL.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 4 -m training.main \ 2 | --train-data ../openclip_data/coco_gt_171cls.csv \ 3 | --train-num-samples 965036 \ 4 | --lr 0.000005 \ 5 | --warmup 100 \ 6 | --force-quick-gelu \ 7 | --dataset-type csv \ 8 | --batch-size 32 \ 9 | --precision amp \ 10 | --workers 4 \ 11 | --model ViT-L-14 \ 12 | --lock-text \ 13 | --zeroshot-frequency 1 \ 14 | --save-frequency 5 \ 15 | --epoch 5 \ 16 | --pretrained openai \ 17 | --ade-val ../openclip_data/ade_gt_150cls_val -------------------------------------------------------------------------------- /open_clip_training/src/scripts/coco_proposal_1cap_finetune_VitB.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 4 -m training.main \ 2 | --train-data ../openclip_data/coco_proposal_1cap.csv \ 3 | --train-num-samples 442117 \ 4 | --lr 0.000005 \ 5 | --warmup 100 \ 6 | --force-quick-gelu \ 7 | --dataset-type csv \ 8 | --batch-size 128 \ 9 | --precision amp \ 10 | --workers 4 \ 11 | --model ViT-B-16 \ 12 | --lock-text \ 13 | --zeroshot-frequency 1 \ 14 | --save-frequency 1 \ 15 | --epoch 5 \ 16 | --pretrained openai \ 17 | --ade-val ../openclip_data/ade_gt_150cls_val 18 | -------------------------------------------------------------------------------- /open_clip_training/src/scripts/coco_proposal_1cap_finetune_VitL.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 4 -m training.main \ 2 | --train-data ../openclip_data/coco_proposal_1cap.csv \ 3 | --train-num-samples 442117 \ 4 | --lr 0.000005 \ 5 | --warmup 100 \ 6 | --force-quick-gelu \ 7 | --dataset-type csv \ 8 | --batch-size 32 \ 9 | --precision amp \ 10 | --workers 4 \ 11 | --model ViT-L-14 \ 12 | --lock-text \ 13 | --zeroshot-frequency 1 \ 14 | --save-frequency 1 \ 15 | --epoch 5 \ 16 | --pretrained openai \ 17 | --ade-val ../openclip_data/ade_gt_150cls_val -------------------------------------------------------------------------------- /open_clip_training/src/scripts/coco_proposal_1cap_mask_prompt_tuning_VitL.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 4 -m training.main_mask_prompt_tuning \ 2 | --train-data ../openclip_data/coco_proposal_1cap.csv \ 3 | --train-num-samples 442117 \ 4 | --lr 0.05 \ 5 | --mask_wd 0.0 \ 6 | --warmup 100 \ 7 | --force-quick-gelu \ 8 | --dataset-type csv \ 9 | --batch-size 32 \ 10 | --precision amp \ 11 | --workers 4 \ 12 | --with-mask \ 13 | --model ViT-L-14 \ 14 | --mask-emb-depth 3 \ 15 | --lock-text \ 16 | --lock-image \ 17 | --lock-image-unlocked-groups 0 \ 18 | --zeroshot-frequency 1 \ 19 | --save-frequency 1 \ 20 | --epoch 5 \ 21 | --pretrained /home/jeffliang/ov-seg/open_clip_training/src/logs/2023_05_28-23_35_23-model_ViT-L-14-lr_5e-06-b_32-j_4-p_amp/checkpoints/epoch_5.pt \ 22 | --ade-val ../openclip_data/ade_gt_150cls_val -------------------------------------------------------------------------------- /open_clip_training/src/scripts/launch.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/scripts/launch.sh -------------------------------------------------------------------------------- /open_clip_training/src/training/.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | -------------------------------------------------------------------------------- /open_clip_training/src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/open_clip_training/src/training/__init__.py -------------------------------------------------------------------------------- /open_clip_training/src/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | try: 6 | import horovod.torch as hvd 7 | except ImportError: 8 | hvd = None 9 | 10 | 11 | def is_global_master(args): 12 | return args.rank == 0 13 | 14 | 15 | def is_local_master(args): 16 | return args.local_rank == 0 17 | 18 | 19 | def is_master(args, local=False): 20 | return is_local_master(args) if local else is_global_master(args) 21 | 22 | 23 | def is_using_horovod(): 24 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 25 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 26 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 27 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 28 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 29 | return True 30 | else: 31 | return False 32 | 33 | 34 | def is_using_distributed(): 35 | if 'WORLD_SIZE' in os.environ: 36 | return int(os.environ['WORLD_SIZE']) > 1 37 | if 'SLURM_NTASKS' in os.environ: 38 | return int(os.environ['SLURM_NTASKS']) > 1 39 | return False 40 | 41 | 42 | def world_info_from_env(): 43 | local_rank = 0 44 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 45 | if v in os.environ: 46 | local_rank = int(os.environ[v]) 47 | break 48 | global_rank = 0 49 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 50 | if v in os.environ: 51 | global_rank = int(os.environ[v]) 52 | break 53 | world_size = 1 54 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 55 | if v in os.environ: 56 | world_size = int(os.environ[v]) 57 | break 58 | 59 | return local_rank, global_rank, world_size 60 | 61 | 62 | def init_distributed_device(args): 63 | # Distributed training = training on more than one GPU. 64 | # Works in both single and multi-node scenarios. 65 | args.distributed = False 66 | args.world_size = 1 67 | args.rank = 0 # global rank 68 | args.local_rank = 0 69 | if args.horovod: 70 | assert hvd is not None, "Horovod is not installed" 71 | hvd.init() 72 | args.local_rank = int(hvd.local_rank()) 73 | args.rank = hvd.rank() 74 | args.world_size = hvd.size() 75 | args.distributed = True 76 | os.environ['LOCAL_RANK'] = str(args.local_rank) 77 | os.environ['RANK'] = str(args.rank) 78 | os.environ['WORLD_SIZE'] = str(args.world_size) 79 | elif is_using_distributed(): 80 | if 'SLURM_PROCID' in os.environ: 81 | # DDP via SLURM 82 | args.local_rank, args.rank, args.world_size = world_info_from_env() 83 | # SLURM var -> torch.distributed vars in case needed 84 | os.environ['LOCAL_RANK'] = str(args.local_rank) 85 | os.environ['RANK'] = str(args.rank) 86 | os.environ['WORLD_SIZE'] = str(args.world_size) 87 | torch.distributed.init_process_group( 88 | backend=args.dist_backend, 89 | init_method=args.dist_url, 90 | world_size=args.world_size, 91 | rank=args.rank, 92 | ) 93 | else: 94 | # DDP via torchrun, torch.distributed.launch 95 | args.local_rank, _, _ = world_info_from_env() 96 | torch.distributed.init_process_group( 97 | backend=args.dist_backend, 98 | init_method=args.dist_url) 99 | args.world_size = torch.distributed.get_world_size() 100 | args.rank = torch.distributed.get_rank() 101 | args.distributed = True 102 | 103 | if torch.cuda.is_available(): 104 | if args.distributed and not args.no_set_device_rank: 105 | device = 'cuda:%d' % args.local_rank 106 | else: 107 | device = 'cuda:0' 108 | torch.cuda.set_device(device) 109 | else: 110 | device = 'cpu' 111 | args.device = device 112 | device = torch.device(device) 113 | return device 114 | -------------------------------------------------------------------------------- /open_clip_training/src/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_logging(log_file, level, include_host=False): 5 | if include_host: 6 | import socket 7 | hostname = socket.gethostname() 8 | formatter = logging.Formatter( 9 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 10 | else: 11 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 12 | 13 | logging.root.setLevel(level) 14 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 15 | for logger in loggers: 16 | logger.setLevel(level) 17 | 18 | stream_handler = logging.StreamHandler() 19 | stream_handler.setFormatter(formatter) 20 | logging.root.addHandler(stream_handler) 21 | 22 | if log_file: 23 | file_handler = logging.FileHandler(filename=log_file) 24 | file_handler.setFormatter(formatter) 25 | logging.root.addHandler(file_handler) 26 | 27 | -------------------------------------------------------------------------------- /open_clip_training/src/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | e = step - warmup_length 19 | es = steps - warmup_length 20 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 21 | assign_learning_rate(optimizer, lr) 22 | return lr 23 | return _lr_adjuster -------------------------------------------------------------------------------- /open_clip_training/src/training/zero_shot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import suppress 3 | import inspect 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torchmetrics import Accuracy 8 | from tqdm import tqdm 9 | 10 | from open_clip import tokenize 11 | from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template 12 | from .ade150_zeroshot_data import ade150_classnames 13 | 14 | 15 | def zero_shot_classifier(model, classnames, templates, args): 16 | with torch.no_grad(): 17 | zeroshot_weights = [] 18 | for classname in tqdm(classnames): 19 | texts = [template(classname) for template in templates] # format with class 20 | texts = tokenize(texts).to(args.device) # tokenize 21 | if args.distributed and not args.horovod: 22 | class_embeddings = model.module.encode_text(texts) 23 | else: 24 | class_embeddings = model.encode_text(texts) 25 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 26 | class_embedding /= class_embedding.norm() 27 | zeroshot_weights.append(class_embedding) 28 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) 29 | return zeroshot_weights 30 | 31 | 32 | def accuracy(output, target, topk=(1,)): 33 | pred = output.topk(max(topk), 1, True, True)[1].t() 34 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 35 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk], pred[0] 36 | 37 | 38 | def run(model, classifier, dataloader, args): 39 | autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress 40 | with torch.no_grad(): 41 | top1, top5, n = 0., 0., 0. 42 | preds = [] 43 | targets = [] 44 | macc = Accuracy('multiclass', num_classes=150, average='macro').cuda() 45 | for batch, target in tqdm(dataloader, unit_scale=args.batch_size): 46 | if args.with_mask: 47 | images, masks = batch 48 | masks = masks.to(args.device) 49 | else: 50 | images = batch 51 | images = images.to(args.device) 52 | target = target.to(args.device) 53 | 54 | with autocast(): 55 | # predict 56 | if args.distributed and not args.horovod: 57 | if args.with_mask: 58 | image_features = model.module.encode_image(images, masks) 59 | else: 60 | image_features = model.module.encode_image(images) 61 | else: 62 | if args.with_mask: 63 | image_features = model.encode_image(images, masks) 64 | else: 65 | image_features = model.encode_image(images) 66 | image_features = F.normalize(image_features, dim=-1) 67 | logits = 100. * image_features @ classifier 68 | 69 | # measure accuracy 70 | (acc1, acc5), pred = accuracy(logits, target, topk=(1, 5)) 71 | preds.append(pred) 72 | targets.append(target) 73 | top1 += acc1 74 | top5 += acc5 75 | n += images.size(0) 76 | preds = torch.cat(preds) 77 | targets = torch.cat(targets) 78 | top1 = (top1 / n) 79 | top5 = (top5 / n) 80 | return top1, top5, macc(preds, targets).item() 81 | 82 | 83 | def zero_shot_eval(model, data, epoch, args): 84 | if 'imagenet-val' not in data and 'imagenet-v2' not in data and 'ade-val' not in data: 85 | return {} 86 | if args.zeroshot_frequency == 0: 87 | return {} 88 | if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: 89 | return {} 90 | 91 | logging.info('Starting zero-shot imagenet.') 92 | # for i in range(len(openai_imagenet_template)): 93 | # template = openai_imagenet_template[i] 94 | # logging.info(inspect.getsource(template)) 95 | logging.info('Building zero-shot classifier') 96 | if 'ade-val' in data: 97 | classifier = zero_shot_classifier(model, ade150_classnames, openai_imagenet_template, args) 98 | else: 99 | classifier = zero_shot_classifier(model, imagenet_classnames, openai_imagenet_template, args) 100 | 101 | logging.info('Using classifier') 102 | results = {} 103 | if 'imagenet-val' in data: 104 | top1, top5, macc = run(model, classifier, data['imagenet-val'].dataloader, args) 105 | results['imagenet-zeroshot-val-top1'] = top1 106 | results['imagenet-zeroshot-val-top5'] = top5 107 | results['mean-accuracy-top1'] = macc 108 | if 'imagenet-v2' in data: 109 | top1, top5 = run(model, classifier, data['imagenet-v2'].dataloader, args) 110 | results['imagenetv2-zeroshot-val-top1'] = top1 111 | results['imagenetv2-zeroshot-val-top5'] = top5 112 | if 'ade-val' in data: 113 | top1, top5, macc = run(model, classifier, data['ade-val'].dataloader, args) 114 | results['ade150-zeroshot-val-top1'] = top1 115 | results['ade150-zeroshot-val-top5'] = top5 116 | 117 | logging.info('Finished zero-shot imagenet.') 118 | 119 | return results 120 | -------------------------------------------------------------------------------- /open_clip_training/tests/test_simple.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from PIL import Image 4 | from open_clip import tokenizer 5 | import open_clip 6 | import os 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 8 | 9 | def test_inference(): 10 | model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32') 11 | 12 | current_dir = os.path.dirname(os.path.realpath(__file__)) 13 | 14 | image = preprocess(Image.open(current_dir + "/../docs/CLIP.png")).unsqueeze(0) 15 | text = tokenizer.tokenize(["a diagram", "a dog", "a cat"]) 16 | 17 | with torch.no_grad(): 18 | image_features = model.encode_image(image) 19 | text_features = model.encode_text(text) 20 | 21 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 22 | 23 | assert text_probs.cpu().numpy()[0].tolist() == [1.0, 0.0, 0.0] -------------------------------------------------------------------------------- /open_vocab_seg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | from . import data 5 | from . import modeling 6 | from .config import add_ovseg_config 7 | 8 | from .test_time_augmentation import SemanticSegmentorWithTTA 9 | from .ovseg_model import OVSeg, OVSegDEMO 10 | -------------------------------------------------------------------------------- /open_vocab_seg/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | from detectron2.config import CfgNode as CN 5 | 6 | 7 | def add_mask_former_default_config(cfg): 8 | # data config 9 | # select the dataset mapper 10 | cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic" 11 | # Color augmentation 12 | cfg.INPUT.COLOR_AUG_SSD = False 13 | # We retry random cropping until no single category in semantic segmentation GT occupies more 14 | # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. 15 | cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 16 | # Pad image and segmentation GT in dataset mapper. 17 | cfg.INPUT.SIZE_DIVISIBILITY = -1 18 | 19 | # solver config 20 | # test batch size 21 | cfg.SOLVER.TEST_IMS_PER_BATCH = 1 22 | # weight decay on embedding 23 | cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0 24 | # optimizer 25 | cfg.SOLVER.OPTIMIZER = "ADAMW" 26 | cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 27 | 28 | # mask_former model config 29 | cfg.MODEL.MASK_FORMER = CN() 30 | 31 | # loss 32 | cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True 33 | cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1 34 | cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0 35 | cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0 36 | 37 | # transformer config 38 | cfg.MODEL.MASK_FORMER.NHEADS = 8 39 | cfg.MODEL.MASK_FORMER.DROPOUT = 0.1 40 | cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 41 | cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0 42 | cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6 43 | cfg.MODEL.MASK_FORMER.PRE_NORM = False 44 | 45 | cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 46 | cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100 47 | 48 | cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5" 49 | cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False 50 | 51 | # mask_former inference config 52 | cfg.MODEL.MASK_FORMER.TEST = CN() 53 | cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False 54 | cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0 55 | cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0 56 | cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False 57 | 58 | # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet) 59 | # you can use this config to override 60 | cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32 61 | 62 | # pixel decoder config 63 | cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 64 | # adding transformer in pixel decoder 65 | cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0 66 | # pixel decoder 67 | cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder" 68 | 69 | # swin transformer backbone 70 | cfg.MODEL.SWIN = CN() 71 | cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224 72 | cfg.MODEL.SWIN.PATCH_SIZE = 4 73 | cfg.MODEL.SWIN.EMBED_DIM = 96 74 | cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 75 | cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 76 | cfg.MODEL.SWIN.WINDOW_SIZE = 7 77 | cfg.MODEL.SWIN.MLP_RATIO = 4.0 78 | cfg.MODEL.SWIN.QKV_BIAS = True 79 | cfg.MODEL.SWIN.QK_SCALE = None 80 | cfg.MODEL.SWIN.NORM_INDICES = None 81 | cfg.MODEL.SWIN.PROJECTION = False 82 | cfg.MODEL.SWIN.PROJECT_DIM = 256 83 | cfg.MODEL.SWIN.DROP_RATE = 0.0 84 | cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0 85 | cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3 86 | cfg.MODEL.SWIN.APE = False 87 | cfg.MODEL.SWIN.PATCH_NORM = True 88 | cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"] 89 | 90 | 91 | def add_our_config(cfg): 92 | cfg.TEST.SLIDING_WINDOW = False 93 | cfg.TEST.SLIDING_TILE_SIZE = 224 94 | cfg.TEST.SLIDING_OVERLAP = 2 / 3.0 95 | # whether to use dense crf 96 | cfg.TEST.DENSE_CRF = False 97 | cfg.DATASETS.SAMPLE_PER_CLASS = -1 98 | cfg.DATASETS.SAMPLE_SEED = 0 99 | # embedding head 100 | cfg.MODEL.SEM_SEG_HEAD.EMBEDDING_DIM = 512 101 | cfg.MODEL.SEM_SEG_HEAD.EMBED_HIDDEN_DIM = 1024 102 | cfg.MODEL.SEM_SEG_HEAD.EMBED_LAYERS = 2 103 | # clip_adapter 104 | cfg.MODEL.CLIP_ADAPTER = CN() 105 | cfg.MODEL.CLIP_ADAPTER.TEXT_TEMPLATES = "vild" 106 | # for predefined 107 | cfg.MODEL.CLIP_ADAPTER.PREDEFINED_PROMPT_TEMPLATES = ["a photo of a {}."] 108 | # for learnable prompt 109 | cfg.MODEL.CLIP_ADAPTER.PROMPT_CHECKPOINT = "" 110 | cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME = "ViT-B/16" 111 | cfg.MODEL.CLIP_ADAPTER.MASK_FILL = "mean" 112 | cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO = 1.0 113 | cfg.MODEL.CLIP_ADAPTER.MASK_THR = 0.4 114 | cfg.MODEL.CLIP_ADAPTER.MASK_MATTING = False 115 | cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED = True 116 | cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE = True 117 | cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT = 0.7 118 | # for mask prompt 119 | cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH = 3 120 | cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD = False 121 | 122 | # wandb 123 | cfg.WANDB = CN() 124 | cfg.WANDB.PROJECT = "open_vocab_seg" 125 | cfg.WANDB.NAME = None 126 | 127 | 128 | def add_ovseg_config(cfg): 129 | """ 130 | Add config for open_vocab_seg. 131 | """ 132 | add_mask_former_default_config(cfg) 133 | add_our_config(cfg) 134 | -------------------------------------------------------------------------------- /open_vocab_seg/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | from .dataset_mappers import * 5 | from . import datasets 6 | from .build import ( 7 | build_detection_train_loader, 8 | build_detection_test_loader, 9 | ) 10 | -------------------------------------------------------------------------------- /open_vocab_seg/data/dataset_mappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper 5 | -------------------------------------------------------------------------------- /open_vocab_seg/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from . import register_coco_stuff, register_voc_seg 3 | from . import register_cc3m 4 | from . import register_ade20k_full 5 | from . import register_pascal_context -------------------------------------------------------------------------------- /open_vocab_seg/data/datasets/register_voc_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | 4 | from detectron2.data import DatasetCatalog, MetadataCatalog 5 | from detectron2.data.datasets import load_sem_seg 6 | 7 | PASCALVOC20_NAMES = ( 8 | "aeroplane", 9 | "bicycle", 10 | "bird", 11 | "boat", 12 | "bottle", 13 | "bus", 14 | "car", 15 | "cat", 16 | "chair", 17 | "cow", 18 | "diningtable", 19 | "dog", 20 | "horse", 21 | "motorbike", 22 | "person", 23 | "pottedplant", 24 | "sheep", 25 | "sofa", 26 | "train", 27 | "tvmonitor", 28 | ) 29 | 30 | def _get_voc_meta(cat_list): 31 | ret = { 32 | "stuff_classes": cat_list, 33 | } 34 | return ret 35 | 36 | 37 | def register_pascalvoc(root): 38 | root = os.path.join(root, "VOCdevkit/VOC2012") 39 | meta = _get_voc_meta(PASCALVOC20_NAMES) 40 | 41 | for name, image_dirname, sem_seg_dirname in [ 42 | ("val", "JPEGImages", "annotations_detectron2/val"), 43 | ]: 44 | image_dir = os.path.join(root, image_dirname) 45 | gt_dir = os.path.join(root, sem_seg_dirname) 46 | all_name = f"pascalvoc20_sem_seg_{name}" 47 | DatasetCatalog.register( 48 | all_name, 49 | lambda x=image_dir, y=gt_dir: load_sem_seg( 50 | y, x, gt_ext="png", image_ext="jpg" 51 | ), 52 | ) 53 | MetadataCatalog.get(all_name).set( 54 | image_root=image_dir, 55 | sem_seg_root=gt_dir, 56 | evaluator_type="sem_seg", 57 | ignore_label=255, 58 | **meta, 59 | ) 60 | 61 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 62 | register_pascalvoc(_root) 63 | -------------------------------------------------------------------------------- /open_vocab_seg/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | from .generalized_sem_seg_evaluation import GeneralizedSemSegEvaluator 5 | -------------------------------------------------------------------------------- /open_vocab_seg/evaluation/generalized_sem_seg_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import itertools 5 | import json 6 | import numpy as np 7 | import os 8 | from collections import OrderedDict 9 | import PIL.Image as Image 10 | import torch 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.utils.comm import all_gather, is_main_process, synchronize 14 | from detectron2.utils.file_io import PathManager 15 | 16 | from detectron2.evaluation import SemSegEvaluator 17 | 18 | 19 | class GeneralizedSemSegEvaluator(SemSegEvaluator): 20 | """ 21 | Evaluate semantic segmentation metrics. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | dataset_name, 27 | distributed=True, 28 | output_dir=None, 29 | *, 30 | num_classes=None, 31 | ignore_label=None, 32 | post_process_func=None, 33 | ): 34 | super().__init__( 35 | dataset_name, 36 | distributed=distributed, 37 | output_dir=output_dir, 38 | num_classes=num_classes, 39 | ignore_label=ignore_label, 40 | ) 41 | meta = MetadataCatalog.get(dataset_name) 42 | try: 43 | self._evaluation_set = meta.evaluation_set 44 | except AttributeError: 45 | self._evaluation_set = None 46 | self.post_process_func = ( 47 | post_process_func 48 | if post_process_func is not None 49 | else lambda x, **kwargs: x 50 | ) 51 | 52 | def process(self, inputs, outputs): 53 | """ 54 | Args: 55 | inputs: the inputs to a model. 56 | It is a list of dicts. Each dict corresponds to an image and 57 | contains keys like "height", "width", "file_name". 58 | outputs: the outputs of a model. It is either list of semantic segmentation predictions 59 | (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic 60 | segmentation prediction in the same format. 61 | """ 62 | for input, output in zip(inputs, outputs): 63 | output = self.post_process_func( 64 | output["sem_seg"], image=np.array(Image.open(input["file_name"])) 65 | ) 66 | output = output.argmax(dim=0).to(self._cpu_device) 67 | pred = np.array(output, dtype=np.int) 68 | with PathManager.open( 69 | self.input_file_to_gt_file[input["file_name"]], "rb" 70 | ) as f: 71 | gt = np.array(Image.open(f), dtype=np.int) 72 | 73 | gt[gt == self._ignore_label] = self._num_classes 74 | 75 | self._conf_matrix += np.bincount( 76 | (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1), 77 | minlength=self._conf_matrix.size, 78 | ).reshape(self._conf_matrix.shape) 79 | 80 | self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"])) 81 | 82 | def evaluate(self): 83 | """ 84 | Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval): 85 | 86 | * Mean intersection-over-union averaged across classes (mIoU) 87 | * Frequency Weighted IoU (fwIoU) 88 | * Mean pixel accuracy averaged across classes (mACC) 89 | * Pixel Accuracy (pACC) 90 | """ 91 | if self._distributed: 92 | synchronize() 93 | conf_matrix_list = all_gather(self._conf_matrix) 94 | self._predictions = all_gather(self._predictions) 95 | self._predictions = list(itertools.chain(*self._predictions)) 96 | if not is_main_process(): 97 | return 98 | 99 | self._conf_matrix = np.zeros_like(self._conf_matrix) 100 | for conf_matrix in conf_matrix_list: 101 | self._conf_matrix += conf_matrix 102 | 103 | if self._output_dir: 104 | PathManager.mkdirs(self._output_dir) 105 | file_path = os.path.join(self._output_dir, "sem_seg_predictions.json") 106 | with PathManager.open(file_path, "w") as f: 107 | f.write(json.dumps(self._predictions)) 108 | 109 | acc = np.full(self._num_classes, np.nan, dtype=np.float) 110 | iou = np.full(self._num_classes, np.nan, dtype=np.float) 111 | tp = self._conf_matrix.diagonal()[:-1].astype(np.float) 112 | pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float) 113 | class_weights = pos_gt / np.sum(pos_gt) 114 | pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float) 115 | acc_valid = pos_gt > 0 116 | acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid] 117 | iou_valid = (pos_gt + pos_pred) > 0 118 | union = pos_gt + pos_pred - tp 119 | iou[acc_valid] = tp[acc_valid] / union[acc_valid] 120 | macc = np.sum(acc[acc_valid]) / np.sum(acc_valid) 121 | miou = np.sum(iou[acc_valid]) / np.sum(iou_valid) 122 | fiou = np.sum(iou[acc_valid] * class_weights[acc_valid]) 123 | pacc = np.sum(tp) / np.sum(pos_gt) 124 | 125 | res = {} 126 | res["mIoU"] = 100 * miou 127 | res["fwIoU"] = 100 * fiou 128 | for i, name in enumerate(self._class_names): 129 | res["IoU-{}".format(name)] = 100 * iou[i] 130 | res["mACC"] = 100 * macc 131 | res["pACC"] = 100 * pacc 132 | for i, name in enumerate(self._class_names): 133 | res["ACC-{}".format(name)] = 100 * acc[i] 134 | if self._evaluation_set is not None: 135 | for set_name, set_inds in self._evaluation_set.items(): 136 | iou_list = [] 137 | set_inds = np.array(set_inds, np.int) 138 | mask = np.zeros((len(iou),)).astype(np.bool) 139 | mask[set_inds] = 1 140 | miou = np.sum(iou[mask][acc_valid[mask]]) / np.sum(iou_valid[mask]) 141 | pacc = np.sum(tp[mask]) / np.sum(pos_gt[mask]) 142 | res["mIoU-{}".format(set_name)] = 100 * miou 143 | res["pAcc-{}".format(set_name)] = 100 * pacc 144 | iou_list.append(miou) 145 | miou = np.sum(iou[~mask][acc_valid[~mask]]) / np.sum(iou_valid[~mask]) 146 | pacc = np.sum(tp[~mask]) / np.sum(pos_gt[~mask]) 147 | res["mIoU-un{}".format(set_name)] = 100 * miou 148 | res["pAcc-un{}".format(set_name)] = 100 * pacc 149 | iou_list.append(miou) 150 | res["hIoU-{}".format(set_name)] = ( 151 | 100 * len(iou_list) / sum([1 / iou for iou in iou_list]) 152 | ) 153 | if self._output_dir: 154 | file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth") 155 | with PathManager.open(file_path, "wb") as f: 156 | torch.save(res, f) 157 | results = OrderedDict({"sem_seg": res}) 158 | self._logger.info(results) 159 | return results 160 | -------------------------------------------------------------------------------- /open_vocab_seg/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | from .backbone.swin import D2SwinTransformer 5 | from .backbone.clip_resnet import D2ModifiedResNet 6 | from .heads.mask_former_head import MaskFormerHead 7 | from .heads.open_vocab_mask_former_head import OpenVocabMaskFormerHead 8 | from .heads.pixel_decoder import BasePixelDecoder 9 | -------------------------------------------------------------------------------- /open_vocab_seg/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | -------------------------------------------------------------------------------- /open_vocab_seg/modeling/clip_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | from .text_template import ( 5 | PredefinedPromptExtractor, 6 | ImageNetPromptExtractor, 7 | VILDPromptExtractor, 8 | ) 9 | from .adapter import ClipAdapter, MaskFormerClipAdapter 10 | 11 | 12 | def build_text_prompt(cfg): 13 | if cfg.TEXT_TEMPLATES == "predefined": 14 | text_templates = PredefinedPromptExtractor(cfg.PREDEFINED_PROMPT_TEMPLATES) 15 | elif cfg.TEXT_TEMPLATES == "imagenet": 16 | text_templates = ImageNetPromptExtractor() 17 | elif cfg.TEXT_TEMPLATES == "vild": 18 | text_templates = VILDPromptExtractor() 19 | else: 20 | raise NotImplementedError( 21 | "Prompt learner {} is not supported".format(cfg.TEXT_TEMPLATES) 22 | ) 23 | return text_templates 24 | -------------------------------------------------------------------------------- /open_vocab_seg/modeling/clip_adapter/text_template.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Modified by Feng Liang from 4 | # https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/text_prompt.py 5 | # https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/utils.py 6 | 7 | from typing import List 8 | 9 | import clip 10 | import torch 11 | from torch import nn 12 | 13 | IMAGENET_PROMPT = [ 14 | "a bad photo of a {}.", 15 | "a photo of many {}.", 16 | "a sculpture of a {}.", 17 | "a photo of the hard to see {}.", 18 | "a low resolution photo of the {}.", 19 | "a rendering of a {}.", 20 | "graffiti of a {}.", 21 | "a bad photo of the {}.", 22 | "a cropped photo of the {}.", 23 | "a tattoo of a {}.", 24 | "the embroidered {}.", 25 | "a photo of a hard to see {}.", 26 | "a bright photo of a {}.", 27 | "a photo of a clean {}.", 28 | "a photo of a dirty {}.", 29 | "a dark photo of the {}.", 30 | "a drawing of a {}.", 31 | "a photo of my {}.", 32 | "the plastic {}.", 33 | "a photo of the cool {}.", 34 | "a close-up photo of a {}.", 35 | "a black and white photo of the {}.", 36 | "a painting of the {}.", 37 | "a painting of a {}.", 38 | "a pixelated photo of the {}.", 39 | "a sculpture of the {}.", 40 | "a bright photo of the {}.", 41 | "a cropped photo of a {}.", 42 | "a plastic {}.", 43 | "a photo of the dirty {}.", 44 | "a jpeg corrupted photo of a {}.", 45 | "a blurry photo of the {}.", 46 | "a photo of the {}.", 47 | "a good photo of the {}.", 48 | "a rendering of the {}.", 49 | "a {} in a video game.", 50 | "a photo of one {}.", 51 | "a doodle of a {}.", 52 | "a close-up photo of the {}.", 53 | "a photo of a {}.", 54 | "the origami {}.", 55 | "the {} in a video game.", 56 | "a sketch of a {}.", 57 | "a doodle of the {}.", 58 | "a origami {}.", 59 | "a low resolution photo of a {}.", 60 | "the toy {}.", 61 | "a rendition of the {}.", 62 | "a photo of the clean {}.", 63 | "a photo of a large {}.", 64 | "a rendition of a {}.", 65 | "a photo of a nice {}.", 66 | "a photo of a weird {}.", 67 | "a blurry photo of a {}.", 68 | "a cartoon {}.", 69 | "art of a {}.", 70 | "a sketch of the {}.", 71 | "a embroidered {}.", 72 | "a pixelated photo of a {}.", 73 | "itap of the {}.", 74 | "a jpeg corrupted photo of the {}.", 75 | "a good photo of a {}.", 76 | "a plushie {}.", 77 | "a photo of the nice {}.", 78 | "a photo of the small {}.", 79 | "a photo of the weird {}.", 80 | "the cartoon {}.", 81 | "art of the {}.", 82 | "a drawing of the {}.", 83 | "a photo of the large {}.", 84 | "a black and white photo of a {}.", 85 | "the plushie {}.", 86 | "a dark photo of a {}.", 87 | "itap of a {}.", 88 | "graffiti of the {}.", 89 | "a toy {}.", 90 | "itap of my {}.", 91 | "a photo of a cool {}.", 92 | "a photo of a small {}.", 93 | "a tattoo of the {}.", 94 | ] 95 | 96 | VILD_PROMPT = [ 97 | "a photo of a {}.", 98 | "This is a photo of a {}", 99 | "There is a {} in the scene", 100 | "There is the {} in the scene", 101 | "a photo of a {} in the scene", 102 | "a photo of a small {}.", 103 | "a photo of a medium {}.", 104 | "a photo of a large {}.", 105 | "This is a photo of a small {}.", 106 | "This is a photo of a medium {}.", 107 | "This is a photo of a large {}.", 108 | "There is a small {} in the scene.", 109 | "There is a medium {} in the scene.", 110 | "There is a large {} in the scene.", 111 | ] 112 | 113 | class PromptExtractor(nn.Module): 114 | def __init__(self): 115 | super().__init__() 116 | self._buffer_init = False 117 | 118 | def init_buffer(self, clip_model): 119 | self._buffer_init = True 120 | 121 | def forward(self, noun_list: List[str], clip_model: nn.Module): 122 | raise NotImplementedError() 123 | 124 | 125 | class PredefinedPromptExtractor(PromptExtractor): 126 | def __init__(self, templates: List[str]): 127 | super().__init__() 128 | self.templates = templates 129 | 130 | def forward(self, noun_list: List[str], clip_model: nn.Module): 131 | text_features_bucket = [] 132 | for template in self.templates: 133 | noun_tokens = [clip.tokenize(template.format(noun)) for noun in noun_list] 134 | text_inputs = torch.cat(noun_tokens).to( 135 | clip_model.text_projection.data.device 136 | ) 137 | text_features = clip_model.encode_text(text_inputs) 138 | text_features /= text_features.norm(dim=-1, keepdim=True) 139 | text_features_bucket.append(text_features) 140 | del text_inputs 141 | # ensemble by averaging 142 | text_features = torch.stack(text_features_bucket).mean(dim=0) 143 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 144 | 145 | return text_features 146 | 147 | 148 | class ImageNetPromptExtractor(PredefinedPromptExtractor): 149 | def __init__(self): 150 | super().__init__(IMAGENET_PROMPT) 151 | 152 | 153 | class VILDPromptExtractor(PredefinedPromptExtractor): 154 | def __init__(self): 155 | super().__init__(VILD_PROMPT) 156 | -------------------------------------------------------------------------------- /open_vocab_seg/modeling/clip_adapter/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | from typing import Tuple 5 | import numpy as np 6 | import torch 7 | import clip 8 | from detectron2.utils.comm import get_local_rank, synchronize 9 | 10 | 11 | def expand_box( 12 | x1: float, 13 | y1: float, 14 | x2: float, 15 | y2: float, 16 | expand_ratio: float = 1.0, 17 | max_h: int = None, 18 | max_w: int = None, 19 | ): 20 | cx = 0.5 * (x1 + x2) 21 | cy = 0.5 * (y1 + y2) 22 | w = x2 - x1 23 | h = y2 - y1 24 | w = w * expand_ratio 25 | h = h * expand_ratio 26 | box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h] 27 | if max_h is not None: 28 | box[1] = max(0, box[1]) 29 | box[3] = min(max_h - 1, box[3]) 30 | if max_w is not None: 31 | box[0] = max(0, box[0]) 32 | box[2] = min(max_w - 1, box[2]) 33 | return [int(b) for b in box] 34 | 35 | 36 | def mask2box(mask: torch.Tensor): 37 | # use naive way 38 | row = torch.nonzero(mask.sum(dim=0))[:, 0] 39 | if len(row) == 0: 40 | return None 41 | x1 = row.min() 42 | x2 = row.max() 43 | col = np.nonzero(mask.sum(dim=1))[:, 0] 44 | y1 = col.min() 45 | y2 = col.max() 46 | return x1, y1, x2 + 1, y2 + 1 47 | 48 | 49 | def crop_with_mask( 50 | image: torch.Tensor, 51 | mask: torch.Tensor, 52 | bbox: torch.Tensor, 53 | fill: Tuple[float, float, float] = (0, 0, 0), 54 | expand_ratio: float = 1.0, 55 | ): 56 | l, t, r, b = expand_box(*bbox, expand_ratio) 57 | _, h, w = image.shape 58 | l = max(l, 0) 59 | t = max(t, 0) 60 | r = min(r, w) 61 | b = min(b, h) 62 | new_image = torch.cat( 63 | [image.new_full((1, b - t, r - l), fill_value=val) for val in fill] 64 | ) 65 | # return image[:, t:b, l:r], mask[None, t:b, l:r] 66 | return image[:, t:b, l:r] * mask[None, t:b, l:r] + (1 - mask[None, t:b, l:r]) * new_image, mask[None, t:b, l:r] 67 | 68 | 69 | def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True): 70 | rank = get_local_rank() 71 | if rank == 0: 72 | # download on rank 0 only 73 | model, _ = clip.load(model, mask_prompt_depth=mask_prompt_depth, device="cpu") 74 | synchronize() 75 | if rank != 0: 76 | model, _ = clip.load(model, mask_prompt_depth=mask_prompt_depth, device="cpu") 77 | synchronize() 78 | if frozen: 79 | for param in model.parameters(): 80 | param.requires_grad = False 81 | return model 82 | -------------------------------------------------------------------------------- /open_vocab_seg/modeling/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved -------------------------------------------------------------------------------- /open_vocab_seg/modeling/heads/mask_former_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import logging 5 | from copy import deepcopy 6 | from typing import Callable, Dict, List, Optional, Tuple, Union 7 | 8 | import fvcore.nn.weight_init as weight_init 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from detectron2.config import configurable 13 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 14 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 15 | 16 | from ..transformer.transformer_predictor import TransformerPredictor 17 | from .pixel_decoder import build_pixel_decoder 18 | 19 | 20 | @SEM_SEG_HEADS_REGISTRY.register() 21 | class MaskFormerHead(nn.Module): 22 | 23 | _version = 2 24 | 25 | def _load_from_state_dict( 26 | self, 27 | state_dict, 28 | prefix, 29 | local_metadata, 30 | strict, 31 | missing_keys, 32 | unexpected_keys, 33 | error_msgs, 34 | ): 35 | version = local_metadata.get("version", None) 36 | if version is None or version < 2: 37 | # Do not warn if train from scratch 38 | scratch = True 39 | logger = logging.getLogger(__name__) 40 | for k in list(state_dict.keys()): 41 | newk = k 42 | if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): 43 | newk = k.replace(prefix, prefix + "pixel_decoder.") 44 | # logger.debug(f"{k} ==> {newk}") 45 | if newk != k: 46 | state_dict[newk] = state_dict[k] 47 | del state_dict[k] 48 | scratch = False 49 | 50 | if not scratch: 51 | logger.warning( 52 | f"Weight format of {self.__class__.__name__} have changed! " 53 | "Please upgrade your models. Applying automatic conversion now ..." 54 | ) 55 | 56 | @configurable 57 | def __init__( 58 | self, 59 | input_shape: Dict[str, ShapeSpec], 60 | *, 61 | num_classes: int, 62 | pixel_decoder: nn.Module, 63 | loss_weight: float = 1.0, 64 | ignore_value: int = -1, 65 | # extra parameters 66 | transformer_predictor: nn.Module, 67 | transformer_in_feature: str, 68 | ): 69 | """ 70 | NOTE: this interface is experimental. 71 | Args: 72 | input_shape: shapes (channels and stride) of the input features 73 | num_classes: number of classes to predict 74 | pixel_decoder: the pixel decoder module 75 | loss_weight: loss weight 76 | ignore_value: category id to be ignored during training. 77 | transformer_predictor: the transformer decoder that makes prediction 78 | transformer_in_feature: input feature name to the transformer_predictor 79 | """ 80 | super().__init__() 81 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 82 | self.in_features = [k for k, v in input_shape] 83 | feature_strides = [v.stride for k, v in input_shape] 84 | feature_channels = [v.channels for k, v in input_shape] 85 | 86 | self.ignore_value = ignore_value 87 | self.common_stride = 4 88 | self.loss_weight = loss_weight 89 | 90 | self.pixel_decoder = pixel_decoder 91 | self.predictor = transformer_predictor 92 | self.transformer_in_feature = transformer_in_feature 93 | 94 | self.num_classes = num_classes 95 | 96 | @classmethod 97 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 98 | return { 99 | "input_shape": { 100 | k: v 101 | for k, v in input_shape.items() 102 | if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES 103 | }, 104 | "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 105 | "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 106 | "pixel_decoder": build_pixel_decoder(cfg, input_shape), 107 | "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, 108 | "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE, 109 | "transformer_predictor": TransformerPredictor( 110 | cfg, 111 | cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM 112 | if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder" 113 | else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels, 114 | mask_classification=True, 115 | ), 116 | } 117 | 118 | def forward(self, features): 119 | return self.layers(features) 120 | 121 | def layers(self, features): 122 | ( 123 | mask_features, 124 | transformer_encoder_features, 125 | ) = self.pixel_decoder.forward_features(features) 126 | if self.transformer_in_feature == "transformer_encoder": 127 | assert ( 128 | transformer_encoder_features is not None 129 | ), "Please use the TransformerEncoderPixelDecoder." 130 | predictions = self.predictor(transformer_encoder_features, mask_features) 131 | else: 132 | predictions = self.predictor( 133 | features[self.transformer_in_feature], mask_features 134 | ) 135 | return predictions 136 | -------------------------------------------------------------------------------- /open_vocab_seg/modeling/heads/open_vocab_mask_former_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | # Modified by Feng Liang from 4 | # https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/heads/zero_shot_mask_former_head.py 5 | 6 | import logging 7 | from copy import deepcopy 8 | from typing import Callable, Dict, List, Optional, Tuple, Union 9 | 10 | import fvcore.nn.weight_init as weight_init 11 | from torch import nn 12 | from torch.nn import functional as F 13 | 14 | from detectron2.config import configurable 15 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 16 | from detectron2.modeling import SEM_SEG_HEADS_REGISTRY 17 | 18 | from ..transformer.open_vocab_transformer_predictor import OpenVocabTransformerPredictor 19 | from .pixel_decoder import build_pixel_decoder 20 | 21 | 22 | @SEM_SEG_HEADS_REGISTRY.register() 23 | class OpenVocabMaskFormerHead(nn.Module): 24 | 25 | _version = 2 26 | 27 | def _load_from_state_dict( 28 | self, 29 | state_dict, 30 | prefix, 31 | local_metadata, 32 | strict, 33 | missing_keys, 34 | unexpected_keys, 35 | error_msgs, 36 | ): 37 | version = local_metadata.get("version", None) 38 | if version is None or version < 2: 39 | # Do not warn if train from scratch 40 | scratch = True 41 | logger = logging.getLogger(__name__) 42 | for k in list(state_dict.keys()): 43 | newk = k 44 | if "sem_seg_head" in k and not k.startswith(prefix + "predictor"): 45 | newk = k.replace(prefix, prefix + "pixel_decoder.") 46 | # logger.debug(f"{k} ==> {newk}") 47 | if newk != k: 48 | state_dict[newk] = state_dict[k] 49 | del state_dict[k] 50 | scratch = False 51 | 52 | if not scratch: 53 | logger.warning( 54 | f"Weight format of {self.__class__.__name__} have changed! " 55 | "Please upgrade your models. Applying automatic conversion now ..." 56 | ) 57 | 58 | @configurable 59 | def __init__( 60 | self, 61 | input_shape: Dict[str, ShapeSpec], 62 | *, 63 | num_classes: int, 64 | pixel_decoder: nn.Module, 65 | loss_weight: float = 1.0, 66 | ignore_value: int = -1, 67 | # extra parameters 68 | transformer_predictor: nn.Module, 69 | transformer_in_feature: str, 70 | ): 71 | """ 72 | NOTE: this interface is experimental. 73 | Args: 74 | input_shape: shapes (channels and stride) of the input features 75 | num_classes: number of classes to predict 76 | pixel_decoder: the pixel decoder module 77 | loss_weight: loss weight 78 | ignore_value: category id to be ignored during training. 79 | transformer_predictor: the transformer decoder that makes prediction 80 | transformer_in_feature: input feature name to the transformer_predictor 81 | """ 82 | super().__init__() 83 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 84 | self.in_features = [k for k, v in input_shape] 85 | feature_strides = [v.stride for k, v in input_shape] 86 | feature_channels = [v.channels for k, v in input_shape] 87 | 88 | self.ignore_value = ignore_value 89 | self.common_stride = 4 90 | self.loss_weight = loss_weight 91 | 92 | self.pixel_decoder = pixel_decoder 93 | self.predictor = transformer_predictor 94 | self.transformer_in_feature = transformer_in_feature 95 | 96 | self.num_classes = num_classes 97 | 98 | @classmethod 99 | def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): 100 | return { 101 | "input_shape": { 102 | k: v 103 | for k, v in input_shape.items() 104 | if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES 105 | }, 106 | "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 107 | "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 108 | "pixel_decoder": build_pixel_decoder(cfg, input_shape), 109 | "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, 110 | "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE, 111 | "transformer_predictor": OpenVocabTransformerPredictor( 112 | cfg, 113 | cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM 114 | if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder" 115 | else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels, 116 | mask_classification=True, 117 | ), 118 | } 119 | 120 | def forward(self, features): 121 | return self.layers(features) 122 | 123 | def layers(self, features): 124 | ( 125 | mask_features, 126 | transformer_encoder_features, 127 | ) = self.pixel_decoder.forward_features(features) 128 | if self.transformer_in_feature == "transformer_encoder": 129 | assert ( 130 | transformer_encoder_features is not None 131 | ), "Please use the TransformerEncoderPixelDecoder." 132 | predictions = self.predictor(transformer_encoder_features, mask_features) 133 | else: 134 | predictions = self.predictor( 135 | features[self.transformer_in_feature], mask_features 136 | ) 137 | return predictions 138 | 139 | def freeze_pretrained(self): 140 | for name, module in self.named_children(): 141 | if name not in ["predictor"]: 142 | for param in module.parameters(): 143 | param.requires_grad = False 144 | else: 145 | module.freeze_pretrained() 146 | -------------------------------------------------------------------------------- /open_vocab_seg/modeling/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | -------------------------------------------------------------------------------- /open_vocab_seg/modeling/transformer/open_vocab_transformer_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py 3 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | 5 | from torch import nn 6 | from detectron2.config import configurable 7 | from .transformer_predictor import TransformerPredictor, MLP 8 | 9 | 10 | class OpenVocabTransformerPredictor(TransformerPredictor): 11 | @configurable 12 | def __init__( 13 | self, 14 | in_channels, 15 | mask_classification=True, 16 | *, 17 | embedding_dim: int, 18 | embed_hidden_dim: int, 19 | embed_layers: int, 20 | hidden_dim: int, 21 | num_queries: int, 22 | nheads: int, 23 | dropout: float, 24 | dim_feedforward: int, 25 | enc_layers: int, 26 | dec_layers: int, 27 | pre_norm: bool, 28 | deep_supervision: bool, 29 | mask_dim: int, 30 | enforce_input_project: bool, 31 | ): 32 | super().__init__( 33 | in_channels, 34 | False, 35 | num_classes=embedding_dim, 36 | hidden_dim=hidden_dim, 37 | num_queries=num_queries, 38 | nheads=nheads, 39 | dropout=dropout, 40 | dim_feedforward=dim_feedforward, 41 | enc_layers=enc_layers, 42 | dec_layers=dec_layers, 43 | pre_norm=pre_norm, 44 | deep_supervision=deep_supervision, 45 | mask_dim=mask_dim, 46 | enforce_input_project=enforce_input_project, 47 | ) 48 | self.mask_classification = mask_classification 49 | # output FFNs 50 | if self.mask_classification: 51 | self.class_embed = MLP( 52 | hidden_dim, embed_hidden_dim, embedding_dim, embed_layers 53 | ) 54 | 55 | def freeze_pretrained(self): 56 | for name, module in self.named_children(): 57 | if name not in ["class_embed"]: 58 | for param in module.parameters(): 59 | param.requires_grad = False 60 | 61 | @classmethod 62 | def from_config(cls, cfg, in_channels, mask_classification): 63 | ret = {} 64 | ret["in_channels"] = in_channels 65 | ret["mask_classification"] = mask_classification 66 | 67 | ret["embedding_dim"] = cfg.MODEL.SEM_SEG_HEAD.EMBEDDING_DIM 68 | ret["embed_hidden_dim"] = cfg.MODEL.SEM_SEG_HEAD.EMBED_HIDDEN_DIM 69 | ret["embed_layers"] = cfg.MODEL.SEM_SEG_HEAD.EMBED_LAYERS 70 | ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM 71 | ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES 72 | # Transformer parameters: 73 | ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS 74 | ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT 75 | ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD 76 | ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS 77 | ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS 78 | ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM 79 | ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION 80 | ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ 81 | 82 | ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM 83 | 84 | return ret 85 | -------------------------------------------------------------------------------- /open_vocab_seg/modeling/transformer/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 3 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | 5 | """ 6 | Various positional encodings for the transformer. 7 | """ 8 | import math 9 | 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class PositionEmbeddingSine(nn.Module): 15 | """ 16 | This is a more standard version of the position embedding, very similar to the one 17 | used by the Attention is all you need paper, generalized to work on images. 18 | """ 19 | 20 | def __init__( 21 | self, num_pos_feats=64, temperature=10000, normalize=False, scale=None 22 | ): 23 | super().__init__() 24 | self.num_pos_feats = num_pos_feats 25 | self.temperature = temperature 26 | self.normalize = normalize 27 | if scale is not None and normalize is False: 28 | raise ValueError("normalize should be True if scale is passed") 29 | if scale is None: 30 | scale = 2 * math.pi 31 | self.scale = scale 32 | 33 | def forward(self, x, mask=None): 34 | if mask is None: 35 | mask = torch.zeros( 36 | (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool 37 | ) 38 | not_mask = ~mask 39 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 40 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 41 | if self.normalize: 42 | eps = 1e-6 43 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 44 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 45 | 46 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 47 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 48 | 49 | pos_x = x_embed[:, :, :, None] / dim_t 50 | pos_y = y_embed[:, :, :, None] / dim_t 51 | pos_x = torch.stack( 52 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 53 | ).flatten(3) 54 | pos_y = torch.stack( 55 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 56 | ).flatten(3) 57 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 58 | return pos 59 | -------------------------------------------------------------------------------- /open_vocab_seg/modeling/transformer/transformer_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py 3 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | 5 | import fvcore.nn.weight_init as weight_init 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from detectron2.config import configurable 11 | from detectron2.layers import Conv2d 12 | 13 | from .position_encoding import PositionEmbeddingSine 14 | from .transformer import Transformer 15 | 16 | 17 | class TransformerPredictor(nn.Module): 18 | @configurable 19 | def __init__( 20 | self, 21 | in_channels, 22 | mask_classification=True, 23 | *, 24 | num_classes: int, 25 | hidden_dim: int, 26 | num_queries: int, 27 | nheads: int, 28 | dropout: float, 29 | dim_feedforward: int, 30 | enc_layers: int, 31 | dec_layers: int, 32 | pre_norm: bool, 33 | deep_supervision: bool, 34 | mask_dim: int, 35 | enforce_input_project: bool, 36 | ): 37 | """ 38 | NOTE: this interface is experimental. 39 | Args: 40 | in_channels: channels of the input features 41 | mask_classification: whether to add mask classifier or not 42 | num_classes: number of classes 43 | hidden_dim: Transformer feature dimension 44 | num_queries: number of queries 45 | nheads: number of heads 46 | dropout: dropout in Transformer 47 | dim_feedforward: feature dimension in feedforward network 48 | enc_layers: number of Transformer encoder layers 49 | dec_layers: number of Transformer decoder layers 50 | pre_norm: whether to use pre-LayerNorm or not 51 | deep_supervision: whether to add supervision to every decoder layers 52 | mask_dim: mask feature dimension 53 | enforce_input_project: add input project 1x1 conv even if input 54 | channels and hidden dim is identical 55 | """ 56 | super().__init__() 57 | 58 | self.mask_classification = mask_classification 59 | 60 | # positional encoding 61 | N_steps = hidden_dim // 2 62 | self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) 63 | 64 | transformer = Transformer( 65 | d_model=hidden_dim, 66 | dropout=dropout, 67 | nhead=nheads, 68 | dim_feedforward=dim_feedforward, 69 | num_encoder_layers=enc_layers, 70 | num_decoder_layers=dec_layers, 71 | normalize_before=pre_norm, 72 | return_intermediate_dec=deep_supervision, 73 | ) 74 | 75 | self.num_queries = num_queries 76 | self.transformer = transformer 77 | hidden_dim = transformer.d_model 78 | 79 | self.query_embed = nn.Embedding(num_queries, hidden_dim) 80 | 81 | if in_channels != hidden_dim or enforce_input_project: 82 | self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1) 83 | weight_init.c2_xavier_fill(self.input_proj) 84 | else: 85 | self.input_proj = nn.Sequential() 86 | self.aux_loss = deep_supervision 87 | 88 | # output FFNs 89 | if self.mask_classification: 90 | self.class_embed = nn.Linear(hidden_dim, num_classes + 1) 91 | self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) 92 | 93 | @classmethod 94 | def from_config(cls, cfg, in_channels, mask_classification): 95 | ret = {} 96 | ret["in_channels"] = in_channels 97 | ret["mask_classification"] = mask_classification 98 | 99 | ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES 100 | ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM 101 | ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES 102 | # Transformer parameters: 103 | ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS 104 | ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT 105 | ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD 106 | ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS 107 | ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS 108 | ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM 109 | ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION 110 | ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ 111 | 112 | ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM 113 | 114 | return ret 115 | 116 | def forward(self, x, mask_features): 117 | pos = self.pe_layer(x) 118 | 119 | src = x 120 | mask = None 121 | hs, memory = self.transformer( 122 | self.input_proj(src), mask, self.query_embed.weight, pos 123 | ) 124 | 125 | if self.mask_classification: 126 | outputs_class = self.class_embed(hs) 127 | out = {"pred_logits": outputs_class[-1]} 128 | else: 129 | out = {} 130 | 131 | if self.aux_loss: 132 | # [l, bs, queries, embed] 133 | mask_embed = self.mask_embed(hs) 134 | outputs_seg_masks = torch.einsum( 135 | "lbqc,bchw->lbqhw", mask_embed, mask_features 136 | ) 137 | out["pred_masks"] = outputs_seg_masks[-1] 138 | out["aux_outputs"] = self._set_aux_loss( 139 | outputs_class if self.mask_classification else None, outputs_seg_masks 140 | ) 141 | else: 142 | # FIXME h_boxes takes the last one computed, keep this in mind 143 | # [bs, queries, embed] 144 | mask_embed = self.mask_embed(hs[-1]) 145 | outputs_seg_masks = torch.einsum( 146 | "bqc,bchw->bqhw", mask_embed, mask_features 147 | ) 148 | out["pred_masks"] = outputs_seg_masks 149 | return out 150 | 151 | @torch.jit.unused 152 | def _set_aux_loss(self, outputs_class, outputs_seg_masks): 153 | # this is a workaround to make torchscript happy, as torchscript 154 | # doesn't support dictionary with non-homogeneous values, such 155 | # as a dict having both a Tensor and a list. 156 | if self.mask_classification: 157 | return [ 158 | {"pred_logits": a, "pred_masks": b} 159 | for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) 160 | ] 161 | else: 162 | return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] 163 | 164 | 165 | class MLP(nn.Module): 166 | """Very simple multi-layer perceptron (also called FFN)""" 167 | 168 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 169 | super().__init__() 170 | self.num_layers = num_layers 171 | h = [hidden_dim] * (num_layers - 1) 172 | self.layers = nn.ModuleList( 173 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 174 | ) 175 | 176 | def forward(self, x): 177 | for i, layer in enumerate(self.layers): 178 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 179 | return x 180 | -------------------------------------------------------------------------------- /open_vocab_seg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | from .events import setup_wandb, WandbWriter 5 | from .predictor import VisualizationDemo -------------------------------------------------------------------------------- /open_vocab_seg/utils/events.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import os 5 | import wandb 6 | from detectron2.utils import comm 7 | from detectron2.utils.events import EventWriter, get_event_storage 8 | 9 | 10 | def setup_wandb(cfg, args): 11 | if comm.is_main_process(): 12 | init_args = { 13 | k.lower(): v 14 | for k, v in cfg.WANDB.items() 15 | if isinstance(k, str) and k not in ["config", "name"] 16 | } 17 | # only include most related part to avoid too big table 18 | # TODO: add configurable params to select which part of `cfg` should be saved in config 19 | if "config_exclude_keys" in init_args: 20 | init_args["config"] = cfg 21 | init_args["config"]["cfg_file"] = args.config_file 22 | else: 23 | init_args["config"] = { 24 | "model": cfg.MODEL, 25 | "solver": cfg.SOLVER, 26 | "cfg_file": args.config_file, 27 | } 28 | if ("name" not in init_args) or (init_args["name"] is None): 29 | init_args["name"] = os.path.basename(args.config_file) 30 | wandb.init(**init_args) 31 | 32 | 33 | class BaseRule(object): 34 | def __call__(self, target): 35 | return target 36 | 37 | 38 | class IsIn(BaseRule): 39 | def __init__(self, keyword: str): 40 | self.keyword = keyword 41 | 42 | def __call__(self, target): 43 | return self.keyword in target 44 | 45 | 46 | class Prefix(BaseRule): 47 | def __init__(self, keyword: str): 48 | self.keyword = keyword 49 | 50 | def __call__(self, target): 51 | return "/".join([self.keyword, target]) 52 | 53 | 54 | class WandbWriter(EventWriter): 55 | """ 56 | Write all scalars to a tensorboard file. 57 | """ 58 | 59 | def __init__(self): 60 | """ 61 | Args: 62 | log_dir (str): the directory to save the output events 63 | kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)` 64 | """ 65 | self._last_write = -1 66 | self._group_rules = [ 67 | (IsIn("/"), BaseRule()), 68 | (IsIn("loss"), Prefix("train")), 69 | ] 70 | 71 | def write(self): 72 | 73 | storage = get_event_storage() 74 | 75 | def _group_name(scalar_name): 76 | for (rule, op) in self._group_rules: 77 | if rule(scalar_name): 78 | return op(scalar_name) 79 | return scalar_name 80 | 81 | stats = { 82 | _group_name(name): scalars[0] 83 | for name, scalars in storage.latest().items() 84 | if scalars[1] > self._last_write 85 | } 86 | if len(stats) > 0: 87 | self._last_write = max([v[1] for k, v in storage.latest().items()]) 88 | 89 | # storage.put_{image,histogram} is only meant to be used by 90 | # tensorboard writer. So we access its internal fields directly from here. 91 | if len(storage._vis_data) >= 1: 92 | stats["image"] = [ 93 | wandb.Image(img, caption=img_name) 94 | for img_name, img, step_num in storage._vis_data 95 | ] 96 | # Storage stores all image data and rely on this writer to clear them. 97 | # As a result it assumes only one writer will use its image data. 98 | # An alternative design is to let storage store limited recent 99 | # data (e.g. only the most recent image) that all writers can access. 100 | # In that case a writer may not see all image data if its period is long. 101 | storage.clear_images() 102 | 103 | if len(storage._histograms) >= 1: 104 | 105 | def create_bar(tag, bucket_limits, bucket_counts, **kwargs): 106 | data = [ 107 | [label, val] for (label, val) in zip(bucket_limits, bucket_counts) 108 | ] 109 | table = wandb.Table(data=data, columns=["label", "value"]) 110 | return wandb.plot.bar(table, "label", "value", title=tag) 111 | 112 | stats["hist"] = [create_bar(**params) for params in storage._histograms] 113 | 114 | storage.clear_histograms() 115 | 116 | if len(stats) == 0: 117 | return 118 | wandb.log(stats, step=storage.iter) 119 | 120 | def close(self): 121 | wandb.finish() 122 | -------------------------------------------------------------------------------- /open_vocab_seg/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py 3 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | 5 | """ 6 | Misc functions, including distributed helpers. 7 | 8 | Mostly copy-paste from torchvision references. 9 | """ 10 | from typing import List, Optional 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import torchvision 15 | from torch import Tensor 16 | 17 | 18 | def _max_by_axis(the_list): 19 | # type: (List[List[int]]) -> List[int] 20 | maxes = the_list[0] 21 | for sublist in the_list[1:]: 22 | for index, item in enumerate(sublist): 23 | maxes[index] = max(maxes[index], item) 24 | return maxes 25 | 26 | 27 | class NestedTensor(object): 28 | def __init__(self, tensors, mask: Optional[Tensor]): 29 | self.tensors = tensors 30 | self.mask = mask 31 | 32 | def to(self, device): 33 | # type: (Device) -> NestedTensor # noqa 34 | cast_tensor = self.tensors.to(device) 35 | mask = self.mask 36 | if mask is not None: 37 | assert mask is not None 38 | cast_mask = mask.to(device) 39 | else: 40 | cast_mask = None 41 | return NestedTensor(cast_tensor, cast_mask) 42 | 43 | def decompose(self): 44 | return self.tensors, self.mask 45 | 46 | def __repr__(self): 47 | return str(self.tensors) 48 | 49 | 50 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 51 | # TODO make this more general 52 | if tensor_list[0].ndim == 3: 53 | if torchvision._is_tracing(): 54 | # nested_tensor_from_tensor_list() does not export well to ONNX 55 | # call _onnx_nested_tensor_from_tensor_list() instead 56 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 57 | 58 | # TODO make it support different-sized images 59 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 60 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 61 | batch_shape = [len(tensor_list)] + max_size 62 | b, c, h, w = batch_shape 63 | dtype = tensor_list[0].dtype 64 | device = tensor_list[0].device 65 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 66 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 67 | for img, pad_img, m in zip(tensor_list, tensor, mask): 68 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 69 | m[: img.shape[1], : img.shape[2]] = False 70 | else: 71 | raise ValueError("not supported") 72 | return NestedTensor(tensor, mask) 73 | 74 | 75 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 76 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 77 | @torch.jit.unused 78 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 79 | max_size = [] 80 | for i in range(tensor_list[0].dim()): 81 | max_size_i = torch.max( 82 | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) 83 | ).to(torch.int64) 84 | max_size.append(max_size_i) 85 | max_size = tuple(max_size) 86 | 87 | # work around for 88 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 89 | # m[: img.shape[1], :img.shape[2]] = False 90 | # which is not yet supported in onnx 91 | padded_imgs = [] 92 | padded_masks = [] 93 | for img in tensor_list: 94 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 95 | padded_img = torch.nn.functional.pad( 96 | img, (0, padding[2], 0, padding[1], 0, padding[0]) 97 | ) 98 | padded_imgs.append(padded_img) 99 | 100 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 101 | padded_mask = torch.nn.functional.pad( 102 | m, (0, padding[2], 0, padding[1]), "constant", 1 103 | ) 104 | padded_masks.append(padded_mask.to(torch.bool)) 105 | 106 | tensor = torch.stack(padded_imgs) 107 | mask = torch.stack(padded_masks) 108 | 109 | return NestedTensor(tensor, mask=mask) 110 | 111 | 112 | def is_dist_avail_and_initialized(): 113 | if not dist.is_available(): 114 | return False 115 | if not dist.is_initialized(): 116 | return False 117 | return True 118 | 119 | def get_gt_binary_masks(gt_semseg): 120 | mask_ids = torch.unique(gt_semseg) 121 | gt_masks = [] 122 | for id in mask_ids: 123 | if id != 255: 124 | gt_masks.append(gt_semseg == id) 125 | gt_masks = torch.stack(gt_masks).float() 126 | return gt_masks 127 | -------------------------------------------------------------------------------- /open_vocab_seg/utils/post_process_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | import numpy as np 7 | 8 | try: 9 | import pydensecrf.densecrf as dcrf 10 | from pydensecrf.utils import ( 11 | unary_from_softmax, 12 | unary_from_labels, 13 | create_pairwise_bilateral, 14 | create_pairwise_gaussian, 15 | ) 16 | except: 17 | dcrf = None 18 | 19 | 20 | def dense_crf_post_process( 21 | logits, 22 | image, 23 | n_labels=None, 24 | max_iters=5, 25 | pos_xy_std=(3, 3), 26 | pos_w=3, 27 | bi_xy_std=(80, 80), 28 | bi_rgb_std=(13, 13, 13), 29 | bi_w=10, 30 | ): 31 | """ 32 | logits : [C,H,W] 33 | image : [3,H,W] 34 | """ 35 | if dcrf is None: 36 | raise FileNotFoundError( 37 | "pydensecrf is required to perform dense crf inference." 38 | ) 39 | if isinstance(logits, torch.Tensor): 40 | logits = F.softmax(logits, dim=0).detach().cpu().numpy() 41 | U = unary_from_softmax(logits) 42 | n_labels = logits.shape[0] 43 | elif logits.ndim == 3: 44 | U = unary_from_softmax(logits) 45 | n_labels = logits.shape[0] 46 | else: 47 | assert n_labels is not None 48 | U = unary_from_labels(logits, n_labels, zero_unsure=False) 49 | 50 | d = dcrf.DenseCRF2D(image.shape[1], image.shape[0], n_labels) 51 | 52 | d.setUnaryEnergy(U) 53 | 54 | # This adds the color-independent term, features are the locations only. 55 | d.addPairwiseGaussian( 56 | sxy=pos_xy_std, 57 | compat=pos_w, 58 | kernel=dcrf.DIAG_KERNEL, 59 | normalization=dcrf.NORMALIZE_SYMMETRIC, 60 | ) 61 | 62 | # This adds the color-dependent term, i.e. features are (x,y,r,g,b). 63 | d.addPairwiseBilateral( 64 | sxy=bi_xy_std, 65 | srgb=bi_rgb_std, 66 | rgbim=image, 67 | compat=bi_w, 68 | kernel=dcrf.DIAG_KERNEL, 69 | normalization=dcrf.NORMALIZE_SYMMETRIC, 70 | ) 71 | # Run five inference steps. 72 | logits = d.inference(max_iters) 73 | logits = np.asarray(logits).reshape((n_labels, image.shape[0], image.shape[1])) 74 | return torch.from_numpy(logits) 75 | -------------------------------------------------------------------------------- /open_vocab_seg/utils/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from detectron2.data import MetadataCatalog 8 | from detectron2.engine.defaults import DefaultPredictor 9 | from detectron2.utils.visualizer import ColorMode, Visualizer 10 | 11 | 12 | class OVSegPredictor(DefaultPredictor): 13 | def __init__(self, cfg): 14 | super().__init__(cfg) 15 | 16 | def __call__(self, original_image, class_names): 17 | """ 18 | Args: 19 | original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). 20 | 21 | Returns: 22 | predictions (dict): 23 | the output of the model for one image only. 24 | See :doc:`/tutorials/models` for details about the format. 25 | """ 26 | with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 27 | # Apply pre-processing to image. 28 | if self.input_format == "RGB": 29 | # whether the model expects BGR inputs or RGB 30 | original_image = original_image[:, :, ::-1] 31 | height, width = original_image.shape[:2] 32 | image = self.aug.get_transform(original_image).apply_image(original_image) 33 | image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) 34 | 35 | inputs = {"image": image, "height": height, "width": width, "class_names": class_names} 36 | predictions = self.model([inputs])[0] 37 | return predictions 38 | 39 | class OVSegVisualizer(Visualizer): 40 | def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE, class_names=None): 41 | super().__init__(img_rgb, metadata, scale, instance_mode) 42 | self.class_names = class_names 43 | 44 | def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8): 45 | """ 46 | Draw semantic segmentation predictions/labels. 47 | 48 | Args: 49 | sem_seg (Tensor or ndarray): the segmentation of shape (H, W). 50 | Each value is the integer label of the pixel. 51 | area_threshold (int): segments with less than `area_threshold` are not drawn. 52 | alpha (float): the larger it is, the more opaque the segmentations are. 53 | 54 | Returns: 55 | output (VisImage): image object with visualizations. 56 | """ 57 | if isinstance(sem_seg, torch.Tensor): 58 | sem_seg = sem_seg.numpy() 59 | labels, areas = np.unique(sem_seg, return_counts=True) 60 | sorted_idxs = np.argsort(-areas).tolist() 61 | labels = labels[sorted_idxs] 62 | class_names = self.class_names if self.class_names is not None else self.metadata.stuff_classes 63 | 64 | for label in filter(lambda l: l < len(class_names), labels): 65 | try: 66 | mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] 67 | except (AttributeError, IndexError): 68 | mask_color = None 69 | 70 | binary_mask = (sem_seg == label).astype(np.uint8) 71 | text = class_names[label] 72 | self.draw_binary_mask( 73 | binary_mask, 74 | color=mask_color, 75 | edge_color=(1.0, 1.0, 240.0 / 255), 76 | text=text, 77 | alpha=alpha, 78 | area_threshold=area_threshold, 79 | ) 80 | return self.output 81 | 82 | 83 | 84 | class VisualizationDemo(object): 85 | def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): 86 | """ 87 | Args: 88 | cfg (CfgNode): 89 | instance_mode (ColorMode): 90 | parallel (bool): whether to run the model in different processes from visualization. 91 | Useful since the visualization logic can be slow. 92 | """ 93 | self.metadata = MetadataCatalog.get( 94 | cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" 95 | ) 96 | 97 | self.cpu_device = torch.device("cpu") 98 | self.instance_mode = instance_mode 99 | 100 | self.parallel = parallel 101 | if parallel: 102 | raise NotImplementedError 103 | else: 104 | self.predictor = OVSegPredictor(cfg) 105 | 106 | def run_on_image(self, image, class_names): 107 | """ 108 | Args: 109 | image (np.ndarray): an image of shape (H, W, C) (in BGR order). 110 | This is the format used by OpenCV. 111 | Returns: 112 | predictions (dict): the output of the model. 113 | vis_output (VisImage): the visualized image output. 114 | """ 115 | predictions = self.predictor(image, class_names) 116 | # Convert image from OpenCV BGR format to Matplotlib RGB format. 117 | image = image[:, :, ::-1] 118 | visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names) 119 | if "sem_seg" in predictions: 120 | r = predictions["sem_seg"] 121 | blank_area = (r[0] == 0) 122 | pred_mask = r.argmax(dim=0).to('cpu') 123 | pred_mask[blank_area] = 255 124 | pred_mask = np.array(pred_mask, dtype=np.int) 125 | 126 | vis_output = visualizer.draw_sem_seg( 127 | pred_mask 128 | ) 129 | else: 130 | raise NotImplementedError 131 | 132 | return predictions, vis_output -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | scipy 3 | shapely 4 | timm 5 | h5py 6 | wandb 7 | fire 8 | opencv-python 9 | pandas 10 | braceexpand 11 | torch-ema 12 | torchmetrics==0.11.4 13 | setuptools==59.5.0 14 | webdataset>=0.2.5 -------------------------------------------------------------------------------- /resources/demo_samples/sample_03.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/resources/demo_samples/sample_03.jpeg -------------------------------------------------------------------------------- /resources/ovseg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/resources/ovseg.gif -------------------------------------------------------------------------------- /resources/proposal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/resources/proposal.png -------------------------------------------------------------------------------- /resources/pytorch-logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/resources/pytorch-logo-dark.png -------------------------------------------------------------------------------- /third_party/CLIP/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info 5 | .pytest_cache 6 | .ipynb_checkpoints 7 | 8 | thumbs.db 9 | .DS_Store 10 | .idea 11 | data/ 12 | *.pkl 13 | .theia 14 | tmp 15 | */tmp 16 | wandb/ 17 | */wadb 18 | .history -------------------------------------------------------------------------------- /third_party/CLIP/CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/third_party/CLIP/CLIP.png -------------------------------------------------------------------------------- /third_party/CLIP/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /third_party/CLIP/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include clip/bpe_simple_vocab_16e6.txt.gz 2 | -------------------------------------------------------------------------------- /third_party/CLIP/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /third_party/CLIP/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ov-seg/36f49d496714998058d115ffb6172d9d84c59065/third_party/CLIP/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /third_party/CLIP/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join( 13 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" 14 | ) 15 | 16 | 17 | @lru_cache() 18 | def bytes_to_unicode(): 19 | """ 20 | Returns list of utf-8 byte and a corresponding list of unicode strings. 21 | The reversible bpe codes work on unicode strings. 22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 24 | This is a signficant percentage of your normal, say, 32K bpe vocab. 25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 26 | And avoids mapping to whitespace/control characters the bpe code barfs on. 27 | """ 28 | bs = ( 29 | list(range(ord("!"), ord("~") + 1)) 30 | + list(range(ord("¡"), ord("¬") + 1)) 31 | + list(range(ord("®"), ord("ÿ") + 1)) 32 | ) 33 | cs = bs[:] 34 | n = 0 35 | for b in range(2 ** 8): 36 | if b not in bs: 37 | bs.append(b) 38 | cs.append(2 ** 8 + n) 39 | n += 1 40 | cs = [chr(n) for n in cs] 41 | return dict(zip(bs, cs)) 42 | 43 | 44 | def get_pairs(word): 45 | """Return set of symbol pairs in a word. 46 | Word is represented as tuple of symbols (symbols being variable-length strings). 47 | """ 48 | pairs = set() 49 | prev_char = word[0] 50 | for char in word[1:]: 51 | pairs.add((prev_char, char)) 52 | prev_char = char 53 | return pairs 54 | 55 | 56 | def basic_clean(text): 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r"\s+", " ", text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe()): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 73 | merges = merges[1 : 49152 - 256 - 2 + 1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v + "" for v in vocab] 77 | for merge in merges: 78 | vocab.append("".join(merge)) 79 | vocab.extend(["<|startoftext|>", "<|endoftext|>"]) 80 | self.encoder = dict(zip(vocab, range(len(vocab)))) 81 | self.decoder = {v: k for k, v in self.encoder.items()} 82 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 83 | self.cache = { 84 | "<|startoftext|>": "<|startoftext|>", 85 | "<|endoftext|>": "<|endoftext|>", 86 | } 87 | self.pat = re.compile( 88 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 89 | re.IGNORECASE, 90 | ) 91 | 92 | def bpe(self, token): 93 | if token in self.cache: 94 | return self.cache[token] 95 | word = tuple(token[:-1]) + (token[-1] + "",) 96 | pairs = get_pairs(word) 97 | 98 | if not pairs: 99 | return token + "" 100 | 101 | while True: 102 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 103 | if bigram not in self.bpe_ranks: 104 | break 105 | first, second = bigram 106 | new_word = [] 107 | i = 0 108 | while i < len(word): 109 | try: 110 | j = word.index(first, i) 111 | new_word.extend(word[i:j]) 112 | i = j 113 | except: 114 | new_word.extend(word[i:]) 115 | break 116 | 117 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 118 | new_word.append(first + second) 119 | i += 2 120 | else: 121 | new_word.append(word[i]) 122 | i += 1 123 | new_word = tuple(new_word) 124 | word = new_word 125 | if len(word) == 1: 126 | break 127 | else: 128 | pairs = get_pairs(word) 129 | word = " ".join(word) 130 | self.cache[token] = word 131 | return word 132 | 133 | def encode(self, text): 134 | bpe_tokens = [] 135 | text = whitespace_clean(basic_clean(text)).lower() 136 | for token in re.findall(self.pat, text): 137 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 138 | bpe_tokens.extend( 139 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") 140 | ) 141 | return bpe_tokens 142 | 143 | def decode(self, tokens): 144 | text = "".join([self.decoder[token] for token in tokens]) 145 | text = ( 146 | bytearray([self.byte_decoder[c] for c in text]) 147 | .decode("utf-8", errors="replace") 148 | .replace("", " ") 149 | ) 150 | return text 151 | -------------------------------------------------------------------------------- /third_party/CLIP/requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | torch 5 | torchvision 6 | -------------------------------------------------------------------------------- /third_party/CLIP/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="clip", 8 | py_modules=["clip"], 9 | version="1.0", 10 | description="", 11 | author="OpenAI", 12 | packages=find_packages(exclude=["tests*"]), 13 | install_requires=[ 14 | str(r) 15 | for r in pkg_resources.parse_requirements( 16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 17 | ) 18 | ], 19 | include_package_data=True, 20 | extras_require={"dev": ["pytest"]}, 21 | ) 22 | -------------------------------------------------------------------------------- /third_party/CLIP/tests/test_consistency.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from PIL import Image 5 | 6 | import clip 7 | 8 | 9 | @pytest.mark.parametrize("model_name", clip.available_models()) 10 | def test_consistency(model_name): 11 | device = "cpu" 12 | jit_model, transform = clip.load(model_name, device=device, jit=True) 13 | py_model, _ = clip.load(model_name, device=device, jit=False) 14 | 15 | image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device) 16 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) 17 | 18 | with torch.no_grad(): 19 | logits_per_image, _ = jit_model(image, text) 20 | jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 21 | 22 | logits_per_image, _ = py_model(image, text) 23 | py_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 24 | 25 | assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1) 26 | -------------------------------------------------------------------------------- /tools/convert-pretrained-clip-model-to-d2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import pickle as pkl 5 | import sys 6 | 7 | import torch 8 | 9 | """ 10 | Usage: 11 | # download pretrained swin model: 12 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth 13 | # run the conversion 14 | ./convert-pretrained-model-to-d2.py swin_tiny_patch4_window7_224.pth swin_tiny_patch4_window7_224.pkl 15 | # Then, use swin_tiny_patch4_window7_224.pkl with the following changes in config: 16 | MODEL: 17 | WEIGHTS: "/path/to/swin_tiny_patch4_window7_224.pkl" 18 | INPUT: 19 | FORMAT: "RGB" 20 | """ 21 | 22 | 23 | def transform(path): 24 | model = torch.load(path, map_location="cpu") 25 | print(f"loading {path}......") 26 | state_dict = model["model"] 27 | state_dict = { 28 | k.replace("visual_model.", ""): v 29 | for k, v in state_dict.items() 30 | if k.startswith("visual_model") 31 | } 32 | source_keys = [k for k in state_dict.keys() if "relative_coords" in k] 33 | for k in source_keys: 34 | state_dict[ 35 | k.replace("relative_coords", "relative_position_index") 36 | ] = state_dict[k] 37 | del state_dict[k] 38 | 39 | source_keys = [k for k in state_dict.keys() if "atten_mask_matrix" in k] 40 | for k in source_keys: 41 | state_dict[k.replace("atten_mask_matrix", "attn_mask")] = state_dict[k] 42 | del state_dict[k] 43 | 44 | source_keys = [k for k in state_dict.keys() if "rel_pos_embed_table" in k] 45 | for k in source_keys: 46 | state_dict[ 47 | k.replace("rel_pos_embed_table", "relative_position_bias_table") 48 | ] = state_dict[k] 49 | del state_dict[k] 50 | 51 | source_keys = [k for k in state_dict.keys() if "channel_reduction" in k] 52 | for k in source_keys: 53 | state_dict[k.replace("channel_reduction", "reduction")] = state_dict[k] 54 | del state_dict[k] 55 | return { 56 | k if k.startswith("backbone.") else "backbone." + k: v 57 | for k, v in state_dict.items() 58 | } 59 | 60 | 61 | if __name__ == "__main__": 62 | input = sys.argv[1] 63 | res = { 64 | "model": transform(input), 65 | "__author__": "third_party", 66 | "matching_heuristics": True, 67 | } 68 | with open(sys.argv[2], "wb") as f: 69 | pkl.dump(res, f) 70 | -------------------------------------------------------------------------------- /tools/convert-pretrained-swin-model-to-d2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import pickle as pkl 5 | import sys 6 | 7 | import torch 8 | 9 | """ 10 | Usage: 11 | # download pretrained swin model: 12 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth 13 | # run the conversion 14 | ./convert-pretrained-model-to-d2.py swin_tiny_patch4_window7_224.pth swin_tiny_patch4_window7_224.pkl 15 | # Then, use swin_tiny_patch4_window7_224.pkl with the following changes in config: 16 | MODEL: 17 | WEIGHTS: "/path/to/swin_tiny_patch4_window7_224.pkl" 18 | INPUT: 19 | FORMAT: "RGB" 20 | """ 21 | 22 | if __name__ == "__main__": 23 | input = sys.argv[1] 24 | 25 | obj = torch.load(input, map_location="cpu")["model"] 26 | 27 | res = {"model": obj, "__author__": "third_party", "matching_heuristics": True} 28 | 29 | with open(sys.argv[2], "wb") as f: 30 | pkl.dump(res, f) 31 | -------------------------------------------------------------------------------- /tools/convert-torchvision-to-d2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import pickle as pkl 5 | import sys 6 | 7 | import torch 8 | 9 | """ 10 | Usage: 11 | # download one of the ResNet{18,34,50,101,152} models from torchvision: 12 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth 13 | # run the conversion 14 | ./convert-torchvision-to-d2.py r50.pth r50.pkl 15 | # Then, use r50.pkl with the following changes in config: 16 | MODEL: 17 | WEIGHTS: "/path/to/r50.pkl" 18 | PIXEL_MEAN: [123.675, 116.280, 103.530] 19 | PIXEL_STD: [58.395, 57.120, 57.375] 20 | RESNETS: 21 | DEPTH: 50 22 | STRIDE_IN_1X1: False 23 | INPUT: 24 | FORMAT: "RGB" 25 | These models typically produce slightly worse results than the 26 | pre-trained ResNets we use in official configs, which are the 27 | original ResNet models released by MSRA. 28 | """ 29 | 30 | if __name__ == "__main__": 31 | input = sys.argv[1] 32 | 33 | obj = torch.load(input, map_location="cpu") 34 | 35 | newmodel = {} 36 | for k in list(obj.keys()): 37 | old_k = k 38 | if "layer" not in k: 39 | k = "stem." + k 40 | for t in [1, 2, 3, 4]: 41 | k = k.replace("layer{}".format(t), "res{}".format(t + 1)) 42 | for t in [1, 2, 3]: 43 | k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) 44 | k = k.replace("downsample.0", "shortcut") 45 | k = k.replace("downsample.1", "shortcut.norm") 46 | print(old_k, "->", k) 47 | newmodel[k] = obj.pop(old_k).detach().numpy() 48 | 49 | res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True} 50 | 51 | with open(sys.argv[2], "wb") as f: 52 | pkl.dump(res, f) 53 | if obj: 54 | print("Unconverted keys:", obj.keys()) 55 | -------------------------------------------------------------------------------- /tools/ovseg_replace_clip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import torch 5 | from collections import OrderedDict 6 | 7 | 8 | # PATH to finetune clip model 9 | clip_ckpt = torch.load('/home/jeffliang/ov-seg/open_clip_training/src/logs/2023_05_29-16_27_38-mask_prompt_tuning-model_ViT-L-14-lr_0.05-b_32-j_4-p_amp/checkpoints/epoch_2.pt') 10 | 11 | new_model = OrderedDict() 12 | state_dict = clip_ckpt['state_dict'] 13 | 14 | for k, v in state_dict.items(): 15 | new_key = k.replace('module.','') 16 | new_model[new_key] = v 17 | 18 | # PATH to trained MaskFormer model 19 | ovseg_model = torch.load('/home/jeffliang/ov-seg/weights/ovseg_swinbase_vitL14_mpt_only.pth', 'cpu') 20 | 21 | for k, v in new_model.items(): 22 | new_k = 'clip_adapter.clip_model.' + k 23 | if new_k in ovseg_model['model'].keys(): 24 | ovseg_model['model'][new_k] = v 25 | else: 26 | print(f'{new_k} does not exist in ckpt') 27 | try: 28 | ovseg_model['model']['clip_adapter.clip_model.visual.mask_embedding'] = new_model['visual.mask_embedding'] 29 | print('clip_ckpt has mask_embedding, remember to set MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD True during OVSeg evaluation') 30 | except: 31 | print('clip_ckpt does not have mask_embedding, remember to set MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD False during OVSeg evaluation') 32 | 33 | torch.save(ovseg_model, '/home/jeffliang/ov-seg/weights/new_ovseg.pth') 34 | -------------------------------------------------------------------------------- /tools/sanity_check_ft_clip_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | # PATH to trained MaskFormer model (containing a CLIP classifier) 5 | ovseg_model = torch.load('/home/jeffliang/ov-seg/weights/ovseg_swinbase_vitL14_mpt_only.pth', 'cpu') 6 | 7 | # PATH to finetuned CLIP weights 8 | clip = OrderedDict() 9 | new_clip_model = torch.load('/home/jeffliang/ov-seg/open_clip_training/src/logs/2023_05_28-22_15_08-model_ViT-L-14-lr_5e-06-b_32-j_4-p_amp/checkpoints/epoch_1.pt', 'cpu') 10 | 11 | new_clip_model = new_clip_model['state_dict'] 12 | new_clip = OrderedDict() 13 | 14 | 15 | for k, v in new_clip_model.items(): 16 | k = k.replace('module.', '') 17 | # if not k == 'visual.mask_embedding': 18 | new_clip[k] = v 19 | new_k = 'clip_adapter.clip_model.' + k 20 | if new_k in ovseg_model['model'].keys(): 21 | clip[k] = ovseg_model['model'][new_k] 22 | else: 23 | print(f'{new_k} does not exist in ckpt') 24 | 25 | for c_p, new_c_p in zip(clip.items(), new_clip.items()): 26 | k1, v1 = c_p 27 | k2, v2 = new_c_p 28 | assert k1 == k2 29 | diff = (v1-v2).abs().mean() 30 | print(f'{k1} difference {diff}') -------------------------------------------------------------------------------- /tools/search_thr_ensemble_w.sh: -------------------------------------------------------------------------------- 1 | or MASK_THR in 0.35 0.4 0.45 2 | o 3 | for ENSEMBLE_WEIGHT in 0.6 0.65 0.7 0.75 0.8 4 | do 5 | python train_net.py --num-gpu 8 --eval-only --config-file configs/ovseg_swinB_vitL_bs32_120k.yaml \ 6 | MODEL.WEIGHTS #PATH_of_ovseg_swinbase_vitL14_ft_mpt.pth DATASETS.TEST \(\"ade20k_sem_seg_val\"\) \ 7 | MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT $ENSEMBLE_WEIGHT MODEL.CLIP_ADAPTER.MASK_THR $MASK_THR 8 | done 9 | one 10 | 11 | 12 | -------------------------------------------------------------------------------- /tools/web_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 3 | 4 | import multiprocessing as mp 5 | 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from detectron2.config import get_cfg 10 | 11 | from detectron2.projects.deeplab import add_deeplab_config 12 | from detectron2.data.detection_utils import read_image 13 | from open_vocab_seg import add_ovseg_config 14 | from open_vocab_seg.utils import VisualizationDemo 15 | 16 | import gradio as gr 17 | 18 | def setup_cfg(config_file): 19 | # load config from file and command-line arguments 20 | cfg = get_cfg() 21 | add_deeplab_config(cfg) 22 | add_ovseg_config(cfg) 23 | cfg.merge_from_file(config_file) 24 | cfg.freeze() 25 | return cfg 26 | 27 | 28 | def inference(class_names, input_img): 29 | mp.set_start_method("spawn", force=True) 30 | config_file = './configs/ovseg_swinB_vitL_demo.yaml' 31 | cfg = setup_cfg(config_file) 32 | 33 | demo = VisualizationDemo(cfg) 34 | 35 | class_names = class_names.split(',') 36 | img = read_image(input_img, format="BGR") 37 | _, visualized_output = demo.run_on_image(img, class_names) 38 | 39 | return Image.fromarray(np.uint8(visualized_output.get_image())).convert('RGB') 40 | 41 | # demo = gr.Interface(fn=greet, inputs="text", outputs="text") 42 | # demo.launch() 43 | 44 | 45 | examples = [['Oculus, Ukulele', './resources/demo_samples/sample_03.jpeg'],] 46 | output_labels = ['segmentation map'] 47 | 48 | title = 'OVSeg' 49 | 50 | description = """ 51 | Gradio Demo for Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP \n 52 | You may click on of the examples or upload your own image. \n 53 | OVSeg could perform open vocabulary segmentation, you may input more classes (seperate by comma). 54 | """ 55 | 56 | article = """ 57 |

58 | 59 | Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP 60 | 61 | | 62 | Github Repo

63 | """ 64 | 65 | gr.Interface( 66 | inference, 67 | inputs=[ 68 | gr.inputs.Textbox( 69 | lines=1, placeholder=None, default='', label='class names'), 70 | gr.inputs.Image(type='filepath') 71 | ], 72 | outputs=gr.outputs.Image(label='segmentation map'), 73 | title=title, 74 | description=description, 75 | article=article, 76 | examples=examples).launch(enable_queue=True) 77 | --------------------------------------------------------------------------------