├── .DS_Store ├── 1.jpg ├── 3.jpg ├── 4.jpg ├── LICENSE.md ├── README.md ├── __pycache__ └── utils_metrics.cpython-39.pyc ├── data ├── test_in │ ├── 1.jpg │ ├── 10.jpg │ ├── 100.jpg │ ├── 11.jpg │ ├── 12.jpg │ ├── 13.jpg │ ├── 14.jpg │ ├── 15.jpg │ ├── 16.jpg │ ├── 17.jpg │ ├── 18.jpg │ └── 19.jpg └── test_out │ ├── 1.png │ ├── 10.png │ ├── 100.png │ ├── 11.png │ ├── 12.png │ ├── 13.png │ ├── 14.png │ ├── 15.png │ ├── 16.png │ ├── 17.png │ ├── 18.png │ └── 19.png ├── pre_tongue.py ├── predict.py ├── pretrained_model └── .DS_Store ├── segment ├── __pycache__ │ ├── deeplab.cpython-39.pyc │ ├── pspnet.cpython-39.pyc │ ├── unet.cpython-39.pyc │ ├── yolox.cpython-39.pyc │ └── yolox_new.cpython-39.pyc ├── tongue_classes.txt ├── utils_yolox │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── callbacks.cpython-310.pyc │ │ ├── callbacks.cpython-39.pyc │ │ ├── dataloader.cpython-310.pyc │ │ ├── dataloader.cpython-39.pyc │ │ ├── utils.cpython-310.pyc │ │ ├── utils.cpython-39.pyc │ │ ├── utils_bbox.cpython-310.pyc │ │ ├── utils_bbox.cpython-39.pyc │ │ ├── utils_fit.cpython-310.pyc │ │ ├── utils_fit.cpython-39.pyc │ │ ├── utils_map.cpython-310.pyc │ │ └── utils_map.cpython-39.pyc │ ├── callbacks.py │ ├── dataloader.py │ ├── utils.py │ ├── utils_bbox.py │ ├── utils_fit.py │ └── utils_map.py ├── yolox.pth ├── yolox.py └── yolox_nets │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── darknet.cpython-310.pyc │ ├── darknet.cpython-39.pyc │ ├── yolo.cpython-310.pyc │ ├── yolo.cpython-39.pyc │ ├── yolo_training.cpython-310.pyc │ └── yolo_training.cpython-39.pyc │ ├── darknet.py │ ├── yolo.py │ └── yolo_training.py ├── segment_anything ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── automatic_mask_generator.cpython-310.pyc │ ├── automatic_mask_generator.cpython-39.pyc │ ├── build_sam.cpython-310.pyc │ ├── build_sam.cpython-39.pyc │ ├── predictor.cpython-310.pyc │ └── predictor.cpython-39.pyc ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── common.cpython-310.pyc │ │ ├── common.cpython-39.pyc │ │ ├── image_encoder.cpython-310.pyc │ │ ├── image_encoder.cpython-39.pyc │ │ ├── mask_decoder.cpython-310.pyc │ │ ├── mask_decoder.cpython-39.pyc │ │ ├── prompt_encoder.cpython-310.pyc │ │ ├── prompt_encoder.cpython-39.pyc │ │ ├── sam.cpython-310.pyc │ │ ├── sam.cpython-39.pyc │ │ ├── transformer.cpython-310.pyc │ │ └── transformer.cpython-39.pyc │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── amg.cpython-310.pyc │ ├── amg.cpython-39.pyc │ ├── transforms.cpython-310.pyc │ └── transforms.cpython-39.pyc │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── split.py ├── train.py ├── utils ├── SurfaceDice.py ├── __init__.py ├── __pycache__ │ ├── SurfaceDice.cpython-39.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-39.pyc │ ├── callbacks.cpython-310.pyc │ ├── callbacks.cpython-39.pyc │ ├── dataloader.cpython-310.pyc │ ├── dataloader.cpython-39.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-39.pyc │ ├── utils_bbox.cpython-310.pyc │ ├── utils_bbox.cpython-39.pyc │ ├── utils_fit.cpython-310.pyc │ ├── utils_fit.cpython-39.pyc │ ├── utils_map.cpython-310.pyc │ └── utils_map.cpython-39.pyc ├── callbacks.py ├── dataloader.py ├── dataset.py ├── precompute_img_embed.py ├── utils.py ├── utils_bbox.py ├── utils_fit.py └── utils_map.py └── utils_metrics.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/.DS_Store -------------------------------------------------------------------------------- /1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/1.jpg -------------------------------------------------------------------------------- /3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/3.jpg -------------------------------------------------------------------------------- /4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/4.jpg -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2023] [Shan Cao] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TongueSAM: An Universal Tongue Segmentation Model Based on SAM with Zero-Shot 2 | This is the public project of paper:"TongueSAM: An Universal Tongue Segmentation Model Based on SAM with Zero-Shot", this paper can be get:https://arxiv.org/abs/2308.06444. 3 | 4 | ## Abstract 5 | 6 | Tongue segmentation serves as the primary step in automated TCM tongue diagnosis, which plays a significant role in the di- agnostic results. Currently, numerous deep learning based methods have achieved promising results. However, most of these methods exhibit mediocre performance on tongues different from the training set. To address this issue, this paper proposes a universal tongue segmentation model named TongueSAM based on SAM (Segment Anything Model). SAM is a large-scale pretrained interactive segmentation model known for its powerful zero-shot generalization capability. Applying SAM to tongue segmentation enables the segmentation of various types of tongue images with zero-shot. In this study, a Prompt Generator based on object detection 7 | is integrated into SAM to enable an end-to-end automated tongue segmentation method. Experiments demonstrate that TongueSAM achieves exceptional performance across various of tongue segmentation datasets, particularly under zero-shot. TongueSAM can be directly applied to other datasets without fine-tuning. As far as we know, this is the first application of large-scale pretrained model for tongue segmentation. 8 | 9 | ## Method 10 | 11 | TongueSAM consists primarily of two components: SAM and the Prompt Generator. For a given tongue image, TongueSAM first utilizes the pretrained Image Encoder in SAM for encoding. Meanwhile, the Prompt Generator generates bounding box prompt based on the tongue image. Finally, the image embedding and prompts are jointly fed into the Mask Decoder to generate the segmentation result. The entire segmentation process is end-to-end and does not require any additional manual prompts. The following sections will introduce different components of TongueSAM. 12 | 13 |
14 |
15 |
16 |
17 | ## Result
18 |
19 |
20 |
21 |
22 | ## DataSet
23 |
24 | In our experiments, we used 3 tongue image segmentation datasets, TongueSet1, TongueSet2(BioHit), TongueSet3. The TongueSet1 cannot be public at the moment due to privacy concerns. The [TongueSet2](https://github.com/BioHit/TongeImageDataset) has already been made public. We are now releasing the TongueSet3 [here](https://pan.baidu.com/s/1TCcbwMYraSPzWeI60EME0A?pwd=ttm4).
25 |
26 | TongueSet3 is a dataset we compiled by selecting 1000 tongue images from the [website](https://aistudio.baidu.com/datasetdetail/196398), and manually segmenting them using the [Labelme](https://github.com/wkentaro/labelme) tool. This dataset encompasses a wide range of tongue images from various sources, including those captured with mobile devices and non-standard angles. To our knowledge, this is the first publicly available tongue image segmentation dataset in a free environment. The original tongue images from the website vary in size. To ensure input consistency, we resized each tongue image to [400, 400] pixels. In the files we have made public, the "img" folder contains the original input tongue images, and the "gt" folder contains our manually annotated ground truth segmentations. **It's important to note that the images in the "gt" folder may appear completely black, but in reality, pixels with a value of [1, 1, 1] represent the tongue region, while pixels with a value of [0, 0, 0] represent the background. Please be mindful of this distinction.**
27 |
28 |
29 |
30 |
31 | ## Project Description
32 |
33 | **1.Zero-Shot Segmentation**
34 |
35 | The most crucial capability of TongueSAM lies in its Zero-Shot segmentation. To facilitate user adoption, we employed the three datasets mentioned in the paper for fine-tuning TongueSAM and openly released the pre-trained model. Users can perform tongue image segmentation directly using TongueSAM with just a few straightforward steps.
36 |
37 | Download the pre-trained weights:[TongueSAM](https://pan.baidu.com/s/1zG0jpYshlBs3lcdy4F37dQ?pwd=xtfg)
38 |
39 | Put the ```tonguesam.pth``` into the ```./pretrained_model/``` folder.
40 |
41 | Place the tongue image files that need to be segmented into the ```./data/test_in/``` folder.
42 |
43 | Run ```./python.py```
44 |
45 | The segmented tongue images will be located in the ```./data/test_out/``` folder.
46 |
47 | **2.Fine-tune**
48 |
49 | If you wish to further fine-tune the model, please follow these steps:
50 |
51 | To train the Prompt Generator based on YOLOX, please refer to the following guidelines:[YOLOX](https://github.com/bubbliiiing/yolox-pytorch)
52 |
53 | Replace the pre-trained model in the ```./segment/yolox.pth``` file with your trained model.
54 |
55 | Run ```split.py``` twice, and the path of ```src_folder``` is your img_data and gt_data respectively.
56 |
57 | Run ```pre_tongue.py```, ```img_path``` and ```gt_path``` for your processed folder paths, respectively. For other parameter Settings, refer to [MedSAM](https://github.com/bowang-lab/MedSAM).
58 |
59 | Run ```./train.py```,please refer to the following guidelines:[MedSAM](https://github.com/bowang-lab/MedSAM)
60 |
61 | ## Acknowledge
62 |
63 | The project is based on [YOLOX](https://github.com/bubbliiiing/yolox-pytorch) and [MedSAM](https://github.com/bowang-lab/MedSAM), and we appreciate their contributions.
64 |
65 | ## License
66 |
67 | This project is licensed under the [MIT LICENSE](https://github.com/cshan-github/TongueSAM/blob/main/LICENSE.md).
68 |
69 |
--------------------------------------------------------------------------------
/__pycache__/utils_metrics.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/__pycache__/utils_metrics.cpython-39.pyc
--------------------------------------------------------------------------------
/data/test_in/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/1.jpg
--------------------------------------------------------------------------------
/data/test_in/10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/10.jpg
--------------------------------------------------------------------------------
/data/test_in/100.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/100.jpg
--------------------------------------------------------------------------------
/data/test_in/11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/11.jpg
--------------------------------------------------------------------------------
/data/test_in/12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/12.jpg
--------------------------------------------------------------------------------
/data/test_in/13.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/13.jpg
--------------------------------------------------------------------------------
/data/test_in/14.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/14.jpg
--------------------------------------------------------------------------------
/data/test_in/15.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/15.jpg
--------------------------------------------------------------------------------
/data/test_in/16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/16.jpg
--------------------------------------------------------------------------------
/data/test_in/17.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/17.jpg
--------------------------------------------------------------------------------
/data/test_in/18.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/18.jpg
--------------------------------------------------------------------------------
/data/test_in/19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_in/19.jpg
--------------------------------------------------------------------------------
/data/test_out/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/1.png
--------------------------------------------------------------------------------
/data/test_out/10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/10.png
--------------------------------------------------------------------------------
/data/test_out/100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/100.png
--------------------------------------------------------------------------------
/data/test_out/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/11.png
--------------------------------------------------------------------------------
/data/test_out/12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/12.png
--------------------------------------------------------------------------------
/data/test_out/13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/13.png
--------------------------------------------------------------------------------
/data/test_out/14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/14.png
--------------------------------------------------------------------------------
/data/test_out/15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/15.png
--------------------------------------------------------------------------------
/data/test_out/16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/16.png
--------------------------------------------------------------------------------
/data/test_out/17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/17.png
--------------------------------------------------------------------------------
/data/test_out/18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/18.png
--------------------------------------------------------------------------------
/data/test_out/19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/data/test_out/19.png
--------------------------------------------------------------------------------
/pre_tongue.py:
--------------------------------------------------------------------------------
1 | #%% import packages
2 | import numpy as np
3 | import os
4 | join = os.path.join
5 | from skimage import transform, io, segmentation,util,exposure
6 | from tqdm import tqdm
7 | import torch
8 | from segment_anything import sam_model_registry
9 | from segment_anything.utils.transforms import ResizeLongestSide
10 | import argparse
11 | import torch.nn as nn
12 | import cv2
13 |
14 | from PIL import Image
15 | from tqdm import tqdm
16 | # from segment.nets.deeplabv3_plus import DeepLab
17 | from segment.deeplab import DeeplabV3
18 |
19 | # set up the parser
20 | parser = argparse.ArgumentParser(description='preprocess grey and RGB images')
21 | parser.add_argument('-i', '--img_path', type=str, default='/home/disk/cs/project/dataset/segmentation/tongueset3_split/img/', help='path to the images')
22 | parser.add_argument('-gt', '--gt_path', type=str, default='/home/disk/cs/project/dataset/segmentation/tongueset3_split/gt/', help='path to the ground truth (gt)')
23 | parser.add_argument('-o', '--npz_path', type=str, default='/home/disk/cs/project/dataset/segmentation/tongueset3_npz2/', help='path to save the npz files')
24 | parser.add_argument('--data_name', type=str, default='tongue', help='dataset name; used to name the final npz file, e.g., demo2d.npz')
25 | parser.add_argument('--image_size', type=int, default=400, help='image size')
26 | parser.add_argument('--img_name_suffix', type=str, default='.jpg', help='image name suffix')
27 | parser.add_argument('--label_id', type=int, default=1, help='label id')
28 | parser.add_argument('--model_type', type=str, default='vit_b', help='model type')
29 | parser.add_argument('--checkpoint', type=str, default='./pretrained_model/sam.pth', help='checkpoint')
30 | parser.add_argument('--device', type=str, default='cuda:0', help='device')
31 | parser.add_argument('--seed', type=int, default=2023, help='random seed')
32 | args = parser.parse_args()
33 |
34 |
35 | def semantic_segmentation_augmentation(image, label, rotation_range=(-90, 90)):
36 | stretch_range=(0.8, 1.2)
37 | # 随机裁剪
38 | h, w, _ = image.shape
39 | top = np.random.randint(0, h - 300)
40 | left = np.random.randint(0, w - 300)
41 | image = image[top:top+300, left:left+300]
42 | label = label[top:top+300, left:left+300]
43 |
44 | # 随机水平翻转
45 | if np.random.rand() > 0.5:
46 | image = np.fliplr(image)
47 | label = np.fliplr(label)
48 |
49 | # 随机旋转
50 | rotation_angle = np.random.uniform(rotation_range[0], rotation_range[1])
51 | image = rotate_image(image, rotation_angle)
52 | label = rotate_image(label, rotation_angle)
53 |
54 | # 随机颜色抖动(可选)
55 | image = augment_colors(image)
56 | # 随机拉伸
57 | # stretch_factor_x = np.random.uniform(stretch_range[0], stretch_range[1])
58 | # stretch_factor_y = np.random.uniform(stretch_range[0], stretch_range[1])
59 | # image = stretch_image(image, stretch_factor_x, stretch_factor_y)
60 | # label = stretch_image(label, stretch_factor_x, stretch_factor_y)
61 |
62 | return image, label
63 |
64 | def augment_colors(image):
65 | # 随机生成对比度和亮度的增益
66 | contrast_factor = np.random.uniform(0.5, 1.5) # 可以调整范围以获得所需的效果
67 | brightness_factor = np.random.randint(-50, 51) # 亮度增益的范围
68 |
69 | # 修改对比度和亮度
70 | image = cv2.convertScaleAbs(image, alpha=contrast_factor, beta=brightness_factor)
71 |
72 |
73 | return image
74 |
75 | def rotate_image(image, angle):
76 | # 旋转图像
77 | image = transform.rotate(image, angle, resize=False, mode='reflect')
78 | return util.img_as_ubyte(image)
79 |
80 |
81 | def deal(img_path,gt_path,num):
82 | names = sorted(os.listdir(gt_path))
83 | save_path = args.npz_path
84 | os.makedirs(save_path, exist_ok=True)
85 | print('image number:', len(names))
86 | imgs = []
87 | gts = []
88 | boxes=[]
89 | img_embeddings = []
90 | for gt_name in tqdm(names):
91 | sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to(args.device)
92 | image_name = gt_name.split('.')[0] + args.img_name_suffix
93 | gt_data = io.imread(join(gt_path, gt_name))
94 | image_data = io.imread(join(img_path, image_name))
95 | image_data,gt_data=semantic_segmentation_augmentation(image_data, gt_data)
96 | # cv2.imwrite(gt_name,image_data)
97 | gt_data=cv2.resize(gt_data,(args.image_size,args.image_size))
98 | if len(gt_data.shape)==3:
99 | gt_data = gt_data[:,:,0]
100 | assert len(gt_data.shape)==2, 'ground truth should be 2D'
101 | gt_data = transform.resize(gt_data==args.label_id, (args.image_size, args.image_size), order=0, preserve_range=True, mode='constant')
102 | gt_data = np.uint8(gt_data)
103 |
104 | if image_data.shape[-1]>3 and len(image_data.shape)==3:
105 | image_data = image_data[:,:,:3]
106 | if len(image_data.shape)==2:
107 | image_data = np.repeat(image_data[:,:,None], 3, axis=-1)
108 | if gt_data.shape[-1]==3:
109 | gt=gt_data
110 | z=np.zeros([gt.shape[0],gt.shape[1]])
111 | for i in range(gt.shape[0]):
112 | for j in range(gt.shape[1]):
113 | if gt[i][j][0]==1:
114 | z[i][j]=1
115 | gt=z
116 | gt_data=gt
117 |
118 | lower_bound, upper_bound = np.percentile(image_data, 0.5), np.percentile(image_data, 99.5)
119 | image_data_pre = np.clip(image_data, lower_bound, upper_bound)
120 | image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0
121 | image_data_pre[image_data==0] = 0
122 | image_data_pre = transform.resize(image_data_pre, (args.image_size,args.image_size), order=3, preserve_range=True, mode='constant', anti_aliasing=True)
123 | image_data_pre = np.uint8(image_data_pre)
124 | imgs.append(image_data_pre)
125 | gts.append(gt_data)
126 | H, W, _ = image_data_pre.shape
127 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
128 | resize_img = sam_transform.apply_image(image_data_pre)
129 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(args.device)
130 | input_image = sam_model.preprocess(resize_img_tensor[None,:,:,:]) # (1, 3, 1024, 1024)
131 | assert input_image.shape == (1, 3, sam_model.image_encoder.img_size, sam_model.image_encoder.img_size), 'input image should be resized to 1024*1024'
132 | embedding = sam_model.image_encoder(input_image)
133 | img_embeddings.append(embedding.cpu().detach().numpy()[0])
134 | ##########################################################################################
135 | # pic=Image.open(join(img_path, image_name))
136 | # pic=Image.fromarray(image_data_pre)
137 | # pic= model.get_miou_png(pic)
138 | y_indices, x_indices = np.where(gt_data > 0)
139 | xmin, xmax = np.min(x_indices), np.max(x_indices)
140 | ymin, ymax = np.min(y_indices), np.max(y_indices)
141 | box=np.array([xmin,ymin,xmax,ymax])
142 | boxes.append(box)
143 | ##########################################################################################
144 | del sam_model
145 | print('Num. of images:', len(imgs))
146 | if len(imgs)>0:
147 | imgs = np.stack(imgs, axis=0) # (n, 256, 256, 3)
148 | gts = np.stack(gts, axis=0) # (n, 256, 256)
149 | img_embeddings = np.stack(img_embeddings, axis=0) # (n, 1, 256, 64, 64)
150 | np.savez_compressed(join(save_path, args.data_name + str(num)+'.npz'), imgs=imgs, boxes=boxes,gts=gts, img_embeddings=img_embeddings)
151 | # save an example image for sanity check
152 | idx = np.random.randint(imgs.shape[0])
153 | img_idx = imgs[idx,:,:,:]
154 | gt_idx = gts[idx,:,:]
155 | bd = segmentation.find_boundaries(gt_idx, mode='inner')
156 | img_idx[bd, :] = [args.image_size-1, 0, 0]
157 | # io.imsave(save_path + '.png', img_idx, check_contrast=False)
158 | else:
159 | print('Do not find image and ground-truth pairs. Please check your dataset and argument settings')
160 |
161 | num=0
162 | # for i in range(10):
163 | for f in os.listdir(args.img_path):
164 | print(f)
165 | # deal(args.img_path+'/'+f,args.gt_path+'/'+f,num)
166 | num+=1
167 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from PIL import ImageDraw
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import os
5 | from skimage import io
6 | join = os.path.join
7 | from tqdm import tqdm
8 | import torch
9 | from torch.utils.data import Dataset, DataLoader
10 | import monai
11 | from segment_anything import SamPredictor, sam_model_registry
12 | from segment_anything.utils.transforms import ResizeLongestSide
13 | from utils.SurfaceDice import compute_dice_coefficient
14 | import cv2
15 | from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score, jaccard_score
16 | # set seeds
17 | torch.manual_seed(2023)
18 | np.random.seed(2023)
19 | from skimage import io
20 | from utils_metrics import *
21 | from skimage import transform, io, segmentation
22 | from segment.yolox import YOLOX
23 | import random
24 | import warnings
25 |
26 | # 永久性地忽略指定类型的警告
27 | warnings.filterwarnings("ignore", category=UserWarning)
28 | #########################################################################################################
29 | ts_img_path = './data/test_in/'
30 | model_type = 'vit_b'
31 | checkpoint = './pretrained_model/tonguesam.pth'
32 | device = 'cuda:1'
33 | path_out='./data/test_out/'
34 | segment=YOLOX()
35 | ##############################################################################################################
36 | def get_bbox_from_mask(mask):
37 | '''Returns a bounding box from a mask'''
38 | y_indices, x_indices = np.where(mask > 0)
39 | x_min, x_max = np.min(x_indices), np.max(x_indices)
40 | y_min, y_max = np.min(y_indices), np.max(y_indices)
41 | # add perturbation to bounding box coordinates
42 | H, W = mask.shape
43 | x_min = max(0, x_min - np.random.randint(0, 20))
44 | x_max = min(W, x_max + np.random.randint(0, 20))
45 | y_min = max(0, y_min - np.random.randint(0, 20))
46 | y_max = min(H, y_max + np.random.randint(0, 20))
47 |
48 | return np.array([x_min, y_min, x_max, y_max])
49 | def show_mask(mask, ax, random_color=False):
50 | if random_color:
51 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
52 | else:
53 | color = np.array([251/255, 252/255, 30/255, 0.6])
54 | h, w = mask.shape[-2:]
55 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
56 | ax.imshow(mask_image)
57 |
58 | def show_box(box, ax):
59 | x0, y0 = box[0], box[1]
60 | w, h = box[2] - box[0], box[3] - box[1]
61 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))
62 | best_iou=0
63 | test_names = sorted(os.listdir(ts_img_path))
64 |
65 | sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
66 | #################################
67 | # prune_threshold =0.005
68 | # for param in sam_model.image_encoder.parameters():
69 | # param.data[torch.abs(param.data) < prune_threshold] = 0
70 | sam_model.eval()
71 | #################################
72 | val_gts=[]
73 | val_preds=[]
74 | for f in os.listdir(ts_img_path):
75 | with torch.no_grad():
76 | image_data = io.imread(join(ts_img_path, f))
77 |
78 | if image_data.shape[-1] > 3 and len(image_data.shape) == 3:
79 | image_data = image_data[:, :, :3]
80 | if len(image_data.shape) == 2:
81 | image_data = np.repeat(image_data[:, :, None], 3, axis=-1)
82 |
83 | lower_bound, upper_bound = np.percentile(image_data, 0.5), np.percentile(image_data, 99.5)
84 | image_data_pre = np.clip(image_data, lower_bound, upper_bound)
85 | image_data_pre = (image_data_pre - np.min(image_data_pre)) / (np.max(image_data_pre) - np.min(image_data_pre)) * 255.0
86 | image_data_pre[image_data == 0] = 0
87 | image_data_pre = transform.resize(image_data_pre, (400, 400), order=3, preserve_range=True, mode='constant', anti_aliasing=True)
88 | image_data_pre = np.uint8(image_data_pre)
89 |
90 | H, W, _ = image_data_pre.shape
91 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
92 | resize_img = sam_transform.apply_image(image_data_pre)
93 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(device)
94 | input_image = sam_model.preprocess(resize_img_tensor[None, :, :, :])
95 | ts_img_embedding = sam_model.image_encoder(input_image)
96 |
97 | img = image_data_pre
98 | boxes = segment.get_prompt(img)
99 |
100 | if boxes is not None:
101 | sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
102 | box = sam_trans.apply_boxes(boxes, (400,400))
103 | box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
104 | else:
105 | box_torch = None
106 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
107 | points=None,
108 | boxes=box_torch,
109 | masks=None,
110 | )
111 |
112 | # 使用Mask_Decoder生成分割结果
113 | medsam_seg_prob, _ = sam_model.mask_decoder(
114 | image_embeddings=ts_img_embedding.to(device),
115 | image_pe=sam_model.prompt_encoder.get_dense_pe(),
116 | sparse_prompt_embeddings=sparse_embeddings,
117 | dense_prompt_embeddings=dense_embeddings,
118 | multimask_output=False,
119 | )
120 | medsam_seg_prob =medsam_seg_prob.cpu().detach().numpy().squeeze()
121 | medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
122 |
123 | medsam_seg=cv2.resize(medsam_seg,(400,400))
124 |
125 |
126 | pred = cv2.Canny(cv2.resize((medsam_seg != 0).astype(np.uint8) * 255, (400, 400)), 100, 200)
127 |
128 | for i in range(pred.shape[0]):
129 | for j in range(pred.shape[1]):
130 | if pred[i, j] != 0:
131 | img[max(i - 1, 0):min(i + 2, 400), max(j - 1, 0):min(j + 2, 400), :] = [0, 0, 255]
132 |
133 | image1 = Image.fromarray(medsam_seg)
134 | image2 = Image.fromarray(img)
135 |
136 | image1 = image1.resize(image2.size).convert("RGBA")
137 | image2 = image2.convert("RGBA")
138 | data1 = image1.getdata()
139 |
140 | new_image = Image.new("RGBA", image2.size)
141 | new_data = [(0, 0, 128, 96) if pixel1[0] != 0 else (0, 0, 0, 0) for pixel1 in data1]
142 |
143 | new_image.putdata(new_data)
144 | if boxes is not None:
145 | draw = ImageDraw.Draw(image2)
146 | draw.rectangle([boxes[0],boxes[1],boxes[2],boxes[3]],fill=None, outline=(0, 255, 0), width=5) # 用红色绘制方框的边框,线宽为2
147 | image2.paste(new_image, (0, 0), mask=new_image)
148 | image2.save(path_out + f.split('.')[0] + '.png')
149 | print(f)
150 |
--------------------------------------------------------------------------------
/pretrained_model/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/pretrained_model/.DS_Store
--------------------------------------------------------------------------------
/segment/__pycache__/deeplab.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/__pycache__/deeplab.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/__pycache__/pspnet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/__pycache__/pspnet.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/__pycache__/unet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/__pycache__/unet.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/__pycache__/yolox.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/__pycache__/yolox.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/__pycache__/yolox_new.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/__pycache__/yolox_new.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/tongue_classes.txt:
--------------------------------------------------------------------------------
1 | tongue
--------------------------------------------------------------------------------
/segment/utils_yolox/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/callbacks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/callbacks.cpython-310.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/callbacks.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/callbacks.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/dataloader.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/dataloader.cpython-310.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/dataloader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/dataloader.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/utils_bbox.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_bbox.cpython-310.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/utils_bbox.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_bbox.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/utils_fit.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_fit.cpython-310.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/utils_fit.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_fit.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/utils_map.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_map.cpython-310.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/__pycache__/utils_map.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/utils_yolox/__pycache__/utils_map.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/utils_yolox/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import matplotlib
5 | matplotlib.use('Agg')
6 | import scipy.signal
7 | from matplotlib import pyplot as plt
8 | from torch.utils.tensorboard import SummaryWriter
9 |
10 | import shutil
11 | import numpy as np
12 |
13 | from PIL import Image
14 | from tqdm import tqdm
15 | from .utils import cvtColor, preprocess_input, resize_image
16 | from .utils_bbox import decode_outputs, non_max_suppression
17 | from .utils_map import get_coco_map, get_map
18 |
19 |
20 | class LossHistory():
21 | def __init__(self, log_dir, model, input_shape):
22 | self.log_dir = log_dir
23 | self.losses = []
24 | self.val_loss = []
25 |
26 | os.makedirs(self.log_dir)
27 | self.writer = SummaryWriter(self.log_dir)
28 | try:
29 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
30 | self.writer.add_graph(model, dummy_input)
31 | except:
32 | pass
33 |
34 | def append_loss(self, epoch, loss, val_loss):
35 | if not os.path.exists(self.log_dir):
36 | os.makedirs(self.log_dir)
37 |
38 | self.losses.append(loss)
39 | self.val_loss.append(val_loss)
40 |
41 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
42 | f.write(str(loss))
43 | f.write("\n")
44 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
45 | f.write(str(val_loss))
46 | f.write("\n")
47 |
48 | self.writer.add_scalar('loss', loss, epoch)
49 | self.writer.add_scalar('val_loss', val_loss, epoch)
50 | self.loss_plot()
51 |
52 | def loss_plot(self):
53 | iters = range(len(self.losses))
54 |
55 | plt.figure()
56 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
57 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
58 | try:
59 | if len(self.losses) < 25:
60 | num = 5
61 | else:
62 | num = 15
63 |
64 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
65 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
66 | except:
67 | pass
68 |
69 | plt.grid(True)
70 | plt.xlabel('Epoch')
71 | plt.ylabel('Loss')
72 | plt.legend(loc="upper right")
73 |
74 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
75 |
76 | plt.cla()
77 | plt.close("all")
78 |
79 | class EvalCallback():
80 | def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \
81 | map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
82 | super(EvalCallback, self).__init__()
83 |
84 | self.net = net
85 | self.input_shape = input_shape
86 | self.class_names = class_names
87 | self.num_classes = num_classes
88 | self.val_lines = val_lines
89 | self.log_dir = log_dir
90 | self.cuda = cuda
91 | self.map_out_path = map_out_path
92 | self.max_boxes = max_boxes
93 | self.confidence = confidence
94 | self.nms_iou = nms_iou
95 | self.letterbox_image = letterbox_image
96 | self.MINOVERLAP = MINOVERLAP
97 | self.eval_flag = eval_flag
98 | self.period = period
99 |
100 | self.maps = [0]
101 | self.epoches = [0]
102 | if self.eval_flag:
103 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
104 | f.write(str(0))
105 | f.write("\n")
106 |
107 | def get_map_txt(self, image_id, image, class_names, map_out_path):
108 | f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
109 | image_shape = np.array(np.shape(image)[0:2])
110 | #---------------------------------------------------------#
111 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
112 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
113 | #---------------------------------------------------------#
114 | image = cvtColor(image)
115 | #---------------------------------------------------------#
116 | # 给图像增加灰条,实现不失真的resize
117 | # 也可以直接resize进行识别
118 | #---------------------------------------------------------#
119 | image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
120 | #---------------------------------------------------------#
121 | # 添加上batch_size维度
122 | #---------------------------------------------------------#
123 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
124 |
125 | with torch.no_grad():
126 | images = torch.from_numpy(image_data)
127 | if self.cuda:
128 | images = images.cuda()
129 | #---------------------------------------------------------#
130 | # 将图像输入网络当中进行预测!
131 | #---------------------------------------------------------#
132 | outputs = self.net(images)
133 | outputs = decode_outputs(outputs, self.input_shape)
134 | #---------------------------------------------------------#
135 | # 将预测框进行堆叠,然后进行非极大抑制
136 | #---------------------------------------------------------#
137 | results = non_max_suppression(outputs, self.num_classes, self.input_shape,
138 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
139 |
140 | if results[0] is None:
141 | return
142 |
143 | top_label = np.array(results[0][:, 6], dtype = 'int32')
144 | top_conf = results[0][:, 4] * results[0][:, 5]
145 | top_boxes = results[0][:, :4]
146 |
147 | top_100 = np.argsort(top_conf)[::-1][:self.max_boxes]
148 | top_boxes = top_boxes[top_100]
149 | top_conf = top_conf[top_100]
150 | top_label = top_label[top_100]
151 |
152 | for i, c in list(enumerate(top_label)):
153 | predicted_class = self.class_names[int(c)]
154 | box = top_boxes[i]
155 | score = str(top_conf[i])
156 |
157 | top, left, bottom, right = box
158 | if predicted_class not in class_names:
159 | continue
160 |
161 | f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
162 |
163 | f.close()
164 | return
165 |
166 | def on_epoch_end(self, epoch, model_eval):
167 | if epoch % self.period == 0 and self.eval_flag:
168 | self.net = model_eval
169 | if not os.path.exists(self.map_out_path):
170 | os.makedirs(self.map_out_path)
171 | if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
172 | os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
173 | if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
174 | os.makedirs(os.path.join(self.map_out_path, "detection-results"))
175 | print("Get map.")
176 | for annotation_line in tqdm(self.val_lines):
177 | line = annotation_line.split()
178 | image_id = os.path.basename(line[0]).split('.')[0]
179 | #------------------------------#
180 | # 读取图像并转换成RGB图像
181 | #------------------------------#
182 | image = Image.open(line[0])
183 | #------------------------------#
184 | # 获得预测框
185 | #------------------------------#
186 | gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
187 | #------------------------------#
188 | # 获得预测txt
189 | #------------------------------#
190 | self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
191 |
192 | #------------------------------#
193 | # 获得真实框txt
194 | #------------------------------#
195 | with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
196 | for box in gt_boxes:
197 | left, top, right, bottom, obj = box
198 | obj_name = self.class_names[obj]
199 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
200 |
201 | print("Calculate Map.")
202 | try:
203 | temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
204 | except:
205 | temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
206 | self.maps.append(temp_map)
207 | self.epoches.append(epoch)
208 |
209 | with open("./epoch_map.txt", 'a') as f:
210 | f.write(str(temp_map))
211 | f.write("\n")
212 |
213 | plt.figure()
214 | plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
215 |
216 | plt.grid(True)
217 | plt.xlabel('Epoch')
218 | plt.ylabel('Map %s'%str(self.MINOVERLAP))
219 | plt.title('A Map Curve')
220 | plt.legend(loc="upper right")
221 |
222 | plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
223 | plt.cla()
224 | plt.close("all")
225 |
226 | print("Get map done.")
227 | shutil.rmtree(self.map_out_path)
228 |
--------------------------------------------------------------------------------
/segment/utils_yolox/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 |
5 | #---------------------------------------------------------#
6 | # 将图像转换成RGB图像,防止灰度图在预测时报错。
7 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
8 | #---------------------------------------------------------#
9 | def cvtColor(image):
10 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
11 | return image
12 | else:
13 | image = image.convert('RGB')
14 | return image
15 |
16 | #---------------------------------------------------#
17 | # 对输入图像进行resize
18 | #---------------------------------------------------#
19 | def resize_image(image, size, letterbox_image):
20 | iw, ih = image.size
21 | w, h = size
22 | if letterbox_image:
23 | scale = min(w/iw, h/ih)
24 | nw = int(iw*scale)
25 | nh = int(ih*scale)
26 |
27 | image = image.resize((nw,nh), Image.BICUBIC)
28 | new_image = Image.new('RGB', size, (128,128,128))
29 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
30 | else:
31 | new_image = image.resize((w, h), Image.BICUBIC)
32 | return new_image
33 |
34 | #---------------------------------------------------#
35 | # 获得类
36 | #---------------------------------------------------#
37 | def get_classes(classes_path):
38 | with open(classes_path, encoding='utf-8') as f:
39 | class_names = f.readlines()
40 | class_names = [c.strip() for c in class_names]
41 | return class_names, len(class_names)
42 |
43 | def preprocess_input(image):
44 | image /= 255.0
45 | image -= np.array([0.485, 0.456, 0.406])
46 | image /= np.array([0.229, 0.224, 0.225])
47 | return image
48 |
49 | #---------------------------------------------------#
50 | # 获得学习率
51 | #---------------------------------------------------#
52 | def get_lr(optimizer):
53 | for param_group in optimizer.param_groups:
54 | return param_group['lr']
55 |
56 | def show_config(**kwargs):
57 | print('Configurations:')
58 | print('-' * 70)
59 | print('|%25s | %40s|' % ('keys', 'values'))
60 | print('-' * 70)
61 | for key, value in kwargs.items():
62 | print('|%25s | %40s|' % (str(key), str(value)))
63 | print('-' * 70)
64 |
--------------------------------------------------------------------------------
/segment/utils_yolox/utils_bbox.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchvision.ops import nms, boxes
4 |
5 | def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image):
6 | #-----------------------------------------------------------------#
7 | # 把y轴放前面是因为方便预测框和图像的宽高进行相乘
8 | #-----------------------------------------------------------------#
9 | box_yx = box_xy[..., ::-1]
10 | box_hw = box_wh[..., ::-1]
11 | input_shape = np.array(input_shape)
12 | image_shape = np.array(image_shape)
13 |
14 | if letterbox_image:
15 | #-----------------------------------------------------------------#
16 | # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
17 | # new_shape指的是宽高缩放情况
18 | #-----------------------------------------------------------------#
19 | new_shape = np.round(image_shape * np.min(input_shape/image_shape))
20 | offset = (input_shape - new_shape)/2./input_shape
21 | scale = input_shape/new_shape
22 |
23 | box_yx = (box_yx - offset) * scale
24 | box_hw *= scale
25 |
26 | box_mins = box_yx - (box_hw / 2.)
27 | box_maxes = box_yx + (box_hw / 2.)
28 | boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
29 | boxes *= np.concatenate([image_shape, image_shape], axis=-1)
30 | return boxes
31 |
32 | def decode_outputs(outputs, input_shape):
33 | grids = []
34 | strides = []
35 | hw = [x.shape[-2:] for x in outputs]
36 | #---------------------------------------------------#
37 | # outputs输入前代表每个特征层的预测结果
38 | # batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400
39 | # batch_size, 5 + num_classes, 40, 40
40 | # batch_size, 5 + num_classes, 20, 20
41 | # batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400
42 | # 堆叠后为batch_size, 8400, 5 + num_classes
43 | #---------------------------------------------------#
44 | outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
45 | #---------------------------------------------------#
46 | # 获得每一个特征点属于每一个种类的概率
47 | #---------------------------------------------------#
48 | outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:])
49 | for h, w in hw:
50 | #---------------------------#
51 | # 根据特征层的高宽生成网格点
52 | #---------------------------#
53 | grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)])
54 | #---------------------------#
55 | # 1, 6400, 2
56 | # 1, 1600, 2
57 | # 1, 400, 2
58 | #---------------------------#
59 | grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2)
60 | shape = grid.shape[:2]
61 |
62 | grids.append(grid)
63 | strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h))
64 | #---------------------------#
65 | # 将网格点堆叠到一起
66 | # 1, 6400, 2
67 | # 1, 1600, 2
68 | # 1, 400, 2
69 | #
70 | # 1, 8400, 2
71 | #---------------------------#
72 | grids = torch.cat(grids, dim=1).type(outputs.type())
73 | strides = torch.cat(strides, dim=1).type(outputs.type())
74 | #------------------------#
75 | # 根据网格点进行解码
76 | #------------------------#
77 | outputs[..., :2] = (outputs[..., :2] + grids) * strides
78 | outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
79 | #-----------------#
80 | # 归一化
81 | #-----------------#
82 | outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1]
83 | outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0]
84 | return outputs
85 |
86 | def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
87 | #----------------------------------------------------------#
88 | # 将预测结果的格式转换成左上角右下角的格式。
89 | # prediction [batch_size, num_anchors, 85]
90 | #----------------------------------------------------------#
91 | box_corner = prediction.new(prediction.shape)
92 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
93 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
94 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
95 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
96 | prediction[:, :, :4] = box_corner[:, :, :4]
97 |
98 | output = [None for _ in range(len(prediction))]
99 | #----------------------------------------------------------#
100 | # 对输入图片进行循环,一般只会进行一次
101 | #----------------------------------------------------------#
102 | for i, image_pred in enumerate(prediction):
103 |
104 | #----------------------------------------------------------#
105 | # 对种类预测部分取max。
106 | # class_conf [num_anchors, 1] 种类置信度
107 | # class_pred [num_anchors, 1] 种类
108 | #----------------------------------------------------------#
109 | class_conf,class_pred= torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
110 | for i in class_pred:
111 | if class_pred[i]!=0:
112 | print(1)
113 | #----------------------------------------------------------#
114 | # 利用置信度进行第一轮筛选
115 | #----------------------------------------------------------#
116 | conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
117 | if not image_pred.size(0):
118 | continue
119 | #-------------------------------------------------------------------------#
120 | # detections [num_anchors, 7]
121 | # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
122 | #-------------------------------------------------------------------------#
123 | detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
124 | detections = detections[conf_mask]
125 | nms_out_index = boxes.batched_nms(
126 | detections[:, :4],
127 | detections[:, 4] * detections[:, 5],
128 | detections[:, 6],
129 | nms_thres,
130 | )
131 |
132 | output[i] = detections[nms_out_index]
133 |
134 | # #------------------------------------------#
135 | # # 获得预测结果中包含的所有种类
136 | # #------------------------------------------#
137 | # unique_labels = detections[:, -1].cpu().unique()
138 |
139 | # if prediction.is_cuda:
140 | # unique_labels = unique_labels.cuda()
141 | # detections = detections.cuda()
142 |
143 | # for c in unique_labels:
144 | # #------------------------------------------#
145 | # # 获得某一类得分筛选后全部的预测结果
146 | # #------------------------------------------#
147 | # detections_class = detections[detections[:, -1] == c]
148 |
149 | # #------------------------------------------#
150 | # # 使用官方自带的非极大抑制会速度更快一些!
151 | # #------------------------------------------#
152 | # keep = nms(
153 | # detections_class[:, :4],
154 | # detections_class[:, 4] * detections_class[:, 5],
155 | # nms_thres
156 | # )
157 | # max_detections = detections_class[keep]
158 |
159 | # # # 按照存在物体的置信度排序
160 | # # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
161 | # # detections_class = detections_class[conf_sort_index]
162 | # # # 进行非极大抑制
163 | # # max_detections = []
164 | # # while detections_class.size(0):
165 | # # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
166 | # # max_detections.append(detections_class[0].unsqueeze(0))
167 | # # if len(detections_class) == 1:
168 | # # break
169 | # # ious = bbox_iou(max_detections[-1], detections_class[1:])
170 | # # detections_class = detections_class[1:][ious < nms_thres]
171 | # # # 堆叠
172 | # # max_detections = torch.cat(max_detections).data
173 |
174 | # # Add max detections to outputs
175 | # output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
176 |
177 | if output[i] is not None:
178 | output[i] = output[i].cpu().numpy()
179 | box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
180 | output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
181 | return output
182 |
--------------------------------------------------------------------------------
/segment/utils_yolox/utils_fit.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | from utils.utils import get_lr
7 |
8 |
9 | def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
10 | loss = 0
11 | val_loss = 0
12 |
13 | if local_rank == 0:
14 | print('Start Train')
15 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
16 | model_train.train()
17 | for iteration, batch in enumerate(gen):
18 | if iteration >= epoch_step:
19 | break
20 |
21 | images, targets = batch[0], batch[1]
22 | with torch.no_grad():
23 | if cuda:
24 | images = images.cuda(local_rank)
25 | targets = [ann.cuda(local_rank) for ann in targets]
26 | #----------------------#
27 | # 清零梯度
28 | #----------------------#
29 | optimizer.zero_grad()
30 | if not fp16:
31 | #----------------------#
32 | # 前向传播
33 | #----------------------#
34 | outputs = model_train(images)
35 |
36 | #----------------------#
37 | # 计算损失
38 | #----------------------#
39 | loss_value = yolo_loss(outputs, targets)
40 |
41 | #----------------------#
42 | # 反向传播
43 | #----------------------#
44 | loss_value.backward()
45 | optimizer.step()
46 | else:
47 | from torch.cuda.amp import autocast
48 | with autocast():
49 | outputs = model_train(images)
50 | #----------------------#
51 | # 计算损失
52 | #----------------------#
53 | loss_value = yolo_loss(outputs, targets)
54 |
55 | #----------------------#
56 | # 反向传播
57 | #----------------------#
58 | scaler.scale(loss_value).backward()
59 | scaler.step(optimizer)
60 | scaler.update()
61 | if ema:
62 | ema.update(model_train)
63 |
64 | loss += loss_value.item()
65 |
66 | if local_rank == 0:
67 | pbar.set_postfix(**{'loss' : loss / (iteration + 1),
68 | 'lr' : get_lr(optimizer)})
69 | pbar.update(1)
70 |
71 | if local_rank == 0:
72 | pbar.close()
73 | print('Finish Train')
74 | print('Start Validation')
75 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
76 |
77 | if ema:
78 | model_train_eval = ema.ema
79 | else:
80 | model_train_eval = model_train.eval()
81 |
82 | for iteration, batch in enumerate(gen_val):
83 | if iteration >= epoch_step_val:
84 | break
85 | images, targets = batch[0], batch[1]
86 | with torch.no_grad():
87 | if cuda:
88 | images = images.cuda(local_rank)
89 | targets = [ann.cuda(local_rank) for ann in targets]
90 | #----------------------#
91 | # 清零梯度
92 | #----------------------#
93 | optimizer.zero_grad()
94 | #----------------------#
95 | # 前向传播
96 | #----------------------#
97 | outputs = model_train_eval(images)
98 |
99 | #----------------------#
100 | # 计算损失
101 | #----------------------#
102 | loss_value = yolo_loss(outputs, targets)
103 |
104 | val_loss += loss_value.item()
105 | if local_rank == 0:
106 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
107 | pbar.update(1)
108 |
109 | if local_rank == 0:
110 | pbar.close()
111 | print('Finish Validation')
112 | loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
113 | eval_callback.on_epoch_end(epoch + 1, model_train_eval)
114 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
115 | print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
116 |
117 | #-----------------------------------------------#
118 | # 保存权值
119 | #-----------------------------------------------#
120 | if ema:
121 | save_state_dict = ema.ema.state_dict()
122 | else:
123 | save_state_dict = model.state_dict()
124 |
125 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
126 | torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
127 |
128 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
129 | print('Save best model to best_epoch_weights.pth')
130 | torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))
131 |
132 | torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth"))
--------------------------------------------------------------------------------
/segment/yolox.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox.pth
--------------------------------------------------------------------------------
/segment/yolox_nets/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/segment/yolox_nets/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/segment/yolox_nets/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/yolox_nets/__pycache__/darknet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/darknet.cpython-310.pyc
--------------------------------------------------------------------------------
/segment/yolox_nets/__pycache__/darknet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/darknet.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/yolox_nets/__pycache__/yolo.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/yolo.cpython-310.pyc
--------------------------------------------------------------------------------
/segment/yolox_nets/__pycache__/yolo.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/yolo.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/yolox_nets/__pycache__/yolo_training.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/yolo_training.cpython-310.pyc
--------------------------------------------------------------------------------
/segment/yolox_nets/__pycache__/yolo_training.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment/yolox_nets/__pycache__/yolo_training.cpython-39.pyc
--------------------------------------------------------------------------------
/segment/yolox_nets/darknet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Copyright (c) Megvii, Inc. and its affiliates.
4 |
5 | import torch
6 | from torch import nn
7 |
8 | class SiLU(nn.Module):
9 | @staticmethod
10 | def forward(x):
11 | return x * torch.sigmoid(x)
12 |
13 | def get_activation(name="silu", inplace=True):
14 | if name == "silu":
15 | module = SiLU()
16 | elif name == "relu":
17 | module = nn.ReLU(inplace=inplace)
18 | elif name == "lrelu":
19 | module = nn.LeakyReLU(0.1, inplace=inplace)
20 | else:
21 | raise AttributeError("Unsupported act type: {}".format(name))
22 | return module
23 |
24 | class Focus(nn.Module):
25 | def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
26 | super().__init__()
27 | self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
28 |
29 | def forward(self, x):
30 | patch_top_left = x[..., ::2, ::2]
31 | patch_bot_left = x[..., 1::2, ::2]
32 | patch_top_right = x[..., ::2, 1::2]
33 | patch_bot_right = x[..., 1::2, 1::2]
34 | x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), dim=1,)
35 | return self.conv(x)
36 |
37 | class BaseConv(nn.Module):
38 | def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):
39 | super().__init__()
40 | pad = (ksize - 1) // 2
41 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, stride=stride, padding=pad, groups=groups, bias=bias)
42 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.03)
43 | self.act = get_activation(act, inplace=True)
44 |
45 | def forward(self, x):
46 | return self.act(self.bn(self.conv(x)))
47 |
48 | def fuseforward(self, x):
49 | return self.act(self.conv(x))
50 |
51 | class DWConv(nn.Module):
52 | def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
53 | super().__init__()
54 | self.dconv = BaseConv(in_channels, in_channels, ksize=ksize, stride=stride, groups=in_channels, act=act,)
55 | self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act)
56 |
57 | def forward(self, x):
58 | x = self.dconv(x)
59 | return self.pconv(x)
60 |
61 | class SPPBottleneck(nn.Module):
62 | def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):
63 | super().__init__()
64 | hidden_channels = in_channels // 2
65 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
66 | self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes])
67 | conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
68 | self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
69 |
70 | def forward(self, x):
71 | x = self.conv1(x)
72 | x = torch.cat([x] + [m(x) for m in self.m], dim=1)
73 | x = self.conv2(x)
74 | return x
75 |
76 | #--------------------------------------------------#
77 | # 残差结构的构建,小的残差结构
78 | #--------------------------------------------------#
79 | class Bottleneck(nn.Module):
80 | # Standard bottleneck
81 | def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
82 | super().__init__()
83 | hidden_channels = int(out_channels * expansion)
84 | Conv = DWConv if depthwise else BaseConv
85 | #--------------------------------------------------#
86 | # 利用1x1卷积进行通道数的缩减。缩减率一般是50%
87 | #--------------------------------------------------#
88 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
89 | #--------------------------------------------------#
90 | # 利用3x3卷积进行通道数的拓张。并且完成特征提取
91 | #--------------------------------------------------#
92 | self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
93 | self.use_add = shortcut and in_channels == out_channels
94 |
95 | def forward(self, x):
96 | y = self.conv2(self.conv1(x))
97 | if self.use_add:
98 | y = y + x
99 | return y
100 |
101 | class CSPLayer(nn.Module):
102 | def __init__(self, in_channels, out_channels, n=1, shortcut=True, expansion=0.5, depthwise=False, act="silu",):
103 | # ch_in, ch_out, number, shortcut, groups, expansion
104 | super().__init__()
105 | hidden_channels = int(out_channels * expansion)
106 | #--------------------------------------------------#
107 | # 主干部分的初次卷积
108 | #--------------------------------------------------#
109 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
110 | #--------------------------------------------------#
111 | # 大的残差边部分的初次卷积
112 | #--------------------------------------------------#
113 | self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
114 | #-----------------------------------------------#
115 | # 对堆叠的结果进行卷积的处理
116 | #-----------------------------------------------#
117 | self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
118 |
119 | #--------------------------------------------------#
120 | # 根据循环的次数构建上述Bottleneck残差结构
121 | #--------------------------------------------------#
122 | module_list = [Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)]
123 | self.m = nn.Sequential(*module_list)
124 |
125 | def forward(self, x):
126 | #-------------------------------#
127 | # x_1是主干部分
128 | #-------------------------------#
129 | x_1 = self.conv1(x)
130 | #-------------------------------#
131 | # x_2是大的残差边部分
132 | #-------------------------------#
133 | x_2 = self.conv2(x)
134 |
135 | #-----------------------------------------------#
136 | # 主干部分利用残差结构堆叠继续进行特征提取
137 | #-----------------------------------------------#
138 | x_1 = self.m(x_1)
139 | #-----------------------------------------------#
140 | # 主干部分和大的残差边部分进行堆叠
141 | #-----------------------------------------------#
142 | x = torch.cat((x_1, x_2), dim=1)
143 | #-----------------------------------------------#
144 | # 对堆叠的结果进行卷积的处理
145 | #-----------------------------------------------#
146 | return self.conv3(x)
147 |
148 | class CSPDarknet(nn.Module):
149 | def __init__(self, dep_mul, wid_mul, out_features=("dark3", "dark4", "dark5"), depthwise=False, act="silu",):
150 | super().__init__()
151 | assert out_features, "please provide output features of Darknet"
152 | self.out_features = out_features
153 | Conv = DWConv if depthwise else BaseConv
154 |
155 | #-----------------------------------------------#
156 | # 输入图片是640, 640, 3
157 | # 初始的基本通道是64
158 | #-----------------------------------------------#
159 | base_channels = int(wid_mul * 64) # 64
160 | base_depth = max(round(dep_mul * 3), 1) # 3
161 |
162 | #-----------------------------------------------#
163 | # 利用focus网络结构进行特征提取
164 | # 640, 640, 3 -> 320, 320, 12 -> 320, 320, 64
165 | #-----------------------------------------------#
166 | self.stem = Focus(3, base_channels, ksize=3, act=act)
167 |
168 | #-----------------------------------------------#
169 | # 完成卷积之后,320, 320, 64 -> 160, 160, 128
170 | # 完成CSPlayer之后,160, 160, 128 -> 160, 160, 128
171 | #-----------------------------------------------#
172 | self.dark2 = nn.Sequential(
173 | Conv(base_channels, base_channels * 2, 3, 2, act=act),
174 | CSPLayer(base_channels * 2, base_channels * 2, n=base_depth, depthwise=depthwise, act=act),
175 | )
176 |
177 | #-----------------------------------------------#
178 | # 完成卷积之后,160, 160, 128 -> 80, 80, 256
179 | # 完成CSPlayer之后,80, 80, 256 -> 80, 80, 256
180 | #-----------------------------------------------#
181 | self.dark3 = nn.Sequential(
182 | Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
183 | CSPLayer(base_channels * 4, base_channels * 4, n=base_depth * 3, depthwise=depthwise, act=act),
184 | )
185 |
186 | #-----------------------------------------------#
187 | # 完成卷积之后,80, 80, 256 -> 40, 40, 512
188 | # 完成CSPlayer之后,40, 40, 512 -> 40, 40, 512
189 | #-----------------------------------------------#
190 | self.dark4 = nn.Sequential(
191 | Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
192 | CSPLayer(base_channels * 8, base_channels * 8, n=base_depth * 3, depthwise=depthwise, act=act),
193 | )
194 |
195 | #-----------------------------------------------#
196 | # 完成卷积之后,40, 40, 512 -> 20, 20, 1024
197 | # 完成SPP之后,20, 20, 1024 -> 20, 20, 1024
198 | # 完成CSPlayer之后,20, 20, 1024 -> 20, 20, 1024
199 | #-----------------------------------------------#
200 | self.dark5 = nn.Sequential(
201 | Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
202 | SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
203 | CSPLayer(base_channels * 16, base_channels * 16, n=base_depth, shortcut=False, depthwise=depthwise, act=act),
204 | )
205 |
206 | def forward(self, x):
207 | outputs = {}
208 | x = self.stem(x)
209 | outputs["stem"] = x
210 | x = self.dark2(x)
211 | outputs["dark2"] = x
212 | #-----------------------------------------------#
213 | # dark3的输出为80, 80, 256,是一个有效特征层
214 | #-----------------------------------------------#
215 | x = self.dark3(x)
216 |
217 | outputs["dark3"] = x
218 | #-----------------------------------------------#
219 | # dark4的输出为40, 40, 512,是一个有效特征层
220 | #-----------------------------------------------#
221 | x = self.dark4(x)
222 |
223 | outputs["dark4"] = x
224 | #-----------------------------------------------#
225 | # dark5的输出为20, 20, 1024,是一个有效特征层
226 | #-----------------------------------------------#
227 | x = self.dark5(x)
228 |
229 | outputs["dark5"] = x
230 | return {k: v for k, v in outputs.items() if k in self.out_features}
231 |
232 |
233 | if __name__ == '__main__':
234 | print(CSPDarknet(1, 1))
--------------------------------------------------------------------------------
/segment/yolox_nets/yolo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Copyright (c) Megvii, Inc. and its affiliates.
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .darknet import BaseConv, CSPDarknet, CSPLayer, DWConv
9 |
10 |
11 | class YOLOXHead(nn.Module):
12 | def __init__(self, num_classes, width = 1.0, in_channels = [256, 512, 1024], act = "silu", depthwise = False,):
13 | super().__init__()
14 | Conv = DWConv if depthwise else BaseConv
15 |
16 | self.cls_convs = nn.ModuleList()
17 | self.reg_convs = nn.ModuleList()
18 | self.cls_preds = nn.ModuleList()
19 | self.reg_preds = nn.ModuleList()
20 | self.obj_preds = nn.ModuleList()
21 | self.stems = nn.ModuleList()
22 |
23 | for i in range(len(in_channels)):
24 | self.stems.append(BaseConv(in_channels = int(in_channels[i] * width), out_channels = int(256 * width), ksize = 1, stride = 1, act = act))
25 | self.cls_convs.append(nn.Sequential(*[
26 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
27 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
28 | ]))
29 | self.cls_preds.append(
30 | nn.Conv2d(in_channels = int(256 * width), out_channels = num_classes, kernel_size = 1, stride = 1, padding = 0)
31 | )
32 |
33 |
34 | self.reg_convs.append(nn.Sequential(*[
35 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act),
36 | Conv(in_channels = int(256 * width), out_channels = int(256 * width), ksize = 3, stride = 1, act = act)
37 | ]))
38 | self.reg_preds.append(
39 | nn.Conv2d(in_channels = int(256 * width), out_channels = 4, kernel_size = 1, stride = 1, padding = 0)
40 | )
41 | self.obj_preds.append(
42 | nn.Conv2d(in_channels = int(256 * width), out_channels = 1, kernel_size = 1, stride = 1, padding = 0)
43 | )
44 |
45 | def forward(self, inputs):
46 | #---------------------------------------------------#
47 | # inputs输入
48 | # P3_out 80, 80, 256
49 | # P4_out 40, 40, 512
50 | # P5_out 20, 20, 1024
51 | #---------------------------------------------------#
52 | outputs = []
53 | for k, x in enumerate(inputs):
54 | #---------------------------------------------------#
55 | # 利用1x1卷积进行通道整合
56 | #---------------------------------------------------#
57 | x = self.stems[k](x)
58 | #---------------------------------------------------#
59 | # 利用两个卷积标准化激活函数来进行特征提取
60 | #---------------------------------------------------#
61 | cls_feat = self.cls_convs[k](x)
62 | #---------------------------------------------------#
63 | # 判断特征点所属的种类
64 | # 80, 80, num_classes
65 | # 40, 40, num_classes
66 | # 20, 20, num_classes
67 | #---------------------------------------------------#
68 | cls_output = self.cls_preds[k](cls_feat)
69 |
70 | #---------------------------------------------------#
71 | # 利用两个卷积标准化激活函数来进行特征提取
72 | #---------------------------------------------------#
73 | reg_feat = self.reg_convs[k](x)
74 | #---------------------------------------------------#
75 | # 特征点的回归系数
76 | # reg_pred 80, 80, 4
77 | # reg_pred 40, 40, 4
78 | # reg_pred 20, 20, 4
79 | #---------------------------------------------------#
80 | reg_output = self.reg_preds[k](reg_feat)
81 | #---------------------------------------------------#
82 | # 判断特征点是否有对应的物体
83 | # obj_pred 80, 80, 1
84 | # obj_pred 40, 40, 1
85 | # obj_pred 20, 20, 1
86 | #---------------------------------------------------#
87 | obj_output = self.obj_preds[k](reg_feat)
88 |
89 | output = torch.cat([reg_output, obj_output, cls_output], 1)
90 | outputs.append(output)
91 | return outputs
92 |
93 | class YOLOPAFPN(nn.Module):
94 | def __init__(self, depth = 1.0, width = 1.0, in_features = ("dark3", "dark4", "dark5"), in_channels = [256, 512, 1024], depthwise = False, act = "silu"):
95 | super().__init__()
96 | Conv = DWConv if depthwise else BaseConv
97 | self.backbone = CSPDarknet(depth, width, depthwise = depthwise, act = act)
98 | self.in_features = in_features
99 |
100 | self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
101 |
102 | #-------------------------------------------#
103 | # 20, 20, 1024 -> 20, 20, 512
104 | #-------------------------------------------#
105 | self.lateral_conv0 = BaseConv(int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act)
106 |
107 | #-------------------------------------------#
108 | # 40, 40, 1024 -> 40, 40, 512
109 | #-------------------------------------------#
110 | self.C3_p4 = CSPLayer(
111 | int(2 * in_channels[1] * width),
112 | int(in_channels[1] * width),
113 | round(3 * depth),
114 | False,
115 | depthwise = depthwise,
116 | act = act,
117 | )
118 |
119 | #-------------------------------------------#
120 | # 40, 40, 512 -> 40, 40, 256
121 | #-------------------------------------------#
122 | self.reduce_conv1 = BaseConv(int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act)
123 | #-------------------------------------------#
124 | # 80, 80, 512 -> 80, 80, 256
125 | #-------------------------------------------#
126 | self.C3_p3 = CSPLayer(
127 | int(2 * in_channels[0] * width),
128 | int(in_channels[0] * width),
129 | round(3 * depth),
130 | False,
131 | depthwise = depthwise,
132 | act = act,
133 | )
134 |
135 | #-------------------------------------------#
136 | # 80, 80, 256 -> 40, 40, 256
137 | #-------------------------------------------#
138 | self.bu_conv2 = Conv(int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act)
139 | #-------------------------------------------#
140 | # 40, 40, 256 -> 40, 40, 512
141 | #-------------------------------------------#
142 | self.C3_n3 = CSPLayer(
143 | int(2 * in_channels[0] * width),
144 | int(in_channels[1] * width),
145 | round(3 * depth),
146 | False,
147 | depthwise = depthwise,
148 | act = act,
149 | )
150 |
151 | #-------------------------------------------#
152 | # 40, 40, 512 -> 20, 20, 512
153 | #-------------------------------------------#
154 | self.bu_conv1 = Conv(int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act)
155 | #-------------------------------------------#
156 | # 20, 20, 1024 -> 20, 20, 1024
157 | #-------------------------------------------#
158 | self.C3_n4 = CSPLayer(
159 | int(2 * in_channels[1] * width),
160 | int(in_channels[2] * width),
161 | round(3 * depth),
162 | False,
163 | depthwise = depthwise,
164 | act = act,
165 | )
166 |
167 | def forward(self, input):
168 | out_features = self.backbone.forward(input)
169 | [feat1, feat2, feat3] = [out_features[f] for f in self.in_features]
170 |
171 | #-------------------------------------------#
172 | # 20, 20, 1024 -> 20, 20, 512
173 | #-------------------------------------------#
174 | P5 = self.lateral_conv0(feat3)
175 | #-------------------------------------------#
176 | # 20, 20, 512 -> 40, 40, 512
177 | #-------------------------------------------#
178 | P5_upsample = self.upsample(P5)
179 | #-------------------------------------------#
180 | # 40, 40, 512 + 40, 40, 512 -> 40, 40, 1024
181 | #-------------------------------------------#
182 | P5_upsample = torch.cat([P5_upsample, feat2], 1)
183 | #-------------------------------------------#
184 | # 40, 40, 1024 -> 40, 40, 512
185 | #-------------------------------------------#
186 | P5_upsample = self.C3_p4(P5_upsample)
187 |
188 | #-------------------------------------------#
189 | # 40, 40, 512 -> 40, 40, 256
190 | #-------------------------------------------#
191 | P4 = self.reduce_conv1(P5_upsample)
192 | #-------------------------------------------#
193 | # 40, 40, 256 -> 80, 80, 256
194 | #-------------------------------------------#
195 | P4_upsample = self.upsample(P4)
196 | #-------------------------------------------#
197 | # 80, 80, 256 + 80, 80, 256 -> 80, 80, 512
198 | #-------------------------------------------#
199 | P4_upsample = torch.cat([P4_upsample, feat1], 1)
200 | #-------------------------------------------#
201 | # 80, 80, 512 -> 80, 80, 256
202 | #-------------------------------------------#
203 | P3_out = self.C3_p3(P4_upsample)
204 |
205 | #-------------------------------------------#
206 | # 80, 80, 256 -> 40, 40, 256
207 | #-------------------------------------------#
208 | P3_downsample = self.bu_conv2(P3_out)
209 | #-------------------------------------------#
210 | # 40, 40, 256 + 40, 40, 256 -> 40, 40, 512
211 | #-------------------------------------------#
212 | P3_downsample = torch.cat([P3_downsample, P4], 1)
213 | #-------------------------------------------#
214 | # 40, 40, 256 -> 40, 40, 512
215 | #-------------------------------------------#
216 | P4_out = self.C3_n3(P3_downsample)
217 |
218 | #-------------------------------------------#
219 | # 40, 40, 512 -> 20, 20, 512
220 | #-------------------------------------------#
221 | P4_downsample = self.bu_conv1(P4_out)
222 | #-------------------------------------------#
223 | # 20, 20, 512 + 20, 20, 512 -> 20, 20, 1024
224 | #-------------------------------------------#
225 | P4_downsample = torch.cat([P4_downsample, P5], 1)
226 | #-------------------------------------------#
227 | # 20, 20, 1024 -> 20, 20, 1024
228 | #-------------------------------------------#
229 | P5_out = self.C3_n4(P4_downsample)
230 |
231 | return (P3_out, P4_out, P5_out)
232 |
233 | class YoloBody(nn.Module):
234 | def __init__(self, num_classes, phi):
235 | super().__init__()
236 | depth_dict = {'nano': 0.33, 'tiny': 0.33, 's' : 0.33, 'm' : 0.67, 'l' : 1.00, 'x' : 1.33,}
237 | width_dict = {'nano': 0.25, 'tiny': 0.375, 's' : 0.50, 'm' : 0.75, 'l' : 1.00, 'x' : 1.25,}
238 | depth, width = depth_dict[phi], width_dict[phi]
239 | depthwise = True if phi == 'nano' else False
240 |
241 | self.backbone = YOLOPAFPN(depth, width, depthwise=depthwise)
242 | self.head = YOLOXHead(num_classes, width, depthwise=depthwise)
243 |
244 | def forward(self, x):
245 | fpn_outs = self.backbone.forward(x)
246 | outputs = self.head.forward(fpn_outs)
247 | return outputs
248 |
--------------------------------------------------------------------------------
/segment_anything/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .build_sam import (
8 | build_sam,
9 | build_sam_vit_h,
10 | build_sam_vit_l,
11 | build_sam_vit_b,
12 | sam_model_registry,
13 | )
14 | from .predictor import SamPredictor
15 | from .automatic_mask_generator import SamAutomaticMaskGenerator
16 |
--------------------------------------------------------------------------------
/segment_anything/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/automatic_mask_generator.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/__pycache__/automatic_mask_generator.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/automatic_mask_generator.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/__pycache__/build_sam.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/build_sam.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/__pycache__/build_sam.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/build_sam.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/__pycache__/predictor.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/predictor.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/__pycache__/predictor.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/__pycache__/predictor.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from functools import partial
7 | from pathlib import Path
8 | import urllib.request
9 | import torch
10 |
11 | from .modeling import (
12 | ImageEncoderViT,
13 | MaskDecoder,
14 | PromptEncoder,
15 | Sam,
16 | TwoWayTransformer,
17 | )
18 |
19 |
20 | def build_sam_vit_h(checkpoint=None):
21 | return _build_sam(
22 | encoder_embed_dim=1280,
23 | encoder_depth=32,
24 | encoder_num_heads=16,
25 | encoder_global_attn_indexes=[7, 15, 23, 31],
26 | checkpoint=checkpoint,
27 | )
28 |
29 |
30 | build_sam = build_sam_vit_h
31 |
32 |
33 | def build_sam_vit_l(checkpoint=None):
34 | return _build_sam(
35 | encoder_embed_dim=1024,
36 | encoder_depth=24,
37 | encoder_num_heads=16,
38 | encoder_global_attn_indexes=[5, 11, 17, 23],
39 | checkpoint=checkpoint,
40 | )
41 |
42 |
43 | def build_sam_vit_b(checkpoint=None):
44 | return _build_sam(
45 | encoder_embed_dim=768,
46 | encoder_depth=12,
47 | encoder_num_heads=12,
48 | encoder_global_attn_indexes=[2, 5, 8, 11],
49 | checkpoint=checkpoint,
50 | )
51 |
52 |
53 | sam_model_registry = {
54 | "default": build_sam_vit_h,
55 | "vit_h": build_sam_vit_h,
56 | "vit_l": build_sam_vit_l,
57 | "vit_b": build_sam_vit_b,
58 | }
59 |
60 |
61 |
62 |
63 | def _build_sam(
64 | encoder_embed_dim,
65 | encoder_depth,
66 | encoder_num_heads,
67 | encoder_global_attn_indexes,
68 | checkpoint=None,
69 | ):
70 | prompt_embed_dim = 256
71 | image_size = 1024
72 | vit_patch_size = 16
73 | image_embedding_size = image_size // vit_patch_size
74 | sam = Sam(
75 | image_encoder=ImageEncoderViT(
76 | depth=encoder_depth,
77 | embed_dim=encoder_embed_dim,
78 | img_size=image_size,
79 | mlp_ratio=4,
80 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
81 | num_heads=encoder_num_heads,
82 | patch_size=vit_patch_size,
83 | qkv_bias=True,
84 | use_rel_pos=True,
85 | global_attn_indexes=encoder_global_attn_indexes,
86 | window_size=14,
87 | out_chans=prompt_embed_dim,
88 | ),
89 | prompt_encoder=PromptEncoder(
90 | embed_dim=prompt_embed_dim,
91 | image_embedding_size=(image_embedding_size, image_embedding_size),
92 | input_image_size=(image_size, image_size),
93 | mask_in_chans=16,
94 | ),
95 | mask_decoder=MaskDecoder(
96 | num_multimask_outputs=3,
97 | transformer=TwoWayTransformer(
98 | depth=2,
99 | embedding_dim=prompt_embed_dim,
100 | mlp_dim=2048,
101 | num_heads=8,
102 | ),
103 | transformer_dim=prompt_embed_dim,
104 | iou_head_depth=3,
105 | iou_head_hidden_dim=256,
106 | ),
107 | pixel_mean=[123.675, 116.28, 103.53],
108 | pixel_std=[58.395, 57.12, 57.375],
109 | )
110 | sam.eval()
111 | checkpoint = Path(checkpoint)
112 | if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists():
113 | cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ")
114 | if len(cmd) == 0 or cmd.lower() == 'y':
115 | checkpoint.parent.mkdir(parents=True, exist_ok=True)
116 | print("Downloading SAM ViT-B checkpoint...")
117 | urllib.request.urlretrieve(
118 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
119 | checkpoint,
120 | )
121 | print(checkpoint.name, " is downloaded!")
122 | elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists():
123 | cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ")
124 | if len(cmd) == 0 or cmd.lower() == 'y':
125 | checkpoint.parent.mkdir(parents=True, exist_ok=True)
126 | print("Downloading SAM ViT-H checkpoint...")
127 | urllib.request.urlretrieve(
128 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
129 | checkpoint,
130 | )
131 | print(checkpoint.name, " is downloaded!")
132 | elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists():
133 | cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ")
134 | if len(cmd) == 0 or cmd.lower() == 'y':
135 | checkpoint.parent.mkdir(parents=True, exist_ok=True)
136 | print("Downloading SAM ViT-L checkpoint...")
137 | urllib.request.urlretrieve(
138 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
139 | checkpoint,
140 | )
141 | print(checkpoint.name, " is downloaded!")
142 |
143 |
144 | if checkpoint is not None:
145 | with open(checkpoint, "rb") as f:
146 | state_dict = torch.load(f)
147 | sam.load_state_dict(state_dict)
148 | return sam
149 |
--------------------------------------------------------------------------------
/segment_anything/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .sam import Sam
8 | from .image_encoder import ImageEncoderViT
9 | from .mask_decoder import MaskDecoder
10 | from .prompt_encoder import PromptEncoder
11 | from .transformer import TwoWayTransformer
12 |
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/common.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/common.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/common.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/common.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/image_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/mask_decoder.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/prompt_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/sam.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/sam.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/sam.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/sam.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/transformer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/transformer.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/__pycache__/transformer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/modeling/__pycache__/transformer.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from typing import Type
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/segment_anything/modeling/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import List, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class MaskDecoder(nn.Module):
17 | def __init__(
18 | self,
19 | *,
20 | transformer_dim: int,
21 | transformer: nn.Module,
22 | num_multimask_outputs: int = 3,
23 | activation: Type[nn.Module] = nn.GELU,
24 | iou_head_depth: int = 3,
25 | iou_head_hidden_dim: int = 400,
26 | ) -> None:
27 | """
28 | Predicts masks given an image and prompt embeddings, using a
29 | transformer architecture.
30 |
31 | Arguments:
32 | transformer_dim (int): the channel dimension of the transformer
33 | transformer (nn.Module): the transformer used to predict masks
34 | num_multimask_outputs (int): the number of masks to predict
35 | when disambiguating masks
36 | activation (nn.Module): the type of activation to use when
37 | upscaling masks
38 | iou_head_depth (int): the depth of the MLP used to predict
39 | mask quality
40 | iou_head_hidden_dim (int): the hidden dimension of the MLP
41 | used to predict mask quality
42 | """
43 | super().__init__()
44 | self.transformer_dim = transformer_dim
45 | self.transformer = transformer
46 |
47 | self.num_multimask_outputs = num_multimask_outputs
48 |
49 | self.iou_token = nn.Embedding(1, transformer_dim)
50 | self.num_mask_tokens = num_multimask_outputs + 1
51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52 |
53 | self.output_upscaling = nn.Sequential(
54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55 | LayerNorm2d(transformer_dim // 4),
56 | activation(),
57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58 | activation(),
59 | )
60 | self.output_hypernetworks_mlps = nn.ModuleList(
61 | [
62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
63 | for i in range(self.num_mask_tokens)
64 | ]
65 | )
66 |
67 | self.iou_prediction_head = MLP(
68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
69 | )
70 |
71 | def forward(
72 | self,
73 | image_embeddings: torch.Tensor,
74 | image_pe: torch.Tensor,
75 | sparse_prompt_embeddings: torch.Tensor,
76 | dense_prompt_embeddings: torch.Tensor,
77 | multimask_output: bool,
78 | ) -> Tuple[torch.Tensor, torch.Tensor]:
79 | """
80 | Predict masks given image and prompt embeddings.
81 |
82 | Arguments:
83 | image_embeddings (torch.Tensor): the embeddings from the image encoder
84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
87 | multimask_output (bool): Whether to return multiple masks or a single
88 | mask.
89 |
90 | Returns:
91 | torch.Tensor: batched predicted masks
92 | torch.Tensor: batched predictions of mask quality
93 | """
94 | masks, iou_pred = self.predict_masks(
95 | image_embeddings=image_embeddings,
96 | image_pe=image_pe,
97 | sparse_prompt_embeddings=sparse_prompt_embeddings,
98 | dense_prompt_embeddings=dense_prompt_embeddings,
99 | )
100 |
101 | # Select the correct mask or masks for output
102 | if multimask_output:
103 | mask_slice = slice(1, None)
104 | else:
105 | mask_slice = slice(0, 1)
106 | masks = masks[:, mask_slice, :, :]
107 | iou_pred = iou_pred[:, mask_slice]
108 | masks=torch.sigmoid(masks.view(masks.shape[0], 1, -1)).view(masks.shape[0], 1, 256, 256)
109 | return masks, iou_pred
110 |
111 | def predict_masks(
112 | self,
113 | image_embeddings: torch.Tensor,
114 | image_pe: torch.Tensor,
115 | sparse_prompt_embeddings: torch.Tensor,
116 | dense_prompt_embeddings: torch.Tensor,
117 | ) -> Tuple[torch.Tensor, torch.Tensor]:
118 | """Predicts masks. See 'forward' for more details."""
119 | # Concatenate output tokens
120 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
121 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
122 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
123 |
124 | # Expand per-image data in batch direction to be per-mask
125 | if image_embeddings.shape[0] != tokens.shape[0]:
126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
127 | else:
128 | src = image_embeddings
129 | src = src + dense_prompt_embeddings
130 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
131 | b, c, h, w = src.shape
132 |
133 | # Run the transformer
134 | hs, src = self.transformer(src, pos_src, tokens)
135 | iou_token_out = hs[:, 0, :]
136 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
137 |
138 | # Upscale mask embeddings and predict masks using the mask tokens
139 | src = src.transpose(1, 2).view(b, c, h, w)
140 | upscaled_embedding = self.output_upscaling(src)
141 | hyper_in_list: List[torch.Tensor] = []
142 | for i in range(self.num_mask_tokens):
143 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
144 | hyper_in = torch.stack(hyper_in_list, dim=1)
145 | b, c, h, w = upscaled_embedding.shape
146 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
147 | # Generate mask quality predictions
148 | iou_pred = self.iou_prediction_head(iou_token_out)
149 |
150 | return masks, iou_pred
151 |
152 |
153 | # Lightly adapted from
154 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
155 | class MLP(nn.Module):
156 | def __init__(
157 | self,
158 | input_dim: int,
159 | hidden_dim: int,
160 | output_dim: int,
161 | num_layers: int,
162 | sigmoid_output: bool = False,
163 | ) -> None:
164 | super().__init__()
165 | self.num_layers = num_layers
166 | h = [hidden_dim] * (num_layers - 1)
167 | self.layers = nn.ModuleList(
168 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
169 | )
170 | self.sigmoid_output = sigmoid_output
171 |
172 | def forward(self, x):
173 | for i, layer in enumerate(self.layers):
174 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
175 | if self.sigmoid_output:
176 | x = F.sigmoid(x)
177 | return x
178 |
--------------------------------------------------------------------------------
/segment_anything/modeling/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch import nn
10 |
11 | from typing import Any, Optional, Tuple, Type
12 |
13 | from .common import LayerNorm2d
14 |
15 |
16 | class PromptEncoder(nn.Module):
17 | def __init__(
18 | self,
19 | embed_dim: int,
20 | image_embedding_size: Tuple[int, int],
21 | input_image_size: Tuple[int, int],
22 | mask_in_chans: int,
23 | activation: Type[nn.Module] = nn.GELU,
24 | ) -> None:
25 | """
26 | Encodes prompts for input to SAM's mask decoder.
27 |
28 | Arguments:
29 | embed_dim (int): The prompts' embedding dimension
30 | image_embedding_size (tuple(int, int)): The spatial size of the
31 | image embedding, as (H, W).
32 | input_image_size (int): The padded size of the image as input
33 | to the image encoder, as (H, W).
34 | mask_in_chans (int): The number of hidden channels used for
35 | encoding input masks.
36 | activation (nn.Module): The activation to use when encoding
37 | input masks.
38 | """
39 | super().__init__()
40 | self.embed_dim = embed_dim
41 | self.input_image_size = input_image_size
42 | self.image_embedding_size = image_embedding_size
43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44 |
45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47 | self.point_embeddings = nn.ModuleList(point_embeddings)
48 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
49 |
50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51 | self.mask_downscaling = nn.Sequential(
52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53 | LayerNorm2d(mask_in_chans // 4),
54 | activation(),
55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56 | LayerNorm2d(mask_in_chans),
57 | activation(),
58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59 | )
60 | self.no_mask_embed = nn.Embedding(1, embed_dim)
61 |
62 | def get_dense_pe(self) -> torch.Tensor:
63 | """
64 | Returns the positional encoding used to encode point prompts,
65 | applied to a dense set of points the shape of the image encoding.
66 |
67 | Returns:
68 | torch.Tensor: Positional encoding with shape
69 | 1x(embed_dim)x(embedding_h)x(embedding_w)
70 | """
71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72 |
73 | def _embed_points(
74 | self,
75 | points: torch.Tensor,
76 | labels: torch.Tensor,
77 | pad: bool,
78 | ) -> torch.Tensor:
79 | """Embeds point prompts."""
80 | points = points + 0.5 # Shift to center of pixel
81 | if pad:
82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84 | points = torch.cat([points, padding_point], dim=1)
85 | labels=torch.squeeze(labels)
86 | labels = torch.cat([labels, padding_label], dim=1)
87 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
88 | point_embedding[labels == -1] = 0.0
89 | point_embedding[labels == -1] += self.not_a_point_embed.weight
90 | point_embedding[labels == 0] += self.point_embeddings[0].weight
91 | point_embedding[labels == 1] += self.point_embeddings[1].weight
92 | return point_embedding
93 |
94 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
95 | """Embeds box prompts."""
96 | boxes = boxes + 0.5 # Shift to center of pixel
97 | coords = boxes.reshape(-1, 2, 2)
98 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
99 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
100 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
101 | return corner_embedding
102 |
103 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
104 | """Embeds mask inputs."""
105 | mask_embedding = self.mask_downscaling(masks)
106 | return mask_embedding
107 |
108 | def _get_batch_size(
109 | self,
110 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
111 | boxes: Optional[torch.Tensor],
112 | masks: Optional[torch.Tensor],
113 | ) -> int:
114 | """
115 | Gets the batch size of the output given the batch size of the input prompts.
116 | """
117 | if points is not None:
118 | return points.shape[0]
119 | elif boxes is not None:
120 | return boxes.shape[0]
121 | elif masks is not None:
122 | return masks.shape[0]
123 | else:
124 | return 1
125 |
126 | def _get_device(self) -> torch.device:
127 | return self.point_embeddings[0].weight.device
128 |
129 | def forward(
130 | self,
131 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
132 | boxes: Optional[torch.Tensor],
133 | masks: Optional[torch.Tensor],
134 | ) -> Tuple[torch.Tensor, torch.Tensor]:
135 | """
136 | Embeds different types of prompts, returning both sparse and dense
137 | embeddings.
138 |
139 | Arguments:
140 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
141 | and labels to embed.
142 | boxes (torch.Tensor or none): boxes to embed
143 | masks (torch.Tensor or none): masks to embed
144 |
145 | Returns:
146 | torch.Tensor: sparse embeddings for the points and boxes, with shape
147 | BxNx(embed_dim), where N is determined by the number of input points
148 | and boxes.
149 | torch.Tensor: dense embeddings for the masks, in the shape
150 | Bx(embed_dim)x(embed_H)x(embed_W)
151 | """
152 | bs = self._get_batch_size(points, boxes, masks)
153 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
154 | if points is not None:
155 | coords=points.to('cuda:1')
156 | labels=torch.ones([points.shape[0],points.shape[1],1]).to('cuda:1')
157 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
158 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
159 | if boxes is not None:
160 | box_embeddings = self._embed_boxes(boxes)
161 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
162 |
163 | if masks is not None:
164 | dense_embeddings = self._embed_masks(masks)
165 | else:
166 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
167 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
168 | )
169 |
170 | return sparse_embeddings, dense_embeddings
171 |
172 |
173 | class PositionEmbeddingRandom(nn.Module):
174 | """
175 | Positional encoding using random spatial frequencies.
176 | """
177 |
178 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
179 | super().__init__()
180 | if scale is None or scale <= 0.0:
181 | scale = 1.0
182 | self.register_buffer(
183 | "positional_encoding_gaussian_matrix",
184 | scale * torch.randn((2, num_pos_feats)),
185 | )
186 |
187 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
188 | """Positionally encode points that are normalized to [0,1]."""
189 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
190 | coords = 2 * coords - 1
191 | coords = coords @ self.positional_encoding_gaussian_matrix
192 | coords = 2 * np.pi * coords
193 | # outputs d_1 x ... x d_n x C shape
194 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
195 |
196 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
197 | """Generate positional encoding for a grid of the specified size."""
198 | h, w = size
199 | device: Any = self.positional_encoding_gaussian_matrix.device
200 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
201 | y_embed = grid.cumsum(dim=0) - 0.5
202 | x_embed = grid.cumsum(dim=1) - 0.5
203 | y_embed = y_embed / h
204 | x_embed = x_embed / w
205 |
206 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
207 | return pe.permute(2, 0, 1) # C x H x W
208 |
209 | def forward_with_coords(
210 | self, coords_input: torch.Tensor, image_size: Tuple[int, int]
211 | ) -> torch.Tensor:
212 | """Positionally encode points that are not normalized to [0,1]."""
213 | coords = coords_input.clone()
214 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
215 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
216 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
217 |
--------------------------------------------------------------------------------
/segment_anything/modeling/sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Any, Dict, List, Tuple
12 |
13 | from .image_encoder import ImageEncoderViT
14 | from .mask_decoder import MaskDecoder
15 | from .prompt_encoder import PromptEncoder
16 |
17 |
18 | class Sam(nn.Module):
19 | mask_threshold: float = 0.0
20 | image_format: str = "RGB"
21 |
22 | def __init__(
23 | self,
24 | image_encoder: ImageEncoderViT,
25 | prompt_encoder: PromptEncoder,
26 | mask_decoder: MaskDecoder,
27 | pixel_mean: List[float] = [123.675, 116.28, 103.53],
28 | pixel_std: List[float] = [58.395, 57.12, 57.375],
29 | ) -> None:
30 | """
31 | SAM predicts object masks from an image and input prompts.
32 |
33 | Arguments:
34 | image_encoder (ImageEncoderViT): The backbone used to encode the
35 | image into image embeddings that allow for efficient mask prediction.
36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38 | and encoded prompts.
39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40 | pixel_std (list(float)): Std values for normalizing pixels in the input image.
41 | """
42 | super().__init__()
43 | self.image_encoder = image_encoder
44 | self.prompt_encoder = prompt_encoder
45 | self.mask_decoder = mask_decoder
46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
48 |
49 | @property
50 | def device(self) -> Any:
51 | return self.pixel_mean.device
52 |
53 | @torch.no_grad()
54 | def forward(
55 | self,
56 | batched_input: List[Dict[str, Any]],
57 | multimask_output: bool,
58 | ) -> List[Dict[str, torch.Tensor]]:
59 | """
60 | Predicts masks end-to-end from provided images and prompts.
61 | If prompts are not known in advance, using SamPredictor is
62 | recommended over calling the model directly.
63 |
64 | Arguments:
65 | batched_input (list(dict)): A list over input images, each a
66 | dictionary with the following keys. A prompt key can be
67 | excluded if it is not present.
68 | 'image': The image as a torch tensor in 3xHxW format,
69 | already transformed for input to the model.
70 | 'original_size': (tuple(int, int)) The original size of
71 | the image before transformation, as (H, W).
72 | 'point_coords': (torch.Tensor) Batched point prompts for
73 | this image, with shape BxNx2. Already transformed to the
74 | input frame of the model.
75 | 'point_labels': (torch.Tensor) Batched labels for point prompts,
76 | with shape BxN.
77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
78 | Already transformed to the input frame of the model.
79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
80 | in the form Bx1xHxW.
81 | multimask_output (bool): Whether the model should predict multiple
82 | disambiguating masks, or return a single mask.
83 |
84 | Returns:
85 | (list(dict)): A list over input images, where each element is
86 | as dictionary with the following keys.
87 | 'masks': (torch.Tensor) Batched binary mask predictions,
88 | with shape BxCxHxW, where B is the number of input prompts,
89 | C is determined by multimask_output, and (H, W) is the
90 | original size of the image.
91 | 'iou_predictions': (torch.Tensor) The model's predictions
92 | of mask quality, in shape BxC.
93 | 'low_res_logits': (torch.Tensor) Low resolution logits with
94 | shape BxCxHxW, where H=W=256. Can be passed as mask input
95 | to subsequent iterations of prediction.
96 | """
97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
98 | image_embeddings = self.image_encoder(input_images)
99 |
100 | outputs = []
101 | for image_record, curr_embedding in zip(batched_input, image_embeddings):
102 | if "point_coords" in image_record:
103 | points = (image_record["point_coords"], image_record["point_labels"])
104 | else:
105 | points = None
106 | sparse_embeddings, dense_embeddings = self.prompt_encoder(
107 |
108 | points=points,
109 | boxes=image_record.get("boxes", None),
110 | masks=image_record.get("mask_inputs", None),
111 | )
112 | low_res_masks, iou_predictions = self.mask_decoder(
113 | image_embeddings=curr_embedding.unsqueeze(0),
114 | image_pe=self.prompt_encoder.get_dense_pe(),
115 | sparse_prompt_embeddings=sparse_embeddings,
116 | dense_prompt_embeddings=dense_embeddings,
117 | multimask_output=multimask_output,
118 | )
119 | masks = self.postprocess_masks(
120 | low_res_masks,
121 | input_size=image_record["image"].shape[-2:],
122 | original_size=image_record["original_size"],
123 | )
124 | masks = masks > self.mask_threshold
125 | outputs.append(
126 | {
127 | "masks": masks,
128 | "iou_predictions": iou_predictions,
129 | "low_res_logits": low_res_masks,
130 | }
131 | )
132 | return outputs
133 |
134 | def postprocess_masks(
135 | self,
136 | masks: torch.Tensor,
137 | input_size: Tuple[int, ...],
138 | original_size: Tuple[int, ...],
139 | ) -> torch.Tensor:
140 | """
141 | Remove padding and upscale masks to the original image size.
142 |
143 | Arguments:
144 | masks (torch.Tensor): Batched masks from the mask_decoder,
145 | in BxCxHxW format.
146 | input_size (tuple(int, int)): The size of the image input to the
147 | model, in (H, W) format. Used to remove padding.
148 | original_size (tuple(int, int)): The original size of the image
149 | before resizing for input to the model, in (H, W) format.
150 |
151 | Returns:
152 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
153 | is given by original_size.
154 | """
155 | masks = F.interpolate(
156 | masks,
157 | (self.image_encoder.img_size, self.image_encoder.img_size),
158 | mode="bilinear",
159 | align_corners=False,
160 | )
161 | masks = masks[..., : input_size[0], : input_size[1]]
162 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
163 | return masks
164 |
165 | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
166 | """Normalize pixel values and pad to a square input."""
167 | # Normalize colors
168 | x = (x - self.pixel_mean) / self.pixel_std
169 |
170 | # Pad
171 | h, w = x.shape[-2:]
172 | padh = self.image_encoder.img_size - h
173 | padw = self.image_encoder.img_size - w
174 | x = F.pad(x, (0, padw, 0, padh))
175 | return x
176 |
--------------------------------------------------------------------------------
/segment_anything/modeling/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import Tensor, nn
9 |
10 | import math
11 | from typing import Tuple, Type
12 |
13 | from .common import MLPBlock
14 |
15 |
16 | class TwoWayTransformer(nn.Module):
17 | def __init__(
18 | self,
19 | depth: int,
20 | embedding_dim: int,
21 | num_heads: int,
22 | mlp_dim: int,
23 | activation: Type[nn.Module] = nn.ReLU,
24 | attention_downsample_rate: int = 2,
25 | ) -> None:
26 | """
27 | A transformer decoder that attends to an input image using
28 | queries whose positional embedding is supplied.
29 |
30 | Args:
31 | depth (int): number of layers in the transformer
32 | embedding_dim (int): the channel dimension for the input embeddings
33 | num_heads (int): the number of heads for multihead attention. Must
34 | divide embedding_dim
35 | mlp_dim (int): the channel dimension internal to the MLP block
36 | activation (nn.Module): the activation to use in the MLP block
37 | """
38 | super().__init__()
39 | self.depth = depth
40 | self.embedding_dim = embedding_dim
41 | self.num_heads = num_heads
42 | self.mlp_dim = mlp_dim
43 | self.layers = nn.ModuleList()
44 |
45 | for i in range(depth):
46 | self.layers.append(
47 | TwoWayAttentionBlock(
48 | embedding_dim=embedding_dim,
49 | num_heads=num_heads,
50 | mlp_dim=mlp_dim,
51 | activation=activation,
52 | attention_downsample_rate=attention_downsample_rate,
53 | skip_first_layer_pe=(i == 0),
54 | )
55 | )
56 |
57 | self.final_attn_token_to_image = Attention(
58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59 | )
60 | self.norm_final_attn = nn.LayerNorm(embedding_dim)
61 |
62 | def forward(
63 | self,
64 | image_embedding: Tensor,
65 | image_pe: Tensor,
66 | point_embedding: Tensor,
67 | ) -> Tuple[Tensor, Tensor]:
68 | """
69 | Args:
70 | image_embedding (torch.Tensor): image to attend to. Should be shape
71 | B x embedding_dim x h x w for any h and w.
72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must
73 | have the same shape as image_embedding.
74 | point_embedding (torch.Tensor): the embedding to add to the query points.
75 | Must have shape B x N_points x embedding_dim for any N_points.
76 |
77 | Returns:
78 | torch.Tensor: the processed point_embedding
79 | torch.Tensor: the processed image_embedding
80 | """
81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82 | bs, c, h, w = image_embedding.shape
83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84 | image_pe = image_pe.flatten(2).permute(0, 2, 1)
85 |
86 | # Prepare queries
87 | queries = point_embedding
88 | keys = image_embedding
89 |
90 | # Apply transformer blocks and final layernorm
91 | for layer in self.layers:
92 | queries, keys = layer(
93 | queries=queries,
94 | keys=keys,
95 | query_pe=point_embedding,
96 | key_pe=image_pe,
97 | )
98 |
99 | # Apply the final attention layer from the points to the image
100 | q = queries + point_embedding
101 | k = keys + image_pe
102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103 | queries = queries + attn_out
104 | queries = self.norm_final_attn(queries)
105 |
106 | return queries, keys
107 |
108 |
109 | class TwoWayAttentionBlock(nn.Module):
110 | def __init__(
111 | self,
112 | embedding_dim: int,
113 | num_heads: int,
114 | mlp_dim: int = 2048,
115 | activation: Type[nn.Module] = nn.ReLU,
116 | attention_downsample_rate: int = 2,
117 | skip_first_layer_pe: bool = False,
118 | ) -> None:
119 | """
120 | A transformer block with four layers: (1) self-attention of sparse
121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse
123 | inputs.
124 |
125 | Arguments:
126 | embedding_dim (int): the channel dimension of the embeddings
127 | num_heads (int): the number of heads in the attention layers
128 | mlp_dim (int): the hidden dimension of the mlp block
129 | activation (nn.Module): the activation of the mlp block
130 | skip_first_layer_pe (bool): skip the PE on the first layer
131 | """
132 | super().__init__()
133 | self.self_attn = Attention(embedding_dim, num_heads)
134 | self.norm1 = nn.LayerNorm(embedding_dim)
135 |
136 | self.cross_attn_token_to_image = Attention(
137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138 | )
139 | self.norm2 = nn.LayerNorm(embedding_dim)
140 |
141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142 | self.norm3 = nn.LayerNorm(embedding_dim)
143 |
144 | self.norm4 = nn.LayerNorm(embedding_dim)
145 | self.cross_attn_image_to_token = Attention(
146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147 | )
148 |
149 | self.skip_first_layer_pe = skip_first_layer_pe
150 |
151 | def forward(
152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153 | ) -> Tuple[Tensor, Tensor]:
154 | # Self attention block
155 | if self.skip_first_layer_pe:
156 | queries = self.self_attn(q=queries, k=queries, v=queries)
157 | else:
158 | q = queries + query_pe
159 | attn_out = self.self_attn(q=q, k=q, v=queries)
160 | queries = queries + attn_out
161 | queries = self.norm1(queries)
162 |
163 | # Cross attention block, tokens attending to image embedding
164 | q = queries + query_pe
165 | k = keys + key_pe
166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167 | queries = queries + attn_out
168 | queries = self.norm2(queries)
169 |
170 | # MLP block
171 | mlp_out = self.mlp(queries)
172 | queries = queries + mlp_out
173 | queries = self.norm3(queries)
174 |
175 | # Cross attention block, image embedding attending to tokens
176 | q = queries + query_pe
177 | k = keys + key_pe
178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179 | keys = keys + attn_out
180 | keys = self.norm4(keys)
181 |
182 | return queries, keys
183 |
184 |
185 | class Attention(nn.Module):
186 | """
187 | An attention layer that allows for downscaling the size of the embedding
188 | after projection to queries, keys, and values.
189 | """
190 |
191 | def __init__(
192 | self,
193 | embedding_dim: int,
194 | num_heads: int,
195 | downsample_rate: int = 1,
196 | ) -> None:
197 | super().__init__()
198 | self.embedding_dim = embedding_dim
199 | self.internal_dim = embedding_dim // downsample_rate
200 | self.num_heads = num_heads
201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
202 |
203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
207 |
208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
209 | b, n, c = x.shape
210 | x = x.reshape(b, n, num_heads, c // num_heads)
211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
212 |
213 | def _recombine_heads(self, x: Tensor) -> Tensor:
214 | b, n_heads, n_tokens, c_per_head = x.shape
215 | x = x.transpose(1, 2)
216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
217 |
218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
219 | # Input projections
220 | q = self.q_proj(q)
221 | k = self.k_proj(k)
222 | v = self.v_proj(v)
223 |
224 | # Separate into heads
225 | q = self._separate_heads(q, self.num_heads)
226 | k = self._separate_heads(k, self.num_heads)
227 | v = self._separate_heads(v, self.num_heads)
228 |
229 | # Attention
230 | _, _, _, c_per_head = q.shape
231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
232 | attn = attn / math.sqrt(c_per_head)
233 | attn = torch.softmax(attn, dim=-1)
234 |
235 | # Get output
236 | out = attn @ v
237 | out = self._recombine_heads(out)
238 | out = self.out_proj(out)
239 |
240 | return out
241 |
--------------------------------------------------------------------------------
/segment_anything/predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from segment_anything.modeling import Sam
11 |
12 | from typing import Optional, Tuple
13 |
14 | from .utils.transforms import ResizeLongestSide
15 |
16 |
17 | class SamPredictor:
18 | def __init__(
19 | self,
20 | sam_model: Sam,
21 | ) -> None:
22 | """
23 | Uses SAM to calculate the image embedding for an image, and then
24 | allow repeated, efficient mask prediction given prompts.
25 |
26 | Arguments:
27 | sam_model (Sam): The model to use for mask prediction.
28 | """
29 | super().__init__()
30 | self.model = sam_model
31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
32 | self.reset_image()
33 |
34 | def set_image(
35 | self,
36 | image: np.ndarray,
37 | image_format: str = "RGB",
38 | ) -> None:
39 | """
40 | Calculates the image embeddings for the provided image, allowing
41 | masks to be predicted with the 'predict' method.
42 |
43 | Arguments:
44 | image (np.ndarray): The image for calculating masks. Expects an
45 | image in HWC uint8 format, with pixel values in [0, 255].
46 | image_format (str): The color format of the image, in ['RGB', 'BGR'].
47 | """
48 | assert image_format in [
49 | "RGB",
50 | "BGR",
51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
52 | if image_format != self.model.image_format:
53 | image = image[..., ::-1]
54 |
55 | # Transform the image to the form expected by the model
56 | input_image = self.transform.apply_image(image)
57 | input_image_torch = torch.as_tensor(input_image, device=self.device)
58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
59 |
60 | self.set_torch_image(input_image_torch, image.shape[:2])
61 |
62 | @torch.no_grad()
63 | def set_torch_image(
64 | self,
65 | transformed_image: torch.Tensor,
66 | original_image_size: Tuple[int, ...],
67 | ) -> None:
68 | """
69 | Calculates the image embeddings for the provided image, allowing
70 | masks to be predicted with the 'predict' method. Expects the input
71 | image to be already transformed to the format expected by the model.
72 |
73 | Arguments:
74 | transformed_image (torch.Tensor): The input image, with shape
75 | 1x3xHxW, which has been transformed with ResizeLongestSide.
76 | original_image_size (tuple(int, int)): The size of the image
77 | before transformation, in (H, W) format.
78 | """
79 | assert (
80 | len(transformed_image.shape) == 4
81 | and transformed_image.shape[1] == 3
82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
84 | self.reset_image()
85 |
86 | self.original_size = original_image_size
87 | self.input_size = tuple(transformed_image.shape[-2:])
88 | input_image = self.model.preprocess(transformed_image)
89 | self.features = self.model.image_encoder(input_image)
90 | self.is_image_set = True
91 |
92 | def predict(
93 | self,
94 | point_coords: Optional[np.ndarray] = None,
95 | point_labels: Optional[np.ndarray] = None,
96 | box: Optional[np.ndarray] = None,
97 | mask_input: Optional[np.ndarray] = None,
98 | multimask_output: bool = True,
99 | return_logits: bool = False,
100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
101 | """
102 | Predict masks for the given input prompts, using the currently set image.
103 |
104 | Arguments:
105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the
106 | model. Each point is in (X,Y) in pixels.
107 | point_labels (np.ndarray or None): A length N array of labels for the
108 | point prompts. 1 indicates a foreground point and 0 indicates a
109 | background point.
110 | box (np.ndarray or None): A length 4 array given a box prompt to the
111 | model, in XYXY format.
112 | mask_input (np.ndarray): A low resolution mask input to the model, typically
113 | coming from a previous prediction iteration. Has form 1xHxW, where
114 | for SAM, H=W=256.
115 | multimask_output (bool): If true, the model will return three masks.
116 | For ambiguous input prompts (such as a single click), this will often
117 | produce better masks than a single prediction. If only a single
118 | mask is needed, the model's predicted quality score can be used
119 | to select the best mask. For non-ambiguous prompts, such as multiple
120 | input prompts, multimask_output=False can give better results.
121 | return_logits (bool): If true, returns un-thresholded masks logits
122 | instead of a binary mask.
123 |
124 | Returns:
125 | (np.ndarray): The output masks in CxHxW format, where C is the
126 | number of masks, and (H, W) is the original image size.
127 | (np.ndarray): An array of length C containing the model's
128 | predictions for the quality of each mask.
129 | (np.ndarray): An array of shape CxHxW, where C is the number
130 | of masks and H=W=256. These low resolution logits can be passed to
131 | a subsequent iteration as mask input.
132 | """
133 | if not self.is_image_set:
134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
135 |
136 | # Transform input prompts
137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
138 | if point_coords is not None:
139 | assert (
140 | point_labels is not None
141 | ), "point_labels must be supplied if point_coords is supplied."
142 | point_coords = self.transform.apply_coords(point_coords, self.original_size)
143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
146 | if box is not None:
147 | box = self.transform.apply_boxes(box, self.original_size)
148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
149 | box_torch = box_torch[None, :]
150 | if mask_input is not None:
151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
152 | mask_input_torch = mask_input_torch[None, :, :, :]
153 |
154 | masks, iou_predictions, low_res_masks = self.predict_torch(
155 | coords_torch,
156 | labels_torch,
157 | box_torch,
158 | mask_input_torch,
159 | multimask_output,
160 | return_logits=return_logits,
161 | )
162 |
163 | masks_np = masks[0].detach().cpu().numpy()
164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
166 | return masks_np, iou_predictions_np, low_res_masks_np
167 |
168 | @torch.no_grad()
169 | def predict_torch(
170 | self,
171 | point_coords: Optional[torch.Tensor],
172 | point_labels: Optional[torch.Tensor],
173 | boxes: Optional[torch.Tensor] = None,
174 | mask_input: Optional[torch.Tensor] = None,
175 | multimask_output: bool = True,
176 | return_logits: bool = False,
177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178 | """
179 | Predict masks for the given input prompts, using the currently set image.
180 | Input prompts are batched torch tensors and are expected to already be
181 | transformed to the input frame using ResizeLongestSide.
182 |
183 | Arguments:
184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
185 | model. Each point is in (X,Y) in pixels.
186 | point_labels (torch.Tensor or None): A BxN array of labels for the
187 | point prompts. 1 indicates a foreground point and 0 indicates a
188 | background point.
189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the
190 | model, in XYXY format.
191 | mask_input (np.ndarray): A low resolution mask input to the model, typically
192 | coming from a previous prediction iteration. Has form Bx1xHxW, where
193 | for SAM, H=W=256. Masks returned by a previous iteration of the
194 | predict method do not need further transformation.
195 | multimask_output (bool): If true, the model will return three masks.
196 | For ambiguous input prompts (such as a single click), this will often
197 | produce better masks than a single prediction. If only a single
198 | mask is needed, the model's predicted quality score can be used
199 | to select the best mask. For non-ambiguous prompts, such as multiple
200 | input prompts, multimask_output=False can give better results.
201 | return_logits (bool): If true, returns un-thresholded masks logits
202 | instead of a binary mask.
203 |
204 | Returns:
205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the
206 | number of masks, and (H, W) is the original image size.
207 | (torch.Tensor): An array of shape BxC containing the model's
208 | predictions for the quality of each mask.
209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number
210 | of masks and H=W=256. These low res logits can be passed to
211 | a subsequent iteration as mask input.
212 | """
213 | if not self.is_image_set:
214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
215 |
216 | if point_coords is not None:
217 | points = (point_coords, point_labels)
218 | else:
219 | points = None
220 |
221 | # Embed prompts
222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
223 | points=points,
224 | boxes=boxes,
225 | masks=mask_input,
226 | )
227 |
228 | # Predict masks
229 | low_res_masks, iou_predictions = self.model.mask_decoder(
230 | image_embeddings=self.features,
231 | image_pe=self.model.prompt_encoder.get_dense_pe(),
232 | sparse_prompt_embeddings=sparse_embeddings,
233 | dense_prompt_embeddings=dense_embeddings,
234 | multimask_output=multimask_output,
235 | )
236 |
237 | # Upscale the masks to the original image resolution
238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
239 |
240 | if not return_logits:
241 | masks = masks > self.model.mask_threshold
242 |
243 | return masks, iou_predictions, low_res_masks
244 |
245 | def get_image_embedding(self) -> torch.Tensor:
246 | """
247 | Returns the image embeddings for the currently set image, with
248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are
249 | the embedding spatial dimension of SAM (typically C=256, H=W=64).
250 | """
251 | if not self.is_image_set:
252 | raise RuntimeError(
253 | "An image must be set with .set_image(...) to generate an embedding."
254 | )
255 | assert self.features is not None, "Features must exist if an image has been set."
256 | return self.features
257 |
258 | @property
259 | def device(self) -> torch.device:
260 | return self.model.device
261 |
262 | def reset_image(self) -> None:
263 | """Resets the currently set image."""
264 | self.is_image_set = False
265 | self.features = None
266 | self.orig_h = None
267 | self.orig_w = None
268 | self.input_h = None
269 | self.input_w = None
270 |
--------------------------------------------------------------------------------
/segment_anything/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/segment_anything/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/utils/__pycache__/amg.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/amg.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/utils/__pycache__/amg.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/amg.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/utils/__pycache__/transforms.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/transforms.cpython-310.pyc
--------------------------------------------------------------------------------
/segment_anything/utils/__pycache__/transforms.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/segment_anything/utils/__pycache__/transforms.cpython-39.pyc
--------------------------------------------------------------------------------
/segment_anything/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 |
11 | from typing import Tuple
12 |
13 | from ..modeling import Sam
14 | from .amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(
43 | input_image_size: torch.Tensor, longest_side: int
44 | ) -> torch.Tensor:
45 | input_image_size = input_image_size.to(torch.float32)
46 | scale = longest_side / torch.max(input_image_size)
47 | transformed_size = scale * input_image_size
48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49 | return transformed_size
50 |
51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
52 | point_coords = point_coords + 0.5
53 | point_coords = point_coords / self.img_size
54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
56 |
57 | point_embedding = point_embedding * (point_labels != -1)
58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
59 | point_labels == -1
60 | )
61 |
62 | for i in range(self.model.prompt_encoder.num_point_embeddings):
63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
64 | i
65 | ].weight * (point_labels == i)
66 |
67 | return point_embedding
68 |
69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
71 | mask_embedding = mask_embedding + (
72 | 1 - has_mask_input
73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
74 | return mask_embedding
75 |
76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
77 | masks = F.interpolate(
78 | masks,
79 | size=(self.img_size, self.img_size),
80 | mode="bilinear",
81 | align_corners=False,
82 | )
83 |
84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
86 |
87 | orig_im_size = orig_im_size.to(torch.int64)
88 | h, w = orig_im_size[0], orig_im_size[1]
89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
90 | return masks
91 |
92 | def select_masks(
93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
94 | ) -> Tuple[torch.Tensor, torch.Tensor]:
95 | # Determine if we should return the multiclick mask or not from the number of points.
96 | # The reweighting is used to avoid control flow.
97 | score_reweight = torch.tensor(
98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
99 | ).to(iou_preds.device)
100 | score = iou_preds + (num_points - 2.5) * score_reweight
101 | best_idx = torch.argmax(score, dim=1)
102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
104 |
105 | return masks, iou_preds
106 |
107 | @torch.no_grad()
108 | def forward(
109 | self,
110 | image_embeddings: torch.Tensor,
111 | point_coords: torch.Tensor,
112 | point_labels: torch.Tensor,
113 | mask_input: torch.Tensor,
114 | has_mask_input: torch.Tensor,
115 | orig_im_size: torch.Tensor,
116 | ):
117 | sparse_embedding = self._embed_points(point_coords, point_labels)
118 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
119 |
120 | masks, scores = self.model.mask_decoder.predict_masks(
121 | image_embeddings=image_embeddings,
122 | image_pe=self.model.prompt_encoder.get_dense_pe(),
123 | sparse_prompt_embeddings=sparse_embedding,
124 | dense_prompt_embeddings=dense_embedding,
125 | )
126 |
127 | if self.use_stability_score:
128 | scores = calculate_stability_score(
129 | masks, self.model.mask_threshold, self.stability_score_offset
130 | )
131 |
132 | if self.return_single_mask:
133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
134 |
135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
136 |
137 | if self.return_extra_metrics:
138 | stability_scores = calculate_stability_score(
139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
140 | )
141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
142 | return upscaled_masks, scores, stability_scores, areas, masks
143 |
144 | return upscaled_masks, scores, masks
145 |
--------------------------------------------------------------------------------
/segment_anything/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
11 |
12 | from copy import deepcopy
13 | from typing import Tuple
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to the longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34 | """
35 | Expects a numpy array of length 2 in the final dimension. Requires the
36 | original image size in (H, W) format.
37 | """
38 | old_h, old_w = original_size
39 | new_h, new_w = self.get_preprocess_shape(old_h, old_w, self.target_length)
40 | new_coords = np.empty_like(coords)
41 | new_coords[..., 0] = coords[..., 0] * (new_w / old_w)
42 | new_coords[..., 1] = coords[..., 1] * (new_h / old_h)
43 | return new_coords
44 |
45 |
46 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
47 | """
48 | Expects a numpy array shape Bx4. Requires the original image size
49 | in (H, W) format.
50 | """
51 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
52 | return boxes.reshape(-1, 4)
53 |
54 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
55 | """
56 | Expects batched images with shape BxCxHxW and float format. This
57 | transformation may not exactly match apply_image. apply_image is
58 | the transformation expected by the model.
59 | """
60 | # Expects an image in BCHW format. May not exactly match apply_image.
61 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
62 | return F.interpolate(
63 | image, target_size, mode="bilinear", align_corners=False, antialias=True
64 | )
65 |
66 | def apply_coords_torch(
67 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
68 | ) -> torch.Tensor:
69 | """
70 | Expects a torch tensor with length 2 in the last dimension. Requires the
71 | original image size in (H, W) format.
72 | """
73 | old_h, old_w = original_size
74 | new_h, new_w = self.get_preprocess_shape(
75 | original_size[0], original_size[1], self.target_length
76 | )
77 | coords = deepcopy(coords).to(torch.float)
78 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
79 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
80 | return coords
81 |
82 | def apply_boxes_torch(
83 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
84 | ) -> torch.Tensor:
85 | """
86 | Expects a torch tensor with shape Bx4. Requires the original image
87 | size in (H, W) format.
88 | """
89 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
90 | return boxes.reshape(-1, 4)
91 |
92 | @staticmethod
93 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
94 | """
95 | Compute the output size given input size and target long side length.
96 | """
97 | scale = long_side_length * 1.0 / max(oldh, oldw)
98 | newh, neww = oldh * scale, oldw * scale
99 | neww = int(neww + 0.5)
100 | newh = int(newh + 0.5)
101 | return (newh, neww)
102 |
--------------------------------------------------------------------------------
/split.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import shutil
4 |
5 | # 源文件夹路径
6 | src_folder = '/home/disk1/cs/project/dataset/COCO/val_gt/'
7 |
8 | # 目标文件夹路径
9 | dst_folder = '/home/disk1/cs/project/dataset/COCO/split/val_gt/'
10 |
11 | # 每组文件数
12 | group_size = 8
13 |
14 | # 遍历文件夹中所有图片文件,按顺序分组
15 | file_groups = []
16 | for i, file_name in enumerate(sorted(os.listdir(src_folder))):
17 | # if file_name.endswith('.jpg'):
18 | group_idx = i // group_size
19 | if len(file_groups) <= group_idx:
20 | file_groups.append([])
21 | file_groups[group_idx].append(file_name)
22 |
23 | # 创建目标文件夹
24 | os.makedirs(dst_folder, exist_ok=True)
25 |
26 | # 将每组文件复制到目标文件夹
27 | for i, group in enumerate(file_groups):
28 | group_folder = os.path.join(dst_folder, f'group_{i}')
29 | os.makedirs(group_folder, exist_ok=True)
30 | for file_name in group:
31 | src_file = os.path.join(src_folder, file_name)
32 | dst_file = os.path.join(group_folder, file_name)
33 | shutil.copy(src_file, dst_file)
34 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import os
4 | from skimage import io
5 | join = os.path.join
6 | from tqdm import tqdm
7 | import torch
8 | from torch.utils.data import Dataset, DataLoader
9 | import monai
10 | from segment_anything import SamPredictor, sam_model_registry
11 | from segment_anything.utils.transforms import ResizeLongestSide
12 | from utils.SurfaceDice import compute_dice_coefficient
13 | import cv2
14 | from sklearn.metrics import accuracy_score, confusion_matrix, precision_score, recall_score, f1_score, jaccard_score
15 | # set seeds
16 | torch.manual_seed(2023)
17 | np.random.seed(2023)
18 | from skimage import io
19 | from utils_metrics import *
20 | from skimage import transform, io, segmentation
21 | from segment.yolox import YOLOX
22 | import random
23 | import math
24 | from functools import partial
25 | ##############################################################################################################
26 | num_epochs = 10
27 | ts_npz_path='/home/cs/project/medsam_tongue/data/tongueset3_npz/test/'
28 | npz_tr_path = '/home/cs/project/medsam_tongue/data/tongue_train_npz/'
29 | model_type = 'vit_b'
30 | checkpoint = '/home/cs/project/medsam_tongue/pretrained_model/final.pth'
31 | device = 'cuda:1'
32 | model_save_path = './logs/'
33 | if_save=False
34 | if_onlytest=True
35 | batch_size=32
36 | prompt_type='no'
37 | lr_decay_type= "cos"
38 | Init_lr= 1e-4
39 | point_num=3
40 | segment=None
41 | ###############################################################################################################
42 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
43 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
44 | if iters <= warmup_total_iters:
45 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
46 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
47 | elif iters >= total_iters - no_aug_iter:
48 | lr = min_lr
49 | else:
50 | lr = min_lr + 0.5 * (lr - min_lr) * (
51 | 1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
52 | )
53 | return lr
54 |
55 | def step_lr(lr, decay_rate, step_size, iters):
56 | if step_size < 1:
57 | raise ValueError("step_size must above 1.")
58 | n = iters // step_size
59 | out_lr = lr * decay_rate ** n
60 | return out_lr
61 |
62 | if lr_decay_type == "cos":
63 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
64 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
65 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
66 | func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
67 | else:
68 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
69 | step_size = total_iters / step_num
70 | func = partial(step_lr, lr, decay_rate, step_size)
71 |
72 | return func
73 | #%% create a dataset class to load npz data and return back image embeddings and ground truth
74 | class NpzDataset(Dataset):
75 | def __init__(self, data_root):
76 | self.npz_data=np.load(data_root)
77 | self.ori_gts = self.npz_data['gts']
78 | self.img_embeddings = self.npz_data['img_embeddings']
79 | self.imgs=self.npz_data['imgs']
80 | self.model=segment
81 | self.point_num=point_num
82 | def __len__(self):
83 | return self.ori_gts.shape[0]
84 |
85 | def __getitem__(self, index):
86 | img_embed = self.img_embeddings[index]
87 | gt2D = self.ori_gts[index]
88 | img=self.imgs[index]
89 | H, W = gt2D.shape
90 |
91 | # ############################box##############################################################
92 | if self.model!=None:
93 | img=Image.fromarray(img)
94 | img= self.model.get_miou_png(img)
95 | y_indices, x_indices = np.where(img > 0)
96 | x_min, x_max = np.min(x_indices), np.max(x_indices)
97 | y_min, y_max = np.min(y_indices), np.max(y_indices)
98 | bboxes = np.array([x_min, y_min, x_max, y_max])
99 | bboxes=np.array([x_min,y_min,x_max,y_max])
100 | points=np.where(img > 0)
101 | random_points = random.choices(range(len(points[0])), k=self.point_num)
102 | random_points = [(points[0][i], points[1][i]) for i in random_points]
103 |
104 | else:
105 | y_indices, x_indices = np.where(gt2D > 0)
106 | x_min, x_max = np.min(x_indices), np.max(x_indices)
107 | y_min, y_max = np.min(y_indices), np.max(y_indices)
108 | bboxes = np.array([x_min, y_min, x_max, y_max])
109 | points=np.where(gt2D > 0)
110 | random_points = random.choices(range(len(points[0])), k=self.point_num)
111 | random_points = [(points[0][i], points[1][i]) for i in random_points]
112 |
113 | return torch.tensor(img_embed).float(), torch.tensor(gt2D[None, :,:]).long(), torch.tensor(bboxes).float(),torch.tensor(img).float(),torch.tensor(random_points).float()
114 | #####################################################Begin############################################################################
115 | Min_lr=Init_lr*0.01
116 | lr_limit_max = Init_lr
117 | lr_limit_min = 3e-4
118 | Init_lr_fit = min(max(batch_size / batch_size * Init_lr, lr_limit_min), lr_limit_max)
119 | Min_lr_fit = min(max(batch_size / batch_size * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2)
120 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, num_epochs)
121 | train_losses = []
122 | val_losses = []
123 | best_iou=0
124 | best_pa=0
125 | best_acc=0
126 |
127 |
128 | sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
129 | seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')#%% train
130 | os.makedirs(model_save_path, exist_ok=True)
131 |
132 | for epoch in range(num_epochs):
133 | print(f'EPOCH: {epoch}')
134 | epoch_loss = 0
135 | ###############################################################Test##################################################################
136 | sam_model.eval()
137 | val_gts=[]
138 | val_preds=[]
139 | with torch.no_grad():
140 | for f in os.listdir(ts_npz_path):
141 | ts_dataset = NpzDataset(join(ts_npz_path,f))
142 | ts_dataloader = DataLoader(ts_dataset, batch_size=batch_size, shuffle=True)
143 | for step, (image_embedding, gt2D, boxes,img,points) in enumerate(ts_dataloader):
144 | if prompt_type=='box':
145 | box_np = boxes.numpy()
146 | sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
147 | box = sam_trans.apply_boxes(box_np, (img.shape[-2], img.shape[-1]))
148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
149 | if len(box_torch.shape) == 2:
150 | box_torch = box_torch[:, None, :]
151 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
152 | points=None,
153 | boxes=box_torch,
154 | masks=None,
155 | )
156 | elif prompt_type=='point':
157 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
158 | points=points,
159 | boxes=None,
160 | masks=None,
161 | )
162 | elif prompt_type=='no':
163 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
164 | points=None,
165 | boxes=None,
166 | masks=None,
167 | )
168 | mask_predictions, _ = sam_model.mask_decoder(
169 | image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)
170 | image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
171 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
172 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
173 | multimask_output=False,
174 | )
175 | for i in range(mask_predictions.shape[0]):
176 | mask = mask_predictions[i]
177 | mask = mask.cpu().detach().numpy().squeeze()
178 | mask = cv2.resize((mask > 0.5).astype(np.uint8),(gt2D.shape[2], gt2D.shape[3]))
179 | gt_data=gt2D[i].cpu().numpy().astype(np.uint8)
180 | val_gts.append(gt_data.astype(np.uint8))
181 | val_preds.append(mask.astype(np.uint8))
182 | iou,pa,acc=compute_mIoU(val_gts,val_preds)
183 | if iou> best_iou:
184 | best_iou=iou
185 | best_pa=pa
186 | best_acc=acc
187 | if if_onlytest:
188 | continue
189 | if if_save==True:
190 | torch.save(sam_model.state_dict(), join(model_save_path, 'best.pth'))# plot loss
191 | print('best_miou:'+str(best_iou))
192 | print('best_pa:'+str(best_pa))
193 | print('best_acc:'+str(best_acc))
194 | if if_onlytest:
195 | continue
196 | ###############################################################Train##################################################################
197 | sam_model.train()
198 | lr = lr_scheduler_func(epoch)
199 | optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr,weight_decay=0)
200 | for f in os.listdir(npz_tr_path):
201 | train_dataset = NpzDataset(join(npz_tr_path,f))
202 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
203 | for step, (image_embedding, gt2D, boxes,img,points) in enumerate(train_dataloader):
204 | with torch.no_grad():
205 | if prompt_type=='box':
206 | box_np = boxes.numpy()
207 | sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
208 | box = sam_trans.apply_boxes(box_np, (img.shape[-2], img.shape[-1]))
209 | box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
210 | if len(box_torch.shape) == 2:
211 | box_torch = box_torch[:, None, :]
212 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
213 | points=None,
214 | boxes=box_torch,
215 | masks=None,
216 | )
217 | elif prompt_type=='point':
218 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
219 | points=points,
220 | boxes=None,
221 | masks=None,
222 | )
223 | elif prompt_type=='no':
224 | sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
225 | points=None,
226 | boxes=None,
227 | masks=None,
228 | )
229 | mask_predictions, _ = sam_model.mask_decoder(
230 | image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)
231 | image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
232 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
233 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
234 | multimask_output=False,
235 | )
236 | mask_predictions= F.interpolate(mask_predictions, size=(gt2D.shape[2],gt2D.shape[3]), mode='bilinear', align_corners=False)
237 | gt2D=gt2D.to(device)
238 | loss = seg_loss(mask_predictions, gt2D)
239 | optimizer.zero_grad()
240 | loss.backward()
241 | optimizer.step()
242 | ################################################################################################################################
243 | if if_onlytest is False:
244 | plt.plot(train_losses)
245 | plt.title('Train Loss')
246 | plt.xlabel('Epoch')
247 | plt.ylabel('train_loss')
248 | plt.show()
249 | plt.savefig(join(model_save_path, 'train_loss.png'))
250 | plt.close()
251 | plt.plot(val_losses)
252 | plt.title('Val Loss')
253 | plt.xlabel('Epoch')
254 | plt.ylabel('val_loss')
255 | plt.show()
256 | plt.savefig(join(model_save_path, 'val_loss.png'))
257 | plt.close()
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/utils/__pycache__/SurfaceDice.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/SurfaceDice.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/callbacks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/callbacks.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/callbacks.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/callbacks.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataloader.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/dataloader.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataloader.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/dataloader.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils_bbox.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_bbox.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils_bbox.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_bbox.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils_fit.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_fit.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils_fit.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_fit.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils_map.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_map.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils_map.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cshan-github/TongueSAM/3f6e8c620e4d89e669a92a22f3be3d007932ce45/utils/__pycache__/utils_map.cpython-39.pyc
--------------------------------------------------------------------------------
/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import matplotlib
5 | matplotlib.use('Agg')
6 | import scipy.signal
7 | from matplotlib import pyplot as plt
8 | from torch.utils.tensorboard import SummaryWriter
9 |
10 | import shutil
11 | import numpy as np
12 |
13 | from PIL import Image
14 | from tqdm import tqdm
15 | from .utils import cvtColor, preprocess_input, resize_image
16 | from .utils_bbox import decode_outputs, non_max_suppression
17 | from .utils_map import get_coco_map, get_map
18 |
19 |
20 | class LossHistory():
21 | def __init__(self, log_dir, model, input_shape):
22 | self.log_dir = log_dir
23 | self.losses = []
24 | self.val_loss = []
25 |
26 | os.makedirs(self.log_dir)
27 | self.writer = SummaryWriter(self.log_dir)
28 | try:
29 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
30 | self.writer.add_graph(model, dummy_input)
31 | except:
32 | pass
33 |
34 | def append_loss(self, epoch, loss, val_loss):
35 | if not os.path.exists(self.log_dir):
36 | os.makedirs(self.log_dir)
37 |
38 | self.losses.append(loss)
39 | self.val_loss.append(val_loss)
40 |
41 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
42 | f.write(str(loss))
43 | f.write("\n")
44 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
45 | f.write(str(val_loss))
46 | f.write("\n")
47 |
48 | self.writer.add_scalar('loss', loss, epoch)
49 | self.writer.add_scalar('val_loss', val_loss, epoch)
50 | self.loss_plot()
51 |
52 | def loss_plot(self):
53 | iters = range(len(self.losses))
54 |
55 | plt.figure()
56 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
57 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
58 | try:
59 | if len(self.losses) < 25:
60 | num = 5
61 | else:
62 | num = 15
63 |
64 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
65 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
66 | except:
67 | pass
68 |
69 | plt.grid(True)
70 | plt.xlabel('Epoch')
71 | plt.ylabel('Loss')
72 | plt.legend(loc="upper right")
73 |
74 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
75 |
76 | plt.cla()
77 | plt.close("all")
78 |
79 | class EvalCallback():
80 | def __init__(self, net, input_shape, class_names, num_classes, val_lines, log_dir, cuda, \
81 | map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
82 | super(EvalCallback, self).__init__()
83 |
84 | self.net = net
85 | self.input_shape = input_shape
86 | self.class_names = class_names
87 | self.num_classes = num_classes
88 | self.val_lines = val_lines
89 | self.log_dir = log_dir
90 | self.cuda = cuda
91 | self.map_out_path = map_out_path
92 | self.max_boxes = max_boxes
93 | self.confidence = confidence
94 | self.nms_iou = nms_iou
95 | self.letterbox_image = letterbox_image
96 | self.MINOVERLAP = MINOVERLAP
97 | self.eval_flag = eval_flag
98 | self.period = period
99 |
100 | self.maps = [0]
101 | self.epoches = [0]
102 | if self.eval_flag:
103 | with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
104 | f.write(str(0))
105 | f.write("\n")
106 |
107 | def get_map_txt(self, image_id, image, class_names, map_out_path):
108 | f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
109 | image_shape = np.array(np.shape(image)[0:2])
110 | #---------------------------------------------------------#
111 | # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
112 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
113 | #---------------------------------------------------------#
114 | image = cvtColor(image)
115 | #---------------------------------------------------------#
116 | # 给图像增加灰条,实现不失真的resize
117 | # 也可以直接resize进行识别
118 | #---------------------------------------------------------#
119 | image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
120 | #---------------------------------------------------------#
121 | # 添加上batch_size维度
122 | #---------------------------------------------------------#
123 | image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
124 |
125 | with torch.no_grad():
126 | images = torch.from_numpy(image_data)
127 | if self.cuda:
128 | images = images.cuda()
129 | #---------------------------------------------------------#
130 | # 将图像输入网络当中进行预测!
131 | #---------------------------------------------------------#
132 | outputs = self.net(images)
133 | outputs = decode_outputs(outputs, self.input_shape)
134 | #---------------------------------------------------------#
135 | # 将预测框进行堆叠,然后进行非极大抑制
136 | #---------------------------------------------------------#
137 | results = non_max_suppression(outputs, self.num_classes, self.input_shape,
138 | image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
139 |
140 | if results[0] is None:
141 | return
142 |
143 | top_label = np.array(results[0][:, 6], dtype = 'int32')
144 | top_conf = results[0][:, 4] * results[0][:, 5]
145 | top_boxes = results[0][:, :4]
146 |
147 | top_100 = np.argsort(top_conf)[::-1][:self.max_boxes]
148 | top_boxes = top_boxes[top_100]
149 | top_conf = top_conf[top_100]
150 | top_label = top_label[top_100]
151 |
152 | for i, c in list(enumerate(top_label)):
153 | predicted_class = self.class_names[int(c)]
154 | box = top_boxes[i]
155 | score = str(top_conf[i])
156 |
157 | top, left, bottom, right = box
158 | if predicted_class not in class_names:
159 | continue
160 |
161 | f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
162 |
163 | f.close()
164 | return
165 |
166 | def on_epoch_end(self, epoch, model_eval):
167 | if epoch % self.period == 0 and self.eval_flag:
168 | self.net = model_eval
169 | if not os.path.exists(self.map_out_path):
170 | os.makedirs(self.map_out_path)
171 | if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
172 | os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
173 | if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
174 | os.makedirs(os.path.join(self.map_out_path, "detection-results"))
175 | print("Get map.")
176 | for annotation_line in tqdm(self.val_lines):
177 | line = annotation_line.split()
178 | image_id = os.path.basename(line[0]).split('.')[0]
179 | #------------------------------#
180 | # 读取图像并转换成RGB图像
181 | #------------------------------#
182 | image = Image.open(line[0])
183 | #------------------------------#
184 | # 获得预测框
185 | #------------------------------#
186 | gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
187 | #------------------------------#
188 | # 获得预测txt
189 | #------------------------------#
190 | self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
191 |
192 | #------------------------------#
193 | # 获得真实框txt
194 | #------------------------------#
195 | with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
196 | for box in gt_boxes:
197 | left, top, right, bottom, obj = box
198 | obj_name = self.class_names[obj]
199 | new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
200 |
201 | print("Calculate Map.")
202 | try:
203 | temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
204 | except:
205 | temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
206 | self.maps.append(temp_map)
207 | self.epoches.append(epoch)
208 |
209 | with open("./epoch_map.txt", 'a') as f:
210 | f.write(str(temp_map))
211 | f.write("\n")
212 |
213 | plt.figure()
214 | plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
215 |
216 | plt.grid(True)
217 | plt.xlabel('Epoch')
218 | plt.ylabel('Map %s'%str(self.MINOVERLAP))
219 | plt.title('A Map Curve')
220 | plt.legend(loc="upper right")
221 |
222 | plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
223 | plt.cla()
224 | plt.close("all")
225 |
226 | print("Get map done.")
227 | shutil.rmtree(self.map_out_path)
228 |
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | import cv2
5 | from typing import Any, Tuple
6 | import torch
7 | from torch.utils.data import Dataset
8 | import torchvision.transforms.functional as TF
9 |
10 |
11 | class MedSamDataset(Dataset):
12 | def __init__(
13 | self,
14 | df: pd.DataFrame,
15 | image_col: str,
16 | mask_col: str,
17 | image_dir: Any = None,
18 | mask_dir: str = None,
19 | image_size: Tuple = (256, 256),
20 | ):
21 | """
22 | PyTorch dataset class for loading image,mask and bbox pairs from a dataframe.
23 | The dataframe will need to have atleast two columns for the image and mask file names. The columns can either have the full or relative
24 | path of the images or just the file names.
25 | If only file names are given in the columns, the `image_dir` and `mask_dir` arguments should be specified.
26 |
27 | Args:
28 | df (pd.DataFrame): the pandas dataframe object
29 | image_col (str): the name of the column on the dataframe that holds the image file names.
30 | mask_col (str): the name of the column on the dataframe that holds the mask file names.
31 | image_dir (Any, optional): Path to the input image directory. Defaults to None.
32 | mask_dir (str, optional): Path to the mask images directory. Defaults to None.
33 | image_size (Tuple, optional): image size. Defaults to (256, 256).
34 | """
35 | self.df = df
36 | self.image_dir = image_dir
37 | self.mask_dir = mask_dir
38 | self.image_col = image_col
39 | self.mask_col = mask_col
40 | self.image_size = image_size
41 |
42 | def __len__(self):
43 | return len(self.df)
44 |
45 | def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
46 | # read dataframe row
47 | row = self.df.iloc[idx]
48 | # If the `image_dir` attribute is set, the path will be relative to that directory.
49 | # Otherwise, the path will be the value of the `row[self.image_col]` attribute.
50 | image_file = (
51 | os.path.join(self.image_dir, row[self.image_col])
52 | if self.image_dir
53 | else row[self.image_col]
54 | )
55 | mask_file = (
56 | os.path.join(self.mask_dir, row[self.mask_col])
57 | if self.mask_dir
58 | else row[self.mask_col]
59 | )
60 |
61 | if not os.path.exists(image_file):
62 | raise FileNotFoundError(f"Couldn't find image {image_file}")
63 | if not os.path.exists(mask_file):
64 | raise FileNotFoundError(f"Couldn't find image {mask_file}")
65 |
66 | # read image and mask files
67 | image_data = cv2.imread(image_file)
68 | # read mask as gray scale
69 | mask_data = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE)
70 |
71 | return self._preprocess(image_data, mask_data)
72 |
73 | def _preprocess(
74 | self, image: np.ndarray, mask: np.ndarray
75 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
76 | # Threshold mask to binary
77 | mask = cv2.threshold(mask, 127.0, 255.0, cv2.THRESH_BINARY)[1]
78 | # convert to tensor
79 | image = TF.to_tensor(image)
80 | mask = TF.to_tensor(mask)
81 | # min-max normalize and scale
82 | image = (image - image.min()) / (image.max() - image.min()) * 255.0
83 | # resize
84 | image = TF.resize(image, self.image_size, antialias=True)
85 | mask = TF.resize(mask, self.image_size, antialias=True)
86 |
87 | bbox = self._get_bbox(mask)
88 |
89 | return image, mask, bbox
90 |
91 | def _get_bbox(self, mask: torch.Tensor) -> torch.Tensor:
92 | _, y_indices, x_indices = torch.where(mask > 0)
93 |
94 | x_min, y_min = (x_indices.min(), y_indices.min())
95 | x_max, y_max = (x_indices.max(), y_indices.max())
96 |
97 | # add perturbation to bounding box coordinates
98 | H, W = mask.shape[1:]
99 | # add perfurbation to the bbox
100 | assert H == W, f"{W} and {H} are not equal size!!"
101 | x_min = max(0, x_min - np.random.randint(0, 10))
102 | x_max = min(W, x_max + np.random.randint(0, 10))
103 | y_min = max(0, y_min - np.random.randint(0, 10))
104 | y_max = min(H, y_max + np.random.randint(0, 10))
105 |
106 | return torch.tensor([x_min, y_min, x_max, y_max])
107 |
--------------------------------------------------------------------------------
/utils/precompute_img_embed.py:
--------------------------------------------------------------------------------
1 | #%% import packages
2 | # precompute image embeddings and save them to disk for model training
3 |
4 | import numpy as np
5 | import os
6 | join = os.path.join
7 | from skimage import io, segmentation
8 | from tqdm import tqdm
9 | import torch
10 | from segment_anything import sam_model_registry
11 | from segment_anything.utils.transforms import ResizeLongestSide
12 | import argparse
13 |
14 | #%% parse arguments
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('-i', '--img_path', type=str, default='./data/Tr_Release_Part1', help='# and also Tr_Release_Part2 when part1 is done')
17 | parser.add_argument('-o', '--save_path', type=str, default='./data/Tr_npy', help='path to save the image embeddings')
18 | parser.add_argument('--model_type', type=str, default='vit_b', help='model type')
19 | parser.add_argument('--checkpoint', type=str, default='../work_dir/SAM/sam_vit_b_01ec64.pth', help='path to the pre-trained SAM model')
20 | args = parser.parse_args()
21 |
22 | pre_img_path = args.img_path
23 | save_img_emb_path = join(args.save_path, 'npy_embs')
24 | save_gt_path = join(args.save_path, 'npy_gts')
25 | os.makedirs(save_img_emb_path, exist_ok=True)
26 | os.makedirs(save_gt_path, exist_ok=True)
27 | npz_files = sorted(os.listdir(pre_img_path))
28 | #%% set up the model
29 | sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to('cuda:0')
30 | sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
31 |
32 | # compute image embeddings
33 | for name in tqdm(npz_files):
34 | img = np.load(join(pre_img_path, name))['img'] # (256, 256, 3)
35 | gt = np.load(join(pre_img_path, name))['gt']
36 | resize_img = sam_transform.apply_image(img)
37 | resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to('cuda:0')
38 | # model input: (1, 3, 1024, 1024)
39 | input_image = sam_model.preprocess(resize_img_tensor[None,:,:,:]) # (1, 3, 1024, 1024)
40 | assert input_image.shape == (1, 3, sam_model.image_encoder.img_size, sam_model.image_encoder.img_size), 'input image should be resized to 1024*1024'
41 | with torch.no_grad():
42 | embedding = sam_model.image_encoder(input_image)
43 |
44 | # save as npy
45 | np.save(join(save_img_emb_path, name.split('.npz')[0]+'.npy'), embedding.cpu().numpy()[0])
46 | np.save(join(save_gt_path, name.split('.npz')[0]+'.npy'), gt)
47 | # sanity check
48 | img_idx = img.copy()
49 | bd = segmentation.find_boundaries(gt, mode='inner')
50 | img_idx[bd, :] = [255, 0, 0]
51 | io.imsave(save_img_emb_path + '.png', img_idx)
52 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 |
5 | #---------------------------------------------------------#
6 | # 将图像转换成RGB图像,防止灰度图在预测时报错。
7 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
8 | #---------------------------------------------------------#
9 | def cvtColor(image):
10 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
11 | return image
12 | else:
13 | image = image.convert('RGB')
14 | return image
15 |
16 | #---------------------------------------------------#
17 | # 对输入图像进行resize
18 | #---------------------------------------------------#
19 | def resize_image(image, size, letterbox_image):
20 | iw, ih = image.size
21 | w, h = size
22 | if letterbox_image:
23 | scale = min(w/iw, h/ih)
24 | nw = int(iw*scale)
25 | nh = int(ih*scale)
26 |
27 | image = image.resize((nw,nh), Image.BICUBIC)
28 | new_image = Image.new('RGB', size, (128,128,128))
29 | new_image.paste(image, ((w-nw)//2, (h-nh)//2))
30 | else:
31 | new_image = image.resize((w, h), Image.BICUBIC)
32 | return new_image
33 |
34 | #---------------------------------------------------#
35 | # 获得类
36 | #---------------------------------------------------#
37 | def get_classes(classes_path):
38 | with open(classes_path, encoding='utf-8') as f:
39 | class_names = f.readlines()
40 | class_names = [c.strip() for c in class_names]
41 | return class_names, len(class_names)
42 |
43 | def preprocess_input(image):
44 | image /= 255.0
45 | image -= np.array([0.485, 0.456, 0.406])
46 | image /= np.array([0.229, 0.224, 0.225])
47 | return image
48 |
49 | #---------------------------------------------------#
50 | # 获得学习率
51 | #---------------------------------------------------#
52 | def get_lr(optimizer):
53 | for param_group in optimizer.param_groups:
54 | return param_group['lr']
55 |
56 | def show_config(**kwargs):
57 | print('Configurations:')
58 | print('-' * 70)
59 | print('|%25s | %40s|' % ('keys', 'values'))
60 | print('-' * 70)
61 | for key, value in kwargs.items():
62 | print('|%25s | %40s|' % (str(key), str(value)))
63 | print('-' * 70)
64 |
--------------------------------------------------------------------------------
/utils/utils_bbox.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torchvision.ops import nms, boxes
4 |
5 | def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image):
6 | #-----------------------------------------------------------------#
7 | # 把y轴放前面是因为方便预测框和图像的宽高进行相乘
8 | #-----------------------------------------------------------------#
9 | box_yx = box_xy[..., ::-1]
10 | box_hw = box_wh[..., ::-1]
11 | input_shape = np.array(input_shape)
12 | image_shape = np.array(image_shape)
13 |
14 | if letterbox_image:
15 | #-----------------------------------------------------------------#
16 | # 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
17 | # new_shape指的是宽高缩放情况
18 | #-----------------------------------------------------------------#
19 | new_shape = np.round(image_shape * np.min(input_shape/image_shape))
20 | offset = (input_shape - new_shape)/2./input_shape
21 | scale = input_shape/new_shape
22 |
23 | box_yx = (box_yx - offset) * scale
24 | box_hw *= scale
25 |
26 | box_mins = box_yx - (box_hw / 2.)
27 | box_maxes = box_yx + (box_hw / 2.)
28 | boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
29 | boxes *= np.concatenate([image_shape, image_shape], axis=-1)
30 | return boxes
31 |
32 | def decode_outputs(outputs, input_shape):
33 | grids = []
34 | strides = []
35 | hw = [x.shape[-2:] for x in outputs]
36 | #---------------------------------------------------#
37 | # outputs输入前代表每个特征层的预测结果
38 | # batch_size, 4 + 1 + num_classes, 80, 80 => batch_size, 4 + 1 + num_classes, 6400
39 | # batch_size, 5 + num_classes, 40, 40
40 | # batch_size, 5 + num_classes, 20, 20
41 | # batch_size, 4 + 1 + num_classes, 6400 + 1600 + 400 -> batch_size, 4 + 1 + num_classes, 8400
42 | # 堆叠后为batch_size, 8400, 5 + num_classes
43 | #---------------------------------------------------#
44 | outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
45 | #---------------------------------------------------#
46 | # 获得每一个特征点属于每一个种类的概率
47 | #---------------------------------------------------#
48 | outputs[:, :, 4:] = torch.sigmoid(outputs[:, :, 4:])
49 | for h, w in hw:
50 | #---------------------------#
51 | # 根据特征层的高宽生成网格点
52 | #---------------------------#
53 | grid_y, grid_x = torch.meshgrid([torch.arange(h), torch.arange(w)])
54 | #---------------------------#
55 | # 1, 6400, 2
56 | # 1, 1600, 2
57 | # 1, 400, 2
58 | #---------------------------#
59 | grid = torch.stack((grid_x, grid_y), 2).view(1, -1, 2)
60 | shape = grid.shape[:2]
61 |
62 | grids.append(grid)
63 | strides.append(torch.full((shape[0], shape[1], 1), input_shape[0] / h))
64 | #---------------------------#
65 | # 将网格点堆叠到一起
66 | # 1, 6400, 2
67 | # 1, 1600, 2
68 | # 1, 400, 2
69 | #
70 | # 1, 8400, 2
71 | #---------------------------#
72 | grids = torch.cat(grids, dim=1).type(outputs.type())
73 | strides = torch.cat(strides, dim=1).type(outputs.type())
74 | #------------------------#
75 | # 根据网格点进行解码
76 | #------------------------#
77 | outputs[..., :2] = (outputs[..., :2] + grids) * strides
78 | outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
79 | #-----------------#
80 | # 归一化
81 | #-----------------#
82 | outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1]
83 | outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0]
84 | return outputs
85 |
86 | def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
87 | #----------------------------------------------------------#
88 | # 将预测结果的格式转换成左上角右下角的格式。
89 | # prediction [batch_size, num_anchors, 85]
90 | #----------------------------------------------------------#
91 | box_corner = prediction.new(prediction.shape)
92 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
93 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
94 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
95 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
96 | prediction[:, :, :4] = box_corner[:, :, :4]
97 |
98 | output = [None for _ in range(len(prediction))]
99 | #----------------------------------------------------------#
100 | # 对输入图片进行循环,一般只会进行一次
101 | #----------------------------------------------------------#
102 | for i, image_pred in enumerate(prediction):
103 |
104 | #----------------------------------------------------------#
105 | # 对种类预测部分取max。
106 | # class_conf [num_anchors, 1] 种类置信度
107 | # class_pred [num_anchors, 1] 种类
108 | #----------------------------------------------------------#
109 | class_conf,class_pred= torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
110 | for i in class_pred:
111 | if class_pred[i]!=0:
112 | print(1)
113 | #----------------------------------------------------------#
114 | # 利用置信度进行第一轮筛选
115 | #----------------------------------------------------------#
116 | conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
117 | if not image_pred.size(0):
118 | continue
119 | #-------------------------------------------------------------------------#
120 | # detections [num_anchors, 7]
121 | # 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
122 | #-------------------------------------------------------------------------#
123 | detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
124 | detections = detections[conf_mask]
125 | nms_out_index = boxes.batched_nms(
126 | detections[:, :4],
127 | detections[:, 4] * detections[:, 5],
128 | detections[:, 6],
129 | nms_thres,
130 | )
131 |
132 | output[i] = detections[nms_out_index]
133 |
134 | # #------------------------------------------#
135 | # # 获得预测结果中包含的所有种类
136 | # #------------------------------------------#
137 | # unique_labels = detections[:, -1].cpu().unique()
138 |
139 | # if prediction.is_cuda:
140 | # unique_labels = unique_labels.cuda()
141 | # detections = detections.cuda()
142 |
143 | # for c in unique_labels:
144 | # #------------------------------------------#
145 | # # 获得某一类得分筛选后全部的预测结果
146 | # #------------------------------------------#
147 | # detections_class = detections[detections[:, -1] == c]
148 |
149 | # #------------------------------------------#
150 | # # 使用官方自带的非极大抑制会速度更快一些!
151 | # #------------------------------------------#
152 | # keep = nms(
153 | # detections_class[:, :4],
154 | # detections_class[:, 4] * detections_class[:, 5],
155 | # nms_thres
156 | # )
157 | # max_detections = detections_class[keep]
158 |
159 | # # # 按照存在物体的置信度排序
160 | # # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
161 | # # detections_class = detections_class[conf_sort_index]
162 | # # # 进行非极大抑制
163 | # # max_detections = []
164 | # # while detections_class.size(0):
165 | # # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
166 | # # max_detections.append(detections_class[0].unsqueeze(0))
167 | # # if len(detections_class) == 1:
168 | # # break
169 | # # ious = bbox_iou(max_detections[-1], detections_class[1:])
170 | # # detections_class = detections_class[1:][ious < nms_thres]
171 | # # # 堆叠
172 | # # max_detections = torch.cat(max_detections).data
173 |
174 | # # Add max detections to outputs
175 | # output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
176 |
177 | if output[i] is not None:
178 | output[i] = output[i].cpu().numpy()
179 | box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
180 | output[i][:, :4] = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
181 | return output
182 |
--------------------------------------------------------------------------------
/utils/utils_fit.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | from utils.utils import get_lr
7 |
8 |
9 | def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
10 | loss = 0
11 | val_loss = 0
12 |
13 | if local_rank == 0:
14 | print('Start Train')
15 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
16 | model_train.train()
17 | for iteration, batch in enumerate(gen):
18 | if iteration >= epoch_step:
19 | break
20 |
21 | images, targets = batch[0], batch[1]
22 | with torch.no_grad():
23 | if cuda:
24 | images = images.cuda(local_rank)
25 | targets = [ann.cuda(local_rank) for ann in targets]
26 | #----------------------#
27 | # 清零梯度
28 | #----------------------#
29 | optimizer.zero_grad()
30 | if not fp16:
31 | #----------------------#
32 | # 前向传播
33 | #----------------------#
34 | outputs = model_train(images)
35 |
36 | #----------------------#
37 | # 计算损失
38 | #----------------------#
39 | loss_value = yolo_loss(outputs, targets)
40 |
41 | #----------------------#
42 | # 反向传播
43 | #----------------------#
44 | loss_value.backward()
45 | optimizer.step()
46 | else:
47 | from torch.cuda.amp import autocast
48 | with autocast():
49 | outputs = model_train(images)
50 | #----------------------#
51 | # 计算损失
52 | #----------------------#
53 | loss_value = yolo_loss(outputs, targets)
54 |
55 | #----------------------#
56 | # 反向传播
57 | #----------------------#
58 | scaler.scale(loss_value).backward()
59 | scaler.step(optimizer)
60 | scaler.update()
61 | if ema:
62 | ema.update(model_train)
63 |
64 | loss += loss_value.item()
65 |
66 | if local_rank == 0:
67 | pbar.set_postfix(**{'loss' : loss / (iteration + 1),
68 | 'lr' : get_lr(optimizer)})
69 | pbar.update(1)
70 |
71 | if local_rank == 0:
72 | pbar.close()
73 | print('Finish Train')
74 | print('Start Validation')
75 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
76 |
77 | if ema:
78 | model_train_eval = ema.ema
79 | else:
80 | model_train_eval = model_train.eval()
81 |
82 | for iteration, batch in enumerate(gen_val):
83 | if iteration >= epoch_step_val:
84 | break
85 | images, targets = batch[0], batch[1]
86 | with torch.no_grad():
87 | if cuda:
88 | images = images.cuda(local_rank)
89 | targets = [ann.cuda(local_rank) for ann in targets]
90 | #----------------------#
91 | # 清零梯度
92 | #----------------------#
93 | optimizer.zero_grad()
94 | #----------------------#
95 | # 前向传播
96 | #----------------------#
97 | outputs = model_train_eval(images)
98 |
99 | #----------------------#
100 | # 计算损失
101 | #----------------------#
102 | loss_value = yolo_loss(outputs, targets)
103 |
104 | val_loss += loss_value.item()
105 | if local_rank == 0:
106 | pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
107 | pbar.update(1)
108 |
109 | if local_rank == 0:
110 | pbar.close()
111 | print('Finish Validation')
112 | loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
113 | eval_callback.on_epoch_end(epoch + 1, model_train_eval)
114 | print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
115 | print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
116 |
117 | #-----------------------------------------------#
118 | # 保存权值
119 | #-----------------------------------------------#
120 | if ema:
121 | save_state_dict = ema.ema.state_dict()
122 | else:
123 | save_state_dict = model.state_dict()
124 |
125 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
126 | torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
127 |
128 | if len(loss_history.val_loss) <= 1 or (val_loss / epoch_step_val) <= min(loss_history.val_loss):
129 | print('Save best model to best_epoch_weights.pth')
130 | torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))
131 |
132 | torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth"))
--------------------------------------------------------------------------------
/utils_metrics.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | from os.path import join
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from PIL import Image
10 |
11 |
12 | def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5):
13 | n, c, h, w = inputs.size()
14 | nt, ht, wt, ct = target.size()
15 | if h != ht and w != wt:
16 | inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
17 |
18 | temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
19 | temp_target = target.view(n, -1, ct)
20 |
21 | #--------------------------------------------#
22 | # 计算dice系数
23 | #--------------------------------------------#
24 | temp_inputs = torch.gt(temp_inputs, threhold).float()
25 | tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
26 | fp = torch.sum(temp_inputs , axis=[0,1]) - tp
27 | fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp
28 |
29 | score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
30 | score = torch.mean(score)
31 | return score
32 |
33 | # 设标签宽W,长H
34 | def fast_hist(a, b, n):
35 |
36 | #--------------------------------------------------------------------------------#
37 | # a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,)
38 | #--------------------------------------------------------------------------------#
39 | k = (a >= 0) & (a < n)
40 | #--------------------------------------------------------------------------------#
41 | # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
42 | # 返回中,写对角线上的为分类正确的像素点
43 | #--------------------------------------------------------------------------------#
44 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
45 |
46 | def per_class_iu(hist):
47 | return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1)
48 |
49 | def per_class_PA_Recall(hist):
50 | return np.diag(hist) / np.maximum(hist.sum(1), 1)
51 |
52 | def per_class_Precision(hist):
53 | return np.diag(hist) / np.maximum(hist.sum(0), 1)
54 |
55 | def per_Accuracy(hist):
56 | return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1)
57 |
58 | def compute_mIoU(gts, preds, num_classes=2,name_classes= ["background","tongue"]):
59 | hist = np.zeros((num_classes, num_classes))
60 | for ind in range(len(gts)):
61 | pred = preds[ind]
62 | pred[pred== 2] = 1
63 | label = gts[ind]
64 | label[label== 2] = 1
65 | hist += fast_hist(label.flatten(), pred.flatten(), num_classes)
66 | IoUs = per_class_iu(hist)
67 | PA_Recall = per_class_PA_Recall(hist)
68 | Precision = per_class_Precision(hist)
69 | #------------------------------------------------#
70 | # 逐类别输出一下mIoU值
71 | #------------------------------------------------#
72 | if name_classes is not None:
73 | for ind_class in range(num_classes):
74 | print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \
75 | + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2)))
76 |
77 | #-----------------------------------------------------------------#
78 | # 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值
79 | #-----------------------------------------------------------------#
80 | print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2)))
81 | return round(np.nanmean(IoUs) * 100, 2), round(np.nanmean(PA_Recall) * 100, 2), round(per_Accuracy(hist) * 100, 2)
82 |
83 | def adjust_axes(r, t, fig, axes):
84 | bb = t.get_window_extent(renderer=r)
85 | text_width_inches = bb.width / fig.dpi
86 | current_fig_width = fig.get_figwidth()
87 | new_fig_width = current_fig_width + text_width_inches
88 | propotion = new_fig_width / current_fig_width
89 | x_lim = axes.get_xlim()
90 | axes.set_xlim([x_lim[0], x_lim[1] * propotion])
91 |
92 | def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True):
93 | fig = plt.gcf()
94 | axes = plt.gca()
95 | plt.barh(range(len(values)), values, color='royalblue')
96 | plt.title(plot_title, fontsize=tick_font_size + 2)
97 | plt.xlabel(x_label, fontsize=tick_font_size)
98 | plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size)
99 | r = fig.canvas.get_renderer()
100 | for i, val in enumerate(values):
101 | str_val = " " + str(val)
102 | if val < 1.0:
103 | str_val = " {0:.2f}".format(val)
104 | t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold')
105 | if i == (len(values)-1):
106 | adjust_axes(r, t, fig, axes)
107 |
108 | fig.tight_layout()
109 | fig.savefig(output_path)
110 | if plt_show:
111 | plt.show()
112 | plt.close()
113 |
114 | def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12):
115 | draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \
116 | os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True)
117 | print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png"))
118 |
119 | draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Pixel Accuracy", \
120 | os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False)
121 | print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png"))
122 |
123 | draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Recall", \
124 | os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False)
125 | print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png"))
126 |
127 | draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \
128 | os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False)
129 | print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png"))
130 |
131 | with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f:
132 | writer = csv.writer(f)
133 | writer_list = []
134 | writer_list.append([' '] + [str(c) for c in name_classes])
135 | for i in range(len(hist)):
136 | writer_list.append([name_classes[i]] + [str(x) for x in hist[i]])
137 | writer.writerows(writer_list)
138 | print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv"))
139 |
--------------------------------------------------------------------------------