├── .gitignore
├── LICENSE
├── README.md
├── datagen
├── decoder.py
└── gen_data.py
├── figs
├── Network.png
└── sample_data.jpg
├── model
└── train_fcn_pytorch.py
├── requirements.txt
├── setup_pytorch.sh
├── setup_virtualenv.sh
└── start_jupyter_env.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.pyc
3 | .env/*
4 | data/*.hdf5
5 | model/data/*.hdf5
6 | model/training/*/saved_model_params
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Sambhav R. Jain
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep-Route
2 |
3 | This repository contains PyTorch implementation (pretrained weights provided) and dataset generation code for the paper
4 |
5 | **[Training a Fully Convolutional Neural Network to Route Integrated Circuits](https://arxiv.org/abs/1706.08948)**
6 |
7 | [Sambhav R. Jain](https://bit.ly/sjain-stanford), Kye Okabe
8 |
9 | arXiv-cs.CV (Computer Vision and Pattern Recognition) 2017
10 |
11 | We design and train a deep, fully convolutional neural network that learns to route a circuit layout net with appropriate choice of metal tracks and wire class combinations. Inputs to the network are the encoded layouts containing spatial location of pins to be routed. After 15 fully convolutional stages followed by a score comparator, the network outputs 8 layout layers (corresponding to 4 route layers, 3 via layers and an identity-mapped pin layer) which are then decoded to obtain the routed layouts.
12 |
13 | **Proposed FCN Model**
14 |
15 |
16 | **Training samples (left: data, right: labels) from the generated dataset**
17 |
18 |
19 | ## Install (Linux)
20 | 1. Fork [this GitHub repository](https://github.com/sjain-stanford/deep-route)
21 | 2. Setup virtualenv and install dependencies
22 | * `./setup_virtualenv.sh`
23 | 3. Install PyTorch
24 | * `./setup_pytorch.sh`
25 | 4. Activate virtualenv, start Jupyter notebook
26 | * `./start_jupyter_env.sh`
27 |
28 | ## Generate Dataset
29 | Run the script `./datagen/gen_data.py` to generate training data of shape (N, 1, H, W) and labels of shape (N, 8, H, W) stored using [HDF5 (h5py)](https://github.com/h5py/h5py). Default parameters used for the paper are `H = W = 32`, and `pin_range = (2, 6)`, but feel free to modify as desired. Generating 50,000 image dataset should take < 1 minute.
30 | ```
31 | python ./datagen/gen_data.py
32 | >> Enter the number of images to be generated: 50000
33 | mv ./data/layout_data.hdf5 ./model/data/train_50k_32pix.hdf5
34 |
35 | python ./datagen/gen_data.py
36 | >> Enter the number of images to be generated: 10000
37 | mv ./data/layout_data.hdf5 ./model/data/val_10k_32pix.hdf5
38 | ```
39 |
40 | ## Train FCN Network (in PyTorch)
41 | Switch to `./model/` dir and run the script `./train_fcn_pytorch.py` to train the FCN model with default options, or use the switch `--help` to display a list of options and their defaults.
42 | ```
43 | cd ./model/
44 | python ./train_fcn_pytorch.py --help
45 | ```
46 |
47 | ```
48 | usage: train_fcn_pytorch.py [-h] [--data PATH] [--batch_size N]
49 | [--num_workers N] [--num_epochs N] [--use_gpu]
50 | [--pretrained] [--lr LR] [--adapt_lr] [--reg REG]
51 | [--print-freq N]
52 |
53 | Deep-Route: Training a deep FCN network to route circuit layouts.
54 |
55 | optional arguments:
56 | -h, --help show this help message and exit
57 | --data PATH path to dataset (default: ./data/)
58 | --batch_size N mini-batch size (default: 100)
59 | --num_workers N number of data loading workers (default: 4)
60 | --num_epochs N number of total epochs to run (default: 200)
61 | --use_gpu use GPU if available
62 | --pretrained use pre-trained model
63 | --lr LR initial learning rate (default: 5e-4)
64 | --adapt_lr use learning rate schedule
65 | --reg REG regularization strength (default: 1e-5)
66 | --print-freq N print frequency (default: 10)
67 | ```
68 |
69 | To run on GPU, provide switch `--use_gpu`. Best model parameters (based on F-1 score on validation set) are saved to `./model/training/` dir every epoch, along with loss and training curves. If the switch `--pretrained` is provided, model is pre-loaded with saved parameters before training. Pretrained weights (for batch size 10 and 100) are made available [here](https://github.com/sjain-stanford/pretrained-weights). With `--adapt_lr`, a learning rate decay factor of 10 is applied every 30 epochs.
70 |
71 | ## Cite
72 | If you find this work useful in your research, please cite:
73 | ```
74 | @article{jain2017route,
75 | title={Training a Fully Convolutional Neural Network to Route Integrated Circuits},
76 | author={Jain, Sambhav R and Okabe, Kye},
77 | journal={arXiv preprint arXiv:1706.08948},
78 | year={2017}
79 | }
80 | ```
81 |
--------------------------------------------------------------------------------
/datagen/decoder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import h5py
4 |
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import matplotlib.pyplot as plt
8 |
9 |
10 | def decodeData(inp):
11 | """
12 | Generates decoded (RGB) images from 8 channel data
13 |
14 | Input:
15 | - inp: data of shape (N, C, H, W) (range: 0 or 1)
16 | where C = 8 (Y) or 1 (X)
17 |
18 | Output:
19 | - out: decoded (RGB) data (N, 3, H, W) (range: 0-255)
20 | """
21 |
22 | # Define RGB colors (http://www.rapidtables.com/web/color/RGB_Color.htm)
23 | M6_clr = np.array([[ 0.0, 191.0, 255.0]]) # layer 7 m6 dodger blue
24 | VIA5_clr = np.array([[255.0, 0.0, 0.0]]) # layer 6 via5 red
25 | M5_clr = np.array([[169.0, 169.0, 169.0]]) # layer 5 m5 dark gray
26 | VIA4_clr = np.array([[ 0.0, 255.0, 0.0]]) # layer 4 via4 green
27 | M4_clr = np.array([[255.0, 99.0, 71.0]]) # layer 3 m4 tomato red
28 | VIA3_clr = np.array([[ 0.0, 0.0, 255.0]]) # layer 2 via3 blue
29 | M3_clr = np.array([[ 50.0, 205.0, 50.0]]) # layer 1 m3 lime green
30 | PIN_clr = np.array([[ 0.0, 0.0, 0.0]]) # layer 0 pin black
31 | BGND_clr = np.array([[255.0, 255.0, 255.0]]) # background white
32 |
33 | # If inp is 1 image, reshape
34 | if (len(inp.shape) == 3):
35 | C, H, W = inp.shape
36 | inp = inp.reshape(1, C, H, W)
37 |
38 | N, C, H, W = inp.shape
39 | D = N * H * W
40 |
41 | # Warn if any input elements are out of range
42 | if np.any(np.greater(inp, 1)):
43 | print("Invalid input range detected (some input elements > 1)")
44 | if np.any(np.less(inp, 0)):
45 | print("Invalid input range detected (some input elements < 0)")
46 |
47 | # Initialize pseudo output & activepixelcount matrix
48 | out = np.zeros([N, 3, H, W])
49 |
50 | inp_swap = np.swapaxes(inp, 1, 0).reshape(C, D) # (C, N*H*W)
51 | out_swap = np.swapaxes(out, 1, 0).reshape(3, D) # (3, N*H*W)
52 |
53 | if C != 1:
54 | #layer 7 processing
55 | temp7 = np.broadcast_to((inp_swap[7, :] == 1), (3, D))
56 | out_swap += M6_clr.T * temp7
57 |
58 | #layer 6 processing
59 | temp6 = np.broadcast_to((inp_swap[6, :] == 1), (3, D))
60 | out_swap += VIA5_clr.T * temp6
61 |
62 | #layer 5 processing
63 | temp5 = np.broadcast_to((inp_swap[5, :] == 1), (3, D))
64 | out_swap += M5_clr.T * temp5
65 |
66 | #layer 4 processing
67 | temp4 = np.broadcast_to((inp_swap[4, :] == 1), (3, D))
68 | out_swap += VIA4_clr.T * temp4
69 |
70 | #layer 3 processing
71 | temp3 = np.broadcast_to((inp_swap[3, :] == 1), (3, D))
72 | out_swap += M4_clr.T * temp3
73 |
74 | #layer 2 processing
75 | temp2 = np.broadcast_to((inp_swap[2, :] == 1), (3, D))
76 | out_swap += VIA3_clr.T * temp2
77 |
78 | #layer 1 processing
79 | temp1 = np.broadcast_to((inp_swap[1, :] == 1), (3, D))
80 | out_swap += M3_clr.T * temp1
81 |
82 | #layer 0 processing
83 | temp0 = np.broadcast_to((inp_swap[0, :] == 1), (3, D))
84 | out_swap += PIN_clr.T * temp0
85 |
86 | # For every pixel, count the number of active layers that overlap
87 | overlap = 1.0 * np.sum(inp_swap, axis=0, keepdims=True)
88 | n_layers = overlap + (overlap == 0.0) * 1.0 # To avoid division by zero
89 |
90 | # If no active layers, fill in white
91 | temp_white = np.broadcast_to((overlap == 0), (3, D))
92 | out_swap += BGND_clr.T * temp_white
93 |
94 | # Average the colors for all active layers
95 | out_swap /= (n_layers * 1.0)
96 |
97 | out_swap = out_swap.reshape(3, N, H, W)
98 | out = np.swapaxes(out_swap, 1, 0) # (N, 3, H, W)
99 |
100 | # Warn if any output elements are out of range
101 | if np.any(np.greater(out, 255)):
102 | print ("Invalid output range detected (some output elements > 255)")
103 | if np.any(np.less(out, 0)):
104 | print ("Invalid output range detected (some output elements < 0)")
105 |
106 | return out
107 |
108 |
109 | ############################
110 | # MAIN #
111 | ############################
112 | def main():
113 | plt.rcParams['figure.figsize'] = (15.0, 10.0) # set default size of plots
114 | plt.rcParams['image.interpolation'] = 'nearest'
115 | plt.rcParams['image.cmap'] = 'gray'
116 |
117 | data = {}
118 | data_file = os.getcwd() + "/data/layout_data.hdf5"
119 |
120 | #data = h5py.File(data_file, 'r')
121 | #X = np.array(data['X']) # (N, 1, H, W)
122 | #Y = np.array(data['Y']) # (N, C, H, W)
123 |
124 | with h5py.File(data_file, 'r') as f:
125 | for k, v in f.items():
126 | data[k] = np.asarray(v)
127 |
128 | X_dec = decodeData(data['X']) # (N, 3, H, W)
129 | X_dec = np.swapaxes(X_dec, 1, 2) # (N, H, 3, W)
130 | X_dec = np.swapaxes(X_dec, 2, 3) # (N, H, W, 3)
131 |
132 | Y_dec = decodeData(data['Y']) # (N, 3, H, W)
133 | Y_dec = np.swapaxes(Y_dec, 1, 2) # (N, H, 3, W)
134 | Y_dec = np.swapaxes(Y_dec, 2, 3) # (N, H, W, 3)
135 |
136 | num_images = 6
137 |
138 | for n in range(num_images):
139 | plt.subplot(2, 3, n+1)
140 | plt.imshow(X_dec[n].astype('uint8'))
141 | #plt.axis('off')
142 | plt.tight_layout(pad=2.0, w_pad=1.0, h_pad=1.0)
143 |
144 | plt.savefig("data/sample_X.jpg")
145 | plt.savefig("data/sample_X.eps")
146 |
147 | for n in range(num_images):
148 | plt.subplot(2, 3, n+1)
149 | plt.imshow(Y_dec[n].astype('uint8'))
150 | #plt.axis('off')
151 | plt.tight_layout(pad=2.0, w_pad=1.0, h_pad=1.0)
152 |
153 | plt.savefig("data/sample_Y.jpg")
154 | plt.savefig("data/sample_Y.eps")
155 |
156 |
157 | if __name__ == '__main__':
158 | main()
159 |
--------------------------------------------------------------------------------
/datagen/gen_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import numpy as np
3 | import random
4 | import os
5 | import h5py
6 |
7 |
8 | def drawPins(input_data, xpins, ypins, map):
9 | """
10 | Draw pins on input_data
11 |
12 | Inputs:
13 | - input_data: np.array of size (C, H, W) (all zeros to start with)
14 | - xpins & ypins: np.array of x-coordinates and y-coordinates for all pins
15 | e.g., [x1, x2 ... xm] and [y1, y2, ... ym] for m pins
16 | - map: layout layer map (dictionary from layer name to index)
17 |
18 | Outputs:
19 | - output_data: np.array of size (C, H, W) (with pixels corresponding to pins in layer 'pins' set to '1')
20 | """
21 |
22 | output_data = input_data
23 |
24 | for (x, y) in zip(xpins, ypins):
25 | output_data[map['pin'], y, x] = 1
26 |
27 | return output_data
28 |
29 |
30 | def drawRoutes(input_data, xpins, ypins, xlen, ylen, xwire, ywire, map):
31 | """
32 | Draw routes on input_data
33 |
34 | Inputs:
35 | - input_data: np.array of size (C, H, W)
36 | - xpins & ypins: np.array of x-coordinates and y-coordinates for all pins
37 | e.g., [x1, x2 ... xm] and [y1, y2, ... ym] for m pins
38 | - xlen & ylen: horizontal and vertical max length (in um)
39 | - xwire & ywire: horizontal and vertical wire class to route; such as ('m4', 'm3')
40 | - xavg & yavg: horizontal and vertical average points for placing branch of route
41 | - map: layout layer map (dictionary from layer name to index)
42 |
43 | Outputs:
44 | - output_data: np.array of size (C, H, W)
45 | """
46 |
47 | output_data = input_data
48 |
49 | xavg = int(np.average(xpins))
50 | yavg = int(np.average(ypins))
51 |
52 | if xlen > ylen:
53 | # Draw horizontal branch first
54 | xmin = min(xpins)
55 | xmax = max(xpins)
56 | output_data[map[xwire], yavg, xmin:xmax+1] = 1
57 |
58 | # Draw vertical legs from all pins to branch
59 | # Except in the special case of nPins=2 and pins being x-aligned, where only horizontal branch is needed
60 | if ywire is not None:
61 | for n in range(len(xpins)):
62 | if ypins[n] > yavg:
63 | output_data[map[ywire], yavg:ypins[n]+1, xpins[n]] = 1
64 | else:
65 | output_data[map[ywire], ypins[n]:yavg+1, xpins[n]] = 1
66 | else:
67 | # Draw vertical branch first
68 | ymin = min(ypins)
69 | ymax = max(ypins)
70 | output_data[map[ywire], ymin:ymax+1, xavg] = 1
71 |
72 | # Draw horizontal legs from all pins to branch
73 | # Except in the special case of nPins=2 and pins being y-aligned, where only vertical branch is needed
74 | if xwire is not None:
75 | for n in range(len(xpins)):
76 | if xpins[n] > xavg:
77 | output_data[map[xwire], ypins[n], xavg:xpins[n]+1] = 1
78 | else:
79 | output_data[map[xwire], ypins[n], xpins[n]:xavg+1] = 1
80 |
81 | # Draw vias
82 | output_data[map['via3']] = (output_data[map['m3']] == 1) * (output_data[map['m4']] == 1) * 1
83 | output_data[map['via4']] = (output_data[map['m4']] == 1) * (output_data[map['m5']] == 1) * 1
84 | output_data[map['via5']] = (output_data[map['m5']] == 1) * (output_data[map['m6']] == 1) * 1
85 |
86 | return output_data
87 |
88 |
89 | def selectWireClass(xlen, ylen):
90 | """
91 | Draw pins on input_data
92 |
93 | Inputs:
94 | - xlen & ylen: horizontal and vertical max length (in um)
95 |
96 | Outputs:
97 | - wire_class: tuple of (xwire, ywire); such as ('m4', 'm3')
98 | """
99 | # 7FF Wire Performance (from RCCalc)
100 | # R = rho * l / (t * w)
101 | # rho / (t * w) = R / l
102 | # Multiply rho_tw by length to get wire resistance, add via resistance to this
103 |
104 | # m2 1x w=20nm horz
105 | rho_tw_m2_1x = 106.45 # ohm/um
106 | # m3 2x w=40nm vert
107 | rho_tw_m3_2x = 30.555 # ohm/um
108 | # m4 1x w=40nm horz
109 | rho_tw_m4_1x = 24.077 # ohm/um
110 | # m5 1x w=38nm vert
111 | rho_tw_m5_1x = 18.368 # ohm/um
112 | # m6 1x w=40nm horz
113 | rho_tw_m6_1x = 16.950 # ohm/um
114 | # m7 1x w=76nm vert
115 | rho_tw_m7_1x = 8.854 # ohm/um
116 |
117 | # via resistance
118 | r_via2 = 40 # ohm/ct
119 | r_via3 = 30 # ohm/ct
120 | r_via4 = 12 # ohm/ct
121 | r_via5 = 12 # ohm/ct
122 | r_via6 = 12 # ohm/ct
123 |
124 | # Assumption - All routes run to/from two m2 pins
125 | R_m2 = (rho_tw_m2_1x * xlen)
126 | R_m3 = (rho_tw_m3_2x * ylen) + (r_via2 * 2)
127 | R_m4 = (rho_tw_m4_1x * xlen) + (r_via2 * 2) + (r_via3 * 2)
128 | R_m5 = (rho_tw_m5_1x * ylen) + (r_via2 * 2) + (r_via3 * 2) + (r_via4 * 2)
129 | R_m6 = (rho_tw_m6_1x * xlen) + (r_via2 * 2) + (r_via3 * 2) + (r_via4 * 2) + (r_via5 * 2)
130 | R_m7 = (rho_tw_m7_1x * ylen) + (r_via2 * 2) + (r_via3 * 2) + (r_via4 * 2) + (r_via5 * 2) + (r_via6 * 2)
131 |
132 | # Choose the less resistant wire class based on whether vert or horz length is dominant
133 | if xlen > ylen:
134 | if R_m4 > R_m6:
135 | xwire = 'm6'
136 | ywire = 'm5'
137 | else:
138 | xwire = 'm4'
139 | if R_m3 > R_m5:
140 | ywire = 'm5'
141 | else:
142 | ywire = 'm3'
143 | else:
144 | if R_m5 > R_m3:
145 | ywire = 'm3'
146 | xwire = 'm4'
147 | else:
148 | ywire = 'm5'
149 | if R_m4 > R_m6:
150 | xwire = 'm6'
151 | else:
152 | xwire = 'm4'
153 |
154 | # Special case: 2 pins that are either x-aligned or y-aligned
155 | if xlen == 0:
156 | xwire = None
157 | if ylen == 0:
158 | ywire = None
159 |
160 | wire_class = (xwire, ywire)
161 |
162 | return wire_class
163 |
164 |
165 | def genData(N=10, H=32, W=32, pin_range=(2, 6)):
166 | """
167 | Generates decoded training image dataset of size (N, C, H, W)
168 |
169 | Inputs:
170 | - N: number of images to generate
171 | - H, W: image height, width (in px)
172 | - pin_range: tuple (low, high) for allowed range of number of pins; "half-open" interval [low, high)
173 | (e.g. (2, 6) pins means 2 or 3 or 4 or 5 pins)
174 |
175 | Outputs:
176 | - Saves X: training data (N, 1, H, W)
177 | - Saves Y: training labels (N, C, H, W)
178 | where C = 8 (each corresponds to a layout layer, viz.
179 | [pin, m3, via3, m4, via4, m5, via5, m6])
180 |
181 | X[:, 0, :, :] pin (m2)
182 |
183 | Y[:, 0, :, :] pin (m2) - same as X[:, 0, :, :]
184 | Y[:, 1, :, :] m3 (vert)
185 | Y[:, 2, :, :] via3
186 | Y[:, 3, :, :] m4 (horz)
187 | Y[:, 4, :, :] via4
188 | Y[:, 5, :, :] m5 (vert)
189 | Y[:, 6, :, :] via5
190 | Y[:, 7, :, :] m6 (horz)
191 | """
192 |
193 | # 8 layout layers for now
194 | C = 8
195 |
196 | data_dir = os.getcwd() + "/data/"
197 | if os.path.exists(data_dir):
198 | for file in os.listdir(data_dir):
199 | if file.endswith(".hdf5"):
200 | os.remove(data_dir+file)
201 | else:
202 | os.makedirs(data_dir)
203 |
204 | data = h5py.File(data_dir + "layout_data.hdf5")
205 |
206 | # numpy arrays no longer needed; use HDF5 instead
207 | #X = np.zeros([N, 1, H, W], dtype = np.int8)
208 | #Y = np.zeros([N, C, H, W], dtype = np.int8)
209 |
210 | X = data.create_dataset("X", shape=(N, 1, H, W), dtype='uint8', compression='lzf', chunks=(1, 1, H, W))
211 | Y = data.create_dataset("Y", shape=(N, C, H, W), dtype='uint8', compression='lzf', chunks=(1, 1, H, W))
212 |
213 | # Set physical size represented by HxW pixels
214 | microns = 11.0 # To have balanced dataset covering from m3 to m6 (based on resistance plots from resistance_vs_distance.ipynb)
215 | microns_per_xpixel = microns/W
216 | microns_per_ypixel = microns/H
217 |
218 | # Layer map
219 | l_map = {
220 | # Pins
221 | 'pin' : 0,
222 |
223 | # Vias
224 | 'via3' : 2,
225 | 'via4' : 4,
226 | 'via5' : 6,
227 |
228 | # Vertical tracks
229 | 'm3' : 1,
230 | 'm5' : 5,
231 |
232 | # Horizontal tracks
233 | 'm4' : 3,
234 | 'm6' : 7
235 | }
236 |
237 | #m3_m4 = m5_m4 = m5_m6 = 0
238 |
239 | n = 0
240 | print_every = 5000
241 |
242 | while n < N:
243 | # Randomly select number of pins from given range
244 | # Uniform distribution over pin range
245 | nPins = np.random.randint(*pin_range)
246 | # Non-uniform distribution (skewed exponentially towards smaller number of pins)
247 | #p_range = np.array(range(*pin_range))
248 | #p = np.exp(-p_range) / np.sum(np.exp(-p_range))
249 | #nPins = np.random.choice(p_range, p=p)
250 |
251 | # Randomly pick x and y co-ords for nPins from [0, W) and [0, H) pixels
252 | x_pins = np.random.randint(W, size=nPins)
253 | y_pins = np.random.randint(H, size=nPins)
254 |
255 | max_xlen = (max(x_pins) - min(x_pins)) * microns_per_xpixel # length in um
256 | max_ylen = (max(y_pins) - min(y_pins)) * microns_per_ypixel # length in um
257 |
258 | # Corner case when pins overlap each other (invalid case)
259 | # Bug fix for https://github.com/sjain-stanford/RouteAI/issues/4
260 | if (max_xlen == 0) and (max_ylen == 0):
261 | continue
262 |
263 | # Draw pins on layer 'pin (m2)' of both X (data) and Y (labels)
264 | X[n] = drawPins(X[n], x_pins, y_pins, l_map)
265 | Y[n] = drawPins(Y[n], x_pins, y_pins, l_map)
266 |
267 | # Add routes to Y (labels)
268 | x_wire, y_wire = selectWireClass(max_xlen, max_ylen)
269 | Y[n] = drawRoutes(Y[n], x_pins, y_pins, max_xlen, max_ylen, x_wire, y_wire, l_map)
270 |
271 | n += 1
272 |
273 | if (n % print_every == 0):
274 | print("Finished generating %d samples." %(n))
275 |
276 | #if x_wire == 'm4' and y_wire == 'm3':
277 | # m3_m4 += 1
278 | #elif x_wire == 'm4' and y_wire == 'm5':
279 | # m5_m4 += 1
280 | #elif x_wire == 'm6' and y_wire == 'm5':
281 | # m5_m6 += 1
282 | #else:
283 | # print(x_wire, y_wire)
284 |
285 | #print(m3_m4, m5_m4, m5_m6)
286 |
287 | # Storing as .npy using np.save -> Issue: RAM out of memory, disk memory limitation
288 | #data_dir = os.getcwd() + '/data/'
289 | #if os.path.exists(data_dir):
290 | # for file in os.listdir(data_dir):
291 | # if file.endswith(".npy"):
292 | # os.remove(data_dir+file)
293 | #else:
294 | # os.makedirs(data_dir)
295 | #X_save = data_dir + 'X_save.npy'
296 | #Y_save = data_dir + 'Y_save.npy'
297 | #np.save(X_save, X, allow_pickle=False)
298 | #np.save(Y_save, Y, allow_pickle=False)
299 |
300 | print("Dataset generated as follows:")
301 | for ds in data:
302 | print(ds, data[ds])
303 |
304 |
305 | ############################
306 | # MAIN #
307 | ############################
308 | def main():
309 | N = int(input('Enter the number of images to be generated: '))
310 | genData(N)
311 |
312 |
313 | if __name__ == '__main__':
314 | main()
--------------------------------------------------------------------------------
/figs/Network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sjain-stanford/deep-route/ef8934457a37255f62fcdc99780ff11f2e7d3d4f/figs/Network.png
--------------------------------------------------------------------------------
/figs/sample_data.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sjain-stanford/deep-route/ef8934457a37255f62fcdc99780ff11f2e7d3d4f/figs/sample_data.jpg
--------------------------------------------------------------------------------
/model/train_fcn_pytorch.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import sys
5 |
6 | import h5py
7 | import numpy as np
8 |
9 | import matplotlib
10 | matplotlib.use('Agg')
11 | import matplotlib.pyplot as plt
12 | plt.rcParams['figure.figsize'] = (10.0, 10.0) # set default size of plots
13 | plt.rcParams['image.interpolation'] = 'nearest'
14 | plt.rcParams['image.cmap'] = 'gray'
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.optim as optim
19 |
20 | from torch.autograd import Variable
21 | from torch.utils.data import DataLoader, TensorDataset
22 |
23 | sys.path.append('../')
24 | from datagen.decoder import decodeData
25 |
26 | parser = argparse.ArgumentParser(description='Deep-Route: Training a deep FCN network to route circuit layouts.')
27 | parser.add_argument('--data', metavar='PATH', default=os.getcwd()+'/data/', help='path to dataset (default: ./data/)')
28 | parser.add_argument('--batch_size', metavar='N', default=100, type=int, help='mini-batch size (default: 100)')
29 | parser.add_argument('--num_workers', metavar='N', default=4, type=int, help='number of data loading workers (default: 4)')
30 | parser.add_argument('--num_epochs', metavar='N', default=200, type=int, help='number of total epochs to run (default: 200)')
31 | parser.add_argument('--use_gpu', action='store_true', help='use GPU if available')
32 | parser.add_argument('--pretrained', action='store_true', help='use pre-trained model')
33 | parser.add_argument('--lr', metavar='LR', default=5e-4, type=float, help='initial learning rate (default: 5e-4)')
34 | parser.add_argument('--adapt_lr', action='store_true', help='use learning rate schedule')
35 | parser.add_argument('--reg', metavar='REG', default=1e-5, type=float, help='regularization strength (default: 1e-5)')
36 | parser.add_argument('--print-freq', metavar='N', default=10, type=int, help='print frequency (default: 10)')
37 |
38 |
39 | def main(args):
40 | # Unutilized GPU notification
41 | if torch.cuda.is_available() and not args.use_gpu:
42 | print("GPU is available. Provide command line flag --use_gpu to use it!")
43 |
44 | # To run on GPU, specify command-line flag --use_gpu
45 | if args.use_gpu and torch.cuda.is_available():
46 | dtype = torch.cuda.FloatTensor
47 | else:
48 | dtype = torch.FloatTensor
49 |
50 | # Dataset filenames
51 | train_fname = 'train_50k_32pix.hdf5'
52 | val_fname = 'val_10k_32pix.hdf5'
53 |
54 | # Save dir
55 | train_id = 'train50k_val10k_pix32' + '_lr' + str(args.lr) + '_reg' + str(args.reg) + '_batchsize' + str(args.batch_size) + '_epochs' + str(args.num_epochs) + '_gpu' + str(args.use_gpu)
56 | save_dir = os.getcwd() + '/training/' + train_id + '/'
57 | if not os.path.exists(save_dir):
58 | os.makedirs(save_dir)
59 |
60 | # Weighted loss to overcome unbalanced dataset (>98% pixels are off ('0'))
61 | weight = torch.Tensor([1, 3]).type(dtype)
62 |
63 | # Read dataset at path provided by --data command-line flag
64 | train_data = h5py.File(args.data + train_fname, 'r')
65 | X_train = np.asarray(train_data['X']) # Data: X_train.shape = (N, 1, H, W); X_train.dtype = uint8
66 | Y_train = np.asarray(train_data['Y']) # Labels: Y_train.shape = (N, 8, H, W); Y_train.dtype = uint8
67 | print("X_train: %s \nY_train: %s\n" %(X_train.shape, Y_train.shape))
68 |
69 | val_data = h5py.File(args.data + val_fname, 'r')
70 | X_val = np.asarray(val_data['X'])
71 | Y_val = np.asarray(val_data['Y'])
72 | print("X_val: %s \nY_val: %s\n" %(X_val.shape, Y_val.shape))
73 |
74 | # Dimensions
75 | N_train = X_train.shape[0]
76 | N_val = X_val.shape[0]
77 | C = Y_train.shape[1]
78 | H = X_train.shape[2]
79 | W = X_train.shape[3]
80 | dims_X = [-1, 1, H, W]
81 | dims_Y = [-1, C, H, W]
82 |
83 | # Setup DataLoader
84 | # https://stackoverflow.com/questions/41924453/pytorch-how-to-use-dataloaders-for-custom-datasets
85 | # PyTorch tensors are of type torch.ByteTensor (8 bit unsigned int)
86 | # Stored as 2D --> train: (N, 1*H*W), val: (N, 8*H*W)
87 | train_dset = TensorDataset(torch.from_numpy(X_train).view(N_train, -1),
88 | torch.from_numpy(Y_train).view(N_train, -1))
89 |
90 | train_loader = DataLoader(train_dset, batch_size=args.batch_size,
91 | # Disable shuffling in debug mode
92 | #num_workers=args.num_workers, shuffle=False)
93 | num_workers=args.num_workers, shuffle=True)
94 |
95 | val_dset = TensorDataset(torch.from_numpy(X_val).view(N_val, -1),
96 | torch.from_numpy(Y_val).view(N_val, -1))
97 |
98 | val_loader = DataLoader(val_dset, batch_size=args.batch_size,
99 | num_workers=args.num_workers, shuffle=False)
100 |
101 | # Define NN architecture
102 | model = nn.Sequential( # Input (N, 1, 32, 32)
103 | nn.Conv2d(1, 16, kernel_size=33, stride=1, padding=16, bias=True), # Output (N, 16, 32, 32)
104 | nn.BatchNorm2d(16),
105 | nn.LeakyReLU(inplace=True),
106 |
107 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
108 | nn.BatchNorm2d(16),
109 | nn.LeakyReLU(inplace=True),
110 |
111 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
112 | nn.BatchNorm2d(16),
113 | nn.LeakyReLU(inplace=True),
114 |
115 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
116 | nn.BatchNorm2d(16),
117 | nn.LeakyReLU(inplace=True),
118 |
119 | # Layer 5
120 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
121 | nn.BatchNorm2d(16),
122 | nn.LeakyReLU(inplace=True),
123 |
124 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
125 | nn.BatchNorm2d(16),
126 | nn.LeakyReLU(inplace=True),
127 |
128 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
129 | nn.BatchNorm2d(16),
130 | nn.LeakyReLU(inplace=True),
131 |
132 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
133 | nn.BatchNorm2d(16),
134 | nn.LeakyReLU(inplace=True),
135 |
136 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
137 | nn.BatchNorm2d(16),
138 | nn.LeakyReLU(inplace=True),
139 |
140 | # Layer 10
141 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
142 | nn.BatchNorm2d(16),
143 | nn.LeakyReLU(inplace=True),
144 |
145 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
146 | nn.BatchNorm2d(16),
147 | nn.LeakyReLU(inplace=True),
148 |
149 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
150 | nn.BatchNorm2d(16),
151 | nn.LeakyReLU(inplace=True),
152 |
153 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
154 | nn.BatchNorm2d(16),
155 | nn.LeakyReLU(inplace=True),
156 |
157 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
158 | nn.BatchNorm2d(16),
159 | nn.LeakyReLU(inplace=True),
160 |
161 | # Layer 15
162 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32)
163 | )
164 |
165 | # Load pretrained model parameters
166 | if args.pretrained:
167 | model.load_state_dict(torch.load(save_dir + '../saved_model_params'))
168 |
169 | # Cast model to the correct datatype
170 | model.type(dtype)
171 |
172 | loss_fn = nn.CrossEntropyLoss(weight=weight).type(dtype)
173 |
174 | # Use Adam optimizer with default betas
175 | optimizer = optim.Adam(model.parameters(), lr=args.lr,
176 | betas=(0.9, 0.999), weight_decay=args.reg)
177 | loss_history = []
178 | train_precision = [0]
179 | train_recall = [0]
180 | train_f1score = [0]
181 | val_precision = [0]
182 | val_recall = [0]
183 | val_f1score = [0]
184 |
185 | best_val_f1score = 0
186 |
187 | epoch_time = AverageMeter()
188 | end = time.time()
189 |
190 | # Run the model for given epochs
191 | for epoch in range(args.num_epochs):
192 | # Adaptive learning rate schedule
193 | if args.adapt_lr:
194 | adjust_learning_rate(optimizer, epoch)
195 |
196 | # Run an epoch over the training data
197 | loss = train(model, train_loader, loss_fn, optimizer, dtype, dims_X, dims_Y, epoch)
198 | loss_history.extend(loss)
199 |
200 | # Check precision/recall/accuracy/F1_score on the train and val sets
201 | prec, rec, f1 = check_accuracy(model, train_loader, dtype, dims_X, dims_Y, epoch, save_dir, 'train')
202 | train_precision.append(prec)
203 | train_recall.append(rec)
204 | train_f1score.append(f1)
205 | prec, rec, f1 = check_accuracy(model, val_loader, dtype, dims_X, dims_Y, epoch, save_dir, 'val')
206 | val_precision.append(prec)
207 | val_recall.append(rec)
208 | val_f1score.append(f1)
209 |
210 | plt.subplot(2, 2, 1)
211 | plt.title('Training loss')
212 | plt.plot(loss_history, 'o')
213 | plt.yscale('log')
214 | plt.xlabel('Iteration')
215 |
216 | plt.subplot(2, 2, 2)
217 | plt.title('Accuracy (F1 Score)')
218 | plt.plot(train_f1score, '-o', label='train')
219 | plt.plot(val_f1score, '-o', label='val')
220 | plt.xlabel('Epoch')
221 | plt.legend(loc='lower right')
222 |
223 | plt.subplot(2, 2, 3)
224 | plt.title('Precision')
225 | plt.plot(train_precision, '-o', label='train')
226 | plt.plot(val_precision, '-o', label='val')
227 | plt.xlabel('Epoch')
228 | plt.legend(loc='lower right')
229 |
230 | plt.subplot(2, 2, 4)
231 | plt.title('Recall')
232 | plt.plot(train_recall, '-o', label='train')
233 | plt.plot(val_recall, '-o', label='val')
234 | plt.xlabel('Epoch')
235 | plt.legend(loc='lower right')
236 |
237 | plt.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.0)
238 | plt.savefig(save_dir + 'training_history.jpg')
239 | #plt.savefig(save_dir + 'training_history.eps', format='eps')
240 | plt.close()
241 |
242 | # Save best model parameters
243 | if f1 > best_val_f1score:
244 | best_val_f1score = f1
245 | print('Saving best model parameters with Val F1 score = %.4f' %(best_val_f1score))
246 | torch.save(model.state_dict(), save_dir + 'saved_model_params')
247 |
248 |
249 | # Measure elapsed time
250 | epoch_time.update(time.time() - end)
251 | end = time.time()
252 |
253 | print('Timer Epoch [{0}/{1}]\t'
254 | 't_epoch {epoch_time.val:.3f} ({epoch_time.avg:.3f})'.format(
255 | epoch+1, args.num_epochs, epoch_time=epoch_time))
256 |
257 |
258 | def train(model, loader, loss_fn, optimizer, dtype, dims_X, dims_Y, epoch):
259 | """
260 | Train the model for one epoch
261 | """
262 | batch_time = AverageMeter()
263 | data_time = AverageMeter()
264 | losses = AverageMeter()
265 |
266 | # Set the model to training mode
267 | model.train()
268 |
269 | loss_hist = []
270 |
271 | end = time.time()
272 | for i, (x, y) in enumerate(loader):
273 | # The DataLoader produces 2D Torch Tensors, so we need to reshape them to 4D,
274 | # cast them to the correct datatype and wrap them in Variables.
275 | #
276 | # Note that the labels should be a torch.LongTensor on CPU and a
277 | # torch.cuda.LongTensor on GPU; to accomplish this we first cast to dtype
278 | # (either torch.FloatTensor or torch.cuda.FloatTensor) and then cast to
279 | # long; this ensures that y has the correct type in both cases.
280 |
281 | # Measure data loading time
282 | data_time.update(time.time() - end)
283 |
284 | x = x.view(dims_X) # (N_batch, 1, H, W)
285 | y = y.view(dims_Y) # (N_batch, 8, H, W)
286 | x_var = Variable(x.type(dtype), requires_grad=False)
287 | y_var = Variable(y.type(dtype).long(), requires_grad=False)
288 |
289 | # Run the model forward to compute scores and loss
290 | scores = model(x_var) # (N_batch, 16, H, W)
291 |
292 | # To convert scores from (N_batch, 16, H, W) to (N_batch*H*W*8, 2) where 2 = number of classes (on/off),
293 | # for PyTorch's cross entropy loss format (http://pytorch.org/docs/nn.html#crossentropyloss)
294 | _, twoC, _, _ = scores.size()
295 | scores = scores.permute(0, 2, 3, 1).contiguous().view(-1, twoC) # (N_batch*H*W, twoC)
296 | scores = torch.cat((scores[:, 0:twoC:2].contiguous().view(-1, 1),
297 | scores[:, 1:twoC:2].contiguous().view(-1, 1)), 1) # (N_batch*H*W*8, 2)
298 |
299 | # To convert y_var from (N_batch, 8, H, W) to (N_batch*H*W*8)
300 | # for PyTorch's cross entropy loss format (http://pytorch.org/docs/nn.html#crossentropyloss)
301 | y_var = y_var.permute(0, 2, 3, 1).contiguous().view(-1) # (N_batch*H*W*8)
302 |
303 | # Use cross entropy loss - 16 filter case
304 | loss = loss_fn(scores, y_var)
305 |
306 | losses.update(loss.data[0], y_var.size(0))
307 | loss_hist.append(loss.data[0])
308 |
309 | # Run the model backward and take a step using the optimizer
310 | optimizer.zero_grad()
311 | loss.backward()
312 | optimizer.step()
313 |
314 | # Measure elapsed time
315 | batch_time.update(time.time() - end)
316 | end = time.time()
317 |
318 | if (i % args.print_freq == 0) or (i+1 == len(loader)):
319 | print('Train Epoch [{0}/{1}]\t'
320 | 'Batch [{2}/{3}]\t'
321 | 't_total {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
322 | 't_data {data_time.val:.3f} ({data_time.avg:.3f})\t'
323 | 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
324 | epoch+1, args.num_epochs, i+1, len(loader), batch_time=batch_time,
325 | data_time=data_time, loss=losses))
326 |
327 | return loss_hist
328 |
329 |
330 | def check_accuracy(model, loader, dtype, dims_X, dims_Y, epoch, save_dir, train_val):
331 | """
332 | Check the accuracy of the model
333 | """
334 | # Filenames
335 | y_pred_fname = 'Y_' + train_val + '_pred'
336 | y_act_fname = 'Y_' + train_val + '_act'
337 |
338 | # Set the model to eval mode
339 | model.eval()
340 |
341 | tp, tn, fp, fn = 0, 0, 0, 0
342 |
343 | for i, (x, y) in enumerate(loader):
344 | # Reshape 2D torch tensors from DataLoader to 4D, cast to the
345 | # correct type and wrap it in a Variable.
346 | #
347 | # At test-time when we do not need to compute gradients, marking
348 | # the Variable as volatile can reduce memory usage and slightly improve speed.
349 | x = x.view(dims_X) # (N_batch, 1, H, W)
350 | y = y.view(dims_Y) # (N_batch, 8, H, W)
351 | x_var = Variable(x.type(dtype), volatile=True)
352 | y_var = Variable(y.type(dtype), volatile=True)
353 |
354 | # Run the model forward, and compute the y_pred to compare with the ground-truth
355 | scores = model(x_var) # (N, 16, H, W)
356 |
357 | _, twoC, _, _ = scores.size()
358 | scores_off = scores[:, 0:twoC:2, :, :]
359 | scores_on = scores[:, 1:twoC:2, :, :]
360 |
361 | y_pred = (scores_on > scores_off) # (N_batch, 8, H, W)
362 |
363 | # Precision / Recall / F-1 Score
364 | #https://en.wikipedia.org/wiki/Precision_and_recall
365 | # tp = true_pos, tn = true_neg, fp = false_pos, fn = false_neg
366 | tp += ((y_pred.data == 1) * (y_var.data == 1)).sum()
367 | tn += ((y_pred.data == 0) * (y_var.data == 0)).sum()
368 | fp += ((y_pred.data == 1) * (y_var.data == 0)).sum()
369 | fn += ((y_pred.data == 0) * (y_var.data == 1)).sum()
370 |
371 | # Preview images from first mini-batch after every 5% of epochs
372 | # E.g., if num_epochs = 20, preview every 1 epoch
373 | # if num_epochs = 200, preview every 10 epochs
374 | if i == 0 and ((epoch % (args.num_epochs*5//100) == 0) or (epoch+1 == args.num_epochs)):
375 | Y_act_dec = decodeData(y_var.data.cpu().numpy()) # (N_batch, 3, H, W)
376 | Y_act_dec = np.swapaxes(Y_act_dec, 1, 2) # (N_batch, H, 3, W)
377 | Y_act_dec = np.swapaxes(Y_act_dec, 2, 3) # (N_batch, H, W, 3)
378 |
379 | Y_pred_dec = decodeData(y_pred.data.cpu().numpy()) # (N_batch, 3, H, W)
380 | Y_pred_dec = np.swapaxes(Y_pred_dec, 1, 2) # (N_batch, H, 3, W)
381 | Y_pred_dec = np.swapaxes(Y_pred_dec, 2, 3) # (N_batch, H, W, 3)
382 |
383 | num_images = 9
384 | for n in range(num_images):
385 | plt.subplot(3, 3, n+1)
386 | plt.imshow(Y_act_dec[n].astype('uint8'))
387 | #plt.axis('off')
388 | plt.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.0)
389 | plt.title('Y_%s_actual (epoch %d)' % (train_val, epoch+1))
390 | plt.savefig(save_dir + 'epoch_' + str(epoch+1) + '_' + y_act_fname + '.jpg')
391 | #plt.savefig(save_dir + 'epoch_' + str(epoch+1) + '_' + y_act_fname + '.eps', format='eps')
392 | plt.close()
393 |
394 | for n in range(num_images):
395 | plt.subplot(3, 3, n+1)
396 | plt.imshow(Y_pred_dec[n].astype('uint8'))
397 | #plt.axis('off')
398 | plt.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.0)
399 | plt.title('Y_%s_predicted (epoch %d)' % (train_val, epoch+1))
400 | plt.savefig(save_dir + 'epoch_' + str(epoch+1) + '_' + y_pred_fname + '.jpg')
401 | #plt.savefig(save_dir + 'epoch_' + str(epoch+1) + '_' + y_pred_fname + '.eps', format='eps')
402 | plt.close()
403 |
404 | # 1e-8 to avoid division by zero
405 | precision = tp / (tp + fp + 1e-8)
406 | recall = tp / (tp + fn)
407 | accuracy = (tp + tn) / (tp + tn + fp + fn)
408 | f1_score = 2 * (precision*recall) / (precision + recall + 1e-8)
409 |
410 | print('{0}\t'
411 | 'Check Epoch [{1}/{2}]\t'
412 | 'Precision {p:.4f}\t'
413 | 'Recall {r:.4f}\t'
414 | 'Accuracy {a:.4f}\t'
415 | 'F1 score {f1:.4f}'.format(
416 | train_val, epoch+1, args.num_epochs, p=precision, r=recall, a=accuracy, f1=f1_score))
417 |
418 | return precision, recall, f1_score
419 |
420 |
421 | def bce_loss(input, target):
422 | """
423 | Numerically stable version of the binary cross-entropy loss function.
424 |
425 | As per https://github.com/pytorch/pytorch/issues/751
426 | See the TensorFlow docs for a derivation of this formula:
427 | https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
428 |
429 | Inputs:
430 | - input: PyTorch Variable of shape (N, 8, H, W) giving scores.
431 | - target: PyTorch Variable of shape (N, 8, H, W) containing 0 and 1 giving targets.
432 |
433 | Returns:
434 | - A PyTorch Variable containing the mean BCE loss over the minibatch of input data.
435 | """
436 | # bce_loss(input, target) = target * -log(sigmoid(input)) + (1 - target) * -log(1 - sigmoid(input))
437 |
438 | neg_abs = - input.abs()
439 | bce_loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() # (N, 8, H, W)
440 | return bce_loss.mean()
441 |
442 |
443 | def wt_bce_loss(input, target, weight):
444 | """
445 | Numerically stable version of the weighted binary cross-entropy loss function.
446 |
447 | See the TensorFlow docs for a derivation of this formula:
448 | https://www.tensorflow.org/api_docs/python/tf/nn/weighted_cross_entropy_with_logits
449 |
450 | Inputs:
451 | - input: PyTorch Variable of shape (N, 8, H, W) giving scores.
452 | - target: PyTorch Variable of shape (N, 8, H, W) containing 0 and 1 giving targets.
453 |
454 | Returns:
455 | - A PyTorch Variable containing the mean weighted BCE loss over the minibatch of input data.
456 | """
457 | # wt_bce_loss(input, target, weight) = weight * target * -log(sigmoid(input)) + (1 - target) * -log(1 - sigmoid(input))
458 |
459 | neg_abs = - input.abs()
460 | wt_bce_loss = (-input).clamp(min=0) + (1 - target) * input + (1 + (weight - 1) * target) * (1 + neg_abs.exp()).log() # (N, 8, H, W)
461 | return wt_bce_loss.mean()
462 |
463 |
464 | def adjust_learning_rate(optimizer, epoch):
465 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
466 | lr = args.lr * (0.1 ** (epoch // 30))
467 | print("Adaptive learning rate: %e" %(lr))
468 | for param_group in optimizer.param_groups:
469 | param_group['lr'] = lr
470 |
471 |
472 | class AverageMeter(object):
473 | """Computes and stores the average and current value"""
474 | def __init__(self):
475 | self.reset()
476 |
477 | def reset(self):
478 | self.val = 0
479 | self.avg = 0
480 | self.sum = 0
481 | self.count = 0
482 |
483 | def update(self, val, n=1):
484 | self.val = val
485 | self.sum += val * n
486 | self.count += n
487 | self.avg = self.sum / self.count
488 |
489 |
490 | if __name__ == '__main__':
491 | args = parser.parse_args()
492 | main(args)
493 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Cython==0.23.4
2 | Jinja2==2.8
3 | MarkupSafe==0.23
4 | Pillow==3.0.0
5 | Pygments==2.0.2
6 | appnope==0.1.0
7 | argparse==1.2.1
8 | backports-abc==0.4
9 | backports.ssl-match-hostname==3.5.0.1
10 | certifi==2015.11.20.1
11 | cycler==0.10.0
12 | decorator==4.0.6
13 | future==0.16.0
14 | gnureadline==6.3.3
15 | h5py==2.7.0
16 | ipykernel==4.2.2
17 | ipython==4.0.1
18 | ipython-genutils==0.1.0
19 | ipywidgets==4.1.1
20 | jsonschema==2.5.1
21 | jupyter==1.0.0
22 | jupyter-client==4.1.1
23 | jupyter-console==4.0.3
24 | jupyter-core==4.0.6
25 | matplotlib==2.0.0
26 | mistune==0.8.1
27 | nbconvert==4.1.0
28 | nbformat==4.0.1
29 | notebook==5.7.8
30 | numpy==1.10.4
31 | path.py==8.1.2
32 | pexpect==4.0.1
33 | pickleshare==0.5
34 | ptyprocess==0.5
35 | pyparsing==2.0.7
36 | python-dateutil==2.4.2
37 | pytz==2015.7
38 | pyzmq==15.1.0
39 | qtconsole==4.1.1
40 | scipy==0.16.1
41 | simplegeneric==0.8.1
42 | singledispatch==3.4.0.3
43 | site==0.0.1
44 | six==1.10.0
45 | terminado==0.5
46 | tornado==4.3
47 | traitlets==4.0.0
48 |
--------------------------------------------------------------------------------
/setup_pytorch.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | source .env/bin/activate
3 |
4 | pip install numpy
5 | pip install http://download.pytorch.org/whl/cu80/torch-0.1.12.post2-cp35-cp35m-linux_x86_64.whl
6 | pip install torchvision
7 | deactivate
8 |
--------------------------------------------------------------------------------
/setup_virtualenv.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | sudo apt-get update
4 | sudo apt-get install libncurses5-dev
5 | sudo apt-get install python-dev
6 | sudo apt-get install python-pip
7 | sudo apt-get install libjpeg8-dev
8 | sudo ln -s /usr/lib/x86_64-linux-gnu/libjpeg.so /usr/lib
9 | pip install pillow
10 | sudo apt-get build-dep python-imaging
11 | sudo apt-get install libjpeg8 libjpeg62-dev libfreetype6 libfreetype6-dev
12 | sudo pip install virtualenv
13 | virtualenv -p python3 .env # Create a virtual environment
14 | source .env/bin/activate # Activate the virtual environment
15 | pip install -r requirements.txt # Install dependencies
16 | deactivate
17 | echo "If you had no errors, you can proceed to work with your virtualenv as normal."
18 | echo "(run 'source .env/bin/activate' to load the venv,"
19 | echo " and run 'deactivate' to exit the venv.)"
20 |
--------------------------------------------------------------------------------
/start_jupyter_env.sh:
--------------------------------------------------------------------------------
1 | source .env/bin/activate
2 | jupyter-notebook --no-browser --port=7001 &
3 |
--------------------------------------------------------------------------------