├── .DS_Store ├── 1.jpg ├── 3.jpg ├── 4.jpg ├── LICENSE.md ├── README.md ├── __pycache__ └── utils_metrics.cpython-39.pyc ├── data ├── test_in │ ├── 1.jpg │ ├── 10.jpg │ ├── 100.jpg │ ├── 11.jpg │ ├── 12.jpg │ ├── 13.jpg │ ├── 14.jpg │ ├── 15.jpg │ ├── 16.jpg │ ├── 17.jpg │ ├── 18.jpg │ └── 19.jpg └── test_out │ ├── 1.png │ ├── 10.png │ ├── 100.png │ ├── 11.png │ ├── 12.png │ ├── 13.png │ ├── 14.png │ ├── 15.png │ ├── 16.png │ ├── 17.png │ ├── 18.png │ └── 19.png ├── pre_tongue.py ├── predict.py ├── pretrained_model └── .DS_Store ├── segment ├── __pycache__ │ ├── deeplab.cpython-39.pyc │ ├── pspnet.cpython-39.pyc │ ├── unet.cpython-39.pyc │ ├── yolox.cpython-39.pyc │ └── yolox_new.cpython-39.pyc ├── tongue_classes.txt ├── utils_yolox │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── callbacks.cpython-310.pyc │ │ ├── callbacks.cpython-39.pyc │ │ ├── dataloader.cpython-310.pyc │ │ ├── dataloader.cpython-39.pyc │ │ ├── utils.cpython-310.pyc │ │ ├── utils.cpython-39.pyc │ │ ├── utils_bbox.cpython-310.pyc │ │ ├── utils_bbox.cpython-39.pyc │ │ ├── utils_fit.cpython-310.pyc │ │ ├── utils_fit.cpython-39.pyc │ │ ├── utils_map.cpython-310.pyc │ │ └── utils_map.cpython-39.pyc │ ├── callbacks.py │ ├── dataloader.py │ ├── utils.py │ ├── utils_bbox.py │ ├── utils_fit.py │ └── utils_map.py ├── yolox.pth ├── yolox.py └── yolox_nets │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── darknet.cpython-310.pyc │ ├── darknet.cpython-39.pyc │ ├── yolo.cpython-310.pyc │ ├── yolo.cpython-39.pyc │ ├── yolo_training.cpython-310.pyc │ └── yolo_training.cpython-39.pyc │ ├── darknet.py │ ├── yolo.py │ └── yolo_training.py ├── segment_anything ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── automatic_mask_generator.cpython-310.pyc │ ├── automatic_mask_generator.cpython-39.pyc │ ├── build_sam.cpython-310.pyc │ ├── build_sam.cpython-39.pyc │ ├── predictor.cpython-310.pyc │ └── predictor.cpython-39.pyc ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── common.cpython-310.pyc │ │ ├── common.cpython-39.pyc │ │ ├── image_encoder.cpython-310.pyc │ │ ├── image_encoder.cpython-39.pyc │ │ ├── mask_decoder.cpython-310.pyc │ │ ├── mask_decoder.cpython-39.pyc │ │ ├── prompt_encoder.cpython-310.pyc │ │ ├── prompt_encoder.cpython-39.pyc │ │ ├── sam.cpython-310.pyc │ │ ├── sam.cpython-39.pyc │ │ ├── transformer.cpython-310.pyc │ │ └── transformer.cpython-39.pyc │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── amg.cpython-310.pyc │ ├── amg.cpython-39.pyc │ ├── transforms.cpython-310.pyc │ └── transforms.cpython-39.pyc │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── split.py ├── train.py ├── utils ├── SurfaceDice.py ├── __init__.py ├── __pycache__ │ ├── SurfaceDice.cpython-39.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── callbacks.cpython-310.pyc │ ├── callbacks.cpython-39.pyc │ ├── dataloader.cpython-310.pyc │ ├── dataloader.cpython-39.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-39.pyc │ ├── utils_bbox.cpython-310.pyc │ ├── utils_bbox.cpython-39.pyc │ ├── utils_fit.cpython-310.pyc │ ├── utils_fit.cpython-39.pyc │ ├── utils_map.cpython-310.pyc │ └── utils_map.cpython-39.pyc ├── callbacks.py ├── dataloader.py ├── dataset.py ├── precompute_img_embed.py ├── utils.py ├── utils_bbox.py ├── utils_fit.py └── utils_map.py └── utils_metrics.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/.DS_Store -------------------------------------------------------------------------------- /1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/1.jpg -------------------------------------------------------------------------------- /3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/3.jpg -------------------------------------------------------------------------------- /4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/4.jpg -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2023] [Shan Cao] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TongueSAM: An Universal Tongue Segmentation Model Based on SAM with Zero-Shot 2 | This is the public project of paper:"TongueSAM: An Universal Tongue Segmentation Model Based on SAM with Zero-Shot", this paper can be get:https://arxiv.org/abs/2308.06444. 3 | 4 | ## Abstract 5 | 6 | Tongue segmentation serves as the primary step in automated TCM tongue diagnosis, which plays a significant role in the di- agnostic results. Currently, numerous deep learning based methods have achieved promising results. However, most of these methods exhibit mediocre performance on tongues different from the training set. To address this issue, this paper proposes a universal tongue segmentation model named TongueSAM based on SAM (Segment Anything Model). SAM is a large-scale pretrained interactive segmentation model known for its powerful zero-shot generalization capability. Applying SAM to tongue segmentation enables the segmentation of various types of tongue images with zero-shot. In this study, a Prompt Generator based on object detection 7 | is integrated into SAM to enable an end-to-end automated tongue segmentation method. Experiments demonstrate that TongueSAM achieves exceptional performance across various of tongue segmentation datasets, particularly under zero-shot. TongueSAM can be directly applied to other datasets without fine-tuning. As far as we know, this is the first application of large-scale pretrained model for tongue segmentation. 8 | 9 | ## Method 10 | 11 | TongueSAM consists primarily of two components: SAM and the Prompt Generator. For a given tongue image, TongueSAM first utilizes the pretrained Image Encoder in SAM for encoding. Meanwhile, the Prompt Generator generates bounding box prompt based on the tongue image. Finally, the image embedding and prompts are jointly fed into the Mask Decoder to generate the segmentation result. The entire segmentation process is end-to-end and does not require any additional manual prompts. The following sections will introduce different components of TongueSAM. 12 | 13 |

14 | The model structure of TonguSAM. 15 | 16 | 17 | ## Result 18 | 19 |

20 | The model structure of TonguSAM. 21 | 22 | ## DataSet 23 | 24 | In our experiments, we used 3 tongue image segmentation datasets, TongueSet1, TongueSet2(BioHit), TongueSet3. The TongueSet1 cannot be public at the moment due to privacy concerns. The [TongueSet2](https://github.com/BioHit/TongeImageDataset) has already been made public. We are now releasing the TongueSet3 [here](https://pan.baidu.com/s/1TCcbwMYraSPzWeI60EME0A?pwd=ttm4). 25 | 26 | TongueSet3 is a dataset we compiled by selecting 1000 tongue images from the [website](https://aistudio.baidu.com/datasetdetail/196398), and manually segmenting them using the [Labelme](https://github.com/wkentaro/labelme) tool. This dataset encompasses a wide range of tongue images from various sources, including those captured with mobile devices and non-standard angles. To our knowledge, this is the first publicly available tongue image segmentation dataset in a free environment. The original tongue images from the website vary in size. To ensure input consistency, we resized each tongue image to [400, 400] pixels. In the files we have made public, the "img" folder contains the original input tongue images, and the "gt" folder contains our manually annotated ground truth segmentations. **It's important to note that the images in the "gt" folder may appear completely black, but in reality, pixels with a value of [1, 1, 1] represent the tongue region, while pixels with a value of [0, 0, 0] represent the background. Please be mindful of this distinction.** 27 | 28 |

