├── .gitignore
├── .idea
├── encodings.xml
├── misc.xml
├── modules.xml
├── segment_indicator.iml
└── workspace.xml
├── README.md
├── __init__.py
├── ceshi.py
├── eval_segm.py
├── unit_tests.py
└── version2.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/segment_indicator.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 | 1551515379813
157 |
158 |
159 | 1551515379813
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Image Segmentation Evaluation
2 |
3 | Use details can be found in my blog
4 | 使用详情可见本人博客 https://blog.csdn.net/qq_40994943/article/details/88359871
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ferryer/segment_indicator/63b2d8f79679e652f06cb7f307738ed5773fd745/__init__.py
--------------------------------------------------------------------------------
/ceshi.py:
--------------------------------------------------------------------------------
1 | # 输入:测试集标注文件夹和对应的预测文件夹,图片命名为 标注+x;predict+x
2 | # 输出:四个指标的平均值
3 |
4 | from eval_segm import *
5 | import cv2 as cv
6 |
7 |
8 | def threshold_demo(a):
9 | gray = cv.cvtColor(a, cv.COLOR_RGB2GRAY) # 把输入图像灰度化
10 | # 直接阈值化是对输入的单通道矩阵逐像素进行阈值分割。
11 | ret, binary = cv.threshold(gray, 0, 255, cv.THRESH_BINARY | cv.THRESH_TRIANGLE)
12 | print("threshold value %s"%ret)
13 | return binary
14 |
15 |
16 | biaozhu_path = ''
17 | predict_path = ''
18 |
19 | pa = []
20 | mpa = []
21 | miou = []
22 | fwiou = []
23 |
24 | number_photo = 1
25 | for i in range(number_photo):
26 | i = i+1
27 | a = cv.imread(biaozhu_path + 'biaozhu' + str(i) + '.png')
28 | b = cv.imread(predict_path + 'predict' + str(i) + '.png')
29 | '''
30 | cv.namedWindow("a", cv.WINDOW_NORMAL)
31 | cv.imshow("a", a)
32 | cv.waitKey()
33 | cv.namedWindow("b", cv.WINDOW_NORMAL)
34 | cv.imshow("b", b)
35 | cv.waitKey()
36 | '''
37 | binary=threshold_demo(a)
38 | binary1=threshold_demo(b)
39 | '''
40 | cv.namedWindow("binary", cv.WINDOW_NORMAL)
41 | cv.imshow("binary", binary)
42 | cv.waitKey()
43 | cv.namedWindow("binary1", cv.WINDOW_NORMAL)
44 | cv.imshow("binary1", binary1)
45 | cv.waitKey()
46 | '''
47 | # binary1=cv.resize(binary1,(224,224))
48 | print(binary.shape)
49 | print(binary1.shape)
50 |
51 | # 计算分割指标
52 | pa_temporary=pixel_accuracy(binary,binary1)
53 | mpa_temporary=mean_accuracy(binary,binary1)
54 | miou_temporary=mean_IU(binary,binary1)
55 | fwiou_temporary=frequency_weighted_IU(binary,binary1)
56 |
57 | pa.append(pa_temporary)
58 | mpa.append(mpa_temporary)
59 | miou.append(miou_temporary)
60 | fwiou.append(fwiou_temporary)
61 |
62 | print(sum(pa)/number_photo)
63 | print(sum(mpa)/number_photo)
64 | print(sum(miou)/number_photo)
65 | print(sum(fwiou)/number_photo)
66 |
--------------------------------------------------------------------------------
/eval_segm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 |
3 | '''
4 | Martin Kersner, m.kersner@gmail.com
5 | 2015/11/30
6 |
7 | Evaluation metrics for image segmentation inspired by
8 | paper Fully Convolutional Networks for Semantic Segmentation.
9 | '''
10 |
11 | import numpy as np
12 |
13 | def pixel_accuracy(eval_segm, gt_segm):
14 | '''
15 | sum_i(n_ii) / sum_i(t_i)
16 | '''
17 |
18 | check_size(eval_segm, gt_segm)
19 |
20 | cl, n_cl = extract_classes(gt_segm)
21 | eval_mask, gt_mask = extract_both_masks(eval_segm, gt_segm, cl, n_cl)
22 |
23 | sum_n_ii = 0
24 | sum_t_i = 0
25 |
26 | for i, c in enumerate(cl):
27 | curr_eval_mask = eval_mask[i, :, :]
28 | curr_gt_mask = gt_mask[i, :, :]
29 |
30 | sum_n_ii += np.sum(np.logical_and(curr_eval_mask, curr_gt_mask))
31 | sum_t_i += np.sum(curr_gt_mask)
32 |
33 | if (sum_t_i == 0):
34 | pixel_accuracy_ = 0
35 | else:
36 | pixel_accuracy_ = sum_n_ii / sum_t_i
37 |
38 | return pixel_accuracy_
39 |
40 | def mean_accuracy(eval_segm, gt_segm):
41 | '''
42 | (1/n_cl) sum_i(n_ii/t_i)
43 | '''
44 |
45 | check_size(eval_segm, gt_segm)
46 |
47 | cl, n_cl = extract_classes(gt_segm)
48 | eval_mask, gt_mask = extract_both_masks(eval_segm, gt_segm, cl, n_cl)
49 |
50 | accuracy = list([0]) * n_cl
51 |
52 | for i, c in enumerate(cl):
53 | curr_eval_mask = eval_mask[i, :, :]
54 | curr_gt_mask = gt_mask[i, :, :]
55 |
56 | n_ii = np.sum(np.logical_and(curr_eval_mask, curr_gt_mask))
57 | t_i = np.sum(curr_gt_mask)
58 |
59 | if (t_i != 0):
60 | accuracy[i] = n_ii / t_i
61 |
62 | mean_accuracy_ = np.mean(accuracy)
63 | return mean_accuracy_
64 |
65 | def mean_IU(eval_segm, gt_segm):
66 | '''
67 | (1/n_cl) * sum_i(n_ii / (t_i + sum_j(n_ji) - n_ii))
68 | '''
69 |
70 | check_size(eval_segm, gt_segm)
71 |
72 | cl, n_cl = union_classes(eval_segm, gt_segm)
73 | _, n_cl_gt = extract_classes(gt_segm)
74 | eval_mask, gt_mask = extract_both_masks(eval_segm, gt_segm, cl, n_cl)
75 |
76 | IU = list([0]) * n_cl
77 |
78 | for i, c in enumerate(cl):
79 | curr_eval_mask = eval_mask[i, :, :]
80 | curr_gt_mask = gt_mask[i, :, :]
81 |
82 | if (np.sum(curr_eval_mask) == 0) or (np.sum(curr_gt_mask) == 0):
83 | continue
84 |
85 | n_ii = np.sum(np.logical_and(curr_eval_mask, curr_gt_mask))
86 | t_i = np.sum(curr_gt_mask)
87 | n_ij = np.sum(curr_eval_mask)
88 |
89 | IU[i] = n_ii / (t_i + n_ij - n_ii)
90 |
91 | mean_IU_ = np.sum(IU) / n_cl_gt
92 | return mean_IU_
93 |
94 | def frequency_weighted_IU(eval_segm, gt_segm):
95 | '''
96 | sum_k(t_k)^(-1) * sum_i((t_i*n_ii)/(t_i + sum_j(n_ji) - n_ii))
97 | '''
98 |
99 | check_size(eval_segm, gt_segm)
100 |
101 | cl, n_cl = union_classes(eval_segm, gt_segm)
102 | eval_mask, gt_mask = extract_both_masks(eval_segm, gt_segm, cl, n_cl)
103 |
104 | frequency_weighted_IU_ = list([0]) * n_cl
105 |
106 | for i, c in enumerate(cl):
107 | curr_eval_mask = eval_mask[i, :, :]
108 | curr_gt_mask = gt_mask[i, :, :]
109 |
110 | if (np.sum(curr_eval_mask) == 0) or (np.sum(curr_gt_mask) == 0):
111 | continue
112 |
113 | n_ii = np.sum(np.logical_and(curr_eval_mask, curr_gt_mask))
114 | t_i = np.sum(curr_gt_mask)
115 | n_ij = np.sum(curr_eval_mask)
116 |
117 | frequency_weighted_IU_[i] = (t_i * n_ii) / (t_i + n_ij - n_ii)
118 |
119 | sum_k_t_k = get_pixel_area(eval_segm)
120 |
121 | frequency_weighted_IU_ = np.sum(frequency_weighted_IU_) / sum_k_t_k
122 | return frequency_weighted_IU_
123 |
124 | '''
125 | Auxiliary functions used during evaluation.
126 | '''
127 | def get_pixel_area(segm):
128 | return segm.shape[0] * segm.shape[1]
129 |
130 | def extract_both_masks(eval_segm, gt_segm, cl, n_cl):
131 | eval_mask = extract_masks(eval_segm, cl, n_cl)
132 | gt_mask = extract_masks(gt_segm, cl, n_cl)
133 |
134 | return eval_mask, gt_mask
135 |
136 | def extract_classes(segm):
137 | cl = np.unique(segm)
138 | n_cl = len(cl)
139 |
140 | return cl, n_cl
141 |
142 | def union_classes(eval_segm, gt_segm):
143 | eval_cl, _ = extract_classes(eval_segm)
144 | gt_cl, _ = extract_classes(gt_segm)
145 |
146 | cl = np.union1d(eval_cl, gt_cl)
147 | n_cl = len(cl)
148 |
149 | return cl, n_cl
150 |
151 | def extract_masks(segm, cl, n_cl):
152 | h, w = segm_size(segm)
153 | masks = np.zeros((n_cl, h, w))
154 |
155 | for i, c in enumerate(cl):
156 | masks[i, :, :] = segm == c
157 |
158 | return masks
159 |
160 | def segm_size(segm):
161 | try:
162 | height = segm.shape[0]
163 | width = segm.shape[1]
164 | except IndexError:
165 | raise
166 |
167 | return height, width
168 |
169 | def check_size(eval_segm, gt_segm):
170 | h_e, w_e = segm_size(eval_segm)
171 | h_g, w_g = segm_size(gt_segm)
172 |
173 | if (h_e != h_g) or (w_e != w_g):
174 | raise EvalSegErr("DiffDim: Different dimensions of matrices!")
175 |
176 | '''
177 | Exceptions
178 | '''
179 | class EvalSegErr(Exception):
180 | def __init__(self, value):
181 | self.value = value
182 |
183 | def __str__(self):
184 | return repr(self.value)
185 |
--------------------------------------------------------------------------------
/unit_tests.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 |
3 | '''
4 | Martin Kersner, m.kersner@gmail.com
5 | 2015/11/30
6 |
7 | Unit tests for eval_segm.py.
8 | '''
9 |
10 | import numpy as np
11 | import eval_segm as es
12 | import unittest
13 |
14 | class pixel_accuracy_UnitTests(unittest.TestCase):
15 | '''
16 | Wrong inputs
17 | '''
18 | def test1dInput(self):
19 | mat = np.array([0])
20 | self.assertRaises(IndexError, es.pixel_accuracy, mat, mat)
21 |
22 | def testDiffDim(self):
23 | mat0 = np.array([[0,0], [0,0]])
24 | mat1 = np.array([[0,0,0], [0,0,0]])
25 | self.assertRaisesRegexp(es.EvalSegErr, "DiffDim", es.pixel_accuracy, mat0, mat1)
26 |
27 | '''
28 | Correct inputs
29 | '''
30 | def testOneClass(self):
31 | segm = np.array([[0,0], [0,0]])
32 | gt = np.array([[0,0], [0,0]])
33 |
34 | res = es.pixel_accuracy(segm, gt)
35 | self.assertEqual(res, 1.0)
36 |
37 | def testTwoClasses0(self):
38 | segm = np.array([[1,1,1,1,1], [1,1,1,1,1]])
39 | gt = np.array([[0,0,0,0,0], [0,0,0,0,0]])
40 |
41 | res = es.pixel_accuracy(segm, gt)
42 | self.assertEqual(res, 0)
43 |
44 | def testTwoClasses1(self):
45 | segm = np.array([[1,0,0,0,0], [0,0,0,0,0]])
46 | gt = np.array([[0,0,0,0,0], [0,0,0,0,0]])
47 |
48 | res = es.pixel_accuracy(segm, gt)
49 | self.assertEqual(res, (9.0)/(10.0))
50 |
51 | def testTwoClasses2(self):
52 | segm = np.array([[0,0,0,0,0], [0,0,0,0,0]])
53 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
54 |
55 | res = es.pixel_accuracy(segm, gt)
56 | self.assertEqual(res, (9.0+0.0)/(9.0+1.0))
57 |
58 | def testThreeClasses0(self):
59 | segm = np.array([[0,0,0,0,0], [0,0,0,0,0]])
60 | gt = np.array([[1,2,0,0,0], [0,0,0,0,0]])
61 |
62 | res = es.pixel_accuracy(segm, gt)
63 | self.assertEqual(res, (8.0+0.0+0.0)/(8.0+1.0+1.0))
64 |
65 | def testThreeClasses1(self):
66 | segm = np.array([[0,2,0,0,0], [0,0,0,0,0]])
67 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
68 |
69 | res = es.pixel_accuracy(segm, gt)
70 | self.assertEqual(res, (8.0+0.0)/(9.0+1.0))
71 |
72 | def testFourClasses0(self):
73 | segm = np.array([[0,2,3,0,0], [0,0,0,0,0]])
74 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
75 |
76 | res = es.pixel_accuracy(segm, gt)
77 | self.assertEqual(res, (7.0+0.0)/(9.0+1.0))
78 |
79 | def testFourClasses1(self):
80 | segm = np.array([[1,2,3,0,0], [0,0,0,0,0]])
81 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
82 |
83 | res = es.pixel_accuracy(segm, gt)
84 | self.assertEqual(res, (7.0+1.0)/(9.0+1.0))
85 |
86 | def testFiveClasses0(self):
87 | segm = np.array([[1,2,3,4,3], [0,0,0,0,0]])
88 | gt = np.array([[1,0,3,0,0], [0,0,0,0,0]])
89 |
90 | res = es.pixel_accuracy(segm, gt)
91 | self.assertEqual(res, (5.0+1.0+1.0)/(8.0+1.0+1.0))
92 |
93 | class mean_accuracy_UnitTests(unittest.TestCase):
94 | '''
95 | Wrong inputs
96 | '''
97 | def test1dInput(self):
98 | mat = np.array([0])
99 | self.assertRaises(IndexError, es.mean_accuracy, mat, mat)
100 |
101 | def testDiffDim(self):
102 | mat0 = np.array([[0,0], [0,0]])
103 | mat1 = np.array([[0,0,0], [0,0,0]])
104 | self.assertRaisesRegexp(es.EvalSegErr, "DiffDim", es.mean_accuracy, mat0, mat1)
105 |
106 | '''
107 | Correct inputs
108 | '''
109 | def testOneClass(self):
110 | segm = np.array([[0,0], [0,0]])
111 | gt = np.array([[0,0], [0,0]])
112 |
113 | res = es.mean_accuracy(segm, gt)
114 | self.assertEqual(res, 1.0)
115 |
116 | def testTwoClasses0(self):
117 | segm = np.array([[1,1,1,1,1], [1,1,1,1,1]])
118 | gt = np.array([[0,0,0,0,0], [0,0,0,0,0]])
119 |
120 | res = es.mean_accuracy(segm, gt)
121 | self.assertEqual(res, 0)
122 |
123 | def testTwoClasses1(self):
124 | segm = np.array([[1,0,0,0,0], [0,0,0,0,0]])
125 | gt = np.array([[0,0,0,0,0], [0,0,0,0,0]])
126 |
127 | res = es.mean_accuracy(segm, gt)
128 | self.assertEqual(res, 9.0/10.0)
129 |
130 | def testTwoClasses2(self):
131 | segm = np.array([[0,0,0,0,0], [0,0,0,0,0]])
132 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
133 |
134 | res = es.mean_accuracy(segm, gt)
135 | self.assertEqual(res, np.mean([9.0/9.0, 0.0/1.0]))
136 |
137 | def testThreeClasses0(self):
138 | segm = np.array([[0,0,0,0,0], [0,0,0,0,0]])
139 | gt = np.array([[1,2,0,0,0], [0,0,0,0,0]])
140 |
141 | res = es.mean_accuracy(segm, gt)
142 | self.assertEqual(res, np.mean([8.0/8.0, 0.0/1.0, 0.0/1.0]))
143 |
144 | def testThreeClasses1(self):
145 | segm = np.array([[0,2,0,0,0], [0,0,0,0,0]])
146 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
147 |
148 | res = es.mean_accuracy(segm, gt)
149 | self.assertEqual(res, np.mean([8.0/9.0, 0.0/1.0]))
150 |
151 | def testFourClasses0(self):
152 | segm = np.array([[0,2,3,0,0], [0,0,0,0,0]])
153 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
154 |
155 | res = es.mean_accuracy(segm, gt)
156 | self.assertEqual(res, np.mean([7.0/9.0, 0.0/1.0]))
157 |
158 | def testFourClasses1(self):
159 | segm = np.array([[1,2,3,0,0], [0,0,0,0,0]])
160 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
161 |
162 | res = es.mean_accuracy(segm, gt)
163 | self.assertEqual(res, np.mean([7.0/9.0, 1.0/1.0]))
164 |
165 | def testFiveClasses0(self):
166 | segm = np.array([[1,2,3,4,3], [0,0,0,0,0]])
167 | gt = np.array([[1,0,3,0,0], [0,0,0,0,0]])
168 |
169 | res = es.mean_accuracy(segm, gt)
170 | self.assertEqual(res, np.mean([5.0/8.0, 1.0, 1.0]))
171 |
172 | class mean_IU_UnitTests(unittest.TestCase):
173 | '''
174 | Wrong inputs
175 | '''
176 | def test1dInput(self):
177 | mat = np.array([0])
178 | self.assertRaises(IndexError, es.mean_IU, mat, mat)
179 |
180 | def testDiffDim(self):
181 | mat0 = np.array([[0,0], [0,0]])
182 | mat1 = np.array([[0,0,0], [0,0,0]])
183 | self.assertRaisesRegexp(es.EvalSegErr, "DiffDim", es.mean_IU, mat0, mat1)
184 |
185 | '''
186 | Correct inputs
187 | '''
188 | def testOneClass(self):
189 | segm = np.array([[0,0], [0,0]])
190 | gt = np.array([[0,0], [0,0]])
191 |
192 | res = es.mean_IU(segm, gt)
193 | self.assertEqual(res, 1.0)
194 |
195 | def testTwoClasses0(self):
196 | segm = np.array([[1,1,1,1,1], [1,1,1,1,1]])
197 | gt = np.array([[0,0,0,0,0], [0,0,0,0,0]])
198 |
199 | res = es.mean_IU(segm, gt)
200 | self.assertEqual(res, 0)
201 |
202 | def testTwoClasses1(self):
203 | segm = np.array([[1,0,0,0,0], [0,0,0,0,0]])
204 | gt = np.array([[0,0,0,0,0], [0,0,0,0,0]])
205 |
206 | res = es.mean_IU(segm, gt)
207 | self.assertEqual(res, np.mean([0.9]))
208 |
209 | def testTwoClasses2(self):
210 | segm = np.array([[0,0,0,0,0], [0,0,0,0,0]])
211 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
212 |
213 | res = es.mean_IU(segm, gt)
214 | self.assertEqual(res, np.mean([0.9, 0]))
215 |
216 | def testThreeClasses0(self):
217 | segm = np.array([[0,0,0,0,0], [0,0,0,0,0]])
218 | gt = np.array([[1,2,0,0,0], [0,0,0,0,0]])
219 |
220 | res = es.mean_IU(segm, gt)
221 | self.assertEqual(res, np.mean([8.0/10.0, 0, 0]))
222 |
223 | def testThreeClasses1(self):
224 | segm = np.array([[0,2,0,0,0], [0,0,0,0,0]])
225 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
226 |
227 | res = es.mean_IU(segm, gt)
228 | self.assertEqual(res, np.mean([8.0/10.0, 0]))
229 |
230 | def testFourClasses0(self):
231 | segm = np.array([[0,2,3,0,0], [0,0,0,0,0]])
232 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
233 |
234 | res = es.mean_IU(segm, gt)
235 | self.assertEqual(res, np.mean([7.0/10.0, 0]))
236 |
237 | def testFourClasses1(self):
238 | segm = np.array([[1,2,3,0,0], [0,0,0,0,0]])
239 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
240 |
241 | res = es.mean_IU(segm, gt)
242 | self.assertEqual(res, np.mean([7.0/9.0, 1]))
243 |
244 | def testFiveClasses0(self):
245 | segm = np.array([[1,2,3,4,3], [0,0,0,0,0]])
246 | gt = np.array([[1,0,3,0,0], [0,0,0,0,0]])
247 |
248 | res = es.mean_IU(segm, gt)
249 | self.assertEqual(res, np.mean([5.0/8.0, 1, 1.0/2.0]))
250 |
251 | class frequency_weighted_IU_UnitTests(unittest.TestCase):
252 | '''
253 | Wrong inputs
254 | '''
255 | def test1dInput(self):
256 | mat = np.array([0])
257 | self.assertRaises(IndexError, es.frequency_weighted_IU, mat, mat)
258 |
259 | def testDiffDim(self):
260 | mat0 = np.array([[0,0], [0,0]])
261 | mat1 = np.array([[0,0,0], [0,0,0]])
262 | self.assertRaisesRegexp(es.EvalSegErr, "DiffDim", es.frequency_weighted_IU, mat0, mat1)
263 |
264 | '''
265 | Correct inputs
266 | '''
267 | def testOneClass(self):
268 | segm = np.array([[0,0], [0,0]])
269 | gt = np.array([[0,0], [0,0]])
270 |
271 | res = es.frequency_weighted_IU(segm, gt)
272 | self.assertEqual(res, 1.0)
273 |
274 | def testTwoClasses0(self):
275 | segm = np.array([[1,1,1,1,1], [1,1,1,1,1]])
276 | gt = np.array([[0,0,0,0,0], [0,0,0,0,0]])
277 |
278 | res = es.frequency_weighted_IU(segm, gt)
279 | self.assertEqual(res, 0)
280 |
281 | def testTwoClasses1(self):
282 | segm = np.array([[1,0,0,0,0], [0,0,0,0,0]])
283 | gt = np.array([[0,0,0,0,0], [0,0,0,0,0]])
284 |
285 | res = es.frequency_weighted_IU(segm, gt)
286 | self.assertEqual(res, (1.0/10.0)*(10.0*9.0/10.0))
287 |
288 | def testTwoClasses2(self):
289 | segm = np.array([[0,0,0,0,0], [0,0,0,0,0]])
290 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
291 |
292 | res = es.frequency_weighted_IU(segm, gt)
293 | # Almost equal!
294 | self.assertAlmostEqual(res, (1.0/10.0)*((9.0*9.0/10.0)+(1.0*0.0/1.0)))
295 |
296 | def testThreeClasses0(self):
297 | segm = np.array([[0,0,0,0,0], [0,0,0,0,0]])
298 | gt = np.array([[1,2,0,0,0], [0,0,0,0,0]])
299 |
300 | res = es.frequency_weighted_IU(segm, gt)
301 | # Almost equal!
302 | self.assertAlmostEqual(res, (1.0/10.0)*((8.0*8.0/10.0)+(1.0*0.0/1.0)+(1.0*0.0/1.0)))
303 |
304 | def testThreeClasses1(self):
305 | segm = np.array([[0,2,0,0,0], [0,0,0,0,0]])
306 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
307 |
308 | res = es.frequency_weighted_IU(segm, gt)
309 | # Almost equal!
310 | self.assertAlmostEqual(res, (1.0/10.0)*((9.0*8.0/10.0)+(1.0*0.0/1.0)))
311 |
312 | def testFourClasses0(self):
313 | segm = np.array([[0,2,3,0,0], [0,0,0,0,0]])
314 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
315 |
316 | res = es.frequency_weighted_IU(segm, gt)
317 | self.assertEqual(res, (1.0/10.0)*((9.0*7.0/10.0)+(1.0*0.0/1.0)))
318 |
319 | def testFourClasses1(self):
320 | segm = np.array([[1,2,3,0,0], [0,0,0,0,0]])
321 | gt = np.array([[1,0,0,0,0], [0,0,0,0,0]])
322 |
323 | res = es.frequency_weighted_IU(segm, gt)
324 | self.assertEqual(res, (1.0/10.0)*((9.0*7.0/9.0)+(1.0*1.0/1.0)))
325 |
326 | def testFiveClasses0(self):
327 | segm = np.array([[1,2,3,4,3], [0,0,0,0,0]])
328 | gt = np.array([[1,0,3,0,0], [0,0,0,0,0]])
329 |
330 | res = es.frequency_weighted_IU(segm, gt)
331 | self.assertEqual(res, (1.0/10.0)*((8.0*5.0/8.0)+(1.0*1.0/1.0)+(1.0*1.0/2.0)))
332 |
333 |
334 | if __name__ == "__main__":
335 | unittest.main()
336 |
--------------------------------------------------------------------------------
/version2.py:
--------------------------------------------------------------------------------
1 | import _init_paths
2 |
3 | import os
4 | import numpy as np
5 | from PIL import Image
6 | import matplotlib.pyplot as plt
7 | from skimage import io
8 | from timer import Timer
9 | import cv2
10 | from datetime import datetime
11 |
12 | import caffe
13 |
14 | test_file = 'test.txt'
15 | file_path_img = 'JPEGImages'
16 | file_path_label = 'SegmentationClass'
17 | save_path = 'output/results'
18 |
19 | test_prototxt = 'Models/test.prototxt'
20 | weight = 'Training/Seg_iter_10000.caffemodel'
21 |
22 | layer = 'conv_seg'
23 | save_dir = False # True
24 |
25 | if save_dir:
26 | save_dir = save_path
27 | else:
28 | save_dir = False
29 |
30 | # load net
31 | net = caffe.Net(test_prototxt, weight, caffe.TEST)
32 |
33 | # load test.txt
34 | test_img = np.loadtxt(test_file, dtype=str)
35 |
36 |
37 | def fast_hist(a, b, n):
38 | k = (a >= 0) & (a < n)
39 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
40 |
41 |
42 | # seg test
43 | print
44 | '>>>', datetime.now(), 'Begin seg tests'
45 |
46 | n_cl = net.blobs[layer].channels
47 | hist = np.zeros((n_cl, n_cl))
48 |
49 | # timers
50 | _t = {'im_seg': Timer()}
51 |
52 | # load image and label
53 | i = 0
54 | for img_name in test_img:
55 | _t['im_seg'].tic()
56 | img = Image.open(os.path.join(file_path_img, img_name + '.jpg'))
57 | img = img.resize((512, 384), Image.ANTIALIAS)
58 |
59 | in_ = np.array(img, dtype=np.float32)
60 | in_ = in_[:, :, ::-1] # rgb to bgr
61 | in_ -= np.array([[[68.2117, 78.2288, 75.4916]]]) # 数据集平均值,根据需要修改
62 | in_ = in_.transpose((2, 0, 1))
63 |
64 | label = Image.open(os.path.join(file_path_label, img_name + '.png'))
65 | label = label.resize((512, 384), Image.ANTIALIAS) # 图像大小(宽,高),根据需要修改
66 | label = np.array(label, dtype=np.uint8)
67 |
68 | # shape for input (data blob is N x C x H x W), set data
69 | net.blobs['data'].reshape(1, *in_.shape)
70 | net.blobs['data'].data[...] = in_
71 |
72 | net.forward()
73 | _t['im_seg'].toc()
74 |
75 | print
76 | 'im_seg: {:d}/{:d} {:.3f}s' \
77 | .format(i + 1, len(test_img), _t['im_seg'].average_time)
78 | i += 1
79 |
80 | hist += fast_hist(label.flatten(), net.blobs[layer].data[0].argmax(0).flatten(), n_cl)
81 |
82 | if save_dir:
83 | seg = net.blobs[layer].data[0].argmax(axis=0)
84 | result = np.array(img, dtype=np.uint8)
85 | index = np.where(seg == 1)
86 | for i in xrange(len(index[0])):
87 | result[index[0][i], index[1][i], 0] = 255
88 | result[index[0][i], index[1][i], 1] = 0
89 | result[index[0][i], index[1][i], 2] = 0
90 | result = Image.fromarray(result.astype(np.uint8))
91 | result.save(os.path.join(save_dir, img_name + '.jpg'))
92 |
93 | iter = len(test_img)
94 | # overall accuracy
95 | acc = np.diag(hist).sum() / hist.sum()
96 | print
97 | '>>>', datetime.now(), 'Iteration', iter, 'overall accuracy', acc
98 | # per-class accuracy
99 | acc = np.diag(hist) / hist.sum(1)
100 | print
101 | '>>>', datetime.now(), 'Iteration', iter, 'mean accuracy', np.nanmean(acc)
102 | # per-class IU
103 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
104 | print
105 | '>>>', datetime.now(), 'Iteration', iter, 'mean IU', np.nanmean(iu)
106 | freq = hist.sum(1) / hist.sum()
107 | print
108 | '>>>', datetime.now(), 'Iteration', iter, 'fwavacc', \
109 | (freq[freq > 0] * iu[freq > 0]).sum()
110 |
--------------------------------------------------------------------------------