├── .gitignore ├── LICENSE ├── README.md ├── config ├── __init__.py ├── clip.yaml ├── cluster.yaml ├── insseg.yaml └── sam.yaml ├── data ├── resources │ ├── clip_model.png │ ├── imagenet-classes.txt │ ├── obj365-classes.txt │ ├── sam_model.png │ ├── test_airplane_insseg_result.jpg │ ├── test_baseball_insseg_result.jpg │ ├── test_bear_insseg_result.jpg │ ├── test_bear_result.jpg │ ├── test_bridge_insseg_result.jpg │ ├── test_dog_insseg_result.jpg │ ├── test_fish_insseg_result.jpg │ ├── test_frog_insseg_result_multi_label.jpg │ ├── test_glasses_insseg_result.jpg │ ├── test_horse_insseg_result.jpg │ ├── test_horse_insseg_result_after.jpg │ ├── test_horse_insseg_result_muti_label.jpg │ ├── test_horse_result.jpg │ ├── test_shoes_insseg_result.jpg │ ├── test_strawberry_insseg_result.jpg │ ├── test_strawberry_insseg_result_multi_label.jpg │ ├── test_tv_insseg_result.jpg │ └── test_tv_insseg_result_multi_label.jpg └── test_images │ ├── test_airplane.jpg │ ├── test_baseball.jpg │ ├── test_bear.jpg │ ├── test_bench.jpg │ ├── test_bridge.jpg │ ├── test_dog.jpg │ ├── test_fish.jpg │ ├── test_frog.jpg │ ├── test_glasses.jpg │ ├── test_horse.jpg │ ├── test_road.jpg │ ├── test_shoes.jpg │ ├── test_strawberry.jpg │ ├── test_tv.jpg │ └── test_vatetable.jpg ├── local_utils ├── __init__.py ├── __pycache__ │ └── __init__.cpython-38.pyc ├── config_utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── parse_config_utils.cpython-38.pyc │ └── parse_config_utils.py └── log_util │ ├── __init__.py │ └── init_logger.py ├── models ├── __init__.py ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py ├── cluster │ ├── __init__.py │ ├── cluster_model.py │ └── utils.py ├── detector │ ├── __init__.py │ ├── insseg_model.py │ └── utils.py └── sam │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ ├── tiny_vit_sam.py │ └── transformer.py │ ├── predictor.py │ └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── amg.cpython-38.pyc │ └── transforms.cpython-38.pyc │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── output └── cluster │ ├── test_baseball.jpg │ ├── test_baseball_clustered_mask.png │ ├── test_baseball_clustered_mask_add.png │ ├── test_baseball_ori_mask.png │ ├── test_baseball_ori_mask_add.png │ ├── test_bear.jpg │ ├── test_bear_clustered_mask.png │ ├── test_bear_clustered_mask_add.png │ ├── test_bear_ori_mask.png │ ├── test_bear_ori_mask_add.png │ ├── test_horse.jpg │ ├── test_horse_clustered_mask.png │ ├── test_horse_clustered_mask_add.png │ ├── test_horse_ori_mask.png │ └── test_horse_ori_mask_add.png ├── requirements.txt ├── scripts └── download_pretrained_ckpt.sh └── tools ├── __init__.py ├── cluster_sam.py └── sam_clip_text_seg.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea 132 | log/* 133 | output/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MaybeShewill-CV 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segment-Anything-U-Specify 2 | Use SAM and CLIP model to segment unique instances you want. 3 | You may use this repo to segment any instances in the picture with 4 | text prompts. 5 | 6 | The main network architecture is as follows: 7 | 8 | `Clip Model Architecture` 9 | ![CLIP_MODEL](./data/resources/clip_model.png) 10 | 11 | `SAM Model Architecture` 12 | ![SAM](./data/resources/sam_model.png) 13 | 14 | ## Installation 15 | 16 | Install python packages via commands: 17 | ``` 18 | pip3 install -r requirements.txt 19 | ``` 20 | Download pretrained model weights 21 | ``` 22 | cd PROJECT_ROOT_DIR 23 | bash scripts/download_pretrained_ckpt.sh 24 | ``` 25 | 26 | ## Instance Segmentation With Text Prompts 27 | Instance segmentor first using sam model to get all obj's mask of the input image. Second using clip model to classify each mask with both 28 | image features and your text prompts features. 29 | 30 | ``` 31 | cd PROJECT_ROOT_DIR 32 | export PYTHONPATH=$PWD:$PYTHONPATH 33 | python tools/sam_clip_text_seg.py --input_image_path ./data/test_images/test_bear.jpg --text bear 34 | ``` 35 | 36 | `Bear Instance Segmentation Result, Text Prompt: bear` 37 | ![bear_insseg_result](./data/resources/test_bear_insseg_result.jpg) 38 | 39 | `Athelete Instance Segmentation Result, Text Prompt: athlete` 40 | ![athlete_insseg_result](./data/resources/test_baseball_insseg_result.jpg) 41 | 42 | `Horse Instance Segmentation Result, Text Prompt: horse` 43 | ![horse_insseg_result](./data/resources/test_horse_insseg_result_after.jpg) 44 | 45 | `Dog Instance Segmentation Result, Text Prompt: dog` 46 | ![dog_insseg_result](./data/resources/test_dog_insseg_result.jpg) 47 | 48 | `Fish Instance Segmentation Result, Text Prompt: fish` 49 | ![fish_insseg_result](./data/resources/test_fish_insseg_result.jpg) 50 | 51 | `Strawberry Instance Segmentaton Result, Text Prompt: strawberry` 52 | ![strawberry_insseg_result](./data/resources/test_strawberry_insseg_result.jpg) 53 | 54 | `Glasses Instance Segmentaton Result, Text Prompt: glasses` 55 | ![glasses_insseg_result](./data/resources/test_glasses_insseg_result.jpg) 56 | 57 | `Tv Instance Segmentaton Result, Text Prompt: television` 58 | ![tv_insseg_result](./data/resources/test_tv_insseg_result.jpg) 59 | 60 | `Shoes Instance Segmentaton Result, Text Prompt: shoe` 61 | ![shoes_insseg_result](./data/resources/test_shoes_insseg_result.jpg) 62 | 63 | `Bridge Instance Segmentaton Result, Text Prompt: bridge` 64 | ![bridge_insseg_result](./data/resources/test_bridge_insseg_result.jpg) 65 | 66 | `Airplane Instance Segmentaton Result, Text Prompt: airplane` 67 | ![airplane_insseg_result](./data/resources/test_airplane_insseg_result.jpg) 68 | 69 | ### Support Multiple Classes Segmentation All In Once ---- YOSO ---- You Only Segment Once 70 | ``` 71 | cd PROJECT_ROOT_DIR 72 | export PYTHONPATH=$PWD:$PYTHONPATH 73 | python tools/sam_clip_text_seg.py --input_image_path ./data/test_images/test_horse.jpg --text "horse,mountain,grass,sky,clouds,tree" --cls_score_thresh 0.5 --use_text_prefix 74 | ``` 75 | 76 | `Horse Instance Segmentation Result, Text Prompt: horse,mountain,grass,sky,clouds,tree` 77 | ![horse_insseg_result](./data/resources/test_horse_insseg_result_muti_label.jpg) 78 | `Tv Instance Segmentaton Result, Text Prompt: television,audio system,tape recorder,box` 79 | ![tv_insseg_result](./data/resources/test_tv_insseg_result_multi_label.jpg) 80 | `Strawberry Instance Segmentaton Result, Text Prompt: strawberry,grapefruit,spoon,wolfberry,oatmeal` 81 | ![strawberry_insseg_result](./data/resources/test_strawberry_insseg_result_multi_label.jpg) 82 | `Frog Instance Segmentaton Result, Text Prompt: frog,turtle,snail,eye` 83 | ![frog_insseg_result](./data/resources/test_frog_insseg_result_multi_label.jpg) 84 | 85 | #### Instance Segmentation Provement 86 | 87 | ##### 2023-04-21 improve background segmentation problem 88 | 89 | `Befor Optimize` 90 | ![before](./data/resources/test_horse_insseg_result.jpg) 91 | `After Optimize` 92 | ![after](./data/resources/test_horse_insseg_result_after.jpg) 93 | 94 | ## Unsupervised Cluster Semantic Objects From SAM Model 95 | Cluster first using sam model to get all obj's mask of the input image. Second using clip model to extract image features for each objects. Third calculate feature distance of every two object pairs. Finally using a similarity threshold to cluster source objects. 96 | 97 | To test the cluster simply run 98 | 99 | ``` 100 | cd PROJECT_ROOT_DIR 101 | export PYTHONPATH=$PWD:$PYTHONPATH 102 | python tools/cluster_sam.py --input_image_path ./data/test_images/test_bear.jpg --simi_thresh 0.82 103 | ``` 104 | 105 | `Bear Cluster Result` 106 | ![bear_cluster_result](./data/resources/test_bear_result.jpg) 107 | 108 | `Horse Cluster Result` 109 | ![horse_cluster_result](./data/resources/test_horse_result.jpg) 110 | 111 | Each row represents `source image`, `sam origin mask`, `ori masked image`, `clustered mask`, `cluster masked image` 112 | 113 | ## UPDATES 114 | 115 | ### 2023-07-04 Integrate MobileSAM 116 | 117 | Integrate MobileSAM into the pipeline for lightweight and faster inference. If you want to use mobile-sam to segment your 118 | image all you need to do is to modify `./config/sam.yaml` file. Modify the model name field to `vit_t` and modify the 119 | model weight file path to `./pretrained/sam/mobile_sam.pt` 120 | 121 | ## TODO 122 | - [x] Test different kinds of cluster method 123 | - [x] Using cluster result as input prompts to reseg the image via sam model 124 | - [ ] Merge embedding feats of global image and masked image 125 | 126 | ## Acknowledgement 127 | 128 | Most of the repo's code borrows from opeai's clip repo and facebook's segment-anything repo: 129 | 130 | - [CLIP](https://github.com/openai/CLIP) 131 | - [segment-anything](https://github.com/facebookresearch/segment-anything) 132 | 133 | ## Star History 134 | 135 | [![Star History Chart](https://api.star-history.com/svg?repos=MaybeShewill-CV/segment-anything-u-specify&type=Date)](https://star-history.com/#MaybeShewill-CV/segment-anything-u-specify&Date) 136 | 137 | ## Visitor Count 138 | 139 | ![Visitor Count](https://profile-counter.glitch.me/15725187_sam_clip/count.svg) 140 | 141 | ## Contact 142 | 143 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-10 下午5:34 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm Community Edition 8 | -------------------------------------------------------------------------------- /config/clip.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | CKPT_DIR: './pretrained/clip' 3 | DEVICE: 'cuda' 4 | NAME: 'ViT-B/32' 5 | LOG: 6 | SAVE_DIR: './log' 7 | LEVEL: INFO 8 | -------------------------------------------------------------------------------- /config/cluster.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | DEVICE: 'cuda' 3 | SAM: 4 | CFG_PATH: './config/sam.yaml' 5 | CLIP: 6 | CFG_PATH: './config/clip.yaml' 7 | CLUSTER: 8 | MAX_INPUT_SIZE: [1280, 1280] 9 | SIMILARITY_THRESH: 0.83 10 | TOP_K_MASK_COUNT: 50 11 | LOG: 12 | SAVE_DIR: './log' 13 | LEVEL: INFO 14 | -------------------------------------------------------------------------------- /config/insseg.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | DEVICE: 'cuda' 3 | SAM: 4 | CFG_PATH: './config/sam.yaml' 5 | CLIP: 6 | CFG_PATH: './config/clip.yaml' 7 | INS_SEG: 8 | MAX_INPUT_SIZE: [1280, 1280] 9 | TOP_K_MASK_COUNT: 100 10 | CLS_SCORE_THRESH: 0.95 11 | LOG: 12 | SAVE_DIR: './log' 13 | LEVEL: INFO 14 | -------------------------------------------------------------------------------- /config/sam.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | MODEL_NAME: 'vit_h' # only support ['vit_h', 'vit_l', 'vit_b', 'vit_t', 'default'] 3 | CKPT_PATH: './pretrained/sam/sam_vit_h_4b8939.pth' 4 | DEVICE: 'cuda' 5 | MASK_GENERATOR: 6 | PTS_PER_SIDE: 32 7 | PTS_PER_BATCH: 64 8 | PRED_IOU_THRESH: 0.88 9 | STABILITY_SCORE_THRESH: 0.95 10 | STABILITY_SCORE_OFFSET: 1.0 11 | BOX_NMS_THRESH: 0.7 12 | CROP_N_LAYERS: 0 13 | CROP_NMS_THRESH: 0.7 14 | CROP_OVERLAP_RATIO: 0.3413333 15 | CROP_N_POINTS_DOWNSCALE_FACTOR: 1 16 | POINT_GRIDS: None 17 | MIN_MASK_REGION_AERA: 200 18 | OUTPUT_MODE: 'binary_mask' 19 | LOG: 20 | SAVE_DIR: './log' 21 | LEVEL: INFO 22 | -------------------------------------------------------------------------------- /data/resources/clip_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/clip_model.png -------------------------------------------------------------------------------- /data/resources/obj365-classes.txt: -------------------------------------------------------------------------------- 1 | human 1 2 | sneakers 2 3 | chair 3 4 | hat 4 5 | lamp 5 6 | bottle 6 7 | cabinet/shelf 7 8 | cup 8 9 | car 9 10 | glasses 10 11 | picture/frame 11 12 | desk 12 13 | handbag 13 14 | street lights 14 15 | book 15 16 | plate 16 17 | helmet 17 18 | leather shoes 18 19 | pillow 19 20 | glove 20 21 | potted plant 21 22 | bracelet 22 23 | flower 23 24 | monitor 24 25 | storage box 25 26 | plants pot/vase 26 27 | bench 27 28 | wine glass 28 29 | boots 29 30 | dining table 30 31 | umbrella 31 32 | boat 32 33 | flag 33 34 | speaker 34 35 | trash bin/can 35 36 | stool 36 37 | backpack 37 38 | sofa 38 39 | belt 39 40 | carpet 40 41 | basket 41 42 | towel/napkin 42 43 | slippers 43 44 | bowl 44 45 | barrel/bucket 45 46 | coffee table 46 47 | suv 47 48 | toy 48 49 | tie 49 50 | bed 50 51 | traffic light 51 52 | pen/pencil 52 53 | microphone 53 54 | sandals 54 55 | canned 55 56 | necklace 56 57 | mirror 57 58 | faucet 58 59 | bicycle 59 60 | bread 60 61 | high heels 61 62 | ring 62 63 | van 63 64 | watch 64 65 | combine with bowl 65 66 | sink 66 67 | horse 67 68 | fish 68 69 | apple 69 70 | traffic sign 70 71 | camera 71 72 | candle 72 73 | stuffed animal 73 74 | cake 74 75 | motorbike/motorcycle 75 76 | wild bird 76 77 | laptop 77 78 | knife 78 79 | cellphone 79 80 | paddle 80 81 | truck 81 82 | cow 82 83 | power outlet 83 84 | clock 84 85 | drum 85 86 | fork 86 87 | bus 87 88 | hanger 88 89 | nightstand 89 90 | pot/pan 90 91 | sheep 91 92 | guitar 92 93 | traffic cone 93 94 | tea pot 94 95 | keyboard 95 96 | tripod 96 97 | hockey stick 97 98 | fan 98 99 | dog 99 100 | spoon 100 101 | blackboard/whiteboard 101 102 | balloon 102 103 | air conditioner 103 104 | cymbal 104 105 | mouse 105 106 | telephone 106 107 | pickup truck 107 108 | orange 108 109 | banana 109 110 | airplane 110 111 | luggage 111 112 | skis 112 113 | soccer 113 114 | trolley 114 115 | oven 115 116 | remote 116 117 | combine with glove 117 118 | paper towel 118 119 | refrigerator 119 120 | train 120 121 | tomato 121 122 | machinery vehicle 122 123 | tent 123 124 | shampoo/shower gel 124 125 | head phone 125 126 | lantern 126 127 | donut 127 128 | cleaning products 128 129 | sailboat 129 130 | tangerine 130 131 | pizza 131 132 | kite 132 133 | computer box 133 134 | elephant 134 135 | toiletries 135 136 | gas stove 136 137 | broccoli 137 138 | toilet 138 139 | stroller 139 140 | shovel 140 141 | baseball bat 141 142 | microwave 142 143 | skateboard 143 144 | surfboard 144 145 | surveillance camera 145 146 | gun 146 147 | Life saver 147 148 | cat 148 149 | lemon 149 150 | liquid soap 150 151 | zebra 151 152 | duck 152 153 | sports car 153 154 | giraffe 154 155 | pumpkin 155 156 | Accordion/keyboard/piano 156 157 | radiator 157 158 | converter 158 159 | tissue 159 160 | carrot 160 161 | washing machine 161 162 | vent 162 163 | cookies 163 164 | cutting/chopping board 164 165 | tennis racket 165 166 | candy 166 167 | skating and skiing shoes 167 168 | scissors 168 169 | folder 169 170 | baseball 170 171 | strawberry 171 172 | bow tie 172 173 | pigeon 173 174 | pepper 174 175 | coffee machine 175 176 | bathtub 176 177 | snowboard 177 178 | suitcase 178 179 | grapes 179 180 | ladder 180 181 | pear 181 182 | american football 182 183 | basketball 183 184 | potato 184 185 | paint brush 185 186 | printer 186 187 | billiards 187 188 | fire hydrant 188 189 | goose 189 190 | projector 190 191 | sausage 191 192 | fire extinguisher 192 193 | extension cord 193 194 | facial mask 194 195 | tennis ball 195 196 | chopsticks 196 197 | Electronic stove and gas stove 197 198 | pie 198 199 | frisbee 199 200 | kettle 200 201 | hamburger 201 202 | golf club 202 203 | cucumber 203 204 | clutch 204 205 | blender 205 206 | tong 206 207 | slide 207 208 | hot dog 208 209 | toothbrush 209 210 | facial cleanser 210 211 | mango 211 212 | deer 212 213 | egg 213 214 | violin 214 215 | marker 215 216 | ship 216 217 | chicken 217 218 | onion 218 219 | ice cream 219 220 | tape 220 221 | wheelchair 221 222 | plum 222 223 | bar soap 223 224 | scale 224 225 | watermelon 225 226 | cabbage 226 227 | router/modem 227 228 | golf ball 228 229 | pine apple 229 230 | crane 230 231 | fire truck 231 232 | peach 232 233 | cello 233 234 | notepaper 234 235 | tricycle 235 236 | toaster 236 237 | helicopter 237 238 | green beans 238 239 | brush 239 240 | carriage 240 241 | cigar 241 242 | earphone 242 243 | penguin 243 244 | hurdle 244 245 | swing 245 246 | radio 246 247 | CD 247 248 | parking meter 248 249 | swan 249 250 | garlic 250 251 | french fries 251 252 | horn 252 253 | avocado 253 254 | saxophone 254 255 | trumpet 255 256 | sandwich 256 257 | cue 257 258 | kiwi fruit 258 259 | bear 259 260 | fishing rod 260 261 | cherry 261 262 | tablet 262 263 | green vegetables 263 264 | nuts 264 265 | corn 265 266 | key 266 267 | screwdriver 267 268 | globe 268 269 | broom 269 270 | pliers 270 271 | hammer 271 272 | volleyball 272 273 | eggplant 273 274 | trophy 274 275 | board eraser 275 276 | dates 276 277 | rice 277 278 | tape measure/ruler 278 279 | dumbbell 279 280 | hamimelon 280 281 | stapler 281 282 | camel 282 283 | lettuce 283 284 | goldfish 284 285 | meat balls 285 286 | medal 286 287 | toothpaste 287 288 | antelope 288 289 | shrimp 289 290 | rickshaw 290 291 | trombone 291 292 | pomegranate 292 293 | coconut 293 294 | jellyfish 294 295 | mushroom 295 296 | calculator 296 297 | treadmill 297 298 | butterfly 298 299 | egg tart 299 300 | cheese 300 301 | pomelo 301 302 | pig 302 303 | race car 303 304 | rice cooker 304 305 | tuba 305 306 | crosswalk sign 306 307 | papaya 307 308 | hair dryer 308 309 | green onion 309 310 | chips 310 311 | dolphin 311 312 | sushi 312 313 | urinal 313 314 | donkey 314 315 | electric drill 315 316 | spring rolls 316 317 | tortoise/turtle 317 318 | parrot 318 319 | flute 319 320 | measuring cup 320 321 | shark 321 322 | steak 322 323 | poker card 323 324 | binoculars 324 325 | llama 325 326 | radish 326 327 | noodles 327 328 | mop 328 329 | yak 329 330 | crab 330 331 | microscope 331 332 | barbell 332 333 | Bread/bun 333 334 | baozi 334 335 | lion 335 336 | red cabbage 336 337 | polar bear 337 338 | lighter 338 339 | mangosteen 339 340 | seal 340 341 | comb 341 342 | eraser 342 343 | pitaya 343 344 | scallop 344 345 | pencil case 345 346 | saw 346 347 | table tennis paddle 347 348 | okra 348 349 | starfish 349 350 | monkey 350 351 | eagle 351 352 | durian 352 353 | rabbit 353 354 | game board 354 355 | french horn 355 356 | ambulance 356 357 | asparagus 357 358 | hoverboard 358 359 | pasta 359 360 | target 360 361 | hotair balloon 361 362 | chainsaw 362 363 | lobster 363 364 | iron 364 365 | flashlight 365 366 | background 366 -------------------------------------------------------------------------------- /data/resources/sam_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/sam_model.png -------------------------------------------------------------------------------- /data/resources/test_airplane_insseg_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_airplane_insseg_result.jpg -------------------------------------------------------------------------------- /data/resources/test_baseball_insseg_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_baseball_insseg_result.jpg -------------------------------------------------------------------------------- /data/resources/test_bear_insseg_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_bear_insseg_result.jpg -------------------------------------------------------------------------------- /data/resources/test_bear_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_bear_result.jpg -------------------------------------------------------------------------------- /data/resources/test_bridge_insseg_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_bridge_insseg_result.jpg -------------------------------------------------------------------------------- /data/resources/test_dog_insseg_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_dog_insseg_result.jpg -------------------------------------------------------------------------------- /data/resources/test_fish_insseg_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_fish_insseg_result.jpg -------------------------------------------------------------------------------- /data/resources/test_frog_insseg_result_multi_label.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_frog_insseg_result_multi_label.jpg -------------------------------------------------------------------------------- /data/resources/test_glasses_insseg_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_glasses_insseg_result.jpg -------------------------------------------------------------------------------- /data/resources/test_horse_insseg_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_horse_insseg_result.jpg -------------------------------------------------------------------------------- /data/resources/test_horse_insseg_result_after.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_horse_insseg_result_after.jpg -------------------------------------------------------------------------------- /data/resources/test_horse_insseg_result_muti_label.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_horse_insseg_result_muti_label.jpg -------------------------------------------------------------------------------- /data/resources/test_horse_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_horse_result.jpg -------------------------------------------------------------------------------- /data/resources/test_shoes_insseg_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_shoes_insseg_result.jpg -------------------------------------------------------------------------------- /data/resources/test_strawberry_insseg_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_strawberry_insseg_result.jpg -------------------------------------------------------------------------------- /data/resources/test_strawberry_insseg_result_multi_label.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_strawberry_insseg_result_multi_label.jpg -------------------------------------------------------------------------------- /data/resources/test_tv_insseg_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_tv_insseg_result.jpg -------------------------------------------------------------------------------- /data/resources/test_tv_insseg_result_multi_label.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/resources/test_tv_insseg_result_multi_label.jpg -------------------------------------------------------------------------------- /data/test_images/test_airplane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_airplane.jpg -------------------------------------------------------------------------------- /data/test_images/test_baseball.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_baseball.jpg -------------------------------------------------------------------------------- /data/test_images/test_bear.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_bear.jpg -------------------------------------------------------------------------------- /data/test_images/test_bench.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_bench.jpg -------------------------------------------------------------------------------- /data/test_images/test_bridge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_bridge.jpg -------------------------------------------------------------------------------- /data/test_images/test_dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_dog.jpg -------------------------------------------------------------------------------- /data/test_images/test_fish.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_fish.jpg -------------------------------------------------------------------------------- /data/test_images/test_frog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_frog.jpg -------------------------------------------------------------------------------- /data/test_images/test_glasses.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_glasses.jpg -------------------------------------------------------------------------------- /data/test_images/test_horse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_horse.jpg -------------------------------------------------------------------------------- /data/test_images/test_road.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_road.jpg -------------------------------------------------------------------------------- /data/test_images/test_shoes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_shoes.jpg -------------------------------------------------------------------------------- /data/test_images/test_strawberry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_strawberry.jpg -------------------------------------------------------------------------------- /data/test_images/test_tv.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_tv.jpg -------------------------------------------------------------------------------- /data/test_images/test_vatetable.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/data/test_images/test_vatetable.jpg -------------------------------------------------------------------------------- /local_utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-10 下午5:35 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm Community Edition 8 | -------------------------------------------------------------------------------- /local_utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/local_utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /local_utils/config_utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/16 下午1:24 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/bisenetv2-tensorflow 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm -------------------------------------------------------------------------------- /local_utils/config_utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/local_utils/config_utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /local_utils/config_utils/__pycache__/parse_config_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/local_utils/config_utils/__pycache__/parse_config_utils.cpython-38.pyc -------------------------------------------------------------------------------- /local_utils/config_utils/parse_config_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/13 上午11:17 4 | # @Author : PaddlePaddle 5 | # @Site : https://github.com/PaddlePaddle/PaddleSeg 6 | # @File : parse_config_utils.py 7 | # @IDE: PyCharm 8 | """ 9 | Parse config utils 10 | """ 11 | import os 12 | import yaml 13 | import json 14 | import codecs 15 | from ast import literal_eval 16 | 17 | 18 | class Config(dict): 19 | """ 20 | Config class 21 | """ 22 | def __init__(self, *args, **kwargs): 23 | """ 24 | init class 25 | :param args: 26 | :param kwargs: 27 | """ 28 | if 'config_path' in kwargs: 29 | config_content = self._load_config_file(kwargs['config_path']) 30 | super(Config, self).__init__(config_content) 31 | else: 32 | super(Config, self).__init__(*args, **kwargs) 33 | self.immutable = False 34 | 35 | def __setattr__(self, key, value, create_if_not_exist=True): 36 | """ 37 | 38 | :param key: 39 | :param value: 40 | :param create_if_not_exist: 41 | :return: 42 | """ 43 | if key in ["immutable"]: 44 | self.__dict__[key] = value 45 | return 46 | 47 | t = self 48 | keylist = key.split(".") 49 | for k in keylist[:-1]: 50 | t = t.__getattr__(k, create_if_not_exist) 51 | 52 | t.__getattr__(keylist[-1], create_if_not_exist) 53 | t[keylist[-1]] = value 54 | 55 | def __getattr__(self, key, create_if_not_exist=True): 56 | """ 57 | 58 | :param key: 59 | :param create_if_not_exist: 60 | :return: 61 | """ 62 | if key in ["immutable"]: 63 | return self.__dict__[key] 64 | 65 | if key not in self: 66 | if not create_if_not_exist: 67 | raise KeyError 68 | self[key] = Config() 69 | if isinstance(self[key], dict): 70 | self[key] = Config(self[key]) 71 | return self[key] 72 | 73 | def __setitem__(self, key, value): 74 | """ 75 | 76 | :param key: 77 | :param value: 78 | :return: 79 | """ 80 | if self.immutable: 81 | raise AttributeError( 82 | 'Attempted to set "{}" to "{}", but SegConfig is immutable'. 83 | format(key, value)) 84 | # 85 | if isinstance(value, str): 86 | try: 87 | value = literal_eval(value) 88 | except ValueError: 89 | pass 90 | except SyntaxError: 91 | pass 92 | super(Config, self).__setitem__(key, value) 93 | 94 | @staticmethod 95 | def _load_config_file(config_file_path): 96 | """ 97 | 98 | :param config_file_path 99 | :return: 100 | """ 101 | if not os.access(config_file_path, os.R_OK): 102 | raise OSError('Config file: {:s}, can not be read'.format(config_file_path)) 103 | with open(config_file_path, 'r') as f: 104 | config_content = yaml.safe_load(f) 105 | 106 | return config_content 107 | 108 | def update_from_config(self, other): 109 | """ 110 | 111 | :param other: 112 | :return: 113 | """ 114 | if isinstance(other, dict): 115 | other = Config(other) 116 | assert isinstance(other, Config) 117 | diclist = [("", other)] 118 | while len(diclist): 119 | prefix, tdic = diclist[0] 120 | diclist = diclist[1:] 121 | for key, value in tdic.items(): 122 | key = "{}.{}".format(prefix, key) if prefix else key 123 | if isinstance(value, dict): 124 | diclist.append((key, value)) 125 | continue 126 | try: 127 | self.__setattr__(key, value, create_if_not_exist=False) 128 | except KeyError: 129 | raise KeyError('Non-existent config key: {}'.format(key)) 130 | 131 | def check_and_infer(self): 132 | """ 133 | 134 | :return: 135 | """ 136 | if self.DATASET.IMAGE_TYPE in ['rgb', 'gray']: 137 | self.DATASET.DATA_DIM = 3 138 | elif self.DATASET.IMAGE_TYPE in ['rgba']: 139 | self.DATASET.DATA_DIM = 4 140 | else: 141 | raise KeyError( 142 | 'DATASET.IMAGE_TYPE config error, only support `rgb`, `gray` and `rgba`' 143 | ) 144 | if self.MEAN is not None: 145 | self.DATASET.PADDING_VALUE = [x * 255.0 for x in self.MEAN] 146 | 147 | if not self.TRAIN_CROP_SIZE: 148 | raise ValueError( 149 | 'TRAIN_CROP_SIZE is empty! Please set a pair of values in format (width, height)' 150 | ) 151 | 152 | if not self.EVAL_CROP_SIZE: 153 | raise ValueError( 154 | 'EVAL_CROP_SIZE is empty! Please set a pair of values in format (width, height)' 155 | ) 156 | 157 | # Ensure file list is use UTF-8 encoding 158 | train_sets = codecs.open(self.DATASET.TRAIN_FILE_LIST, 'r', 'utf-8').readlines() 159 | val_sets = codecs.open(self.DATASET.VAL_FILE_LIST, 'r', 'utf-8').readlines() 160 | test_sets = codecs.open(self.DATASET.TEST_FILE_LIST, 'r', 'utf-8').readlines() 161 | self.DATASET.TRAIN_TOTAL_IMAGES = len(train_sets) 162 | self.DATASET.VAL_TOTAL_IMAGES = len(val_sets) 163 | self.DATASET.TEST_TOTAL_IMAGES = len(test_sets) 164 | 165 | if self.MODEL.MODEL_NAME == 'icnet' and \ 166 | len(self.MODEL.MULTI_LOSS_WEIGHT) != 3: 167 | self.MODEL.MULTI_LOSS_WEIGHT = [1.0, 0.4, 0.16] 168 | 169 | def update_from_list(self, config_list): 170 | if len(config_list) % 2 != 0: 171 | raise ValueError( 172 | "Command line options config format error! Please check it: {}". 173 | format(config_list)) 174 | for key, value in zip(config_list[0::2], config_list[1::2]): 175 | try: 176 | self.__setattr__(key, value, create_if_not_exist=False) 177 | except KeyError: 178 | raise KeyError('Non-existent config key: {}'.format(key)) 179 | 180 | def update_from_file(self, config_file): 181 | """ 182 | 183 | :param config_file: 184 | :return: 185 | """ 186 | with codecs.open(config_file, 'r', 'utf-8') as f: 187 | dic = yaml.safe_load(f) 188 | self.update_from_config(dic) 189 | 190 | def set_immutable(self, immutable): 191 | """ 192 | 193 | :param immutable: 194 | :return: 195 | """ 196 | self.immutable = immutable 197 | for value in self.values(): 198 | if isinstance(value, Config): 199 | value.set_immutable(immutable) 200 | 201 | def is_immutable(self): 202 | """ 203 | 204 | :return: 205 | """ 206 | return self.immutable 207 | 208 | def dump_to_json_file(self, f_obj): 209 | """ 210 | 211 | :param f_obj: 212 | :return: 213 | """ 214 | origin_dict = dict() 215 | for key, val in self.items(): 216 | if isinstance(val, Config): 217 | origin_dict.update({key: dict(val)}) 218 | elif isinstance(val, dict): 219 | origin_dict.update({key: val}) 220 | else: 221 | raise TypeError('Not supported type {}'.format(type(val))) 222 | return json.dump(origin_dict, f_obj) 223 | 224 | 225 | sam_cfg = Config(config_path='./config/sam.yaml') 226 | clip_cfg = Config(config_path='./config/clip.yaml') 227 | cluster_cfg = Config(config_path='./config/cluster.yaml') 228 | -------------------------------------------------------------------------------- /local_utils/log_util/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/11/14 下午8:18 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/bisenetv2-tensorflow 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm 8 | """ 9 | log utils 10 | """ -------------------------------------------------------------------------------- /local_utils/log_util/init_logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/11/14 下午9:04 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/bisenetv2-tensorflow 6 | # @File : init_logger.py 7 | # @IDE: PyCharm 8 | """ 9 | Log relative utils 10 | """ 11 | import os.path as ops 12 | import time 13 | 14 | import loguru 15 | 16 | from local_utils.config_utils import parse_config_utils 17 | 18 | CFG = parse_config_utils.sam_cfg 19 | 20 | 21 | def get_logger(log_file_name_prefix): 22 | """ 23 | init logger 24 | :param log_file_name_prefix: log file prefix 25 | :return: 26 | """ 27 | start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) 28 | log_file_name = '{:s}_{:s}.log'.format(log_file_name_prefix, start_time) 29 | log_file_path = ops.join(CFG.LOG.SAVE_DIR, log_file_name) 30 | 31 | logger = loguru.logger 32 | log_level = 'INFO' 33 | if CFG.LOG.LEVEL == "DEBUG": 34 | log_level = 'DEBUG' 35 | elif CFG.LOG.LEVEL == "WARNING": 36 | log_level = 'WARNING' 37 | elif CFG.LOG.LEVEL == "ERROR": 38 | log_level = 'ERROR' 39 | 40 | logger.add( 41 | log_file_path, 42 | level=log_level, 43 | format="{time} {level} {message}", 44 | retention="10 days", 45 | rotation="1 week" 46 | ) 47 | 48 | return logger 49 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-10 上午11:50 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | model 10 | """ 11 | import torch 12 | 13 | from local_utils.config_utils import parse_config_utils 14 | from models.clip import clip 15 | import models.sam as sam_builder 16 | 17 | 18 | def build_sam_model(cfg): 19 | """ 20 | 21 | :param cfg: 22 | :return: 23 | """ 24 | supported_model_name = ['vit_h', 'vit_l', 'vit_b', 'vit_t', 'default'] 25 | model_name = cfg.MODEL.MODEL_NAME 26 | if model_name not in supported_model_name: 27 | raise ValueError('not supported model: {:s}, only supported {}'.format(model_name, supported_model_name)) 28 | ckpt_path = cfg.MODEL.CKPT_PATH 29 | device = torch.device(cfg.MODEL.DEVICE) 30 | build_func = sam_builder.sam_model_registry[model_name] 31 | model = build_func(checkpoint=ckpt_path) 32 | return model.to(device) 33 | 34 | 35 | def build_sam_mask_generator(cfg): 36 | """ 37 | 38 | :param cfg: 39 | :return: 40 | """ 41 | sam = build_sam_model(cfg) 42 | points_per_side = cfg.MASK_GENERATOR.PTS_PER_SIDE 43 | points_per_batch = cfg.MASK_GENERATOR.PTS_PER_BATCH 44 | pred_iou_thresh = cfg.MASK_GENERATOR.PRED_IOU_THRESH 45 | stability_score_thresh = cfg.MASK_GENERATOR.STABILITY_SCORE_THRESH 46 | stability_score_offset = cfg.MASK_GENERATOR.STABILITY_SCORE_THRESH 47 | box_nms_thresh = cfg.MASK_GENERATOR.BOX_NMS_THRESH 48 | crop_n_layers = cfg.MASK_GENERATOR.CROP_N_LAYERS 49 | crop_nms_thresh = cfg.MASK_GENERATOR.CROP_NMS_THRESH 50 | crop_overlap_ratio = cfg.MASK_GENERATOR.CROP_OVERLAP_RATIO 51 | crop_n_points_downscale_factor = cfg.MASK_GENERATOR.CROP_N_POINTS_DOWNSCALE_FACTOR 52 | point_grids = None if cfg.MASK_GENERATOR.POINT_GRIDS.lower() == 'none' else cfg.MASK_GENERATOR.POINT_GRIDS 53 | min_mask_region_area = cfg.MASK_GENERATOR.MIN_MASK_REGION_AERA 54 | output_mode = cfg.MASK_GENERATOR.OUTPUT_MODE 55 | mask_generator = sam_builder.SamAutomaticMaskGenerator( 56 | sam, 57 | points_per_side=points_per_side, 58 | points_per_batch=points_per_batch, 59 | pred_iou_thresh=pred_iou_thresh, 60 | stability_score_thresh=stability_score_thresh, 61 | stability_score_offset=stability_score_offset, 62 | box_nms_thresh=box_nms_thresh, 63 | crop_n_layers=crop_n_layers, 64 | crop_nms_thresh=crop_nms_thresh, 65 | crop_overlap_ratio=crop_overlap_ratio, 66 | crop_n_points_downscale_factor=crop_n_points_downscale_factor, 67 | point_grids=point_grids, 68 | min_mask_region_area=min_mask_region_area, 69 | output_mode=output_mode 70 | ) 71 | return mask_generator 72 | 73 | 74 | def build_clip_model(cfg): 75 | """ 76 | 77 | :param cfg: 78 | :return: 79 | """ 80 | ckpt_dir = cfg.MODEL.CKPT_DIR 81 | device = cfg.MODEL.DEVICE 82 | model_name = cfg.MODEL.NAME 83 | model, preprocess = clip.load(model_name, device=device, download_root=ckpt_dir) 84 | 85 | return model, preprocess 86 | 87 | 88 | import models.cluster as sam_clip_cluster 89 | 90 | 91 | def build_cluster(cfg): 92 | """ 93 | 94 | :param cfg: 95 | :return: 96 | """ 97 | sam_cfg_path = cfg.MODEL.SAM.CFG_PATH 98 | clip_cfg_path = cfg.MODEL.CLIP.CFG_PATH 99 | sam_cfg = parse_config_utils.Config(config_path=sam_cfg_path) 100 | clip_cfg = parse_config_utils.Config(config_path=clip_cfg_path) 101 | model = sam_clip_cluster.cluster_model.SamClipCluster( 102 | sam_cfg=sam_cfg, 103 | clip_cfg=clip_cfg, 104 | cluster_cfg=cfg 105 | ) 106 | 107 | return model 108 | 109 | 110 | import models.detector as sam_clip_insseg 111 | 112 | 113 | def build_sam_clip_text_ins_segmentor(cfg): 114 | """ 115 | 116 | :param cfg: 117 | :return: 118 | """ 119 | sam_cfg_path = cfg.MODEL.SAM.CFG_PATH 120 | clip_cfg_path = cfg.MODEL.CLIP.CFG_PATH 121 | sam_cfg = parse_config_utils.Config(config_path=sam_cfg_path) 122 | clip_cfg = parse_config_utils.Config(config_path=clip_cfg_path) 123 | model = sam_clip_insseg.insseg_model.SamClipInsSegmentor( 124 | sam_cfg=sam_cfg, 125 | clip_cfg=clip_cfg, 126 | insseg_cfg=cfg 127 | ) 128 | 129 | return model 130 | -------------------------------------------------------------------------------- /models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /models/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/models/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /models/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | # from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from models.clip.model import build_model 14 | from models.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | __all__ = ["available_models", "load", "tokenize"] 24 | _tokenizer = _Tokenizer() 25 | 26 | _MODELS = { 27 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 28 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 29 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 30 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 31 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 32 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 33 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 34 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 35 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _convert_image_to_rgb(image): 72 | return image.convert("RGB") 73 | 74 | 75 | def _transform(n_px): 76 | return Compose([ 77 | Resize(n_px, interpolation=BICUBIC), 78 | CenterCrop(n_px), 79 | _convert_image_to_rgb, 80 | ToTensor(), 81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 82 | ]) 83 | 84 | 85 | def available_models() -> List[str]: 86 | """Returns the names of available CLIP models""" 87 | return list(_MODELS.keys()) 88 | 89 | 90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 91 | """Load a CLIP model 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 97 | 98 | device : Union[str, torch.device] 99 | The device to put the loaded model 100 | 101 | jit : bool 102 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 103 | 104 | download_root: str 105 | path to download the model files; by default, it uses "~/.cache/clip" 106 | 107 | Returns 108 | ------- 109 | model : torch.nn.Module 110 | The CLIP model 111 | 112 | preprocess : Callable[[PIL.Image], torch.Tensor] 113 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 114 | """ 115 | if name in _MODELS: 116 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 117 | elif os.path.isfile(name): 118 | model_path = name 119 | else: 120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 121 | 122 | with open(model_path, 'rb') as opened_file: 123 | try: 124 | # loading JIT archive 125 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 126 | state_dict = None 127 | except RuntimeError: 128 | # loading saved state dict 129 | if jit: 130 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 131 | jit = False 132 | state_dict = torch.load(opened_file, map_location="cpu") 133 | 134 | if not jit: 135 | model = build_model(state_dict or model.state_dict()).to(device) 136 | if str(device) == "cpu": 137 | model.float() 138 | return model, _transform(model.visual.input_resolution) 139 | 140 | # patch the device names 141 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 142 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 143 | 144 | def patch_device(module): 145 | try: 146 | graphs = [module.graph] if hasattr(module, "graph") else [] 147 | except RuntimeError: 148 | graphs = [] 149 | 150 | if hasattr(module, "forward1"): 151 | graphs.append(module.forward1.graph) 152 | 153 | for graph in graphs: 154 | for node in graph.findAllNodes("prim::Constant"): 155 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 156 | node.copyAttributes(device_node) 157 | 158 | model.apply(patch_device) 159 | patch_device(model.encode_image) 160 | patch_device(model.encode_text) 161 | 162 | # patch dtype to float32 on CPU 163 | if str(device) == "cpu": 164 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 165 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 166 | float_node = float_input.node() 167 | 168 | def patch_float(module): 169 | try: 170 | graphs = [module.graph] if hasattr(module, "graph") else [] 171 | except RuntimeError: 172 | graphs = [] 173 | 174 | if hasattr(module, "forward1"): 175 | graphs.append(module.forward1.graph) 176 | 177 | for graph in graphs: 178 | for node in graph.findAllNodes("aten::to"): 179 | inputs = list(node.inputs()) 180 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 181 | if inputs[i].node()["value"] == 5: 182 | inputs[i].node().copyAttributes(float_node) 183 | 184 | model.apply(patch_float) 185 | patch_float(model.encode_image) 186 | patch_float(model.encode_text) 187 | 188 | model.float() 189 | 190 | return model, _transform(model.input_resolution.item()) 191 | 192 | 193 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 194 | """ 195 | Returns the tokenized representation of given input string(s) 196 | 197 | Parameters 198 | ---------- 199 | texts : Union[str, List[str]] 200 | An input string or a list of input strings to tokenize 201 | 202 | context_length : int 203 | The context length to use; all CLIP models use 77 as the context length 204 | 205 | truncate: bool 206 | Whether to truncate the text in case its encoding is longer than the context length 207 | 208 | Returns 209 | ------- 210 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 211 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 212 | """ 213 | if isinstance(texts, str): 214 | texts = [texts] 215 | 216 | sot_token = _tokenizer.encoder["<|startoftext|>"] 217 | eot_token = _tokenizer.encoder["<|endoftext|>"] 218 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 219 | # if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 220 | # result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 221 | # else: 222 | # result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 223 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 224 | 225 | for i, tokens in enumerate(all_tokens): 226 | if len(tokens) > context_length: 227 | if truncate: 228 | tokens = tokens[:context_length] 229 | tokens[-1] = eot_token 230 | else: 231 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 232 | result[i, :len(tokens)] = torch.tensor(tokens) 233 | 234 | return result 235 | -------------------------------------------------------------------------------- /models/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /models/cluster/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-10 上午11:58 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | 10 | """ 11 | from models.cluster import cluster_model 12 | from models.cluster import utils 13 | -------------------------------------------------------------------------------- /models/cluster/cluster_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-10 下午1:57 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : cluster_model.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | cluster model 10 | """ 11 | import numpy as np 12 | import cv2 13 | from PIL import Image 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | from models.cluster import utils 18 | from models.clip import tokenize 19 | from models import build_clip_model 20 | from models import build_sam_mask_generator 21 | 22 | 23 | class SamClipCluster(object): 24 | """ 25 | cluster segment objects from sam's output 26 | """ 27 | def __init__(self, sam_cfg, clip_cfg, cluster_cfg): 28 | """ 29 | 30 | :param sam_cfg: 31 | :param clip_cfg: 32 | :param cluster_cfg: 33 | """ 34 | self.mask_generator = build_sam_mask_generator(sam_cfg) 35 | self.clip_model, self.clip_preprocess = build_clip_model(clip_cfg) 36 | self.device = torch.device(cluster_cfg.MODEL.DEVICE) 37 | self.top_k_objs = cluster_cfg.CLUSTER.TOP_K_MASK_COUNT 38 | self.similarity_thresh = cluster_cfg.CLUSTER.SIMILARITY_THRESH 39 | self.max_input_size = cluster_cfg.CLUSTER.MAX_INPUT_SIZE 40 | self.imagenet_cls_text_prompts = utils.generate_imagenet_classification_text_prompts() 41 | self.imagenet_cls_text_token = tokenize(self.imagenet_cls_text_prompts).to(self.device) 42 | 43 | def _generate_sam_mask(self, input_image: np.ndarray): 44 | """ 45 | 46 | :param input_image: 47 | :return: 48 | """ 49 | masks = self.mask_generator.generate(input_image) 50 | most_stable_mask = sorted(masks, key=lambda d: d['area']) 51 | if len(most_stable_mask) > self.top_k_objs: 52 | most_stable_mask = most_stable_mask[-self.top_k_objs:] 53 | sam_masks = { 54 | 'segmentations': [tmp['segmentation'] for tmp in most_stable_mask], 55 | 'bboxes': [tmp['bbox'] for tmp in most_stable_mask], 56 | 'stability_scores': [tmp['stability_score'] for tmp in most_stable_mask], 57 | } 58 | return sam_masks 59 | 60 | def _extract_image_features(self, input_image: np.ndarray, normalize=False): 61 | """ 62 | 63 | :param input_image: 64 | :return: 65 | """ 66 | image = Image.fromarray(input_image) 67 | image = self.clip_preprocess(image).unsqueeze(0).to(self.device) 68 | image_features = self.clip_model.encode_image(image) 69 | image_features = F.normalize(image_features, dim=-1) if normalize else image_features 70 | image_features = image_features.squeeze(0) 71 | 72 | return image_features.cpu().numpy() 73 | 74 | def _classify_image(self, input_image: np.ndarray): 75 | """ 76 | 77 | :param input_image: 78 | :return: 79 | """ 80 | image = Image.fromarray(input_image) 81 | image = self.clip_preprocess(image).unsqueeze(0).to(self.device) 82 | logits_per_image, logits_per_text = self.clip_model(image, self.imagenet_cls_text_token) 83 | probs = logits_per_image.softmax(dim=-1).cpu().numpy()[0, :] 84 | cls_id = np.argmax(probs) 85 | cls_name = self.imagenet_cls_text_prompts[cls_id].replace('a photo of a', '') 86 | 87 | return cls_name 88 | 89 | def _extract_mask_features(self, input_image, mask): 90 | """ 91 | 92 | :param input_image: 93 | :param mask: 94 | :return: 95 | """ 96 | bboxes_features = [] 97 | bboxes_ids = [] 98 | for idx, bbox in enumerate(mask['bboxes']): 99 | bbox = [int(tmp) for tmp in bbox] 100 | roi_image = input_image[bbox[1]:bbox[1] + bbox[3], bbox[0]:bbox[0] + bbox[2], :] 101 | roi_features = self._extract_image_features(roi_image, False) 102 | bboxes_features.append(roi_features) 103 | bboxes_ids.append(idx) 104 | 105 | mask['bbox_features'] = np.asarray(bboxes_features, dtype=np.float32) 106 | mask['bbox_ori_ids'] = bboxes_ids 107 | return 108 | 109 | def _classify_mask(self, input_image, mask): 110 | """ 111 | 112 | :param input_image: 113 | :param mask: 114 | :return: 115 | """ 116 | bboxes_cls_names = [] 117 | for idx, bbox in enumerate(mask['bboxes']): 118 | bbox = [int(tmp) for tmp in bbox] 119 | roi_image = input_image[bbox[1]:bbox[1] + bbox[3], bbox[0]:bbox[0] + bbox[2], :] 120 | cls_name = self._classify_image(roi_image) 121 | bboxes_cls_names.append(cls_name) 122 | 123 | mask['bbox_cls_names'] = bboxes_cls_names 124 | return 125 | 126 | def _cluster_bbox_features(self, bbox_features: np.ndarray): 127 | """ 128 | 129 | :param bbox_features: [N, 512] feats 130 | :return: 131 | """ 132 | norm = np.linalg.norm(bbox_features, axis=1) 133 | similarity_matrix = np.dot(bbox_features, bbox_features.T) / np.outer(norm, norm) 134 | similar_obj = np.argwhere(similarity_matrix > self.similarity_thresh) 135 | 136 | obj_classes = [i for i in range(bbox_features.shape[0])] 137 | for i, j in similar_obj: 138 | if i != j and i in obj_classes and j in obj_classes: 139 | obj_classes[j] = obj_classes[i] 140 | 141 | return obj_classes 142 | 143 | def cluster_image(self, input_image_path): 144 | """ 145 | 146 | :param input_image_path: 147 | :return: 148 | """ 149 | # read input image 150 | input_image = cv2.imread(input_image_path, cv2.IMREAD_COLOR) 151 | if input_image.shape[0] > self.max_input_size[0] or input_image.shape[1] > self.max_input_size[1]: 152 | h, w, _ = input_image.shape 153 | hw_ratio = h / w if h > w else w / h 154 | if h > w: 155 | dsize = (int(self.max_input_size[1] / hw_ratio), self.max_input_size[1]) 156 | else: 157 | dsize = (self.max_input_size[0], int(self.max_input_size[0] / hw_ratio)) 158 | input_image = cv2.resize(input_image, dsize=dsize) 159 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) 160 | 161 | with torch.no_grad(): 162 | # extract mask from sam model 163 | masks = self._generate_sam_mask(input_image) 164 | # extract each mask's features 165 | self._extract_mask_features(input_image, masks) 166 | # classify each mask's label 167 | self._classify_mask(input_image, masks) 168 | 169 | # cluster obj ids 170 | cluster_obj_ids = self._cluster_bbox_features(bbox_features=masks['bbox_features']) 171 | masks['bbox_cluster_ids'] = cluster_obj_ids 172 | 173 | # diff mask image 174 | input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) 175 | ori_mask = utils.generate_sam_mask_images( 176 | masks, 177 | bbox_cls_ids=masks['bbox_ori_ids'], 178 | bbox_cls_names=masks['bbox_cls_names'], 179 | draw_bbox=True 180 | ) 181 | ori_mask_add = cv2.addWeighted(input_image, 0.5, ori_mask, 0.5, 0.0) 182 | clustered_mask = utils.generate_sam_mask_images(masks, bbox_cls_ids=masks['bbox_cluster_ids']) 183 | clustered_mask_add = cv2.addWeighted(input_image, 0.5, clustered_mask, 0.5, 0.0) 184 | 185 | ret = { 186 | 'source': input_image, 187 | 'ori_mask': ori_mask, 188 | 'ori_mask_add': ori_mask_add, 189 | 'cluster_mask': clustered_mask, 190 | 'cluster_mask_add': clustered_mask_add 191 | } 192 | 193 | return ret 194 | -------------------------------------------------------------------------------- /models/cluster/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-10 下午3:25 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : utils.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | utils func 10 | """ 11 | import time 12 | 13 | import cv2 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | def generate_color_pool(color_nums): 19 | """ 20 | generate random color 21 | :param color_nums: 22 | :return: 23 | """ 24 | color_pool = [] 25 | 26 | np.random.seed(int(time.time())) 27 | 28 | for i in range(color_nums): 29 | b = int(np.random.randint(0, 255, dtype=np.uint8)) 30 | g = int(np.random.randint(0, 255, dtype=np.uint8)) 31 | r = int(np.random.randint(0, 255, dtype=np.uint8)) 32 | color_pool.append((b, g, r)) 33 | 34 | return color_pool 35 | 36 | 37 | def generate_sam_mask_images(sam_masks, bbox_cls_ids=None, bbox_cls_names=None, draw_bbox=False): 38 | """ 39 | 40 | :param sam_masks: 41 | :param bbox_cls_ids: 42 | :param bbox_cls_names: 43 | :param draw_bbox: 44 | :return: 45 | """ 46 | seg_images = sam_masks['segmentations'] 47 | bboxes = sam_masks['bboxes'] 48 | cls_ids = bbox_cls_ids if bbox_cls_ids is not None else sam_masks['bbox_ori_ids'] 49 | color_pool = generate_color_pool(color_nums=max(cls_ids) + 1) 50 | 51 | mask_image = np.ones(shape=(seg_images[0].shape[0], seg_images[0].shape[1], 3), dtype=np.uint8) 52 | for idx, cls_id in enumerate(cls_ids): 53 | color = color_pool[cls_id] 54 | # draw mask 55 | mask_image[:, :, 0][seg_images[idx]] = color[0] 56 | mask_image[:, :, 1][seg_images[idx]] = color[1] 57 | mask_image[:, :, 2][seg_images[idx]] = color[2] 58 | # draw bbox 59 | if draw_bbox: 60 | for idx, cls_id in enumerate(cls_ids): 61 | color = color_pool[cls_id] 62 | bbox_pt1 = [bboxes[idx][0], bboxes[idx][1]] 63 | bbox_pt1 = [int(tmp) for tmp in bbox_pt1] 64 | bbox_pt2 = [bboxes[idx][0] + bboxes[idx][2], bboxes[idx][1] + bboxes[idx][3]] 65 | bbox_pt2 = [int(tmp) for tmp in bbox_pt2] 66 | cv2.rectangle(mask_image, bbox_pt1, bbox_pt2, color, 2) 67 | text = bbox_cls_names[idx].split(',')[0] 68 | org = [bbox_pt1[0] - 10, bbox_pt1[1] - 10] 69 | cv2.putText(mask_image, text, org, cv2.FONT_HERSHEY_PLAIN, 1.5, (0, 0, 0), 2) 70 | 71 | return mask_image 72 | 73 | 74 | def show_anns(anns, bbox_cls_ids=None): 75 | if len(anns) == 0: 76 | return 77 | # sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) 78 | ax = plt.gca() 79 | ax.set_autoscale_on(False) 80 | cls_ids = bbox_cls_ids if bbox_cls_ids is not None else anns['bbox_ori_ids'] 81 | color_pool = generate_color_pool(color_nums=max(cls_ids) + 1) 82 | for idx, m in enumerate(anns['segmentations']): 83 | img = np.ones((m.shape[0], m.shape[1], 3)) 84 | color_mask = color_pool[cls_ids[idx]] 85 | for i in range(3): 86 | img[:, :, i] = color_mask[i] / 255 87 | ax.imshow(np.dstack((img, m*0.35))) 88 | 89 | return 90 | 91 | 92 | def generate_imagenet_classification_text_prompts(): 93 | """ 94 | 95 | :return: 96 | """ 97 | text_prefix = 'a photo of a' 98 | text_prompts = open('./data/resources/imagenet-classes.txt', 'r').readlines() 99 | text_prompts = list(map(lambda x: x.rstrip('\r').rstrip('\n'), text_prompts)) 100 | text_prompts = list(map(lambda x: ' '.join([text_prefix, x]), text_prompts)) 101 | 102 | return text_prompts 103 | -------------------------------------------------------------------------------- /models/detector/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-11 上午11:49 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | instance segmentation model with sam and clip 10 | """ 11 | from models.detector import insseg_model 12 | -------------------------------------------------------------------------------- /models/detector/insseg_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-11 下午1:49 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : insseg_model.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | instance segmentation model with sam and clip 10 | """ 11 | import numpy as np 12 | import cv2 13 | from PIL import Image 14 | import torch 15 | 16 | from models.detector import utils 17 | from models.clip import tokenize 18 | from models import build_clip_model 19 | from models import build_sam_mask_generator 20 | 21 | 22 | class SamClipInsSegmentor(object): 23 | """ 24 | 25 | """ 26 | def __init__(self, sam_cfg, clip_cfg, insseg_cfg): 27 | """ 28 | 29 | :param sam_cfg: 30 | :param clip_cfg: 31 | :param insseg_cfg: 32 | """ 33 | self.mask_generator = build_sam_mask_generator(sam_cfg) 34 | self.clip_model, self.clip_preprocess = build_clip_model(clip_cfg) 35 | self.device = torch.device(insseg_cfg.MODEL.DEVICE) 36 | self.top_k_objs = insseg_cfg.INS_SEG.TOP_K_MASK_COUNT 37 | self.cls_score_thresh = insseg_cfg.INS_SEG.CLS_SCORE_THRESH 38 | self.max_input_size = insseg_cfg.INS_SEG.MAX_INPUT_SIZE 39 | self.obj365_text_prompts = utils.generate_object365_text_prompts() 40 | self.obj365_text_token = tokenize(self.obj365_text_prompts).to(self.device) 41 | self.text_token = None 42 | 43 | def _set_text_tokens(self, texts): 44 | """ 45 | 46 | :param texts: 47 | :return: 48 | """ 49 | self.text_token = tokenize(texts=texts).to(self.device) 50 | return 51 | 52 | def _generate_sam_mask(self, input_image: np.ndarray): 53 | """ 54 | 55 | :param input_image: 56 | :return: 57 | """ 58 | masks = self.mask_generator.generate(input_image) 59 | most_stable_mask = sorted(masks, key=lambda d: d['area']) 60 | if len(most_stable_mask) > self.top_k_objs: 61 | most_stable_mask = most_stable_mask[-self.top_k_objs:] 62 | sam_masks = { 63 | 'segmentations': [tmp['segmentation'] for tmp in most_stable_mask], 64 | 'bboxes': [tmp['bbox'] for tmp in most_stable_mask], 65 | 'stability_scores': [tmp['stability_score'] for tmp in most_stable_mask], 66 | } 67 | return sam_masks 68 | 69 | @staticmethod 70 | def _crop_rotate_image_roi(input_image, seg_mask): 71 | """ 72 | 73 | :param input_image: 74 | :param seg_mask: 75 | :return: 76 | """ 77 | y, x = np.where(seg_mask == 1) 78 | fg_pts = np.vstack((x, y)).transpose() 79 | src_image = cv2.bitwise_or(input_image, input_image, mask=np.asarray(seg_mask, dtype=np.uint8)) 80 | roi_x, roi_y, roi_w, roi_h = cv2.boundingRect(fg_pts) 81 | extend_size = 20 82 | if roi_x - extend_size >= 0: 83 | roi_x -= extend_size 84 | roi_w += extend_size 85 | if roi_y - extend_size >= 0: 86 | roi_y -= extend_size 87 | roi_h += extend_size 88 | roi_image = src_image[roi_y:roi_y + roi_h, roi_x:roi_x + roi_w, :] 89 | if np.any(np.shape(roi_image) < (3, 3)): 90 | return None 91 | return roi_image 92 | 93 | def _classify_image(self, input_image: np.ndarray, text=None): 94 | """ 95 | 96 | :param input_image: 97 | :return: 98 | """ 99 | image = Image.fromarray(input_image) 100 | image = self.clip_preprocess(image).unsqueeze(0).to(self.device) 101 | if text is None: 102 | logits_per_image, logits_per_text = self.clip_model(image, self.obj365_text_token) 103 | else: 104 | if self.text_token is None: 105 | text_token = tokenize(texts=text).to(self.device) 106 | logits_per_image, logits_per_text = self.clip_model(image, text_token) 107 | else: 108 | logits_per_image, logits_per_text = self.clip_model(image, self.text_token) 109 | probs = logits_per_image.softmax(dim=-1).cpu().numpy()[0, :] 110 | cls_id = np.argmax(probs) 111 | score = probs[cls_id] 112 | if text is None: 113 | if score < 0.15: 114 | cls_id = probs.shape[0] - 1 115 | return cls_id 116 | else: 117 | if score < self.cls_score_thresh: 118 | cls_id = probs.shape[0] - 1 119 | return cls_id 120 | 121 | def _classify_mask(self, input_image, mask, text=None): 122 | """ 123 | 124 | :param input_image: 125 | :param mask: 126 | :return: 127 | """ 128 | bboxes_cls_names = [] 129 | for idx, bbox in enumerate(mask['bboxes']): 130 | roi_image = self._crop_rotate_image_roi(input_image, mask['segmentations'][idx]) 131 | if roi_image is None: 132 | cls_name = 'background' 133 | bboxes_cls_names.append(cls_name) 134 | continue 135 | # cv2.imwrite('{:d}_mask.png'.format(idx), roi_image[:, :, (2, 1, 0)]) 136 | cls_id = self._classify_image(roi_image, text=text) 137 | if text is None: 138 | cls_name = self.obj365_text_prompts[cls_id].split('a photo of')[1].strip(' ') 139 | bboxes_cls_names.append(cls_name) 140 | else: 141 | cls_name = text[cls_id] 142 | if cls_name.startswith('a photo of'): 143 | cls_name = cls_name.split(' ')[3] 144 | bboxes_cls_names.append(cls_name) 145 | 146 | mask['bbox_cls_names'] = bboxes_cls_names 147 | return 148 | 149 | def seg_image(self, input_image_path, unique_label=None, use_text_prefix=False): 150 | """ 151 | 152 | :param input_image_path: 153 | :param unique_label: 154 | :param use_text_prefix: 155 | :return: 156 | """ 157 | # read input image 158 | input_image = cv2.imread(input_image_path, cv2.IMREAD_COLOR) 159 | if input_image.shape[0] > self.max_input_size[0] or input_image.shape[1] > self.max_input_size[1]: 160 | h, w, _ = input_image.shape 161 | hw_ratio = h / w if h > w else w / h 162 | if h > w: 163 | dsize = (int(self.max_input_size[1] / hw_ratio), self.max_input_size[1]) 164 | else: 165 | dsize = (self.max_input_size[0], int(self.max_input_size[0] / hw_ratio)) 166 | input_image = cv2.resize(input_image, dsize=dsize) 167 | input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) 168 | 169 | with torch.no_grad(): 170 | # extract mask from sam model 171 | masks = self._generate_sam_mask(input_image) 172 | # classify each mask's label 173 | if unique_label is None: 174 | self._classify_mask(input_image, masks, text=None) 175 | else: 176 | texts = utils.generate_text_prompts_for_instance_seg( 177 | unique_labels=unique_label, 178 | use_text_prefix=use_text_prefix 179 | ) 180 | self._set_text_tokens(texts) 181 | self._classify_mask(input_image, masks, text=texts) 182 | 183 | # visualize segmentation result 184 | input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) 185 | ins_seg_mask = utils.visualize_instance_seg_results(masks, draw_bbox=True) 186 | ins_seg_add = cv2.addWeighted(input_image, 0.5, ins_seg_mask, 0.5, 0.0) 187 | 188 | ret = { 189 | 'source': input_image, 190 | 'ins_seg_mask': ins_seg_mask, 191 | 'ins_seg_add': ins_seg_add 192 | } 193 | 194 | return ret 195 | -------------------------------------------------------------------------------- /models/detector/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-11 下午1:55 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : utils.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | 10 | """ 11 | import time 12 | 13 | import cv2 14 | import numpy as np 15 | 16 | 17 | def generate_color_pool(color_nums): 18 | """ 19 | 随机生成颜色池, 用来给不同的车道线染不同的颜色 20 | :param color_nums: 需要生成的颜色池大小 21 | :return: 22 | """ 23 | color_pool = [] 24 | 25 | np.random.seed(int(time.time())) 26 | 27 | for i in range(color_nums): 28 | b = int(np.random.randint(0, 255, dtype=np.uint8)) 29 | g = int(np.random.randint(0, 255, dtype=np.uint8)) 30 | r = int(np.random.randint(0, 255, dtype=np.uint8)) 31 | color_pool.append((b, g, r)) 32 | 33 | return color_pool 34 | 35 | 36 | def visualize_instance_seg_results(sam_masks, draw_bbox=False): 37 | """ 38 | 39 | :param sam_masks: 40 | :param draw_bbox: 41 | :return: 42 | """ 43 | seg_images = sam_masks['segmentations'] 44 | bboxes = sam_masks['bboxes'] 45 | bboxes_names = sam_masks['bbox_cls_names'] 46 | unique_names = list(np.unique(bboxes_names)) 47 | color_pool = generate_color_pool(color_nums=len(unique_names) + 1) 48 | 49 | mask_image = np.ones(shape=(seg_images[0].shape[0], seg_images[0].shape[1], 3), dtype=np.uint8) 50 | for idx, _ in enumerate(bboxes): 51 | if bboxes_names[idx] == 'background': 52 | continue 53 | color_id = unique_names.index(bboxes_names[idx]) 54 | color = color_pool[color_id] 55 | # draw mask 56 | mask_image[:, :, 0][seg_images[idx]] = color[0] 57 | mask_image[:, :, 1][seg_images[idx]] = color[1] 58 | mask_image[:, :, 2][seg_images[idx]] = color[2] 59 | # draw bbox 60 | if draw_bbox: 61 | for idx, _ in enumerate(bboxes): 62 | if bboxes_names[idx] == 'background': 63 | continue 64 | color_id = unique_names.index(bboxes_names[idx]) 65 | color = color_pool[color_id] 66 | bbox_pt1 = [bboxes[idx][0], bboxes[idx][1]] 67 | bbox_pt1 = [int(tmp) for tmp in bbox_pt1] 68 | bbox_pt2 = [bboxes[idx][0] + bboxes[idx][2], bboxes[idx][1] + bboxes[idx][3]] 69 | bbox_pt2 = [int(tmp) for tmp in bbox_pt2] 70 | cv2.rectangle(mask_image, bbox_pt1, bbox_pt2, color, 2) 71 | text = bboxes_names[idx].split(',')[0] 72 | org = [bbox_pt1[0] - 10, bbox_pt1[1] - 10] 73 | cv2.putText(mask_image, text, org, cv2.FONT_HERSHEY_PLAIN, 1.5, (0, 0, 255), 2) 74 | 75 | return mask_image 76 | 77 | 78 | def generate_imagenet_classification_text_prompts(): 79 | """ 80 | 81 | :return: 82 | """ 83 | text_prefix = 'a photo of' 84 | text_prompts = open('./data/resources/imagenet-classes.txt', 'r').readlines() 85 | text_prompts = list(map(lambda x: x.rstrip('\r').rstrip('\n'), text_prompts)) 86 | text_prompts = list(map(lambda x: ' '.join([text_prefix, x]), text_prompts)) 87 | 88 | return text_prompts 89 | 90 | 91 | def generate_object365_text_prompts(): 92 | """ 93 | 94 | :return: 95 | """ 96 | text_prefix = 'a photo of' 97 | text_prompts = open('./data/resources/obj365-classes.txt', 'r').readlines() 98 | text_prompts = list(map(lambda x: x.rstrip('\r').rstrip('\n').split(' ')[0], text_prompts)) 99 | text_prompts = list(map(lambda x: ' '.join([text_prefix, x]), text_prompts)) 100 | 101 | return text_prompts 102 | 103 | 104 | def generate_text_prompts_for_instance_seg(unique_labels, use_text_prefix=True): 105 | """ 106 | 107 | :param unique_labels: 108 | :param use_text_prefix: 109 | :return: 110 | """ 111 | if unique_labels[-1] != 'background': 112 | unique_labels.append('background') 113 | text_pre = 'a photo of {:s}' 114 | if use_text_prefix: 115 | text_prompts = [text_pre.format(tmp) for tmp in unique_labels] 116 | else: 117 | text_prompts = list(unique_labels) 118 | return text_prompts 119 | -------------------------------------------------------------------------------- /models/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | build_sam_vit_t, 13 | sam_model_registry, 14 | ) 15 | from .predictor import SamPredictor 16 | from .automatic_mask_generator import SamAutomaticMaskGenerator 17 | -------------------------------------------------------------------------------- /models/sam/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 10 | 11 | from typing import Any, Dict, List, Optional, Tuple 12 | 13 | from .modeling import Sam 14 | from .predictor import SamPredictor 15 | from .utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crops_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crops_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 176 | elif self.output_mode == "binary_mask": 177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 178 | else: 179 | mask_data["segmentations"] = mask_data["rles"] 180 | 181 | # Write mask records 182 | curr_anns = [] 183 | for idx in range(len(mask_data["segmentations"])): 184 | ann = { 185 | "segmentation": mask_data["segmentations"][idx], 186 | "area": area_from_rle(mask_data["rles"][idx]), 187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 188 | "predicted_iou": mask_data["iou_preds"][idx].item(), 189 | "point_coords": [mask_data["points"][idx].tolist()], 190 | "stability_score": mask_data["stability_score"][idx].item(), 191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 192 | } 193 | curr_anns.append(ann) 194 | 195 | return curr_anns 196 | 197 | def _generate_masks(self, image: np.ndarray) -> MaskData: 198 | orig_size = image.shape[:2] 199 | crop_boxes, layer_idxs = generate_crop_boxes( 200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 201 | ) 202 | 203 | # Iterate over image crops 204 | data = MaskData() 205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 207 | data.cat(crop_data) 208 | 209 | # Remove duplicate masks between crops 210 | if len(crop_boxes) > 1: 211 | # Prefer masks from smaller crops 212 | scores = 1 / box_area(data["crop_boxes"]) 213 | scores = scores.to(data["boxes"].device) 214 | keep_by_nms = batched_nms( 215 | data["boxes"].float(), 216 | scores, 217 | torch.zeros(len(data["boxes"])), # categories 218 | iou_threshold=self.crop_nms_thresh, 219 | ) 220 | data.filter(keep_by_nms) 221 | 222 | data.to_numpy() 223 | return data 224 | 225 | def _process_crop( 226 | self, 227 | image: np.ndarray, 228 | crop_box: List[int], 229 | crop_layer_idx: int, 230 | orig_size: Tuple[int, ...], 231 | ) -> MaskData: 232 | # Crop the image and calculate embeddings 233 | x0, y0, x1, y1 = crop_box 234 | cropped_im = image[y0:y1, x0:x1, :] 235 | cropped_im_size = cropped_im.shape[:2] 236 | self.predictor.set_image(cropped_im) 237 | 238 | # Get points for this crop 239 | points_scale = np.array(cropped_im_size)[None, ::-1] 240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 241 | 242 | # Generate masks for this crop in batches 243 | data = MaskData() 244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) 246 | data.cat(batch_data) 247 | del batch_data 248 | self.predictor.reset_image() 249 | 250 | # Remove duplicates within this crop. 251 | keep_by_nms = batched_nms( 252 | data["boxes"].float(), 253 | data["iou_preds"], 254 | torch.zeros(len(data["boxes"])), # categories 255 | iou_threshold=self.box_nms_thresh, 256 | ) 257 | data.filter(keep_by_nms) 258 | 259 | # Return to the original image frame 260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 261 | data["points"] = uncrop_points(data["points"], crop_box) 262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 263 | 264 | return data 265 | 266 | def _process_batch( 267 | self, 268 | points: np.ndarray, 269 | im_size: Tuple[int, ...], 270 | crop_box: List[int], 271 | orig_size: Tuple[int, ...], 272 | ) -> MaskData: 273 | orig_h, orig_w = orig_size 274 | 275 | # Run model on this batch 276 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 279 | masks, iou_preds, _ = self.predictor.predict_torch( 280 | in_points[:, None, :], 281 | in_labels[:, None], 282 | multimask_output=True, 283 | return_logits=True, 284 | ) 285 | 286 | # Serialize predictions and store in MaskData 287 | data = MaskData( 288 | masks=masks.flatten(0, 1), 289 | iou_preds=iou_preds.flatten(0, 1), 290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 291 | ) 292 | del masks 293 | 294 | # Filter by predicted IoU 295 | if self.pred_iou_thresh > 0.0: 296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 297 | data.filter(keep_mask) 298 | 299 | # Calculate stability score 300 | data["stability_score"] = calculate_stability_score( 301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 302 | ) 303 | if self.stability_score_thresh > 0.0: 304 | keep_mask = data["stability_score"] >= self.stability_score_thresh 305 | data.filter(keep_mask) 306 | 307 | # Threshold masks and calculate boxes 308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 309 | data["boxes"] = batched_mask_to_box(data["masks"]) 310 | 311 | # Filter boxes that touch crop boundaries 312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 313 | if not torch.all(keep_mask): 314 | data.filter(keep_mask) 315 | 316 | # Compress to RLE 317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 318 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 319 | del data["masks"] 320 | 321 | return data 322 | 323 | @staticmethod 324 | def postprocess_small_regions( 325 | mask_data: MaskData, min_area: int, nms_thresh: float 326 | ) -> MaskData: 327 | """ 328 | Removes small disconnected regions and holes in masks, then reruns 329 | box NMS to remove any new duplicates. 330 | 331 | Edits mask_data in place. 332 | 333 | Requires open-cv as a dependency. 334 | """ 335 | if len(mask_data["rles"]) == 0: 336 | return mask_data 337 | 338 | # Filter small disconnected regions and holes 339 | new_masks = [] 340 | scores = [] 341 | for rle in mask_data["rles"]: 342 | mask = rle_to_mask(rle) 343 | 344 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 345 | unchanged = not changed 346 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 347 | unchanged = unchanged and not changed 348 | 349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 350 | # Give score=0 to changed masks and score=1 to unchanged masks 351 | # so NMS will prefer ones that didn't need postprocessing 352 | scores.append(float(unchanged)) 353 | 354 | # Recalculate boxes and remove any new duplicates 355 | masks = torch.cat(new_masks, dim=0) 356 | boxes = batched_mask_to_box(masks) 357 | keep_by_nms = batched_nms( 358 | boxes.float(), 359 | torch.as_tensor(scores), 360 | torch.zeros(len(boxes)), # categories 361 | iou_threshold=nms_thresh, 362 | ) 363 | 364 | # Only recalculate RLEs for masks that have changed 365 | for i_mask in keep_by_nms: 366 | if scores[i_mask] == 0.0: 367 | mask_torch = masks[i_mask].unsqueeze(0) 368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 370 | mask_data.filter(keep_by_nms) 371 | 372 | return mask_data 373 | -------------------------------------------------------------------------------- /models/sam/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | def build_sam_vit_t(checkpoint=None): 48 | prompt_embed_dim = 256 49 | image_size = 1024 50 | vit_patch_size = 16 51 | image_embedding_size = image_size // vit_patch_size 52 | mobile_sam = Sam( 53 | image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, 54 | embed_dims=[64, 128, 160, 320], 55 | depths=[2, 2, 6, 2], 56 | num_heads=[2, 4, 5, 10], 57 | window_sizes=[7, 7, 14, 7], 58 | mlp_ratio=4., 59 | drop_rate=0., 60 | drop_path_rate=0.0, 61 | use_checkpoint=False, 62 | mbconv_expand_ratio=4.0, 63 | local_conv_size=3, 64 | layer_lr_decay=0.8 65 | ), 66 | prompt_encoder=PromptEncoder( 67 | embed_dim=prompt_embed_dim, 68 | image_embedding_size=(image_embedding_size, image_embedding_size), 69 | input_image_size=(image_size, image_size), 70 | mask_in_chans=16, 71 | ), 72 | mask_decoder=MaskDecoder( 73 | num_multimask_outputs=3, 74 | transformer=TwoWayTransformer( 75 | depth=2, 76 | embedding_dim=prompt_embed_dim, 77 | mlp_dim=2048, 78 | num_heads=8, 79 | ), 80 | transformer_dim=prompt_embed_dim, 81 | iou_head_depth=3, 82 | iou_head_hidden_dim=256, 83 | ), 84 | pixel_mean=[123.675, 116.28, 103.53], 85 | pixel_std=[58.395, 57.12, 57.375], 86 | ) 87 | 88 | mobile_sam.eval() 89 | if checkpoint is not None: 90 | with open(checkpoint, "rb") as f: 91 | state_dict = torch.load(f) 92 | mobile_sam.load_state_dict(state_dict) 93 | return mobile_sam 94 | 95 | 96 | sam_model_registry = { 97 | "default": build_sam, 98 | "vit_h": build_sam, 99 | "vit_l": build_sam_vit_l, 100 | "vit_b": build_sam_vit_b, 101 | "vit_t": build_sam_vit_t, 102 | } 103 | 104 | 105 | def _build_sam( 106 | encoder_embed_dim, 107 | encoder_depth, 108 | encoder_num_heads, 109 | encoder_global_attn_indexes, 110 | checkpoint=None, 111 | ): 112 | prompt_embed_dim = 256 113 | image_size = 1024 114 | vit_patch_size = 16 115 | image_embedding_size = image_size // vit_patch_size 116 | sam = Sam( 117 | image_encoder=ImageEncoderViT( 118 | depth=encoder_depth, 119 | embed_dim=encoder_embed_dim, 120 | img_size=image_size, 121 | mlp_ratio=4, 122 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 123 | num_heads=encoder_num_heads, 124 | patch_size=vit_patch_size, 125 | qkv_bias=True, 126 | use_rel_pos=True, 127 | global_attn_indexes=encoder_global_attn_indexes, 128 | window_size=14, 129 | out_chans=prompt_embed_dim, 130 | ), 131 | prompt_encoder=PromptEncoder( 132 | embed_dim=prompt_embed_dim, 133 | image_embedding_size=(image_embedding_size, image_embedding_size), 134 | input_image_size=(image_size, image_size), 135 | mask_in_chans=16, 136 | ), 137 | mask_decoder=MaskDecoder( 138 | num_multimask_outputs=3, 139 | transformer=TwoWayTransformer( 140 | depth=2, 141 | embedding_dim=prompt_embed_dim, 142 | mlp_dim=2048, 143 | num_heads=8, 144 | ), 145 | transformer_dim=prompt_embed_dim, 146 | iou_head_depth=3, 147 | iou_head_hidden_dim=256, 148 | ), 149 | pixel_mean=[123.675, 116.28, 103.53], 150 | pixel_std=[58.395, 57.12, 57.375], 151 | ) 152 | sam.eval() 153 | if checkpoint is not None: 154 | with open(checkpoint, "rb") as f: 155 | state_dict = torch.load(f) 156 | sam.load_state_dict(state_dict) 157 | return sam 158 | -------------------------------------------------------------------------------- /models/sam/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | from .tiny_vit_sam import TinyViT 13 | -------------------------------------------------------------------------------- /models/sam/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /models/sam/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 768, 24 | depth: int = 12, 25 | num_heads: int = 12, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter( 69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 70 | ) 71 | 72 | self.blocks = nn.ModuleList() 73 | for i in range(depth): 74 | block = Block( 75 | dim=embed_dim, 76 | num_heads=num_heads, 77 | mlp_ratio=mlp_ratio, 78 | qkv_bias=qkv_bias, 79 | norm_layer=norm_layer, 80 | act_layer=act_layer, 81 | use_rel_pos=use_rel_pos, 82 | rel_pos_zero_init=rel_pos_zero_init, 83 | window_size=window_size if i not in global_attn_indexes else 0, 84 | input_size=(img_size // patch_size, img_size // patch_size), 85 | ) 86 | self.blocks.append(block) 87 | 88 | self.neck = nn.Sequential( 89 | nn.Conv2d( 90 | embed_dim, 91 | out_chans, 92 | kernel_size=1, 93 | bias=False, 94 | ), 95 | LayerNorm2d(out_chans), 96 | nn.Conv2d( 97 | out_chans, 98 | out_chans, 99 | kernel_size=3, 100 | padding=1, 101 | bias=False, 102 | ), 103 | LayerNorm2d(out_chans), 104 | ) 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = self.patch_embed(x) 108 | if self.pos_embed is not None: 109 | x = x + self.pos_embed 110 | 111 | for blk in self.blocks: 112 | x = blk(x) 113 | 114 | x = self.neck(x.permute(0, 3, 1, 2)) 115 | 116 | return x 117 | 118 | 119 | class Block(nn.Module): 120 | """Transformer blocks with support of window attention and residual propagation blocks""" 121 | 122 | def __init__( 123 | self, 124 | dim: int, 125 | num_heads: int, 126 | mlp_ratio: float = 4.0, 127 | qkv_bias: bool = True, 128 | norm_layer: Type[nn.Module] = nn.LayerNorm, 129 | act_layer: Type[nn.Module] = nn.GELU, 130 | use_rel_pos: bool = False, 131 | rel_pos_zero_init: bool = True, 132 | window_size: int = 0, 133 | input_size: Optional[Tuple[int, int]] = None, 134 | ) -> None: 135 | """ 136 | Args: 137 | dim (int): Number of input channels. 138 | num_heads (int): Number of attention heads in each ViT block. 139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 140 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 141 | norm_layer (nn.Module): Normalization layer. 142 | act_layer (nn.Module): Activation layer. 143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 145 | window_size (int): Window size for window attention blocks. If it equals 0, then 146 | use global attention. 147 | input_size (int or None): Input resolution for calculating the relative positional 148 | parameter size. 149 | """ 150 | super().__init__() 151 | self.norm1 = norm_layer(dim) 152 | self.attn = Attention( 153 | dim, 154 | num_heads=num_heads, 155 | qkv_bias=qkv_bias, 156 | use_rel_pos=use_rel_pos, 157 | rel_pos_zero_init=rel_pos_zero_init, 158 | input_size=input_size if window_size == 0 else (window_size, window_size), 159 | ) 160 | 161 | self.norm2 = norm_layer(dim) 162 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 163 | 164 | self.window_size = window_size 165 | 166 | def forward(self, x: torch.Tensor) -> torch.Tensor: 167 | shortcut = x 168 | x = self.norm1(x) 169 | # Window partition 170 | if self.window_size > 0: 171 | H, W = x.shape[1], x.shape[2] 172 | x, pad_hw = window_partition(x, self.window_size) 173 | 174 | x = self.attn(x) 175 | # Reverse window partition 176 | if self.window_size > 0: 177 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 178 | 179 | x = shortcut + x 180 | x = x + self.mlp(self.norm2(x)) 181 | 182 | return x 183 | 184 | 185 | class Attention(nn.Module): 186 | """Multi-head Attention block with relative position embeddings.""" 187 | 188 | def __init__( 189 | self, 190 | dim: int, 191 | num_heads: int = 8, 192 | qkv_bias: bool = True, 193 | use_rel_pos: bool = False, 194 | rel_pos_zero_init: bool = True, 195 | input_size: Optional[Tuple[int, int]] = None, 196 | ) -> None: 197 | """ 198 | Args: 199 | dim (int): Number of input channels. 200 | num_heads (int): Number of attention heads. 201 | qkv_bias (bool: If True, add a learnable bias to query, key, value. 202 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 203 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 204 | input_size (int or None): Input resolution for calculating the relative positional 205 | parameter size. 206 | """ 207 | super().__init__() 208 | self.num_heads = num_heads 209 | head_dim = dim // num_heads 210 | self.scale = head_dim**-0.5 211 | 212 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 213 | self.proj = nn.Linear(dim, dim) 214 | 215 | self.use_rel_pos = use_rel_pos 216 | if self.use_rel_pos: 217 | assert ( 218 | input_size is not None 219 | ), "Input size must be provided if using relative positional encoding." 220 | # initialize relative positional embeddings 221 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 222 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 223 | 224 | def forward(self, x: torch.Tensor) -> torch.Tensor: 225 | B, H, W, _ = x.shape 226 | # qkv with shape (3, B, nHead, H * W, C) 227 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 228 | # q, k, v with shape (B * nHead, H * W, C) 229 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 230 | 231 | attn = (q * self.scale) @ k.transpose(-2, -1) 232 | 233 | if self.use_rel_pos: 234 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 235 | 236 | attn = attn.softmax(dim=-1) 237 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 238 | x = self.proj(x) 239 | 240 | return x 241 | 242 | 243 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 244 | """ 245 | Partition into non-overlapping windows with padding if needed. 246 | Args: 247 | x (tensor): input tokens with [B, H, W, C]. 248 | window_size (int): window size. 249 | 250 | Returns: 251 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 252 | (Hp, Wp): padded height and width before partition 253 | """ 254 | B, H, W, C = x.shape 255 | 256 | pad_h = (window_size - H % window_size) % window_size 257 | pad_w = (window_size - W % window_size) % window_size 258 | if pad_h > 0 or pad_w > 0: 259 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 260 | Hp, Wp = H + pad_h, W + pad_w 261 | 262 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 263 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 264 | return windows, (Hp, Wp) 265 | 266 | 267 | def window_unpartition( 268 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 269 | ) -> torch.Tensor: 270 | """ 271 | Window unpartition into original sequences and removing padding. 272 | Args: 273 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 274 | window_size (int): window size. 275 | pad_hw (Tuple): padded height and width (Hp, Wp). 276 | hw (Tuple): original height and width (H, W) before padding. 277 | 278 | Returns: 279 | x: unpartitioned sequences with [B, H, W, C]. 280 | """ 281 | Hp, Wp = pad_hw 282 | H, W = hw 283 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 284 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 285 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 286 | 287 | if Hp > H or Wp > W: 288 | x = x[:, :H, :W, :].contiguous() 289 | return x 290 | 291 | 292 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 293 | """ 294 | Get relative positional embeddings according to the relative positions of 295 | query and key sizes. 296 | Args: 297 | q_size (int): size of query q. 298 | k_size (int): size of key k. 299 | rel_pos (Tensor): relative position embeddings (L, C). 300 | 301 | Returns: 302 | Extracted positional embeddings according to relative positions. 303 | """ 304 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 305 | # Interpolate rel pos if needed. 306 | if rel_pos.shape[0] != max_rel_dist: 307 | # Interpolate rel pos. 308 | rel_pos_resized = F.interpolate( 309 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 310 | size=max_rel_dist, 311 | mode="linear", 312 | ) 313 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 314 | else: 315 | rel_pos_resized = rel_pos 316 | 317 | # Scale the coords with short length if shapes for q and k are different. 318 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 319 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 320 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 321 | 322 | return rel_pos_resized[relative_coords.long()] 323 | 324 | 325 | def add_decomposed_rel_pos( 326 | attn: torch.Tensor, 327 | q: torch.Tensor, 328 | rel_pos_h: torch.Tensor, 329 | rel_pos_w: torch.Tensor, 330 | q_size: Tuple[int, int], 331 | k_size: Tuple[int, int], 332 | ) -> torch.Tensor: 333 | """ 334 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 335 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 336 | Args: 337 | attn (Tensor): attention map. 338 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 339 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 340 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 341 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 342 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 343 | 344 | Returns: 345 | attn (Tensor): attention map with added relative positional embeddings. 346 | """ 347 | q_h, q_w = q_size 348 | k_h, k_w = k_size 349 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 350 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 351 | 352 | B, _, dim = q.shape 353 | r_q = q.reshape(B, q_h, q_w, dim) 354 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 355 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 356 | 357 | attn = ( 358 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 359 | ).view(B, q_h * q_w, k_h * k_w) 360 | 361 | return attn 362 | 363 | 364 | class PatchEmbed(nn.Module): 365 | """ 366 | Image to Patch Embedding. 367 | """ 368 | 369 | def __init__( 370 | self, 371 | kernel_size: Tuple[int, int] = (16, 16), 372 | stride: Tuple[int, int] = (16, 16), 373 | padding: Tuple[int, int] = (0, 0), 374 | in_chans: int = 3, 375 | embed_dim: int = 768, 376 | ) -> None: 377 | """ 378 | Args: 379 | kernel_size (Tuple): kernel size of the projection layer. 380 | stride (Tuple): stride of the projection layer. 381 | padding (Tuple): padding size of the projection layer. 382 | in_chans (int): Number of input image channels. 383 | embed_dim (int): embed_dim (int): Patch embedding dimension. 384 | """ 385 | super().__init__() 386 | 387 | self.proj = nn.Conv2d( 388 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 389 | ) 390 | 391 | def forward(self, x: torch.Tensor) -> torch.Tensor: 392 | x = self.proj(x) 393 | # B C H W -> B H W C 394 | x = x.permute(0, 2, 3, 1) 395 | return x 396 | -------------------------------------------------------------------------------- /models/sam/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | tranformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for outptu 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /models/sam/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /models/sam/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple, Union 12 | 13 | from .tiny_vit_sam import TinyViT 14 | from .image_encoder import ImageEncoderViT 15 | from .mask_decoder import MaskDecoder 16 | from .prompt_encoder import PromptEncoder 17 | 18 | 19 | class Sam(nn.Module): 20 | mask_threshold: float = 0.0 21 | image_format: str = "RGB" 22 | 23 | def __init__( 24 | self, 25 | image_encoder: Union[ImageEncoderViT, TinyViT], 26 | prompt_encoder: PromptEncoder, 27 | mask_decoder: MaskDecoder, 28 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 29 | pixel_std: List[float] = [58.395, 57.12, 57.375], 30 | ) -> None: 31 | """ 32 | SAM predicts object masks from an image and input prompts. 33 | 34 | Arguments: 35 | image_encoder (ImageEncoderViT): The backbone used to encode the 36 | image into image embeddings that allow for efficient mask prediction. 37 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 38 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 39 | and encoded prompts. 40 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 41 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 42 | """ 43 | super().__init__() 44 | self.image_encoder = image_encoder 45 | self.prompt_encoder = prompt_encoder 46 | self.mask_decoder = mask_decoder 47 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 48 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 49 | 50 | @property 51 | def device(self) -> Any: 52 | return self.pixel_mean.device 53 | 54 | @torch.no_grad() 55 | def forward( 56 | self, 57 | batched_input: List[Dict[str, Any]], 58 | multimask_output: bool, 59 | ) -> List[Dict[str, torch.Tensor]]: 60 | """ 61 | Predicts masks end-to-end from provided images and prompts. 62 | If prompts are not known in advance, using SamPredictor is 63 | recommended over calling the model directly. 64 | 65 | Arguments: 66 | batched_input (list(dict)): A list over input images, each a 67 | dictionary with the following keys. A prompt key can be 68 | excluded if it is not present. 69 | 'image': The image as a torch tensor in 3xHxW format, 70 | already transformed for input to the model. 71 | 'original_size': (tuple(int, int)) The original size of 72 | the image before transformation, as (H, W). 73 | 'point_coords': (torch.Tensor) Batched point prompts for 74 | this image, with shape BxNx2. Already transformed to the 75 | input frame of the model. 76 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 77 | with shape BxN. 78 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 79 | Already transformed to the input frame of the model. 80 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 81 | in the form Bx1xHxW. 82 | multimask_output (bool): Whether the model should predict multiple 83 | disambiguating masks, or return a single mask. 84 | 85 | Returns: 86 | (list(dict)): A list over input images, where each element is 87 | as dictionary with the following keys. 88 | 'masks': (torch.Tensor) Batched binary mask predictions, 89 | with shape BxCxHxW, where B is the number of input promts, 90 | C is determiend by multimask_output, and (H, W) is the 91 | original size of the image. 92 | 'iou_predictions': (torch.Tensor) The model's predictions 93 | of mask quality, in shape BxC. 94 | 'low_res_logits': (torch.Tensor) Low resolution logits with 95 | shape BxCxHxW, where H=W=256. Can be passed as mask input 96 | to subsequent iterations of prediction. 97 | """ 98 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 99 | image_embeddings = self.image_encoder(input_images) 100 | 101 | outputs = [] 102 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 103 | if "point_coords" in image_record: 104 | points = (image_record["point_coords"], image_record["point_labels"]) 105 | else: 106 | points = None 107 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 108 | points=points, 109 | boxes=image_record.get("boxes", None), 110 | masks=image_record.get("mask_inputs", None), 111 | ) 112 | low_res_masks, iou_predictions = self.mask_decoder( 113 | image_embeddings=curr_embedding.unsqueeze(0), 114 | image_pe=self.prompt_encoder.get_dense_pe(), 115 | sparse_prompt_embeddings=sparse_embeddings, 116 | dense_prompt_embeddings=dense_embeddings, 117 | multimask_output=multimask_output, 118 | ) 119 | masks = self.postprocess_masks( 120 | low_res_masks, 121 | input_size=image_record["image"].shape[-2:], 122 | original_size=image_record["original_size"], 123 | ) 124 | masks = masks > self.mask_threshold 125 | outputs.append( 126 | { 127 | "masks": masks, 128 | "iou_predictions": iou_predictions, 129 | "low_res_logits": low_res_masks, 130 | } 131 | ) 132 | return outputs 133 | 134 | def postprocess_masks( 135 | self, 136 | masks: torch.Tensor, 137 | input_size: Tuple[int, ...], 138 | original_size: Tuple[int, ...], 139 | ) -> torch.Tensor: 140 | """ 141 | Remove padding and upscale masks to the original image size. 142 | 143 | Arguments: 144 | masks (torch.Tensor): Batched masks from the mask_decoder, 145 | in BxCxHxW format. 146 | input_size (tuple(int, int)): The size of the image input to the 147 | model, in (H, W) format. Used to remove padding. 148 | original_size (tuple(int, int)): The original size of the image 149 | before resizing for input to the model, in (H, W) format. 150 | 151 | Returns: 152 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 153 | is given by original_size. 154 | """ 155 | masks = F.interpolate( 156 | masks, 157 | (self.image_encoder.img_size, self.image_encoder.img_size), 158 | mode="bilinear", 159 | align_corners=False, 160 | ) 161 | masks = masks[..., : input_size[0], : input_size[1]] 162 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 163 | return masks 164 | 165 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 166 | """Normalize pixel values and pad to a square input.""" 167 | # Normalize colors 168 | x = (x - self.pixel_mean) / self.pixel_std 169 | 170 | # Pad 171 | h, w = x.shape[-2:] 172 | padh = self.image_encoder.img_size - h 173 | padw = self.image_encoder.img_size - w 174 | x = F.pad(x, (0, padw, 0, padh)) 175 | return x 176 | -------------------------------------------------------------------------------- /models/sam/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attenion layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /models/sam/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from models.sam.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks = masks[0].detach().cpu().numpy() 164 | iou_predictions = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks = low_res_masks[0].detach().cpu().numpy() 166 | return masks, iou_predictions, low_res_masks 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | box (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /models/sam/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /models/sam/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/models/sam/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/sam/utils/__pycache__/amg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/models/sam/utils/__pycache__/amg.cpython-38.pyc -------------------------------------------------------------------------------- /models/sam/utils/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/models/sam/utils/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /models/sam/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecesary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /models/sam/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) 85 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /models/sam/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /output/cluster/test_baseball.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_baseball.jpg -------------------------------------------------------------------------------- /output/cluster/test_baseball_clustered_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_baseball_clustered_mask.png -------------------------------------------------------------------------------- /output/cluster/test_baseball_clustered_mask_add.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_baseball_clustered_mask_add.png -------------------------------------------------------------------------------- /output/cluster/test_baseball_ori_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_baseball_ori_mask.png -------------------------------------------------------------------------------- /output/cluster/test_baseball_ori_mask_add.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_baseball_ori_mask_add.png -------------------------------------------------------------------------------- /output/cluster/test_bear.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_bear.jpg -------------------------------------------------------------------------------- /output/cluster/test_bear_clustered_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_bear_clustered_mask.png -------------------------------------------------------------------------------- /output/cluster/test_bear_clustered_mask_add.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_bear_clustered_mask_add.png -------------------------------------------------------------------------------- /output/cluster/test_bear_ori_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_bear_ori_mask.png -------------------------------------------------------------------------------- /output/cluster/test_bear_ori_mask_add.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_bear_ori_mask_add.png -------------------------------------------------------------------------------- /output/cluster/test_horse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_horse.jpg -------------------------------------------------------------------------------- /output/cluster/test_horse_clustered_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_horse_clustered_mask.png -------------------------------------------------------------------------------- /output/cluster/test_horse_clustered_mask_add.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_horse_clustered_mask_add.png -------------------------------------------------------------------------------- /output/cluster/test_horse_ori_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_horse_ori_mask.png -------------------------------------------------------------------------------- /output/cluster/test_horse_ori_mask_add.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/segment-anything-u-specify/606a6bc744fd7cc37824a32a95ec770580467428/output/cluster/test_horse_ori_mask_add.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | loguru 3 | matplotlib 4 | numpy 5 | opencv_python 6 | Pillow 7 | pycocotools 8 | PyYAML 9 | regex 10 | torch==1.12.0 11 | torchvision==0.13.0 12 | tqdm 13 | -------------------------------------------------------------------------------- /scripts/download_pretrained_ckpt.sh: -------------------------------------------------------------------------------- 1 | echo ----- Start downloading clip and segment-anything pretrained model weights 2 | mkdir pretrained 3 | cd pretrained 4 | # download clip vit-b-32 model 5 | wget -O ViT-B-32.pt https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt 6 | # download clip vit-l-14 model 7 | wget -O ViT-L-14.pt https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt 8 | # download default sam model 9 | wget -O sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 10 | # download sam vit-l model 11 | wget -O sam_vit_l_0b3195.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth 12 | # download sam vit-b model 13 | wget -O sam_vit_b_01ec64.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 14 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-10 上午11:49 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm Community Edition 8 | -------------------------------------------------------------------------------- /tools/cluster_sam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-10 下午7:33 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : cluster_sam.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | cluster sam segmentation result with clip features 10 | """ 11 | import os 12 | import os.path as ops 13 | import argparse 14 | 15 | import cv2 16 | 17 | from local_utils.log_util import init_logger 18 | from local_utils.config_utils import parse_config_utils 19 | from models import build_cluster 20 | 21 | 22 | LOG = init_logger.get_logger('cluster.log') 23 | 24 | 25 | def init_args(): 26 | """ 27 | 28 | :return: 29 | """ 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--input_image_path', type=str, default='./data/test_bear.jpg', required=True) 32 | parser.add_argument('--cluster_cfg_path', type=str, default='./config/cluster.yaml') 33 | parser.add_argument('--save_dir', type=str, default='./output/cluster') 34 | parser.add_argument('--simi_thresh', type=float, default=None) 35 | 36 | return parser.parse_args() 37 | 38 | 39 | def main(): 40 | """ 41 | 42 | :return: 43 | """ 44 | # init args 45 | args = init_args() 46 | input_image_path = args.input_image_path 47 | input_image_name = ops.split(input_image_path)[1] 48 | if not ops.exists(input_image_path): 49 | LOG.error('input image path: {:s} not exists'.format(input_image_path)) 50 | return 51 | cluster_cfg_path = args.cluster_cfg_path 52 | if not ops.exists(cluster_cfg_path): 53 | LOG.error('input cluster cfg path: {:s} not exists'.format(cluster_cfg_path)) 54 | return 55 | cluster_cfg = parse_config_utils.Config(config_path=cluster_cfg_path) 56 | if args.simi_thresh is not None: 57 | cluster_cfg.CLUSTER.SIMILARITY_THRESH = float(args.simi_thresh) 58 | 59 | # init cluster 60 | LOG.info('Start initializing cluster ...') 61 | cluster = build_cluster(cfg=cluster_cfg) 62 | LOG.info('Cluster initialized complete') 63 | LOG.info('Start to segment and cluster input image ...') 64 | ret = cluster.cluster_image(input_image_path) 65 | LOG.info('segment and cluster complete') 66 | 67 | # save cluster result 68 | save_dir = args.save_dir 69 | os.makedirs(save_dir, exist_ok=True) 70 | ori_image_save_path = ops.join(save_dir, input_image_name) 71 | cv2.imwrite(ori_image_save_path, ret['source']) 72 | ori_mask_save_path = ops.join(save_dir, '{:s}_ori_mask.png'.format(input_image_name.split('.')[0])) 73 | cv2.imwrite(ori_mask_save_path, ret['ori_mask']) 74 | ori_mask_add_save_path = ops.join(save_dir, '{:s}_ori_mask_add.png'.format(input_image_name.split('.')[0])) 75 | cv2.imwrite(ori_mask_add_save_path, ret['ori_mask_add']) 76 | clustered_mask_save_path = ops.join(save_dir, '{:s}_clustered_mask.png'.format(input_image_name.split('.')[0])) 77 | cv2.imwrite(clustered_mask_save_path, ret['cluster_mask']) 78 | clustered_mask_add_save_path = ops.join( 79 | save_dir, 80 | '{:s}_clustered_mask_add.png'.format(input_image_name.split('.')[0]) 81 | ) 82 | cv2.imwrite(clustered_mask_add_save_path, ret['cluster_mask_add']) 83 | LOG.info('save segment and cluster result into {:s}'.format(save_dir)) 84 | 85 | return 86 | 87 | 88 | if __name__ == '__main__': 89 | """ 90 | main func 91 | """ 92 | main() 93 | -------------------------------------------------------------------------------- /tools/sam_clip_text_seg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2023-04-11 下午2:15 4 | # @Author : MaybeShewill-CV 5 | # @Site : 6 | # @File : sam_clip_text_seg.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | instance segmentation image with sam and clip with text prompts 10 | """ 11 | import os 12 | import os.path as ops 13 | import argparse 14 | 15 | import cv2 16 | 17 | from local_utils.log_util import init_logger 18 | from local_utils.config_utils import parse_config_utils 19 | from models import build_sam_clip_text_ins_segmentor 20 | 21 | 22 | LOG = init_logger.get_logger('instance_seg.log') 23 | 24 | 25 | def init_args(): 26 | """ 27 | 28 | :return: 29 | """ 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--input_image_path', type=str, default='./data/test_bear.jpg', required=True) 32 | parser.add_argument('--insseg_cfg_path', type=str, default='./config/insseg.yaml') 33 | parser.add_argument('--text', type=str, default=None) 34 | parser.add_argument('--cls_score_thresh', type=float, default=None) 35 | parser.add_argument('--save_dir', type=str, default='./output/insseg') 36 | parser.add_argument('--use_text_prefix', action='store_true') 37 | 38 | return parser.parse_args() 39 | 40 | 41 | def main(): 42 | """ 43 | 44 | :return: 45 | """ 46 | # init args 47 | args = init_args() 48 | input_image_path = args.input_image_path 49 | input_image_name = ops.split(input_image_path)[1] 50 | if not ops.exists(input_image_path): 51 | LOG.error('input image path: {:s} not exists'.format(input_image_path)) 52 | return 53 | insseg_cfg_path = args.insseg_cfg_path 54 | if not ops.exists(insseg_cfg_path): 55 | LOG.error('input innseg cfg path: {:s} not exists'.format(insseg_cfg_path)) 56 | return 57 | insseg_cfg = parse_config_utils.Config(config_path=insseg_cfg_path) 58 | if args.text is not None: 59 | unique_labels = args.text.split(',') 60 | else: 61 | unique_labels = None 62 | if args.cls_score_thresh is not None: 63 | insseg_cfg.INS_SEG.CLS_SCORE_THRESH = args.cls_score_thresh 64 | use_text_prefix = True if args.use_text_prefix else False 65 | 66 | # init cluster 67 | LOG.info('Start initializing instance segmentor ...') 68 | segmentor = build_sam_clip_text_ins_segmentor(cfg=insseg_cfg) 69 | LOG.info('Segmentor initialized complete') 70 | LOG.info('Start to segment input image ...') 71 | ret = segmentor.seg_image(input_image_path, unique_label=unique_labels, use_text_prefix=use_text_prefix) 72 | LOG.info('segment complete') 73 | 74 | # save cluster result 75 | save_dir = args.save_dir 76 | os.makedirs(save_dir, exist_ok=True) 77 | ori_image_save_path = ops.join(save_dir, input_image_name) 78 | cv2.imwrite(ori_image_save_path, ret['source']) 79 | mask_save_path = ops.join(save_dir, '{:s}_insseg_mask.png'.format(input_image_name.split('.')[0])) 80 | cv2.imwrite(mask_save_path, ret['ins_seg_mask']) 81 | mask_add_save_path = ops.join(save_dir, '{:s}_insseg_add.png'.format(input_image_name.split('.')[0])) 82 | cv2.imwrite(mask_add_save_path, ret['ins_seg_add']) 83 | 84 | LOG.info('save segment and cluster result into {:s}'.format(save_dir)) 85 | 86 | return 87 | 88 | 89 | if __name__ == '__main__': 90 | """ 91 | main func 92 | """ 93 | main() 94 | --------------------------------------------------------------------------------