29 | The model structure of TonguSAM. 30 | 31 | ## Project Description 32 | 33 | **1.Zero-Shot Segmentation** 34 | 35 | The most crucial capability of TongueSAM lies in its Zero-Shot segmentation. To facilitate user adoption, we employed the three datasets mentioned in the paper for fine-tuning TongueSAM and openly released the pre-trained model. Users can perform tongue image segmentation directly using TongueSAM with just a few straightforward steps. 36 | 37 | Download the pre-trained weights:[TongueSAM](https://pan.baidu.com/s/1zG0jpYshlBs3lcdy4F37dQ?pwd=xtfg) 38 | 39 | Put the ```tonguesam.pth``` into the ```./pretrained_model/``` folder. 40 | 41 | Place the tongue image files that need to be segmented into the ```./data/test_in/``` folder. 42 | 43 | Run ```./python.py``` 44 | 45 | The segmented tongue images will be located in the ```./data/test_out/``` folder. 46 | 47 | **2.Fine-tune** 48 | 49 | If you wish to further fine-tune the model, please follow these steps: 50 | 51 | To train the Prompt Generator based on YOLOX, please refer to the following guidelines:[YOLOX](https://github.com/bubbliiiing/yolox-pytorch) 52 | 53 | Replace the pre-trained model in the ```./segment/yolox.pth``` file with your trained model. 54 | 55 | Run ```split.py``` twice, and the path of ```src_folder``` is your img_data and gt_data respectively. 56 | 57 | Run ```pre_tongue.py```, ```img_path``` and ```gt_path``` for your processed folder paths, respectively. For other parameter Settings, refer to [MedSAM](https://github.com/bowang-lab/MedSAM). 58 | 59 | Run ```./train.py```,please refer to the following guidelines:[MedSAM](https://github.com/bowang-lab/MedSAM) 60 | 61 | ## Acknowledge 62 | 63 | The project is based on [YOLOX](https://github.com/bubbliiiing/yolox-pytorch) and [MedSAM](https://github.com/bowang-lab/MedSAM), and we appreciate their contributions. 64 | 65 | ## License 66 | 67 | This project is licensed under the [MIT LICENSE](https://github.com/cshan-github/TongueSAM/blob/main/LICENSE.md). 68 | 69 | -------------------------------------------------------------------------------- /__pycache__/utils_metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/__pycache__/utils_metrics.cpython-39.pyc -------------------------------------------------------------------------------- /data/test_in/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/1.jpg -------------------------------------------------------------------------------- /data/test_in/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/10.jpg -------------------------------------------------------------------------------- /data/test_in/100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/100.jpg -------------------------------------------------------------------------------- /data/test_in/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/11.jpg -------------------------------------------------------------------------------- /data/test_in/12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/12.jpg -------------------------------------------------------------------------------- /data/test_in/13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/13.jpg -------------------------------------------------------------------------------- /data/test_in/14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/14.jpg -------------------------------------------------------------------------------- /data/test_in/15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/15.jpg -------------------------------------------------------------------------------- /data/test_in/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/16.jpg -------------------------------------------------------------------------------- /data/test_in/17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/17.jpg -------------------------------------------------------------------------------- /data/test_in/18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/18.jpg -------------------------------------------------------------------------------- /data/test_in/19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/19.jpg -------------------------------------------------------------------------------- /data/test_out/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/1.png -------------------------------------------------------------------------------- /data/test_out/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/10.png -------------------------------------------------------------------------------- /data/test_out/100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/100.png -------------------------------------------------------------------------------- /data/test_out/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/11.png -------------------------------------------------------------------------------- /data/test_out/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/12.png -------------------------------------------------------------------------------- /data/test_out/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/13.png -------------------------------------------------------------------------------- /data/test_out/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/14.png -------------------------------------------------------------------------------- /data/test_out/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/15.png -------------------------------------------------------------------------------- /data/test_out/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/16.png -------------------------------------------------------------------------------- /data/test_out/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/17.png -------------------------------------------------------------------------------- /data/test_out/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/18.png -------------------------------------------------------------------------------- /data/test_out/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/19.png -------------------------------------------------------------------------------- /pre_tongue.py: -------------------------------------------------------------------------------- 1 | #%% import packages 2 | import numpy as np 3 | import os 4 | join = os.path.join 5 | from skimage import transform, io, segmentation,util,exposure 6 | from tqdm import tqdm 7 | import torch 8 | from segment_anything import sam_model_registry 9 | from segment_anything.utils.transforms import ResizeLongestSide 10 | import argparse 11 | import torch.nn as nn 12 | import cv2 13 | 14 | from PIL import Image 15 | from tqdm import tqdm 16 | # from segment.nets.deeplabv3_plus import DeepLab 17 | from segment.deeplab import DeeplabV3 18 | 19 | # set up the parser 20 | parser = argparse.ArgumentParser(description='preprocess grey and RGB images') 21 | parser.add_argument('-i', '--img_path', type=str, default='/home/disk/cs/project/dataset/segmentation/tongueset3_split/img/', help='path to the images') 22 | parser.add_argument('-gt', '--gt_path', type=str, default='/home/disk/cs/project/dataset/segmentation/tongueset3_split/gt/', help='path to the ground truth (gt)') 23 | parser.add_argument('-o', '--npz_path', type=str, default='/home/disk/cs/project/dataset/segmentation/tongueset3_npz2/', help='path to save the npz files') 24 | parser.add_argument('--data_name', type=str, default='tongue', help='dataset name; used to name the final npz file, e.g., demo2d.npz') 25 | parser.add_argument('--image_size', type=int, default=400, help='image size') 26 | parser.add_argument('--img_name_suffix', type=str, default='.jpg', help='image name suffix') 27 | parser.add_argument('--label_id', type=int, default=1, help='label id') 28 | parser.add_argument('--model_type', type=str, default='vit_b', help='model type') 29 | parser.add_argument('--checkpoint', type=str, default='./pretrained_model/sam.pth', help='checkpoint') 30 | parser.add_argument('--device', type=str, default='cuda:0', help='device') 31 | parser.add_argument('--seed', type=int, default=2023, help='random seed') 32 | args = parser.parse_args() 33 | 34 | 35 | def semantic_segmentation_augmentation(image, label, rotation_range=(-90, 90)): 36 | stretch_range=(0.8, 1.2) 37 | # 随机裁剪 38 | h, w, _ = image.shape 39 | top = np.random.randint(0, h - 300) 40 | left = np.random.randint(0, w - 300) 41 | image = image[top:top+300, left:left+300] 42 | label = label[top:top+300, left:left+300] 43 | 44 | # 随机水平翻转 45 | if np.random.rand() > 0.5: 46 | image = np.fliplr(image) 47 | label = np.fliplr(label) 48 | 49 | # 随机旋转 50 | rotation_angle = np.random.uniform(rotation_range[0], rotation_range[1]) 51 | image = rotate_image(image, rotation_angle) 52 | label = rotate_image(label, rotation_angle) 53 | 54 | # 随机颜色抖动(可选) 55 | image = augment_colors(image) 56 | # 随机拉伸 57 | # stretch_factor_x = np.random.uniform(stretch_range[0], stretch_range[1]) 58 | # stretch_factor_y = np.random.uniform(stretch_range[0], stretch_range[1]) 59 | # image = stretch_image(image, stretch_factor_x, stretch_factor_y) 60 | # label = stretch_image(label, stretch_factor_x, stretch_factor_y) 61 | 62 | return image, label 63 | 64 | def augment_colors(image): 65 | # 随机生成对比度和亮度的增益 66 | contrast_factor = np.random.uniform(0.5, 1.5) # 可以调整范围以获得所需的效果 67 | brightness_factor = np.random.randint(-50, 51) # 亮度增益的范围 68 | 69 | # 修改对比度和亮度 70 | image = cv2.convertScaleAbs(image, alpha=contrast_factor, beta=brightness_factor) 71 | 72 | 73 | return image 74 | 75 | def rotate_image(image, angle): 76 | # 旋转图像 77 | image = transform.rotate(image, angle, resize=False, mode='reflect') 78 | return util.img_as_ubyte(image) 79 | 80 | 81 | def deal(img_path,gt_path,num): 82 | names = sorted(os.listdir(gt_path)) 83 | save_path = args.npz_path 84 | os.makedirs(save_path, exist_ok=True) 85 | print('image number:', len(names)) 86 | imgs = [] 87 | gts = [] 88 | boxes=[] 89 | img_embeddings = [] 90 | for gt_name in tqdm(names): 91 | sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to(args.device) 92 | image_name = gt_name.split('.')[0] + args.img_name_suffix 93 | gt_data = io.imread(join(gt_path, gt_name)) 94 | image_data = io.imread(join(img_path, image_name)) 95 | image_data,gt_data=semantic_segmentation_augmentation(image_data, gt_data) 96 | # cv2.imwrite(gt_name,image_data) 97 | gt_data=cv2.resize(gt_data,(args.image_size,args.image_size)) 98 | if len(gt_data.shape)==3: 99 | gt_data = gt_data[:,:,0] 100 | assert len(gt_data.shape)==2, 'ground truth should be 2D' 101 | gt_data = transform.resize(gt_data==args.label_id, (args.image_size, args.image_size), order=0, preserve_range=True, mode='constant') 102 | gt_data = np.uint8(gt_data) 103 | 104 | if image_data.shape[-1]>3 and len(image_data.shape)==3: 105 | image_data = image_data[:,:,:3] 106 | if len(image_data.shape)==2: 107 | image_data = np.repeat(image_data[:,:,None], 3, axis=-1) 108 | if gt_data.shape[-1]==3: 109 | gt=gt_data 110 | z=np.zeros([gt.shape[0],gt.shape[1]]) 111 | for i in range(gt.shape[0]): 112 | for j in range(gt.shape[1]): 113 | if gt[i][j][0]==1: 114 | z[i][j]=1 115 | gt=z 116 | gt_data=gt 117 | 118 | lower_bound, upper_bound = np.percentile(image_data, 0.5), np.percentile(image_data, 99.5) 119 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 120 | image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0 121 | image_data_pre[image_data==0] = 0 122 | image_data_pre = transform.resize(image_data_pre, (args.image_size,args.image_size), order=3, preserve_range=True, mode='constant', anti_aliasing=True) 123 | image_data_pre = np.uint8(image_data_pre) 124 | imgs.append(image_data_pre) 125 | gts.append(gt_data) 126 | H, W, _ = image_data_pre.shape 127 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size) 128 | resize_img = sam_transform.apply_image(image_data_pre) 129 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(args.device) 130 | input_image = sam_model.preprocess(resize_img_tensor[None,:,:,:]) # (1, 3, 1024, 1024) 131 | assert input_image.shape == (1, 3, sam_model.image_encoder.img_size, sam_model.image_encoder.img_size), 'input image should be resized to 1024*1024' 132 | embedding = sam_model.image_encoder(input_image) 133 | img_embeddings.append(embedding.cpu().detach().numpy()[0]) 134 | ########################################################################################## 135 | # pic=Image.open(join(img_path, image_name)) 136 | # pic=Image.fromarray(image_data_pre) 137 | # pic= model.get_miou_png(pic) 138 | y_indices, x_indices = np.where(gt_data > 0) 139 | xmin, xmax = np.min(x_indices), np.max(x_indices) 140 | ymin, ymax = np.min(y_indices), np.max(y_indices) 141 | box=np.array([xmin,ymin,xmax,ymax]) 142 | boxes.append(box) 143 | ########################################################################################## 144 | del sam_model 145 | print('Num. of images:', len(imgs)) 146 | if len(imgs)>0: 147 | imgs = np.stack(imgs, axis=0) # (n, 256, 256, 3) 148 | gts = np.stack(gts, axis=0) # (n, 256, 256) 149 | img_embeddings = np.stack(img_embeddings, axis=0) # (n, 1, 256, 64, 64) 150 | np.savez_compressed(join(save_path, args.data_name + str(num)+'.npz'), imgs=imgs, boxes=boxes,gts=gts, img_embeddings=img_embeddings) 151 | # save an example image for sanity check 152 | idx = np.random.randint(imgs.shape[0]) 153 | img_idx = imgs[idx,:,:,:] 154 | gt_idx = gts[idx,:,:] 155 | bd = segmentation.find_boundaries(gt_idx, mode='inner') 156 | img_idx[bd, :] = [args.image_size-1, 0, 0] 157 | # io.imsave(save_path + '.png', img_idx, check_contrast=False) 158 | else: 159 | print('Do not find image and ground-truth pairs. Please check your dataset and argument settings') 160 | 161 | num=0 162 | # for i in range(10): 163 | for f in os.listdir(args.img_path): 164 | print(f) 165 | # deal(args.img_path+'/'+f,args.gt_path+'/'+f,num) 166 | num+=1 167 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import os 5 | from skimage import io 6 | join = os.path.join 7 | from tqdm import tqdm 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | import monai 11 | from segment_anything import SamPredictor, sam_model_registry 12 | from segment_anything.utils.transforms import ResizeLongestSide 13 | from utils.SurfaceDice import compute_dice_coefficient 14 | import cv2 15 | from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score, jaccard_score 16 | # set seeds 17 | torch.manual_seed(2023) 18 | np.random.seed(2023) 19 | from skimage import io 20 | from utils_metrics import * 21 | from skimage import transform, io, segmentation 22 | from segment.yolox import YOLOX 23 | import random 24 | import warnings 25 | 26 | # 永久性地忽略指定类型的警告 27 | warnings.filterwarnings("ignore", category=UserWarning) 28 | ######################################################################################################### 29 | ts_img_path = './data/test_in/' 30 | model_type = 'vit_b' 31 | checkpoint = './pretrained_model/tonguesam.pth' 32 | device = 'cuda:1' 33 | path_out='./data/test_out/' 34 | segment=YOLOX() 35 | ############################################################################################################## 36 | def get_bbox_from_mask(mask): 37 | '''Returns a bounding box from a mask''' 38 | y_indices, x_indices = np.where(mask > 0) 39 | x_min, x_max = np.min(x_indices), np.max(x_indices) 40 | y_min, y_max = np.min(y_indices), np.max(y_indices) 41 | # add perturbation to bounding box coordinates 42 | H, W = mask.shape 43 | x_min = max(0, x_min - np.random.randint(0, 20)) 44 | x_max = min(W, x_max + np.random.randint(0, 20)) 45 | y_min = max(0, y_min - np.random.randint(0, 20)) 46 | y_max = min(H, y_max + np.random.randint(0, 20)) 47 | 48 | return np.array([x_min, y_min, x_max, y_max]) 49 | def show_mask(mask, ax, random_color=False): 50 | if random_color: 51 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 52 | else: 53 | color = np.array([251/255, 252/255, 30/255, 0.6]) 54 | h, w = mask.shape[-2:] 55 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 56 | ax.imshow(mask_image) 57 | 58 | def show_box(box, ax): 59 | x0, y0 = box[0], box[1] 60 | w, h = box[2] - box[0], box[3] - box[1] 61 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2)) 62 | best_iou=0 63 | test_names = sorted(os.listdir(ts_img_path)) 64 | 65 | sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device) 66 | ################################# 67 | # prune_threshold =0.005 68 | # for param in sam_model.image_encoder.parameters(): 69 | # param.data[torch.abs(param.data) < prune_threshold] = 0 70 | sam_model.eval() 71 | ################################# 72 | val_gts=[] 73 | val_preds=[] 74 | for f in os.listdir(ts_img_path): 75 | with torch.no_grad(): 76 | image_data = io.imread(join(ts_img_path, f)) 77 | 78 | if image_data.shape[-1] > 3 and len(image_data.shape) == 3: 79 | image_data = image_data[:, :, :3] 80 | if len(image_data.shape) == 2: 81 | image_data = np.repeat(image_data[:, :, None], 3, axis=-1) 82 | 83 | lower_bound, upper_bound = np.percentile(image_data, 0.5), np.percentile(image_data, 99.5) 84 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 85 | image_data_pre = (image_data_pre - np.min(image_data_pre)) / (np.max(image_data_pre) - np.min(image_data_pre)) * 255.0 86 | image_data_pre[image_data == 0] = 0 87 | image_data_pre = transform.resize(image_data_pre, (400, 400), order=3, preserve_range=True, mode='constant', anti_aliasing=True) 88 | image_data_pre = np.uint8(image_data_pre) 89 | 90 | H, W, _ = image_data_pre.shape 91 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size) 92 | resize_img = sam_transform.apply_image(image_data_pre) 93 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(device) 94 | input_image = sam_model.preprocess(resize_img_tensor[None, :, :, :]) 95 | ts_img_embedding = sam_model.image_encoder(input_image) 96 | 97 | img = image_data_pre 98 | boxes = segment.get_prompt(img) 99 | 100 | if boxes is not None: 101 | sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size) 102 | box = sam_trans.apply_boxes(boxes, (400,400)) 103 | box_torch = torch.as_tensor(box, dtype=torch.float, device=device) 104 | else: 105 | box_torch = None 106 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( 107 | points=None, 108 | boxes=box_torch, 109 | masks=None, 110 | ) 111 | 112 | # 使用Mask_Decoder生成分割结果 113 | medsam_seg_prob, _ = sam_model.mask_decoder( 114 | image_embeddings=ts_img_embedding.to(device), 115 | image_pe=sam_model.prompt_encoder.get_dense_pe(), 116 | sparse_prompt_embeddings=sparse_embeddings, 117 | dense_prompt_embeddings=dense_embeddings, 118 | multimask_output=False, 119 | ) 120 | medsam_seg_prob =medsam_seg_prob.cpu().detach().numpy().squeeze() 121 | medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8) 122 | 123 | medsam_seg=cv2.resize(medsam_seg,(400,400)) 124 | 125 | 126 | pred = cv2.Canny(cv2.resize((medsam_seg != 0).astype(np.uint8) * 255, (400, 400)), 100, 200) 127 | 128 | for i in range(pred.shape[0]): 129 | for j in range(pred.shape[1]): 130 | if pred[i, j] != 0: 131 | img[max(i - 1, 0):min(i + 2, 400), max(j - 1, 0):min(j + 2, 400), :] = [0, 0, 255] 132 | 133 | image1 = Image.fromarray(medsam_seg) 134 | image2 = Image.fromarray(img) 135 | 136 | image1 = image1.resize(image2.size).convert("RGBA") 137 | image2 = image2.convert("RGBA") 138 | data1 = image1.getdata() 139 | 140 | new_image = Image.new("RGBA", image2.size) 141 | new_data = [(0, 0, 128, 96) if pixel1[0] != 0 else (0, 0, 0, 0) for pixel1 in data1] 142 | 143 | new_image.putdata(new_data) 144 | if boxes is not None: 145 | draw = ImageDraw.Draw(image2) 146 | draw.rectangle([boxes[0],boxes[1],boxes[2],boxes[3]],fill=None, outline=(0, 255, 0), width=5) # 用红色绘制方框的边框,线宽为2 147 | image2.paste(new_image, (0, 0), mask=new_image) 148 | image2.save(path_out + f.split('.')[0] + '.png') 149 | print(f) 150 | -------------------------------------------------------------------------------- /pretrained_model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/pretrained_model/.DS_Store -------------------------------------------------------------------------------- /segment/__pycache__/deeplab.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/__pycache__/deeplab.cpython-39.pyc -------------------------------------------------------------------------------- /segment/__pycache__/pspnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/__pycache__/pspnet.cpython-39.pyc -------------------------------------------------------------------------------- /segment/__pycache__/unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/__pycache__/unet.cpython-39.pyc -------------------------------------------------------------------------------- /segment/__pycache__/yolox.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/__pycache__/yolox.cpython-39.pyc -------------------------------------------------------------------------------- /segment/__pycache__/yolox_new.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/__pycache__/yolox_new.cpython-39.pyc -------------------------------------------------------------------------------- /segment/tongue_classes.txt: -------------------------------------------------------------------------------- 1 | tongue -------------------------------------------------------------------------------- /segment/utils_yolox/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/callbacks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/callbacks.cpython-310.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/callbacks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/callbacks.cpython-39.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/utils_bbox.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_bbox.cpython-310.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/utils_bbox.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_bbox.cpython-39.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/utils_fit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_fit.cpython-310.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/utils_fit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_fit.cpython-39.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/utils_map.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_map.cpython-310.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/__pycache__/utils_map.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_map.cpython-39.pyc -------------------------------------------------------------------------------- /segment/utils_yolox/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import scipy.signal 7 | from matplotlib import pyplot as plt 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | import shutil 11 | import numpy as np 12 | 13 | from PIL import Image 14 | from tqdm import tqdm 15 | from .utils import cvtColor, preprocess_input, resize_image 16 | from .utils_bbox import decode_outputs, non_max_suppression 17 | from .utils_map import get_coco_map, get_map 18 | 19 | 20 | class LossHistory(): 21 | def __init__(self, log_dir, model, input_shape): 22 | self.log_dir = log_dir 23 | self.losses = [] 24 | self.val_loss = [] 25 | 26 | os.makedirs(self.log_dir) 27 | self.writer = SummaryWriter(self.log_dir) 28 | try: 29 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) 30 | self.writer.add_graph(model, dummy_input) 31 | except: 32 | pass 33 | 34 | def append_loss(self, epoch, loss, val_loss): 35 | if not os.path.exists(self.log_dir): 36 | os.makedirs(self.log_dir) 37 | 38 | self.losses.append(loss) 39 | self.val_loss.append(val_loss) 40 | 41 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: 42 | f.write(str(loss)) 43 | f.write("\n") 44 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: 45 | f.write(str(val_loss)) 46 | f.write("\n") 47 | 48 | self.writer.add_scalar('loss', loss, epoch) 49 | self.writer.add_scalar('val_loss', val_loss, epoch) 50 | self.loss_plot() 51 | 52 | def loss_plot(self): 53 | iters = range(len(self.losses)) 54 | 55 | plt.figure() 56 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') 57 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') 58 | try: 59 | if len(self.losses) < 25: 60 | num = 5 61 | else: 62 | num = 15 63 | 64 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') 65 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') 66 | except: 67 | pass 68 | 69 | plt.grid(True) 70 | plt.xlabel('Epoch') 71 | plt.ylabel('Loss') 72 | plt.legend(loc="upper right") 73 | 74 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) 75 | 76 | plt.cla() 77 | plt.close("all") 78 | 79 | class EvalCallback(): 80 | def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \ 81 | map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1): 82 | super(EvalCallback, self).__init__() 83 | 84 | self.net = net 85 | self.input_shape = input_shape 86 | self.class_names = class_names 87 | self.num_classes = num_classes 88 | self.val_lines = val_lines 89 | self.log_dir = log_dir 90 | self.cuda = cuda 91 | self.map_out_path = map_out_path 92 | self.max_boxes = max_boxes 93 | self.confidence = confidence 94 | self.nms_iou = nms_iou 95 | self.letterbox_image = letterbox_image 96 | self.MINOVERLAP = MINOVERLAP 97 | self.eval_flag = eval_flag 98 | self.period = period 99 | 100 | self.maps = [0] 101 | self.epoches = [0] 102 | if self.eval_flag: 103 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: 104 | f.write(str(0)) 105 | f.write("\n") 106 | 107 | def get_map_txt(self, image_id, image, class_names, map_out_path): 108 | f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w") 109 | image_shape = np.array(np.shape(image)[0:2]) 110 | #---------------------------------------------------------# 111 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 112 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 113 | #---------------------------------------------------------# 114 | image = cvtColor(image) 115 | #---------------------------------------------------------# 116 | # 给图像增加灰条,实现不失真的resize 117 | # 也可以直接resize进行识别 118 | #---------------------------------------------------------# 119 | image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) 120 | #---------------------------------------------------------# 121 | # 添加上batch_size维度 122 | #---------------------------------------------------------# 123 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) 124 | 125 | with torch.no_grad(): 126 | images = torch.from_numpy(image_data) 127 | if self.cuda: 128 | images = images.cuda() 129 | #---------------------------------------------------------# 130 | # 将图像输入网络当中进行预测! 131 | #---------------------------------------------------------# 132 | outputs = self.net(images) 133 | outputs = decode_outputs(outputs, self.input_shape) 134 | #---------------------------------------------------------# 135 | # 将预测框进行堆叠,然后进行非极大抑制 136 | #---------------------------------------------------------# 137 | results = non_max_suppression(outputs, self.num_classes, self.input_shape, 138 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) 139 | 140 | if results[0] is None: 141 | return 142 | 143 | top_label = np.array(results[0][:, 6], dtype = 'int32') 144 | top_conf = results[0][:, 4] * results[0][:, 5] 145 | top_boxes = results[0][:, :4] 146 | 147 | top_100 = np.argsort(top_conf)[::-1][:self.max_boxes] 148 | top_boxes = top_boxes[top_100] 149 | top_conf = top_conf[top_100] 150 | top_label = top_label[top_100] 151 | 152 | for i, c in list(enumerate(top_label)): 153 | predicted_class = self.class_names[int(c)] 154 | box = top_boxes[i] 155 | score = str(top_conf[i]) 156 | 157 | top, left, bottom, right = box 158 | if predicted_class not in class_names: 159 | continue 160 | 161 | f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) 162 | 163 | f.close() 164 | return 165 | 166 | def on_epoch_end(self, epoch, model_eval): 167 | if epoch % self.period == 0 and self.eval_flag: 168 | self.net = model_eval 169 | if not os.path.exists(self.map_out_path): 170 | os.makedirs(self.map_out_path) 171 | if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")): 172 | os.makedirs(os.path.join(self.map_out_path, "ground-truth")) 173 | if not os.path.exists(os.path.join(self.map_out_path, "detection-results")): 174 | os.makedirs(os.path.join(self.map_out_path, "detection-results")) 175 | print("Get map.") 176 | for annotation_line in tqdm(self.val_lines): 177 | line = annotation_line.split() 178 | image_id = os.path.basename(line[0]).split('.')[0] 179 | #------------------------------# 180 | # 读取图像并转换成RGB图像 181 | #------------------------------# 182 | image = Image.open(line[0]) 183 | #------------------------------# 184 | # 获得预测框 185 | #------------------------------# 186 | gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) 187 | #------------------------------# 188 | # 获得预测txt 189 | #------------------------------# 190 | self.get_map_txt(image_id, image, self.class_names, self.map_out_path) 191 | 192 | #------------------------------# 193 | # 获得真实框txt 194 | #------------------------------# 195 | with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: 196 | for box in gt_boxes: 197 | left, top, right, bottom, obj = box 198 | obj_name = self.class_names[obj] 199 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 200 | 201 | print("Calculate Map.") 202 | try: 203 | temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1] 204 | except: 205 | temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path) 206 | self.maps.append(temp_map) 207 | self.epoches.append(epoch) 208 | 209 | with open("./epoch_map.txt", 'a') as f: 210 | f.write(str(temp_map)) 211 | f.write("\n") 212 | 213 | plt.figure() 214 | plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map') 215 | 216 | plt.grid(True) 217 | plt.xlabel('Epoch') 218 | plt.ylabel('Map %s'%str(self.MINOVERLAP)) 219 | plt.title('A Map Curve') 220 | plt.legend(loc="upper right") 221 | 222 | plt.savefig(os.path.join(self.log_dir, "epoch_map.png")) 223 | plt.cla() 224 | plt.close("all") 225 | 226 | print("Get map done.") 227 | shutil.rmtree(self.map_out_path) 228 | -------------------------------------------------------------------------------- /segment/utils_yolox/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | #---------------------------------------------------------# 6 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 7 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 8 | #---------------------------------------------------------# 9 | def cvtColor(image): 10 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 11 | return image 12 | else: 13 | image = image.convert('RGB') 14 | return image 15 | 16 | #---------------------------------------------------# 17 | # 对输入图像进行resize 18 | #---------------------------------------------------# 19 | def resize_image(image, size, letterbox_image): 20 | iw, ih = image.size 21 | w, h = size 22 | if letterbox_image: 23 | scale = min(w/iw, h/ih) 24 | nw = int(iw*scale) 25 | nh = int(ih*scale) 26 | 27 | image = image.resize((nw,nh), Image.BICUBIC) 28 | new_image = Image.new('RGB', size, (128,128,128)) 29 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 30 | else: 31 | new_image = image.resize((w, h), Image.BICUBIC) 32 | return new_image 33 | 34 | #---------------------------------------------------# 35 | # 获得类 36 | #---------------------------------------------------# 37 | def get_classes(classes_path): 38 | with open(classes_path, encoding='utf-8') as f: 39 | class_names = f.readlines() 40 | class_names = [c.strip() for c in class_names] 41 | return class_names, len(class_names) 42 | 43 | def preprocess_input(image): 44 | image /= 255.0 45 | image -= np.array([0.485, 0.456, 0.406]) 46 | image /= np.array([0.229, 0.224, 0.225]) 47 | return image 48 | 49 | #---------------------------------------------------# 50 | # 获得学习率 51 | #---------------------------------------------------# 52 | def get_lr(optimizer): 53 | for param_group in optimizer.param_groups: 54 | return param_group['lr'] 55 | 56 | def show_config(**kwargs): 57 | print('Configurations:') 58 | print('-' * 70) 59 | print('|%25s | %40s|' % ('keys', 'values')) 60 | print('-' * 70) 61 | for key, value in kwargs.items(): 62 | print('|%25s | %40s|' % (str(key), str(value))) 63 | print('-' * 70) 64 | -------------------------------------------------------------------------------- /segment/utils_yolox/utils_bbox.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.ops import nms, boxes 4 | 5 | def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image): 6 | #-----------------------------------------------------------------# 7 | # 把y轴放前面是因为方便预测框和图像的宽高进行相乘 8 | #-----------------------------------------------------------------# 9 | box_yx = box_xy[..., ::-1] 10 | box_hw = box_wh[..., ::-1] 11 | input_shape = np.array(input_shape) 12 | image_shape = np.array(image_shape) 13 | 14 | if letterbox_image: 15 | #-----------------------------------------------------------------# 16 | # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况 17 | # new_shape指的是宽高缩放情况 18 | #-----------------------------------------------------------------# 19 | new_shape = np.round(image_shape * np.min(input_shape/image_shape)) 20 | offset = (input_shape - new_shape)/2./input_shape 21 | scale = input_shape/new_shape 22 | 23 | box_yx = (box_yx - offset) * scale 24 | box_hw *= scale 25 | 26 | box_mins = box_yx - (box_hw / 2.) 27 | box_maxes = box_yx + (box_hw / 2.) 28 | boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1) 29 | boxes *= np.concatenate([image_shape, image_shape], axis=-1) 30 | return boxes 31 | 32 | def decode_outputs(outputs, input_shape): 33 | grids = [] 34 | strides = [] 35 | hw = [x.shape[-2:] for x in outputs] 36 | #---------------------------------------------------# 37 | # outputs输入前代表每个特征层的预测结果 38 | # batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400 39 | # batch_size, 5 + num_classes, 40, 40 40 | # batch_size, 5 + num_classes, 20, 20 41 | # batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400 42 | # 堆叠后为batch_size, 8400, 5 + num_classes 43 | #---------------------------------------------------# 44 | outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1) 45 | #---------------------------------------------------# 46 | # 获得每一个特征点属于每一个种类的概率 47 | #---------------------------------------------------# 48 | outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:]) 49 | for h, w in hw: 50 | #---------------------------# 51 | # 根据特征层的高宽生成网格点 52 | #---------------------------# 53 | grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)]) 54 | #---------------------------# 55 | # 1, 6400, 2 56 | # 1, 1600, 2 57 | # 1, 400, 2 58 | #---------------------------# 59 | grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2) 60 | shape = grid.shape[:2] 61 | 62 | grids.append(grid) 63 | strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h)) 64 | #---------------------------# 65 | # 将网格点堆叠到一起 66 | # 1, 6400, 2 67 | # 1, 1600, 2 68 | # 1, 400, 2 69 | # 70 | # 1, 8400, 2 71 | #---------------------------# 72 | grids = torch.cat(grids, dim=1).type(outputs.type()) 73 | strides = torch.cat(strides, dim=1).type(outputs.type()) 74 | #------------------------# 75 | # 根据网格点进行解码 76 | #------------------------# 77 | outputs[..., :2] = (outputs[..., :2] + grids) * strides 78 | outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides 79 | #-----------------# 80 | # 归一化 81 | #-----------------# 82 | outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1] 83 | outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0] 84 | return outputs 85 | 86 | def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4): 87 | #----------------------------------------------------------# 88 | # 将预测结果的格式转换成左上角右下角的格式。 89 | # prediction [batch_size, num_anchors, 85] 90 | #----------------------------------------------------------# 91 | box_corner = prediction.new(prediction.shape) 92 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 93 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 94 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 95 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 96 | prediction[:, :, :4] = box_corner[:, :, :4] 97 | 98 | output = [None for _ in range(len(prediction))] 99 | #----------------------------------------------------------# 100 | # 对输入图片进行循环,一般只会进行一次 101 | #----------------------------------------------------------# 102 | for i, image_pred in enumerate(prediction): 103 | 104 | #----------------------------------------------------------# 105 | # 对种类预测部分取max。 106 | # class_conf [num_anchors, 1] 种类置信度 107 | # class_pred [num_anchors, 1] 种类 108 | #----------------------------------------------------------# 109 | class_conf,class_pred= torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) 110 | for i in class_pred: 111 | if class_pred[i]!=0: 112 | print(1) 113 | #----------------------------------------------------------# 114 | # 利用置信度进行第一轮筛选 115 | #----------------------------------------------------------# 116 | conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze() 117 | if not image_pred.size(0): 118 | continue 119 | #-------------------------------------------------------------------------# 120 | # detections [num_anchors, 7] 121 | # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred 122 | #-------------------------------------------------------------------------# 123 | detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) 124 | detections = detections[conf_mask] 125 | nms_out_index = boxes.batched_nms( 126 | detections[:, :4], 127 | detections[:, 4] * detections[:, 5], 128 | detections[:, 6], 129 | nms_thres, 130 | ) 131 | 132 | output[i] = detections[nms_out_index] 133 | 134 | # #------------------------------------------# 135 | # # 获得预测结果中包含的所有种类 136 | # #------------------------------------------# 137 | # unique_labels = detections[:, -1].cpu().unique() 138 | 139 | # if prediction.is_cuda: 140 | # unique_labels = unique_labels.cuda() 141 | # detections = detections.cuda() 142 | 143 | # for c in unique_labels: 144 | # #------------------------------------------# 145 | # # 获得某一类得分筛选后全部的预测结果 146 | # #------------------------------------------# 147 | # detections_class = detections[detections[:, -1] == c] 148 | 149 | # #------------------------------------------# 150 | # # 使用官方自带的非极大抑制会速度更快一些! 151 | # #------------------------------------------# 152 | # keep = nms( 153 | # detections_class[:, :4], 154 | # detections_class[:, 4] * detections_class[:, 5], 155 | # nms_thres 156 | # ) 157 | # max_detections = detections_class[keep] 158 | 159 | # # # 按照存在物体的置信度排序 160 | # # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True) 161 | # # detections_class = detections_class[conf_sort_index] 162 | # # # 进行非极大抑制 163 | # # max_detections = [] 164 | # # while detections_class.size(0): 165 | # # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉 166 | # # max_detections.append(detections_class[0].unsqueeze(0)) 167 | # # if len(detections_class) == 1: 168 | # # break 169 | # # ious = bbox_iou(max_detections[-1], detections_class[1:]) 170 | # # detections_class = detections_class[1:][ious < nms_thres] 171 | # # # 堆叠 172 | # # max_detections = torch.cat(max_detections).data 173 | 174 | # # Add max detections to outputs 175 | # output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections)) 176 | 177 | if output[i] is not None: 178 | output[i] = output[i].cpu().numpy() 179 | box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2] 180 | output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image) 181 | return output 182 | -------------------------------------------------------------------------------- /segment/utils_yolox/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from utils.utils import get_lr 7 | 8 | 9 | def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0): 10 | loss = 0 11 | val_loss = 0 12 | 13 | if local_rank == 0: 14 | print('Start Train') 15 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 16 | model_train.train() 17 | for iteration, batch in enumerate(gen): 18 | if iteration >= epoch_step: 19 | break 20 | 21 | images, targets = batch[0], batch[1] 22 | with torch.no_grad(): 23 | if cuda: 24 | images = images.cuda(local_rank) 25 | targets = [ann.cuda(local_rank) for ann in targets] 26 | #----------------------# 27 | # 清零梯度 28 | #----------------------# 29 | optimizer.zero_grad() 30 | if not fp16: 31 | #----------------------# 32 | # 前向传播 33 | #----------------------# 34 | outputs = model_train(images) 35 | 36 | #----------------------# 37 | # 计算损失 38 | #----------------------# 39 | loss_value = yolo_loss(outputs, targets) 40 | 41 | #----------------------# 42 | # 反向传播 43 | #----------------------# 44 | loss_value.backward() 45 | optimizer.step() 46 | else: 47 | from torch.cuda.amp import autocast 48 | with autocast(): 49 | outputs = model_train(images) 50 | #----------------------# 51 | # 计算损失 52 | #----------------------# 53 | loss_value = yolo_loss(outputs, targets) 54 | 55 | #----------------------# 56 | # 反向传播 57 | #----------------------# 58 | scaler.scale(loss_value).backward() 59 | scaler.step(optimizer) 60 | scaler.update() 61 | if ema: 62 | ema.update(model_train) 63 | 64 | loss += loss_value.item() 65 | 66 | if local_rank == 0: 67 | pbar.set_postfix(**{'loss' : loss / (iteration + 1), 68 | 'lr' : get_lr(optimizer)}) 69 | pbar.update(1) 70 | 71 | if local_rank == 0: 72 | pbar.close() 73 | print('Finish Train') 74 | print('Start Validation') 75 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 76 | 77 | if ema: 78 | model_train_eval = ema.ema 79 | else: 80 | model_train_eval = model_train.eval() 81 | 82 | for iteration, batch in enumerate(gen_val): 83 | if iteration >= epoch_step_val: 84 | break 85 | images, targets = batch[0], batch[1] 86 | with torch.no_grad(): 87 | if cuda: 88 | images = images.cuda(local_rank) 89 | targets = [ann.cuda(local_rank) for ann in targets] 90 | #----------------------# 91 | # 清零梯度 92 | #----------------------# 93 | optimizer.zero_grad() 94 | #----------------------# 95 | # 前向传播 96 | #----------------------# 97 | outputs = model_train_eval(images) 98 | 99 | #----------------------# 100 | # 计算损失 101 | #----------------------# 102 | loss_value = yolo_loss(outputs, targets) 103 | 104 | val_loss += loss_value.item() 105 | if local_rank == 0: 106 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)}) 107 | pbar.update(1) 108 | 109 | if local_rank == 0: 110 | pbar.close() 111 | print('Finish Validation') 112 | loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val) 113 | eval_callback.on_epoch_end(epoch + 1, model_train_eval) 114 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) 115 | print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val)) 116 | 117 | #-----------------------------------------------# 118 | # 保存权值 119 | #-----------------------------------------------# 120 | if ema: 121 | save_state_dict = ema.ema.state_dict() 122 | else: 123 | save_state_dict = model.state_dict() 124 | 125 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 126 | torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val))) 127 | 128 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): 129 | print('Save best model to best_epoch_weights.pth') 130 | torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth")) 131 | 132 | torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth")) -------------------------------------------------------------------------------- /segment/yolox.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox.pth -------------------------------------------------------------------------------- /segment/yolox_nets/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /segment/yolox_nets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /segment/yolox_nets/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /segment/yolox_nets/__pycache__/darknet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/darknet.cpython-310.pyc -------------------------------------------------------------------------------- /segment/yolox_nets/__pycache__/darknet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/darknet.cpython-39.pyc -------------------------------------------------------------------------------- /segment/yolox_nets/__pycache__/yolo.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/yolo.cpython-310.pyc -------------------------------------------------------------------------------- /segment/yolox_nets/__pycache__/yolo.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/yolo.cpython-39.pyc -------------------------------------------------------------------------------- /segment/yolox_nets/__pycache__/yolo_training.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/yolo_training.cpython-310.pyc -------------------------------------------------------------------------------- /segment/yolox_nets/__pycache__/yolo_training.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/yolo_training.cpython-39.pyc -------------------------------------------------------------------------------- /segment/yolox_nets/darknet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Copyright (c) Megvii, Inc. and its affiliates. 4 | 5 | import torch 6 | from torch import nn 7 | 8 | class SiLU(nn.Module): 9 | @staticmethod 10 | def forward(x): 11 | return x * torch.sigmoid(x) 12 | 13 | def get_activation(name="silu", inplace=True): 14 | if name == "silu": 15 | module = SiLU() 16 | elif name == "relu": 17 | module = nn.ReLU(inplace=inplace) 18 | elif name == "lrelu": 19 | module = nn.LeakyReLU(0.1, inplace=inplace) 20 | else: 21 | raise AttributeError("Unsupported act type: {}".format(name)) 22 | return module 23 | 24 | class Focus(nn.Module): 25 | def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"): 26 | super().__init__() 27 | self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act) 28 | 29 | def forward(self, x): 30 | patch_top_left = x[..., ::2, ::2] 31 | patch_bot_left = x[..., 1::2, ::2] 32 | patch_top_right = x[..., ::2, 1::2] 33 | patch_bot_right = x[..., 1::2, 1::2] 34 | x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,) 35 | return self.conv(x) 36 | 37 | class BaseConv(nn.Module): 38 | def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"): 39 | super().__init__() 40 | pad = (ksize - 1) // 2 41 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias) 42 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03) 43 | self.act = get_activation(act, inplace=True) 44 | 45 | def forward(self, x): 46 | return self.act(self.bn(self.conv(x))) 47 | 48 | def fuseforward(self, x): 49 | return self.act(self.conv(x)) 50 | 51 | class DWConv(nn.Module): 52 | def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"): 53 | super().__init__() 54 | self.dconv = BaseConv(in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act,) 55 | self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) 56 | 57 | def forward(self, x): 58 | x = self.dconv(x) 59 | return self.pconv(x) 60 | 61 | class SPPBottleneck(nn.Module): 62 | def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"): 63 | super().__init__() 64 | hidden_channels = in_channels // 2 65 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation) 66 | self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes]) 67 | conv2_channels = hidden_channels * (len(kernel_sizes) + 1) 68 | self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation) 69 | 70 | def forward(self, x): 71 | x = self.conv1(x) 72 | x = torch.cat([x] + [m(x) for m in self.m], dim=1) 73 | x = self.conv2(x) 74 | return x 75 | 76 | #--------------------------------------------------# 77 | # 残差结构的构建,小的残差结构 78 | #--------------------------------------------------# 79 | class Bottleneck(nn.Module): 80 | # Standard bottleneck 81 | def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu",): 82 | super().__init__() 83 | hidden_channels = int(out_channels * expansion) 84 | Conv = DWConv if depthwise else BaseConv 85 | #--------------------------------------------------# 86 | # 利用1x1卷积进行通道数的缩减。缩减率一般是50% 87 | #--------------------------------------------------# 88 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 89 | #--------------------------------------------------# 90 | # 利用3x3卷积进行通道数的拓张。并且完成特征提取 91 | #--------------------------------------------------# 92 | self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) 93 | self.use_add = shortcut and in_channels == out_channels 94 | 95 | def forward(self, x): 96 | y = self.conv2(self.conv1(x)) 97 | if self.use_add: 98 | y = y + x 99 | return y 100 | 101 | class CSPLayer(nn.Module): 102 | def __init__(self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu",): 103 | # ch_in, ch_out, number, shortcut, groups, expansion 104 | super().__init__() 105 | hidden_channels = int(out_channels * expansion) 106 | #--------------------------------------------------# 107 | # 主干部分的初次卷积 108 | #--------------------------------------------------# 109 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 110 | #--------------------------------------------------# 111 | # 大的残差边部分的初次卷积 112 | #--------------------------------------------------# 113 | self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 114 | #-----------------------------------------------# 115 | # 对堆叠的结果进行卷积的处理 116 | #-----------------------------------------------# 117 | self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act) 118 | 119 | #--------------------------------------------------# 120 | # 根据循环的次数构建上述Bottleneck残差结构 121 | #--------------------------------------------------# 122 | module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)] 123 | self.m = nn.Sequential(*module_list) 124 | 125 | def forward(self, x): 126 | #-------------------------------# 127 | # x_1是主干部分 128 | #-------------------------------# 129 | x_1 = self.conv1(x) 130 | #-------------------------------# 131 | # x_2是大的残差边部分 132 | #-------------------------------# 133 | x_2 = self.conv2(x) 134 | 135 | #-----------------------------------------------# 136 | # 主干部分利用残差结构堆叠继续进行特征提取 137 | #-----------------------------------------------# 138 | x_1 = self.m(x_1) 139 | #-----------------------------------------------# 140 | # 主干部分和大的残差边部分进行堆叠 141 | #-----------------------------------------------# 142 | x = torch.cat((x_1, x_2), dim=1) 143 | #-----------------------------------------------# 144 | # 对堆叠的结果进行卷积的处理 145 | #-----------------------------------------------# 146 | return self.conv3(x) 147 | 148 | class CSPDarknet(nn.Module): 149 | def __init__(self, dep_mul, wid_mul, out_features=("dark3", "dark4", "dark5"), depthwise=False, act="silu",): 150 | super().__init__() 151 | assert out_features, "please provide output features of Darknet" 152 | self.out_features = out_features 153 | Conv = DWConv if depthwise else BaseConv 154 | 155 | #-----------------------------------------------# 156 | # 输入图片是640, 640, 3 157 | # 初始的基本通道是64 158 | #-----------------------------------------------# 159 | base_channels = int(wid_mul * 64) # 64 160 | base_depth = max(round(dep_mul * 3), 1) # 3 161 | 162 | #-----------------------------------------------# 163 | # 利用focus网络结构进行特征提取 164 | # 640, 640, 3 -> 320, 320, 12 -> 320, 320, 64 165 | #-----------------------------------------------# 166 | self.stem = Focus(3, base_channels, ksize=3, act=act) 167 | 168 | #-----------------------------------------------# 169 | # 完成卷积之后,320, 320, 64 -> 160, 160, 128 170 | # 完成CSPlayer之后,160, 160, 128 -> 160, 160, 128 171 | #-----------------------------------------------# 172 | self.dark2 = nn.Sequential( 173 | Conv(base_channels, base_channels * 2, 3, 2, act=act), 174 | CSPLayer(base_channels * 2, base_channels * 2, n=base_depth, depthwise=depthwise, act=act), 175 | ) 176 | 177 | #-----------------------------------------------# 178 | # 完成卷积之后,160, 160, 128 -> 80, 80, 256 179 | # 完成CSPlayer之后,80, 80, 256 -> 80, 80, 256 180 | #-----------------------------------------------# 181 | self.dark3 = nn.Sequential( 182 | Conv(base_channels * 2, base_channels * 4, 3, 2, act=act), 183 | CSPLayer(base_channels * 4, base_channels * 4, n=base_depth * 3, depthwise=depthwise, act=act), 184 | ) 185 | 186 | #-----------------------------------------------# 187 | # 完成卷积之后,80, 80, 256 -> 40, 40, 512 188 | # 完成CSPlayer之后,40, 40, 512 -> 40, 40, 512 189 | #-----------------------------------------------# 190 | self.dark4 = nn.Sequential( 191 | Conv(base_channels * 4, base_channels * 8, 3, 2, act=act), 192 | CSPLayer(base_channels * 8, base_channels * 8, n=base_depth * 3, depthwise=depthwise, act=act), 193 | ) 194 | 195 | #-----------------------------------------------# 196 | # 完成卷积之后,40, 40, 512 -> 20, 20, 1024 197 | # 完成SPP之后,20, 20, 1024 -> 20, 20, 1024 198 | # 完成CSPlayer之后,20, 20, 1024 -> 20, 20, 1024 199 | #-----------------------------------------------# 200 | self.dark5 = nn.Sequential( 201 | Conv(base_channels * 8, base_channels * 16, 3, 2, act=act), 202 | SPPBottleneck(base_channels * 16, base_channels * 16, activation=act), 203 | CSPLayer(base_channels * 16, base_channels * 16, n=base_depth, shortcut=False, depthwise=depthwise, act=act), 204 | ) 205 | 206 | def forward(self, x): 207 | outputs = {} 208 | x = self.stem(x) 209 | outputs["stem"] = x 210 | x = self.dark2(x) 211 | outputs["dark2"] = x 212 | #-----------------------------------------------# 213 | # dark3的输出为80, 80, 256,是一个有效特征层 214 | #-----------------------------------------------# 215 | x = self.dark3(x) 216 | 217 | outputs["dark3"] = x 218 | #-----------------------------------------------# 219 | # dark4的输出为40, 40, 512,是一个有效特征层 220 | #-----------------------------------------------# 221 | x = self.dark4(x) 222 | 223 | outputs["dark4"] = x 224 | #-----------------------------------------------# 225 | # dark5的输出为20, 20, 1024,是一个有效特征层 226 | #-----------------------------------------------# 227 | x = self.dark5(x) 228 | 229 | outputs["dark5"] = x 230 | return {k: v for k, v in outputs.items() if k in self.out_features} 231 | 232 | 233 | if __name__ == '__main__': 234 | print(CSPDarknet(1, 1)) -------------------------------------------------------------------------------- /segment/yolox_nets/yolo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Copyright (c) Megvii, Inc. and its affiliates. 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .darknet import BaseConv, CSPDarknet, CSPLayer, DWConv 9 | 10 | 11 | class YOLOXHead(nn.Module): 12 | def __init__(self, num_classes, width = 1.0, in_channels = [256, 512, 1024], act = "silu", depthwise = False,): 13 | super().__init__() 14 | Conv = DWConv if depthwise else BaseConv 15 | 16 | self.cls_convs = nn.ModuleList() 17 | self.reg_convs = nn.ModuleList() 18 | self.cls_preds = nn.ModuleList() 19 | self.reg_preds = nn.ModuleList() 20 | self.obj_preds = nn.ModuleList() 21 | self.stems = nn.ModuleList() 22 | 23 | for i in range(len(in_channels)): 24 | self.stems.append(BaseConv(in_channels = int(in_channels[i] * width), out_channels = int(256 * width), ksize = 1, stride = 1, act = act)) 25 | self.cls_convs.append(nn.Sequential(*[ 26 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act), 27 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act), 28 | ])) 29 | self.cls_preds.append( 30 | nn.Conv2d(in_channels = int(256 * width), out_channels = num_classes, kernel_size = 1, stride = 1, padding = 0) 31 | ) 32 | 33 | 34 | self.reg_convs.append(nn.Sequential(*[ 35 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act), 36 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act) 37 | ])) 38 | self.reg_preds.append( 39 | nn.Conv2d(in_channels = int(256 * width), out_channels = 4, kernel_size = 1, stride = 1, padding = 0) 40 | ) 41 | self.obj_preds.append( 42 | nn.Conv2d(in_channels = int(256 * width), out_channels = 1, kernel_size = 1, stride = 1, padding = 0) 43 | ) 44 | 45 | def forward(self, inputs): 46 | #---------------------------------------------------# 47 | # inputs输入 48 | # P3_out 80, 80, 256 49 | # P4_out 40, 40, 512 50 | # P5_out 20, 20, 1024 51 | #---------------------------------------------------# 52 | outputs = [] 53 | for k, x in enumerate(inputs): 54 | #---------------------------------------------------# 55 | # 利用1x1卷积进行通道整合 56 | #---------------------------------------------------# 57 | x = self.stems[k](x) 58 | #---------------------------------------------------# 59 | # 利用两个卷积标准化激活函数来进行特征提取 60 | #---------------------------------------------------# 61 | cls_feat = self.cls_convs[k](x) 62 | #---------------------------------------------------# 63 | # 判断特征点所属的种类 64 | # 80, 80, num_classes 65 | # 40, 40, num_classes 66 | # 20, 20, num_classes 67 | #---------------------------------------------------# 68 | cls_output = self.cls_preds[k](cls_feat) 69 | 70 | #---------------------------------------------------# 71 | # 利用两个卷积标准化激活函数来进行特征提取 72 | #---------------------------------------------------# 73 | reg_feat = self.reg_convs[k](x) 74 | #---------------------------------------------------# 75 | # 特征点的回归系数 76 | # reg_pred 80, 80, 4 77 | # reg_pred 40, 40, 4 78 | # reg_pred 20, 20, 4 79 | #---------------------------------------------------# 80 | reg_output = self.reg_preds[k](reg_feat) 81 | #---------------------------------------------------# 82 | # 判断特征点是否有对应的物体 83 | # obj_pred 80, 80, 1 84 | # obj_pred 40, 40, 1 85 | # obj_pred 20, 20, 1 86 | #---------------------------------------------------# 87 | obj_output = self.obj_preds[k](reg_feat) 88 | 89 | output = torch.cat([reg_output, obj_output, cls_output], 1) 90 | outputs.append(output) 91 | return outputs 92 | 93 | class YOLOPAFPN(nn.Module): 94 | def __init__(self, depth = 1.0, width = 1.0, in_features = ("dark3", "dark4", "dark5"), in_channels = [256, 512, 1024], depthwise = False, act = "silu"): 95 | super().__init__() 96 | Conv = DWConv if depthwise else BaseConv 97 | self.backbone = CSPDarknet(depth, width, depthwise = depthwise, act = act) 98 | self.in_features = in_features 99 | 100 | self.upsample = nn.Upsample(scale_factor=2, mode="nearest") 101 | 102 | #-------------------------------------------# 103 | # 20, 20, 1024 -> 20, 20, 512 104 | #-------------------------------------------# 105 | self.lateral_conv0 = BaseConv(int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act) 106 | 107 | #-------------------------------------------# 108 | # 40, 40, 1024 -> 40, 40, 512 109 | #-------------------------------------------# 110 | self.C3_p4 = CSPLayer( 111 | int(2 * in_channels[1] * width), 112 | int(in_channels[1] * width), 113 | round(3 * depth), 114 | False, 115 | depthwise = depthwise, 116 | act = act, 117 | ) 118 | 119 | #-------------------------------------------# 120 | # 40, 40, 512 -> 40, 40, 256 121 | #-------------------------------------------# 122 | self.reduce_conv1 = BaseConv(int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act) 123 | #-------------------------------------------# 124 | # 80, 80, 512 -> 80, 80, 256 125 | #-------------------------------------------# 126 | self.C3_p3 = CSPLayer( 127 | int(2 * in_channels[0] * width), 128 | int(in_channels[0] * width), 129 | round(3 * depth), 130 | False, 131 | depthwise = depthwise, 132 | act = act, 133 | ) 134 | 135 | #-------------------------------------------# 136 | # 80, 80, 256 -> 40, 40, 256 137 | #-------------------------------------------# 138 | self.bu_conv2 = Conv(int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act) 139 | #-------------------------------------------# 140 | # 40, 40, 256 -> 40, 40, 512 141 | #-------------------------------------------# 142 | self.C3_n3 = CSPLayer( 143 | int(2 * in_channels[0] * width), 144 | int(in_channels[1] * width), 145 | round(3 * depth), 146 | False, 147 | depthwise = depthwise, 148 | act = act, 149 | ) 150 | 151 | #-------------------------------------------# 152 | # 40, 40, 512 -> 20, 20, 512 153 | #-------------------------------------------# 154 | self.bu_conv1 = Conv(int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act) 155 | #-------------------------------------------# 156 | # 20, 20, 1024 -> 20, 20, 1024 157 | #-------------------------------------------# 158 | self.C3_n4 = CSPLayer( 159 | int(2 * in_channels[1] * width), 160 | int(in_channels[2] * width), 161 | round(3 * depth), 162 | False, 163 | depthwise = depthwise, 164 | act = act, 165 | ) 166 | 167 | def forward(self, input): 168 | out_features = self.backbone.forward(input) 169 | [feat1, feat2, feat3] = [out_features[f] for f in self.in_features] 170 | 171 | #-------------------------------------------# 172 | # 20, 20, 1024 -> 20, 20, 512 173 | #-------------------------------------------# 174 | P5 = self.lateral_conv0(feat3) 175 | #-------------------------------------------# 176 | # 20, 20, 512 -> 40, 40, 512 177 | #-------------------------------------------# 178 | P5_upsample = self.upsample(P5) 179 | #-------------------------------------------# 180 | # 40, 40, 512 + 40, 40, 512 -> 40, 40, 1024 181 | #-------------------------------------------# 182 | P5_upsample = torch.cat([P5_upsample, feat2], 1) 183 | #-------------------------------------------# 184 | # 40, 40, 1024 -> 40, 40, 512 185 | #-------------------------------------------# 186 | P5_upsample = self.C3_p4(P5_upsample) 187 | 188 | #-------------------------------------------# 189 | # 40, 40, 512 -> 40, 40, 256 190 | #-------------------------------------------# 191 | P4 = self.reduce_conv1(P5_upsample) 192 | #-------------------------------------------# 193 | # 40, 40, 256 -> 80, 80, 256 194 | #-------------------------------------------# 195 | P4_upsample = self.upsample(P4) 196 | #-------------------------------------------# 197 | # 80, 80, 256 + 80, 80, 256 -> 80, 80, 512 198 | #-------------------------------------------# 199 | P4_upsample = torch.cat([P4_upsample, feat1], 1) 200 | #-------------------------------------------# 201 | # 80, 80, 512 -> 80, 80, 256 202 | #-------------------------------------------# 203 | P3_out = self.C3_p3(P4_upsample) 204 | 205 | #-------------------------------------------# 206 | # 80, 80, 256 -> 40, 40, 256 207 | #-------------------------------------------# 208 | P3_downsample = self.bu_conv2(P3_out) 209 | #-------------------------------------------# 210 | # 40, 40, 256 + 40, 40, 256 -> 40, 40, 512 211 | #-------------------------------------------# 212 | P3_downsample = torch.cat([P3_downsample, P4], 1) 213 | #-------------------------------------------# 214 | # 40, 40, 256 -> 40, 40, 512 215 | #-------------------------------------------# 216 | P4_out = self.C3_n3(P3_downsample) 217 | 218 | #-------------------------------------------# 219 | # 40, 40, 512 -> 20, 20, 512 220 | #-------------------------------------------# 221 | P4_downsample = self.bu_conv1(P4_out) 222 | #-------------------------------------------# 223 | # 20, 20, 512 + 20, 20, 512 -> 20, 20, 1024 224 | #-------------------------------------------# 225 | P4_downsample = torch.cat([P4_downsample, P5], 1) 226 | #-------------------------------------------# 227 | # 20, 20, 1024 -> 20, 20, 1024 228 | #-------------------------------------------# 229 | P5_out = self.C3_n4(P4_downsample) 230 | 231 | return (P3_out, P4_out, P5_out) 232 | 233 | class YoloBody(nn.Module): 234 | def __init__(self, num_classes, phi): 235 | super().__init__() 236 | depth_dict = {'nano': 0.33, 'tiny': 0.33, 's' : 0.33, 'm' : 0.67, 'l' : 1.00, 'x' : 1.33,} 237 | width_dict = {'nano': 0.25, 'tiny': 0.375, 's' : 0.50, 'm' : 0.75, 'l' : 1.00, 'x' : 1.25,} 238 | depth, width = depth_dict[phi], width_dict[phi] 239 | depthwise = True if phi == 'nano' else False 240 | 241 | self.backbone = YOLOPAFPN(depth, width, depthwise=depthwise) 242 | self.head = YOLOXHead(num_classes, width, depthwise=depthwise) 243 | 244 | def forward(self, x): 245 | fpn_outs = self.backbone.forward(x) 246 | outputs = self.head.forward(fpn_outs) 247 | return outputs 248 | -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /segment_anything/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/automatic_mask_generator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/automatic_mask_generator.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/build_sam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/build_sam.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/build_sam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/build_sam.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/predictor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/predictor.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/predictor.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/predictor.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from functools import partial 7 | from pathlib import Path 8 | import urllib.request 9 | import torch 10 | 11 | from .modeling import ( 12 | ImageEncoderViT, 13 | MaskDecoder, 14 | PromptEncoder, 15 | Sam, 16 | TwoWayTransformer, 17 | ) 18 | 19 | 20 | def build_sam_vit_h(checkpoint=None): 21 | return _build_sam( 22 | encoder_embed_dim=1280, 23 | encoder_depth=32, 24 | encoder_num_heads=16, 25 | encoder_global_attn_indexes=[7, 15, 23, 31], 26 | checkpoint=checkpoint, 27 | ) 28 | 29 | 30 | build_sam = build_sam_vit_h 31 | 32 | 33 | def build_sam_vit_l(checkpoint=None): 34 | return _build_sam( 35 | encoder_embed_dim=1024, 36 | encoder_depth=24, 37 | encoder_num_heads=16, 38 | encoder_global_attn_indexes=[5, 11, 17, 23], 39 | checkpoint=checkpoint, 40 | ) 41 | 42 | 43 | def build_sam_vit_b(checkpoint=None): 44 | return _build_sam( 45 | encoder_embed_dim=768, 46 | encoder_depth=12, 47 | encoder_num_heads=12, 48 | encoder_global_attn_indexes=[2, 5, 8, 11], 49 | checkpoint=checkpoint, 50 | ) 51 | 52 | 53 | sam_model_registry = { 54 | "default": build_sam_vit_h, 55 | "vit_h": build_sam_vit_h, 56 | "vit_l": build_sam_vit_l, 57 | "vit_b": build_sam_vit_b, 58 | } 59 | 60 | 61 | 62 | 63 | def _build_sam( 64 | encoder_embed_dim, 65 | encoder_depth, 66 | encoder_num_heads, 67 | encoder_global_attn_indexes, 68 | checkpoint=None, 69 | ): 70 | prompt_embed_dim = 256 71 | image_size = 1024 72 | vit_patch_size = 16 73 | image_embedding_size = image_size // vit_patch_size 74 | sam = Sam( 75 | image_encoder=ImageEncoderViT( 76 | depth=encoder_depth, 77 | embed_dim=encoder_embed_dim, 78 | img_size=image_size, 79 | mlp_ratio=4, 80 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 81 | num_heads=encoder_num_heads, 82 | patch_size=vit_patch_size, 83 | qkv_bias=True, 84 | use_rel_pos=True, 85 | global_attn_indexes=encoder_global_attn_indexes, 86 | window_size=14, 87 | out_chans=prompt_embed_dim, 88 | ), 89 | prompt_encoder=PromptEncoder( 90 | embed_dim=prompt_embed_dim, 91 | image_embedding_size=(image_embedding_size, image_embedding_size), 92 | input_image_size=(image_size, image_size), 93 | mask_in_chans=16, 94 | ), 95 | mask_decoder=MaskDecoder( 96 | num_multimask_outputs=3, 97 | transformer=TwoWayTransformer( 98 | depth=2, 99 | embedding_dim=prompt_embed_dim, 100 | mlp_dim=2048, 101 | num_heads=8, 102 | ), 103 | transformer_dim=prompt_embed_dim, 104 | iou_head_depth=3, 105 | iou_head_hidden_dim=256, 106 | ), 107 | pixel_mean=[123.675, 116.28, 103.53], 108 | pixel_std=[58.395, 57.12, 57.375], 109 | ) 110 | sam.eval() 111 | checkpoint = Path(checkpoint) 112 | if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists(): 113 | cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ") 114 | if len(cmd) == 0 or cmd.lower() == 'y': 115 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 116 | print("Downloading SAM ViT-B checkpoint...") 117 | urllib.request.urlretrieve( 118 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 119 | checkpoint, 120 | ) 121 | print(checkpoint.name, " is downloaded!") 122 | elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists(): 123 | cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ") 124 | if len(cmd) == 0 or cmd.lower() == 'y': 125 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 126 | print("Downloading SAM ViT-H checkpoint...") 127 | urllib.request.urlretrieve( 128 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 129 | checkpoint, 130 | ) 131 | print(checkpoint.name, " is downloaded!") 132 | elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists(): 133 | cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ") 134 | if len(cmd) == 0 or cmd.lower() == 'y': 135 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 136 | print("Downloading SAM ViT-L checkpoint...") 137 | urllib.request.urlretrieve( 138 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 139 | checkpoint, 140 | ) 141 | print(checkpoint.name, " is downloaded!") 142 | 143 | 144 | if checkpoint is not None: 145 | with open(checkpoint, "rb") as f: 146 | state_dict = torch.load(f) 147 | sam.load_state_dict(state_dict) 148 | return sam 149 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/common.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/common.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/sam.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/sam.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/sam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/sam.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/transformer.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/transformer.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 400, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | masks=torch.sigmoid(masks.view(masks.shape[0], 1, -1)).view(masks.shape[0], 1, 256, 256) 109 | return masks, iou_pred 110 | 111 | def predict_masks( 112 | self, 113 | image_embeddings: torch.Tensor, 114 | image_pe: torch.Tensor, 115 | sparse_prompt_embeddings: torch.Tensor, 116 | dense_prompt_embeddings: torch.Tensor, 117 | ) -> Tuple[torch.Tensor, torch.Tensor]: 118 | """Predicts masks. See 'forward' for more details.""" 119 | # Concatenate output tokens 120 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 121 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 122 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 123 | 124 | # Expand per-image data in batch direction to be per-mask 125 | if image_embeddings.shape[0] != tokens.shape[0]: 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | else: 128 | src = image_embeddings 129 | src = src + dense_prompt_embeddings 130 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 131 | b, c, h, w = src.shape 132 | 133 | # Run the transformer 134 | hs, src = self.transformer(src, pos_src, tokens) 135 | iou_token_out = hs[:, 0, :] 136 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 137 | 138 | # Upscale mask embeddings and predict masks using the mask tokens 139 | src = src.transpose(1, 2).view(b, c, h, w) 140 | upscaled_embedding = self.output_upscaling(src) 141 | hyper_in_list: List[torch.Tensor] = [] 142 | for i in range(self.num_mask_tokens): 143 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 144 | hyper_in = torch.stack(hyper_in_list, dim=1) 145 | b, c, h, w = upscaled_embedding.shape 146 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 147 | # Generate mask quality predictions 148 | iou_pred = self.iou_prediction_head(iou_token_out) 149 | 150 | return masks, iou_pred 151 | 152 | 153 | # Lightly adapted from 154 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 155 | class MLP(nn.Module): 156 | def __init__( 157 | self, 158 | input_dim: int, 159 | hidden_dim: int, 160 | output_dim: int, 161 | num_layers: int, 162 | sigmoid_output: bool = False, 163 | ) -> None: 164 | super().__init__() 165 | self.num_layers = num_layers 166 | h = [hidden_dim] * (num_layers - 1) 167 | self.layers = nn.ModuleList( 168 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 169 | ) 170 | self.sigmoid_output = sigmoid_output 171 | 172 | def forward(self, x): 173 | for i, layer in enumerate(self.layers): 174 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 175 | if self.sigmoid_output: 176 | x = F.sigmoid(x) 177 | return x 178 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels=torch.squeeze(labels) 86 | labels = torch.cat([labels, padding_label], dim=1) 87 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 88 | point_embedding[labels == -1] = 0.0 89 | point_embedding[labels == -1] += self.not_a_point_embed.weight 90 | point_embedding[labels == 0] += self.point_embeddings[0].weight 91 | point_embedding[labels == 1] += self.point_embeddings[1].weight 92 | return point_embedding 93 | 94 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 95 | """Embeds box prompts.""" 96 | boxes = boxes + 0.5 # Shift to center of pixel 97 | coords = boxes.reshape(-1, 2, 2) 98 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 99 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 100 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 101 | return corner_embedding 102 | 103 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 104 | """Embeds mask inputs.""" 105 | mask_embedding = self.mask_downscaling(masks) 106 | return mask_embedding 107 | 108 | def _get_batch_size( 109 | self, 110 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 111 | boxes: Optional[torch.Tensor], 112 | masks: Optional[torch.Tensor], 113 | ) -> int: 114 | """ 115 | Gets the batch size of the output given the batch size of the input prompts. 116 | """ 117 | if points is not None: 118 | return points.shape[0] 119 | elif boxes is not None: 120 | return boxes.shape[0] 121 | elif masks is not None: 122 | return masks.shape[0] 123 | else: 124 | return 1 125 | 126 | def _get_device(self) -> torch.device: 127 | return self.point_embeddings[0].weight.device 128 | 129 | def forward( 130 | self, 131 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 132 | boxes: Optional[torch.Tensor], 133 | masks: Optional[torch.Tensor], 134 | ) -> Tuple[torch.Tensor, torch.Tensor]: 135 | """ 136 | Embeds different types of prompts, returning both sparse and dense 137 | embeddings. 138 | 139 | Arguments: 140 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 141 | and labels to embed. 142 | boxes (torch.Tensor or none): boxes to embed 143 | masks (torch.Tensor or none): masks to embed 144 | 145 | Returns: 146 | torch.Tensor: sparse embeddings for the points and boxes, with shape 147 | BxNx(embed_dim), where N is determined by the number of input points 148 | and boxes. 149 | torch.Tensor: dense embeddings for the masks, in the shape 150 | Bx(embed_dim)x(embed_H)x(embed_W) 151 | """ 152 | bs = self._get_batch_size(points, boxes, masks) 153 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 154 | if points is not None: 155 | coords=points.to('cuda:1') 156 | labels=torch.ones([points.shape[0],points.shape[1],1]).to('cuda:1') 157 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 158 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 159 | if boxes is not None: 160 | box_embeddings = self._embed_boxes(boxes) 161 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 162 | 163 | if masks is not None: 164 | dense_embeddings = self._embed_masks(masks) 165 | else: 166 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 167 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 168 | ) 169 | 170 | return sparse_embeddings, dense_embeddings 171 | 172 | 173 | class PositionEmbeddingRandom(nn.Module): 174 | """ 175 | Positional encoding using random spatial frequencies. 176 | """ 177 | 178 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 179 | super().__init__() 180 | if scale is None or scale <= 0.0: 181 | scale = 1.0 182 | self.register_buffer( 183 | "positional_encoding_gaussian_matrix", 184 | scale * torch.randn((2, num_pos_feats)), 185 | ) 186 | 187 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 188 | """Positionally encode points that are normalized to [0,1].""" 189 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 190 | coords = 2 * coords - 1 191 | coords = coords @ self.positional_encoding_gaussian_matrix 192 | coords = 2 * np.pi * coords 193 | # outputs d_1 x ... x d_n x C shape 194 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 195 | 196 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 197 | """Generate positional encoding for a grid of the specified size.""" 198 | h, w = size 199 | device: Any = self.positional_encoding_gaussian_matrix.device 200 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 201 | y_embed = grid.cumsum(dim=0) - 0.5 202 | x_embed = grid.cumsum(dim=1) - 0.5 203 | y_embed = y_embed / h 204 | x_embed = x_embed / w 205 | 206 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 207 | return pe.permute(2, 0, 1) # C x H x W 208 | 209 | def forward_with_coords( 210 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 211 | ) -> torch.Tensor: 212 | """Positionally encode points that are not normalized to [0,1].""" 213 | coords = coords_input.clone() 214 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 215 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 216 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 217 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | 108 | points=points, 109 | boxes=image_record.get("boxes", None), 110 | masks=image_record.get("mask_inputs", None), 111 | ) 112 | low_res_masks, iou_predictions = self.mask_decoder( 113 | image_embeddings=curr_embedding.unsqueeze(0), 114 | image_pe=self.prompt_encoder.get_dense_pe(), 115 | sparse_prompt_embeddings=sparse_embeddings, 116 | dense_prompt_embeddings=dense_embeddings, 117 | multimask_output=multimask_output, 118 | ) 119 | masks = self.postprocess_masks( 120 | low_res_masks, 121 | input_size=image_record["image"].shape[-2:], 122 | original_size=image_record["original_size"], 123 | ) 124 | masks = masks > self.mask_threshold 125 | outputs.append( 126 | { 127 | "masks": masks, 128 | "iou_predictions": iou_predictions, 129 | "low_res_logits": low_res_masks, 130 | } 131 | ) 132 | return outputs 133 | 134 | def postprocess_masks( 135 | self, 136 | masks: torch.Tensor, 137 | input_size: Tuple[int, ...], 138 | original_size: Tuple[int, ...], 139 | ) -> torch.Tensor: 140 | """ 141 | Remove padding and upscale masks to the original image size. 142 | 143 | Arguments: 144 | masks (torch.Tensor): Batched masks from the mask_decoder, 145 | in BxCxHxW format. 146 | input_size (tuple(int, int)): The size of the image input to the 147 | model, in (H, W) format. Used to remove padding. 148 | original_size (tuple(int, int)): The original size of the image 149 | before resizing for input to the model, in (H, W) format. 150 | 151 | Returns: 152 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 153 | is given by original_size. 154 | """ 155 | masks = F.interpolate( 156 | masks, 157 | (self.image_encoder.img_size, self.image_encoder.img_size), 158 | mode="bilinear", 159 | align_corners=False, 160 | ) 161 | masks = masks[..., : input_size[0], : input_size[1]] 162 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 163 | return masks 164 | 165 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 166 | """Normalize pixel values and pad to a square input.""" 167 | # Normalize colors 168 | x = (x - self.pixel_mean) / self.pixel_std 169 | 170 | # Pad 171 | h, w = x.shape[-2:] 172 | padh = self.image_encoder.img_size - h 173 | padw = self.image_encoder.img_size - w 174 | x = F.pad(x, (0, padw, 0, padh)) 175 | return x 176 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks_np = masks[0].detach().cpu().numpy() 164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 166 | return masks_np, iou_predictions_np, low_res_masks_np 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/amg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/amg.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/amg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/amg.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/transforms.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/transforms.cpython-310.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape(old_h, old_w, self.target_length) 40 | new_coords = np.empty_like(coords) 41 | new_coords[..., 0] = coords[..., 0] * (new_w / old_w) 42 | new_coords[..., 1] = coords[..., 1] * (new_h / old_h) 43 | return new_coords 44 | 45 | 46 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 47 | """ 48 | Expects a numpy array shape Bx4. Requires the original image size 49 | in (H, W) format. 50 | """ 51 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 52 | return boxes.reshape(-1, 4) 53 | 54 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 55 | """ 56 | Expects batched images with shape BxCxHxW and float format. This 57 | transformation may not exactly match apply_image. apply_image is 58 | the transformation expected by the model. 59 | """ 60 | # Expects an image in BCHW format. May not exactly match apply_image. 61 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 62 | return F.interpolate( 63 | image, target_size, mode="bilinear", align_corners=False, antialias=True 64 | ) 65 | 66 | def apply_coords_torch( 67 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 68 | ) -> torch.Tensor: 69 | """ 70 | Expects a torch tensor with length 2 in the last dimension. Requires the 71 | original image size in (H, W) format. 72 | """ 73 | old_h, old_w = original_size 74 | new_h, new_w = self.get_preprocess_shape( 75 | original_size[0], original_size[1], self.target_length 76 | ) 77 | coords = deepcopy(coords).to(torch.float) 78 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 79 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 80 | return coords 81 | 82 | def apply_boxes_torch( 83 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 84 | ) -> torch.Tensor: 85 | """ 86 | Expects a torch tensor with shape Bx4. Requires the original image 87 | size in (H, W) format. 88 | """ 89 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 90 | return boxes.reshape(-1, 4) 91 | 92 | @staticmethod 93 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 94 | """ 95 | Compute the output size given input size and target long side length. 96 | """ 97 | scale = long_side_length * 1.0 / max(oldh, oldw) 98 | newh, neww = oldh * scale, oldw * scale 99 | neww = int(neww + 0.5) 100 | newh = int(newh + 0.5) 101 | return (newh, neww) 102 | -------------------------------------------------------------------------------- /split.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import shutil 4 | 5 | # 源文件夹路径 6 | src_folder = '/home/disk1/cs/project/dataset/COCO/val_gt/' 7 | 8 | # 目标文件夹路径 9 | dst_folder = '/home/disk1/cs/project/dataset/COCO/split/val_gt/' 10 | 11 | # 每组文件数 12 | group_size = 8 13 | 14 | # 遍历文件夹中所有图片文件,按顺序分组 15 | file_groups = [] 16 | for i, file_name in enumerate(sorted(os.listdir(src_folder))): 17 | # if file_name.endswith('.jpg'): 18 | group_idx = i // group_size 19 | if len(file_groups) <= group_idx: 20 | file_groups.append([]) 21 | file_groups[group_idx].append(file_name) 22 | 23 | # 创建目标文件夹 24 | os.makedirs(dst_folder, exist_ok=True) 25 | 26 | # 将每组文件复制到目标文件夹 27 | for i, group in enumerate(file_groups): 28 | group_folder = os.path.join(dst_folder, f'group_{i}') 29 | os.makedirs(group_folder, exist_ok=True) 30 | for file_name in group: 31 | src_file = os.path.join(src_folder, file_name) 32 | dst_file = os.path.join(group_folder, file_name) 33 | shutil.copy(src_file, dst_file) 34 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import os 4 | from skimage import io 5 | join = os.path.join 6 | from tqdm import tqdm 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | import monai 10 | from segment_anything import SamPredictor, sam_model_registry 11 | from segment_anything.utils.transforms import ResizeLongestSide 12 | from utils.SurfaceDice import compute_dice_coefficient 13 | import cv2 14 | from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score, jaccard_score 15 | # set seeds 16 | torch.manual_seed(2023) 17 | np.random.seed(2023) 18 | from skimage import io 19 | from utils_metrics import * 20 | from skimage import transform, io, segmentation 21 | from segment.yolox import YOLOX 22 | import random 23 | import math 24 | from functools import partial 25 | ############################################################################################################## 26 | num_epochs = 10 27 | ts_npz_path='/home/cs/project/medsam_tongue/data/tongueset3_npz/test/' 28 | npz_tr_path = '/home/cs/project/medsam_tongue/data/tongue_train_npz/' 29 | model_type = 'vit_b' 30 | checkpoint = '/home/cs/project/medsam_tongue/pretrained_model/final.pth' 31 | device = 'cuda:1' 32 | model_save_path = './logs/' 33 | if_save=False 34 | if_onlytest=True 35 | batch_size=32 36 | prompt_type='no' 37 | lr_decay_type= "cos" 38 | Init_lr= 1e-4 39 | point_num=3 40 | segment=None 41 | ############################################################################################################### 42 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10): 43 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): 44 | if iters <= warmup_total_iters: 45 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start 46 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start 47 | elif iters >= total_iters - no_aug_iter: 48 | lr = min_lr 49 | else: 50 | lr = min_lr + 0.5 * (lr - min_lr) * ( 51 | 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter)) 52 | ) 53 | return lr 54 | 55 | def step_lr(lr, decay_rate, step_size, iters): 56 | if step_size < 1: 57 | raise ValueError("step_size must above 1.") 58 | n = iters // step_size 59 | out_lr = lr * decay_rate ** n 60 | return out_lr 61 | 62 | if lr_decay_type == "cos": 63 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) 64 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) 65 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) 66 | func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) 67 | else: 68 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) 69 | step_size = total_iters / step_num 70 | func = partial(step_lr, lr, decay_rate, step_size) 71 | 72 | return func 73 | #%% create a dataset class to load npz data and return back image embeddings and ground truth 74 | class NpzDataset(Dataset): 75 | def __init__(self, data_root): 76 | self.npz_data=np.load(data_root) 77 | self.ori_gts = self.npz_data['gts'] 78 | self.img_embeddings = self.npz_data['img_embeddings'] 79 | self.imgs=self.npz_data['imgs'] 80 | self.model=segment 81 | self.point_num=point_num 82 | def __len__(self): 83 | return self.ori_gts.shape[0] 84 | 85 | def __getitem__(self, index): 86 | img_embed = self.img_embeddings[index] 87 | gt2D = self.ori_gts[index] 88 | img=self.imgs[index] 89 | H, W = gt2D.shape 90 | 91 | # ############################box############################################################## 92 | if self.model!=None: 93 | img=Image.fromarray(img) 94 | img= self.model.get_miou_png(img) 95 | y_indices, x_indices = np.where(img > 0) 96 | x_min, x_max = np.min(x_indices), np.max(x_indices) 97 | y_min, y_max = np.min(y_indices), np.max(y_indices) 98 | bboxes = np.array([x_min, y_min, x_max, y_max]) 99 | bboxes=np.array([x_min,y_min,x_max,y_max]) 100 | points=np.where(img > 0) 101 | random_points = random.choices(range(len(points[0])), k=self.point_num) 102 | random_points = [(points[0][i], points[1][i]) for i in random_points] 103 | 104 | else: 105 | y_indices, x_indices = np.where(gt2D > 0) 106 | x_min, x_max = np.min(x_indices), np.max(x_indices) 107 | y_min, y_max = np.min(y_indices), np.max(y_indices) 108 | bboxes = np.array([x_min, y_min, x_max, y_max]) 109 | points=np.where(gt2D > 0) 110 | random_points = random.choices(range(len(points[0])), k=self.point_num) 111 | random_points = [(points[0][i], points[1][i]) for i in random_points] 112 | 113 | return torch.tensor(img_embed).float(), torch.tensor(gt2D[None, :,:]).long(), torch.tensor(bboxes).float(),torch.tensor(img).float(),torch.tensor(random_points).float() 114 | #####################################################Begin############################################################################ 115 | Min_lr=Init_lr*0.01 116 | lr_limit_max = Init_lr 117 | lr_limit_min = 3e-4 118 | Init_lr_fit = min(max(batch_size / batch_size * Init_lr, lr_limit_min), lr_limit_max) 119 | Min_lr_fit = min(max(batch_size / batch_size * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 120 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, num_epochs) 121 | train_losses = [] 122 | val_losses = [] 123 | best_iou=0 124 | best_pa=0 125 | best_acc=0 126 | 127 | 128 | sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device) 129 | seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')#%% train 130 | os.makedirs(model_save_path, exist_ok=True) 131 | 132 | for epoch in range(num_epochs): 133 | print(f'EPOCH: {epoch}') 134 | epoch_loss = 0 135 | ###############################################################Test################################################################## 136 | sam_model.eval() 137 | val_gts=[] 138 | val_preds=[] 139 | with torch.no_grad(): 140 | for f in os.listdir(ts_npz_path): 141 | ts_dataset = NpzDataset(join(ts_npz_path,f)) 142 | ts_dataloader = DataLoader(ts_dataset, batch_size=batch_size, shuffle=True) 143 | for step, (image_embedding, gt2D, boxes,img,points) in enumerate(ts_dataloader): 144 | if prompt_type=='box': 145 | box_np = boxes.numpy() 146 | sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size) 147 | box = sam_trans.apply_boxes(box_np, (img.shape[-2], img.shape[-1])) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=device) 149 | if len(box_torch.shape) == 2: 150 | box_torch = box_torch[:, None, :] 151 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( 152 | points=None, 153 | boxes=box_torch, 154 | masks=None, 155 | ) 156 | elif prompt_type=='point': 157 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( 158 | points=points, 159 | boxes=None, 160 | masks=None, 161 | ) 162 | elif prompt_type=='no': 163 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( 164 | points=None, 165 | boxes=None, 166 | masks=None, 167 | ) 168 | mask_predictions, _ = sam_model.mask_decoder( 169 | image_embeddings=image_embedding.to(device), # (B, 256, 64, 64) 170 | image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) 171 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) 172 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) 173 | multimask_output=False, 174 | ) 175 | for i in range(mask_predictions.shape[0]): 176 | mask = mask_predictions[i] 177 | mask = mask.cpu().detach().numpy().squeeze() 178 | mask = cv2.resize((mask > 0.5).astype(np.uint8),(gt2D.shape[2], gt2D.shape[3])) 179 | gt_data=gt2D[i].cpu().numpy().astype(np.uint8) 180 | val_gts.append(gt_data.astype(np.uint8)) 181 | val_preds.append(mask.astype(np.uint8)) 182 | iou,pa,acc=compute_mIoU(val_gts,val_preds) 183 | if iou> best_iou: 184 | best_iou=iou 185 | best_pa=pa 186 | best_acc=acc 187 | if if_onlytest: 188 | continue 189 | if if_save==True: 190 | torch.save(sam_model.state_dict(), join(model_save_path, 'best.pth'))# plot loss 191 | print('best_miou:'+str(best_iou)) 192 | print('best_pa:'+str(best_pa)) 193 | print('best_acc:'+str(best_acc)) 194 | if if_onlytest: 195 | continue 196 | ###############################################################Train################################################################## 197 | sam_model.train() 198 | lr = lr_scheduler_func(epoch) 199 | optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr,weight_decay=0) 200 | for f in os.listdir(npz_tr_path): 201 | train_dataset = NpzDataset(join(npz_tr_path,f)) 202 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 203 | for step, (image_embedding, gt2D, boxes,img,points) in enumerate(train_dataloader): 204 | with torch.no_grad(): 205 | if prompt_type=='box': 206 | box_np = boxes.numpy() 207 | sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size) 208 | box = sam_trans.apply_boxes(box_np, (img.shape[-2], img.shape[-1])) 209 | box_torch = torch.as_tensor(box, dtype=torch.float, device=device) 210 | if len(box_torch.shape) == 2: 211 | box_torch = box_torch[:, None, :] 212 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( 213 | points=None, 214 | boxes=box_torch, 215 | masks=None, 216 | ) 217 | elif prompt_type=='point': 218 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( 219 | points=points, 220 | boxes=None, 221 | masks=None, 222 | ) 223 | elif prompt_type=='no': 224 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( 225 | points=None, 226 | boxes=None, 227 | masks=None, 228 | ) 229 | mask_predictions, _ = sam_model.mask_decoder( 230 | image_embeddings=image_embedding.to(device), # (B, 256, 64, 64) 231 | image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) 232 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) 233 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) 234 | multimask_output=False, 235 | ) 236 | mask_predictions= F.interpolate(mask_predictions, size=(gt2D.shape[2],gt2D.shape[3]), mode='bilinear', align_corners=False) 237 | gt2D=gt2D.to(device) 238 | loss = seg_loss(mask_predictions, gt2D) 239 | optimizer.zero_grad() 240 | loss.backward() 241 | optimizer.step() 242 | ################################################################################################################################ 243 | if if_onlytest is False: 244 | plt.plot(train_losses) 245 | plt.title('Train Loss') 246 | plt.xlabel('Epoch') 247 | plt.ylabel('train_loss') 248 | plt.show() 249 | plt.savefig(join(model_save_path, 'train_loss.png')) 250 | plt.close() 251 | plt.plot(val_losses) 252 | plt.title('Val Loss') 253 | plt.xlabel('Epoch') 254 | plt.ylabel('val_loss') 255 | plt.show() 256 | plt.savefig(join(model_save_path, 'val_loss.png')) 257 | plt.close() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /utils/__pycache__/SurfaceDice.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/SurfaceDice.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/callbacks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/callbacks.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/callbacks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/callbacks.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_bbox.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_bbox.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_bbox.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_bbox.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_fit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_fit.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_fit.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_fit.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_map.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_map.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_map.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_map.cpython-39.pyc -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import scipy.signal 7 | from matplotlib import pyplot as plt 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | import shutil 11 | import numpy as np 12 | 13 | from PIL import Image 14 | from tqdm import tqdm 15 | from .utils import cvtColor, preprocess_input, resize_image 16 | from .utils_bbox import decode_outputs, non_max_suppression 17 | from .utils_map import get_coco_map, get_map 18 | 19 | 20 | class LossHistory(): 21 | def __init__(self, log_dir, model, input_shape): 22 | self.log_dir = log_dir 23 | self.losses = [] 24 | self.val_loss = [] 25 | 26 | os.makedirs(self.log_dir) 27 | self.writer = SummaryWriter(self.log_dir) 28 | try: 29 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) 30 | self.writer.add_graph(model, dummy_input) 31 | except: 32 | pass 33 | 34 | def append_loss(self, epoch, loss, val_loss): 35 | if not os.path.exists(self.log_dir): 36 | os.makedirs(self.log_dir) 37 | 38 | self.losses.append(loss) 39 | self.val_loss.append(val_loss) 40 | 41 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: 42 | f.write(str(loss)) 43 | f.write("\n") 44 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: 45 | f.write(str(val_loss)) 46 | f.write("\n") 47 | 48 | self.writer.add_scalar('loss', loss, epoch) 49 | self.writer.add_scalar('val_loss', val_loss, epoch) 50 | self.loss_plot() 51 | 52 | def loss_plot(self): 53 | iters = range(len(self.losses)) 54 | 55 | plt.figure() 56 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') 57 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') 58 | try: 59 | if len(self.losses) < 25: 60 | num = 5 61 | else: 62 | num = 15 63 | 64 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') 65 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') 66 | except: 67 | pass 68 | 69 | plt.grid(True) 70 | plt.xlabel('Epoch') 71 | plt.ylabel('Loss') 72 | plt.legend(loc="upper right") 73 | 74 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) 75 | 76 | plt.cla() 77 | plt.close("all") 78 | 79 | class EvalCallback(): 80 | def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \ 81 | map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1): 82 | super(EvalCallback, self).__init__() 83 | 84 | self.net = net 85 | self.input_shape = input_shape 86 | self.class_names = class_names 87 | self.num_classes = num_classes 88 | self.val_lines = val_lines 89 | self.log_dir = log_dir 90 | self.cuda = cuda 91 | self.map_out_path = map_out_path 92 | self.max_boxes = max_boxes 93 | self.confidence = confidence 94 | self.nms_iou = nms_iou 95 | self.letterbox_image = letterbox_image 96 | self.MINOVERLAP = MINOVERLAP 97 | self.eval_flag = eval_flag 98 | self.period = period 99 | 100 | self.maps = [0] 101 | self.epoches = [0] 102 | if self.eval_flag: 103 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: 104 | f.write(str(0)) 105 | f.write("\n") 106 | 107 | def get_map_txt(self, image_id, image, class_names, map_out_path): 108 | f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w") 109 | image_shape = np.array(np.shape(image)[0:2]) 110 | #---------------------------------------------------------# 111 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 112 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 113 | #---------------------------------------------------------# 114 | image = cvtColor(image) 115 | #---------------------------------------------------------# 116 | # 给图像增加灰条,实现不失真的resize 117 | # 也可以直接resize进行识别 118 | #---------------------------------------------------------# 119 | image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image) 120 | #---------------------------------------------------------# 121 | # 添加上batch_size维度 122 | #---------------------------------------------------------# 123 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) 124 | 125 | with torch.no_grad(): 126 | images = torch.from_numpy(image_data) 127 | if self.cuda: 128 | images = images.cuda() 129 | #---------------------------------------------------------# 130 | # 将图像输入网络当中进行预测! 131 | #---------------------------------------------------------# 132 | outputs = self.net(images) 133 | outputs = decode_outputs(outputs, self.input_shape) 134 | #---------------------------------------------------------# 135 | # 将预测框进行堆叠,然后进行非极大抑制 136 | #---------------------------------------------------------# 137 | results = non_max_suppression(outputs, self.num_classes, self.input_shape, 138 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou) 139 | 140 | if results[0] is None: 141 | return 142 | 143 | top_label = np.array(results[0][:, 6], dtype = 'int32') 144 | top_conf = results[0][:, 4] * results[0][:, 5] 145 | top_boxes = results[0][:, :4] 146 | 147 | top_100 = np.argsort(top_conf)[::-1][:self.max_boxes] 148 | top_boxes = top_boxes[top_100] 149 | top_conf = top_conf[top_100] 150 | top_label = top_label[top_100] 151 | 152 | for i, c in list(enumerate(top_label)): 153 | predicted_class = self.class_names[int(c)] 154 | box = top_boxes[i] 155 | score = str(top_conf[i]) 156 | 157 | top, left, bottom, right = box 158 | if predicted_class not in class_names: 159 | continue 160 | 161 | f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) 162 | 163 | f.close() 164 | return 165 | 166 | def on_epoch_end(self, epoch, model_eval): 167 | if epoch % self.period == 0 and self.eval_flag: 168 | self.net = model_eval 169 | if not os.path.exists(self.map_out_path): 170 | os.makedirs(self.map_out_path) 171 | if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")): 172 | os.makedirs(os.path.join(self.map_out_path, "ground-truth")) 173 | if not os.path.exists(os.path.join(self.map_out_path, "detection-results")): 174 | os.makedirs(os.path.join(self.map_out_path, "detection-results")) 175 | print("Get map.") 176 | for annotation_line in tqdm(self.val_lines): 177 | line = annotation_line.split() 178 | image_id = os.path.basename(line[0]).split('.')[0] 179 | #------------------------------# 180 | # 读取图像并转换成RGB图像 181 | #------------------------------# 182 | image = Image.open(line[0]) 183 | #------------------------------# 184 | # 获得预测框 185 | #------------------------------# 186 | gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) 187 | #------------------------------# 188 | # 获得预测txt 189 | #------------------------------# 190 | self.get_map_txt(image_id, image, self.class_names, self.map_out_path) 191 | 192 | #------------------------------# 193 | # 获得真实框txt 194 | #------------------------------# 195 | with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: 196 | for box in gt_boxes: 197 | left, top, right, bottom, obj = box 198 | obj_name = self.class_names[obj] 199 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) 200 | 201 | print("Calculate Map.") 202 | try: 203 | temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1] 204 | except: 205 | temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path) 206 | self.maps.append(temp_map) 207 | self.epoches.append(epoch) 208 | 209 | with open("./epoch_map.txt", 'a') as f: 210 | f.write(str(temp_map)) 211 | f.write("\n") 212 | 213 | plt.figure() 214 | plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map') 215 | 216 | plt.grid(True) 217 | plt.xlabel('Epoch') 218 | plt.ylabel('Map %s'%str(self.MINOVERLAP)) 219 | plt.title('A Map Curve') 220 | plt.legend(loc="upper right") 221 | 222 | plt.savefig(os.path.join(self.log_dir, "epoch_map.png")) 223 | plt.cla() 224 | plt.close("all") 225 | 226 | print("Get map done.") 227 | shutil.rmtree(self.map_out_path) 228 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import cv2 5 | from typing import Any, Tuple 6 | import torch 7 | from torch.utils.data import Dataset 8 | import torchvision.transforms.functional as TF 9 | 10 | 11 | class MedSamDataset(Dataset): 12 | def __init__( 13 | self, 14 | df: pd.DataFrame, 15 | image_col: str, 16 | mask_col: str, 17 | image_dir: Any = None, 18 | mask_dir: str = None, 19 | image_size: Tuple = (256, 256), 20 | ): 21 | """ 22 | PyTorch dataset class for loading image,mask and bbox pairs from a dataframe. 23 | The dataframe will need to have atleast two columns for the image and mask file names. The columns can either have the full or relative 24 | path of the images or just the file names. 25 | If only file names are given in the columns, the `image_dir` and `mask_dir` arguments should be specified. 26 | 27 | Args: 28 | df (pd.DataFrame): the pandas dataframe object 29 | image_col (str): the name of the column on the dataframe that holds the image file names. 30 | mask_col (str): the name of the column on the dataframe that holds the mask file names. 31 | image_dir (Any, optional): Path to the input image directory. Defaults to None. 32 | mask_dir (str, optional): Path to the mask images directory. Defaults to None. 33 | image_size (Tuple, optional): image size. Defaults to (256, 256). 34 | """ 35 | self.df = df 36 | self.image_dir = image_dir 37 | self.mask_dir = mask_dir 38 | self.image_col = image_col 39 | self.mask_col = mask_col 40 | self.image_size = image_size 41 | 42 | def __len__(self): 43 | return len(self.df) 44 | 45 | def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 46 | # read dataframe row 47 | row = self.df.iloc[idx] 48 | # If the `image_dir` attribute is set, the path will be relative to that directory. 49 | # Otherwise, the path will be the value of the `row[self.image_col]` attribute. 50 | image_file = ( 51 | os.path.join(self.image_dir, row[self.image_col]) 52 | if self.image_dir 53 | else row[self.image_col] 54 | ) 55 | mask_file = ( 56 | os.path.join(self.mask_dir, row[self.mask_col]) 57 | if self.mask_dir 58 | else row[self.mask_col] 59 | ) 60 | 61 | if not os.path.exists(image_file): 62 | raise FileNotFoundError(f"Couldn't find image {image_file}") 63 | if not os.path.exists(mask_file): 64 | raise FileNotFoundError(f"Couldn't find image {mask_file}") 65 | 66 | # read image and mask files 67 | image_data = cv2.imread(image_file) 68 | # read mask as gray scale 69 | mask_data = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE) 70 | 71 | return self._preprocess(image_data, mask_data) 72 | 73 | def _preprocess( 74 | self, image: np.ndarray, mask: np.ndarray 75 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 76 | # Threshold mask to binary 77 | mask = cv2.threshold(mask, 127.0, 255.0, cv2.THRESH_BINARY)[1] 78 | # convert to tensor 79 | image = TF.to_tensor(image) 80 | mask = TF.to_tensor(mask) 81 | # min-max normalize and scale 82 | image = (image - image.min()) / (image.max() - image.min()) * 255.0 83 | # resize 84 | image = TF.resize(image, self.image_size, antialias=True) 85 | mask = TF.resize(mask, self.image_size, antialias=True) 86 | 87 | bbox = self._get_bbox(mask) 88 | 89 | return image, mask, bbox 90 | 91 | def _get_bbox(self, mask: torch.Tensor) -> torch.Tensor: 92 | _, y_indices, x_indices = torch.where(mask > 0) 93 | 94 | x_min, y_min = (x_indices.min(), y_indices.min()) 95 | x_max, y_max = (x_indices.max(), y_indices.max()) 96 | 97 | # add perturbation to bounding box coordinates 98 | H, W = mask.shape[1:] 99 | # add perfurbation to the bbox 100 | assert H == W, f"{W} and {H} are not equal size!!" 101 | x_min = max(0, x_min - np.random.randint(0, 10)) 102 | x_max = min(W, x_max + np.random.randint(0, 10)) 103 | y_min = max(0, y_min - np.random.randint(0, 10)) 104 | y_max = min(H, y_max + np.random.randint(0, 10)) 105 | 106 | return torch.tensor([x_min, y_min, x_max, y_max]) 107 | -------------------------------------------------------------------------------- /utils/precompute_img_embed.py: -------------------------------------------------------------------------------- 1 | #%% import packages 2 | # precompute image embeddings and save them to disk for model training 3 | 4 | import numpy as np 5 | import os 6 | join = os.path.join 7 | from skimage import io, segmentation 8 | from tqdm import tqdm 9 | import torch 10 | from segment_anything import sam_model_registry 11 | from segment_anything.utils.transforms import ResizeLongestSide 12 | import argparse 13 | 14 | #%% parse arguments 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-i', '--img_path', type=str, default='./data/Tr_Release_Part1', help='# and also Tr_Release_Part2 when part1 is done') 17 | parser.add_argument('-o', '--save_path', type=str, default='./data/Tr_npy', help='path to save the image embeddings') 18 | parser.add_argument('--model_type', type=str, default='vit_b', help='model type') 19 | parser.add_argument('--checkpoint', type=str, default='../work_dir/SAM/sam_vit_b_01ec64.pth', help='path to the pre-trained SAM model') 20 | args = parser.parse_args() 21 | 22 | pre_img_path = args.img_path 23 | save_img_emb_path = join(args.save_path, 'npy_embs') 24 | save_gt_path = join(args.save_path, 'npy_gts') 25 | os.makedirs(save_img_emb_path, exist_ok=True) 26 | os.makedirs(save_gt_path, exist_ok=True) 27 | npz_files = sorted(os.listdir(pre_img_path)) 28 | #%% set up the model 29 | sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to('cuda:0') 30 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size) 31 | 32 | # compute image embeddings 33 | for name in tqdm(npz_files): 34 | img = np.load(join(pre_img_path, name))['img'] # (256, 256, 3) 35 | gt = np.load(join(pre_img_path, name))['gt'] 36 | resize_img = sam_transform.apply_image(img) 37 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to('cuda:0') 38 | # model input: (1, 3, 1024, 1024) 39 | input_image = sam_model.preprocess(resize_img_tensor[None,:,:,:]) # (1, 3, 1024, 1024) 40 | assert input_image.shape == (1, 3, sam_model.image_encoder.img_size, sam_model.image_encoder.img_size), 'input image should be resized to 1024*1024' 41 | with torch.no_grad(): 42 | embedding = sam_model.image_encoder(input_image) 43 | 44 | # save as npy 45 | np.save(join(save_img_emb_path, name.split('.npz')[0]+'.npy'), embedding.cpu().numpy()[0]) 46 | np.save(join(save_gt_path, name.split('.npz')[0]+'.npy'), gt) 47 | # sanity check 48 | img_idx = img.copy() 49 | bd = segmentation.find_boundaries(gt, mode='inner') 50 | img_idx[bd, :] = [255, 0, 0] 51 | io.imsave(save_img_emb_path + '.png', img_idx) 52 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | #---------------------------------------------------------# 6 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 7 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 8 | #---------------------------------------------------------# 9 | def cvtColor(image): 10 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 11 | return image 12 | else: 13 | image = image.convert('RGB') 14 | return image 15 | 16 | #---------------------------------------------------# 17 | # 对输入图像进行resize 18 | #---------------------------------------------------# 19 | def resize_image(image, size, letterbox_image): 20 | iw, ih = image.size 21 | w, h = size 22 | if letterbox_image: 23 | scale = min(w/iw, h/ih) 24 | nw = int(iw*scale) 25 | nh = int(ih*scale) 26 | 27 | image = image.resize((nw,nh), Image.BICUBIC) 28 | new_image = Image.new('RGB', size, (128,128,128)) 29 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 30 | else: 31 | new_image = image.resize((w, h), Image.BICUBIC) 32 | return new_image 33 | 34 | #---------------------------------------------------# 35 | # 获得类 36 | #---------------------------------------------------# 37 | def get_classes(classes_path): 38 | with open(classes_path, encoding='utf-8') as f: 39 | class_names = f.readlines() 40 | class_names = [c.strip() for c in class_names] 41 | return class_names, len(class_names) 42 | 43 | def preprocess_input(image): 44 | image /= 255.0 45 | image -= np.array([0.485, 0.456, 0.406]) 46 | image /= np.array([0.229, 0.224, 0.225]) 47 | return image 48 | 49 | #---------------------------------------------------# 50 | # 获得学习率 51 | #---------------------------------------------------# 52 | def get_lr(optimizer): 53 | for param_group in optimizer.param_groups: 54 | return param_group['lr'] 55 | 56 | def show_config(**kwargs): 57 | print('Configurations:') 58 | print('-' * 70) 59 | print('|%25s | %40s|' % ('keys', 'values')) 60 | print('-' * 70) 61 | for key, value in kwargs.items(): 62 | print('|%25s | %40s|' % (str(key), str(value))) 63 | print('-' * 70) 64 | -------------------------------------------------------------------------------- /utils/utils_bbox.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.ops import nms, boxes 4 | 5 | def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image): 6 | #-----------------------------------------------------------------# 7 | # 把y轴放前面是因为方便预测框和图像的宽高进行相乘 8 | #-----------------------------------------------------------------# 9 | box_yx = box_xy[..., ::-1] 10 | box_hw = box_wh[..., ::-1] 11 | input_shape = np.array(input_shape) 12 | image_shape = np.array(image_shape) 13 | 14 | if letterbox_image: 15 | #-----------------------------------------------------------------# 16 | # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况 17 | # new_shape指的是宽高缩放情况 18 | #-----------------------------------------------------------------# 19 | new_shape = np.round(image_shape * np.min(input_shape/image_shape)) 20 | offset = (input_shape - new_shape)/2./input_shape 21 | scale = input_shape/new_shape 22 | 23 | box_yx = (box_yx - offset) * scale 24 | box_hw *= scale 25 | 26 | box_mins = box_yx - (box_hw / 2.) 27 | box_maxes = box_yx + (box_hw / 2.) 28 | boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1) 29 | boxes *= np.concatenate([image_shape, image_shape], axis=-1) 30 | return boxes 31 | 32 | def decode_outputs(outputs, input_shape): 33 | grids = [] 34 | strides = [] 35 | hw = [x.shape[-2:] for x in outputs] 36 | #---------------------------------------------------# 37 | # outputs输入前代表每个特征层的预测结果 38 | # batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400 39 | # batch_size, 5 + num_classes, 40, 40 40 | # batch_size, 5 + num_classes, 20, 20 41 | # batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400 42 | # 堆叠后为batch_size, 8400, 5 + num_classes 43 | #---------------------------------------------------# 44 | outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1) 45 | #---------------------------------------------------# 46 | # 获得每一个特征点属于每一个种类的概率 47 | #---------------------------------------------------# 48 | outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:]) 49 | for h, w in hw: 50 | #---------------------------# 51 | # 根据特征层的高宽生成网格点 52 | #---------------------------# 53 | grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)]) 54 | #---------------------------# 55 | # 1, 6400, 2 56 | # 1, 1600, 2 57 | # 1, 400, 2 58 | #---------------------------# 59 | grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2) 60 | shape = grid.shape[:2] 61 | 62 | grids.append(grid) 63 | strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h)) 64 | #---------------------------# 65 | # 将网格点堆叠到一起 66 | # 1, 6400, 2 67 | # 1, 1600, 2 68 | # 1, 400, 2 69 | # 70 | # 1, 8400, 2 71 | #---------------------------# 72 | grids = torch.cat(grids, dim=1).type(outputs.type()) 73 | strides = torch.cat(strides, dim=1).type(outputs.type()) 74 | #------------------------# 75 | # 根据网格点进行解码 76 | #------------------------# 77 | outputs[..., :2] = (outputs[..., :2] + grids) * strides 78 | outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides 79 | #-----------------# 80 | # 归一化 81 | #-----------------# 82 | outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1] 83 | outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0] 84 | return outputs 85 | 86 | def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4): 87 | #----------------------------------------------------------# 88 | # 将预测结果的格式转换成左上角右下角的格式。 89 | # prediction [batch_size, num_anchors, 85] 90 | #----------------------------------------------------------# 91 | box_corner = prediction.new(prediction.shape) 92 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 93 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 94 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 95 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 96 | prediction[:, :, :4] = box_corner[:, :, :4] 97 | 98 | output = [None for _ in range(len(prediction))] 99 | #----------------------------------------------------------# 100 | # 对输入图片进行循环,一般只会进行一次 101 | #----------------------------------------------------------# 102 | for i, image_pred in enumerate(prediction): 103 | 104 | #----------------------------------------------------------# 105 | # 对种类预测部分取max。 106 | # class_conf [num_anchors, 1] 种类置信度 107 | # class_pred [num_anchors, 1] 种类 108 | #----------------------------------------------------------# 109 | class_conf,class_pred= torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True) 110 | for i in class_pred: 111 | if class_pred[i]!=0: 112 | print(1) 113 | #----------------------------------------------------------# 114 | # 利用置信度进行第一轮筛选 115 | #----------------------------------------------------------# 116 | conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze() 117 | if not image_pred.size(0): 118 | continue 119 | #-------------------------------------------------------------------------# 120 | # detections [num_anchors, 7] 121 | # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred 122 | #-------------------------------------------------------------------------# 123 | detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) 124 | detections = detections[conf_mask] 125 | nms_out_index = boxes.batched_nms( 126 | detections[:, :4], 127 | detections[:, 4] * detections[:, 5], 128 | detections[:, 6], 129 | nms_thres, 130 | ) 131 | 132 | output[i] = detections[nms_out_index] 133 | 134 | # #------------------------------------------# 135 | # # 获得预测结果中包含的所有种类 136 | # #------------------------------------------# 137 | # unique_labels = detections[:, -1].cpu().unique() 138 | 139 | # if prediction.is_cuda: 140 | # unique_labels = unique_labels.cuda() 141 | # detections = detections.cuda() 142 | 143 | # for c in unique_labels: 144 | # #------------------------------------------# 145 | # # 获得某一类得分筛选后全部的预测结果 146 | # #------------------------------------------# 147 | # detections_class = detections[detections[:, -1] == c] 148 | 149 | # #------------------------------------------# 150 | # # 使用官方自带的非极大抑制会速度更快一些! 151 | # #------------------------------------------# 152 | # keep = nms( 153 | # detections_class[:, :4], 154 | # detections_class[:, 4] * detections_class[:, 5], 155 | # nms_thres 156 | # ) 157 | # max_detections = detections_class[keep] 158 | 159 | # # # 按照存在物体的置信度排序 160 | # # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True) 161 | # # detections_class = detections_class[conf_sort_index] 162 | # # # 进行非极大抑制 163 | # # max_detections = [] 164 | # # while detections_class.size(0): 165 | # # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉 166 | # # max_detections.append(detections_class[0].unsqueeze(0)) 167 | # # if len(detections_class) == 1: 168 | # # break 169 | # # ious = bbox_iou(max_detections[-1], detections_class[1:]) 170 | # # detections_class = detections_class[1:][ious < nms_thres] 171 | # # # 堆叠 172 | # # max_detections = torch.cat(max_detections).data 173 | 174 | # # Add max detections to outputs 175 | # output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections)) 176 | 177 | if output[i] is not None: 178 | output[i] = output[i].cpu().numpy() 179 | box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2] 180 | output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image) 181 | return output 182 | -------------------------------------------------------------------------------- /utils/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | from utils.utils import get_lr 7 | 8 | 9 | def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0): 10 | loss = 0 11 | val_loss = 0 12 | 13 | if local_rank == 0: 14 | print('Start Train') 15 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 16 | model_train.train() 17 | for iteration, batch in enumerate(gen): 18 | if iteration >= epoch_step: 19 | break 20 | 21 | images, targets = batch[0], batch[1] 22 | with torch.no_grad(): 23 | if cuda: 24 | images = images.cuda(local_rank) 25 | targets = [ann.cuda(local_rank) for ann in targets] 26 | #----------------------# 27 | # 清零梯度 28 | #----------------------# 29 | optimizer.zero_grad() 30 | if not fp16: 31 | #----------------------# 32 | # 前向传播 33 | #----------------------# 34 | outputs = model_train(images) 35 | 36 | #----------------------# 37 | # 计算损失 38 | #----------------------# 39 | loss_value = yolo_loss(outputs, targets) 40 | 41 | #----------------------# 42 | # 反向传播 43 | #----------------------# 44 | loss_value.backward() 45 | optimizer.step() 46 | else: 47 | from torch.cuda.amp import autocast 48 | with autocast(): 49 | outputs = model_train(images) 50 | #----------------------# 51 | # 计算损失 52 | #----------------------# 53 | loss_value = yolo_loss(outputs, targets) 54 | 55 | #----------------------# 56 | # 反向传播 57 | #----------------------# 58 | scaler.scale(loss_value).backward() 59 | scaler.step(optimizer) 60 | scaler.update() 61 | if ema: 62 | ema.update(model_train) 63 | 64 | loss += loss_value.item() 65 | 66 | if local_rank == 0: 67 | pbar.set_postfix(**{'loss' : loss / (iteration + 1), 68 | 'lr' : get_lr(optimizer)}) 69 | pbar.update(1) 70 | 71 | if local_rank == 0: 72 | pbar.close() 73 | print('Finish Train') 74 | print('Start Validation') 75 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 76 | 77 | if ema: 78 | model_train_eval = ema.ema 79 | else: 80 | model_train_eval = model_train.eval() 81 | 82 | for iteration, batch in enumerate(gen_val): 83 | if iteration >= epoch_step_val: 84 | break 85 | images, targets = batch[0], batch[1] 86 | with torch.no_grad(): 87 | if cuda: 88 | images = images.cuda(local_rank) 89 | targets = [ann.cuda(local_rank) for ann in targets] 90 | #----------------------# 91 | # 清零梯度 92 | #----------------------# 93 | optimizer.zero_grad() 94 | #----------------------# 95 | # 前向传播 96 | #----------------------# 97 | outputs = model_train_eval(images) 98 | 99 | #----------------------# 100 | # 计算损失 101 | #----------------------# 102 | loss_value = yolo_loss(outputs, targets) 103 | 104 | val_loss += loss_value.item() 105 | if local_rank == 0: 106 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)}) 107 | pbar.update(1) 108 | 109 | if local_rank == 0: 110 | pbar.close() 111 | print('Finish Validation') 112 | loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val) 113 | eval_callback.on_epoch_end(epoch + 1, model_train_eval) 114 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch)) 115 | print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val)) 116 | 117 | #-----------------------------------------------# 118 | # 保存权值 119 | #-----------------------------------------------# 120 | if ema: 121 | save_state_dict = ema.ema.state_dict() 122 | else: 123 | save_state_dict = model.state_dict() 124 | 125 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 126 | torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val))) 127 | 128 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss): 129 | print('Save best model to best_epoch_weights.pth') 130 | torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth")) 131 | 132 | torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth")) -------------------------------------------------------------------------------- /utils_metrics.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from os.path import join 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | 11 | 12 | def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5): 13 | n, c, h, w = inputs.size() 14 | nt, ht, wt, ct = target.size() 15 | if h != ht and w != wt: 16 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True) 17 | 18 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1) 19 | temp_target = target.view(n, -1, ct) 20 | 21 | #--------------------------------------------# 22 | # 计算dice系数 23 | #--------------------------------------------# 24 | temp_inputs = torch.gt(temp_inputs, threhold).float() 25 | tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1]) 26 | fp = torch.sum(temp_inputs , axis=[0,1]) - tp 27 | fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp 28 | 29 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) 30 | score = torch.mean(score) 31 | return score 32 | 33 | # 设标签宽W,长H 34 | def fast_hist(a, b, n): 35 | 36 | #--------------------------------------------------------------------------------# 37 | # a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,) 38 | #--------------------------------------------------------------------------------# 39 | k = (a >= 0) & (a < n) 40 | #--------------------------------------------------------------------------------# 41 | # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n) 42 | # 返回中,写对角线上的为分类正确的像素点 43 | #--------------------------------------------------------------------------------# 44 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) 45 | 46 | def per_class_iu(hist): 47 | return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) 48 | 49 | def per_class_PA_Recall(hist): 50 | return np.diag(hist) / np.maximum(hist.sum(1), 1) 51 | 52 | def per_class_Precision(hist): 53 | return np.diag(hist) / np.maximum(hist.sum(0), 1) 54 | 55 | def per_Accuracy(hist): 56 | return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1) 57 | 58 | def compute_mIoU(gts, preds, num_classes=2,name_classes= ["background","tongue"]): 59 | hist = np.zeros((num_classes, num_classes)) 60 | for ind in range(len(gts)): 61 | pred = preds[ind] 62 | pred[pred== 2] = 1 63 | label = gts[ind] 64 | label[label== 2] = 1 65 | hist += fast_hist(label.flatten(), pred.flatten(), num_classes) 66 | IoUs = per_class_iu(hist) 67 | PA_Recall = per_class_PA_Recall(hist) 68 | Precision = per_class_Precision(hist) 69 | #------------------------------------------------# 70 | # 逐类别输出一下mIoU值 71 | #------------------------------------------------# 72 | if name_classes is not None: 73 | for ind_class in range(num_classes): 74 | print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \ 75 | + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2))) 76 | 77 | #-----------------------------------------------------------------# 78 | # 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值 79 | #-----------------------------------------------------------------# 80 | print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2))) 81 | return round(np.nanmean(IoUs) * 100, 2), round(np.nanmean(PA_Recall) * 100, 2), round(per_Accuracy(hist) * 100, 2) 82 | 83 | def adjust_axes(r, t, fig, axes): 84 | bb = t.get_window_extent(renderer=r) 85 | text_width_inches = bb.width / fig.dpi 86 | current_fig_width = fig.get_figwidth() 87 | new_fig_width = current_fig_width + text_width_inches 88 | propotion = new_fig_width / current_fig_width 89 | x_lim = axes.get_xlim() 90 | axes.set_xlim([x_lim[0], x_lim[1] * propotion]) 91 | 92 | def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True): 93 | fig = plt.gcf() 94 | axes = plt.gca() 95 | plt.barh(range(len(values)), values, color='royalblue') 96 | plt.title(plot_title, fontsize=tick_font_size + 2) 97 | plt.xlabel(x_label, fontsize=tick_font_size) 98 | plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size) 99 | r = fig.canvas.get_renderer() 100 | for i, val in enumerate(values): 101 | str_val = " " + str(val) 102 | if val < 1.0: 103 | str_val = " {0:.2f}".format(val) 104 | t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold') 105 | if i == (len(values)-1): 106 | adjust_axes(r, t, fig, axes) 107 | 108 | fig.tight_layout() 109 | fig.savefig(output_path) 110 | if plt_show: 111 | plt.show() 112 | plt.close() 113 | 114 | def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12): 115 | draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \ 116 | os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True) 117 | print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png")) 118 | 119 | draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Pixel Accuracy", \ 120 | os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False) 121 | print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png")) 122 | 123 | draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Recall", \ 124 | os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False) 125 | print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png")) 126 | 127 | draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \ 128 | os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False) 129 | print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png")) 130 | 131 | with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f: 132 | writer = csv.writer(f) 133 | writer_list = [] 134 | writer_list.append([' '] + [str(c) for c in name_classes]) 135 | for i in range(len(hist)): 136 | writer_list.append([name_classes[i]] + [str(x) for x in hist[i]]) 137 | writer.writerows(writer_list) 138 | print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv")) 139 | --------------------------------------------------------------------------------