├── .idea
├── SGM-Net_pytorch.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── SGMNet.py
├── calculate_ct_cost
├── 1528.png
├── 25.png
├── __pycache__
│ └── ct_list.cpython-35.pyc
├── cal_ct.py
├── ct_list.py
└── ct_list.pyc
├── dataloader.py
├── libsave.so
├── loss
└── sgm_pathloss.py
├── pair_load.py
├── readpfm.py
├── redme
├── requirements.txt
├── save2ctype.cpp
├── test.py
├── train.py
└── use_result
└── load_p1p2_demo.cpp
/.idea/SGM-Net_pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
12 |
13 |
14 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 | 1596683900360
48 |
49 |
50 | 1596683900360
51 |
52 |
53 |
54 |
55 |
56 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SGM-Net_pytorch
2 | SGM-Net re-implemented with pytorch.
3 | I‘m trying to re-implement the closed-source SGM-Net project proposed by with pytorch.For of some reasons, I could just partly release the program with only path-cost.
4 |
5 | Contact: [wsywf@bupt.edu.cn](mailto:wsywf@bupt.edu.cn). Any questions or discussions are welcomed!
6 |
7 | ## Usage
8 |
9 | If you want to train the SGM-Net,you might need a initial cost_volume same as the traditional stereo-maching task.
10 | 1.You can use the mc-cnn project(https://github.com/jzbontar/mc-cnn)to get the initial cost,just like the original paper refered.
11 |
12 | 2.You can also get a ct-cost_colume with a provided python demo file ./calculate_ct_cost/cal_ct.py,after you set the data paths.
13 |
14 | ./dataloader.py --------------- To set the dataset,the datasets needed left_image and disp_image.
15 |
16 | ./SGMNet.py ------------------- The SGM-NET modle.
17 |
18 | ./train.py -------------------- Train the SGM-NET.
19 |
20 | ./loss/sgm_pathloss.py -------- The source file to calculate the path-cost and manually get the backward grad with Dynamic Programming stragety.
21 |
22 | ./test.py --------------------- To get the p1p2 params with the trained model.
23 |
24 | If you want to use the params to post-procedure with c++,you can set the save_path in save2ctype.cpp and then use command g++ -fPIC -shared -o libsave.so save2ctype.cpp to build a dynamic link library and use it to get a c_type-params-volume.
25 |
26 |
--------------------------------------------------------------------------------
/SGMNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Conv2d, Module, ReLU, MaxPool2d, AvgPool2d ,init, BatchNorm2d
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from torch import nn
6 |
7 |
8 |
9 | ##### 7x7
10 | class SGM_NET(Module):
11 | def __init__(self):
12 | super(SGM_NET, self).__init__()
13 | self.conv_1 = Conv2d(in_channels = 1, out_channels = 32, kernel_size = 3, stride = 1, padding = 0)
14 | self.conv_2 = Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 1, padding = 0)
15 | self.conv_3 = Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 0)
16 | self.fc1 = nn.Linear(in_features = 128 , out_features = 128)
17 | self.fc2 = nn.Linear(in_features = 128 , out_features = 8)
18 | self.relu = ReLU()
19 | self.avgpool = AvgPool2d(kernel_size = 2, stride=2)
20 |
21 | for m in self.modules():
22 | if isinstance(m, nn.Conv2d):
23 | init.xavier_normal_(m.weight.data)
24 | init.constant_(m.bias.data, 0.01)
25 |
26 | def forward(self, x,cod):
27 | x = self.conv_1(x)
28 | x = self.relu(x)
29 |
30 | x = self.conv_2(x)
31 | x = self.relu(x)
32 |
33 | x = self.conv_3(x)
34 | x = self.relu(x)
35 |
36 | x = x.view(-1, 128)
37 |
38 | # x = torch.cat((x,cod),1)
39 | x = self.fc1(x)
40 | x = self.relu(x)
41 |
42 | x = self.fc2(x)
43 | x = F.elu(x)
44 |
45 | x = torch.add(x,1.0)
46 |
47 | return x
48 |
49 |
--------------------------------------------------------------------------------
/calculate_ct_cost/1528.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wsywf/SGM-Net_pytorch/fa13049a9eb9a3c1431686917f28bd2b9b7d931f/calculate_ct_cost/1528.png
--------------------------------------------------------------------------------
/calculate_ct_cost/25.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wsywf/SGM-Net_pytorch/fa13049a9eb9a3c1431686917f28bd2b9b7d931f/calculate_ct_cost/25.png
--------------------------------------------------------------------------------
/calculate_ct_cost/__pycache__/ct_list.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wsywf/SGM-Net_pytorch/fa13049a9eb9a3c1431686917f28bd2b9b7d931f/calculate_ct_cost/__pycache__/ct_list.cpython-35.pyc
--------------------------------------------------------------------------------
/calculate_ct_cost/cal_ct.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import numpy as np
4 | from ct_list import census_trans,hamin_calcu
5 |
6 | image_name = '0775'
7 | data_root = 'data/driving'
8 | left_path = data_root + '/left/' + image_name + '.png'
9 | left_image = cv2.imread('/home/wsy/datasets/beidong/left_1528.png',0)
10 |
11 | right_path = data_root + '/right/' + image_name + '.png'
12 | right_image = cv2.imread('/home/wsy/datasets/beidong/right_1528.png',0)
13 |
14 |
15 | img_h , img_w = left_image.shape
16 | # ct window
17 | win_h = 6
18 | win_w = 6
19 |
20 | # max disp
21 | Dmax = 70
22 | #feature calculate
23 | ct_left = census_trans(left_image,win_w,win_h)
24 | ct_right = census_trans(right_image,win_w,win_h)
25 | co_height = img_h-2*win_h
26 | co_width = img_w-2*win_w
27 |
28 | # use hamming distance to be the cost
29 | d_cost =hamin_calcu(ct_left,ct_right,Dmax,co_height,co_width)
30 |
31 | #file_name = data_root + '/disp/' +image_name +'.npy'
32 | #np.save(file_name, d_cost)
33 | save_img = np.argmin(d_cost,0)
34 | cv2.imwrite('1528.png',save_img)
35 |
--------------------------------------------------------------------------------
/calculate_ct_cost/ct_list.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import numpy as np
4 | import torch
5 |
6 | def census_trans(img,win_w,win_h):
7 | img_h,img_w = img.shape
8 | img_trans = []
9 |
10 | for y in range(win_h,img_h-win_h):
11 | for x in range(win_w,img_w-win_w):
12 | count = 0
13 |
14 | for m in range(y-win_h,y+win_h+1):
15 | for n in range(x-win_w,x+win_w+1):
16 | count<<=1
17 | if img[m][n] > img[y][x]:
18 | count = count | 1
19 | else:
20 | count = count | 0
21 | #print("count",count,type(count))
22 | img_trans.append(count)
23 | #print("changdu",len(img_trans))
24 |
25 | return img_trans
26 |
27 |
28 |
29 | def hamin_calcu(ct_l,ct_r,Dmax,height,width):
30 |
31 | ct_cost = np.empty((Dmax,height,width))
32 | for y in range(0,height):
33 | for x in range (0,width):
34 | base_ct = ct_l[y*width+x]
35 | for d in range (0,Dmax):
36 | if x >= d:
37 | contrast_ct = ct_r[y*width+x-d]
38 |
39 | hamin_dis = bin(base_ct^contrast_ct).count('1')
40 |
41 | else:
42 | hamin_dis = 81
43 | ct_cost[d,y,x] = hamin_dis
44 | return ct_cost
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/calculate_ct_cost/ct_list.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wsywf/SGM-Net_pytorch/fa13049a9eb9a3c1431686917f28bd2b9b7d931f/calculate_ct_cost/ct_list.pyc
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import numpy as np
4 | import torch
5 | from readpfm import readPFM
6 |
7 | class KittiDataset():
8 | def __init__(self, train_file, mode='train', n_samples=None):
9 | self.mode = mode
10 | self.train_file = open(train_file,'r')
11 | self.image_lists = []
12 | for line in self.train_file:
13 | line = line.strip('\n')
14 | self.image_lists.append(line)
15 | self.data_root = 'data/kitti'
16 | #left image
17 | self.left_image = cv2.imread('data/driving/left/0350.png',0)
18 | self.left_image = self.left_image.astype(np.float32)
19 | self.max_pixel = np.max(self.left_image)
20 | #use disp image .pfm
21 | self.disp_image, _ = readPFM('data/driving/disp/0350.pfm')
22 | self.disp_image = self.disp_image.astype(np.int32)
23 | # #use disp image .png
24 | # disp_path = self.data_root + '/disp_occ_0/' + image_name + '.png'
25 | # disp_image = cv2.imread(disp_path,0)
26 |
27 | # use the ct_cost from calculate_ct_cost in local file
28 | self.d_cost = np.load('data/driving/disp/0350.npy')
29 | self.d_cost = np.array(self.d_cost)
30 | self.d_cost = torch.tensor(self.d_cost)
31 | self.d_cost = self.d_cost.type(torch.FloatTensor)
32 |
33 | #use the cost volume geted from mc-cnn https://github.com/jzbontar/mc-cnn
34 | # cost_path = self.data_root + '/leftbin/' + image_name + '.bin'
35 | # d_cost = np.memmap(cost_path, dtype=np.float32, shape=(1, 70, img_h, img_w))
36 | # p = np.isnan(d_cost)
37 | # d_cost[p] = 0.0
38 | # d_cost = torch.tensor(d_cost)
39 | # d_cost = d_cost.squeeze(0)
40 | # d_cost = d_cost.type(torch.FloatTensor) # [70,h,w]
41 |
42 |
43 | self.img_h , self.img_w = self.left_image.shape
44 | self.valid_points = []
45 | #get the valid points list
46 | for i in range(300, self.img_w-200):
47 | for j in range(150, self.img_h -150):
48 | disp_val = self.disp_image[j,i]
49 | if(disp_val>5 and disp_val < 69):
50 | self.valid_points.append([i,j,disp_val])
51 |
52 | random.shuffle(self.valid_points)
53 | self.valid_points = random.sample(self.valid_points, int(len(self.valid_points)/20))
54 |
55 | print(self.d_cost.size())
56 | print(self.img_h,self.img_w)
57 |
58 | def getlen(self):
59 | return len(self.valid_points)
60 |
61 | def getitem(self, i):
62 |
63 | train_point = self.valid_points[i]
64 | x = train_point[0]
65 | y = train_point[1]
66 | disp_val = train_point[2]
67 |
68 | image_patchs_l = []
69 | patch_cods_l = []
70 |
71 | image_patchs_r = []
72 | patch_cods_r = []
73 |
74 | image_patchs_u = []
75 | patch_cods_u = []
76 |
77 | image_patchs_d = []
78 | patch_cods_d = []
79 |
80 | for i in range(120,x):#left
81 | x0 = i-3
82 | x1 = i+4
83 | y0 = y-3
84 | y1 = y+4
85 | img_patch = self.left_image[y0:y1, x0:x1]
86 | patch_mean = np.mean(img_patch)
87 | img_patch_normlz = (img_patch - patch_mean)/self.max_pixel
88 |
89 | image_patchs_l.append(img_patch_normlz)
90 | patch_cods_l.append([i,y])
91 |
92 | for i in range(x+1,self.img_w-8):#right
93 | x0 = i-3
94 | x1 = i+4
95 | y0 = y-3
96 | y1 = y+4
97 | img_patch = self.left_image[y0:y1, x0:x1]
98 | patch_mean = np.mean(img_patch)
99 | img_patch_normlz = (img_patch - patch_mean)/self.max_pixel
100 |
101 | image_patchs_r.append(img_patch_normlz)
102 | patch_cods_r.append([i,y])
103 |
104 | for j in range(8,y):#up
105 | x0 = x-3
106 | x1 = x+4
107 | y0 = j-3
108 | y1 = j+4
109 | img_patch = self.left_image[y0:y1, x0:x1]
110 | patch_mean = np.mean(img_patch)
111 | img_patch_normlz = (img_patch - patch_mean)/self.max_pixel
112 |
113 | image_patchs_u.append(img_patch_normlz)
114 | patch_cods_u.append([x,j])
115 |
116 | for j in range(y+1,self.img_h-8):#down
117 | x0 = x-3
118 | x1 = x+4
119 | y0 = j-3
120 | y1 = j+4
121 | img_patch = self.left_image[y0:y1, x0:x1]
122 | patch_mean = np.mean(img_patch)
123 | img_patch_normlz = (img_patch - patch_mean)/self.max_pixel
124 |
125 | image_patchs_d.append(img_patch_normlz)
126 | patch_cods_d.append([x,j])
127 |
128 | origin_patch = []
129 | origin_cod = []
130 | x0 = x-3
131 | x1 = x+4
132 | y0 = y-3
133 | y1 = y+4
134 | img_patch = self.left_image[y0:y1, x0:x1]
135 | patch_mean = np.mean(img_patch)
136 | img_patch_normlz = (img_patch - patch_mean)/self.max_pixel
137 | origin_patch.append(img_patch_normlz)
138 | origin_cod.append([x,y])
139 |
140 | ### left path
141 | image_patchs_l = np.array(image_patchs_l)
142 | image_patchs_l = torch.tensor(image_patchs_l)
143 | image_patchs_l = image_patchs_l.unsqueeze(1) ## [n,1,5,5]
144 | image_patchs_l = image_patchs_l.type(torch.FloatTensor)
145 |
146 | patch_cods_l = np.array(patch_cods_l)
147 | patch_cods_l = torch.tensor(patch_cods_l)
148 | patch_cods_l = patch_cods_l.type(torch.FloatTensor)
149 | patch_cods_l[:,0]=patch_cods_l[:,0]/self.img_w
150 | patch_cods_l[:,1]=patch_cods_l[:,1]/self.img_h
151 |
152 | ### right path
153 | image_patchs_r = np.array(image_patchs_r)
154 | image_patchs_r = torch.tensor(image_patchs_r)
155 | image_patchs_r = image_patchs_r.unsqueeze(1) ## [n,1,5,5]
156 | image_patchs_r = image_patchs_r.type(torch.FloatTensor)
157 |
158 | patch_cods_r = np.array(patch_cods_r)
159 | patch_cods_r = torch.tensor(patch_cods_r)
160 | patch_cods_r = patch_cods_r.type(torch.FloatTensor)
161 | patch_cods_r[:,0]=patch_cods_r[:,0]/self.img_w
162 | patch_cods_r[:,1]=patch_cods_r[:,1]/self.img_h
163 |
164 | ### up path
165 | image_patchs_u = np.array(image_patchs_u)
166 | image_patchs_u = torch.tensor(image_patchs_u)
167 | image_patchs_u = image_patchs_u.unsqueeze(1) ## [n,1,5,5]
168 | image_patchs_u = image_patchs_u.type(torch.FloatTensor)
169 |
170 | patch_cods_u = np.array(patch_cods_u)
171 | patch_cods_u = torch.tensor(patch_cods_u)
172 | patch_cods_u = patch_cods_u.type(torch.FloatTensor)
173 | patch_cods_u[:,0]=patch_cods_u[:,0]/self.img_w
174 | patch_cods_u[:,1]=patch_cods_u[:,1]/self.img_h
175 |
176 | ### down path
177 | image_patchs_d = np.array(image_patchs_d)
178 | image_patchs_d = torch.tensor(image_patchs_d)
179 | image_patchs_d = image_patchs_d.unsqueeze(1) ## [n,1,5,5]
180 | image_patchs_d = image_patchs_d.type(torch.FloatTensor)
181 |
182 | patch_cods_d = np.array(patch_cods_d)
183 | patch_cods_d = torch.tensor(patch_cods_d)
184 | patch_cods_d = patch_cods_d.type(torch.FloatTensor)
185 | patch_cods_d[:,0]=patch_cods_d[:,0]/self.img_w
186 | patch_cods_d[:,1]=patch_cods_d[:,1]/self.img_h
187 |
188 | ### x0,y0
189 | origin_patch = np.array(origin_patch)
190 | origin_patch = torch.tensor(origin_patch)
191 | origin_patch = origin_patch.unsqueeze(1) ## [n,1,5,5]
192 | origin_patch = origin_patch.type(torch.FloatTensor)
193 |
194 | origin_cod = np.array(origin_cod)
195 | origin_cod = torch.tensor(origin_cod)
196 | origin_cod = origin_cod.type(torch.FloatTensor)
197 | origin_cod[:,0]=origin_cod[:,0]/self.img_w
198 | origin_cod[:,1]=origin_cod[:,1]/self.img_h
199 |
200 | ### match cost
201 | d_cost_l = self.d_cost[:,y-8,112:x-7]
202 | d_cost_r = self.d_cost[:,y-8,x-8:self.img_w-16]
203 | d_cost_u = self.d_cost[:,0:y-7,x-8]
204 | d_cost_d = self.d_cost[:,y-8:self.img_h-16,x-8]
205 | d_cost_x0 = self.d_cost[:,y-8,x-8]
206 |
207 | return image_patchs_l, patch_cods_l, image_patchs_r, patch_cods_r, image_patchs_u, patch_cods_u, image_patchs_d, patch_cods_d, origin_patch, origin_cod ,disp_val, d_cost_l , d_cost_r, d_cost_u, d_cost_d,d_cost_x0, x, y
208 |
209 |
--------------------------------------------------------------------------------
/libsave.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wsywf/SGM-Net_pytorch/fa13049a9eb9a3c1431686917f28bd2b9b7d931f/libsave.so
--------------------------------------------------------------------------------
/loss/sgm_pathloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 |
5 | def left_aggregation(d_cost_l, p1p2_l, p1_left_ptr,p2_left_ptr): ### [Dmax, nl], [nl , 8]
6 | Dmax, num_pix = d_cost_l.size()
7 | last_Lr=d_cost_l[:,0]
8 | grad = torch.zeros(Dmax,num_pix)
9 |
10 | for i in range(1,num_pix):
11 | minlast=torch.min(last_Lr)
12 | next_Lr=d_cost_l[:,i]
13 |
14 | tmp_grad = grad.clone()
15 | for nextd in range(0,Dmax):
16 | last_Lrp2 = last_Lr.clone()
17 |
18 | min0 = next_Lr[nextd] + last_Lr[nextd]
19 | last_Lrp2[nextd] = 1000.0
20 |
21 | min1_pos = 0
22 | if nextd == 0:
23 | min1 = next_Lr[nextd] + last_Lr[nextd+1] + p1p2_l[i-1,p1_left_ptr]
24 | last_Lrp2[nextd+1] = 1000.0
25 | min1_pos = nextd+1
26 |
27 | elif nextd == Dmax-1:
28 | min1 = next_Lr[nextd] + last_Lr[nextd-1] + p1p2_l[i-1,p1_left_ptr]
29 | last_Lrp2[nextd-1] = 1000.0
30 | min1_pos = nextd-1
31 | else:
32 | min1u = next_Lr[nextd] + last_Lr[nextd+1] + p1p2_l[i-1,p1_left_ptr]
33 | min1d = next_Lr[nextd] + last_Lr[nextd-1] + p1p2_l[i-1,p1_left_ptr]
34 | if(min1u < min1d):
35 | min1 = min1u
36 | last_Lrp2[nextd+1] = 1000.0
37 | min1_pos = nextd+1
38 | else:
39 | min1 = min1d
40 | last_Lrp2[nextd-1] = 1000.0
41 | min1_pos = nextd-1
42 |
43 | min2_pos = torch.argmin(last_Lrp2,dim=0).item()
44 | min2 = last_Lr[min2_pos] + next_Lr[nextd] + p1p2_l[i-1,p2_left_ptr]
45 |
46 | if(min1 < min0 and min1 < min2):
47 | next_Lr[nextd] = min1 - minlast
48 | grad[nextd] = tmp_grad[min1_pos]
49 | grad[nextd,i] = 1
50 | elif(min2 < min0 and min2 < min1):
51 | next_Lr[nextd] = min2 - minlast
52 | grad[nextd] = tmp_grad[min2_pos]
53 | grad[nextd,i] = 2
54 | else:
55 | next_Lr[nextd] = min0 - minlast
56 | grad[nextd] = tmp_grad[nextd]
57 | last_Lr = next_Lr
58 | return last_Lr,grad
59 |
60 | def right_aggregation(d_cost_r, p1p2_r ,p1_right_ptr,p2_right_ptr):
61 | Dmax, num_pix = d_cost_r.size()
62 | last_Lr=d_cost_r[:,num_pix-1]
63 | grad = torch.zeros(Dmax,num_pix)
64 |
65 | for i in range(num_pix-2,-1,-1):
66 | minlast=torch.min(last_Lr)
67 | next_Lr=d_cost_r[:,i]
68 |
69 | tmp_grad = grad.clone()
70 | for nextd in range(0,Dmax):
71 | last_Lrp2 = last_Lr.clone()
72 |
73 | min0 = next_Lr[nextd] + last_Lr[nextd]
74 | last_Lrp2[nextd] = 1000.0
75 |
76 | min1_pos = 0
77 | if nextd == 0:
78 | min1 = next_Lr[nextd] + last_Lr[nextd+1] + p1p2_r[i,p1_right_ptr]
79 | last_Lrp2[nextd+1] = 1000.0
80 | min1_pos = nextd +1
81 | elif nextd == Dmax-1:
82 | min1 = next_Lr[nextd] + last_Lr[nextd-1] + p1p2_r[i,p1_right_ptr]
83 | last_Lrp2[nextd-1] = 1000.0
84 | min1_pos = nextd-1
85 | else:
86 | min1u = next_Lr[nextd] + last_Lr[nextd+1] + p1p2_r[i,p1_right_ptr]
87 | min1d = next_Lr[nextd] + last_Lr[nextd-1] + p1p2_r[i,p1_right_ptr]
88 | if(min1u < min1d):
89 | min1 = min1u
90 | last_Lrp2[nextd+1] = 1000.0
91 | min1_pos = nextd+1
92 | else:
93 | min1 = min1d
94 | last_Lrp2[nextd-1] = 1000.0
95 | min1_pos = nextd-1
96 |
97 | min2_pos = torch.argmin(last_Lrp2,dim=0).item()
98 | min2 = last_Lr[min2_pos] + next_Lr[nextd] + p1p2_r[i,p2_right_ptr]
99 |
100 | if(min1 < min0 and min1 < min2):
101 | next_Lr[nextd] = min1 - minlast
102 | grad[nextd] = tmp_grad[min1_pos]
103 | grad[nextd,i] = 1
104 | elif(min2 < min0 and min2 < min1):
105 | next_Lr[nextd] = min2 - minlast
106 | grad[nextd] = tmp_grad[min2_pos]
107 | grad[nextd,i] = 2
108 | else:
109 | next_Lr[nextd] = min0 - minlast
110 | grad[nextd] = tmp_grad[nextd]
111 | last_Lr = next_Lr
112 | return last_Lr,grad
113 |
114 |
115 | def down_aggregation(d_cost_d,p1p2_d , p1_down_ptr,p2_down_ptr):
116 | Dmax, num_pix = d_cost_d.size()
117 | last_Lr=d_cost_d[:,num_pix-1]
118 | grad = torch.zeros(Dmax,num_pix)
119 |
120 | for i in range(num_pix-2,-1,-1):
121 | minlast=torch.min(last_Lr)
122 | next_Lr=d_cost_d[:,i]
123 |
124 | tmp_grad = grad.clone()
125 | for nextd in range(0,Dmax):
126 | last_Lrp2 = last_Lr.clone()
127 | min0 = next_Lr[nextd] + last_Lr[nextd]
128 | last_Lrp2[nextd] = 1000.0
129 |
130 | min1_pos = 0
131 | if nextd == 0:
132 | min1 = next_Lr[nextd] + last_Lr[nextd+1] + p1p2_d[i,p1_down_ptr]
133 | last_Lrp2[nextd+1] = 1000.0
134 | min1_pos = nextd+1
135 | elif nextd == Dmax-1:
136 | min1 = next_Lr[nextd] + last_Lr[nextd-1] + p1p2_d[i,p1_down_ptr]
137 | last_Lrp2[nextd-1] = 1000.0
138 | min1_pos = nextd-1
139 | else:
140 | min1u = next_Lr[nextd] + last_Lr[nextd+1] + p1p2_d[i,p1_down_ptr]
141 | min1d = next_Lr[nextd] + last_Lr[nextd-1] + p1p2_d[i,p1_down_ptr]
142 | if(min1u < min1d):
143 | min1 = min1u
144 | last_Lrp2[nextd+1] = 1000.0
145 | min1_pos = nextd+1
146 | else:
147 | min1 = min1d
148 | last_Lrp2[nextd-1] = 1000.0
149 | min1_pos = nextd-1
150 |
151 | min2_pos = torch.argmin(last_Lrp2,dim=0).item()
152 | min2 = last_Lr[min2_pos] + next_Lr[nextd] + p1p2_d[i,p2_down_ptr]
153 |
154 | if(min1 < min0 and min1 < min2):
155 | next_Lr[nextd] = min1 - minlast
156 | grad[nextd] = tmp_grad[min1_pos]
157 | grad[nextd,i] = 1
158 | elif(min2 < min0 and min2 < min1):
159 | next_Lr[nextd] = min2 - minlast
160 | grad[nextd] = tmp_grad[min2_pos]
161 | grad[nextd,i] = 2
162 | else:
163 | next_Lr[nextd] = min0 - minlast
164 | grad[nextd] = tmp_grad[nextd]
165 | last_Lr = next_Lr
166 | return last_Lr,grad
167 |
168 | def up_aggregation(d_cost_u,p1p2_u, p1_up_ptr,p2_up_ptr):
169 | Dmax, num_pix = d_cost_u.size()
170 | last_Lr=d_cost_u[:,0]
171 | grad = torch.zeros(Dmax,num_pix)
172 |
173 | for i in range(1,num_pix):
174 | minlast=torch.min(last_Lr)
175 | next_Lr=d_cost_u[:,i]
176 |
177 | tmp_grad = grad.clone()
178 | for nextd in range(0,Dmax):
179 | last_Lrp2 = last_Lr.clone()
180 | min0 = next_Lr[nextd] + last_Lr[nextd]
181 | last_Lrp2[nextd] = 1000.0
182 |
183 | min1_pos = 0
184 | if nextd == 0:
185 | min1 = next_Lr[nextd] + last_Lr[nextd+1] + p1p2_u[i-1,p1_up_ptr]
186 | last_Lrp2[nextd+1] = 1000.0
187 | min1_pos = nextd+1
188 | elif nextd == Dmax-1:
189 | min1 = next_Lr[nextd] + last_Lr[nextd-1] + p1p2_u[i-1,p1_up_ptr]
190 | last_Lrp2[nextd-1] = 1000.0
191 | min1_pos = nextd-1
192 | else:
193 | min1u = next_Lr[nextd] + last_Lr[nextd+1] + p1p2_u[i-1,p1_up_ptr]
194 | min1d = next_Lr[nextd] + last_Lr[nextd-1] + p1p2_u[i-1,p1_up_ptr]
195 | if(min1u < min1d):
196 | min1 = min1u
197 | last_Lrp2[nextd+1] = 1000.0
198 | min1_pos = nextd +1
199 | else:
200 | min1 = min1d
201 | last_Lrp2[nextd-1] = 1000.0
202 | min1_pos = nextd-1
203 |
204 | min2_pos = torch.argmin(last_Lrp2,dim=0).item()
205 | min2 = last_Lr[min2_pos] + next_Lr[nextd] + p1p2_u[i-1,p2_up_ptr]
206 |
207 | if(min1 < min0 and min1 < min2):
208 | next_Lr[nextd] = min1 - minlast
209 | grad[nextd] = tmp_grad[min1_pos]
210 | grad[nextd,i] = 1
211 | elif(min2 < min0 and min2 < min1):
212 | next_Lr[nextd] = min2 - minlast
213 | grad[nextd] = tmp_grad[min2_pos]
214 | grad[nextd,i] = 2
215 | else:
216 | next_Lr[nextd] = min0 - minlast
217 | grad[nextd] = tmp_grad[nextd]
218 | last_Lr = next_Lr
219 | return last_Lr,grad
220 |
221 | def pathloss(x,y,disp_val,d_cost_l,d_cost_r,d_cost_u,d_cost_d,p1p2_l,p1p2_r,p1p2_u,p1p2_d):
222 | p1_left_ptr,p2_left_ptr = 0,1
223 | p1_right_ptr,p2_right_ptr = 2,3
224 | p1_up_ptr,p2_up_ptr = 4,5
225 | p1_down_ptr,p2_down_ptr = 6,7
226 | m = 5.0
227 | disp_gt = disp_val
228 |
229 | left_aggr,left_grad = left_aggregation(d_cost_l, p1p2_l , p1_left_ptr,p2_left_ptr)
230 | right_aggr,right_grad = right_aggregation(d_cost_r, p1p2_r, p1_right_ptr,p2_right_ptr)
231 | up_aggr,up_grad= up_aggregation(d_cost_u,p1p2_u ,p1_up_ptr,p2_up_ptr)
232 | down_aggr,down_grad= down_aggregation(d_cost_d,p1p2_d,p1_down_ptr,p2_down_ptr)
233 |
234 | # print(left_aggr)
235 | # print(right_aggr)
236 | # print(up_aggr)
237 | # print(down_aggr)
238 | path_grad = torch.zeros(1,8)
239 | loss = 0
240 | left_cost_gt = left_aggr[disp_gt]
241 | num1 = (left_grad[disp_gt] == 1).nonzero().size()[0]
242 | num2 = (left_grad[disp_gt] == 2).nonzero().size()[0]
243 | for ind in range(left_aggr.size()[0]):
244 | if(ind == disp_gt):
245 | continue
246 | else:
247 | if((left_cost_gt - left_aggr[ind] +m)>0):
248 | loss = loss + left_cost_gt - left_aggr[ind] +m
249 | # loss = loss + left_aggr[ind] - left_cost_gt
250 | path_grad[0,p1_left_ptr] = path_grad[0,p1_left_ptr] + num1 - (left_grad[ind] == 1).nonzero().size()[0]
251 | path_grad[0,p2_left_ptr] = path_grad[0,p2_left_ptr] + num2 - (left_grad[ind] == 2).nonzero().size()[0]
252 |
253 | right_cost_gt = right_aggr[disp_gt]
254 | num1 = (right_grad[disp_gt] == 1).nonzero().size()[0]
255 | num2 = (right_grad[disp_gt] == 2).nonzero().size()[0]
256 | for ind in range(right_aggr.size()[0]):
257 | if(ind == disp_gt):
258 | continue
259 | else:
260 | if((right_cost_gt - right_aggr[ind] +m)>0):
261 | loss = loss + right_cost_gt - right_aggr[ind] +m
262 | # loss = loss + right_aggr[ind] - right_cost_gt
263 | path_grad[0,p1_right_ptr] = path_grad[0,p1_right_ptr] + num1 - (right_grad[ind] == 1).nonzero().size()[0]
264 | path_grad[0,p2_right_ptr] = path_grad[0,p2_right_ptr] + num2 - (right_grad[ind] == 2).nonzero().size()[0]
265 |
266 | up_cost_gt = up_aggr[disp_gt]
267 | num1 = (up_grad[disp_gt] == 1).nonzero().size()[0]
268 | num2 = (up_grad[disp_gt] == 2).nonzero().size()[0]
269 | for ind in range(up_aggr.size()[0]):
270 | if(ind == disp_gt):
271 | continue
272 | else:
273 | if((up_cost_gt - up_aggr[ind] +m)>0):
274 | loss = loss + up_cost_gt - up_aggr[ind] +m
275 | # loss = loss + up_aggr[ind] - up_cost_gt
276 | path_grad[0,p1_up_ptr] = path_grad[0,p1_up_ptr] + num1 - (up_grad[ind] == 1).nonzero().size()[0]
277 | path_grad[0,p2_up_ptr] = path_grad[0,p2_up_ptr] + num2 - (up_grad[ind] == 2).nonzero().size()[0]
278 |
279 | down_cost_gt = down_aggr[disp_gt]
280 | num1 = (down_grad[disp_gt] == 1).nonzero().size()[0]
281 | num2 = (down_grad[disp_gt] == 2).nonzero().size()[0]
282 | for ind in range(down_aggr.size()[0]):
283 | if(ind == disp_gt):
284 | continue
285 | else:
286 | if((down_cost_gt - down_aggr[ind] +m)>0):
287 | loss = loss + down_cost_gt - down_aggr[ind] +m
288 | # loss = loss + down_aggr[ind] - down_cost_gt
289 | path_grad[0,p1_down_ptr] = path_grad[0,p1_down_ptr] + num1 - (down_grad[ind] == 1).nonzero().size()[0]
290 | path_grad[0,p2_down_ptr] = path_grad[0,p2_down_ptr] + num2 - (down_grad[ind] == 2).nonzero().size()[0]
291 |
292 | left_min = torch.argmin(left_aggr,dim=0).item()
293 | right_min = torch.argmin(right_aggr,dim=0).item()
294 | up_min = torch.argmin(up_aggr,dim=0).item()
295 | down_min = torch.argmin(down_aggr,dim=0).item()
296 | disp_abs = abs(left_min-disp_gt) + abs(right_min-disp_gt) + abs(up_min-disp_gt) + abs(down_min-disp_gt)
297 | #outliter restrain
298 | if(abs(left_min-disp_gt)>=10 and abs(right_min-disp_gt)>=10 and abs(up_min-disp_gt)>=10 and abs(down_min-disp_gt)>=10):
299 | path_grad = torch.zeros(1,8)
300 | loss = 0
301 | print('x: {}\t y: {}\t gt: {}\t left: {}\t right: {}\t up: {}\t down: {}\t disp_abs: {}'.format(x, y, disp_gt, left_min,right_min,up_min,down_min,disp_abs))
302 | return loss,path_grad
303 |
304 |
--------------------------------------------------------------------------------
/pair_load.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import numpy as np
4 | import torch
5 | from readpfm import readPFM
6 | import sys
7 | import os
8 | import json
9 |
10 | data_root = '/home/wsy/datasets/training'
11 | imglist = '/home/wsy/datasets/training/image_3/lists.txt'
12 | save_root = '.'
13 |
14 | train_file = open(imglist, 'r')
15 | image_lists = []
16 | for line in train_file:
17 | line = line.strip('\n')
18 | image_lists.append(line)
19 |
20 | pair_num = 0
21 | lable = []
22 |
23 | for image_num in range(len(image_lists)):
24 | image_name = image_lists[image_num]
25 |
26 | print('image', image_name, image_num)
27 |
28 | left_path = data_root + '/image_2/' + image_name + '.png'
29 | left_image = cv2.imread(left_path)
30 | # print(left_image.shape)
31 |
32 | right_path = data_root + '/image_3/' + image_name + '.png'
33 | right_image = cv2.imread(right_path)
34 |
35 | disp_path = data_root + '/disp_occ_0/' + image_name + '.png'
36 | disp_image = cv2.imread(disp_path, 0)
37 |
38 | for res in range(7, 27, 2):
39 | res_path = os.path.join(save_root, 'res_' + str(res))
40 | if not os.path.exist(res_path)
41 | os.makedirs(res_path)
42 | res_path = os.path.join(save_root, 'res_' + str((res - 1) * 2 + 1)
43 | if not os.path.exist(res_path)
44 | os.makedirs(res_path)
45 |
46 | max_pixel_l = np.max(left_image)
47 | max_pixel_r = np.max(right_image)
48 | img_h, img_w, cha = left_image.shape
49 |
50 | vaid_points = []
51 |
52 | for i in range(200, img_w - 145):
53 | for j in range(30, img_h - 30):
54 | disp_val = disp_image[j, i]
55 | if (disp_val > 0 and disp_val < 127):
56 | vaid_points.append([i, j, disp_val])
57 |
58 | random.shuffle(vaid_points)
59 | print(len(vaid_points))
60 | vaid_points = random.sample(vaid_points, 2000)
61 | print(len(vaid_points))
62 | max_disp = 128
63 |
64 | for k in range(0, len(vaid_points)):
65 | point = vaid_points[k]
66 | x = point[0]
67 | y = point[1]
68 | disp_val = point[2]
69 |
70 | target_left_border = x - max_disp
71 |
72 | for res in range(7, 27, 2):
73 | half_window = (res - 1) / 2
74 | y0 = y - half_window
75 | y1 = y + half_window + 1
76 |
77 | left_ori = left_image[y0:y1, x - half_window:x + half_window + 1]
78 | right_strip = right_image[y0:y1, target_left_border - half_window:x + half_window + 1]
79 |
80 | res_path = os.path.join(save_root, 'res_' + str(res))
81 | left_path = os.path.join(res_path,'left_'+str(pair_num) + '.jpg')
82 | right_path = os.path.join(res_path, 'right_' + str(pair_num) + '.jpg')
83 | cv2.imwrite(left_path, left_ori)
84 | cv2.imwrite(right_path, right_strip)
85 | # 先正样本
86 | lable.append(int(1))
87 | # print(pair_num)
88 | pair_num = pair_num + 1
89 | # 再负样本
90 | lable.append(int(0))
91 | # print(pair_num)
92 | pair_num = pair_num + 1
93 | print(pair_num)
94 |
95 | # scenenflow 数据集
96 | # sc_data_root = '/home/wsy/datasets/sceneflow_driving/frames_cleanpass/35mm_focallength/scene_forwards/fast'
97 | # sc_imglist = '/home/wsy/datasets/sceneflow_driving/frames_cleanpass/35mm_focallength/scene_forwards/fast/left/lists.txt'
98 |
99 | # sc_train_file = open(sc_imglist,'r')
100 | # sc_image_lists = []
101 | # for line in sc_train_file:
102 | # line = line.strip('\n')
103 | # sc_image_lists.append(line)
104 |
105 | #
106 | # for image_num in range(len(sc_image_lists)):
107 | # image_name = sc_image_lists[image_num]
108 | #
109 | # print('image',image_name,image_num)
110 | #
111 | # left_path = sc_data_root + '/left/' + image_name + '.png'
112 | # left_image = cv2.imread(left_path)
113 |
114 | # right_path = sc_data_root + '/right/' + image_name + '.png'
115 | # right_image = cv2.imread(right_path)
116 |
117 | # disp_path = '/home/wsy/datasets/sceneflow_driving/disparity/35mm_focallength/scene_forwards/fast/left/' + image_name + '.pfm'
118 | # disp_image, _ = readPFM(disp_path)
119 | # disp_image = disp_image.astype(np.int32)
120 | #
121 | # max_pixel_l = np.max(left_image)
122 | # max_pixel_r = np.max(right_image)
123 | # img_h,img_w,cha = left_image.shape
124 | #
125 | # vaid_points1 = []
126 | #
127 | # for i in range(200,img_w-145):
128 | # for j in range(15, img_h-15):
129 | # disp_val = disp_image[j,i]
130 | # if(disp_val>0 and disp_val < 5):
131 | # vaid_points1.append([i,j,disp_val])
132 | #
133 | # random.shuffle(vaid_points1)
134 | # print(len(vaid_points1))
135 | # if len(vaid_points1) >= 300:
136 | # vaid_points1 = random.sample(vaid_points1, 300)
137 | # print(len(vaid_points1))
138 | # points2_num = 400-len(vaid_points1)
139 | #
140 | # vaid_points2 = []
141 | # for i in range(200,img_w-145):
142 | # for j in range(15, img_h-15):
143 | # disp_val = disp_image[j,i]
144 | # if(disp_val>0 and disp_val < 70):
145 | # vaid_points2.append([i,j,disp_val])
146 | #
147 | # random.shuffle(vaid_points2)
148 | # print(len(vaid_points2))
149 | # vaid_points2 = random.sample(vaid_points2, points2_num)
150 | # print(len(vaid_points2))
151 | #
152 | # vaid_points = vaid_points1 + vaid_points2
153 | # print(len(vaid_points))
154 | #
155 | #
156 | # for k in range(0,len(vaid_points)):
157 | #
158 | # point = vaid_points[k]
159 | # x = point[0]
160 | # y = point[1]
161 | # disp_val = point[2]
162 | # pos =[n for n in range(0,1)]
163 | # opos = random.sample(pos,1)[0]
164 | # xpos = x-disp_val +opos
165 | #
166 | #
167 | # neg =[n for n in range(1,5)] + [n for n in range(-4,0)]
168 | # oneg = random.sample(neg,1)[0]
169 | # xneg = x-disp_val + oneg
170 | # #13*13
171 | # y0 = y-6
172 | # y1 = y+7
173 | # left_ori = left_image[y0:y1, x-6:x+7]
174 | # right_ori_pos = right_image[y0:y1, xpos-6:xpos+7]
175 | # right_ori_neg = right_image[y0:y1, xneg-6:xneg+7]
176 | # #27*27
177 | # y0 = y-13
178 | # y1 = y+14
179 | # left_downsam = left_image[y0:y1, x-13:x+14]
180 | # right_downsam_pos = right_image[y0:y1, xpos-13:xpos+14]
181 | # right_downsam_neg = right_image[y0:y1, xneg-13:xneg+14]
182 | #
183 | # #先正样本
184 | # cv2.imwrite('./left_origin/'+str(pair_num)+'left_ori.png',left_ori)
185 | # cv2.imwrite('./right_origin/'+str(pair_num)+'right_origin.png',right_ori_pos)
186 | # cv2.imwrite('./left_downsample/'+str(pair_num)+'left_downsample.png',left_downsam)
187 | # cv2.imwrite('./right_downsample/'+str(pair_num)+'right_downsample.png',right_downsam_pos)
188 | # lable.append(int(1))
189 | ## print(pair_num)
190 | # pair_num = pair_num + 1
191 | # #再负样本
192 | # cv2.imwrite('./left_origin/'+str(pair_num)+'left_ori.png',left_ori)
193 | # cv2.imwrite('./right_origin/'+str(pair_num)+'right_origin.png',right_ori_neg)
194 | # cv2.imwrite('./left_downsample/'+str(pair_num)+'left_downsample.png',left_downsam)
195 | # cv2.imwrite('./right_downsample/'+str(pair_num)+'right_downsample.png',right_downsam_neg)
196 | # lable.append(int(0))
197 | ## print(pair_num)
198 | # pair_num = pair_num + 1
199 | # print(pair_num)
200 | # lable = np.array(lable)
201 | # np.save('./lable/lable.npy',lable)
202 |
--------------------------------------------------------------------------------
/readpfm.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sys
4 |
5 |
6 | def readPFM(file):
7 | file = open(file, 'rb')
8 |
9 | color = None
10 | width = None
11 | height = None
12 | scale = None
13 | endian = None
14 |
15 | header = file.readline().decode('utf-8').rstrip()
16 |
17 | if header == 'PF':
18 | color = True
19 | elif header == 'Pf':
20 | color = False
21 | else:
22 | raise Exception('Not a PFM file.')
23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
24 | if dim_match:
25 | width, height = map(int, dim_match.groups())
26 | else:
27 | raise Exception('Malformed PFM header.')
28 |
29 | scale = float(file.readline().rstrip())
30 | if scale < 0: # little-endian
31 | endian = '<'
32 | scale = -scale
33 | else:
34 | endian = '>' # big-endian
35 |
36 | data = np.fromfile(file, endian + 'f')
37 | shape = (height, width, 3) if color else (height, width)
38 |
39 | data = np.reshape(data, shape)
40 | data = np.flipud(data)
41 | return data, scale
42 |
--------------------------------------------------------------------------------
/redme:
--------------------------------------------------------------------------------
1 | # SGM-Net_pytorch
2 | SGM-Net re-implemented with pytorch.
3 | I‘m trying to re-implement the closed-source SGM-Net project proposed by with pytorch.For of some reasons, I could just partly release the program with only path-cost.
4 |
5 | Contact: [wsywf@bupt.edu.cn](mailto:wsywf@bupt.edu.cn). Any questions or discussions are welcomed!
6 |
7 | ## Usage
8 |
9 | If you want to train the SGM-Net,you might need a initial cost_volume same as the traditional stereo-maching task.
10 | 1.You can use the mc-cnn project(https://github.com/jzbontar/mc-cnn)to get the initial cost,just like the original paper refered.
11 |
12 | 2.You can also get a ct-cost_colume with a provided python demo file ./calculate_ct_cost/cal_ct.py,after you set the data paths.
13 |
14 | ./dataloader.py --------------- To set the dataset,the datasets needed left_image and disp_image.
15 |
16 | ./SGMNet.py ------------------- The SGM-NET modle.
17 |
18 | ./train.py -------------------- Train the SGM-NET.
19 |
20 | ./loss/sgm_pathloss.py -------- The source file to calculate the path-cost and manually get the backward grad with Dynamic Programming stragety.
21 |
22 | ./test.py --------------------- To get the p1p2 params with the trained model.
23 | If you want to use the params to post-procedure with c++,you can set the save_path in save2ctype.cpp and then use command g++ -fPIC -shared -o libsave.so save2ctype.cpp to build a dynamic link library and use it to get a c_type-params-volume.
24 |
25 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | torch==0.4.1
3 | tensorboardX
4 |
--------------------------------------------------------------------------------
/save2ctype.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | using namespace std;
7 | extern "C" {
8 | void show_matrix(float *matrix, int lens,int img_name)
9 |
10 | {
11 | cout<0):
89 | cout_valid = cout_valid + 1
90 | grad_all.append(grad)
91 |
92 | if(cout_valid == 0):
93 | trained_num = trained_num + batch_size
94 | continue
95 |
96 | loss_sum = 1.0*loss_sum/cout_valid
97 |
98 | patch_x0 = torch.cat(patch_x0, dim=0)
99 | cod_x0 = torch.cat(cod_x0, dim=0)
100 | grad_all = torch.cat(grad_all, dim=0)
101 |
102 | patch_x0 = patch_x0.to(device)
103 | cod_x0 = cod_x0.to(device)
104 | grad_all = grad_all.to(device)
105 |
106 | #use a patch to accept grad
107 | p1p2_x0s = model(patch_x0, cod_x0)
108 | ### backward
109 | optimizer.zero_grad()
110 | p1p2_x0s.backward(grad_all)
111 |
112 | optimizer.step()
113 |
114 | print(p1p2_x0s)
115 | print('Iter_num: {}\t Loss: {}\t Grad: {}'.format(epoch*iter_num + iter_num,loss_sum,grad_all))
116 | writer.add_scalar('loss',loss_sum,epoch*iter_num + iter_num)
117 | if iter_num%10 == 0:
118 | torch.save(model.state_dict(), 'work_sapce/SGMNet.pkl')
119 |
120 | trained_num = trained_num + batch_size
121 |
122 | torch.save(model.state_dict(), 'work_sapce/SGMNet_fin.pkl')
123 |
124 |
--------------------------------------------------------------------------------
/use_result/load_p1p2_demo.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 |
12 | using namespace std;
13 | using namespace cv;
14 |
15 | static float *params;
16 |
17 | int main()
18 | {
19 | Mat img = imread(img_path,0);
20 | pparams = new float[img1.rows*img1.cols*8];
21 | std::ifstream ifs("the path of param txt", std::ios::binary | std::ios::in);
22 | ifs.read((char*)pparams, sizeof(float) * img1.rows*img1.cols*8);
23 |
24 | //the structure of the p1p2 param if setted mannally can be:
25 | // for(int i=0; i< width*height*8 ;i +=8 )
26 | // {
27 | // /// down to up
28 | // pparams[i] = 5;
29 | // pparams[i+1] = 80.0;
30 | // /// left to right
31 | // pparams[i+2] = 5.0;
32 | // pparams[i+3] = 80.0;
33 | // /// up to down
34 | // pparams[i+4] = 5.0;
35 | // pparams[i+5] = 80.0;
36 | // /// right to left
37 | // pparams[i+6] = 5.0;
38 | // pparams[i+7] = 80.0;
39 | // }
40 |
--------------------------------------------------------------------------------