├── .gitignore ├── LICENSE ├── README.md ├── datagen ├── decoder.py └── gen_data.py ├── figs ├── Network.png └── sample_data.jpg ├── model └── train_fcn_pytorch.py ├── requirements.txt ├── setup_pytorch.sh ├── setup_virtualenv.sh └── start_jupyter_env.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | .env/* 4 | data/*.hdf5 5 | model/data/*.hdf5 6 | model/training/*/saved_model_params 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Sambhav R. Jain 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Route 2 | 3 | This repository contains PyTorch implementation (pretrained weights provided) and dataset generation code for the paper 4 | 5 | **[Training a Fully Convolutional Neural Network to Route Integrated Circuits](https://arxiv.org/abs/1706.08948)** 6 |
7 | [Sambhav R. Jain](https://bit.ly/sjain-stanford), Kye Okabe 8 |
9 | arXiv-cs.CV (Computer Vision and Pattern Recognition) 2017 10 | 11 | We design and train a deep, fully convolutional neural network that learns to route a circuit layout net with appropriate choice of metal tracks and wire class combinations. Inputs to the network are the encoded layouts containing spatial location of pins to be routed. After 15 fully convolutional stages followed by a score comparator, the network outputs 8 layout layers (corresponding to 4 route layers, 3 via layers and an identity-mapped pin layer) which are then decoded to obtain the routed layouts. 12 | 13 | **Proposed FCN Model** 14 | 15 | 16 | **Training samples (left: data, right: labels) from the generated dataset** 17 | 18 | 19 | ## Install (Linux) 20 | 1. Fork [this GitHub repository](https://github.com/sjain-stanford/deep-route) 21 | 2. Setup virtualenv and install dependencies 22 | * `./setup_virtualenv.sh` 23 | 3. Install PyTorch 24 | * `./setup_pytorch.sh` 25 | 4. Activate virtualenv, start Jupyter notebook 26 | * `./start_jupyter_env.sh` 27 | 28 | ## Generate Dataset 29 | Run the script `./datagen/gen_data.py` to generate training data of shape (N, 1, H, W) and labels of shape (N, 8, H, W) stored using [HDF5 (h5py)](https://github.com/h5py/h5py). Default parameters used for the paper are `H = W = 32`, and `pin_range = (2, 6)`, but feel free to modify as desired. Generating 50,000 image dataset should take < 1 minute. 30 | ``` 31 | python ./datagen/gen_data.py 32 | >> Enter the number of images to be generated: 50000 33 | mv ./data/layout_data.hdf5 ./model/data/train_50k_32pix.hdf5 34 | 35 | python ./datagen/gen_data.py 36 | >> Enter the number of images to be generated: 10000 37 | mv ./data/layout_data.hdf5 ./model/data/val_10k_32pix.hdf5 38 | ``` 39 | 40 | ## Train FCN Network (in PyTorch) 41 | Switch to `./model/` dir and run the script `./train_fcn_pytorch.py` to train the FCN model with default options, or use the switch `--help` to display a list of options and their defaults. 42 | ``` 43 | cd ./model/ 44 | python ./train_fcn_pytorch.py --help 45 | ``` 46 | 47 | ``` 48 | usage: train_fcn_pytorch.py [-h] [--data PATH] [--batch_size N] 49 | [--num_workers N] [--num_epochs N] [--use_gpu] 50 | [--pretrained] [--lr LR] [--adapt_lr] [--reg REG] 51 | [--print-freq N] 52 | 53 | Deep-Route: Training a deep FCN network to route circuit layouts. 54 | 55 | optional arguments: 56 | -h, --help show this help message and exit 57 | --data PATH path to dataset (default: ./data/) 58 | --batch_size N mini-batch size (default: 100) 59 | --num_workers N number of data loading workers (default: 4) 60 | --num_epochs N number of total epochs to run (default: 200) 61 | --use_gpu use GPU if available 62 | --pretrained use pre-trained model 63 | --lr LR initial learning rate (default: 5e-4) 64 | --adapt_lr use learning rate schedule 65 | --reg REG regularization strength (default: 1e-5) 66 | --print-freq N print frequency (default: 10) 67 | ``` 68 | 69 | To run on GPU, provide switch `--use_gpu`. Best model parameters (based on F-1 score on validation set) are saved to `./model/training/` dir every epoch, along with loss and training curves. If the switch `--pretrained` is provided, model is pre-loaded with saved parameters before training. Pretrained weights (for batch size 10 and 100) are made available [here](https://github.com/sjain-stanford/pretrained-weights). With `--adapt_lr`, a learning rate decay factor of 10 is applied every 30 epochs. 70 | 71 | ## Cite 72 | If you find this work useful in your research, please cite: 73 | ``` 74 | @article{jain2017route, 75 | title={Training a Fully Convolutional Neural Network to Route Integrated Circuits}, 76 | author={Jain, Sambhav R and Okabe, Kye}, 77 | journal={arXiv preprint arXiv:1706.08948}, 78 | year={2017} 79 | } 80 | ``` 81 | -------------------------------------------------------------------------------- /datagen/decoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import h5py 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | def decodeData(inp): 11 | """ 12 | Generates decoded (RGB) images from 8 channel data 13 | 14 | Input: 15 | - inp: data of shape (N, C, H, W) (range: 0 or 1) 16 | where C = 8 (Y) or 1 (X) 17 | 18 | Output: 19 | - out: decoded (RGB) data (N, 3, H, W) (range: 0-255) 20 | """ 21 | 22 | # Define RGB colors (http://www.rapidtables.com/web/color/RGB_Color.htm) 23 | M6_clr = np.array([[ 0.0, 191.0, 255.0]]) # layer 7 m6 dodger blue 24 | VIA5_clr = np.array([[255.0, 0.0, 0.0]]) # layer 6 via5 red 25 | M5_clr = np.array([[169.0, 169.0, 169.0]]) # layer 5 m5 dark gray 26 | VIA4_clr = np.array([[ 0.0, 255.0, 0.0]]) # layer 4 via4 green 27 | M4_clr = np.array([[255.0, 99.0, 71.0]]) # layer 3 m4 tomato red 28 | VIA3_clr = np.array([[ 0.0, 0.0, 255.0]]) # layer 2 via3 blue 29 | M3_clr = np.array([[ 50.0, 205.0, 50.0]]) # layer 1 m3 lime green 30 | PIN_clr = np.array([[ 0.0, 0.0, 0.0]]) # layer 0 pin black 31 | BGND_clr = np.array([[255.0, 255.0, 255.0]]) # background white 32 | 33 | # If inp is 1 image, reshape 34 | if (len(inp.shape) == 3): 35 | C, H, W = inp.shape 36 | inp = inp.reshape(1, C, H, W) 37 | 38 | N, C, H, W = inp.shape 39 | D = N * H * W 40 | 41 | # Warn if any input elements are out of range 42 | if np.any(np.greater(inp, 1)): 43 | print("Invalid input range detected (some input elements > 1)") 44 | if np.any(np.less(inp, 0)): 45 | print("Invalid input range detected (some input elements < 0)") 46 | 47 | # Initialize pseudo output & activepixelcount matrix 48 | out = np.zeros([N, 3, H, W]) 49 | 50 | inp_swap = np.swapaxes(inp, 1, 0).reshape(C, D) # (C, N*H*W) 51 | out_swap = np.swapaxes(out, 1, 0).reshape(3, D) # (3, N*H*W) 52 | 53 | if C != 1: 54 | #layer 7 processing 55 | temp7 = np.broadcast_to((inp_swap[7, :] == 1), (3, D)) 56 | out_swap += M6_clr.T * temp7 57 | 58 | #layer 6 processing 59 | temp6 = np.broadcast_to((inp_swap[6, :] == 1), (3, D)) 60 | out_swap += VIA5_clr.T * temp6 61 | 62 | #layer 5 processing 63 | temp5 = np.broadcast_to((inp_swap[5, :] == 1), (3, D)) 64 | out_swap += M5_clr.T * temp5 65 | 66 | #layer 4 processing 67 | temp4 = np.broadcast_to((inp_swap[4, :] == 1), (3, D)) 68 | out_swap += VIA4_clr.T * temp4 69 | 70 | #layer 3 processing 71 | temp3 = np.broadcast_to((inp_swap[3, :] == 1), (3, D)) 72 | out_swap += M4_clr.T * temp3 73 | 74 | #layer 2 processing 75 | temp2 = np.broadcast_to((inp_swap[2, :] == 1), (3, D)) 76 | out_swap += VIA3_clr.T * temp2 77 | 78 | #layer 1 processing 79 | temp1 = np.broadcast_to((inp_swap[1, :] == 1), (3, D)) 80 | out_swap += M3_clr.T * temp1 81 | 82 | #layer 0 processing 83 | temp0 = np.broadcast_to((inp_swap[0, :] == 1), (3, D)) 84 | out_swap += PIN_clr.T * temp0 85 | 86 | # For every pixel, count the number of active layers that overlap 87 | overlap = 1.0 * np.sum(inp_swap, axis=0, keepdims=True) 88 | n_layers = overlap + (overlap == 0.0) * 1.0 # To avoid division by zero 89 | 90 | # If no active layers, fill in white 91 | temp_white = np.broadcast_to((overlap == 0), (3, D)) 92 | out_swap += BGND_clr.T * temp_white 93 | 94 | # Average the colors for all active layers 95 | out_swap /= (n_layers * 1.0) 96 | 97 | out_swap = out_swap.reshape(3, N, H, W) 98 | out = np.swapaxes(out_swap, 1, 0) # (N, 3, H, W) 99 | 100 | # Warn if any output elements are out of range 101 | if np.any(np.greater(out, 255)): 102 | print ("Invalid output range detected (some output elements > 255)") 103 | if np.any(np.less(out, 0)): 104 | print ("Invalid output range detected (some output elements < 0)") 105 | 106 | return out 107 | 108 | 109 | ############################ 110 | # MAIN # 111 | ############################ 112 | def main(): 113 | plt.rcParams['figure.figsize'] = (15.0, 10.0) # set default size of plots 114 | plt.rcParams['image.interpolation'] = 'nearest' 115 | plt.rcParams['image.cmap'] = 'gray' 116 | 117 | data = {} 118 | data_file = os.getcwd() + "/data/layout_data.hdf5" 119 | 120 | #data = h5py.File(data_file, 'r') 121 | #X = np.array(data['X']) # (N, 1, H, W) 122 | #Y = np.array(data['Y']) # (N, C, H, W) 123 | 124 | with h5py.File(data_file, 'r') as f: 125 | for k, v in f.items(): 126 | data[k] = np.asarray(v) 127 | 128 | X_dec = decodeData(data['X']) # (N, 3, H, W) 129 | X_dec = np.swapaxes(X_dec, 1, 2) # (N, H, 3, W) 130 | X_dec = np.swapaxes(X_dec, 2, 3) # (N, H, W, 3) 131 | 132 | Y_dec = decodeData(data['Y']) # (N, 3, H, W) 133 | Y_dec = np.swapaxes(Y_dec, 1, 2) # (N, H, 3, W) 134 | Y_dec = np.swapaxes(Y_dec, 2, 3) # (N, H, W, 3) 135 | 136 | num_images = 6 137 | 138 | for n in range(num_images): 139 | plt.subplot(2, 3, n+1) 140 | plt.imshow(X_dec[n].astype('uint8')) 141 | #plt.axis('off') 142 | plt.tight_layout(pad=2.0, w_pad=1.0, h_pad=1.0) 143 | 144 | plt.savefig("data/sample_X.jpg") 145 | plt.savefig("data/sample_X.eps") 146 | 147 | for n in range(num_images): 148 | plt.subplot(2, 3, n+1) 149 | plt.imshow(Y_dec[n].astype('uint8')) 150 | #plt.axis('off') 151 | plt.tight_layout(pad=2.0, w_pad=1.0, h_pad=1.0) 152 | 153 | plt.savefig("data/sample_Y.jpg") 154 | plt.savefig("data/sample_Y.eps") 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /datagen/gen_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | import random 4 | import os 5 | import h5py 6 | 7 | 8 | def drawPins(input_data, xpins, ypins, map): 9 | """ 10 | Draw pins on input_data 11 | 12 | Inputs: 13 | - input_data: np.array of size (C, H, W) (all zeros to start with) 14 | - xpins & ypins: np.array of x-coordinates and y-coordinates for all pins 15 | e.g., [x1, x2 ... xm] and [y1, y2, ... ym] for m pins 16 | - map: layout layer map (dictionary from layer name to index) 17 | 18 | Outputs: 19 | - output_data: np.array of size (C, H, W) (with pixels corresponding to pins in layer 'pins' set to '1') 20 | """ 21 | 22 | output_data = input_data 23 | 24 | for (x, y) in zip(xpins, ypins): 25 | output_data[map['pin'], y, x] = 1 26 | 27 | return output_data 28 | 29 | 30 | def drawRoutes(input_data, xpins, ypins, xlen, ylen, xwire, ywire, map): 31 | """ 32 | Draw routes on input_data 33 | 34 | Inputs: 35 | - input_data: np.array of size (C, H, W) 36 | - xpins & ypins: np.array of x-coordinates and y-coordinates for all pins 37 | e.g., [x1, x2 ... xm] and [y1, y2, ... ym] for m pins 38 | - xlen & ylen: horizontal and vertical max length (in um) 39 | - xwire & ywire: horizontal and vertical wire class to route; such as ('m4', 'm3') 40 | - xavg & yavg: horizontal and vertical average points for placing branch of route 41 | - map: layout layer map (dictionary from layer name to index) 42 | 43 | Outputs: 44 | - output_data: np.array of size (C, H, W) 45 | """ 46 | 47 | output_data = input_data 48 | 49 | xavg = int(np.average(xpins)) 50 | yavg = int(np.average(ypins)) 51 | 52 | if xlen > ylen: 53 | # Draw horizontal branch first 54 | xmin = min(xpins) 55 | xmax = max(xpins) 56 | output_data[map[xwire], yavg, xmin:xmax+1] = 1 57 | 58 | # Draw vertical legs from all pins to branch 59 | # Except in the special case of nPins=2 and pins being x-aligned, where only horizontal branch is needed 60 | if ywire is not None: 61 | for n in range(len(xpins)): 62 | if ypins[n] > yavg: 63 | output_data[map[ywire], yavg:ypins[n]+1, xpins[n]] = 1 64 | else: 65 | output_data[map[ywire], ypins[n]:yavg+1, xpins[n]] = 1 66 | else: 67 | # Draw vertical branch first 68 | ymin = min(ypins) 69 | ymax = max(ypins) 70 | output_data[map[ywire], ymin:ymax+1, xavg] = 1 71 | 72 | # Draw horizontal legs from all pins to branch 73 | # Except in the special case of nPins=2 and pins being y-aligned, where only vertical branch is needed 74 | if xwire is not None: 75 | for n in range(len(xpins)): 76 | if xpins[n] > xavg: 77 | output_data[map[xwire], ypins[n], xavg:xpins[n]+1] = 1 78 | else: 79 | output_data[map[xwire], ypins[n], xpins[n]:xavg+1] = 1 80 | 81 | # Draw vias 82 | output_data[map['via3']] = (output_data[map['m3']] == 1) * (output_data[map['m4']] == 1) * 1 83 | output_data[map['via4']] = (output_data[map['m4']] == 1) * (output_data[map['m5']] == 1) * 1 84 | output_data[map['via5']] = (output_data[map['m5']] == 1) * (output_data[map['m6']] == 1) * 1 85 | 86 | return output_data 87 | 88 | 89 | def selectWireClass(xlen, ylen): 90 | """ 91 | Draw pins on input_data 92 | 93 | Inputs: 94 | - xlen & ylen: horizontal and vertical max length (in um) 95 | 96 | Outputs: 97 | - wire_class: tuple of (xwire, ywire); such as ('m4', 'm3') 98 | """ 99 | # 7FF Wire Performance (from RCCalc) 100 | # R = rho * l / (t * w) 101 | # rho / (t * w) = R / l 102 | # Multiply rho_tw by length to get wire resistance, add via resistance to this 103 | 104 | # m2 1x w=20nm horz 105 | rho_tw_m2_1x = 106.45 # ohm/um 106 | # m3 2x w=40nm vert 107 | rho_tw_m3_2x = 30.555 # ohm/um 108 | # m4 1x w=40nm horz 109 | rho_tw_m4_1x = 24.077 # ohm/um 110 | # m5 1x w=38nm vert 111 | rho_tw_m5_1x = 18.368 # ohm/um 112 | # m6 1x w=40nm horz 113 | rho_tw_m6_1x = 16.950 # ohm/um 114 | # m7 1x w=76nm vert 115 | rho_tw_m7_1x = 8.854 # ohm/um 116 | 117 | # via resistance 118 | r_via2 = 40 # ohm/ct 119 | r_via3 = 30 # ohm/ct 120 | r_via4 = 12 # ohm/ct 121 | r_via5 = 12 # ohm/ct 122 | r_via6 = 12 # ohm/ct 123 | 124 | # Assumption - All routes run to/from two m2 pins 125 | R_m2 = (rho_tw_m2_1x * xlen) 126 | R_m3 = (rho_tw_m3_2x * ylen) + (r_via2 * 2) 127 | R_m4 = (rho_tw_m4_1x * xlen) + (r_via2 * 2) + (r_via3 * 2) 128 | R_m5 = (rho_tw_m5_1x * ylen) + (r_via2 * 2) + (r_via3 * 2) + (r_via4 * 2) 129 | R_m6 = (rho_tw_m6_1x * xlen) + (r_via2 * 2) + (r_via3 * 2) + (r_via4 * 2) + (r_via5 * 2) 130 | R_m7 = (rho_tw_m7_1x * ylen) + (r_via2 * 2) + (r_via3 * 2) + (r_via4 * 2) + (r_via5 * 2) + (r_via6 * 2) 131 | 132 | # Choose the less resistant wire class based on whether vert or horz length is dominant 133 | if xlen > ylen: 134 | if R_m4 > R_m6: 135 | xwire = 'm6' 136 | ywire = 'm5' 137 | else: 138 | xwire = 'm4' 139 | if R_m3 > R_m5: 140 | ywire = 'm5' 141 | else: 142 | ywire = 'm3' 143 | else: 144 | if R_m5 > R_m3: 145 | ywire = 'm3' 146 | xwire = 'm4' 147 | else: 148 | ywire = 'm5' 149 | if R_m4 > R_m6: 150 | xwire = 'm6' 151 | else: 152 | xwire = 'm4' 153 | 154 | # Special case: 2 pins that are either x-aligned or y-aligned 155 | if xlen == 0: 156 | xwire = None 157 | if ylen == 0: 158 | ywire = None 159 | 160 | wire_class = (xwire, ywire) 161 | 162 | return wire_class 163 | 164 | 165 | def genData(N=10, H=32, W=32, pin_range=(2, 6)): 166 | """ 167 | Generates decoded training image dataset of size (N, C, H, W) 168 | 169 | Inputs: 170 | - N: number of images to generate 171 | - H, W: image height, width (in px) 172 | - pin_range: tuple (low, high) for allowed range of number of pins; "half-open" interval [low, high) 173 | (e.g. (2, 6) pins means 2 or 3 or 4 or 5 pins) 174 | 175 | Outputs: 176 | - Saves X: training data (N, 1, H, W) 177 | - Saves Y: training labels (N, C, H, W) 178 | where C = 8 (each corresponds to a layout layer, viz. 179 | [pin, m3, via3, m4, via4, m5, via5, m6]) 180 | 181 | X[:, 0, :, :] pin (m2) 182 | 183 | Y[:, 0, :, :] pin (m2) - same as X[:, 0, :, :] 184 | Y[:, 1, :, :] m3 (vert) 185 | Y[:, 2, :, :] via3 186 | Y[:, 3, :, :] m4 (horz) 187 | Y[:, 4, :, :] via4 188 | Y[:, 5, :, :] m5 (vert) 189 | Y[:, 6, :, :] via5 190 | Y[:, 7, :, :] m6 (horz) 191 | """ 192 | 193 | # 8 layout layers for now 194 | C = 8 195 | 196 | data_dir = os.getcwd() + "/data/" 197 | if os.path.exists(data_dir): 198 | for file in os.listdir(data_dir): 199 | if file.endswith(".hdf5"): 200 | os.remove(data_dir+file) 201 | else: 202 | os.makedirs(data_dir) 203 | 204 | data = h5py.File(data_dir + "layout_data.hdf5") 205 | 206 | # numpy arrays no longer needed; use HDF5 instead 207 | #X = np.zeros([N, 1, H, W], dtype = np.int8) 208 | #Y = np.zeros([N, C, H, W], dtype = np.int8) 209 | 210 | X = data.create_dataset("X", shape=(N, 1, H, W), dtype='uint8', compression='lzf', chunks=(1, 1, H, W)) 211 | Y = data.create_dataset("Y", shape=(N, C, H, W), dtype='uint8', compression='lzf', chunks=(1, 1, H, W)) 212 | 213 | # Set physical size represented by HxW pixels 214 | microns = 11.0 # To have balanced dataset covering from m3 to m6 (based on resistance plots from resistance_vs_distance.ipynb) 215 | microns_per_xpixel = microns/W 216 | microns_per_ypixel = microns/H 217 | 218 | # Layer map 219 | l_map = { 220 | # Pins 221 | 'pin' : 0, 222 | 223 | # Vias 224 | 'via3' : 2, 225 | 'via4' : 4, 226 | 'via5' : 6, 227 | 228 | # Vertical tracks 229 | 'm3' : 1, 230 | 'm5' : 5, 231 | 232 | # Horizontal tracks 233 | 'm4' : 3, 234 | 'm6' : 7 235 | } 236 | 237 | #m3_m4 = m5_m4 = m5_m6 = 0 238 | 239 | n = 0 240 | print_every = 5000 241 | 242 | while n < N: 243 | # Randomly select number of pins from given range 244 | # Uniform distribution over pin range 245 | nPins = np.random.randint(*pin_range) 246 | # Non-uniform distribution (skewed exponentially towards smaller number of pins) 247 | #p_range = np.array(range(*pin_range)) 248 | #p = np.exp(-p_range) / np.sum(np.exp(-p_range)) 249 | #nPins = np.random.choice(p_range, p=p) 250 | 251 | # Randomly pick x and y co-ords for nPins from [0, W) and [0, H) pixels 252 | x_pins = np.random.randint(W, size=nPins) 253 | y_pins = np.random.randint(H, size=nPins) 254 | 255 | max_xlen = (max(x_pins) - min(x_pins)) * microns_per_xpixel # length in um 256 | max_ylen = (max(y_pins) - min(y_pins)) * microns_per_ypixel # length in um 257 | 258 | # Corner case when pins overlap each other (invalid case) 259 | # Bug fix for https://github.com/sjain-stanford/RouteAI/issues/4 260 | if (max_xlen == 0) and (max_ylen == 0): 261 | continue 262 | 263 | # Draw pins on layer 'pin (m2)' of both X (data) and Y (labels) 264 | X[n] = drawPins(X[n], x_pins, y_pins, l_map) 265 | Y[n] = drawPins(Y[n], x_pins, y_pins, l_map) 266 | 267 | # Add routes to Y (labels) 268 | x_wire, y_wire = selectWireClass(max_xlen, max_ylen) 269 | Y[n] = drawRoutes(Y[n], x_pins, y_pins, max_xlen, max_ylen, x_wire, y_wire, l_map) 270 | 271 | n += 1 272 | 273 | if (n % print_every == 0): 274 | print("Finished generating %d samples." %(n)) 275 | 276 | #if x_wire == 'm4' and y_wire == 'm3': 277 | # m3_m4 += 1 278 | #elif x_wire == 'm4' and y_wire == 'm5': 279 | # m5_m4 += 1 280 | #elif x_wire == 'm6' and y_wire == 'm5': 281 | # m5_m6 += 1 282 | #else: 283 | # print(x_wire, y_wire) 284 | 285 | #print(m3_m4, m5_m4, m5_m6) 286 | 287 | # Storing as .npy using np.save -> Issue: RAM out of memory, disk memory limitation 288 | #data_dir = os.getcwd() + '/data/' 289 | #if os.path.exists(data_dir): 290 | # for file in os.listdir(data_dir): 291 | # if file.endswith(".npy"): 292 | # os.remove(data_dir+file) 293 | #else: 294 | # os.makedirs(data_dir) 295 | #X_save = data_dir + 'X_save.npy' 296 | #Y_save = data_dir + 'Y_save.npy' 297 | #np.save(X_save, X, allow_pickle=False) 298 | #np.save(Y_save, Y, allow_pickle=False) 299 | 300 | print("Dataset generated as follows:") 301 | for ds in data: 302 | print(ds, data[ds]) 303 | 304 | 305 | ############################ 306 | # MAIN # 307 | ############################ 308 | def main(): 309 | N = int(input('Enter the number of images to be generated: ')) 310 | genData(N) 311 | 312 | 313 | if __name__ == '__main__': 314 | main() -------------------------------------------------------------------------------- /figs/Network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjain-stanford/deep-route/ef8934457a37255f62fcdc99780ff11f2e7d3d4f/figs/Network.png -------------------------------------------------------------------------------- /figs/sample_data.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjain-stanford/deep-route/ef8934457a37255f62fcdc99780ff11f2e7d3d4f/figs/sample_data.jpg -------------------------------------------------------------------------------- /model/train_fcn_pytorch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import sys 5 | 6 | import h5py 7 | import numpy as np 8 | 9 | import matplotlib 10 | matplotlib.use('Agg') 11 | import matplotlib.pyplot as plt 12 | plt.rcParams['figure.figsize'] = (10.0, 10.0) # set default size of plots 13 | plt.rcParams['image.interpolation'] = 'nearest' 14 | plt.rcParams['image.cmap'] = 'gray' 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.optim as optim 19 | 20 | from torch.autograd import Variable 21 | from torch.utils.data import DataLoader, TensorDataset 22 | 23 | sys.path.append('../') 24 | from datagen.decoder import decodeData 25 | 26 | parser = argparse.ArgumentParser(description='Deep-Route: Training a deep FCN network to route circuit layouts.') 27 | parser.add_argument('--data', metavar='PATH', default=os.getcwd()+'/data/', help='path to dataset (default: ./data/)') 28 | parser.add_argument('--batch_size', metavar='N', default=100, type=int, help='mini-batch size (default: 100)') 29 | parser.add_argument('--num_workers', metavar='N', default=4, type=int, help='number of data loading workers (default: 4)') 30 | parser.add_argument('--num_epochs', metavar='N', default=200, type=int, help='number of total epochs to run (default: 200)') 31 | parser.add_argument('--use_gpu', action='store_true', help='use GPU if available') 32 | parser.add_argument('--pretrained', action='store_true', help='use pre-trained model') 33 | parser.add_argument('--lr', metavar='LR', default=5e-4, type=float, help='initial learning rate (default: 5e-4)') 34 | parser.add_argument('--adapt_lr', action='store_true', help='use learning rate schedule') 35 | parser.add_argument('--reg', metavar='REG', default=1e-5, type=float, help='regularization strength (default: 1e-5)') 36 | parser.add_argument('--print-freq', metavar='N', default=10, type=int, help='print frequency (default: 10)') 37 | 38 | 39 | def main(args): 40 | # Unutilized GPU notification 41 | if torch.cuda.is_available() and not args.use_gpu: 42 | print("GPU is available. Provide command line flag --use_gpu to use it!") 43 | 44 | # To run on GPU, specify command-line flag --use_gpu 45 | if args.use_gpu and torch.cuda.is_available(): 46 | dtype = torch.cuda.FloatTensor 47 | else: 48 | dtype = torch.FloatTensor 49 | 50 | # Dataset filenames 51 | train_fname = 'train_50k_32pix.hdf5' 52 | val_fname = 'val_10k_32pix.hdf5' 53 | 54 | # Save dir 55 | train_id = 'train50k_val10k_pix32' + '_lr' + str(args.lr) + '_reg' + str(args.reg) + '_batchsize' + str(args.batch_size) + '_epochs' + str(args.num_epochs) + '_gpu' + str(args.use_gpu) 56 | save_dir = os.getcwd() + '/training/' + train_id + '/' 57 | if not os.path.exists(save_dir): 58 | os.makedirs(save_dir) 59 | 60 | # Weighted loss to overcome unbalanced dataset (>98% pixels are off ('0')) 61 | weight = torch.Tensor([1, 3]).type(dtype) 62 | 63 | # Read dataset at path provided by --data command-line flag 64 | train_data = h5py.File(args.data + train_fname, 'r') 65 | X_train = np.asarray(train_data['X']) # Data: X_train.shape = (N, 1, H, W); X_train.dtype = uint8 66 | Y_train = np.asarray(train_data['Y']) # Labels: Y_train.shape = (N, 8, H, W); Y_train.dtype = uint8 67 | print("X_train: %s \nY_train: %s\n" %(X_train.shape, Y_train.shape)) 68 | 69 | val_data = h5py.File(args.data + val_fname, 'r') 70 | X_val = np.asarray(val_data['X']) 71 | Y_val = np.asarray(val_data['Y']) 72 | print("X_val: %s \nY_val: %s\n" %(X_val.shape, Y_val.shape)) 73 | 74 | # Dimensions 75 | N_train = X_train.shape[0] 76 | N_val = X_val.shape[0] 77 | C = Y_train.shape[1] 78 | H = X_train.shape[2] 79 | W = X_train.shape[3] 80 | dims_X = [-1, 1, H, W] 81 | dims_Y = [-1, C, H, W] 82 | 83 | # Setup DataLoader 84 | # https://stackoverflow.com/questions/41924453/pytorch-how-to-use-dataloaders-for-custom-datasets 85 | # PyTorch tensors are of type torch.ByteTensor (8 bit unsigned int) 86 | # Stored as 2D --> train: (N, 1*H*W), val: (N, 8*H*W) 87 | train_dset = TensorDataset(torch.from_numpy(X_train).view(N_train, -1), 88 | torch.from_numpy(Y_train).view(N_train, -1)) 89 | 90 | train_loader = DataLoader(train_dset, batch_size=args.batch_size, 91 | # Disable shuffling in debug mode 92 | #num_workers=args.num_workers, shuffle=False) 93 | num_workers=args.num_workers, shuffle=True) 94 | 95 | val_dset = TensorDataset(torch.from_numpy(X_val).view(N_val, -1), 96 | torch.from_numpy(Y_val).view(N_val, -1)) 97 | 98 | val_loader = DataLoader(val_dset, batch_size=args.batch_size, 99 | num_workers=args.num_workers, shuffle=False) 100 | 101 | # Define NN architecture 102 | model = nn.Sequential( # Input (N, 1, 32, 32) 103 | nn.Conv2d(1, 16, kernel_size=33, stride=1, padding=16, bias=True), # Output (N, 16, 32, 32) 104 | nn.BatchNorm2d(16), 105 | nn.LeakyReLU(inplace=True), 106 | 107 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 108 | nn.BatchNorm2d(16), 109 | nn.LeakyReLU(inplace=True), 110 | 111 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 112 | nn.BatchNorm2d(16), 113 | nn.LeakyReLU(inplace=True), 114 | 115 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 116 | nn.BatchNorm2d(16), 117 | nn.LeakyReLU(inplace=True), 118 | 119 | # Layer 5 120 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 121 | nn.BatchNorm2d(16), 122 | nn.LeakyReLU(inplace=True), 123 | 124 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 125 | nn.BatchNorm2d(16), 126 | nn.LeakyReLU(inplace=True), 127 | 128 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 129 | nn.BatchNorm2d(16), 130 | nn.LeakyReLU(inplace=True), 131 | 132 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 133 | nn.BatchNorm2d(16), 134 | nn.LeakyReLU(inplace=True), 135 | 136 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 137 | nn.BatchNorm2d(16), 138 | nn.LeakyReLU(inplace=True), 139 | 140 | # Layer 10 141 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 142 | nn.BatchNorm2d(16), 143 | nn.LeakyReLU(inplace=True), 144 | 145 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 146 | nn.BatchNorm2d(16), 147 | nn.LeakyReLU(inplace=True), 148 | 149 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 150 | nn.BatchNorm2d(16), 151 | nn.LeakyReLU(inplace=True), 152 | 153 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 154 | nn.BatchNorm2d(16), 155 | nn.LeakyReLU(inplace=True), 156 | 157 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 158 | nn.BatchNorm2d(16), 159 | nn.LeakyReLU(inplace=True), 160 | 161 | # Layer 15 162 | nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=True), # Output (N, 16, 32, 32) 163 | ) 164 | 165 | # Load pretrained model parameters 166 | if args.pretrained: 167 | model.load_state_dict(torch.load(save_dir + '../saved_model_params')) 168 | 169 | # Cast model to the correct datatype 170 | model.type(dtype) 171 | 172 | loss_fn = nn.CrossEntropyLoss(weight=weight).type(dtype) 173 | 174 | # Use Adam optimizer with default betas 175 | optimizer = optim.Adam(model.parameters(), lr=args.lr, 176 | betas=(0.9, 0.999), weight_decay=args.reg) 177 | loss_history = [] 178 | train_precision = [0] 179 | train_recall = [0] 180 | train_f1score = [0] 181 | val_precision = [0] 182 | val_recall = [0] 183 | val_f1score = [0] 184 | 185 | best_val_f1score = 0 186 | 187 | epoch_time = AverageMeter() 188 | end = time.time() 189 | 190 | # Run the model for given epochs 191 | for epoch in range(args.num_epochs): 192 | # Adaptive learning rate schedule 193 | if args.adapt_lr: 194 | adjust_learning_rate(optimizer, epoch) 195 | 196 | # Run an epoch over the training data 197 | loss = train(model, train_loader, loss_fn, optimizer, dtype, dims_X, dims_Y, epoch) 198 | loss_history.extend(loss) 199 | 200 | # Check precision/recall/accuracy/F1_score on the train and val sets 201 | prec, rec, f1 = check_accuracy(model, train_loader, dtype, dims_X, dims_Y, epoch, save_dir, 'train') 202 | train_precision.append(prec) 203 | train_recall.append(rec) 204 | train_f1score.append(f1) 205 | prec, rec, f1 = check_accuracy(model, val_loader, dtype, dims_X, dims_Y, epoch, save_dir, 'val') 206 | val_precision.append(prec) 207 | val_recall.append(rec) 208 | val_f1score.append(f1) 209 | 210 | plt.subplot(2, 2, 1) 211 | plt.title('Training loss') 212 | plt.plot(loss_history, 'o') 213 | plt.yscale('log') 214 | plt.xlabel('Iteration') 215 | 216 | plt.subplot(2, 2, 2) 217 | plt.title('Accuracy (F1 Score)') 218 | plt.plot(train_f1score, '-o', label='train') 219 | plt.plot(val_f1score, '-o', label='val') 220 | plt.xlabel('Epoch') 221 | plt.legend(loc='lower right') 222 | 223 | plt.subplot(2, 2, 3) 224 | plt.title('Precision') 225 | plt.plot(train_precision, '-o', label='train') 226 | plt.plot(val_precision, '-o', label='val') 227 | plt.xlabel('Epoch') 228 | plt.legend(loc='lower right') 229 | 230 | plt.subplot(2, 2, 4) 231 | plt.title('Recall') 232 | plt.plot(train_recall, '-o', label='train') 233 | plt.plot(val_recall, '-o', label='val') 234 | plt.xlabel('Epoch') 235 | plt.legend(loc='lower right') 236 | 237 | plt.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.0) 238 | plt.savefig(save_dir + 'training_history.jpg') 239 | #plt.savefig(save_dir + 'training_history.eps', format='eps') 240 | plt.close() 241 | 242 | # Save best model parameters 243 | if f1 > best_val_f1score: 244 | best_val_f1score = f1 245 | print('Saving best model parameters with Val F1 score = %.4f' %(best_val_f1score)) 246 | torch.save(model.state_dict(), save_dir + 'saved_model_params') 247 | 248 | 249 | # Measure elapsed time 250 | epoch_time.update(time.time() - end) 251 | end = time.time() 252 | 253 | print('Timer Epoch [{0}/{1}]\t' 254 | 't_epoch {epoch_time.val:.3f} ({epoch_time.avg:.3f})'.format( 255 | epoch+1, args.num_epochs, epoch_time=epoch_time)) 256 | 257 | 258 | def train(model, loader, loss_fn, optimizer, dtype, dims_X, dims_Y, epoch): 259 | """ 260 | Train the model for one epoch 261 | """ 262 | batch_time = AverageMeter() 263 | data_time = AverageMeter() 264 | losses = AverageMeter() 265 | 266 | # Set the model to training mode 267 | model.train() 268 | 269 | loss_hist = [] 270 | 271 | end = time.time() 272 | for i, (x, y) in enumerate(loader): 273 | # The DataLoader produces 2D Torch Tensors, so we need to reshape them to 4D, 274 | # cast them to the correct datatype and wrap them in Variables. 275 | # 276 | # Note that the labels should be a torch.LongTensor on CPU and a 277 | # torch.cuda.LongTensor on GPU; to accomplish this we first cast to dtype 278 | # (either torch.FloatTensor or torch.cuda.FloatTensor) and then cast to 279 | # long; this ensures that y has the correct type in both cases. 280 | 281 | # Measure data loading time 282 | data_time.update(time.time() - end) 283 | 284 | x = x.view(dims_X) # (N_batch, 1, H, W) 285 | y = y.view(dims_Y) # (N_batch, 8, H, W) 286 | x_var = Variable(x.type(dtype), requires_grad=False) 287 | y_var = Variable(y.type(dtype).long(), requires_grad=False) 288 | 289 | # Run the model forward to compute scores and loss 290 | scores = model(x_var) # (N_batch, 16, H, W) 291 | 292 | # To convert scores from (N_batch, 16, H, W) to (N_batch*H*W*8, 2) where 2 = number of classes (on/off), 293 | # for PyTorch's cross entropy loss format (http://pytorch.org/docs/nn.html#crossentropyloss) 294 | _, twoC, _, _ = scores.size() 295 | scores = scores.permute(0, 2, 3, 1).contiguous().view(-1, twoC) # (N_batch*H*W, twoC) 296 | scores = torch.cat((scores[:, 0:twoC:2].contiguous().view(-1, 1), 297 | scores[:, 1:twoC:2].contiguous().view(-1, 1)), 1) # (N_batch*H*W*8, 2) 298 | 299 | # To convert y_var from (N_batch, 8, H, W) to (N_batch*H*W*8) 300 | # for PyTorch's cross entropy loss format (http://pytorch.org/docs/nn.html#crossentropyloss) 301 | y_var = y_var.permute(0, 2, 3, 1).contiguous().view(-1) # (N_batch*H*W*8) 302 | 303 | # Use cross entropy loss - 16 filter case 304 | loss = loss_fn(scores, y_var) 305 | 306 | losses.update(loss.data[0], y_var.size(0)) 307 | loss_hist.append(loss.data[0]) 308 | 309 | # Run the model backward and take a step using the optimizer 310 | optimizer.zero_grad() 311 | loss.backward() 312 | optimizer.step() 313 | 314 | # Measure elapsed time 315 | batch_time.update(time.time() - end) 316 | end = time.time() 317 | 318 | if (i % args.print_freq == 0) or (i+1 == len(loader)): 319 | print('Train Epoch [{0}/{1}]\t' 320 | 'Batch [{2}/{3}]\t' 321 | 't_total {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 322 | 't_data {data_time.val:.3f} ({data_time.avg:.3f})\t' 323 | 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( 324 | epoch+1, args.num_epochs, i+1, len(loader), batch_time=batch_time, 325 | data_time=data_time, loss=losses)) 326 | 327 | return loss_hist 328 | 329 | 330 | def check_accuracy(model, loader, dtype, dims_X, dims_Y, epoch, save_dir, train_val): 331 | """ 332 | Check the accuracy of the model 333 | """ 334 | # Filenames 335 | y_pred_fname = 'Y_' + train_val + '_pred' 336 | y_act_fname = 'Y_' + train_val + '_act' 337 | 338 | # Set the model to eval mode 339 | model.eval() 340 | 341 | tp, tn, fp, fn = 0, 0, 0, 0 342 | 343 | for i, (x, y) in enumerate(loader): 344 | # Reshape 2D torch tensors from DataLoader to 4D, cast to the 345 | # correct type and wrap it in a Variable. 346 | # 347 | # At test-time when we do not need to compute gradients, marking 348 | # the Variable as volatile can reduce memory usage and slightly improve speed. 349 | x = x.view(dims_X) # (N_batch, 1, H, W) 350 | y = y.view(dims_Y) # (N_batch, 8, H, W) 351 | x_var = Variable(x.type(dtype), volatile=True) 352 | y_var = Variable(y.type(dtype), volatile=True) 353 | 354 | # Run the model forward, and compute the y_pred to compare with the ground-truth 355 | scores = model(x_var) # (N, 16, H, W) 356 | 357 | _, twoC, _, _ = scores.size() 358 | scores_off = scores[:, 0:twoC:2, :, :] 359 | scores_on = scores[:, 1:twoC:2, :, :] 360 | 361 | y_pred = (scores_on > scores_off) # (N_batch, 8, H, W) 362 | 363 | # Precision / Recall / F-1 Score 364 | #https://en.wikipedia.org/wiki/Precision_and_recall 365 | # tp = true_pos, tn = true_neg, fp = false_pos, fn = false_neg 366 | tp += ((y_pred.data == 1) * (y_var.data == 1)).sum() 367 | tn += ((y_pred.data == 0) * (y_var.data == 0)).sum() 368 | fp += ((y_pred.data == 1) * (y_var.data == 0)).sum() 369 | fn += ((y_pred.data == 0) * (y_var.data == 1)).sum() 370 | 371 | # Preview images from first mini-batch after every 5% of epochs 372 | # E.g., if num_epochs = 20, preview every 1 epoch 373 | # if num_epochs = 200, preview every 10 epochs 374 | if i == 0 and ((epoch % (args.num_epochs*5//100) == 0) or (epoch+1 == args.num_epochs)): 375 | Y_act_dec = decodeData(y_var.data.cpu().numpy()) # (N_batch, 3, H, W) 376 | Y_act_dec = np.swapaxes(Y_act_dec, 1, 2) # (N_batch, H, 3, W) 377 | Y_act_dec = np.swapaxes(Y_act_dec, 2, 3) # (N_batch, H, W, 3) 378 | 379 | Y_pred_dec = decodeData(y_pred.data.cpu().numpy()) # (N_batch, 3, H, W) 380 | Y_pred_dec = np.swapaxes(Y_pred_dec, 1, 2) # (N_batch, H, 3, W) 381 | Y_pred_dec = np.swapaxes(Y_pred_dec, 2, 3) # (N_batch, H, W, 3) 382 | 383 | num_images = 9 384 | for n in range(num_images): 385 | plt.subplot(3, 3, n+1) 386 | plt.imshow(Y_act_dec[n].astype('uint8')) 387 | #plt.axis('off') 388 | plt.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.0) 389 | plt.title('Y_%s_actual (epoch %d)' % (train_val, epoch+1)) 390 | plt.savefig(save_dir + 'epoch_' + str(epoch+1) + '_' + y_act_fname + '.jpg') 391 | #plt.savefig(save_dir + 'epoch_' + str(epoch+1) + '_' + y_act_fname + '.eps', format='eps') 392 | plt.close() 393 | 394 | for n in range(num_images): 395 | plt.subplot(3, 3, n+1) 396 | plt.imshow(Y_pred_dec[n].astype('uint8')) 397 | #plt.axis('off') 398 | plt.tight_layout(pad=1.0, w_pad=1.0, h_pad=1.0) 399 | plt.title('Y_%s_predicted (epoch %d)' % (train_val, epoch+1)) 400 | plt.savefig(save_dir + 'epoch_' + str(epoch+1) + '_' + y_pred_fname + '.jpg') 401 | #plt.savefig(save_dir + 'epoch_' + str(epoch+1) + '_' + y_pred_fname + '.eps', format='eps') 402 | plt.close() 403 | 404 | # 1e-8 to avoid division by zero 405 | precision = tp / (tp + fp + 1e-8) 406 | recall = tp / (tp + fn) 407 | accuracy = (tp + tn) / (tp + tn + fp + fn) 408 | f1_score = 2 * (precision*recall) / (precision + recall + 1e-8) 409 | 410 | print('{0}\t' 411 | 'Check Epoch [{1}/{2}]\t' 412 | 'Precision {p:.4f}\t' 413 | 'Recall {r:.4f}\t' 414 | 'Accuracy {a:.4f}\t' 415 | 'F1 score {f1:.4f}'.format( 416 | train_val, epoch+1, args.num_epochs, p=precision, r=recall, a=accuracy, f1=f1_score)) 417 | 418 | return precision, recall, f1_score 419 | 420 | 421 | def bce_loss(input, target): 422 | """ 423 | Numerically stable version of the binary cross-entropy loss function. 424 | 425 | As per https://github.com/pytorch/pytorch/issues/751 426 | See the TensorFlow docs for a derivation of this formula: 427 | https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits 428 | 429 | Inputs: 430 | - input: PyTorch Variable of shape (N, 8, H, W) giving scores. 431 | - target: PyTorch Variable of shape (N, 8, H, W) containing 0 and 1 giving targets. 432 | 433 | Returns: 434 | - A PyTorch Variable containing the mean BCE loss over the minibatch of input data. 435 | """ 436 | # bce_loss(input, target) = target * -log(sigmoid(input)) + (1 - target) * -log(1 - sigmoid(input)) 437 | 438 | neg_abs = - input.abs() 439 | bce_loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() # (N, 8, H, W) 440 | return bce_loss.mean() 441 | 442 | 443 | def wt_bce_loss(input, target, weight): 444 | """ 445 | Numerically stable version of the weighted binary cross-entropy loss function. 446 | 447 | See the TensorFlow docs for a derivation of this formula: 448 | https://www.tensorflow.org/api_docs/python/tf/nn/weighted_cross_entropy_with_logits 449 | 450 | Inputs: 451 | - input: PyTorch Variable of shape (N, 8, H, W) giving scores. 452 | - target: PyTorch Variable of shape (N, 8, H, W) containing 0 and 1 giving targets. 453 | 454 | Returns: 455 | - A PyTorch Variable containing the mean weighted BCE loss over the minibatch of input data. 456 | """ 457 | # wt_bce_loss(input, target, weight) = weight * target * -log(sigmoid(input)) + (1 - target) * -log(1 - sigmoid(input)) 458 | 459 | neg_abs = - input.abs() 460 | wt_bce_loss = (-input).clamp(min=0) + (1 - target) * input + (1 + (weight - 1) * target) * (1 + neg_abs.exp()).log() # (N, 8, H, W) 461 | return wt_bce_loss.mean() 462 | 463 | 464 | def adjust_learning_rate(optimizer, epoch): 465 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 466 | lr = args.lr * (0.1 ** (epoch // 30)) 467 | print("Adaptive learning rate: %e" %(lr)) 468 | for param_group in optimizer.param_groups: 469 | param_group['lr'] = lr 470 | 471 | 472 | class AverageMeter(object): 473 | """Computes and stores the average and current value""" 474 | def __init__(self): 475 | self.reset() 476 | 477 | def reset(self): 478 | self.val = 0 479 | self.avg = 0 480 | self.sum = 0 481 | self.count = 0 482 | 483 | def update(self, val, n=1): 484 | self.val = val 485 | self.sum += val * n 486 | self.count += n 487 | self.avg = self.sum / self.count 488 | 489 | 490 | if __name__ == '__main__': 491 | args = parser.parse_args() 492 | main(args) 493 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.23.4 2 | Jinja2==2.8 3 | MarkupSafe==0.23 4 | Pillow==3.0.0 5 | Pygments==2.0.2 6 | appnope==0.1.0 7 | argparse==1.2.1 8 | backports-abc==0.4 9 | backports.ssl-match-hostname==3.5.0.1 10 | certifi==2015.11.20.1 11 | cycler==0.10.0 12 | decorator==4.0.6 13 | future==0.16.0 14 | gnureadline==6.3.3 15 | h5py==2.7.0 16 | ipykernel==4.2.2 17 | ipython==4.0.1 18 | ipython-genutils==0.1.0 19 | ipywidgets==4.1.1 20 | jsonschema==2.5.1 21 | jupyter==1.0.0 22 | jupyter-client==4.1.1 23 | jupyter-console==4.0.3 24 | jupyter-core==4.0.6 25 | matplotlib==2.0.0 26 | mistune==0.8.1 27 | nbconvert==4.1.0 28 | nbformat==4.0.1 29 | notebook==5.7.8 30 | numpy==1.10.4 31 | path.py==8.1.2 32 | pexpect==4.0.1 33 | pickleshare==0.5 34 | ptyprocess==0.5 35 | pyparsing==2.0.7 36 | python-dateutil==2.4.2 37 | pytz==2015.7 38 | pyzmq==15.1.0 39 | qtconsole==4.1.1 40 | scipy==0.16.1 41 | simplegeneric==0.8.1 42 | singledispatch==3.4.0.3 43 | site==0.0.1 44 | six==1.10.0 45 | terminado==0.5 46 | tornado==4.3 47 | traitlets==4.0.0 48 | -------------------------------------------------------------------------------- /setup_pytorch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | source .env/bin/activate 3 | 4 | pip install numpy 5 | pip install http://download.pytorch.org/whl/cu80/torch-0.1.12.post2-cp35-cp35m-linux_x86_64.whl 6 | pip install torchvision 7 | deactivate 8 | -------------------------------------------------------------------------------- /setup_virtualenv.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | sudo apt-get update 4 | sudo apt-get install libncurses5-dev 5 | sudo apt-get install python-dev 6 | sudo apt-get install python-pip 7 | sudo apt-get install libjpeg8-dev 8 | sudo ln -s /usr/lib/x86_64-linux-gnu/libjpeg.so /usr/lib 9 | pip install pillow 10 | sudo apt-get build-dep python-imaging 11 | sudo apt-get install libjpeg8 libjpeg62-dev libfreetype6 libfreetype6-dev 12 | sudo pip install virtualenv 13 | virtualenv -p python3 .env # Create a virtual environment 14 | source .env/bin/activate # Activate the virtual environment 15 | pip install -r requirements.txt # Install dependencies 16 | deactivate 17 | echo "If you had no errors, you can proceed to work with your virtualenv as normal." 18 | echo "(run 'source .env/bin/activate' to load the venv," 19 | echo " and run 'deactivate' to exit the venv.)" 20 | -------------------------------------------------------------------------------- /start_jupyter_env.sh: -------------------------------------------------------------------------------- 1 | source .env/bin/activate 2 | jupyter-notebook --no-browser --port=7001 & 3 | --------------------------------------------------------------------------------