├── APLS-Visualizer ├── README.md └── apls_visualizer-dg │ ├── MANIFEST.txt │ ├── classes │ ├── geom │ │ ├── Graph$PQNode.class │ │ ├── Graph.class │ │ ├── LineString.class │ │ ├── Metrics.class │ │ ├── P2.class │ │ └── RoadSet.class │ └── visualizer │ │ ├── RoadVisualizer$1.class │ │ ├── RoadVisualizer$BandTriplet.class │ │ ├── RoadVisualizer$MapData.class │ │ ├── RoadVisualizer$MapView.class │ │ ├── RoadVisualizer$Results.class │ │ ├── RoadVisualizer.class │ │ └── Utils.class │ ├── data │ ├── band-triplets.txt │ ├── params.txt │ ├── params_dg.txt │ ├── params_sample.txt │ └── solution-example_blank.csv │ ├── src │ ├── docker │ │ └── BaselinePredictor.java │ ├── geom │ │ ├── Graph.java │ │ ├── LineString.java │ │ ├── Metrics.java │ │ ├── P2.java │ │ └── RoadSet.java │ ├── test │ │ ├── BatchMetricTest.java │ │ ├── LinestringToGeojson.java │ │ ├── MetricsSpeedTest.java │ │ └── MetricsTest.java │ └── visualizer │ │ ├── RoadVisualizer.java │ │ └── Utils.java │ ├── visualizer.jar │ ├── visualizer_lib │ ├── imageio-ext-geocore-1.1.16.jar │ ├── imageio-ext-streams-1.1.16.jar │ ├── imageio-ext-tiff-1.1.16.jar │ ├── imageio-ext-utilities-1.1.16.jar │ ├── jai_codec-1.1.3.jar │ ├── jai_core-1.1.3.jar │ └── jai_imageio-1.1.jar │ └── visualizer_readme.html ├── LICENSE ├── README.md ├── assests ├── .DS_Store └── images │ ├── AOI_2_Vegas_img33.png │ ├── mask_AOI_2_Vegas_img33.png │ └── overview.png ├── config.json ├── create_crops.py ├── data ├── deepglobe │ ├── train.txt │ └── val.txt └── spacenet │ ├── train.txt │ └── val.txt ├── data_utils ├── __init__.py ├── affinity_utils.py ├── graph_utils.py ├── rdp.py └── sknw.py ├── model ├── __init__.py ├── linknet.py ├── models.py └── stack_module.py ├── preprocessing ├── prepare_spacenet.sh └── spacenet │ ├── convert_to_8bit_png.py │ ├── create_gaussian_label.py │ └── geoTools.py ├── road_dataset.py ├── split_data.sh ├── train_mtl.py ├── train_refine_pre.py ├── utils ├── __init__.py ├── loss.py ├── util.py ├── viz_util.py └── viz_utils.py ├── visualize_dataset.ipynb ├── visualize_dataset_corrupt.ipynb └── visualize_tasks.ipynb /APLS-Visualizer/README.md: -------------------------------------------------------------------------------- 1 | ## APLS Visualizer modified for DG ## 2 | #### Borrowed from : https://community.topcoder.com/longcontest/?module=ViewProblemStatement&rd=17036&pm=14735 #### 3 | #### Need to test with current directory structure. Will be updated Soon! #### 4 | 5 | ## Potential Required Changes in "RoadVisualizer.java" 6 | * Change file extensions in function - "collectImageIds" 7 | * Change in function - "idToCity" 8 | * Change in function - "loadMap" (line - 510) 9 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/MANIFEST.txt: -------------------------------------------------------------------------------- 1 | Manifest-Version: 1.0 2 | Class-Path: . visualizer_lib/imageio-ext-geocore-1.1.16.jar visualizer_lib/imageio-ext-streams-1.1.16.jar visualizer_lib/imageio-ext-tiff-1.1.16.jar visualizer_lib/imageio-ext-utilities-1.1.16.jar visualizer_lib/jai_codec-1.1.3.jar visualizer_lib/jai_core-1.1.3.jar visualizer_lib/jai_imageio-1.1.jar 3 | Main-Class: visualizer.RoadVisualizer 4 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/geom/Graph$PQNode.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/geom/Graph$PQNode.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/geom/Graph.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/geom/Graph.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/geom/LineString.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/geom/LineString.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/geom/Metrics.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/geom/Metrics.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/geom/P2.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/geom/P2.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/geom/RoadSet.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/geom/RoadSet.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer$1.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer$1.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer$BandTriplet.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer$BandTriplet.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer$MapData.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer$MapData.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer$MapView.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer$MapView.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer$Results.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer$Results.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/visualizer/RoadVisualizer.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/classes/visualizer/Utils.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/classes/visualizer/Utils.class -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/data/band-triplets.txt: -------------------------------------------------------------------------------- 1 | # 1. Coastal: 400 - 450 nm 2 | # 2. Blue: 450 - 510 nm 3 | # 3. Green: 510 - 580 nm 4 | # 4. Yellow: 585 - 625 nm 5 | # 5. Red: 630 - 690 nm 6 | # 6. Red Edge: 705 - 745 nm 7 | # 7. Near-IR1: 770 - 895 nm 8 | # 8. Near-IR2: 860 - 1040 nm 9 | 10 | # Format: 3 integers in the 1..8 range, no spaces or any separators Name (spaces allowed) 11 | 12 | #753 Vegetation 13 | #875 Urban 14 | #781 Blackwater 15 | #777 Infra red 1 16 | #888 Infra red 2 -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/data/params.txt: -------------------------------------------------------------------------------- 1 | solution = ./data/solution-example_blank.csv 2 | 3 | truth = /ssd_scratch/cvit/anil.k/spacenet/AOI_2_Vegas_Roads_Train/summaryData/AOI_2_Vegas_Roads_Train.csv 4 | truth = /ssd_scratch/cvit/anil.k/spacenet/AOI_3_Paris_Roads_Train/summaryData/AOI_3_Paris_Roads_Train.csv 5 | truth = /ssd_scratch/cvit/anil.k/spacenet/AOI_4_Shanghai_Roads_Train/summaryData/AOI_4_Shanghai_Roads_Train.csv 6 | truth = /ssd_scratch/cvit/anil.k/spacenet/AOI_5_Khartoum_Roads_Train/summaryData/AOI_5_Khartoum_Roads_Train.csv 7 | 8 | image-dir = /ssd_scratch/cvit/anil.k/spacenet/AOI_2_Vegas_Roads_Train/ 9 | image-dir = /ssd_scratch/cvit/anil.k/spacenet/AOI_3_Paris_Roads_Train/ 10 | image-dir = /ssd_scratch/cvit/anil.k/spacenet/AOI_4_Shanghai_Roads_Train/ 11 | image-dir = /ssd_scratch/cvit/anil.k/spacenet/AOI_5_Khartoum_Roads_Train/ 12 | 13 | band-triplets = ./data/band-triplets.txt 14 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/data/params_dg.txt: -------------------------------------------------------------------------------- 1 | solution = ./data/solution_dg.csv 2 | 3 | truth = ./data/dg_truth_dg.csv 4 | image-dir = /Users/anilbatra/Downloads/DeepGlobe/valid 5 | w = 1500 6 | 7 | band-triplets = ./data/band-triplets.txt 8 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/data/params_sample.txt: -------------------------------------------------------------------------------- 1 | solution = ./data/solution-example.csv 2 | 3 | truth = ./data/SpaceNet_Roads_Sample/AOI_2_Vegas_Roads_Sample/summaryData/AOI_2_Vegas_Roads_Sample.csv 4 | truth = ./data/SpaceNet_Roads_Sample/AOI_3_Paris_Roads_Sample/summaryData/AOI_3_Paris_Roads_Sample.csv 5 | truth = ./data/SpaceNet_Roads_Sample/AOI_4_Shanghai_Roads_Sample/summaryData/AOI_4_Shanghai_Roads_Sample.csv 6 | truth = ./data/SpaceNet_Roads_Sample/AOI_5_Khartoum_Roads_Sample/summaryData/AOI_5_Khartoum_Roads_Sample.csv 7 | 8 | 9 | image-dir = ./data/SpaceNet_Roads_Sample/AOI_2_Vegas_Roads_Sample 10 | image-dir = ./data/SpaceNet_Roads_Sample/AOI_3_Paris_Roads_Sample 11 | image-dir = ./data/SpaceNet_Roads_Sample/AOI_4_Shanghai_Roads_Sample 12 | image-dir = ./data/SpaceNet_Roads_Sample/AOI_5_Khartoum_Roads_Sample 13 | 14 | band-triplets = ./data/band-triplets.txt -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/data/solution-example_blank.csv: -------------------------------------------------------------------------------- 1 | ImageId,WKT_Pix -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/src/docker/BaselinePredictor.java: -------------------------------------------------------------------------------- 1 | package docker; 2 | 3 | import java.io.BufferedWriter; 4 | import java.io.File; 5 | import java.io.FileWriter; 6 | import java.io.PrintWriter; 7 | import java.util.List; 8 | import java.util.Random; 9 | import java.util.Vector; 10 | 11 | public class BaselinePredictor { 12 | private static final double D = 5 / 0.31; 13 | private static final double junctionRatio = 0.3; 14 | private PrintWriter out; 15 | private int cnt; 16 | private Random rand; 17 | private List poss; 18 | 19 | public static void main(String[] args) { 20 | // TODO remove, test only 21 | //args = new String[] {"../data/train/AOI_3_Paris_Roads_Train", "out"}; 22 | 23 | if (args.length < 2) { 24 | System.out.println("Usage: docker.BaselinePredictor test_dir[...] out_file"); 25 | System.exit(-1); 26 | } 27 | try { 28 | new BaselinePredictor().run(args); 29 | } 30 | catch (Exception e) { 31 | e.printStackTrace(); 32 | System.exit(-1); 33 | } 34 | } 35 | 36 | private void run(String[] args) throws Exception { 37 | int n = args.length; 38 | String outFilePath = args[n-1]; 39 | 40 | out = new PrintWriter(new BufferedWriter(new FileWriter(outFilePath + ".txt"))); 41 | rand = new Random(0); 42 | poss = new Vector<>(); 43 | for (double x = D; x < 1300; x += 2*D) poss.add((int)(x)); 44 | 45 | for (int i = 0; i < n-1; i++) { 46 | File panDir = new File(args[i], "PAN"); 47 | if (!panDir.exists() || !panDir.isDirectory()) { 48 | System.out.println("PAN directory not found at " + panDir.getAbsolutePath()); 49 | } 50 | processDir(panDir); 51 | } 52 | out.close(); 53 | System.out.println("Done."); 54 | } 55 | 56 | private void processDir(File dir) throws Exception{ 57 | for (File f: dir.listFiles()) { 58 | String name = f.getName(); 59 | if (name.startsWith("PAN_") && name.endsWith(".tif")) { 60 | String id = name.replace("PAN_", ""); 61 | id = id.replace(".tif", ""); 62 | cnt++; 63 | System.out.println("Processing image " + id + " (" + cnt + ")"); 64 | 65 | int p1 = poss.get(0); 66 | int p2 = poss.get(poss.size()-1); 67 | // id,"LINESTRING (0.00 541.93, 484.20 687.83, 773.90 772.25)" 68 | for (int i: poss) { 69 | StringBuilder sb = new StringBuilder(); 70 | sb.append(id).append(",\"LINESTRING ("); 71 | for (int j: poss) { 72 | if (j == p1 || j == p2) { 73 | int j2 = rand.nextDouble() < junctionRatio ? j : j-1; 74 | sb.append(i).append(" ").append(j2).append(", "); 75 | } 76 | else if (rand.nextDouble() < junctionRatio) { 77 | sb.append(i).append(" ").append(j).append(", "); 78 | } 79 | } 80 | sb.delete(sb.length()-2, sb.length()); 81 | sb.append(")"); 82 | out.println(sb.toString()); 83 | } 84 | for (int i: poss) { 85 | StringBuilder sb = new StringBuilder(); 86 | sb.append(id).append(",\"LINESTRING ("); 87 | for (int j: poss) { 88 | sb.append(j).append(" ").append(i).append(", "); 89 | } 90 | sb.delete(sb.length()-2, sb.length()); 91 | sb.append(")"); 92 | out.println(sb.toString()); 93 | } 94 | } 95 | } 96 | } 97 | } 98 | 99 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/src/geom/Graph.java: -------------------------------------------------------------------------------- 1 | package geom; 2 | 3 | import java.awt.Color; 4 | import java.awt.Graphics2D; 5 | import java.awt.RenderingHints; 6 | import java.awt.image.BufferedImage; 7 | import java.io.File; 8 | import java.io.IOException; 9 | import java.util.Arrays; 10 | import java.util.HashMap; 11 | import java.util.HashSet; 12 | import java.util.List; 13 | import java.util.Map; 14 | import java.util.PriorityQueue; 15 | import java.util.Set; 16 | 17 | import javax.imageio.ImageIO; 18 | 19 | import visualizer.Utils; 20 | 21 | public class Graph { 22 | public Set nodes = new HashSet<>(); 23 | /* 24 | * This mapping allows multiple nodes to be injected at the same location. 25 | * All connectivity info and route calculations are done on the canonic nodes, but 26 | * we bookkeep the aliases as well. 27 | */ 28 | public Map aliasToCanonicNode = new HashMap<>(); 29 | public double minx, miny, maxx, maxy; 30 | 31 | // debug only 32 | public static final int M = 10; // drawing margin 33 | 34 | public Graph copy() { // deep copy of everything 35 | Graph g_ = new Graph(); 36 | Map nodeMap = new HashMap<>(); 37 | for (P2 p: nodes) { 38 | P2 p_ = p.copy(); 39 | nodeMap.put(p, p_); 40 | g_.nodes.add(p_); 41 | } 42 | for (P2 p: nodes) { 43 | for (LineString e: p.edges) { 44 | if (p != e.p1) continue; // enough to do this in one direction 45 | LineString e_ = new LineString(); 46 | e_.p1 = nodeMap.get(p); 47 | e_.p2 = nodeMap.get(e.otherEnd(p)); 48 | for (P2 n: e.points) { 49 | P2 n_ = nodeMap.get(n); 50 | if (n_ == null) { 51 | n_ = n.copy(); 52 | } 53 | e_.points.add(n_); 54 | } 55 | e_.updated(); 56 | e_.p1.addEdge(e_); 57 | e_.p2.addEdge(e_); 58 | } 59 | } 60 | // no need to copy aliasToCanonicNode 61 | g_.updated(); 62 | return g_; 63 | } 64 | 65 | // Set up the graph from a set of roads 66 | public static Graph fromRoads(RoadSet rs) { 67 | Graph g = new Graph(); 68 | 69 | // create nodes and connect them with 1-long segments 70 | Map hashToNode = new HashMap<>(); 71 | for (LineString path: rs.roads) { 72 | P2 prevN = null; 73 | for (P2 p: path.points) { 74 | P2 n = new P2(p.x, p.y); 75 | int hash = n.hashCode(); 76 | if (hashToNode.containsKey(hash)) { 77 | n = hashToNode.get(hash); 78 | } 79 | else { 80 | g.nodes.add(n); 81 | hashToNode.put(hash, n); 82 | } 83 | 84 | if (prevN != null) { 85 | LineString e = new LineString(prevN, n); 86 | prevN.addEdge(e); 87 | n.addEdge(e); 88 | } 89 | prevN = n; 90 | } 91 | } 92 | 93 | // Simplify 94 | // Step 1: remove nodes with 2 neighbours, update paths to keep geometry 95 | Set simpleNodes = new HashSet<>(); 96 | for (P2 p: g.nodes) if (p.edges.size() == 2) simpleNodes.add(p); 97 | Set toRemove = new HashSet<>(); 98 | for (P2 p: g.nodes) { 99 | if (simpleNodes.contains(p)) continue; 100 | 101 | LineString[] oldEdges = p.edges.toArray(new LineString[0]); 102 | for (LineString edge: oldEdges) { 103 | P2 nextP = edge.otherEnd(p); 104 | 105 | if (toRemove.contains(nextP)) { // already processed from the other direction 106 | p.edges.remove(edge); 107 | continue; 108 | } 109 | 110 | // keep edge if other end is also non-simple 111 | if (!simpleNodes.contains(nextP)) continue; 112 | 113 | LineString newEdge = new LineString(); 114 | newEdge.points.add(p); 115 | while (true) { 116 | newEdge.points.add(nextP); 117 | List neighbours = nextP.getNeighbours(); 118 | toRemove.add(nextP); 119 | simpleNodes.remove(nextP); 120 | P2 prevP = newEdge.points.get(newEdge.points.size() - 2); 121 | nextP = null; 122 | for (P2 np: neighbours) { 123 | if (np != prevP) { 124 | nextP = np; break; 125 | } 126 | } 127 | if (nextP == null) { // shouldn't happen 128 | System.out.println("Err in simplify"); 129 | } 130 | 131 | if (!simpleNodes.contains(nextP)) { 132 | newEdge.points.add(nextP); 133 | newEdge.p1 = p; 134 | newEdge.p2 = nextP; 135 | newEdge.updated(); 136 | p.edges.remove(edge); 137 | nextP.edges.remove(edge); 138 | p.addEdge(newEdge); 139 | if (p != nextP) { 140 | nextP.addEdge(newEdge); 141 | } 142 | break; 143 | } 144 | } 145 | } 146 | } // for nodes 147 | for (P2 p: toRemove) { 148 | g.nodes.remove(p); 149 | p.edges.clear(); 150 | } 151 | 152 | // Previously we processed closed loops, by keeping just one node of the cycle and 153 | // keeping the rest only as path. To mimic networkx.simplify() functionality now we keep 154 | // all 2-connected nodes in cycles. 155 | /* 156 | while (!simpleNodes.isEmpty()) { 157 | // keep the one with smallest coords 158 | double miny = Double.MAX_VALUE; 159 | P2 startP = null; 160 | for (P2 p2: simpleNodes) { 161 | if (p2.y < miny) { 162 | startP = p2; 163 | miny = p2.y; 164 | } 165 | else if (startP != null && p2.y == startP.y) { 166 | if (p2.x < startP.x) { 167 | startP = p2; 168 | } 169 | } 170 | } 171 | simpleNodes.remove(startP); 172 | LineString edge = startP.edges.iterator().next(); 173 | P2 nextP = edge.otherEnd(startP); 174 | startP.edges.clear(); 175 | LineString newEdge = new LineString(); 176 | newEdge.points.add(startP); 177 | while (true) { 178 | newEdge.points.add(nextP); 179 | List neighbours = nextP.getNeighbours(); 180 | g.nodes.remove(nextP); 181 | nextP.edges.clear(); 182 | simpleNodes.remove(nextP); 183 | P2 prevP = newEdge.points.get(newEdge.points.size() - 2); 184 | nextP = null; 185 | for (P2 np: neighbours) { 186 | if (np != prevP) { 187 | nextP = np; break; 188 | } 189 | } 190 | if (nextP == null) { // shouldn't happen 191 | System.out.println("Err in simplify"); 192 | } 193 | 194 | if (nextP == startP) { // finished loop 195 | newEdge.points.add(nextP); 196 | newEdge.p1 = startP; 197 | newEdge.p2 = nextP; 198 | startP.addEdge(newEdge); 199 | newEdge.updated(); 200 | break; 201 | } 202 | } 203 | } // there are simple nodes 204 | */ 205 | 206 | g.updated(); 207 | return g; 208 | } 209 | 210 | public void updated() { 211 | minx = miny = Double.MAX_VALUE; 212 | maxx = maxy = -Double.MAX_VALUE; 213 | for (P2 p: nodes) { 214 | for (LineString e: p.edges) { 215 | minx = Math.min(minx, e.minx); 216 | maxx = Math.max(maxx, e.maxx); 217 | miny = Math.min(miny, e.miny); 218 | maxy = Math.max(maxy, e.maxy); 219 | } 220 | } 221 | } 222 | 223 | public void insertMidpoints(double pathDelta, double minCurvature) { 224 | Set allEdges = new HashSet<>(); 225 | Set processedEdges = new HashSet<>(); 226 | for (P2 p: nodes) { 227 | for (LineString e: p.edges) { 228 | allEdges.add(e); 229 | } 230 | } 231 | for (LineString e: allEdges) { 232 | if (processedEdges.contains(e)) continue; // done from the other end already 233 | 234 | double straight = new P2(e.minx, e.miny).dist(new P2(e.maxx, e.maxy)); 235 | if (Math.abs(straight - e.length) / e.length < minCurvature) continue; 236 | if (e.length < 0.75 * pathDelta) continue; 237 | 238 | int n; 239 | double dist; 240 | if (e.length < pathDelta) { 241 | n = 1; 242 | dist = e.length / 2; 243 | } 244 | else { 245 | n = (int)(Math.floor(e.length / pathDelta)); 246 | dist = e.length / (n+1); 247 | } 248 | 249 | // p1....e....m...e......e...p2 250 | 251 | P2 p1 = e.p1; 252 | P2 p2 = e.p2; 253 | p1.edges.remove(e); 254 | p2.edges.remove(e); 255 | P2 startP = p1; // start node of the new edge to build 256 | LineString remainingEdge = e; 257 | for (int i = 0; i < n; i++) { 258 | LineString[] es = remainingEdge.cut(dist); 259 | LineString e1 = es[0]; 260 | LineString e2 = es[1]; 261 | P2 newP = e2.points.get(0); 262 | startP.addEdge(e1); 263 | newP.addEdge(e1); 264 | nodes.add(newP); 265 | startP = newP; 266 | remainingEdge = e2; 267 | } 268 | // last section 269 | startP.addEdge(remainingEdge); 270 | p2.addEdge(remainingEdge); 271 | 272 | processedEdges.add(e); 273 | } 274 | } 275 | 276 | public P2 injectPoint(P2 externalP, double maxDistance) { 277 | // try special case first: it exactly matches one of the old nodes 278 | for (P2 p: nodes) { 279 | if (p.equals(externalP)) { 280 | P2 newP = p.copy(); 281 | // no need to change graph, only mark new node as alias of old 282 | aliasToCanonicNode.put(newP, p); 283 | return newP; 284 | } 285 | } 286 | 287 | double minDist = Double.MAX_VALUE; 288 | LineString bestEdge = null; 289 | P2 bestP1 = null; // start of line segment on bestEdge where newP is injected 290 | P2 bestNewP = null; 291 | P2 oldMatch = null; // non-null if matched to an existing node 292 | NODE_LOOP: 293 | for (P2 p: nodes) { 294 | for (LineString e: p.edges) { 295 | if (p != e.p1) continue; // enough to do this in one direction 296 | if (externalP.x < e.minx - minDist) continue; 297 | if (externalP.y < e.miny - minDist) continue; 298 | if (externalP.x > e.maxx + minDist) continue; 299 | if (externalP.y > e.maxy + minDist) continue; 300 | 301 | P2 p1 = null; 302 | for (P2 p2: e.points) { 303 | if (p1 != null) { 304 | P2 newP = externalP.projectToLineSegment(p1, p2); 305 | double d = newP.distance; 306 | if (d < minDist) { 307 | minDist = d; 308 | bestEdge = e; 309 | if (newP.equals(p2)) { 310 | bestP1 = p2; 311 | } 312 | else { 313 | bestP1 = p1; 314 | } 315 | bestNewP = newP; 316 | if (newP.equals(e.p1)) oldMatch = e.p1; 317 | else if (newP.equals(e.p2)) oldMatch = e.p2; 318 | else oldMatch = null; 319 | } 320 | if (d == 0) { // can't be better, stop 321 | break NODE_LOOP; 322 | } 323 | } 324 | p1 = p2; 325 | } 326 | } 327 | } // for nodes and edges 328 | 329 | if (minDist > maxDistance) { 330 | return null; 331 | } 332 | 333 | if (oldMatch != null) { // no need to change graph, only mark new node as alias of old 334 | aliasToCanonicNode.put(bestNewP, oldMatch); 335 | } 336 | else { // add newP, split edge 337 | 338 | // p1....p....newP...p......p...p2 339 | LineString e1 = new LineString(); 340 | LineString e2 = new LineString(); 341 | for (int i = 0; i < bestEdge.points.size(); i++) { 342 | P2 p = bestEdge.points.get(i); 343 | if (p.equals(bestP1)) { 344 | if (!bestNewP.equals(bestP1)) { 345 | e1.points.add(bestP1); 346 | } 347 | e1.points.add(bestNewP); 348 | e2.points.add(bestNewP); 349 | for (int j = i+1; j < bestEdge.points.size(); j++) { 350 | e2.points.add(bestEdge.points.get(j)); 351 | } 352 | break; 353 | } 354 | e1.points.add(p); 355 | } 356 | e1.updated(); 357 | e2.updated(); 358 | bestEdge.p1.edges.remove(bestEdge); 359 | bestEdge.p2.edges.remove(bestEdge); 360 | nodes.add(bestNewP); 361 | bestNewP.addEdge(e1); 362 | bestNewP.addEdge(e2); 363 | bestEdge.p1.addEdge(e1); 364 | bestEdge.p2.addEdge(e2); 365 | } 366 | 367 | return bestNewP; 368 | } 369 | 370 | // A copy of a node to live in the priority queue during Dijkstra 371 | private class PQNode implements Comparable { 372 | public P2 p; 373 | public double distance; 374 | 375 | public PQNode(P2 p, double d) { 376 | this.p = p; 377 | distance = d; 378 | } 379 | 380 | @Override 381 | public int compareTo(PQNode other) { 382 | return Double.compare(this.distance, other.distance); 383 | } 384 | 385 | @Override 386 | public String toString() { 387 | return p.toString() + ": " + distance; 388 | } 389 | } 390 | 391 | // A priority queue based implementation of Dijkstra algorithm 392 | public void shortestPathsFromNode(P2 startP) { 393 | Set seen = new HashSet<>(); 394 | PriorityQueue q = new PriorityQueue<>(); 395 | for (P2 p: nodes) { 396 | double distance = p == startP ? 0 : Double.MAX_VALUE; 397 | p.distance = distance; 398 | q.add(new PQNode(p, distance)); 399 | } 400 | while (!q.isEmpty()) { 401 | PQNode node = q.poll(); 402 | if (seen.contains(node.p)) continue; 403 | seen.add(node.p); 404 | 405 | for (LineString e: node.p.edges) { 406 | P2 p2 = e.otherEnd(node.p); 407 | if (seen.contains(p2)) continue; 408 | double dist = node.distance + e.length; 409 | if (dist < p2.distance) { 410 | p2.distance = dist; 411 | q.add(new PQNode(p2, dist)); 412 | } 413 | } 414 | } 415 | } 416 | 417 | @Override 418 | public String toString() { 419 | StringBuilder sb = new StringBuilder(); 420 | P2[] pArr = nodes.toArray(new P2[0]); 421 | Arrays.sort(pArr); 422 | for (P2 p: pArr) { 423 | sb.append(p).append("\n"); 424 | for (LineString e: p.edges) { 425 | sb.append(" ").append(e).append("\n"); 426 | } 427 | } 428 | return sb.toString(); 429 | } 430 | 431 | public void save(String name, boolean mirror) throws IOException { // debug only 432 | int w = 500; 433 | BufferedImage img = new BufferedImage(w, w, BufferedImage.TYPE_INT_ARGB); 434 | draw(img, Color.white, -1, w, 0, 0, mirror); 435 | ImageIO.write(img, "png", new File(name)); 436 | } 437 | 438 | public void draw(BufferedImage img, Color c, double range, int w, int x0, int y0, boolean mirror) { // debug only 439 | Graphics2D g2 = (Graphics2D) img.getGraphics(); 440 | g2.setRenderingHint(RenderingHints.KEY_ANTIALIASING, 441 | RenderingHints.VALUE_ANTIALIAS_ON); 442 | int r = c.getRed(); 443 | int g = c.getGreen(); 444 | int b = c.getBlue(); 445 | int yBase = mirror ? w - M : M; 446 | if (range <= 0) range = Math.max(maxx - minx, maxy - miny); 447 | double scale = (w - 2*M) / range; 448 | double ySign = mirror ? -1 : 1; 449 | g2.setColor(new Color(r,g,b,180)); 450 | for (P2 p: nodes) { 451 | for (LineString e: p.edges) { 452 | if (e.p1 != e.p2) { 453 | if (e.p2 == p) continue; 454 | } 455 | P2 prev = null; 456 | for (P2 p2: e.points) { 457 | if (prev != null) { 458 | int x1 = (int)(x0 + M + (p2.x - minx) * scale); 459 | int y1 = (int)(y0 + yBase + (p2.y - miny)* scale * ySign); 460 | int x2 = (int)(x0 + M + (prev.x - minx) * scale); 461 | int y2 = (int)(y0 + yBase + (prev.y - miny) * scale * ySign); 462 | g2.drawLine(x1, y1, x2, y2); 463 | } 464 | prev = p2; 465 | } 466 | } 467 | } 468 | 469 | Color c1 = new Color(r,g,b,180); 470 | Color c2 = new Color(255,255,0,180); 471 | Color c3 = new Color(255,0,255,180); 472 | 473 | for (P2 p: nodes) { 474 | int n = p.edges.size(); 475 | if (n == 1) g2.setColor(c1); 476 | else if (n > 2) g2.setColor(c3); 477 | else { // two edges 478 | g2.setColor(c2); 479 | // unless there are self loops 480 | for (LineString e: p.edges) { 481 | if (e.p1 == e.p2) g2.setColor(c3); 482 | } 483 | } 484 | int x = (int)(x0 + M + (p.x - minx) * scale); 485 | int y = (int)(y0 + yBase + (p.y - miny) * scale * ySign); 486 | g2.fillOval(x-3, y-3, 6, 6); 487 | } 488 | } 489 | 490 | // test only 491 | public static void main(String[] args) throws IOException { 492 | RoadSet rs = new RoadSet(); 493 | rs.roads.add(LineString.fromText("LINESTRING (1 1, 3 1, 5 1, 5 3, 5 4, 3 3, 3 1)")); 494 | rs.roads.add(LineString.fromText("LINESTRING (6 1, 8 1, 8 3, 6 3, 6 1)")); 495 | rs.roads.add(LineString.fromText("LINESTRING (8 4, 8 6, 6 6, 6 4, 8 4, 10 4, 10 6, 8 6)")); 496 | rs.roads.add(LineString.fromText("LINESTRING (5 4, 6 4)")); 497 | 498 | String err = rs.getError(); 499 | if (err != null) { 500 | System.out.println(err); 501 | System.exit(0); 502 | } 503 | Graph g = fromRoads(rs); 504 | 505 | for (P2 start: g.nodes) { 506 | g.shortestPathsFromNode(start); 507 | for (P2 p2: g.nodes) { 508 | if (!start.equals(p2) && p2.distance < Double.MAX_VALUE) { 509 | System.out.println(start + " -> " + p2 + " : " + Utils.f(p2.distance)); 510 | } 511 | } 512 | } 513 | 514 | System.out.println(g); 515 | g.save("out.png", true); 516 | } 517 | } 518 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/src/geom/LineString.java: -------------------------------------------------------------------------------- 1 | package geom; 2 | 3 | import java.util.List; 4 | import java.util.Vector; 5 | 6 | import visualizer.Utils; 7 | 8 | public class LineString { 9 | public List points; 10 | // end points 11 | public P2 p1; 12 | public P2 p2; 13 | 14 | public double length; 15 | public double minx, miny, maxx, maxy; // bounding rect 16 | 17 | public LineString() { 18 | points = new Vector<>(); 19 | } 20 | 21 | public LineString(P2 p1, P2 p2) { 22 | this(); 23 | this.p1 = p1; 24 | this.p2 = p2; 25 | points.add(p1); 26 | points.add(p2); 27 | updated(); 28 | } 29 | 30 | /* 31 | * Not fool-proof parsing, but handles most cases correctly. 32 | * Should be called in try-catch. 33 | */ 34 | public static LineString fromText(String s) { 35 | LineString ret = new LineString(); 36 | // "LINESTRING (250 250, 250 350, 1050 350)" 37 | s = s.replace("\"", ""); 38 | s = s.toUpperCase().trim(); 39 | if (!s.startsWith("LINESTRING")) return null; 40 | 41 | if (s.contains("EMPTY")) return ret; 42 | s = s.replace("LINESTRING", ""); 43 | s = s.replace("(", ""); 44 | s = s.replace(")", ""); 45 | String[] parts = s.split(","); 46 | for (String coords: parts) { 47 | coords = coords.trim(); 48 | String[] xy = coords.split(" "); 49 | double x = Double.parseDouble(xy[0]); 50 | double y = Double.parseDouble(xy[1]); 51 | P2 p = new P2(x, y); 52 | ret.points.add(p); 53 | } 54 | ret.updated(); 55 | return ret; 56 | } 57 | 58 | public P2 otherEnd(P2 p) { 59 | if (p.equals(p1)) return p2; 60 | return p1; 61 | } 62 | 63 | public void updated() { 64 | p1 = points.get(0); 65 | p2 = points.get(points.size() - 1); 66 | double len = 0; 67 | minx = miny = Double.MAX_VALUE; 68 | maxx = maxy = -Double.MAX_VALUE; 69 | P2 prev = null; 70 | for (P2 p: points) { 71 | if (prev != null) len += prev.dist(p); 72 | minx = Math.min(minx, p.x); 73 | maxx = Math.max(maxx, p.x); 74 | miny = Math.min(miny, p.y); 75 | maxy = Math.max(maxy, p.y); 76 | prev = p; 77 | } 78 | length = len; 79 | } 80 | 81 | // Creates two linestrings, cut at distance from p1 82 | public LineString[] cut(double distance) { 83 | if (distance < 0) distance = 0; 84 | if (distance > length) distance = length; 85 | 86 | LineString e1 = new LineString(); 87 | e1.points.add(p1); 88 | LineString e2 = new LineString(); 89 | double total = 0; 90 | // p1....e......e.....m....e...p2 91 | for (int i = 1; i < points.size(); i++) { 92 | P2 prev = points.get(i-1); 93 | P2 next = points.get(i); 94 | double d = prev.dist(next); 95 | if (total + d < distance) { 96 | total += d; 97 | e1.points.add(next); 98 | } 99 | else if (total + d == distance) { 100 | total += d; 101 | 102 | e1.points.add(next); 103 | e1.p1 = p1; 104 | e1.p2 = next; 105 | e1.updated(); 106 | 107 | for (int j = i; j < points.size(); j++) { 108 | e2.points.add(points.get(j)); 109 | } 110 | e2.p1 = next; 111 | e2.p2 = p2; 112 | e2.updated(); 113 | 114 | break; 115 | } 116 | else { // total + d > distance: project on segment 117 | double r = (distance - total) / d; 118 | double x = prev.x + r * (next.x - prev.x); 119 | double y = prev.y + r * (next.y - prev.y); 120 | P2 mid = new P2(x, y); 121 | 122 | e1.points.add(mid); 123 | e1.p1 = p1; 124 | e1.p2 = mid; 125 | e1.updated(); 126 | 127 | e2.points.add(mid); 128 | for (int j = i; j < points.size(); j++) { 129 | e2.points.add(points.get(j)); 130 | } 131 | e2.p1 = mid; 132 | e2.p2 = p2; 133 | e2.updated(); 134 | 135 | break; 136 | } 137 | } 138 | 139 | return new LineString[] {e1, e2}; 140 | } 141 | 142 | @Override 143 | public String toString() { 144 | String ret = ""; 145 | for (P2 p: points) ret += p + " "; 146 | ret += " Len: " + Utils.f(length); 147 | return ret; 148 | } 149 | 150 | // test only 151 | public static void main(String[] args) { 152 | LineString e = LineString.fromText("LINESTRING (1 1, 3 1, 5 1, 5 3, 5 4, 3 3, 3 1)"); 153 | LineString[] es = e.cut(5); 154 | System.out.println(es[0]); 155 | System.out.println(es[1]); 156 | } 157 | } -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/src/geom/Metrics.java: -------------------------------------------------------------------------------- 1 | package geom; 2 | 3 | import java.awt.Color; 4 | import java.awt.Graphics2D; 5 | import java.awt.image.BufferedImage; 6 | import java.io.File; 7 | import java.io.IOException; 8 | import java.util.HashMap; 9 | import java.util.Map; 10 | 11 | import javax.imageio.ImageIO; 12 | 13 | import static visualizer.Utils.*; 14 | 15 | /* 16 | * Implements the scoring logic. Optionally can draw the graphs 17 | * and output detailed log. 18 | */ 19 | public class Metrics { 20 | private static final double PATH_DELTA = 50 / 0.31; 21 | private static final double MIN_CURVATURE = -1; 22 | public static final double MAX_SNAP_DISTANCE = 4 / 0.31; 23 | 24 | // variables for debug only 25 | public static boolean debug = true; 26 | public static boolean draw = true; 27 | private static BufferedImage img; 28 | private static Graphics2D g2d; 29 | public static String imageName = "metrics"; 30 | private static int w = 500; 31 | private static int callCnt = 0; 32 | private static double range; 33 | 34 | /* 35 | * Returns a double[3] with 3 scores: G1->G2, G2->G1, harmonic mean. 36 | * All 3 numbers are in [0..1], where 0 is bad, 1 is perfect. 37 | * The 'debug' and 'draw' switches should be set externally before calling this. 38 | */ 39 | public static double[] score(Graph g1, Graph g2) { 40 | if (g1.nodes.isEmpty() && g2.nodes.isEmpty()) { 41 | return new double[]{1, 1, 1}; 42 | } 43 | if (g1.nodes.isEmpty() || g2.nodes.isEmpty()) { 44 | return new double[]{0, 0, 0}; 45 | } 46 | 47 | g1.insertMidpoints(PATH_DELTA, MIN_CURVATURE); 48 | g2.insertMidpoints(PATH_DELTA, MIN_CURVATURE); 49 | 50 | if (draw) { 51 | img = new BufferedImage(w, w + 4 * Graph.M, BufferedImage.TYPE_INT_ARGB); 52 | g2d = (Graphics2D) img.getGraphics(); 53 | g2d.setColor(Color.black); 54 | g2d.fillRect(0, 0, img.getWidth(), img.getHeight()); 55 | double range1 = Math.max(g1.maxx - g1.minx, g1.maxy - g1.miny); 56 | double range2 = Math.max(g2.maxx - g2.minx, g2.maxy - g2.miny); 57 | range = Math.max(range1, range2); 58 | g1.draw(img, Color.white, range, w/2, 0, 2 * Graph.M, true); 59 | g2.draw(img, Color.white, range, w/2, w/2, 2 * Graph.M, true); 60 | g2d.setColor(Color.white); 61 | g2d.drawString("G1", 20, 2 * Graph.M); 62 | g2d.drawString("G2", w/2 + 20, 2 * Graph.M); 63 | callCnt = 0; 64 | } 65 | 66 | double s1 = scoreOneWay(g1, g2); 67 | double s2 = scoreOneWay(g2, g1); 68 | double s = 0; 69 | if (s1 + s2 > 0) { 70 | s = 2 * s1 * s2 / (s1 + s2); 71 | } 72 | log("\nS1: =" + s1); 73 | log("\nS2: =" + s2); 74 | log("\nS: =" + s); 75 | log("\n"); 76 | 77 | if (draw) { 78 | try { 79 | ImageIO.write(img, "png", new File(imageName + ".png")); 80 | } 81 | catch (IOException e) { 82 | e.printStackTrace(); 83 | } 84 | } 85 | 86 | return new double[]{s1, s2, s}; 87 | } 88 | 89 | private static double scoreOneWay(Graph g1, Graph g2) { 90 | // inject points of g1 into g2 91 | Map nodeMap = new HashMap<>(); 92 | Graph g2_ = g2.copy(); 93 | for (P2 p1: g1.nodes) { 94 | P2 p2 = g2_.injectPoint(p1, MAX_SNAP_DISTANCE); 95 | if (p2 != null) { 96 | nodeMap.put(p1, p2); 97 | } 98 | } 99 | 100 | if (debug) { 101 | log("\nG1:\n" + g1); 102 | log("\nG2':\n" + g2_); 103 | log(""); 104 | } 105 | double totalDiff = 0; // sum of error 106 | int routeCnt = 0; // number of compared routes 107 | for (P2 start1: g1.nodes) { 108 | g1.shortestPathsFromNode(start1); 109 | 110 | P2 start2 = nodeMap.get(start1); 111 | 112 | if (start2 == null) { 113 | // CASE 1 114 | // if the start node is missing from proposal, use maximum diff for 115 | // all possible routes from the start node 116 | int missCnt = 0; 117 | for (P2 p: g1.nodes) { 118 | if (p != start1 && p.distance < Double.MAX_VALUE) { 119 | missCnt++; 120 | } 121 | } 122 | totalDiff += missCnt; 123 | routeCnt += missCnt; 124 | if (debug) log(" " + start1 + ": no match for start in G2_, missed " + missCnt); 125 | } 126 | else { 127 | if (debug) log(" " + start1 + " ->"); 128 | // found matching node in g2, compare routes. Use canonic nodes in g2_! 129 | if (g2_.aliasToCanonicNode.containsKey(start2)) { 130 | start2 = g2_.aliasToCanonicNode.get(start2); 131 | } 132 | g2_.shortestPathsFromNode(start2); 133 | for (P2 end1: g1.nodes) { 134 | if (end1 == start1) continue; 135 | 136 | P2 end2 = nodeMap.get(end1); 137 | if (end2 != null && g2_.aliasToCanonicNode.containsKey(end2)) { 138 | end2 = g2_.aliasToCanonicNode.get(end2); 139 | } 140 | 141 | double d1 = end1.distance; 142 | if (d1 < Double.MAX_VALUE) { 143 | // there is route between start1 and end1 144 | routeCnt++; 145 | if (end2 == null) { 146 | // CASE 3: no such node in g2, max penalty 147 | totalDiff++; 148 | if (debug) log(" " + end1 + ": no match for end in G2_"); 149 | continue; 150 | } 151 | double d2 = end2.distance; 152 | if (d2 == Double.MAX_VALUE) { 153 | // CASE 3b: no route in g2, max penalty 154 | totalDiff++; 155 | if (debug) log(" " + end1 + ": no route in G2_"); 156 | } 157 | else { 158 | // CASE 2: both paths exist, compare them 159 | double diff = routeDiff(d1, d2); 160 | totalDiff += diff; 161 | if (debug) log(" " + end1 + ": " + f(d1) + " / " + f(d2) + " => " + f(diff)); 162 | } 163 | } 164 | } // for end1 165 | } // start2 not null 166 | } // for start1 167 | 168 | if (routeCnt > 0) totalDiff /= routeCnt; 169 | if (debug) log("\nAverage diff: " + f(totalDiff)); 170 | 171 | double score = 1 - totalDiff; 172 | if (draw) { 173 | g2_.draw(img, Color.white, range, w/2, w/2 * callCnt, w/2 + 2 * Graph.M, true); 174 | g2d.setColor(Color.white); 175 | String msg = callCnt == 0 ? "G2' (G1->G2): " : "G1' (G2->G1): "; 176 | msg += f(score); 177 | g2d.drawString(msg, 20 + w/2 * callCnt, w + 4*Graph.M-2); 178 | callCnt++; 179 | } 180 | 181 | return score; 182 | } 183 | 184 | private static double routeDiff(double d1, double d2) { 185 | if (d1 == 0 && d2 == 0) return 0; 186 | if (d1 == 0 || d2 == 0) return 1; 187 | return Math.min(1, Math.abs(d1 - d2) / d1); 188 | } 189 | 190 | private static void log(String s) { 191 | System.out.println(s); 192 | } 193 | 194 | // test only 195 | public static void main(String[] args) { 196 | RoadSet rs1 = new RoadSet(); 197 | rs1.roads.add(LineString.fromText("LINESTRING (1 1, 3 1, 5 1, 5 3, 5 4, 3 4, 3 1)")); 198 | Graph g1 = Graph.fromRoads(rs1); 199 | 200 | RoadSet rs2 = new RoadSet(); 201 | rs2.roads.add(LineString.fromText("LINESTRING (1 1, 3 1, 5 1, 5 3, 5 4, 3 4)")); 202 | rs2.roads.add(LineString.fromText("LINESTRING (1 2, 1 3)")); 203 | Graph g2 = Graph.fromRoads(rs2); 204 | 205 | double[] scores = score(g1, g2); 206 | System.out.println(f(scores[0]) + ", " + f(scores[1]) + " : " + f(scores[2])); 207 | 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/src/geom/P2.java: -------------------------------------------------------------------------------- 1 | package geom; 2 | 3 | import static visualizer.Utils.f; 4 | 5 | import java.util.HashSet; 6 | import java.util.List; 7 | import java.util.Set; 8 | import java.util.Vector; 9 | 10 | public class P2 implements Comparable { 11 | public double x; 12 | public double y; 13 | public Set edges; 14 | public double distance; // helper to return various tmp results 15 | 16 | public P2(double x, double y) { 17 | this.x = x; this.y = y; 18 | edges = new HashSet<>(); 19 | } 20 | 21 | public P2 copy() { 22 | P2 p = new P2(x, y); 23 | // NOTE edges are not copied 24 | return p; 25 | } 26 | 27 | public void addEdge(LineString edge) { 28 | edges.add(edge); 29 | } 30 | 31 | public List getNeighbours() { 32 | List ret = new Vector<>(); 33 | for (LineString e: edges) { 34 | ret.add(e.otherEnd(this)); 35 | } 36 | return ret; 37 | } 38 | 39 | public double dist(P2 p) { 40 | return Math.hypot(x-p.x, y-p.y); 41 | } 42 | 43 | public static double dist(P2 p, P2 r) { 44 | return p.dist(r); 45 | } 46 | 47 | // Find projection point on a segment. Distance is returned in 'distance'. 48 | public P2 projectToLineSegment(P2 a, P2 b) { 49 | if (this.equals(a)) return a.copy(); 50 | if (this.equals(b)) return b.copy(); 51 | 52 | double dx = b.x - a.x; 53 | double dy = b.y - a.y; 54 | double len2 = dx*dx + dy*dy; 55 | double u = ((x - a.x) * dx + (y - a.y) * dy) / len2; 56 | P2 ret; 57 | if (u > 1) { 58 | ret = b.copy(); 59 | } 60 | else if (u < 0) { 61 | ret = a.copy(); 62 | } 63 | else { 64 | double px = a.x + u * dx; 65 | double py = a.y + u * dy; 66 | ret = new P2(px, py); 67 | } 68 | ret.distance = this.dist(ret); 69 | return ret; 70 | } 71 | 72 | @Override 73 | public String toString() { 74 | return "(" + f(x) + "," + f(y) + ")"; 75 | } 76 | 77 | @Override 78 | public boolean equals(Object o) { 79 | // supporting 0.1 precision 80 | if (!(o instanceof P2)) return false; 81 | P2 p = (P2)o; 82 | double d2 = (x - p.x) * (x - p.x) + (y - p.y) * (y - p.y); 83 | return d2 < 1e-2; 84 | } 85 | 86 | @Override 87 | public int hashCode() { 88 | // 1354.3;1789.6 -> 1354317896 89 | long x1 = Math.round(10 * x); 90 | long y1 = Math.round(10 * y); 91 | return (int)(100000 * x1 + y1); 92 | } 93 | 94 | @Override 95 | public int compareTo(P2 o) { 96 | return this.hashCode() - o.hashCode(); 97 | } 98 | } -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/src/geom/RoadSet.java: -------------------------------------------------------------------------------- 1 | package geom; 2 | 3 | import java.util.HashSet; 4 | import java.util.List; 5 | import java.util.Set; 6 | import java.util.Vector; 7 | import java.util.regex.Matcher; 8 | import java.util.regex.Pattern; 9 | 10 | /* 11 | * A list of roads (as LineStrings) plus error checking logic 12 | */ 13 | public class RoadSet { 14 | public List roads = new Vector<>(); 15 | 16 | public String getError() { 17 | // There is an EMPTY one and more than one roads 18 | if (roads.size() > 1) { 19 | for (LineString road: roads) { 20 | if (road.points.size() == 0) { 21 | return "LINESTRING EMPTY should be alone in a road network"; 22 | } 23 | } 24 | } 25 | 26 | // stand-alone points 27 | for (LineString road: roads) { 28 | if (road.points.size() == 1) { 29 | return "Unconnected point at " + road.points.get(0); 30 | } 31 | } 32 | 33 | // empty sections 34 | for (LineString road: roads) { 35 | P2 p1 = null; 36 | for (P2 p2: road.points) { 37 | if (p1 != null && p2.dist(p1) < 0.01) return "Empty section in road at " + p2; 38 | p1 = p2; 39 | } 40 | } 41 | 42 | // repeated sections 43 | String err = ""; 44 | Set hashes = new HashSet<>(); 45 | for (LineString road: roads) { 46 | P2 p1 = null; 47 | for (P2 p2: road.points) { 48 | if (p1 != null) { 49 | int i1 = p1.hashCode(); 50 | int i2 = p2.hashCode(); 51 | long h1 = i1 * (long)1e10 + i2; 52 | long h2 = i2 * (long)1e10 + i1; 53 | if (hashes.contains(h1) || hashes.contains(h2)) { 54 | err += "Duplicate section: " + p1 + " - " + p2 + "; "; 55 | } 56 | hashes.add(h1); 57 | hashes.add(h2); 58 | } 59 | p1 = p2; 60 | } 61 | } 62 | if (!err.isEmpty()) return err; 63 | 64 | return null; 65 | } 66 | 67 | /* 68 | * Ignore this method. A private tool to create a RoadSet from a Geogebra export file. 69 | */ 70 | public static RoadSet fromText(String line) { 71 | // (1,1)(10,1) 72 | RoadSet rs = new RoadSet(); 73 | Pattern p = Pattern.compile("\\(([0-9]+)\\.,([0-9]+)\\.\\)\\(([0-9]+)\\.,([0-9]+)\\.\\)"); 74 | Matcher m = p.matcher(line); 75 | while (m.find()) { 76 | double x1 = Double.parseDouble(m.group(1)); 77 | double y1 = Double.parseDouble(m.group(2)); 78 | double x2 = Double.parseDouble(m.group(3)); 79 | double y2 = Double.parseDouble(m.group(4)); 80 | LineString e = new LineString(new P2(x1, y1), new P2(x2, y2)); 81 | rs.roads.add(e); 82 | } 83 | 84 | return rs; 85 | } 86 | } -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/src/test/BatchMetricTest.java: -------------------------------------------------------------------------------- 1 | package test; 2 | 3 | import static visualizer.Utils.f; 4 | 5 | import java.awt.BasicStroke; 6 | import java.awt.Color; 7 | import java.awt.Graphics2D; 8 | import java.awt.RenderingHints; 9 | import java.awt.image.BufferedImage; 10 | import java.io.File; 11 | import java.io.FileReader; 12 | import java.io.IOException; 13 | import java.io.LineNumberReader; 14 | import java.util.HashMap; 15 | import java.util.List; 16 | import java.util.Map; 17 | import java.util.Vector; 18 | 19 | import javax.imageio.ImageIO; 20 | 21 | import geom.Graph; 22 | import geom.LineString; 23 | import geom.Metrics; 24 | import geom.P2; 25 | import geom.RoadSet; 26 | 27 | /* 28 | * Standalone tool to run a set of tests found in a pair of truth/proposal files. 29 | * Outputs detailed log of the scoring and also images of the graphs. 30 | */ 31 | public class BatchMetricTest { 32 | private String truthFile = "../data/metrictest/truth1.csv"; 33 | private String proposalFile = "../data/metrictest/proposal1.csv"; 34 | private String singleIdToTest = null; // set to non-null to test a single image 35 | private boolean detailedLog = true; 36 | 37 | private List idList; 38 | 39 | // for drawing 40 | double xmax, xmin, ymax, ymin, scale, ySign; 41 | private int M, w, yBase; 42 | 43 | private void run() { 44 | idList = new Vector<>(); 45 | Map idToTruthRS = load(truthFile); 46 | Map idToProposalRS = load(proposalFile); 47 | Metrics.debug = detailedLog; 48 | Metrics.draw = true; 49 | int cnt = 0; 50 | 51 | for (String id: idList) { 52 | if (singleIdToTest != null && !singleIdToTest.equals(id)) continue; 53 | Graph g1 = Graph.fromRoads(idToTruthRS.get(id)); 54 | Graph g2 = Graph.fromRoads(idToProposalRS.get(id)); 55 | String cntString = "" + cnt++; 56 | while (cntString.length() < 3) cntString = "0" + cntString; 57 | String imageName = cntString + "-" + id; 58 | Metrics.imageName = imageName + "-details"; 59 | log("\n============\n" + id); 60 | double[] scores = Metrics.score(g1, g2); 61 | log("G1->G2: " + f(scores[0]) + "\tG2->G1: " + f(scores[1]) + "\tAvg: " + f(scores[2])); 62 | draw(imageName, g1, g2, scores, true); 63 | } 64 | } 65 | 66 | // Copies most of Graph.draw functionality only to make it possible to draw two 67 | // graphs on the same image using the same scale 68 | private void draw(String name, Graph g1, Graph g2, double[] scores, boolean mirror) { 69 | w = 500; 70 | BufferedImage img = new BufferedImage(w, w + 4 * Graph.M, BufferedImage.TYPE_INT_ARGB); 71 | Graphics2D g2d = (Graphics2D) img.getGraphics(); 72 | g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, 73 | RenderingHints.VALUE_ANTIALIAS_ON); 74 | xmax = -Double.MAX_VALUE; 75 | xmin = Double.MAX_VALUE; 76 | ymax = -Double.MAX_VALUE; 77 | ymin = Double.MAX_VALUE; 78 | for (P2 p: g1.nodes) { 79 | for (LineString e: p.edges) { 80 | for (P2 p2: e.points) { 81 | xmax = Math.max(xmax, p2.x); 82 | xmin = Math.min(xmin, p2.x); 83 | ymax = Math.max(ymax, p2.y); 84 | ymin = Math.min(ymin, p2.y); 85 | } 86 | } 87 | } 88 | for (P2 p: g2.nodes) { 89 | for (LineString e: p.edges) { 90 | for (P2 p2: e.points) { 91 | xmax = Math.max(xmax, p2.x); 92 | xmin = Math.min(xmin, p2.x); 93 | ymax = Math.max(ymax, p2.y); 94 | ymin = Math.min(ymin, p2.y); 95 | } 96 | } 97 | } 98 | double range = Math.max(xmax - xmin, ymax - ymin); 99 | scale = (w - 2 * Graph.M) / range; 100 | ySign = mirror ? -1 : 1; 101 | yBase = mirror ? w + 2 * Graph.M: 2 * Graph.M; 102 | 103 | g2d.setStroke(new BasicStroke(3)); 104 | draw(g2d, g1, new Color(0,255,255,180)); 105 | draw(g2d, g2, new Color(255,255,0,150)); 106 | 107 | g2d.setColor(Color.white); 108 | String res = "G1->G2: " + f(scores[0]) + "; G2->G1: " + f(scores[1]) + "; Avg: " + f(scores[2]); 109 | g2d.drawString(res, 20, 2 * Graph.M); 110 | try { 111 | ImageIO.write(img, "png", new File(name + ".png")); 112 | } catch (IOException e) { 113 | e.printStackTrace(); 114 | } 115 | } 116 | 117 | private void draw(Graphics2D g2d, Graph g, Color color) { 118 | g2d.setColor(color); 119 | for(P2 p: g.nodes) { 120 | for (LineString e: p.edges) { 121 | if (e.p1 != e.p2) { 122 | if (e.p2 == p) continue; 123 | } 124 | P2 prev = null; 125 | for (P2 p2: e.points) { 126 | if (prev != null) { 127 | int x1 = (int)(M + (p2.x - xmin) * scale); 128 | int y1 = (int)(yBase + (p2.y - ymin)* scale * ySign); 129 | int x2 = (int)(M + (prev.x - xmin) * scale); 130 | int y2 = (int)(yBase + (prev.y - ymin) * scale * ySign); 131 | g2d.drawLine(x1, y1, x2, y2); 132 | } 133 | prev = p2; 134 | } 135 | } 136 | } 137 | } 138 | 139 | private Map load(String path) { 140 | Map ret = new HashMap<>(); 141 | String line = null; 142 | int lineNo = 0; 143 | try { 144 | LineNumberReader lnr = new LineNumberReader(new FileReader(path)); 145 | while (true) { 146 | line = lnr.readLine(); 147 | lineNo++; 148 | if (line == null) break; 149 | line = line.trim(); 150 | if (line.isEmpty() || line.startsWith("#") || 151 | line.toLowerCase().startsWith("imageid")) continue; 152 | // ImageId,LineString_Pix 153 | // AOI_5_Khartoum_img1,LINESTRING (250 250, 250 350, 1050 350) 154 | 155 | int pos = line.indexOf(","); 156 | String imageId = line.substring(0, pos); 157 | if (!idList.contains(imageId)) idList.add(imageId); 158 | RoadSet g = ret.get(imageId); 159 | if (g == null) { 160 | g = new RoadSet(); 161 | ret.put(imageId, g); 162 | } 163 | String roadDef = line.substring(pos + 1); 164 | LineString road = LineString.fromText(roadDef); 165 | if (road == null) { 166 | log("Error reading roads"); 167 | log("Line #" + lineNo + ": " + line); 168 | System.exit(1); 169 | } 170 | g.roads.add(road); 171 | } 172 | lnr.close(); 173 | } 174 | catch (Exception e) { 175 | log("Error reading roads"); 176 | log("Line #" + lineNo + ": " + line); 177 | e.printStackTrace(); 178 | System.exit(1); 179 | } 180 | 181 | for (String id: ret.keySet()) { 182 | String err = ret.get(id).getError(); 183 | if (err != null) { 184 | System.out.println("Error with : " + id + " : " + err); 185 | System.exit(1); 186 | } 187 | } 188 | return ret; 189 | } 190 | 191 | private void log(String s) { 192 | System.out.println(s); 193 | } 194 | 195 | public static void main(String[] args) { 196 | new BatchMetricTest().run(); 197 | } 198 | } 199 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/src/test/LinestringToGeojson.java: -------------------------------------------------------------------------------- 1 | package test; 2 | 3 | import java.io.File; 4 | import java.io.FileOutputStream; 5 | 6 | /* 7 | * Standalone tool to create truth/proposal geojson files from LineStrings. 8 | * Needed for testing the Python-based scorer. 9 | */ 10 | public class LinestringToGeojson { 11 | private static String fileName1 = "truth.geojson"; 12 | private static String fileName2 = "solution.geojson"; 13 | private static int id1 = 0; 14 | private static int id2 = 500; 15 | 16 | private static String[] ls1Arr = new String[] { 17 | "3_edges_connected-move_node_outside_buffer,LINESTRING (500 500, 600 500, 600 700, 550 700)" 18 | 19 | }; 20 | 21 | private static String[] ls2Arr = new String[] { 22 | "3_edges_connected-move_node_outside_buffer,LINESTRING (500 500, 600 500, 600 700, 530 700)" 23 | }; 24 | 25 | String header = "{\"type\": \"FeatureCollection\",\"crs\": { \"type\": \"name\", \"properties\": { \"name\": \"urn:ogc:def:crs:OGC:1.3:CRS84\" } },\"features\": [\n"; 26 | String footer = "]}"; 27 | String template = "{\"type\": \"Feature\", \"properties\": { \"osm_id\": xxxid, \"type\": \"residential\", \"class\": \"highway\" }, \"geometry\": { \"type\": \"LineString\", \"coordinates\": [ xxxcoord ] } },\n"; 28 | 29 | 30 | private void out(String[] arr, String fileName, int id) throws Exception { 31 | double scale = 2.7e-6; // This is used in spacenet data. ~ 1 / (111000 / 0.3); 32 | StringBuilder sb = new StringBuilder(); 33 | sb.append(header); 34 | for (String line: arr) { 35 | String lineOut = template; 36 | lineOut = lineOut.replace("xxxid", "" + id++); 37 | String coords = ""; 38 | String ls = line.substring(line.indexOf("(") + 1); 39 | ls = ls.replace(")", ""); 40 | String[] parts = ls.split(","); 41 | for (String point: parts) { 42 | point = point.trim(); 43 | String[] xy = point.split(" "); 44 | double x = Double.parseDouble(xy[0].trim()); 45 | double y = Double.parseDouble(xy[1].trim()); 46 | x = x * scale + 1; 47 | y = y * scale + 1; 48 | String coord = "[" + x + ", " + y + ", 0], "; 49 | coords += coord; 50 | } 51 | coords = coords.substring(0, coords.length()-2); 52 | lineOut = lineOut.replace("xxxcoord", coords); 53 | sb.append(lineOut); 54 | } 55 | sb.append(footer); 56 | FileOutputStream out = new FileOutputStream(new File(fileName)); 57 | out.write(sb.toString().getBytes()); 58 | out.close(); 59 | } 60 | 61 | public static void main(String[] args) throws Exception { 62 | LinestringToGeojson l2g = new LinestringToGeojson(); 63 | l2g.out(ls1Arr, fileName1, id1); 64 | l2g.out(ls2Arr, fileName2, id2); 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/src/test/MetricsSpeedTest.java: -------------------------------------------------------------------------------- 1 | package test; 2 | 3 | import static visualizer.Utils.f6; 4 | 5 | import geom.Graph; 6 | import geom.LineString; 7 | import geom.Metrics; 8 | import geom.P2; 9 | import geom.RoadSet; 10 | 11 | public class MetricsSpeedTest { 12 | 13 | public static void main(String[] args) { 14 | double d = 5 / 0.31; 15 | for (int n = 5; n < 40; n++) { 16 | double[] poss = new double[n]; 17 | for (int i = 0; i < n; i++) poss[i] = (2*i+1) * d; 18 | 19 | RoadSet rs1 = new RoadSet(); 20 | for (int i = 0; i < n; i++) { 21 | LineString e = new LineString(); 22 | for (int j = 0; j < n; j++) { 23 | P2 p = new P2(poss[i], poss[j]); 24 | e.points.add(p); 25 | } 26 | rs1.roads.add(e); 27 | } 28 | for (int i = 0; i < n; i++) { 29 | LineString e = new LineString(); 30 | for (int j = 0; j < n; j++) { 31 | P2 p = new P2(poss[j], poss[i]); 32 | e.points.add(p); 33 | } 34 | rs1.roads.add(e); 35 | } 36 | Graph g1 = Graph.fromRoads(rs1); 37 | 38 | RoadSet rs2 = new RoadSet(); 39 | rs2.roads.add(LineString.fromText("LINESTRING (10 10, 30 10, 50 10, 50 30, 30 40, 30 10)")); 40 | Graph g2 = Graph.fromRoads(rs2); 41 | 42 | long startTime = System.currentTimeMillis(); 43 | Metrics.debug = false; 44 | Metrics.draw = false; 45 | double[] scores = Metrics.score(g1, g2); 46 | long time = System.currentTimeMillis() - startTime; 47 | System.out.println(n + ": " + f6(scores[0]) + ", " + f6(scores[1]) + " : " + f6(scores[2]) + " : " + time); 48 | } 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/src/test/MetricsTest.java: -------------------------------------------------------------------------------- 1 | package test; 2 | 3 | import geom.Graph; 4 | import geom.Metrics; 5 | import geom.RoadSet; 6 | import visualizer.Utils; 7 | 8 | public class MetricsTest { 9 | 10 | // test only 11 | public static void main(String[] args) { 12 | RoadSet rs1 = RoadSet.fromText("(10.,10.)(20.,10.)(20.,10.)(20.,20.)(20.,20.)(40.,20.)(40.,20.)(40.,15.)(40.,15.)(40.,10.)(40.,10.)(25.,10.)(25.,10.)(25.,20.)(25.,15.)(40.,15.)(25.,12.)(40.,12.)(32.,12.)(32.,20.)(20.,22.)(47.,22.)(55.,22.)(55.,20.)(55.,20.)(40.,20.)(47.,24.)(47.,29.)(47.,24.)(27.,24.)(27.,24.)(27.,22.)"); 13 | RoadSet rs2 = RoadSet.fromText("(10.,10.)(20.,10.)(20.,10.)(20.,20.)(20.,20.)(40.,20.)(40.,20.)(40.,15.)(40.,15.)(40.,10.)(40.,10.)(25.,10.)(25.,10.)(25.,20.)(25.,15.)(40.,15.)(25.,12.)(40.,12.)(32.,12.)(32.,20.)(20.,22.)(47.,22.)(55.,22.)(55.,20.)(55.,20.)(40.,20.)(47.,24.)(47.,29.)(47.,24.)(27.,24.)(27.,24.)(27.,22.)"); 14 | String err; 15 | err = rs1.getError(); 16 | if (err != null) { 17 | System.out.println(err); 18 | System.exit(0); 19 | } 20 | err = rs2.getError(); 21 | if (err != null) { 22 | System.out.println(err); 23 | System.exit(0); 24 | } 25 | 26 | Graph g1 = Graph.fromRoads(rs1); 27 | Graph g2 = Graph.fromRoads(rs2); 28 | double[] scores = Metrics.score(g1, g2); 29 | System.out.println(Utils.f(scores[0]) + ", " + Utils.f(scores[1]) + " : " + Utils.f(scores[2])); 30 | 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/src/visualizer/Utils.java: -------------------------------------------------------------------------------- 1 | package visualizer; 2 | 3 | import java.io.FileInputStream; 4 | import java.io.InputStream; 5 | import java.io.InputStreamReader; 6 | import java.io.LineNumberReader; 7 | import java.text.DecimalFormat; 8 | import java.text.DecimalFormatSymbols; 9 | import java.util.List; 10 | import java.util.Vector; 11 | 12 | public class Utils { 13 | 14 | private static DecimalFormat df; 15 | private static DecimalFormat df6; 16 | static { 17 | df = new DecimalFormat("0.###"); 18 | df6 = new DecimalFormat("0.######"); 19 | DecimalFormatSymbols dfs = new DecimalFormatSymbols(); 20 | dfs.setDecimalSeparator('.'); 21 | df.setDecimalFormatSymbols(dfs); 22 | df6.setDecimalFormatSymbols(dfs); 23 | } 24 | 25 | /** 26 | * Pretty print a double 27 | */ 28 | public static String f(double d) { 29 | return df.format(d); 30 | } 31 | public static String f6(double d) { 32 | return df6.format(d); 33 | } 34 | 35 | // Gets the lines of a text file at the given path 36 | public static List readTextLines(String path) { 37 | List ret = new Vector<>(); 38 | try { 39 | InputStream is = new FileInputStream(path); 40 | InputStreamReader isr = new InputStreamReader(is, "UTF-8"); 41 | LineNumberReader lnr = new LineNumberReader(isr); 42 | while (true) { 43 | String line = lnr.readLine(); 44 | if (line == null) break; 45 | line = line.trim(); 46 | if (line.isEmpty() || line.startsWith("#")) continue; 47 | ret.add(line); 48 | } 49 | lnr.close(); 50 | } 51 | catch (Exception e) { 52 | e.printStackTrace(); 53 | } 54 | return ret; 55 | } 56 | } 57 | 58 | 59 | -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/visualizer.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/visualizer.jar -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/visualizer_lib/imageio-ext-geocore-1.1.16.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/visualizer_lib/imageio-ext-geocore-1.1.16.jar -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/visualizer_lib/imageio-ext-streams-1.1.16.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/visualizer_lib/imageio-ext-streams-1.1.16.jar -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/visualizer_lib/imageio-ext-tiff-1.1.16.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/visualizer_lib/imageio-ext-tiff-1.1.16.jar -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/visualizer_lib/imageio-ext-utilities-1.1.16.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/visualizer_lib/imageio-ext-utilities-1.1.16.jar -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/visualizer_lib/jai_codec-1.1.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/visualizer_lib/jai_codec-1.1.3.jar -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/visualizer_lib/jai_core-1.1.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/visualizer_lib/jai_core-1.1.3.jar -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/visualizer_lib/jai_imageio-1.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/APLS-Visualizer/apls_visualizer-dg/visualizer_lib/jai_imageio-1.1.jar -------------------------------------------------------------------------------- /APLS-Visualizer/apls_visualizer-dg/visualizer_readme.html: -------------------------------------------------------------------------------- 1 |

Visualizer for the Road Detector challenge

2 | 3 | The purpose of the visualizer application is to let you view grayscale, 3-band and 8-band images, view ground truth road networks and your solution's road networks as overlays on these images, compare truth to solution and calculate your solution's score.
4 | 5 | Open a command window in the directory where you unzipped the visualizer-1.0.zip package and execute 6 |
 7 | java -jar visualizer.jar 
 8 |      -truth <truth_file_list> 
 9 |      -solution <solution_file> 
10 |      -image-dir <image_directory_list>
11 |      -band-triplets <band_definition_file>
12 | 
13 | (The above is a single line command, line breaks are only for readability.)

14 | 15 | This assumes that you have Java (at least v1.7) installed and it is available on your path. The meaning of the above parameters are the following: 16 |
    17 |
  • -truth (optional) specifies the location of the truth files. It is a semicolon (';')-separated list of files that should be loaded.
  • 18 |
  • -solution (optional) is your solution file, see ./data/solution-sample.csv for an example.
  • 19 |
  • -image-dir is a semicolon (';')-separated list of directories that should be loaded. Each of these is assumed to contain the PAN, RGB-PanSharpen, etc subfolders as described in the problem statement.
  • 20 |
  • -band-triplets points to a file that defines the band index triplets used to generate RGB images from the 8-band imagery. See ./data/band-triplets.txt, it describes the required syntax of band triplet definitions.
  • 21 |
22 | Note that if you use multiple path elements separated by a semicolon in the -truth or -image-dir parameters then it may be necessary to enclose the parameter value in quotes.
23 | All file and directory parameters can be relative or absolute paths. The -truth and -solution parameters are optional, the tool is able to run without them.

24 | 25 | An alternative way of specifying the parameters is via a parameters file, see params.txt for an example. The file contains the description of the required syntax. In this case the command should have exactly two parameters:
26 |
27 | java -jar visualizer.jar -params <params_file>
28 | 
29 | where <params_file> is an absolute or relative path to a parameters file. 30 | 31 | For example a command line that will run the app using the spacenet sample data: 32 |
33 | java -jar visualizer.jar -params ./data/params.txt
34 | 
35 | This assumes that you have already downloaded the sample data from the spacenet-dataset AWS bucket (see the problem statement for details) and extracted it, so the directory structure is something like this: 36 |
37 | data/
38 |     SpaceNet_Roads_Sample/
39 |         AOI_2_Vegas_Roads_Sample/
40 |             geojson/
41 |             MUL/
42 |             ...
43 |         AOI_3_Paris_Roads_Sample/
44 |             ...
45 |         AOI_4_Shanghai_Roads_Sample/
46 |             ...
47 |         AOI_5_Khartoum_Roads_Sample/
48 |             ...
49 |     band-triplets.txt
50 |     params.txt
51 |     solution-example.csv
52 | visualizer_lib/
53 |     *.jar  
54 | visualizer.jar
55 | 
56 | 
57 | Modify the params.txt file if your paths look different. 58 |

59 | 60 | There are some other optional command line parameters you can use (either directly in the command line or in the parameters file): 61 |
    62 |
  • -w <width> : Width of the tool's screen. Defaults to 1500.
  • 63 |
  • -no-gui: if present then no GUI will be shown, the application just scores the supplied solution file in command line mode.
  • 64 |
  • -truth-color <r,g,b,a> : with this you can customize the colour of the ground truth roads. The parameter should be 4 integers separated by commas, no spaces in between. E.g to set it to semi-transparent blue you can use: -truth-color 0,0,255,128
  • 65 |
  • -solution-color <r,g,b,a> : the colour of your solution's road network.
  • 66 |
67 | All these have proper defaults so you can leave them out.
68 | 69 |

Operations

70 | Usage of the tool should be straightforward. Select the view type from the top drop down list: 'RGB Pan-sharpened', 'PAN grayscale' or one of the predefined band triplet combinations (these are generated from the contents of the MUL folder). Select the image to be displayed from the bottom drop down list. You can also switch to another image by clicking the line containing an image name in the output log window.
71 | 72 | If both truth and solution files are specified then solution and truth are compared automatically, scores are displayed in the log window and also in the command line. Note that the visualizer does not enforce some of the constraints mentioned in the problem statement, like it is not necessary to give a prediction for all images in your solution file.
73 | 74 | You can zoom in/out within the image view by the mouse wheel, and pan the view by dragging.
75 | 76 |

Colour scaling

77 | The dataset contains 16-bit grayscale and 48-bit colour images which the tool converts to 8-bit or 24-bit images so that they can be displayed. A rather simple algorithm is used for colour conversion, see the loadMap() method of the Visualizer class. Note that for machine learning you may do this step differently or may not do this step at all. 78 | 79 |

Licenses

80 | The visualizer tool uses the imageio-ext library for reading multiband TIFF files. The imageio-ext library is LGPL licensed, see here for its license text. See here for details on the library.
81 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Anil Batra 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Improved Road Connectivity by Joint Learning of Orientation and Segmentation ## 2 | #### In CVPR 2019 [[pdf]](https://anilbatra2185.github.io/papers/RoadConnectivityCVPR2019.pdf) [[supplementary]](https://anilbatra2185.github.io/papers/RoadConnectivity_CVPR_Supplementary.pdf) 3 | 4 | ## Overview 5 | 6 | 7 | ## Requirements 8 | * [PyTorch](https://pytorch.org/) (version = 0.3.0) 9 | * [sknw](https://github.com/yxdragon/sknw) 10 | * [networkx](https://networkx.github.io/) (version = 2.4) 11 | * json 12 | * skimage 13 | * numpy 14 | * tqdm 15 | 16 | ## Data Preparation 17 | 18 | #### PreProcess Spacenet Data 19 | - Convert Spacenet 11-bit images to 8-bit Images, country wise. 20 | - Create Gaussian Road Masks, country wise. 21 | - Move all data to single folder. 22 | 23 | *Default Spacenet3 tree structure assumed.* 24 | ``` 25 | spacenet3 26 | │ 27 | └───AOI_2_Vegas_Train 28 | │ └───RGB-PanSharpen 29 | │ └───geojson 30 | │ └───summaryData 31 | │ 32 | └───AOI_3_Paris_Train 33 | │ └───RGB-PanSharpen 34 | │ └───geojson 35 | │ └───summaryData 36 | | 37 | └───AOI_4_Shanghai_Train 38 | | . 39 | | 40 | └───AOI_5_Khartoum_Train 41 | | . 42 | | 43 | ``` 44 | 45 | ``` 46 | cd preprocessing 47 | bash prepare_spacenet.sh /spacenet3 48 | ``` 49 | #### Split Datasets 50 | *Spacenet tree structure created by preprocessing.* 51 | ``` 52 | spacenet3 53 | | 54 | └───full 55 | │ └───gt 56 | │ └───images 57 | ``` 58 | 59 | *Download DeepGlobe Road dataset in the following tree structure.* 60 | ``` 61 | deepglobe 62 | │ 63 | └───train 64 | │ └───gt 65 | │ └───images 66 | ``` 67 | *Script to split and save in **'/data/spacenet'** and **'/data/deepglobe'**.* 68 | ``` 69 | bash split_data.sh /spacenet3/full /data/spacenet/ .png .png 70 | bash split_data.sh /deepglobe/train /data/deepglobe _sat.jpg _mask.png 71 | ``` 72 | #### Create Crops 73 | 74 | ``` 75 | data/spacenet 76 | | train.txt 77 | | val.txt 78 | | train_crops.txt # created by script 79 | | val_crops.txt # created by script 80 | | 81 | └───train 82 | │ └───gt 83 | │ └───images 84 | └───val 85 | │ └───gt 86 | │ └───images 87 | └───train_crops # created by script 88 | │ └───gt 89 | │ └───images 90 | └───val_crops # created by script 91 | │ └───gt 92 | │ └───images 93 | ``` 94 | ``` 95 | python create_crops.py --base_dir /data/spacenet/ --crop_size 650 --crop_overlap 215 --im_suffix .png --gt_suffix .png 96 | python create_crops.py --base_dir /data/deepglobe/ --crop_size 512 --crop_overlap 256 --im_suffix _sat.jpg --gt_suffix _mask.png 97 | ``` 98 | ## Visualize Data 99 | * Road Orientation - [Notebook](https://github.com/anilbatra2185/road_connectivity/blob/master/visualize_tasks.ipynb) 100 | * Training Dataset - [Notebook](https://github.com/anilbatra2185/road_connectivity/blob/master/visualize_dataset.ipynb) 101 | * Linear Corruption (Connectivity Refinement) - [Notebook](https://github.com/anilbatra2185/road_connectivity/blob/master/visualize_dataset_corrupt.ipynb) 102 | 103 | ## Training 104 | 105 | Train Multi-Task learning framework to predict road segmentation and road orientation. 106 | 107 | __Training MTL Help__ 108 | ``` 109 | usage: train_mtl.py [-h] --config CONFIG 110 | --model_name {LinkNet34MTL,StackHourglassNetMTL} 111 | --dataset {deepglobe,spacenet} 112 | --exp EXP 113 | [--resume RESUME] 114 | [--model_kwargs MODEL_KWARGS] 115 | [--multi_scale_pred MULTI_SCALE_PRED] 116 | 117 | optional arguments: 118 | -h, --help show this help message and exit 119 | --config CONFIG config file path 120 | --model_name {LinkNet34MTL,StackHourglassNetMTL} 121 | Name of Model = ['StackHourglassNetMTL', 122 | 'LinkNet34MTL'] 123 | --exp EXP Experiment Name/Directory 124 | --resume RESUME path to latest checkpoint (default: None) 125 | --dataset {deepglobe,spacenet} 126 | select dataset name from ['deepglobe', 'spacenet']. 127 | (default: Spacenet) 128 | --model_kwargs MODEL_KWARGS 129 | parameters for the model 130 | --multi_scale_pred MULTI_SCALE_PRED 131 | perform multi-scale prediction (default: True) 132 | ``` 133 | 134 | __Sample Usage__ 135 | 136 | * Training with StackModule 137 | ``` 138 | CUDA_VISIBLE_DEVICES=0,1 python train_mtl.py --config config.json --dataset deepglobe --model_name "StackHourglassNetMTL" --exp dg_stak_mtl 139 | ``` 140 | * Training with LinkNet34 141 | ``` 142 | CUDA_VISIBLE_DEVICES=0,1 python train_mtl.py --config config.json --dataset deepglobe --model_name "LinkNet34MTL" --exp dg_L34_mtl --multi_scale_pred false 143 | ``` 144 | 145 | ## Evaluate APLS 146 | 147 | * Please use Java implementation to compute APLS provided by Spacenet Challenge. - [Visualizer tool](https://drive.google.com/file/d/1rwbj_o-ELBfruPZuVkCnEQxAX2-Pz5DX/view) 148 | * For more info refer issue [#13](https://github.com/anilbatra2185/road_connectivity/issues/13) 149 | 150 | 151 | ## Connectivity Refinement 152 | 153 | * Training with Linear Artifacts/Corruption (using LinkNe34 Architecture) 154 | ``` 155 | CUDA_VISIBLE_DEVICES=0,1 python train_refine_pre.py --config config.json --dataset spacenet --model_name "LinkNet34" --exp spacenet_L34_pre_train_with_corruption --multi_scale_pred false 156 | ``` 157 | 158 | ## Citation 159 | If you find our work useful in your research, please cite: 160 | 161 | @InProceedings{Batra_2019_CVPR, 162 | author = {Batra, Anil and Singh, Suriya and Pang, Guan and Basu, Saikat and Jawahar, C.V. and Paluri, Manohar}, 163 | title = {Improved Road Connectivity by Joint Learning of Orientation and Segmentation}, 164 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 165 | month = {June}, 166 | year = {2019} 167 | } 168 | 169 | ## Remaining Tasks 170 | - [x] Dataset for Connectivity Refinement 171 | - [ ] Training file for Road connectivity refinement 172 | - [ ] Dataset for Junction Learning 173 | -------------------------------------------------------------------------------- /assests/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/assests/.DS_Store -------------------------------------------------------------------------------- /assests/images/AOI_2_Vegas_img33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/assests/images/AOI_2_Vegas_img33.png -------------------------------------------------------------------------------- /assests/images/mask_AOI_2_Vegas_img33.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/assests/images/mask_AOI_2_Vegas_img33.png -------------------------------------------------------------------------------- /assests/images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/assests/images/overview.png -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 7, 3 | "task1_classes": 2, 4 | "task2_classes": 37, 5 | "task1_weight": 1, 6 | "task2_weight": 1, 7 | "train_batch_size": 16, 8 | "val_batch_size": 4, 9 | "refinement":3, 10 | "train_dataset": { 11 | "spacenet":{ 12 | "dir": "/data/spacenet/train_crops/", 13 | "file": "/data/spacenet/train_crops.txt", 14 | "image_suffix": ".png", 15 | "gt_suffix": ".png", 16 | "crop_size": 256 17 | }, 18 | "deepglobe":{ 19 | "dir": "/data/deepglobe/train_crops/", 20 | "file": "/data/deepglobe/train_crops.txt", 21 | "image_suffix": ".jpg", 22 | "gt_suffix": ".png", 23 | "crop_size": 256 24 | }, 25 | "crop_size": 256, 26 | "augmentation": true, 27 | "mean" : "[70.95016901, 71.16398124, 71.30953645]", 28 | "std" : "[ 34.00087859, 35.18201658, 36.40463264]", 29 | "normalize_type": "Mean", 30 | "thresh": 0.76, 31 | "angle_theta": 10, 32 | "angle_bin": 10 33 | }, 34 | "val_dataset": { 35 | "spacenet":{ 36 | "dir": "/data/spacenet/val_crops/", 37 | "file": "/data/spacenet/val_crops.txt", 38 | "image_suffix": ".png", 39 | "gt_suffix": ".png", 40 | "crop_size": 650 41 | }, 42 | "deepglobe":{ 43 | "dir": "/data/deepglobe/val_crops/", 44 | "file": "/data/deepglobe/val_crops.txt", 45 | "image_suffix": ".jpg", 46 | "gt_suffix": ".png", 47 | "crop_size": 512 48 | }, 49 | "crop_size": 512, 50 | "augmentation": false, 51 | "mean" : "[70.95016901, 71.16398124, 71.30953645]", 52 | "std" : "[ 34.00087859, 35.18201658, 36.40463264]", 53 | "normalize_type": "Mean", 54 | "thresh": 0.76, 55 | "angle_theta": 10, 56 | "angle_bin": 10 57 | }, 58 | "optimizer": { 59 | "lr": 0.01, 60 | "d_lr": 0.0001, 61 | "lr_step": 0.1, 62 | "lr_drop_epoch": "[60,90,110]" 63 | }, 64 | "trainer": { 65 | "total_epochs": 120, 66 | "save_dir": "/ssd_scratch/cvit/anil.k/exp/deepglobe100/", 67 | "iter_size": 1, 68 | "test_freq": 1 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /create_crops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | """ 4 | create_crops.py: script to crops for training and validation. 5 | It will save the crop images and mask in the format: 6 | __. 7 | where x = [0, (Image_Height - crop_size) / stride] 8 | y = [0, (Image_Width - crop_size) / stride] 9 | 10 | It will create following directory structure: 11 | base_dir 12 | | train_crops.txt # created by script 13 | | val_crops.txt # created by script 14 | | 15 | └───train_crops # created by script 16 | │ └───gt 17 | │ └───images 18 | └───val_crops # created by script 19 | │ └───gt 20 | │ └───images 21 | """ 22 | 23 | from __future__ import print_function 24 | 25 | import argparse 26 | import os 27 | import mmap 28 | import cv2 29 | import time 30 | import numpy as np 31 | from skimage import io 32 | from tqdm import tqdm 33 | tqdm.monitor_interval = 0 34 | 35 | 36 | 37 | def verify_image(img_file): 38 | try: 39 | img = io.imread(img_file) 40 | except: 41 | return False 42 | return True 43 | 44 | def CreatCrops(base_dir, crop_type, size, stride, image_suffix, gt_suffix): 45 | 46 | crops = os.path.join(base_dir, '{}_crops'.format(crop_type)) 47 | crops_file = open(os.path.join(base_dir,'{}_crops.txt'.format(crop_type)),'w') 48 | 49 | full_file_path = os.path.join(base_dir,'{}.txt'.format(crop_type)) 50 | full_file = open(full_file_path,'r') 51 | 52 | def get_num_lines(file_path): 53 | fp = open(file_path, "r+") 54 | buf = mmap.mmap(fp.fileno(), 0) 55 | lines = 0 56 | while buf.readline(): 57 | lines += 1 58 | return lines 59 | 60 | failure_images = [] 61 | for name in tqdm(full_file, ncols=100, desc="{}_crops".format(crop_type), 62 | total=get_num_lines(full_file_path)): 63 | 64 | name = name.strip("\n") 65 | image_file = os.path.join(base_dir,'{}/images'.format(crop_type),name+image_suffix) 66 | gt_file = os.path.join(base_dir,'{}/gt'.format(crop_type),name+gt_suffix) 67 | 68 | if not verify_image(image_file): 69 | failure_images.append(image_file) 70 | continue 71 | 72 | image = cv2.imread(image_file) 73 | gt = cv2.imread(gt_file,0) 74 | 75 | if image is None: 76 | failure_images.append(image_file) 77 | continue 78 | 79 | if gt is None: 80 | failure_images.append(image_file) 81 | continue 82 | 83 | H,W,C = image.shape 84 | maxx = (H-size)/stride 85 | maxy = (W-size)/stride 86 | 87 | for x in range(maxx+1): 88 | for y in range(maxy+1): 89 | im_ = image[x*stride:x*stride + size,y*stride:y*stride + size,:] 90 | gt_ = gt[x*stride:x*stride + size,y*stride:y*stride + size] 91 | crops_file.write('{}_{}_{}\n'.format(name,x,y)) 92 | cv2.imwrite(crops+'/images/{}_{}_{}.png'.format(name,x,y), im_) 93 | cv2.imwrite(crops+'/gt/{}_{}_{}.png'.format(name,x,y), gt_) 94 | 95 | crops_file.close() 96 | full_file.close() 97 | if len(failure_images) > 0: 98 | print("Unable to process {} images : {}".format(len(failure_images), failure_images)) 99 | 100 | 101 | def main(): 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('-d', '--base_dir', type=str, required=True, 104 | help='Base directory for Spacenent Dataset.') 105 | parser.add_argument('--crop_size', type=int, required=True, 106 | help='Crop Size of Image') 107 | parser.add_argument('--crop_overlap', type=int, required=True, 108 | help='Crop overlap Size of Image') 109 | parser.add_argument('--im_suffix', type=str, required=True, 110 | help='Dataset specific image suffix.') 111 | parser.add_argument('--gt_suffix', type=str, required=True, 112 | help='Dataset specific gt suffix.') 113 | 114 | args = parser.parse_args() 115 | 116 | start = time.clock() 117 | ## Create overlapping Crops for training 118 | CreatCrops(args.base_dir, 119 | crop_type='train', 120 | size=args.crop_size, 121 | stride=args.crop_overlap, 122 | image_suffix=args.im_suffix, 123 | gt_suffix=args.gt_suffix) 124 | 125 | ## Create non-overlapping Crops for validation 126 | CreatCrops(args.base_dir, 127 | crop_type='val', 128 | size=args.crop_size, 129 | stride=args.crop_size, ## Non-overlapping 130 | image_suffix=args.im_suffix, 131 | gt_suffix=args.gt_suffix) 132 | 133 | end = time.clock() 134 | print('Finished Creating crops, time {0}s'.format(end - start)) 135 | 136 | if __name__ == "__main__": 137 | main() -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/data_utils/__init__.py -------------------------------------------------------------------------------- /data_utils/affinity_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import math 5 | import numpy as np 6 | from skimage.morphology import skeletonize 7 | import data_utils.graph_utils as graph_utils 8 | import data_utils.sknw as sknw 9 | 10 | 11 | def getKeypoints(mask, thresh=0.8, is_gaussian=True, is_skeleton=False, smooth_dist=4): 12 | """ 13 | Generate keypoints for binary prediction mask. 14 | 15 | @param mask: Binary road probability mask 16 | @param thresh: Probability threshold used to cnvert the mask to binary 0/1 mask 17 | @param gaussian: Flag to check if the given mask is gaussian/probability mask 18 | from prediction 19 | @param is_skeleton: Flag to perform opencv skeletonization on the binarized 20 | road mask 21 | @param smooth_dist: Tolerance parameter used to smooth the graph using 22 | RDP algorithm 23 | 24 | @return: return ndarray of road keypoints 25 | """ 26 | 27 | if is_gaussian: 28 | mask /= 255.0 29 | mask[mask < thresh] = 0 30 | mask[mask >= thresh] = 1 31 | 32 | h, w = mask.shape 33 | if is_skeleton: 34 | ske = mask 35 | else: 36 | ske = skeletonize(mask).astype(np.uint16) 37 | graph = sknw.build_sknw(ske, multi=True) 38 | 39 | segments = graph_utils.simplify_graph(graph, smooth_dist) 40 | linestrings_1 = graph_utils.segmets_to_linestrings(segments) 41 | linestrings = graph_utils.unique(linestrings_1) 42 | 43 | keypoints = [] 44 | for line in linestrings: 45 | linestring = line.rstrip("\n").split("LINESTRING ")[-1] 46 | points_str = linestring.lstrip("(").rstrip(")").split(", ") 47 | ## If there is no road present 48 | if "EMPTY" in points_str: 49 | return keypoints 50 | points = [] 51 | for pt_st in points_str: 52 | x, y = pt_st.split(" ") 53 | x, y = float(x), float(y) 54 | points.append([x, y]) 55 | 56 | x1, y1 = points[0] 57 | x2, y2 = points[-1] 58 | zero_dist1 = math.sqrt((x1) ** 2 + (y1) ** 2) 59 | zero_dist2 = math.sqrt((x2) ** 2 + (y2) ** 2) 60 | 61 | if zero_dist2 > zero_dist1: 62 | keypoints.append(points[::-1]) 63 | else: 64 | keypoints.append(points) 65 | return keypoints 66 | 67 | 68 | def getVectorMapsAngles(shape, keypoints, theta=5, bin_size=10): 69 | """ 70 | Convert Road keypoints obtained from road mask to orientation angle mask. 71 | Reference: Section 3.1 72 | https://anilbatra2185.github.io/papers/RoadConnectivityCVPR2019.pdf 73 | 74 | @param shape: Road Label/PIL image shape i.e. H x W 75 | @param keypoints: road keypoints generated from Road mask using 76 | function getKeypoints() 77 | @param theta: thickness width for orientation vectors, it is similar to 78 | thicknes of road width with which mask is generated. 79 | @param bin_size: Bin size to quantize the Orientation angles. 80 | 81 | @return: Retun ndarray of shape H x W, containing orientation angles per pixel. 82 | """ 83 | 84 | im_h, im_w = shape 85 | vecmap = np.zeros((im_h, im_w, 2), dtype=np.float32) 86 | vecmap_angles = np.zeros((im_h, im_w), dtype=np.float32) 87 | vecmap_angles.fill(360) 88 | height, width, channel = vecmap.shape 89 | for j in range(len(keypoints)): 90 | for i in range(1, len(keypoints[j])): 91 | a = keypoints[j][i - 1] 92 | b = keypoints[j][i] 93 | ax, ay = a[0], a[1] 94 | bx, by = b[0], b[1] 95 | bax = bx - ax 96 | bay = by - ay 97 | norm = math.sqrt(1.0 * bax * bax + bay * bay) + 1e-9 98 | bax /= norm 99 | bay /= norm 100 | 101 | min_w = max(int(round(min(ax, bx) - theta)), 0) 102 | max_w = min(int(round(max(ax, bx) + theta)), width) 103 | min_h = max(int(round(min(ay, by) - theta)), 0) 104 | max_h = min(int(round(max(ay, by) + theta)), height) 105 | 106 | for h in range(min_h, max_h): 107 | for w in range(min_w, max_w): 108 | px = w - ax 109 | py = h - ay 110 | dis = abs(bax * py - bay * px) 111 | if dis <= theta: 112 | vecmap[h, w, 0] = bax 113 | vecmap[h, w, 1] = bay 114 | _theta = math.degrees(math.atan2(bay, bax)) 115 | vecmap_angles[h, w] = (_theta + 360) % 360 116 | 117 | vecmap_angles = (vecmap_angles / bin_size).astype(int) 118 | return vecmap, vecmap_angles 119 | 120 | 121 | def convertAngles2VecMap(shape, vecmapAngles): 122 | """ 123 | Helper method to convert Orientation angles mask to Orientation vectors. 124 | 125 | @params shape: Road mask shape i.e. H x W 126 | @params vecmapAngles: Orientation agles mask of shape H x W 127 | @param bin_size: Bin size to quantize the Orientation angles. 128 | 129 | @return: ndarray of shape H x W x 2, containing x and y values of vector 130 | """ 131 | 132 | h, w = shape 133 | vecmap = np.zeros((h, w, 2), dtype=np.float) 134 | 135 | for h1 in range(h): 136 | for w1 in range(w): 137 | angle = vecmapAngles[h1, w1] 138 | if angle < 36.0: 139 | angle *= 10.0 140 | if angle >= 180.0: 141 | angle -= 360.0 142 | vecmap[h1, w1, 0] = math.cos(math.radians(angle)) 143 | vecmap[h1, w1, 1] = math.sin(math.radians(angle)) 144 | 145 | return vecmap 146 | 147 | 148 | def convertVecMap2Angles(shape, vecmap, bin_size=10): 149 | """ 150 | Helper method to convert Orientation vectors to Orientation angles. 151 | 152 | @params shape: Road mask shape i.e. H x W 153 | @params vecmap: Orientation vectors of shape H x W x 2 154 | 155 | @return: ndarray of shape H x W, containing orientation angles per pixel. 156 | """ 157 | 158 | im_h, im_w = shape 159 | angles = np.zeros((im_h, im_w), dtype=np.float) 160 | angles.fill(360) 161 | 162 | for h in range(im_h): 163 | for w in range(im_w): 164 | x = vecmap[h, w, 0] 165 | y = vecmap[h, w, 1] 166 | angles[h, w] = (math.degrees(math.atan2(y, x)) + 360) % 360 167 | 168 | angles = (angles / bin_size).astype(int) 169 | return angles 170 | -------------------------------------------------------------------------------- /data_utils/graph_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import math 5 | 6 | import numpy as np 7 | import data_utils.rdp as rdp 8 | 9 | 10 | def simplify_edge(ps, max_distance=1): 11 | """ 12 | Combine multiple points of graph edges to line segments 13 | so distance from points to segments <= max_distance 14 | 15 | @param ps: array of points in the edge, including node coordinates 16 | @param max_distance: maximum distance, if exceeded new segment started 17 | 18 | @return: ndarray of new nodes coordinates 19 | """ 20 | res_points = [] 21 | cur_idx = 0 22 | for i in range(1, len(ps) - 1): 23 | segment = ps[cur_idx : i + 1, :] - ps[cur_idx, :] 24 | angle = -math.atan2(segment[-1, 1], segment[-1, 0]) 25 | ca = math.cos(angle) 26 | sa = math.sin(angle) 27 | # rotate all the points so line is alongside first column coordinate 28 | # and the second col coordinate means the distance to the line 29 | segment_rotated = np.array([[ca, -sa], [sa, ca]]).dot(segment.T) 30 | distance = np.max(np.abs(segment_rotated[1, :])) 31 | if distance > max_distance: 32 | res_points.append(ps[cur_idx, :]) 33 | cur_idx = i 34 | if len(res_points) == 0: 35 | res_points.append(ps[0, :]) 36 | res_points.append(ps[-1, :]) 37 | 38 | return np.array(res_points) 39 | 40 | 41 | def simplify_graph(graph, max_distance=1): 42 | """ 43 | @params graph: MultiGraph object of networkx 44 | @return: simplified graph after applying RDP algorithm. 45 | """ 46 | all_segments = [] 47 | # Iterate over Graph Edges 48 | for (s, e) in graph.edges(): 49 | for _, val in graph[s][e].items(): 50 | # get all pixel points i.e. (x,y) between the edge 51 | ps = val["pts"] 52 | # create a full segment 53 | full_segments = np.row_stack( 54 | [graph.nodes[s]["o"], ps, graph.nodes[e]["o"]]) 55 | # simply the graph. 56 | segments = rdp.rdp(full_segments.tolist(), max_distance) 57 | all_segments.append(segments) 58 | 59 | return all_segments 60 | 61 | 62 | def segment_to_linestring(segment): 63 | """ 64 | Convert Graph segment to LineString require to calculate the APLS mteric 65 | using utility tool provided by Spacenet. 66 | """ 67 | 68 | if len(segment) < 2: 69 | return [] 70 | linestring = "LINESTRING ({})" 71 | sublinestring = "" 72 | for i, node in enumerate(segment): 73 | if i == 0: 74 | sublinestring = sublinestring + "{:.1f} {:.1f}".format(node[1], node[0]) 75 | else: 76 | if node[0] == segment[i - 1][0] and node[1] == segment[i - 1][1]: 77 | if len(segment) == 2: 78 | return [] 79 | continue 80 | if i > 1 and node[0] == segment[i - 2][0] and node[1] == segment[i - 2][1]: 81 | continue 82 | sublinestring = sublinestring + ", {:.1f} {:.1f}".format(node[1], node[0]) 83 | linestring = linestring.format(sublinestring) 84 | return linestring 85 | 86 | 87 | def segmets_to_linestrings(segments): 88 | """ 89 | Convert multiple segments to LineStrings require to calculate the APLS mteric 90 | using utility tool provided by Spacenet. 91 | """ 92 | 93 | linestrings = [] 94 | for segment in segments: 95 | linestring = segment_to_linestring(segment) 96 | if len(linestring) > 0: 97 | linestrings.append(linestring) 98 | if len(linestrings) == 0: 99 | linestrings = ["LINESTRING EMPTY"] 100 | return linestrings 101 | 102 | 103 | def unique(list1): 104 | # intilize a null list 105 | unique_list = [] 106 | 107 | # traverse for all elements 108 | for x in list1: 109 | # check if exists in unique_list or not 110 | if x not in unique_list: 111 | unique_list.append(x) 112 | return unique_list 113 | -------------------------------------------------------------------------------- /data_utils/rdp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | """ 5 | The Ramer-Douglas-Peucker algorithm roughly ported from the pseudo-code provided 6 | by http://en.wikipedia.org/wiki/Ramer-Douglas-Peucker_algorithm 7 | 8 | The code is taken from 9 | https://github.com/mitroadmaps/roadtracer/blob/master/lib/discoverlib/rdp.py 10 | """ 11 | 12 | from math import sqrt 13 | 14 | 15 | def distance(a, b): 16 | return sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) 17 | 18 | 19 | def point_line_distance(point, start, end): 20 | """ 21 | Calaculate the prependicuar distance of given point from the line having 22 | start and end points. 23 | """ 24 | if start == end: 25 | return distance(point, start) 26 | else: 27 | n = abs( 28 | (end[0] - start[0]) * (start[1] - point[1]) 29 | - (start[0] - point[0]) * (end[1] - start[1]) 30 | ) 31 | d = sqrt((end[0] - start[0]) ** 2 + (end[1] - start[1]) ** 2) 32 | return n / d 33 | 34 | 35 | def rdp(points, epsilon): 36 | """ 37 | Reduces a series of points to a simplified version that loses detail, but 38 | maintains the general shape of the series. 39 | 40 | @param points: Series of points for a line geometry represnted in graph. 41 | @param epsilon: Tolerance required for RDP algorithm to aproximate the 42 | line geometry. 43 | 44 | @return: Aproximate series of points for approximate line geometry 45 | """ 46 | dmax = 0.0 47 | index = 0 48 | for i in range(1, len(points) - 1): 49 | d = point_line_distance(points[i], points[0], points[-1]) 50 | if d > dmax: 51 | index = i 52 | dmax = d 53 | if dmax >= epsilon: 54 | results = rdp(points[: index + 1], epsilon)[:-1] + rdp(points[index:], epsilon) 55 | else: 56 | results = [points[0], points[-1]] 57 | return results 58 | -------------------------------------------------------------------------------- /data_utils/sknw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | ''' 5 | The methods are taken from 6 | https://github.com/yxdragon/sknw 7 | ''' 8 | 9 | import networkx as nx 10 | import numpy as np 11 | from numba import jit 12 | 13 | 14 | # get neighbors d index 15 | def neighbors(shape): 16 | dim = len(shape) 17 | block = np.ones([3] * dim) 18 | block[tuple([1] * dim)] = 0 19 | idx = np.where(block > 0) 20 | idx = np.array(idx, dtype=np.uint8).T 21 | idx = np.array(idx - [1] * dim) 22 | acc = np.cumprod((1,) + shape[::-1][:-1]) 23 | return np.dot(idx, acc[::-1]) 24 | 25 | 26 | @jit(nopython=True) # my mark 27 | def mark(img, nbs): # mark the array use (0, 1, 2) 28 | img = img.ravel() 29 | for p in range(len(img)): 30 | if img[p] == 0: 31 | continue 32 | s = 0 33 | for dp in nbs: 34 | if img[p + dp] != 0: 35 | s += 1 36 | if s == 2: 37 | img[p] = 1 38 | else: 39 | img[p] = 2 40 | 41 | 42 | @jit(nopython=True) # trans index to r, c... 43 | def idx2rc(idx, acc): 44 | rst = np.zeros((len(idx), len(acc)), dtype=np.int16) 45 | for i in range(len(idx)): 46 | for j in range(len(acc)): 47 | rst[i, j] = idx[i] // acc[j] 48 | idx[i] -= rst[i, j] * acc[j] 49 | rst -= 1 50 | return rst 51 | 52 | 53 | @jit(nopython=True) # fill a node (may be two or more points) 54 | def fill(img, p, num, nbs, acc, buf): 55 | back = img[p] 56 | img[p] = num 57 | buf[0] = p 58 | cur = 0 59 | s = 1 60 | 61 | while True: 62 | p = buf[cur] 63 | for dp in nbs: 64 | cp = p + dp 65 | if img[cp] == back: 66 | img[cp] = num 67 | buf[s] = cp 68 | s += 1 69 | cur += 1 70 | if cur == s: 71 | break 72 | return idx2rc(buf[:s], acc) 73 | 74 | 75 | # trace the edge and use a buffer, then buf.copy, if use [] numba not works 76 | @jit(nopython=True) 77 | def trace(img, p, nbs, acc, buf): 78 | c1 = 0 79 | c2 = 0 80 | newp = 0 81 | cur = 1 82 | while True: 83 | buf[cur] = p 84 | img[p] = 0 85 | cur += 1 86 | for dp in nbs: 87 | cp = p + dp 88 | if img[cp] >= 10: 89 | if c1 == 0: 90 | c1 = img[cp] 91 | buf[0] = cp 92 | else: 93 | c2 = img[cp] 94 | buf[cur] = cp 95 | if img[cp] == 1: 96 | newp = cp 97 | p = newp 98 | if c2 != 0: 99 | break 100 | return (c1 - 10, c2 - 10, idx2rc(buf[:cur + 1], acc)) 101 | 102 | 103 | @jit(nopython=True) # parse the image then get the nodes and edges 104 | def parse_struc(img, pts, nbs, acc): 105 | img = img.ravel() 106 | buf = np.zeros(131072, dtype=np.int64) 107 | num = 10 108 | nodes = [] 109 | for p in pts: 110 | if img[p] == 2: 111 | nds = fill(img, p, num, nbs, acc, buf) 112 | num += 1 113 | nodes.append(nds) 114 | edges = [] 115 | for p in pts: 116 | for dp in nbs: 117 | if img[p + dp] == 1: 118 | edge = trace(img, p + dp, nbs, acc, buf) 119 | edges.append(edge) 120 | return nodes, edges 121 | 122 | # use nodes and edges build a networkx graph 123 | 124 | 125 | def build_graph(nodes, edges, multi=False): 126 | graph = nx.MultiGraph() if multi else nx.Graph() 127 | for i in range(len(nodes)): 128 | graph.add_node(i, pts=nodes[i], o=nodes[i].mean(axis=0)) 129 | for s, e, pts in edges: 130 | ln = np.linalg.norm(pts[1:] - pts[:-1], axis=1).sum() 131 | graph.add_edge(s, e, pts=pts, weight=ln) 132 | return graph 133 | 134 | 135 | def buffer(ske): 136 | buf = np.zeros(tuple(np.array(ske.shape) + 2), dtype=np.uint16) 137 | buf[tuple([slice(1, -1)] * buf.ndim)] = ske 138 | return buf 139 | 140 | 141 | def mark_node(ske): 142 | buf = buffer(ske) 143 | nbs = neighbors(buf.shape) 144 | acc = np.cumprod((1,)+buf.shape[::-1][:-1])[::-1] 145 | mark(buf, nbs) 146 | return buf 147 | 148 | 149 | def build_sknw(ske, multi=False): 150 | buf = buffer(ske) 151 | nbs = neighbors(buf.shape) 152 | acc = np.cumprod((1,)+buf.shape[::-1][:-1])[::-1] 153 | mark(buf, nbs) 154 | pts = np.array(np.where(buf.ravel() == 2))[0] 155 | nodes, edges = parse_struc(buf, pts, nbs, acc) 156 | return build_graph(nodes, edges, multi) 157 | 158 | # draw the graph 159 | 160 | 161 | def draw_graph(img, graph, cn=255, ce=128): 162 | acc = np.cumprod((1,) + img.shape[::-1][:-1])[::-1] 163 | img = img.ravel() 164 | for (s, e) in graph.edges(): 165 | eds = graph[s][e] 166 | if isinstance(graph, nx.MultiGraph): 167 | for i in eds: 168 | pts = eds[i]['pts'] 169 | img[np.dot(pts, acc)] = ce 170 | else: 171 | img[np.dot(eds['pts'], acc)] = ce 172 | for idx in graph.nodes(): 173 | pts = graph.nodes[idx]['pts'] 174 | img[np.dot(pts, acc)] = cn 175 | 176 | 177 | if __name__ == '__main__': 178 | import matplotlib.pyplot as plt 179 | 180 | img = np.array([ 181 | [0, 0, 0, 1, 0, 0, 0, 0, 0], 182 | [0, 0, 0, 1, 0, 0, 0, 0, 0], 183 | [0, 0, 0, 1, 0, 0, 0, 0, 0], 184 | [1, 1, 1, 1, 0, 0, 0, 0, 0], 185 | [0, 0, 0, 0, 1, 0, 0, 0, 0], 186 | [0, 0, 0, 0, 0, 1, 0, 0, 0], 187 | [0, 0, 0, 0, 0, 1, 1, 1, 1], 188 | [0, 0, 0, 0, 1, 0, 0, 0, 0], 189 | [0, 0, 0, 1, 0, 0, 0, 0, 0]]) 190 | 191 | node_img = mark_node(img) 192 | graph = build_sknw(img) 193 | 194 | plt.imshow(node_img[1:-1, 1:-1], cmap='gray') 195 | 196 | # draw edges by pts 197 | for (s, e) in graph.edges(): 198 | ps = graph[s][e]['pts'] 199 | plt.plot(ps[:, 1], ps[:, 0], 'green') 200 | 201 | # draw node by o 202 | nodes = graph.nodes() 203 | ps = np.array([nodes[i]['o'] for i in nodes]) 204 | plt.plot(ps[:, 1], ps[:, 0], 'r.') 205 | 206 | # title and show 207 | plt.title('Build Graph') 208 | plt.show() 209 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/model/__init__.py -------------------------------------------------------------------------------- /model/linknet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import torch.nn as nn 9 | from torchvision import models 10 | 11 | 12 | class DecoderBlock(nn.Module): 13 | def __init__(self, in_channels, n_filters, group=1): 14 | super(DecoderBlock, self).__init__() 15 | 16 | # B, C, H, W -> B, C/4, H, W 17 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1, groups=group) 18 | self.norm1 = nn.BatchNorm2d(in_channels // 4) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | # B, C/4, H, W -> B, C/4, H, W 22 | self.deconv2 = nn.ConvTranspose2d( 23 | in_channels // 4, 24 | in_channels // 4, 25 | 3, 26 | stride=2, 27 | padding=1, 28 | output_padding=1, 29 | groups=group, 30 | ) 31 | self.norm2 = nn.BatchNorm2d(in_channels // 4) 32 | self.relu2 = nn.ReLU(inplace=True) 33 | 34 | # B, C/4, H, W -> B, C, H, W 35 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1, groups=group) 36 | self.norm3 = nn.BatchNorm2d(n_filters) 37 | self.relu3 = nn.ReLU(inplace=True) 38 | 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 42 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 43 | if isinstance(m, nn.ConvTranspose2d): 44 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 45 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 46 | elif isinstance(m, nn.BatchNorm2d): 47 | m.weight.data.fill_(1) 48 | m.bias.data.zero_() 49 | 50 | def forward(self, x): 51 | x = self.conv1(x) 52 | x = self.norm1(x) 53 | x = self.relu1(x) 54 | x = self.deconv2(x) 55 | x = self.norm2(x) 56 | x = self.relu2(x) 57 | x = self.conv3(x) 58 | x = self.norm3(x) 59 | x = self.relu3(x) 60 | return x 61 | 62 | 63 | class LinkNet34(nn.Module): 64 | def __init__(self, in_channels=3, num_classes=2): 65 | super(LinkNet34, self).__init__() 66 | 67 | filters = [64, 128, 256, 512] 68 | resnet = models.resnet34(pretrained=False) 69 | 70 | if in_channels==3: 71 | self.firstconv = resnet.conv1 72 | else: 73 | self.firstconv = nn.Conv2d(in_channels, filters[0], kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) 74 | 75 | self.firstbn = resnet.bn1 76 | self.firstrelu = resnet.relu 77 | self.firstmaxpool = resnet.maxpool 78 | self.encoder1 = resnet.layer1 79 | self.encoder2 = resnet.layer2 80 | self.encoder3 = resnet.layer3 81 | self.encoder4 = resnet.layer4 82 | 83 | # Decoder 84 | self.decoder4 = DecoderBlock(filters[3], filters[2]) 85 | self.decoder3 = DecoderBlock(filters[2], filters[1]) 86 | self.decoder2 = DecoderBlock(filters[1], filters[0]) 87 | self.decoder1 = DecoderBlock(filters[0], filters[0]) 88 | 89 | # Final Classifier 90 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2) 91 | self.finalrelu1 = nn.LeakyReLU(0.2, inplace=True) 92 | self.finalconv2 = nn.Conv2d(32, 32, 3) 93 | self.finalrelu2 = nn.LeakyReLU(0.2, inplace=True) 94 | self.finalconv3 = nn.Conv2d(32, num_classes, 2, padding=1) 95 | self.return_features = False 96 | self.tanh = nn.Tanh() 97 | 98 | for m in [self.finaldeconv1, self.finalconv2]: 99 | if isinstance(m, nn.Conv2d): 100 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 101 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 102 | if isinstance(m, nn.ConvTranspose2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 105 | elif isinstance(m, nn.BatchNorm2d): 106 | m.weight.data.fill_(1) 107 | m.bias.data.zero_() 108 | 109 | def forward(self, x): 110 | rows = x.size()[2] 111 | cols = x.size()[3] 112 | 113 | # Encoder 114 | x = self.firstconv(x) 115 | x = self.firstbn(x) 116 | x = self.firstrelu(x) 117 | x = self.firstmaxpool(x) 118 | 119 | e1 = self.encoder1(x) 120 | e2 = self.encoder2(e1) 121 | e3 = self.encoder3(e2) 122 | e4 = self.encoder4(e3) 123 | 124 | # Decoder with Skip Connections 125 | d4 = ( 126 | self.decoder4(e4)[ 127 | :, :, : int(math.ceil(rows / 16.0)), : int(math.ceil(cols / 16.0)) 128 | ] 129 | + e3 130 | ) 131 | d3 = ( 132 | self.decoder3(d4)[ 133 | :, :, : int(math.ceil(rows / 8.0)), : int(math.ceil(cols / 8.0)) 134 | ] 135 | + e2 136 | ) 137 | d2 = ( 138 | self.decoder2(d3)[ 139 | :, :, : int(math.ceil(rows / 4.0)), : int(math.ceil(cols / 4.0)) 140 | ] 141 | + e1 142 | ) 143 | d1 = self.decoder1(d2) 144 | 145 | # Final Classification 146 | f1 = self.finaldeconv1(d1) 147 | f2 = self.finalrelu1(f1) 148 | f3 = self.finalconv2(f2) 149 | f4 = self.finalrelu2(f3) 150 | f5 = self.finalconv3(f4) 151 | 152 | return f5[:, :, :rows, :cols] 153 | 154 | 155 | class LinkNet34MTL(nn.Module): 156 | def __init__(self, task1_classes=2, task2_classes=37): 157 | super(LinkNet34MTL, self).__init__() 158 | 159 | filters = [64, 128, 256, 512] 160 | resnet = models.resnet34(pretrained=False) 161 | 162 | self.firstconv = resnet.conv1 163 | self.firstbn = resnet.bn1 164 | self.firstrelu = resnet.relu 165 | self.firstmaxpool = resnet.maxpool 166 | self.encoder1 = resnet.layer1 167 | self.encoder2 = resnet.layer2 168 | self.encoder3 = resnet.layer3 169 | self.encoder4 = resnet.layer4 170 | 171 | # Decoder 172 | self.decoder4 = DecoderBlock(filters[3], filters[2]) 173 | self.decoder3 = DecoderBlock(filters[2], filters[1]) 174 | self.decoder2 = DecoderBlock(filters[1], filters[0]) 175 | self.decoder1 = DecoderBlock(filters[0], filters[0]) 176 | 177 | # Final Classifier 178 | self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2) 179 | self.finalrelu1 = nn.LeakyReLU(0.2, inplace=True) 180 | self.finalconv2 = nn.Conv2d(32, 32, 3) 181 | self.finalrelu2 = nn.LeakyReLU(0.2, inplace=True) 182 | self.finalconv3 = nn.Conv2d(32, task1_classes, 2, padding=1) 183 | 184 | # Decoder 185 | self.a_decoder4 = DecoderBlock(filters[3], filters[2]) 186 | self.a_decoder3 = DecoderBlock(filters[2], filters[1]) 187 | self.a_decoder2 = DecoderBlock(filters[1], filters[0]) 188 | self.a_decoder1 = DecoderBlock(filters[0], filters[0]) 189 | 190 | # Final Classifier 191 | self.a_finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2) 192 | self.a_finalrelu1 = nn.LeakyReLU(0.2, inplace=True) 193 | self.a_finalconv2 = nn.Conv2d(32, 32, 3) 194 | self.a_finalrelu2 = nn.LeakyReLU(0.2, inplace=True) 195 | self.a_finalconv3 = nn.Conv2d(32, task2_classes, 2, padding=1) 196 | 197 | for m in [ 198 | self.finaldeconv1, 199 | self.finalconv2, 200 | self.a_finaldeconv1, 201 | self.a_finalconv2, 202 | ]: 203 | if isinstance(m, nn.Conv2d): 204 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 205 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 206 | if isinstance(m, nn.ConvTranspose2d): 207 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 208 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 209 | elif isinstance(m, nn.BatchNorm2d): 210 | m.weight.data.fill_(1) 211 | m.bias.data.zero_() 212 | 213 | def forward(self, x): 214 | rows = x.size()[2] 215 | cols = x.size()[3] 216 | 217 | # Encoder 218 | x = self.firstconv(x) 219 | x = self.firstbn(x) 220 | x = self.firstrelu(x) 221 | x = self.firstmaxpool(x) 222 | 223 | e1 = self.encoder1(x) 224 | e2 = self.encoder2(e1) 225 | e3 = self.encoder3(e2) 226 | e4 = self.encoder4(e3) 227 | 228 | # Decoder with Skip Connections 229 | d4 = ( 230 | self.decoder4(e4)[ 231 | :, :, : int(math.ceil(rows / 16.0)), : int(math.ceil(cols / 16.0)) 232 | ] 233 | + e3 234 | ) 235 | d3 = ( 236 | self.decoder3(d4)[ 237 | :, :, : int(math.ceil(rows / 8.0)), : int(math.ceil(cols / 8.0)) 238 | ] 239 | + e2 240 | ) 241 | d2 = ( 242 | self.decoder2(d3)[ 243 | :, :, : int(math.ceil(rows / 4.0)), : int(math.ceil(cols / 4.0)) 244 | ] 245 | + e1 246 | ) 247 | d1 = self.decoder1(d2) 248 | 249 | # Final Classification 250 | f1 = self.finaldeconv1(d1) 251 | f2 = self.finalrelu1(f1) 252 | f3 = self.finalconv2(f2) 253 | f4 = self.finalrelu2(f3) 254 | f5 = self.finalconv3(f4) 255 | 256 | # Decoder with Skip Connections 257 | a_d4 = ( 258 | self.a_decoder4(e4)[ 259 | :, :, : int(math.ceil(rows / 16.0)), : int(math.ceil(cols / 16.0)) 260 | ] 261 | + e3 262 | ) 263 | a_d3 = ( 264 | self.a_decoder3(a_d4)[ 265 | :, :, : int(math.ceil(rows / 8.0)), : int(math.ceil(cols / 8.0)) 266 | ] 267 | + e2 268 | ) 269 | a_d2 = ( 270 | self.a_decoder2(a_d3)[ 271 | :, :, : int(math.ceil(rows / 4.0)), : int(math.ceil(cols / 4.0)) 272 | ] 273 | + e1 274 | ) 275 | a_d1 = self.a_decoder1(a_d2) 276 | 277 | # Final Classification 278 | a_f1 = self.a_finaldeconv1(a_d1) 279 | a_f2 = self.a_finalrelu1(a_f1) 280 | a_f3 = self.a_finalconv2(a_f2) 281 | a_f4 = self.a_finalrelu2(a_f3) 282 | a_f5 = self.a_finalconv3(a_f4) 283 | 284 | return f5[:, :, :rows, :cols], a_f5[:, :, :rows, :cols] 285 | -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | from linknet import LinkNet34, LinkNet34MTL 2 | from stack_module import StackHourglassNetMTL 3 | 4 | 5 | MODELS = {"LinkNet34MTL": LinkNet34MTL, "StackHourglassNetMTL": StackHourglassNetMTL} 6 | 7 | MODELS_REFINE = {"LinkNet34": LinkNet34} 8 | -------------------------------------------------------------------------------- /model/stack_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | affine_par = True 12 | 13 | 14 | class BasicResnetBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, padding=1, downsample=None): 18 | super(BasicResnetBlock, self).__init__() 19 | 20 | self.conv1 = nn.Conv2d( 21 | inplanes, planes, kernel_size=3, stride=stride, padding=padding, bias=False 22 | ) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | self.conv2 = nn.Conv2d( 27 | planes, planes, kernel_size=3, stride=stride, padding=padding, bias=False 28 | ) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | out = self.relu(out) 48 | 49 | return out 50 | 51 | 52 | class DecoderBlock(nn.Module): 53 | def __init__(self, in_channels, n_filters, group=1): 54 | super(DecoderBlock, self).__init__() 55 | 56 | # B, C, H, W -> B, C/4, H, W 57 | self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1, groups=group) 58 | self.norm1 = nn.BatchNorm2d(in_channels // 4) 59 | self.relu1 = nn.ReLU(inplace=True) 60 | 61 | # B, C/4, H, W -> B, C/4, H, W 62 | self.deconv2 = nn.ConvTranspose2d( 63 | in_channels // 4, 64 | in_channels // 4, 65 | 3, 66 | stride=2, 67 | padding=1, 68 | output_padding=1, 69 | groups=group, 70 | ) 71 | self.norm2 = nn.BatchNorm2d(in_channels // 4) 72 | self.relu2 = nn.ReLU(inplace=True) 73 | 74 | # B, C/4, H, W -> B, C, H, W 75 | self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1, groups=group) 76 | self.norm3 = nn.BatchNorm2d(n_filters) 77 | self.relu3 = nn.ReLU(inplace=True) 78 | 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 82 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 83 | if isinstance(m, nn.ConvTranspose2d): 84 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 85 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 86 | elif isinstance(m, nn.BatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | 90 | def forward(self, x): 91 | x = self.conv1(x) 92 | x = self.norm1(x) 93 | x = self.relu1(x) 94 | x = self.deconv2(x) 95 | x = self.norm2(x) 96 | x = self.relu2(x) 97 | x = self.conv3(x) 98 | x = self.norm3(x) 99 | x = self.relu3(x) 100 | return x 101 | 102 | 103 | class HourglassModuleMTL(nn.Module): 104 | def __init__(self, block, num_blocks, planes, depth): 105 | super(HourglassModuleMTL, self).__init__() 106 | self.depth = depth 107 | self.block = block 108 | self.upsample = nn.Upsample(scale_factor=2) 109 | self.hg = self._make_hour_glass(block, num_blocks, planes, depth) 110 | 111 | def _make_residual1(self, block, num_blocks, planes): 112 | layers = [] 113 | for i in range(0, num_blocks): 114 | layers.append(block(planes * block.expansion, planes)) 115 | return nn.Sequential(*layers) 116 | 117 | def _make_hour_glass(self, block, num_blocks, planes, depth): 118 | hg = [] 119 | for i in range(depth): 120 | res = [] 121 | for j in range(4): 122 | res.append(self._make_residual1(block, num_blocks, planes)) 123 | if i == 0: 124 | res.append(self._make_residual1(block, num_blocks, planes)) 125 | res.append(self._make_residual1(block, num_blocks, planes)) 126 | hg.append(nn.ModuleList(res)) 127 | return nn.ModuleList(hg) 128 | 129 | def _hour_glass_forward(self, n, x): 130 | rows = x.size(2) 131 | cols = x.size(3) 132 | 133 | up1 = self.hg[n - 1][0](x) 134 | low1 = F.max_pool2d(x, 2, stride=2, ceil_mode=True) 135 | low1 = self.hg[n - 1][1](low1) 136 | 137 | if n > 1: 138 | low2_1, low2_2 = self._hour_glass_forward(n - 1, low1) 139 | else: 140 | low2_1 = self.hg[n - 1][4](low1) 141 | low2_2 = self.hg[n - 1][5](low1) 142 | low3_1 = self.hg[n - 1][2](low2_1) 143 | low3_2 = self.hg[n - 1][3](low2_2) 144 | up2_1 = self.upsample(low3_1) 145 | up2_2 = self.upsample(low3_2) 146 | out_1 = up1 + up2_1[:, :, :rows, :cols] 147 | out_2 = up1 + up2_2[:, :, :rows, :cols] 148 | 149 | return out_1, out_2 150 | 151 | def forward(self, x): 152 | return self._hour_glass_forward(self.depth, x) 153 | 154 | 155 | class StackHourglassNetMTL(nn.Module): 156 | def __init__( 157 | self, 158 | task1_classes=2, 159 | task2_classes=37, 160 | block=BasicResnetBlock, 161 | in_channels=3, 162 | num_stacks=2, 163 | num_blocks=1, 164 | hg_num_blocks=3, 165 | depth=3, 166 | ): 167 | super(StackHourglassNetMTL, self).__init__() 168 | 169 | self.inplanes = 64 170 | self.num_feats = 128 171 | self.num_stacks = num_stacks 172 | self.conv1 = nn.Conv2d( 173 | in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, bias=True 174 | ) 175 | self.bn1 = nn.BatchNorm2d(self.inplanes) 176 | self.relu = nn.ReLU(inplace=True) 177 | self.layer1 = self._make_residual(block, self.inplanes, 1) 178 | self.layer2 = self._make_residual(block, self.inplanes, num_blocks) 179 | self.layer3 = self._make_residual(block, self.num_feats, num_blocks) 180 | self.maxpool = nn.MaxPool2d(2, stride=2, ceil_mode=True) 181 | 182 | # build hourglass modules 183 | ch = self.num_feats * block.expansion 184 | hg = [] 185 | res_1, fc_1, score_1, _fc_1, _score_1 = [], [], [], [], [] 186 | res_2, fc_2, score_2, _fc_2, _score_2 = [], [], [], [], [] 187 | 188 | for i in range(num_stacks): 189 | hg.append(HourglassModuleMTL(block, hg_num_blocks, self.num_feats, depth)) 190 | 191 | res_1.append(self._make_residual(block, self.num_feats, hg_num_blocks)) 192 | res_2.append(self._make_residual(block, self.num_feats, hg_num_blocks)) 193 | 194 | fc_1.append(self._make_fc(ch, ch)) 195 | fc_2.append(self._make_fc(ch, ch)) 196 | 197 | score_1.append(nn.Conv2d(ch, task1_classes, kernel_size=1, bias=True)) 198 | score_2.append(nn.Conv2d(ch, task2_classes, kernel_size=1, bias=True)) 199 | if i < num_stacks - 1: 200 | _fc_1.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True)) 201 | _fc_2.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True)) 202 | _score_1.append(nn.Conv2d(task1_classes, ch, kernel_size=1, bias=True)) 203 | _score_2.append(nn.Conv2d(task2_classes, ch, kernel_size=1, bias=True)) 204 | 205 | self.hg = nn.ModuleList(hg) 206 | self.res_1 = nn.ModuleList(res_1) 207 | self.fc_1 = nn.ModuleList(fc_1) 208 | self.score_1 = nn.ModuleList(score_1) 209 | self._fc_1 = nn.ModuleList(_fc_1) 210 | self._score_1 = nn.ModuleList(_score_1) 211 | 212 | self.res_2 = nn.ModuleList(res_2) 213 | self.fc_2 = nn.ModuleList(fc_2) 214 | self.score_2 = nn.ModuleList(score_2) 215 | self._fc_2 = nn.ModuleList(_fc_2) 216 | self._score_2 = nn.ModuleList(_score_2) 217 | 218 | # Final Classifier 219 | self.decoder1 = DecoderBlock(self.num_feats, self.inplanes) 220 | self.decoder1_score = nn.Conv2d( 221 | self.inplanes, task1_classes, kernel_size=1, bias=True 222 | ) 223 | self.finaldeconv1 = nn.ConvTranspose2d(self.inplanes, 32, 3, stride=2) 224 | self.finalrelu1 = nn.ReLU(inplace=True) 225 | self.finalconv2 = nn.Conv2d(32, 32, 3) 226 | self.finalrelu2 = nn.ReLU(inplace=True) 227 | self.finalconv3 = nn.Conv2d(32, task1_classes, 2, padding=1) 228 | 229 | # Final Classifier 230 | self.angle_decoder1 = DecoderBlock(self.num_feats, self.inplanes) 231 | self.angle_decoder1_score = nn.Conv2d( 232 | self.inplanes, task2_classes, kernel_size=1, bias=True 233 | ) 234 | self.angle_finaldeconv1 = nn.ConvTranspose2d(self.inplanes, 32, 3, stride=2) 235 | self.angle_finalrelu1 = nn.ReLU(inplace=True) 236 | self.angle_finalconv2 = nn.Conv2d(32, 32, 3) 237 | self.angle_finalrelu2 = nn.ReLU(inplace=True) 238 | self.angle_finalconv3 = nn.Conv2d(32, task2_classes, 2, padding=1) 239 | 240 | def _make_residual(self, block, planes, blocks, stride=1): 241 | downsample = None 242 | if stride != 1 or self.inplanes != planes * block.expansion: 243 | downsample = nn.Sequential( 244 | nn.Conv2d( 245 | self.inplanes, 246 | planes * block.expansion, 247 | kernel_size=1, 248 | stride=stride, 249 | bias=True, 250 | ) 251 | ) 252 | 253 | layers = [] 254 | layers.append(block(self.inplanes, planes, stride, downsample=downsample)) 255 | self.inplanes = planes * block.expansion 256 | for i in range(1, blocks): 257 | layers.append(block(self.inplanes, planes)) 258 | 259 | return nn.Sequential(*layers) 260 | 261 | def _make_fc(self, inplanes, outplanes): 262 | bn = nn.BatchNorm2d(inplanes) 263 | conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True) 264 | return nn.Sequential(conv, bn, self.relu) 265 | 266 | def forward(self, x): 267 | out_1 = [] 268 | out_2 = [] 269 | 270 | rows = x.size(2) 271 | cols = x.size(3) 272 | 273 | x = self.conv1(x) 274 | x = self.bn1(x) 275 | x = self.relu(x) 276 | 277 | x = self.layer1(x) 278 | x = self.maxpool(x) 279 | x = self.layer2(x) 280 | x = self.layer3(x) 281 | for i in range(self.num_stacks): 282 | y1, y2 = self.hg[i](x) 283 | y1, y2 = self.res_1[i](y1), self.res_2[i](y2) 284 | y1, y2 = self.fc_1[i](y1), self.fc_2[i](y2) 285 | score1, score2 = self.score_1[i](y1), self.score_2[i](y2) 286 | out_1.append( 287 | score1[:, :, : int(math.ceil(rows / 4.0)), : int(math.ceil(cols / 4.0))] 288 | ) 289 | out_2.append( 290 | score2[:, :, : int(math.ceil(rows / 4.0)), : int(math.ceil(cols / 4.0))] 291 | ) 292 | if i < self.num_stacks - 1: 293 | _fc_1, _fc_2 = self._fc_1[i](y1), self._fc_2[i](y2) 294 | _score_1, _score_2 = self._score_1[i](score1), self._score_2[i](score2) 295 | x = x + _fc_1 + _score_1 + _fc_2 + _score_2 296 | 297 | # Final Classification 298 | d1 = self.decoder1(y1)[ 299 | :, :, : int(math.ceil(rows / 2.0)), : int(math.ceil(cols / 2.0)) 300 | ] 301 | d1_score = self.decoder1_score(d1) 302 | out_1.append(d1_score) 303 | f1 = self.finaldeconv1(d1) 304 | f2 = self.finalrelu1(f1) 305 | f3 = self.finalconv2(f2) 306 | f4 = self.finalrelu2(f3) 307 | f5 = self.finalconv3(f4) 308 | out_1.append(f5) 309 | 310 | # Final Classification 311 | a_d1 = self.angle_decoder1(y2)[ 312 | :, :, : int(math.ceil(rows / 2.0)), : int(math.ceil(cols / 2.0)) 313 | ] 314 | a_d1_score = self.angle_decoder1_score(a_d1) 315 | out_2.append(a_d1_score) 316 | a_f1 = self.angle_finaldeconv1(a_d1) 317 | a_f2 = self.angle_finalrelu1(a_f1) 318 | a_f3 = self.angle_finalconv2(a_f2) 319 | a_f4 = self.angle_finalrelu2(a_f3) 320 | a_f5 = self.angle_finalconv3(a_f4) 321 | out_2.append(a_f5) 322 | 323 | return out_1, out_2 324 | -------------------------------------------------------------------------------- /preprocessing/prepare_spacenet.sh: -------------------------------------------------------------------------------- 1 | : ' 2 | Bash Script file to Prepare Spacenet Images and Gaussian Road Masks. 3 | 1) Convert Spacenet 11-bit images to 8-bit Images, country wise. 4 | 2) Create Gaussian Road Masks, country wise. 5 | 3) Move all data to single folder. 6 | ' 7 | 8 | spacenet_base_dir=$1 9 | 10 | RED='\033[0;31m' 11 | BLUE='\033[0;34m' 12 | GREEN='\033[0;32m' 13 | NC='\033[0m' 14 | BOLD='\033[1m' 15 | UNDERLINE='\033[4m' 16 | 17 | printf "${BLUE}${BOLD}Spacenet Data Base Dir => %s \n${NC}" $spacenet_base_dir 18 | 19 | printf "${GREEN}${UNDERLINE}${BOLD}\n Converting Spacenet 11-bit RGB images to 8-bit. ${NC}\n" 20 | python spacenet/convert_to_8bit_png.py -d $spacenet_base_dir 21 | 22 | printf "${GREEN}${UNDERLINE}${BOLD}\n Creating Spacenet gaussian road labels. ${NC}\n" 23 | python spacenet/create_gaussian_label.py -d $spacenet_base_dir 24 | 25 | printf "${GREEN}${UNDERLINE}${BOLD}\n Copying data to $spacenet_base_dir/full. ${NC}\n" 26 | for dir in $(find $spacenet_base_dir -maxdepth 1 -type d) 27 | do 28 | image_folder="$dir/RGB_8bit" 29 | copy_star="/*" 30 | if [ -d "$image_folder" ]; then 31 | mkdir -p "$spacenet_base_dir/full/images/" 32 | # mkdir -p "$spacenet_base_dir/full/labels/" 33 | cp $image_folder$copy_star "$spacenet_base_dir/full/images/" 34 | fi 35 | label_folder="$dir/gaussian_roads/label_png" 36 | if [ -d "$label_folder" ]; then 37 | mkdir -p "$spacenet_base_dir/full/gt/" 38 | cp $label_folder$copy_star "$spacenet_base_dir/full/gt/" 39 | fi 40 | done 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /preprocessing/spacenet/convert_to_8bit_png.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | """ 4 | create_png.py: script to convert Spacenet 11-bit RGB images to 8-bit images and 5 | preprocessing with CLAHE (a variant of adaptive histogram equalization algorithm). 6 | 7 | It will create following directory structure: 8 | base_dir 9 | | ---> RGB_8bit : Save 8-bit png images. 10 | """ 11 | 12 | from __future__ import print_function 13 | 14 | import argparse 15 | import sys 16 | import os 17 | import numpy as np 18 | import cv2 19 | import glob 20 | import tifffile as tif 21 | import time 22 | from tqdm import tqdm 23 | tqdm.monitor_interval = 0 24 | 25 | 26 | def CreatePNG(base_dir): 27 | spacenet_countries = ['AOI_2_Vegas_Roads_Train', 28 | 'AOI_3_Paris_Roads_Train', 29 | 'AOI_4_Shanghai_Roads_Train', 30 | 'AOI_5_Khartoum_Roads_Train'] 31 | 32 | for country in spacenet_countries: 33 | tif_folder = os.path.join(base_dir,'{country}/RGB-PanSharpen/'.format(country=country)) 34 | if os.path.isdir(tif_folder) == False: 35 | print(" ! RGB-PanSharpen folder does not exist for {country}. ! ".format(country=country)) 36 | print('x'*80) 37 | continue 38 | 39 | out_png_dir = os.path.join(base_dir,'{country}/RGB_8bit'.format(country=country)) 40 | 41 | if os.path.isdir(out_png_dir) == False: 42 | os.makedirs(out_png_dir) 43 | 44 | 45 | clahe = cv2.createCLAHE(clipLimit=2, tileGridSize=(8,8)) 46 | print('Processing Images from {}'.format(country)) 47 | print('*'*80) 48 | 49 | progress_bar = tqdm(glob.glob(tif_folder + '/*.tif'), ncols=150) 50 | for file_ in progress_bar: 51 | 52 | file_name = file_.split('/')[-1].replace('.tif','.png') 53 | progress_bar.set_description(" | --> Converting: {}".format(file_name)) 54 | 55 | img=tif.imread(file_) 56 | red = np.asarray(img[:,:,0],dtype=np.float) 57 | green = np.asarray(img[:,:,1],dtype=np.float) 58 | blue = np.asarray(img[:,:,2],dtype=np.float) 59 | 60 | red_ = 255.0 * ((red-np.min(red))/(np.max(red) - np.min(red))) 61 | green_ = 255.0 * ((green-np.min(green))/(np.max(green) - np.min(green))) 62 | blue_ = 255.0 * ((blue-np.min(blue))/(np.max(blue) - np.min(blue))) 63 | 64 | ## The default image size of Spacenet Dataset is 1300x1300. 65 | img_rgb = np.zeros((1300,1300,3),dtype=np.uint8) 66 | img_rgb[:,:,0] = clahe.apply(np.asarray(red_,dtype=np.uint8)) 67 | img_rgb[:,:,1] = clahe.apply(np.asarray(green_,dtype=np.uint8)) 68 | img_rgb[:,:,2] = clahe.apply(np.asarray(blue_,dtype=np.uint8)) 69 | 70 | cv2.imwrite(os.path.join(out_png_dir,file_name),img_rgb[:,:,::-1]) 71 | 72 | # print('\t|--> Processed Images : {}'.format(index), end='\r') 73 | # time.sleep(1) 74 | # sys.stdout.flush() 75 | 76 | 77 | def main(): 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument('-d', '--base_dir', type=str, required=True, 80 | help='Base directory for Spacenent Dataset.') 81 | 82 | args = parser.parse_args() 83 | 84 | start = time.clock() 85 | CreatePNG(args.base_dir) 86 | end = time.clock() 87 | 88 | print('Finished Creating png, time {0}s'.format(end - start)) 89 | 90 | if __name__ == "__main__": 91 | main() -------------------------------------------------------------------------------- /preprocessing/spacenet/create_gaussian_label.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | 3 | """ 4 | create_gaussian_label.py: script to convert Spacenet linestring annotation to gaussian road mask. 5 | 6 | It will create following directory structure: 7 | base_dir 8 | | ---> gaussian_roads 9 | | ---> label_tif : Tiff image to raster Linestring as road skeleton image. 10 | | ---> label_png : PNG image to create gaussian road mask. 11 | """ 12 | 13 | from __future__ import print_function 14 | 15 | import argparse 16 | import os 17 | import sys 18 | import numpy as np 19 | import cv2 20 | from scipy.ndimage.morphology import * 21 | import glob 22 | import math 23 | import time 24 | from osgeo import gdal 25 | import geoTools as gT 26 | from tqdm import tqdm 27 | tqdm.monitor_interval = 0 28 | 29 | 30 | def CreateGaussianLabel(base_dir): 31 | spacenet_countries = ['AOI_2_Vegas_Roads_Train', 32 | 'AOI_3_Paris_Roads_Train', 33 | 'AOI_4_Shanghai_Roads_Train', 34 | 'AOI_5_Khartoum_Roads_Train'] 35 | 36 | for country in spacenet_countries: 37 | tif_folder = os.path.join(base_dir,'{country}/RGB-PanSharpen/'.format(country=country)) 38 | if os.path.isdir(tif_folder) == False: 39 | print(" ! RGB-PanSharpen folder does not exist for {country}. ! ".format(country=country)) 40 | print('x'*80) 41 | continue 42 | 43 | geojson_dir = os.path.join(base_dir,'{country}/geojson/spacenetroads/'.format(country=country).format(country=country)) 44 | rgb_dir = os.path.join(base_dir,'{country}/RGB-PanSharpen/'.format(country=country)) 45 | 46 | roads_dir = os.path.join(base_dir,'{country}/gaussian_roads'.format(country=country)) 47 | 48 | if os.path.isdir(roads_dir) == False: 49 | os.makedirs(roads_dir) 50 | os.makedirs(os.path.join(roads_dir,'label_tif')) 51 | os.makedirs(os.path.join(roads_dir,'label_png')) 52 | 53 | 54 | ## The default image size of Spacenet Dataset is 1300x1300. 55 | black_image = np.zeros((1300,1300),dtype=np.uint8) 56 | 57 | failure_count = 0 58 | index = 0 59 | print('Processing Images from {}'.format(country)) 60 | print('*'*60) 61 | 62 | progress_bar = tqdm(glob.glob(geojson_dir + '/*.geojson'), ncols=150) 63 | for file_ in progress_bar: 64 | name = file_.split('/')[-1] 65 | index += 1 66 | 67 | file_name = name.replace('spacenetroads_','').replace('.geojson','') 68 | progress_bar.set_description(" | --> Creating: {}".format(file_name)) 69 | 70 | geojson_name_format = 'spacenetroads_{0}.geojson'.format(file_name) 71 | rgb_name_format = 'RGB-PanSharpen_{0}.tif'.format(file_name) 72 | road_segment_name_format = 'RGB-PanSharpen_{0}.tif'.format(file_name) 73 | out_tif_file = os.path.join(roads_dir,'label_tif',road_segment_name_format) 74 | out_png_file = os.path.join(roads_dir,'label_png',road_segment_name_format).replace('.tif','.png') 75 | 76 | geojson_file = os.path.join(geojson_dir,geojson_name_format) 77 | tif_file = os.path.join(rgb_dir,rgb_name_format) 78 | 79 | status = gT.ConvertToRoadSegmentation(tif_file,geojson_file,out_tif_file) 80 | 81 | if status != 0: 82 | print("|xxx-> Not able to convert the file {}. <-xxx".format(name)) 83 | failure_count += 1 84 | cv2.imwrite(out_png_file,black_image) 85 | else: 86 | gt_dataset = gdal.Open(out_tif_file, gdal.GA_ReadOnly) 87 | if not gt_dataset: 88 | continue 89 | gt_array = gt_dataset.GetRasterBand(1).ReadAsArray() 90 | 91 | distance_array = distance_transform_edt(1-(gt_array/255)) 92 | std = 15 93 | distance_array = np.exp(-0.5*(distance_array*distance_array)/(std*std)) 94 | distance_array *= 255 95 | cv2.imwrite(out_png_file,distance_array) 96 | 97 | # print('\t|--> Processed Images : {}'.format(index), end='\r') 98 | # sys.stdout.flush() 99 | # print('\t|--> Image: {}'.format(file_name)) 100 | 101 | print("Not able to convert {} files.".format(failure_count)) 102 | 103 | 104 | 105 | def main(): 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument('-d', '--base_dir', type=str, required=True, 108 | help='Base directory for Spacenent Dataset.') 109 | 110 | args = parser.parse_args() 111 | 112 | start = time.clock() 113 | CreateGaussianLabel(args.base_dir) 114 | end = time.clock() 115 | print('Finished Creating Labels, time {0}s'.format(end - start)) 116 | 117 | if __name__ == "__main__": 118 | main() -------------------------------------------------------------------------------- /preprocessing/spacenet/geoTools.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code is borrowed from Spacenet Utilities. 3 | https://github.com/SpaceNetChallenge/utilities/blob/spacenetV3/spacenetutilities/geoTools.py 4 | """ 5 | 6 | from osgeo import gdal, osr, ogr 7 | import numpy as np 8 | import os 9 | import csv 10 | import subprocess 11 | import math 12 | import geopandas as gpd 13 | import shapely 14 | from shapely.geometry import Point 15 | from pyproj import Proj, transform 16 | from fiona.crs import from_epsg 17 | from shapely.geometry.polygon import Polygon 18 | from shapely.geometry.multipolygon import MultiPolygon 19 | from shapely.geometry.linestring import LineString 20 | from shapely.geometry.multilinestring import MultiLineString 21 | from xml.etree.ElementTree import Element, SubElement, Comment, tostring 22 | from xml.etree import ElementTree 23 | from xml.dom import minidom 24 | try: 25 | import rtree 26 | import centerline 27 | import osmnx 28 | except: 29 | print("rtree not installed, Will break evaluation code") 30 | 31 | 32 | def import_summary_geojson(geojsonfilename): 33 | # driver = ogr.GetDriverByName('geojson') 34 | datasource = ogr.Open(geojsonfilename, 0) 35 | 36 | layer = datasource.GetLayer() 37 | # print(layer.GetFeatureCount()) 38 | 39 | imagename = geojsonfilename.split('/')[-1].replace('.geojson','.tif') 40 | roadlist = [] 41 | 42 | ###Road Type 43 | #1: Motorway 44 | #2: Primary 45 | #3: Secondary 46 | #4: Tertiary 47 | #5: Residential 48 | #6: Unclassified 49 | #7: Cart track 50 | 51 | for idx, feature in enumerate(layer): 52 | 53 | poly = feature.GetGeometryRef() 54 | 55 | if poly: 56 | roadlist.append({'ImageID':imagename,'RoadID': feature.GetField('road_id'), 57 | 'RoadType': feature.GetField('road_type'), 58 | 'Lanes': feature.GetField('lane_number'), 59 | 'IsBridge': feature.GetField('bridge_typ'),# 1= Bridge, 2=Not Bridge, 3=Unknown 60 | 'Paved': feature.GetField('bridge_typ'),# 1= Paved, 2=Unpaved, 3=Unknown 61 | 'LineString': feature.GetGeometryRef().Clone()}) 62 | 63 | return roadlist 64 | 65 | def latlon2pixel(lat, lon, input_raster='', targetsr='', geom_transform=''): 66 | # type: (object, object, object, object, object) -> object 67 | 68 | sourcesr = osr.SpatialReference() 69 | sourcesr.ImportFromEPSG(4326) 70 | 71 | geom = ogr.Geometry(ogr.wkbPoint) 72 | geom.AddPoint(lon, lat) 73 | 74 | if targetsr == '': 75 | src_raster = gdal.Open(input_raster) 76 | targetsr = osr.SpatialReference() 77 | targetsr.ImportFromWkt(src_raster.GetProjectionRef()) 78 | coord_trans = osr.CoordinateTransformation(sourcesr, targetsr) 79 | if geom_transform == '': 80 | src_raster = gdal.Open(input_raster) 81 | transform = src_raster.GetGeoTransform() 82 | else: 83 | transform = geom_transform 84 | 85 | x_origin = transform[0] 86 | # print(x_origin) 87 | y_origin = transform[3] 88 | # print(y_origin) 89 | pixel_width = transform[1] 90 | # print(pixel_width) 91 | pixel_height = transform[5] 92 | # print(pixel_height) 93 | geom.Transform(coord_trans) 94 | # print(geom.GetPoint()) 95 | x_pix = (geom.GetPoint()[0] - x_origin) / pixel_width 96 | y_pix = (geom.GetPoint()[1] - y_origin) / pixel_height 97 | 98 | return (x_pix, y_pix) 99 | 100 | def geoWKTToPixelWKT(geom, inputRaster, targetSR, geomTransform, breakMultiGeo, pixPrecision=2): 101 | # Returns Pixel Coordinate List and GeoCoordinateList 102 | 103 | geom_list = [] 104 | geom_pix_wkt_list = [] 105 | 106 | if geom.GetGeometryName() == 'POLYGON': 107 | polygonPix = ogr.Geometry(ogr.wkbPolygon) 108 | for ring in geom: 109 | # GetPoint returns a tuple not a Geometry 110 | ringPix = ogr.Geometry(ogr.wkbLinearRing) 111 | 112 | for pIdx in xrange(ring.GetPointCount()): 113 | lon, lat, z = ring.GetPoint(pIdx) 114 | xPix, yPix = latlon2pixel(lat, lon, inputRaster, targetSR, geomTransform) 115 | 116 | xPix = round(xPix, pixPrecision) 117 | yPix = round(yPix, pixPrecision) 118 | ringPix.AddPoint(xPix, yPix) 119 | 120 | polygonPix.AddGeometry(ringPix) 121 | polygonPixBuffer = polygonPix.Buffer(0.0) 122 | geom_list.append([polygonPixBuffer, geom]) 123 | 124 | elif geom.GetGeometryName() == 'MULTIPOLYGON': 125 | 126 | for poly in geom: 127 | polygonPix = ogr.Geometry(ogr.wkbPolygon) 128 | for ring in poly: 129 | # GetPoint returns a tuple not a Geometry 130 | ringPix = ogr.Geometry(ogr.wkbLinearRing) 131 | 132 | for pIdx in xrange(ring.GetPointCount()): 133 | lon, lat, z = ring.GetPoint(pIdx) 134 | xPix, yPix = latlon2pixel(lat, lon, inputRaster, targetSR, geomTransform) 135 | 136 | xPix = round(xPix, pixPrecision) 137 | yPix = round(yPix, pixPrecision) 138 | ringPix.AddPoint(xPix, yPix) 139 | 140 | polygonPix.AddGeometry(ringPix) 141 | polygonPixBuffer = polygonPix.Buffer(0.0) 142 | geom_list.append([polygonPixBuffer, geom]) 143 | 144 | elif geom.GetGeometryName() == 'LINESTRING': 145 | line = ogr.Geometry(ogr.wkbLineString) 146 | for pIdx in xrange(geom.GetPointCount()): 147 | lon, lat, z = geom.GetPoint(pIdx) 148 | xPix, yPix = latlon2pixel(lat, lon, inputRaster, targetSR, geomTransform) 149 | 150 | xPix = round(xPix, pixPrecision) 151 | yPix = round(yPix, pixPrecision) 152 | line.AddPoint(xPix, yPix) 153 | geom_list.append([line, geom]) 154 | 155 | elif geom.GetGeometryName() == 'MULTILINESTRING': 156 | 157 | if breakMultiGeo: 158 | for poly in geom: 159 | line = ogr.Geometry(ogr.wkbLineString) 160 | for pIdx in xrange(poly.GetPointCount()): 161 | lon, lat, z = poly.GetPoint(pIdx) 162 | xPix, yPix = latlon2pixel(lat, lon, inputRaster, targetSR, geomTransform) 163 | 164 | xPix = round(xPix, pixPrecision) 165 | yPix = round(yPix, pixPrecision) 166 | line.AddPoint(xPix, yPix) 167 | geom_list.append([line, poly]) 168 | else: 169 | multiline = ogr.Geometry(ogr.wkbMultiLineString) 170 | for poly in geom: 171 | line = ogr.Geometry(ogr.wkbLineString) 172 | for pIdx in xrange(poly.GetPointCount()): 173 | lon, lat, z = poly.GetPoint(pIdx) 174 | xPix, yPix = latlon2pixel(lat, lon, inputRaster, targetSR, geomTransform) 175 | 176 | xPix = round(xPix, pixPrecision) 177 | yPix = round(yPix, pixPrecision) 178 | line.AddPoint(xPix, yPix) 179 | multiline.AddGeometry(line) 180 | geom_list.append([multiline, geom]) 181 | 182 | elif geom.GetGeometryName() == 'POINT': 183 | point = ogr.Geometry(ogr.wkbPoint) 184 | for pIdx in xrange(geom.GetPointCount()): 185 | lon, lat, z = geom.GetPoint(pIdx) 186 | xPix, yPix = latlon2pixel(lat, lon, inputRaster, targetSR, geomTransform) 187 | 188 | xPix = round(xPix, pixPrecision) 189 | yPix = round(yPix, pixPrecision) 190 | point.AddPoint(xPix, yPix) 191 | geom_list.append([point, geom]) 192 | 193 | for polygonTest in geom_list: 194 | 195 | if polygonTest[0].GetGeometryName() == 'POLYGON' or \ 196 | polygonTest[0].GetGeometryName() == 'LINESTRING' or \ 197 | polygonTest[0].GetGeometryName() == 'POINT': 198 | geom_pix_wkt_list.append([polygonTest[0].ExportToWkt(), polygonTest[1].ExportToWkt()]) 199 | elif polygonTest[0].GetGeometryName() == 'MULTIPOLYGON' or \ 200 | polygonTest[0].GetGeometryName() == 'MULTILINESTRING': 201 | for (pix,geo) in geom_list: 202 | geom_pix_wkt_list.append([pix.ExportToWkt(),geo.ExportToWkt()]) 203 | 204 | return geom_pix_wkt_list 205 | 206 | 207 | def convert_wgs84geojson_to_pixgeojson(wgs84geojson, inputraster, image_id=[], pixelgeojson=True,pixelgeojson_path='', 208 | breakMultiGeo=False, pixPrecision=2): 209 | 210 | dataSource = ogr.Open(wgs84geojson, 0) 211 | if dataSource is None: 212 | print '='*50 213 | print 'GeoJson {} has no Coordinates.'.format(wgs84geojson) 214 | print '='*50 215 | return 216 | layer = dataSource.GetLayer() 217 | #print(layer.GetFeatureCount()) 218 | building_id = 0 219 | # check if geoJsonisEmpty 220 | feautureList = [] 221 | if not image_id: 222 | image_id = inputraster.split('/')[-1].replace(".tif", "") 223 | 224 | if layer.GetFeatureCount() > 0: 225 | 226 | if len(inputraster)>0: 227 | if os.path.isfile(inputraster): 228 | srcRaster = gdal.Open(inputraster) 229 | targetSR = osr.SpatialReference() 230 | targetSR.ImportFromWkt(srcRaster.GetProjectionRef()) 231 | geomTransform = srcRaster.GetGeoTransform() 232 | 233 | featureId = 0 234 | for feature in layer: 235 | 236 | geom = feature.GetGeometryRef() 237 | road_id = feature.GetField('road_id') 238 | featureName = 'roads' 239 | if len(inputraster)>0: 240 | ## Calculate 3 band 241 | geom_wkt_list = geoWKTToPixelWKT(geom, inputraster, targetSR, geomTransform,breakMultiGeo, 242 | pixPrecision=pixPrecision) 243 | 244 | for geom_wkt in geom_wkt_list: 245 | featureId += 1 246 | feautureList.append({'ImageId': image_id, 247 | 'RoadId': road_id, 248 | 'lineGeo': ogr.CreateGeometryFromWkt(geom_wkt[1]), 249 | 'linePix': ogr.CreateGeometryFromWkt(geom_wkt[0]), 250 | 'featureName': featureName, 251 | 'featureIdNum': featureId 252 | }) 253 | else: 254 | featureId += 1 255 | feautureList.append({'ImageId': image_id, 256 | 'RoadId': road_id, 257 | 'lineGeo': ogr.CreateGeometryFromWkt(geom.ExportToWkt()), 258 | 'linePix': ogr.CreateGeometryFromWkt('LINESTRING EMPTY'), 259 | 'featureName' : featureName, 260 | 'featureIdNum': featureId 261 | }) 262 | else: 263 | #print("no File exists") 264 | pass 265 | if pixelgeojson: 266 | exporttogeojson(os.path.join(pixelgeojson_path,image_id+'.geojson'), buildinglist=feautureList) 267 | 268 | return feautureList 269 | 270 | def exporttogeojson(geojsonfilename, buildinglist): 271 | # 272 | # geojsonname should end with .geojson 273 | # building list should be list of dictionaries 274 | # list of Dictionaries {'ImageId': image_id, 'RoadID': road_id, 'linePix': poly, 275 | # 'lineGeo': poly} 276 | # image_id is a string, 277 | # BuildingId is an integer, 278 | # poly is a ogr.Geometry Polygon 279 | # 280 | # returns geojsonfilename 281 | 282 | # print geojsonfilename 283 | driver = ogr.GetDriverByName('geojson') 284 | if os.path.exists(geojsonfilename): 285 | driver.DeleteDataSource(geojsonfilename) 286 | datasource = driver.CreateDataSource(geojsonfilename) 287 | layer = datasource.CreateLayer('roads', geom_type=ogr.wkbLineString) 288 | field_name = ogr.FieldDefn("ImageId", ogr.OFTString) 289 | field_name.SetWidth(75) 290 | layer.CreateField(field_name) 291 | layer.CreateField(ogr.FieldDefn("RoadId", ogr.OFTInteger)) 292 | 293 | # loop through buildings 294 | for building in buildinglist: 295 | # create feature 296 | feature = ogr.Feature(layer.GetLayerDefn()) 297 | feature.SetField("ImageId", building['ImageId']) 298 | feature.SetField("RoadId", building['RoadId']) 299 | feature.SetGeometry(building['linePix']) 300 | 301 | # Create the feature in the layer (geojson) 302 | layer.CreateFeature(feature) 303 | # Destroy the feature to free resources 304 | feature.Destroy() 305 | 306 | datasource.Destroy() 307 | 308 | return geojsonfilename 309 | 310 | def ConvertTo8BitImage(srcFileName,outFileDir,outputFormat='GTiff'): 311 | 312 | outputPixType='Byte' 313 | srcRaster = gdal.Open(srcFileName) 314 | outputRaster = os.path.join(outFileDir, srcFileName.split('/')[-1]) 315 | xmlFileName = outputRaster.replace('.tif','.xml') 316 | 317 | cmd = ['gdal_translate', '-ot', outputPixType, '-of', outputFormat, '-co', '"PHOTOMETRIC=rgb"'] 318 | scaleList = [] 319 | for bandId in range(srcRaster.RasterCount): 320 | bandId = bandId+1 321 | band=srcRaster.GetRasterBand(bandId) 322 | min = band.GetMinimum() 323 | max = band.GetMaximum() 324 | 325 | # if not exist minimum and maximum values 326 | if min is None or max is None: 327 | (min, max) = band.ComputeRasterMinMax(1) 328 | cmd.append('-scale_{}'.format(bandId)) 329 | cmd.append('{}'.format(0)) 330 | cmd.append('{}'.format(max)) 331 | cmd.append('{}'.format(0)) 332 | cmd.append('{}'.format(255)) 333 | 334 | cmd.append(srcFileName) 335 | 336 | if outputFormat == 'JPEG': 337 | outputRaster = xmlFileName.replace('.xml', '.jpg') 338 | else: 339 | outputRaster = xmlFileName.replace('.xml', '.tif') 340 | 341 | outputRaster = outputRaster.replace('_img', '_8bit_img') 342 | 343 | cmd.append(outputRaster) 344 | print(' '.join(cmd)) 345 | subprocess.call(cmd) 346 | 347 | def prettify(elem): 348 | """Return a pretty-printed XML string for the Element. 349 | """ 350 | rough_string = ElementTree.tostring(elem, 'utf-8') 351 | reparsed = minidom.parseString(rough_string) 352 | return reparsed.toprettyxml(indent=" ") 353 | 354 | def ConvertToRoadSegmentation(tif_file,geojson_file,out_file,isInstance=False): 355 | 356 | #Read Dataset from geo json file 357 | dataset = ogr.Open(geojson_file) 358 | if not dataset: 359 | print 'No Dataset' 360 | return -1 361 | layer = dataset.GetLayerByIndex(0) 362 | if not layer: 363 | print 'No Layer' 364 | return -1 365 | 366 | # First we will open our raster image, to understand how we will want to rasterize our vector 367 | raster_ds = gdal.Open(tif_file, gdal.GA_ReadOnly) 368 | 369 | # Fetch number of rows and columns 370 | ncol = raster_ds.RasterXSize 371 | nrow = raster_ds.RasterYSize 372 | 373 | # Fetch projection and extent 374 | proj = raster_ds.GetProjectionRef() 375 | ext = raster_ds.GetGeoTransform() 376 | 377 | raster_ds = None 378 | 379 | # Create the raster dataset 380 | memory_driver = gdal.GetDriverByName('GTiff') 381 | out_raster_ds = memory_driver.Create(out_file, ncol, nrow, 1, gdal.GDT_Byte) 382 | 383 | # Set the ROI image's projection and extent to our input raster's projection and extent 384 | out_raster_ds.SetProjection(proj) 385 | out_raster_ds.SetGeoTransform(ext) 386 | 387 | # Fill our output band with the 0 blank, no class label, value 388 | b = out_raster_ds.GetRasterBand(1) 389 | 390 | if isInstance: 391 | b.Fill(0) 392 | # Rasterize the shapefile layer to our new dataset 393 | status = gdal.RasterizeLayer(out_raster_ds, # output to our new dataset 394 | [1], # output to our new dataset's first band 395 | layer, # rasterize this layer 396 | None, None, # don't worry about transformations since we're in same projection 397 | [0], # burn value 0 398 | ['ALL_TOUCHED=TRUE', # rasterize all pixels touched by polygons 399 | 'ATTRIBUTE=road_type'] # put raster values according to the 'id' field values 400 | ) 401 | else: 402 | b.Fill(0) 403 | # Rasterize the shapefile layer to our new dataset 404 | status = gdal.RasterizeLayer(out_raster_ds, # output to our new dataset 405 | [1], # output to our new dataset's first band 406 | layer, # rasterize this layer 407 | None, None, # don't worry about transformations since we're in same projection 408 | [255] # burn value 0 409 | ) 410 | 411 | # Close dataset 412 | out_raster_ds = None 413 | 414 | return status 415 | 416 | def CreateEmptyTif(tif_file,out_file): 417 | 418 | #Read Dataset from geo json file 419 | layer = None 420 | 421 | # First we will open our raster image, to understand how we will want to rasterize our vector 422 | raster_ds = gdal.Open(tif_file, gdal.GA_ReadOnly) 423 | 424 | # Fetch number of rows and columns 425 | ncol = raster_ds.RasterXSize 426 | nrow = raster_ds.RasterYSize 427 | 428 | # Fetch projection and extent 429 | proj = raster_ds.GetProjectionRef() 430 | ext = raster_ds.GetGeoTransform() 431 | 432 | raster_ds = None 433 | 434 | # Create the raster dataset 435 | memory_driver = gdal.GetDriverByName('GTiff') 436 | out_raster_ds = memory_driver.Create(out_file, ncol, nrow, 1, gdal.GDT_Byte) 437 | 438 | # Set the ROI image's projection and extent to our input raster's projection and extent 439 | out_raster_ds.SetProjection(proj) 440 | out_raster_ds.SetGeoTransform(ext) 441 | 442 | # Fill our output band with the 0 blank, no class label, value 443 | b = out_raster_ds.GetRasterBand(1) 444 | 445 | b.Fill(0) 446 | # Rasterize the shapefile layer to our new dataset 447 | status = gdal.RasterizeLayer(out_raster_ds, # output to our new dataset 448 | [1], # output to our new dataset's first band 449 | #None, # rasterize this layer 450 | None, None, # don't worry about transformations since we're in same projection 451 | [255] # burn value 0 452 | ) 453 | 454 | # Close dataset 455 | out_raster_ds = None 456 | 457 | return status -------------------------------------------------------------------------------- /road_dataset.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | import os 4 | import random 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from data_utils import affinity_utils 10 | from torch.utils import data 11 | 12 | 13 | class RoadDataset(data.Dataset): 14 | def __init__( 15 | self, config, dataset_name, seed=7, multi_scale_pred=True, is_train=True 16 | ): 17 | # Seed 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | random.seed(seed) 21 | 22 | self.split = "train" if is_train else "val" 23 | self.config = config 24 | # paths 25 | self.dir = self.config[dataset_name]["dir"] 26 | 27 | self.img_root = os.path.join(self.dir, "images/") 28 | self.gt_root = os.path.join(self.dir, "gt/") 29 | self.image_list = self.config[dataset_name]["file"] 30 | 31 | # list of all images 32 | self.images = [line.rstrip("\n") for line in open(self.image_list)] 33 | 34 | # augmentations 35 | self.augmentation = self.config["augmentation"] 36 | self.crop_size = [ 37 | self.config[dataset_name]["crop_size"], 38 | self.config[dataset_name]["crop_size"], 39 | ] 40 | self.multi_scale_pred = multi_scale_pred 41 | 42 | # preprocess 43 | self.angle_theta = self.config["angle_theta"] 44 | self.mean_bgr = np.array(eval(self.config["mean"])) 45 | self.deviation_bgr = np.array(eval(self.config["std"])) 46 | self.normalize_type = self.config["normalize_type"] 47 | 48 | # to avoid Deadloack between CV Threads and Pytorch Threads caused in resizing 49 | cv2.setNumThreads(0) 50 | 51 | self.files = collections.defaultdict(list) 52 | for f in self.images: 53 | self.files[self.split].append( 54 | { 55 | "img": self.img_root 56 | + f 57 | + self.config[dataset_name]["image_suffix"], 58 | "lbl": self.gt_root + f + self.config[dataset_name]["gt_suffix"], 59 | } 60 | ) 61 | 62 | def __len__(self): 63 | return len(self.files[self.split]) 64 | 65 | def getRoadData(self, index): 66 | 67 | image_dict = self.files[self.split][index] 68 | # read each image in list 69 | if os.path.isfile(image_dict["img"]): 70 | image = cv2.imread(image_dict["img"]).astype(np.float) 71 | else: 72 | print("ERROR: couldn't find image -> ", image_dict["img"]) 73 | 74 | if os.path.isfile(image_dict["lbl"]): 75 | gt = cv2.imread(image_dict["lbl"], 0).astype(np.float) 76 | else: 77 | print("ERROR: couldn't find image -> ", image_dict["lbl"]) 78 | 79 | if self.split == "train": 80 | image, gt = self.random_crop(image, gt, self.crop_size) 81 | else: 82 | image = cv2.resize( 83 | image, 84 | (self.crop_size[0], self.crop_size[1]), 85 | interpolation=cv2.INTER_LINEAR, 86 | ) 87 | gt = cv2.resize( 88 | gt, 89 | (self.crop_size[0], self.crop_size[1]), 90 | interpolation=cv2.INTER_LINEAR, 91 | ) 92 | 93 | if self.split == "train" and index == len(self.files[self.split]) - 1: 94 | np.random.shuffle(self.files[self.split]) 95 | 96 | h, w, c = image.shape 97 | if self.augmentation == 1: 98 | flip = np.random.choice(2) * 2 - 1 99 | image = np.ascontiguousarray(image[:, ::flip, :]) 100 | gt = np.ascontiguousarray(gt[:, ::flip]) 101 | rotation = np.random.randint(4) * 90 102 | M = cv2.getRotationMatrix2D((w / 2, h / 2), rotation, 1) 103 | image = cv2.warpAffine(image, M, (w, h)) 104 | gt = cv2.warpAffine(gt, M, (w, h)) 105 | 106 | image = self.reshape(image) 107 | image = torch.from_numpy(np.array(image)) 108 | 109 | return image, gt 110 | 111 | def getOrientationGT(self, keypoints, height, width): 112 | vecmap, vecmap_angles = affinity_utils.getVectorMapsAngles( 113 | (height, width), keypoints, theta=self.angle_theta, bin_size=10 114 | ) 115 | vecmap_angles = torch.from_numpy(vecmap_angles) 116 | 117 | return vecmap_angles 118 | 119 | def getCorruptRoad( 120 | self, road_gt, height, width, artifacts_shape="linear", element_counts=8 121 | ): 122 | # False Negative Mask 123 | FNmask = np.ones((height, width), np.float) 124 | # False Positive Mask 125 | FPmask = np.zeros((height, width), np.float) 126 | indices = np.where(road_gt == 1) 127 | 128 | if artifacts_shape == "square": 129 | shapes = [[16, 16], [32, 32]] 130 | ##### FNmask 131 | if len(indices[0]) == 0: ### no road pixel in GT 132 | pass 133 | else: 134 | for c_ in range(element_counts): 135 | c = np.random.choice(len(shapes), 1)[ 136 | 0 137 | ] ### choose random square size 138 | shape_ = shapes[c] 139 | ind = np.random.choice(len(indices[0]), 1)[ 140 | 0 141 | ] ### choose a random road pixel as center for the square 142 | row = indices[0][ind] 143 | col = indices[1][ind] 144 | 145 | FNmask[ 146 | row - shape_[0] / 2 : row + shape_[0] / 2, 147 | col - shape_[1] / 2 : col + shape_[1] / 2, 148 | ] = 0 149 | #### FPmask 150 | for c_ in range(element_counts): 151 | c = np.random.choice(len(shapes), 2)[0] ### choose random square size 152 | shape_ = shapes[c] 153 | row = np.random.choice(height - shape_[0] - 1, 1)[ 154 | 0 155 | ] ### choose random pixel 156 | col = np.random.choice(width - shape_[1] - 1, 1)[ 157 | 0 158 | ] ### choose random pixel 159 | FPmask[ 160 | row - shape_[0] / 2 : row + shape_[0] / 2, 161 | col - shape_[1] / 2 : col + shape_[1] / 2, 162 | ] = 1 163 | 164 | elif artifacts_shape == "linear": 165 | ##### FNmask 166 | if len(indices[0]) == 0: ### no road pixel in GT 167 | pass 168 | else: 169 | for c_ in range(element_counts): 170 | c1 = np.random.choice(len(indices[0]), 1)[ 171 | 0 172 | ] ### choose random 2 road pixels to draw a line 173 | c2 = np.random.choice(len(indices[0]), 1)[0] 174 | cv2.line( 175 | FNmask, 176 | (indices[1][c1], indices[0][c1]), 177 | (indices[1][c2], indices[0][c2]), 178 | 0, 179 | self.angle_theta * 2, 180 | ) 181 | #### FPmask 182 | for c_ in range(element_counts): 183 | row1 = np.random.choice(height, 1) 184 | col1 = np.random.choice(width, 1) 185 | row2, col2 = ( 186 | row1 + np.random.choice(50, 1), 187 | col1 + np.random.choice(50, 1), 188 | ) 189 | cv2.line(FPmask, (col1, row1), (col2, row2), 1, self.angle_theta * 2) 190 | 191 | erased_gt = (road_gt * FNmask) + FPmask 192 | erased_gt[erased_gt > 0] = 1 193 | 194 | return erased_gt 195 | 196 | def reshape(self, image): 197 | 198 | if self.normalize_type == "Std": 199 | image = (image - self.mean_bgr) / (3 * self.deviation_bgr) 200 | elif self.normalize_type == "MinMax": 201 | image = (image - self.min_bgr) / (self.max_bgr - self.min_bgr) 202 | image = image * 2 - 1 203 | elif self.normalize_type == "Mean": 204 | image -= self.mean_bgr 205 | else: 206 | image = (image / 255.0) * 2 - 1 207 | 208 | image = image.transpose(2, 0, 1) 209 | return image 210 | 211 | def random_crop(self, image, gt, size): 212 | 213 | w, h, _ = image.shape 214 | crop_h, crop_w = size 215 | 216 | start_x = np.random.randint(0, w - crop_w) 217 | start_y = np.random.randint(0, h - crop_h) 218 | 219 | image = image[start_x : start_x + crop_w, start_y : start_y + crop_h, :] 220 | gt = gt[start_x : start_x + crop_w, start_y : start_y + crop_h] 221 | 222 | return image, gt 223 | 224 | 225 | class SpacenetDataset(RoadDataset): 226 | def __init__(self, config, seed=7, multi_scale_pred=True, is_train=True): 227 | super(SpacenetDataset, self).__init__( 228 | config, "spacenet", seed, multi_scale_pred, is_train 229 | ) 230 | 231 | # preprocess 232 | self.threshold = self.config["thresh"] 233 | print("Threshold is set to {} for {}".format(self.threshold, self.split)) 234 | 235 | def __getitem__(self, index): 236 | 237 | image, gt = self.getRoadData(index) 238 | c, h, w = image.shape 239 | 240 | labels = [] 241 | vecmap_angles = [] 242 | if self.multi_scale_pred: 243 | smoothness = [1, 2, 4] 244 | scale = [4, 2, 1] 245 | else: 246 | smoothness = [4] 247 | scale = [1] 248 | 249 | for i, val in enumerate(scale): 250 | if val != 1: 251 | gt_ = cv2.resize( 252 | gt, 253 | (int(math.ceil(h / (val * 1.0))), int(math.ceil(w / (val * 1.0)))), 254 | interpolation=cv2.INTER_NEAREST, 255 | ) 256 | else: 257 | gt_ = gt 258 | 259 | gt_orig = np.copy(gt_) 260 | gt_orig /= 255.0 261 | gt_orig[gt_orig < self.threshold] = 0 262 | gt_orig[gt_orig >= self.threshold] = 1 263 | labels.append(gt_orig) 264 | 265 | keypoints = affinity_utils.getKeypoints( 266 | gt_, thresh=0.98, smooth_dist=smoothness[i] 267 | ) 268 | vecmap_angle = self.getOrientationGT( 269 | keypoints, 270 | height=int(math.ceil(h / (val * 1.0))), 271 | width=int(math.ceil(w / (val * 1.0))), 272 | ) 273 | vecmap_angles.append(vecmap_angle) 274 | 275 | return image, labels, vecmap_angles 276 | 277 | 278 | class DeepGlobeDataset(RoadDataset): 279 | def __init__(self, config, seed=7, multi_scale_pred=True, is_train=True): 280 | super(DeepGlobeDataset, self).__init__( 281 | config, "deepglobe", seed, multi_scale_pred, is_train 282 | ) 283 | 284 | pass 285 | 286 | def __getitem__(self, index): 287 | 288 | image, gt = self.getRoadData(index) 289 | c, h, w = image.shape 290 | 291 | labels = [] 292 | vecmap_angles = [] 293 | if self.multi_scale_pred: 294 | smoothness = [1, 2, 4] 295 | scale = [4, 2, 1] 296 | else: 297 | smoothness = [4] 298 | scale = [1] 299 | 300 | for i, val in enumerate(scale): 301 | if val != 1: 302 | gt_ = cv2.resize( 303 | gt, 304 | (int(math.ceil(h / (val * 1.0))), int(math.ceil(w / (val * 1.0)))), 305 | interpolation=cv2.INTER_NEAREST, 306 | ) 307 | else: 308 | gt_ = gt 309 | 310 | gt_orig = np.copy(gt_) 311 | gt_orig /= 255.0 312 | labels.append(gt_orig) 313 | 314 | # Create Orientation Ground Truth 315 | keypoints = affinity_utils.getKeypoints( 316 | gt_orig, is_gaussian=False, smooth_dist=smoothness[i] 317 | ) 318 | vecmap_angle = self.getOrientationGT( 319 | keypoints, 320 | height=int(math.ceil(h / (val * 1.0))), 321 | width=int(math.ceil(w / (val * 1.0))), 322 | ) 323 | vecmap_angles.append(vecmap_angle) 324 | 325 | return image, labels, vecmap_angles 326 | 327 | 328 | class SpacenetDatasetCorrupt(RoadDataset): 329 | def __init__(self, config, seed=7, is_train=True): 330 | super(SpacenetDatasetCorrupt, self).__init__( 331 | config, "spacenet", seed, multi_scale_pred=False, is_train=is_train 332 | ) 333 | 334 | # preprocess 335 | self.threshold = self.config["thresh"] 336 | print("Threshold is set to {} for {}".format(self.threshold, self.split)) 337 | 338 | def __getitem__(self, index): 339 | 340 | image, gt = self.getRoadData(index) 341 | c, h, w = image.shape 342 | gt /= 255.0 343 | gt[gt < self.threshold] = 0 344 | gt[gt >= self.threshold] = 1 345 | 346 | erased_gt = self.getCorruptRoad(gt.copy(), h, w) 347 | erased_gt = torch.from_numpy(erased_gt) 348 | 349 | return image, [gt], [erased_gt] 350 | 351 | 352 | class DeepGlobeDatasetCorrupt(RoadDataset): 353 | def __init__(self, config, seed=7, is_train=True): 354 | super(DeepGlobeDatasetCorrupt, self).__init__( 355 | config, "deepglobe", seed, multi_scale_pred=False, is_train=is_train 356 | ) 357 | 358 | pass 359 | 360 | def __getitem__(self, index): 361 | 362 | image, gt = self.getRoadData(index) 363 | c, h, w = image.shape 364 | gt /= 255.0 365 | 366 | erased_gt = self.getCorruptRoad(gt, h, w) 367 | erased_gt = torch.from_numpy(erased_gt) 368 | 369 | return image, [gt], [erased_gt] 370 | -------------------------------------------------------------------------------- /split_data.sh: -------------------------------------------------------------------------------- 1 | # Deepglobe 2 | # image_postfix = "_sat.jpg" 3 | # gt_postfix = "_mask.png" 4 | 5 | # Spacenet 6 | # image_postfix = ".png" 7 | # gt_postfix = ".png" 8 | 9 | full_train_dir=$1 10 | base_dir=$2 11 | image_postfix=$3 12 | gt_postfix=$4 13 | 14 | train_image="$base_dir/train/images/" 15 | train_gt="$base_dir/train/gt/" 16 | val_image="$base_dir/val/images/" 17 | val_gt="$base_dir/val/gt/" 18 | 19 | RED='\033[0;31m' 20 | BLUE='\033[0;34m' 21 | GREEN='\033[0;32m' 22 | NC='\033[0m' 23 | BOLD='\033[1m' 24 | UNDERLINE='\033[4m' 25 | 26 | printf "${BLUE}${BOLD}Full Data Dir => %s \n${NC}" $full_train_dir 27 | printf "${BLUE}${BOLD}Split Data Dir => %s \n${NC}" $base_dir 28 | printf "${BLUE}$Split Sub Dir => %s \n${NC}" $train_image $train_gt $val_image $val_gt 29 | 30 | printf "${GREEN}${UNDERLINE}${BOLD}\n Creating folder structure. ${NC}\n" 31 | mkdir -p $train_image $train_gt $val_image $val_gt 32 | 33 | printf "${GREEN}${UNDERLINE}${BOLD}\n Splitting Data. ${NC}\n" 34 | 35 | i=1 36 | sp="/-\|" 37 | echo -n ' ' 38 | while read -r line 39 | do 40 | cp "$full_train_dir/images/$line$image_postfix" "$train_image" 41 | cp "$full_train_dir/gt/$line$gt_postfix" "$train_gt" 42 | printf "\r${sp:i++%${#sp}:1} Copying training data." 43 | done < "$base_dir/train.txt" 44 | 45 | i=1 46 | while read -r line 47 | do 48 | cp "$full_train_dir/images/$line$image_postfix" "$val_image" 49 | cp "$full_train_dir/gt/$line$gt_postfix" "$val_gt" 50 | printf "\r${sp:i++%${#sp}:1} Copying validation data." 51 | done < "$base_dir/val.txt" 52 | 53 | printf "\n${GREEN}${BOLD}\n Finished Split. ${NC}\n" 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /train_mtl.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import json 5 | import os 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | import torch.utils.data as data 14 | from model.models import MODELS 15 | from road_dataset import DeepGlobeDataset, SpacenetDataset 16 | from torch.autograd import Variable 17 | from torch.optim.lr_scheduler import MultiStepLR 18 | from utils.loss import CrossEntropyLoss2d, mIoULoss 19 | from utils import util 20 | from utils import viz_util 21 | 22 | 23 | __dataset__ = {"spacenet": SpacenetDataset, "deepglobe": DeepGlobeDataset} 24 | 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | "--config", required=True, type=str, help="config file path" 29 | ) 30 | parser.add_argument( 31 | "--model_name", 32 | required=True, 33 | choices=sorted(MODELS.keys()), 34 | help="Name of Model = {}".format(MODELS.keys()), 35 | ) 36 | parser.add_argument("--exp", required=True, type=str, help="Experiment Name/Directory") 37 | parser.add_argument( 38 | "--resume", default=None, type=str, help="path to latest checkpoint (default: None)" 39 | ) 40 | parser.add_argument( 41 | "--dataset", 42 | required=True, 43 | choices=sorted(__dataset__.keys()), 44 | help="select dataset name from {}. (default: Spacenet)".format(__dataset__.keys()), 45 | ) 46 | parser.add_argument( 47 | "--model_kwargs", 48 | default={}, 49 | type=json.loads, 50 | help="parameters for the model", 51 | ) 52 | parser.add_argument( 53 | "--multi_scale_pred", 54 | default=True, 55 | type=util.str2bool, 56 | help="perform multi-scale prediction (default: True)", 57 | ) 58 | 59 | args = parser.parse_args() 60 | config = None 61 | 62 | if args.resume is not None: 63 | if args.config is not None: 64 | print("Warning: --config overridden by --resume") 65 | config = torch.load(args.resume)["config"] 66 | elif args.config is not None: 67 | config = json.load(open(args.config)) 68 | 69 | assert config is not None 70 | 71 | util.setSeed(config) 72 | 73 | experiment_dir = os.path.join(config["trainer"]["save_dir"], args.exp) 74 | util.ensure_dir(experiment_dir) 75 | 76 | ###Logging Files 77 | train_file = "{}/{}_train_loss.txt".format(experiment_dir, args.dataset) 78 | test_file = "{}/{}_test_loss.txt".format(experiment_dir, args.dataset) 79 | 80 | train_loss_file = open(train_file, "w", 0) 81 | val_loss_file = open(test_file, "w", 0) 82 | 83 | ### Angle Metrics 84 | train_file_angle = "{}/{}_train_angle_loss.txt".format(experiment_dir, args.dataset) 85 | test_file_angle = "{}/{}_test_angle_loss.txt".format(experiment_dir, args.dataset) 86 | 87 | train_loss_angle_file = open(train_file_angle, "w", 0) 88 | val_loss_angle_file = open(test_file_angle, "w", 0) 89 | ################################################################################ 90 | num_gpus = torch.cuda.device_count() 91 | 92 | model = MODELS[args.model_name]( 93 | config["task1_classes"], config["task2_classes"], **args.model_kwargs 94 | ) 95 | 96 | if num_gpus > 1: 97 | print("Training with multiple GPUs ({})".format(num_gpus)) 98 | model = nn.DataParallel(model).cuda() 99 | else: 100 | print("Single Cuda Node is avaiable") 101 | model.cuda() 102 | ################################################################################ 103 | 104 | ### Load Dataset from root folder and intialize DataLoader 105 | train_loader = data.DataLoader( 106 | __dataset__[args.dataset]( 107 | config["train_dataset"], 108 | seed=config["seed"], 109 | is_train=True, 110 | multi_scale_pred=args.multi_scale_pred, 111 | ), 112 | batch_size=config["train_batch_size"], 113 | num_workers=8, 114 | shuffle=True, 115 | pin_memory=False, 116 | ) 117 | 118 | val_loader = data.DataLoader( 119 | __dataset__[args.dataset]( 120 | config["val_dataset"], 121 | seed=config["seed"], 122 | is_train=False, 123 | multi_scale_pred=args.multi_scale_pred, 124 | ), 125 | batch_size=config["val_batch_size"], 126 | num_workers=8, 127 | shuffle=True, 128 | pin_memory=False, 129 | ) 130 | 131 | print("Training with dataset => {}".format(train_loader.dataset.__class__.__name__)) 132 | ################################################################################ 133 | 134 | best_accuracy = 0 135 | best_miou = 0 136 | start_epoch = 1 137 | total_epochs = config["trainer"]["total_epochs"] 138 | optimizer = optim.SGD( 139 | model.parameters(), lr=config["optimizer"]["lr"], momentum=0.9, weight_decay=0.0005 140 | ) 141 | 142 | if args.resume is not None: 143 | print("Loading from existing FCN and copying weights to continue....") 144 | checkpoint = torch.load(args.resume) 145 | start_epoch = checkpoint["epoch"] + 1 146 | best_miou = checkpoint["miou"] 147 | # stat_parallel_dict = util.getParllelNetworkStateDict(checkpoint['state_dict']) 148 | # model.load_state_dict(stat_parallel_dict) 149 | model.load_state_dict(checkpoint["state_dict"]) 150 | optimizer.load_state_dict(checkpoint["optimizer"]) 151 | else: 152 | util.weights_init(model, manual_seed=config["seed"]) 153 | 154 | viz_util.summary(model, print_arch=False) 155 | 156 | scheduler = MultiStepLR( 157 | optimizer, 158 | milestones=eval(config["optimizer"]["lr_drop_epoch"]), 159 | gamma=config["optimizer"]["lr_step"], 160 | ) 161 | 162 | 163 | weights = torch.ones(config["task1_classes"]).cuda() 164 | if config["task1_weight"] < 1: 165 | print("Roads are weighted.") 166 | weights[0] = 1 - config["task1_weight"] 167 | weights[1] = config["task1_weight"] 168 | 169 | 170 | weights_angles = torch.ones(config["task2_classes"]).cuda() 171 | if config["task2_weight"] < 1: 172 | print("Road angles are weighted.") 173 | weights_angles[-1] = config["task2_weight"] 174 | 175 | 176 | angle_loss = CrossEntropyLoss2d( 177 | weight=weights_angles, size_average=True, ignore_index=255, reduce=True 178 | ).cuda() 179 | road_loss = mIoULoss( 180 | weight=weights, size_average=True, n_classes=config["task1_classes"] 181 | ).cuda() 182 | 183 | 184 | def train(epoch): 185 | train_loss_iou = 0 186 | train_loss_vec = 0 187 | model.train() 188 | optimizer.zero_grad() 189 | hist = np.zeros((config["task1_classes"], config["task1_classes"])) 190 | hist_angles = np.zeros((config["task2_classes"], config["task2_classes"])) 191 | crop_size = config["train_dataset"][args.dataset]["crop_size"] 192 | for i, data in enumerate(train_loader, 0): 193 | inputsBGR, labels, vecmap_angles = data 194 | inputsBGR = Variable(inputsBGR.float().cuda()) 195 | outputs, pred_vecmaps = model(inputsBGR) 196 | 197 | if args.multi_scale_pred: 198 | loss1 = road_loss(outputs[0], util.to_variable(labels[0]), False) 199 | num_stacks = model.module.num_stacks if num_gpus > 1 else model.num_stacks 200 | for idx in range(num_stacks - 1): 201 | loss1 += road_loss(outputs[idx + 1], util.to_variable(labels[0]), False) 202 | for idx, output in enumerate(outputs[-2:]): 203 | loss1 += road_loss(output, util.to_variable(labels[idx + 1]), False) 204 | 205 | loss2 = angle_loss(pred_vecmaps[0], util.to_variable(vecmap_angles[0])) 206 | for idx in range(num_stacks - 1): 207 | loss2 += angle_loss( 208 | pred_vecmaps[idx + 1], util.to_variable(vecmap_angles[0]) 209 | ) 210 | for idx, pred_vecmap in enumerate(pred_vecmaps[-2:]): 211 | loss2 += angle_loss(pred_vecmap, util.to_variable(vecmap_angles[idx + 1])) 212 | 213 | outputs = outputs[-1] 214 | pred_vecmaps = pred_vecmaps[-1] 215 | else: 216 | loss1 = road_loss(outputs, util.to_variable(labels[-1]), False) 217 | loss2 = angle_loss(pred_vecmaps, util.to_variable(vecmap_angles[-1])) 218 | 219 | train_loss_iou += loss1.data[0] 220 | train_loss_vec += loss2.data[0] 221 | 222 | _, predicted = torch.max(outputs.data, 1) 223 | 224 | correctLabel = labels[-1].view(-1, crop_size, crop_size).long() 225 | hist += util.fast_hist( 226 | predicted.view(predicted.size(0), -1).cpu().numpy(), 227 | correctLabel.view(correctLabel.size(0), -1).cpu().numpy(), 228 | config["task1_classes"], 229 | ) 230 | 231 | _, predicted_angle = torch.max(pred_vecmaps.data, 1) 232 | correct_angles = vecmap_angles[-1].view(-1, crop_size, crop_size).long() 233 | hist_angles += util.fast_hist( 234 | predicted_angle.view(predicted_angle.size(0), -1).cpu().numpy(), 235 | correct_angles.view(correct_angles.size(0), -1).cpu().numpy(), 236 | config["task2_classes"], 237 | ) 238 | 239 | p_accu, miou, road_iou, fwacc = util.performMetrics( 240 | train_loss_file, 241 | val_loss_file, 242 | epoch, 243 | hist, 244 | train_loss_iou / (i + 1), 245 | train_loss_vec / (i + 1), 246 | ) 247 | p_accu_angle, miou_angle, fwacc_angle = util.performAngleMetrics( 248 | train_loss_angle_file, val_loss_angle_file, epoch, hist_angles 249 | ) 250 | 251 | viz_util.progress_bar( 252 | i, 253 | len(train_loader), 254 | "Loss: %.6f | VecLoss: %.6f | road miou: %.4f%%(%.4f%%) | angle miou: %.4f%% " 255 | % ( 256 | train_loss_iou / (i + 1), 257 | train_loss_vec / (i + 1), 258 | miou, 259 | road_iou, 260 | miou_angle, 261 | ), 262 | ) 263 | 264 | torch.autograd.backward([loss1, loss2]) 265 | 266 | if i % config["trainer"]["iter_size"] == 0 or i == len(train_loader) - 1: 267 | optimizer.step() 268 | optimizer.zero_grad() 269 | 270 | del ( 271 | outputs, 272 | pred_vecmaps, 273 | predicted, 274 | correct_angles, 275 | correctLabel, 276 | inputsBGR, 277 | labels, 278 | vecmap_angles, 279 | ) 280 | 281 | util.performMetrics( 282 | train_loss_file, 283 | val_loss_file, 284 | epoch, 285 | hist, 286 | train_loss_iou / len(train_loader), 287 | train_loss_vec / len(train_loader), 288 | write=True, 289 | ) 290 | util.performAngleMetrics( 291 | train_loss_angle_file, val_loss_angle_file, epoch, hist_angles, write=True 292 | ) 293 | 294 | 295 | def test(epoch): 296 | global best_accuracy 297 | global best_miou 298 | model.eval() 299 | test_loss_iou = 0 300 | test_loss_vec = 0 301 | hist = np.zeros((config["task1_classes"], config["task1_classes"])) 302 | hist_angles = np.zeros((config["task2_classes"], config["task2_classes"])) 303 | crop_size = config["val_dataset"][args.dataset]["crop_size"] 304 | for i, (inputsBGR, labels, vecmap_angles) in enumerate(val_loader, 0): 305 | inputsBGR = Variable( 306 | inputsBGR.float().cuda(), volatile=True, requires_grad=False 307 | ) 308 | 309 | outputs, pred_vecmaps = model(inputsBGR) 310 | if args.multi_scale_pred: 311 | loss1 = road_loss(outputs[0], util.to_variable(labels[0], True, False), True) 312 | num_stacks = model.module.num_stacks if num_gpus > 1 else model.num_stacks 313 | for idx in range(num_stacks - 1): 314 | loss1 += road_loss(outputs[idx + 1], util.to_variable(labels[0], True, False), True) 315 | for idx, output in enumerate(outputs[-2:]): 316 | loss1 += road_loss(output, util.to_variable(labels[idx + 1], True, False), True) 317 | 318 | loss2 = angle_loss(pred_vecmaps[0], util.to_variable(vecmap_angles[0], True, False)) 319 | for idx in range(num_stacks - 1): 320 | loss2 += angle_loss( 321 | pred_vecmaps[idx + 1], util.to_variable(vecmap_angles[0], True, False) 322 | ) 323 | for idx, pred_vecmap in enumerate(pred_vecmaps[-2:]): 324 | loss2 += angle_loss( 325 | pred_vecmap, util.to_variable(vecmap_angles[idx + 1], True, False) 326 | ) 327 | 328 | outputs = outputs[-1] 329 | pred_vecmaps = pred_vecmaps[-1] 330 | else: 331 | loss1 = road_loss(outputs, util.to_variable(labels[0], True, False), True) 332 | loss2 = angle_loss(pred_vecmaps, util.to_variable(labels[0], True, False)) 333 | 334 | test_loss_iou += loss1.data[0] 335 | test_loss_vec += loss2.data[0] 336 | 337 | _, predicted = torch.max(outputs.data, 1) 338 | 339 | correctLabel = labels[-1].view(-1, crop_size, crop_size).long() 340 | hist += util.fast_hist( 341 | predicted.view(predicted.size(0), -1).cpu().numpy(), 342 | correctLabel.view(correctLabel.size(0), -1).cpu().numpy(), 343 | config["task1_classes"], 344 | ) 345 | 346 | _, predicted_angle = torch.max(pred_vecmaps.data, 1) 347 | correct_angles = vecmap_angles[-1].view(-1, crop_size, crop_size).long() 348 | hist_angles += util.fast_hist( 349 | predicted_angle.view(predicted_angle.size(0), -1).cpu().numpy(), 350 | correct_angles.view(correct_angles.size(0), -1).cpu().numpy(), 351 | config["task2_classes"], 352 | ) 353 | 354 | p_accu, miou, road_iou, fwacc = util.performMetrics( 355 | train_loss_file, 356 | val_loss_file, 357 | epoch, 358 | hist, 359 | test_loss_iou / (i + 1), 360 | test_loss_vec / (i + 1), 361 | is_train=False, 362 | ) 363 | p_accu_angle, miou_angle, fwacc_angle = util.performAngleMetrics( 364 | train_loss_angle_file, val_loss_angle_file, epoch, hist_angles, is_train=False 365 | ) 366 | 367 | viz_util.progress_bar( 368 | i, 369 | len(val_loader), 370 | "Loss: %.6f | VecLoss: %.6f | road miou: %.4f%%(%.4f%%) | angle miou: %.4f%%" 371 | % ( 372 | test_loss_iou / (i + 1), 373 | test_loss_vec / (i + 1), 374 | miou, 375 | road_iou, 376 | miou_angle, 377 | ), 378 | ) 379 | 380 | if i % 100 == 0 or i == len(val_loader) - 1: 381 | images_path = "{}/images/".format(experiment_dir) 382 | util.ensure_dir(images_path) 383 | util.savePredictedProb( 384 | inputsBGR.data.cpu(), 385 | labels[-1].cpu(), 386 | predicted.cpu(), 387 | F.softmax(outputs, dim=1).data.cpu()[:, 1, :, :], 388 | predicted_angle.cpu(), 389 | os.path.join(images_path, "validate_pair_{}_{}.png".format(epoch, i)), 390 | norm_type=config["val_dataset"]["normalize_type"], 391 | ) 392 | 393 | del inputsBGR, labels, predicted, outputs, pred_vecmaps, predicted_angle 394 | 395 | accuracy, miou, road_iou, fwacc = util.performMetrics( 396 | train_loss_file, 397 | val_loss_file, 398 | epoch, 399 | hist, 400 | test_loss_iou / len(val_loader), 401 | test_loss_vec / len(val_loader), 402 | is_train=False, 403 | write=True, 404 | ) 405 | util.performAngleMetrics( 406 | train_loss_angle_file, 407 | val_loss_angle_file, 408 | epoch, 409 | hist_angles, 410 | is_train=False, 411 | write=True, 412 | ) 413 | 414 | if miou > best_miou: 415 | best_accuracy = accuracy 416 | best_miou = miou 417 | util.save_checkpoint(epoch, test_loss_iou / len(val_loader), model, optimizer, best_accuracy, best_miou, config, experiment_dir) 418 | 419 | return test_loss_iou / len(val_loader) 420 | 421 | 422 | for epoch in range(start_epoch, total_epochs + 1): 423 | start_time = datetime.now() 424 | scheduler.step(epoch) 425 | print("\nTraining Epoch: %d" % epoch) 426 | train(epoch) 427 | if epoch % config["trainer"]["test_freq"] == 0: 428 | print("\nTesting Epoch: %d" % epoch) 429 | val_loss = test(epoch) 430 | 431 | end_time = datetime.now() 432 | print("Time Elapsed for epoch => {1}".format(epoch, end_time - start_time)) 433 | -------------------------------------------------------------------------------- /train_refine_pre.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import json 5 | import os 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | import torch.utils.data as data 14 | from model.models import MODELS_REFINE 15 | from road_dataset import DeepGlobeDatasetCorrupt, SpacenetDatasetCorrupt 16 | from torch.autograd import Variable 17 | from torch.optim.lr_scheduler import MultiStepLR 18 | from utils.loss import CrossEntropyLoss2d, mIoULoss 19 | from utils import util 20 | from utils import viz_util 21 | 22 | 23 | __dataset__ = {"spacenet": SpacenetDatasetCorrupt, "deepglobe": DeepGlobeDatasetCorrupt} 24 | 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument( 28 | "--config", required=True, type=str, help="config file path" 29 | ) 30 | parser.add_argument( 31 | "--model_name", 32 | required=True, 33 | choices=sorted(MODELS_REFINE.keys()), 34 | help="Name of Model = {}".format(MODELS_REFINE.keys()), 35 | ) 36 | parser.add_argument("--exp", required=True, type=str, help="Experiment Name/Directory") 37 | parser.add_argument( 38 | "--resume", default=None, type=str, help="path to latest checkpoint (default: None)" 39 | ) 40 | parser.add_argument( 41 | "--dataset", 42 | required=True, 43 | choices=sorted(__dataset__.keys()), 44 | help="select dataset name from {}. (default: Spacenet)".format(__dataset__.keys()), 45 | ) 46 | parser.add_argument( 47 | "--model_kwargs", 48 | default={}, 49 | type=json.loads, 50 | help="parameters for the model", 51 | ) 52 | parser.add_argument( 53 | "--multi_scale_pred", 54 | default=False, 55 | type=util.str2bool, 56 | help="perform multi-scale prediction (default: False)", 57 | ) 58 | 59 | args = parser.parse_args() 60 | config = None 61 | 62 | if args.resume is not None: 63 | if args.config is not None: 64 | print("Warning: --config overridden by --resume") 65 | config = torch.load(args.resume)["config"] 66 | elif args.config is not None: 67 | config = json.load(open(args.config)) 68 | 69 | assert config is not None 70 | 71 | util.setSeed(config) 72 | 73 | experiment_dir = os.path.join(config["trainer"]["save_dir"], args.exp) 74 | util.ensure_dir(experiment_dir) 75 | 76 | ###Logging Files 77 | train_file = "{}/{}_train_loss.txt".format(experiment_dir, args.dataset) 78 | test_file = "{}/{}_test_loss.txt".format(experiment_dir, args.dataset) 79 | 80 | train_loss_file = open(train_file, "w", 0) 81 | val_loss_file = open(test_file, "w", 0) 82 | 83 | ### Angle Metrics 84 | train_file_angle = "{}/{}_train_angle_loss.txt".format(experiment_dir, args.dataset) 85 | test_file_angle = "{}/{}_test_angle_loss.txt".format(experiment_dir, args.dataset) 86 | 87 | train_loss_angle_file = open(train_file_angle, "w", 0) 88 | val_loss_angle_file = open(test_file_angle, "w", 0) 89 | ################################################################################ 90 | num_gpus = torch.cuda.device_count() 91 | 92 | model = MODELS_REFINE[args.model_name]( 93 | in_channels=5, num_classes=config["task1_classes"] 94 | ) 95 | 96 | if num_gpus > 1: 97 | print("Training with multiple GPUs ({})".format(num_gpus)) 98 | model = nn.DataParallel(model).cuda() 99 | else: 100 | print("Single Cuda Node is avaiable") 101 | model.cuda() 102 | ################################################################################ 103 | 104 | ### Load Dataset from root folder and intialize DataLoader 105 | train_loader = data.DataLoader( 106 | __dataset__[args.dataset]( 107 | config["train_dataset"], 108 | seed=config["seed"], 109 | is_train=True 110 | ), 111 | batch_size=config["train_batch_size"], 112 | num_workers=8, 113 | shuffle=True, 114 | pin_memory=False, 115 | ) 116 | 117 | val_loader = data.DataLoader( 118 | __dataset__[args.dataset]( 119 | config["val_dataset"], 120 | seed=config["seed"], 121 | is_train=False 122 | ), 123 | batch_size=config["val_batch_size"], 124 | num_workers=8, 125 | shuffle=True, 126 | pin_memory=False, 127 | ) 128 | 129 | print("Training with dataset => {}".format(train_loader.dataset.__class__.__name__)) 130 | ################################################################################ 131 | 132 | best_accuracy = 0 133 | best_miou = 0 134 | start_epoch = 1 135 | total_epochs = config["trainer"]["total_epochs"] 136 | optimizer = optim.SGD( 137 | model.parameters(), lr=config["optimizer"]["lr"], momentum=0.9, weight_decay=0.0005 138 | ) 139 | 140 | if args.resume is not None: 141 | print("Loading from existing FCN and copying weights to continue....") 142 | checkpoint = torch.load(args.resume) 143 | start_epoch = checkpoint["epoch"] + 1 144 | best_miou = checkpoint["miou"] 145 | # stat_parallel_dict = util.getParllelNetworkStateDict(checkpoint['state_dict']) 146 | # model.load_state_dict(stat_parallel_dict) 147 | model.load_state_dict(checkpoint["state_dict"]) 148 | optimizer.load_state_dict(checkpoint["optimizer"]) 149 | else: 150 | util.weights_init(model, manual_seed=config["seed"]) 151 | 152 | viz_util.summary(model, print_arch=False) 153 | 154 | scheduler = MultiStepLR( 155 | optimizer, 156 | milestones=eval(config["optimizer"]["lr_drop_epoch"]), 157 | gamma=config["optimizer"]["lr_step"], 158 | ) 159 | 160 | 161 | weights = torch.ones(config["task1_classes"]).cuda() 162 | if config["task1_weight"] < 1: 163 | print("Roads are weighted.") 164 | weights[0] = 1 - config["task1_weight"] 165 | weights[1] = config["task1_weight"] 166 | 167 | 168 | road_loss = mIoULoss( 169 | weight=weights, size_average=True, n_classes=config["task1_classes"] 170 | ).cuda() 171 | 172 | 173 | def train(epoch): 174 | train_loss_iou = 0 175 | train_loss_vec = 0 176 | model.train() 177 | optimizer.zero_grad() 178 | hist = np.zeros((config["task1_classes"], config["task1_classes"])) 179 | crop_size = config["train_dataset"][args.dataset]["crop_size"] 180 | for i, data in enumerate(train_loader, 0): 181 | inputs, labels, erased_label = data 182 | batch_size = inputs.size(0) 183 | 184 | inputs = Variable(inputs.float().cuda()) 185 | erased_label = Variable(erased_label[-1].float().cuda()).unsqueeze(dim = 1) 186 | temp = erased_label 187 | 188 | for k in range(config['refinement']): 189 | in_ = torch.cat((inputs, erased_label, temp), dim=1) 190 | outputs = model(in_) 191 | if args.multi_scale_pred: 192 | loss1 = road_loss(outputs[0], labels[0].long().cuda(), False) 193 | num_stacks = model.module.num_stacks if num_gpus > 1 else model.num_stacks 194 | for idx in range(num_stacks - 1): 195 | loss1 += road_loss(outputs[idx + 1], labels[0].long().cuda(), False) 196 | for idx, output in enumerate(outputs[-2:]): 197 | loss1 += road_loss(output, labels[idx + 1].long().cuda(), False) 198 | 199 | outputs = outputs[-1] 200 | else: 201 | loss1 = road_loss(outputs, labels[-1].long().cuda(), False) 202 | 203 | loss1.backward() 204 | temp = Variable(torch.max(outputs.data, 1)[1].float()).unsqueeze(dim = 1) 205 | 206 | train_loss_iou += loss1.data[0] 207 | 208 | _, predicted = torch.max(outputs.data, 1) 209 | 210 | correctLabel = labels[-1].view(-1, crop_size, crop_size).long() 211 | hist += util.fast_hist( 212 | predicted.view(predicted.size(0), -1).cpu().numpy(), 213 | correctLabel.view(correctLabel.size(0), -1).cpu().numpy(), 214 | config["task1_classes"], 215 | ) 216 | 217 | p_accu, miou, road_iou, fwacc = util.performMetrics( 218 | train_loss_file, 219 | val_loss_file, 220 | epoch, 221 | hist, 222 | train_loss_iou / (i + 1), 223 | 0, 224 | ) 225 | 226 | viz_util.progress_bar( 227 | i, 228 | len(train_loader), 229 | "Loss: %.6f | road miou: %.4f%%(%.4f%%)" 230 | % ( 231 | train_loss_iou / (i + 1), 232 | miou, 233 | road_iou, 234 | ), 235 | ) 236 | 237 | if i % config["trainer"]["iter_size"] == 0 or i == len(train_loader) - 1: 238 | optimizer.step() 239 | optimizer.zero_grad() 240 | 241 | del ( 242 | outputs, 243 | predicted, 244 | correctLabel, 245 | inputs, 246 | labels, 247 | ) 248 | 249 | util.performMetrics( 250 | train_loss_file, 251 | val_loss_file, 252 | epoch, 253 | hist, 254 | train_loss_iou / len(train_loader), 255 | 0, 256 | write=True, 257 | ) 258 | 259 | 260 | def test(epoch): 261 | global best_accuracy 262 | global best_miou 263 | model.eval() 264 | test_loss_iou = 0 265 | test_loss_vec = 0 266 | hist = np.zeros((config["task1_classes"], config["task1_classes"])) 267 | crop_size = config["val_dataset"][args.dataset]["crop_size"] 268 | for i, datas in enumerate(val_loader, 0): 269 | inputs, labels, erased_label = data 270 | batch_size = inputs.size(0) 271 | 272 | inputs = Variable(inputs.float().cuda(), volatile=True, requires_grad=False) 273 | erased_label = Variable(erased_label[-1].float().cuda(), volatile=True, requires_grad=False).unsqueeze(dim = 1) 274 | temp = erased_label 275 | 276 | for k in range(config['refinement']): 277 | in_ = torch.cat((inputs, erased_label, temp), dim=1) 278 | outputs = model(in_) 279 | if args.multi_scale_pred: 280 | loss1 = road_loss(outputs[0], labels[0].long().cuda(), False) 281 | num_stacks = model.module.num_stacks if num_gpus > 1 else model.num_stacks 282 | for idx in range(num_stacks - 1): 283 | loss1 += road_loss(outputs[idx + 1], labels[0].long().cuda(), False) 284 | for idx, output in enumerate(outputs[-2:]): 285 | loss1 += road_loss(output, labels[idx + 1].long().cuda(), False) 286 | 287 | outputs = outputs[-1] 288 | else: 289 | loss1 = road_loss(outputs, labels[-1].long().cuda(), False) 290 | 291 | temp = Variable(torch.max(outputs.data, 1)[1].float(), volatile=True, requires_grad=False).unsqueeze(dim = 1) 292 | 293 | test_loss_iou += loss1.data[0] 294 | 295 | _, predicted = torch.max(outputs.data, 1) 296 | 297 | correctLabel = labels[-1].view(-1, crop_size, crop_size).long() 298 | hist += util.fast_hist( 299 | predicted.view(predicted.size(0), -1).cpu().numpy(), 300 | correctLabel.view(correctLabel.size(0), -1).cpu().numpy(), 301 | config["task1_classes"], 302 | ) 303 | 304 | p_accu, miou, road_iou, fwacc = util.performMetrics( 305 | train_loss_file, 306 | val_loss_file, 307 | epoch, 308 | hist, 309 | test_loss_iou / (i + 1), 310 | 0, 311 | is_train=False, 312 | ) 313 | 314 | viz_util.progress_bar( 315 | i, 316 | len(val_loader), 317 | "Loss: %.6f | road miou: %.4f%%(%.4f%%)" 318 | % ( 319 | test_loss_iou / (i + 1), 320 | miou, 321 | road_iou, 322 | ), 323 | ) 324 | 325 | if i % 100 == 0 or i == len(val_loader) - 1: 326 | images_path = "{}/images/".format(experiment_dir) 327 | util.ensure_dir(images_path) 328 | util.savePredictedProb( 329 | inputsBGR.data.cpu(), 330 | labels[-1].cpu(), 331 | predicted.cpu(), 332 | F.softmax(outputs, dim=1).data.cpu()[:, 1, :, :], 333 | None, 334 | os.path.join(images_path, "validate_pair_{}_{}.png".format(epoch, i)), 335 | norm_type=config["val_dataset"]["normalize_type"], 336 | ) 337 | 338 | del inputsBGR, labels, predicted, outputs 339 | 340 | accuracy, miou, road_iou, fwacc = util.performMetrics( 341 | train_loss_file, 342 | val_loss_file, 343 | epoch, 344 | hist, 345 | test_loss_iou / len(val_loader), 346 | 0, 347 | is_train=False, 348 | write=True, 349 | ) 350 | 351 | if miou > best_miou: 352 | best_accuracy = accuracy 353 | best_miou = miou 354 | util.save_checkpoint(epoch, test_loss_iou / len(val_loader), model, optimizer, best_accuracy, best_miou, config, experiment_dir) 355 | 356 | return test_loss_iou / len(val_loader) 357 | 358 | 359 | for epoch in range(start_epoch, total_epochs + 1): 360 | start_time = datetime.now() 361 | scheduler.step(epoch) 362 | print("\nTraining Epoch: %d" % epoch) 363 | train(epoch) 364 | if epoch % config["trainer"]["test_freq"] == 0: 365 | print("\nTesting Epoch: %d" % epoch) 366 | val_loss = test(epoch) 367 | 368 | end_time = datetime.now() 369 | print("Time Elapsed for epoch => {1}".format(epoch, end_time - start_time)) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anilbatra2185/road_connectivity/01e41662a43d9e289926ebd58eff4f6e14359ae4/utils/__init__.py -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class CrossEntropyLoss2d(nn.Module): 10 | def __init__(self, weight=None, size_average=True, ignore_index=255, reduce=True): 11 | super(CrossEntropyLoss2d, self).__init__() 12 | self.nll_loss = nn.NLLLoss(weight, size_average, ignore_index, reduce) 13 | 14 | def forward(self, inputs, targets): 15 | log_p = F.log_softmax(inputs, dim=1) 16 | loss = self.nll_loss(log_p, targets) 17 | return loss 18 | 19 | 20 | def to_one_hot_var(tensor, nClasses, requires_grad=False): 21 | 22 | n, h, w = tensor.size() 23 | one_hot = tensor.new(n, nClasses, h, w).fill_(0) 24 | one_hot = one_hot.scatter_(1, tensor.view(n, 1, h, w), 1) 25 | return Variable(one_hot, requires_grad=requires_grad) 26 | 27 | 28 | class mIoULoss(nn.Module): 29 | def __init__(self, weight=None, size_average=True, n_classes=2): 30 | super(mIoULoss, self).__init__() 31 | self.classes = n_classes 32 | self.weights = Variable(weight * weight) 33 | 34 | def forward(self, inputs, target, is_target_variable=False): 35 | # inputs => N x Classes x H x W 36 | # target => N x H x W 37 | # target_oneHot => N x Classes x H x W 38 | 39 | N = inputs.size()[0] 40 | if is_target_variable: 41 | target_oneHot = to_one_hot_var(target.data, self.classes).float() 42 | else: 43 | target_oneHot = to_one_hot_var(target, self.classes).float() 44 | 45 | # predicted probabilities for each pixel along channel 46 | inputs = F.softmax(inputs, dim=1) 47 | 48 | # Numerator Product 49 | inter = inputs * target_oneHot 50 | ## Sum over all pixels N x C x H x W => N x C 51 | inter = inter.view(N, self.classes, -1).sum(2) 52 | 53 | # Denominator 54 | union = inputs + target_oneHot - (inputs * target_oneHot) 55 | ## Sum over all pixels N x C x H x W => N x C 56 | union = union.view(N, self.classes, -1).sum(2) 57 | 58 | loss = (self.weights * inter) / (self.weights * union + 1e-8) 59 | 60 | ## Return average loss over classes and batch 61 | return -torch.mean(loss) 62 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import sys 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch.autograd import Variable 12 | from skimage.morphology import skeletonize 13 | 14 | 15 | def str2bool(v): 16 | if isinstance(v, bool): 17 | return v 18 | if v.lower() in ("yes", "true", "t", "y", "1"): 19 | return True 20 | elif v.lower() in ("no", "false", "f", "n", "0"): 21 | return False 22 | else: 23 | raise argparse.ArgumentTypeError("Boolean value expected.") 24 | 25 | 26 | def ensure_dir(path): 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | 30 | 31 | def setSeed(config): 32 | if config["seed"] is None: 33 | manualSeed = np.random.randint(1, 10000) 34 | else: 35 | manualSeed = config["seed"] 36 | print("Random Seed: ", manualSeed) 37 | np.random.seed(manualSeed) 38 | torch.manual_seed(manualSeed) 39 | random.seed(manualSeed) 40 | torch.cuda.manual_seed_all(manualSeed) 41 | 42 | 43 | def getParllelNetworkStateDict(state_dict): 44 | from collections import OrderedDict 45 | 46 | new_state_dict = OrderedDict() 47 | for k, v in state_dict.items(): 48 | name = k[7:] # remove `module.` 49 | new_state_dict[name] = v 50 | return new_state_dict 51 | 52 | 53 | def to_variable(tensor, volatile=False, requires_grad=True): 54 | return Variable(tensor.long().cuda(), requires_grad=requires_grad) 55 | 56 | 57 | def weights_init(model, manual_seed=7): 58 | np.random.seed(manual_seed) 59 | torch.manual_seed(manual_seed) 60 | random.seed(manual_seed) 61 | torch.cuda.manual_seed_all(manual_seed) 62 | for m in model.modules(): 63 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 64 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 65 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 66 | elif isinstance(m, nn.BatchNorm2d): 67 | m.weight.data.fill_(1) 68 | m.bias.data.zero_() 69 | 70 | 71 | def weights_normal_init(model, manual_seed=7): 72 | np.random.seed(manual_seed) 73 | torch.manual_seed(manual_seed) 74 | random.seed(manual_seed) 75 | torch.cuda.manual_seed_all(manual_seed) 76 | for m in model.modules(): 77 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 78 | m.weight.data.normal_(0.0, 0.02) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | m.weight.data.normal_(1.0, 0.02) 81 | m.bias.data.fill_(0) 82 | 83 | 84 | def performAngleMetrics( 85 | train_loss_angle_file, val_loss_angle_file, epoch, hist, is_train=True, write=False 86 | ): 87 | 88 | pixel_accuracy = np.diag(hist).sum() / hist.sum() 89 | mean_accuracy = np.diag(hist) / hist.sum(1) 90 | mean_iou = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 91 | freq = hist.sum(1) / hist.sum() 92 | fwavacc = (freq[freq > 0] * mean_iou[freq > 0]).sum() 93 | if write and is_train: 94 | train_loss_angle_file.write( 95 | "[%d], Pixel Accuracy:%.3f, Mean Accuracy:%.3f, Mean IoU:%.3f, Freq.Weighted Accuray:%.3f \n" 96 | % ( 97 | epoch, 98 | 100 * pixel_accuracy, 99 | 100 * np.nanmean(mean_accuracy), 100 | 100 * np.nanmean(mean_iou), 101 | 100 * fwavacc, 102 | ) 103 | ) 104 | elif write and not is_train: 105 | val_loss_angle_file.write( 106 | "[%d], Pixel Accuracy:%.3f, Mean Accuracy:%.3f, Mean IoU:%.3f, Freq.Weighted Accuray:%.3f \n" 107 | % ( 108 | epoch, 109 | 100 * pixel_accuracy, 110 | 100 * np.nanmean(mean_accuracy), 111 | 100 * np.nanmean(mean_iou), 112 | 100 * fwavacc, 113 | ) 114 | ) 115 | 116 | return 100 * pixel_accuracy, 100 * np.nanmean(mean_iou), 100 * fwavacc 117 | 118 | 119 | def performMetrics( 120 | train_loss_file, 121 | val_loss_file, 122 | epoch, 123 | hist, 124 | loss, 125 | loss_vec, 126 | is_train=True, 127 | write=False, 128 | ): 129 | 130 | pixel_accuracy = np.diag(hist).sum() / hist.sum() 131 | mean_accuracy = np.diag(hist) / hist.sum(1) 132 | mean_iou = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 133 | freq = hist.sum(1) / hist.sum() 134 | fwavacc = (freq[freq > 0] * mean_iou[freq > 0]).sum() 135 | 136 | if write and is_train: 137 | train_loss_file.write( 138 | "[%d], Loss:%.5f, Loss(VecMap):%.5f, Pixel Accuracy:%.3f, Mean Accuracy:%.3f, Mean IoU:%.3f, Class IoU:[%.5f/%.5f], Freq.Weighted Accuray:%.3f \n" 139 | % ( 140 | epoch, 141 | loss, 142 | loss_vec, 143 | 100 * pixel_accuracy, 144 | 100 * np.nanmean(mean_accuracy), 145 | 100 * np.nanmean(mean_iou), 146 | mean_iou[0], 147 | mean_iou[1], 148 | 100 * fwavacc, 149 | ) 150 | ) 151 | elif write and not is_train: 152 | val_loss_file.write( 153 | "[%d], Loss:%.5f, Loss(VecMap):%.5f, Pixel Accuracy:%.3f, Mean Accuracy:%.3f, Mean IoU:%.3f, Class IoU:[%.5f/%.5f], Freq.Weighted Accuray:%.3f \n" 154 | % ( 155 | epoch, 156 | loss, 157 | loss_vec, 158 | 100 * pixel_accuracy, 159 | 100 * np.nanmean(mean_accuracy), 160 | 100 * np.nanmean(mean_iou), 161 | mean_iou[0], 162 | mean_iou[1], 163 | 100 * fwavacc, 164 | ) 165 | ) 166 | 167 | return ( 168 | 100 * pixel_accuracy, 169 | 100 * np.nanmean(mean_iou), 170 | 100 * mean_iou[1], 171 | 100 * fwavacc, 172 | ) 173 | 174 | 175 | def fast_hist(a, b, n): 176 | 177 | k = (a >= 0) & (a < n) 178 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) 179 | 180 | 181 | def save_checkpoint(epoch, loss, model, optimizer, best_accuracy, best_miou, config, experiment_dir): 182 | 183 | if torch.cuda.device_count() > 1: 184 | arch = type(model.module).__name__ 185 | else: 186 | arch = type(model).__name__ 187 | state = { 188 | "arch": arch, 189 | "epoch": epoch, 190 | "state_dict": model.state_dict(), 191 | "optimizer": optimizer.state_dict(), 192 | "pixel_accuracy": best_accuracy, 193 | "miou": best_miou, 194 | "config": config, 195 | } 196 | filename = os.path.join( 197 | experiment_dir, "checkpoint-epoch{:03d}-loss-{:.4f}.pth.tar".format( 198 | epoch, loss) 199 | ) 200 | torch.save(state, filename) 201 | os.rename(filename, os.path.join(experiment_dir, "model_best.pth.tar")) 202 | print("Saving current best: {} ...".format("model_best.pth.tar")) 203 | 204 | 205 | def savePredictedProb( 206 | real, 207 | gt, 208 | predicted, 209 | predicted_prob, 210 | pred_affinity=None, 211 | image_name="", 212 | norm_type="Mean", 213 | ): 214 | b, c, h, w = real.size() 215 | grid = [] 216 | mean_bgr = np.array([70.95016901, 71.16398124, 71.30953645]) 217 | deviation_bgr = np.array([34.00087859, 35.18201658, 36.40463264]) 218 | 219 | for idx in range(b): 220 | # real_ = np.asarray(real[idx].numpy().transpose(1,2,0),dtype=np.float32) 221 | real_ = np.asarray(real[idx].numpy().transpose( 222 | 1, 2, 0), dtype=np.float32) 223 | if norm_type == "Mean": 224 | real_ = real_ + mean_bgr 225 | elif norm_type == "Std": 226 | real_ = (real_ * deviation_bgr) + mean_bgr 227 | 228 | real_ = np.asarray(real_, dtype=np.uint8) 229 | gt_ = gt[idx].numpy() * 255.0 230 | gt_ = np.asarray(gt_, dtype=np.uint8) 231 | gt_ = np.stack((gt_,) * 3).transpose(1, 2, 0) 232 | 233 | predicted_ = (predicted[idx]).numpy() * 255.0 234 | predicted_ = np.asarray(predicted_, dtype=np.uint8) 235 | predicted_ = np.stack((predicted_,) * 3).transpose(1, 2, 0) 236 | 237 | predicted_prob_ = (predicted_prob[idx]).numpy() * 255.0 238 | # predicted_prob_ = predicted_prob_[:,:] 239 | predicted_prob_ = np.asarray(predicted_prob_, dtype=np.uint8) 240 | # predicted_prob_ = np.stack((predicted_prob_,)*3).transpose(1,2,0) 241 | predicted_prob_ = cv2.applyColorMap(predicted_prob_, cv2.COLORMAP_JET) 242 | 243 | if pred_affinity is not None: 244 | hsv = np.zeros_like(real_) 245 | hsv[..., 1] = 255 246 | affinity_ = pred_affinity[idx].numpy() 247 | mag = np.copy(affinity_) 248 | mag[mag < 36] = 1 249 | mag[mag >= 36] = 0 250 | affinity_[affinity_ == 36] = 0 251 | 252 | # mag, ang = cv2.cartToPolar(affinity_[0], affinity_[1]) 253 | hsv[..., 0] = affinity_ * 10 / np.pi / 2 254 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 255 | affinity_bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 256 | 257 | pair = np.concatenate( 258 | (real_, gt_, predicted_, predicted_prob_, affinity_bgr), axis=1 259 | ) 260 | else: 261 | pair = np.concatenate( 262 | (real_, gt_, predicted_, predicted_prob_), axis=1) 263 | grid.append(pair) 264 | 265 | if pred_affinity is not None: 266 | cv2.imwrite(image_name, np.array(grid).reshape(b * h, 5 * w, 3)) 267 | else: 268 | cv2.imwrite(image_name, np.array(grid).reshape(b * h, 4 * w, 3)) 269 | 270 | 271 | def get_relaxed_precision(a, b, buffer): 272 | tp = 0 273 | indices = np.where(a == 1) 274 | for ind in range(len(indices[0])): 275 | tp += (np.sum( 276 | b[indices[0][ind]-buffer: indices[0][ind]+buffer+1, 277 | indices[1][ind]-buffer: indices[1][ind]+buffer+1]) > 0).astype(np.int) 278 | return tp 279 | 280 | 281 | def relaxed_f1(pred, gt, buffer=3): 282 | ''' Usage and Call 283 | # rp_tp, rr_tp, pred_p, gt_p = relaxed_f1(predicted.cpu().numpy(), labels.cpu().numpy(), buffer = 3) 284 | 285 | # rprecision_tp += rp_tp 286 | # rrecall_tp += rr_tp 287 | # pred_positive += pred_p 288 | # gt_positive += gt_p 289 | 290 | # precision = rprecision_tp/(gt_positive + 1e-12) 291 | # recall = rrecall_tp/(gt_positive + 1e-12) 292 | # f1measure = 2*precision*recall/(precision + recall + 1e-12) 293 | # iou = precision*recall/(precision+recall-(precision*recall) + 1e-12) 294 | ''' 295 | 296 | rprecision_tp, rrecall_tp, pred_positive, gt_positive = 0, 0, 0, 0 297 | for b in range(pred.shape[0]): 298 | pred_sk = skeletonize(pred[b]) 299 | gt_sk = skeletonize(gt[b]) 300 | # pred_sk = pred[b] 301 | # gt_sk = gt[b] 302 | rprecision_tp += get_relaxed_precision(pred_sk, gt_sk, buffer) 303 | rrecall_tp += get_relaxed_precision(gt_sk, pred_sk, buffer) 304 | pred_positive += len(np.where(pred_sk == 1)[0]) 305 | gt_positive += len(np.where(gt_sk == 1)[0]) 306 | 307 | return rprecision_tp, rrecall_tp, pred_positive, gt_positive 308 | -------------------------------------------------------------------------------- /utils/viz_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import sys 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | term_width = int(190) 14 | 15 | TOTAL_BAR_LENGTH = 25.0 16 | last_time = time.time() 17 | begin_time = last_time 18 | 19 | 20 | def progress_bar(current, total, msg=None): 21 | global last_time, begin_time 22 | if current == 0: 23 | begin_time = time.time() # Reset for new bar. 24 | 25 | cur_len = int(TOTAL_BAR_LENGTH * current / total) 26 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 27 | 28 | sys.stdout.write("[") 29 | for i in range(cur_len): 30 | sys.stdout.write("=") 31 | sys.stdout.write(">") 32 | for i in range(rest_len): 33 | sys.stdout.write(".") 34 | sys.stdout.write("]") 35 | 36 | cur_time = time.time() 37 | step_time = cur_time - last_time 38 | last_time = cur_time 39 | tot_time = cur_time - begin_time 40 | 41 | L = [] 42 | L.append("S:%s" % format_time(step_time)) 43 | L.append("|T:%s" % format_time(tot_time)) 44 | if msg: 45 | L.append("|" + msg) 46 | 47 | msg = "".join(L) 48 | sys.stdout.write(msg) 49 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): 50 | sys.stdout.write(" ") 51 | 52 | # Go back to the center of the bar. 53 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2): 54 | sys.stdout.write("\b") 55 | sys.stdout.write(" %d/%d " % (current + 1, total)) 56 | 57 | if current < total - 1: 58 | sys.stdout.write("\r") 59 | else: 60 | sys.stdout.write("\n") 61 | sys.stdout.flush() 62 | 63 | 64 | def format_time(seconds): 65 | days = int(seconds / 3600 / 24) 66 | seconds = seconds - days * 3600 * 24 67 | hours = int(seconds / 3600) 68 | seconds = seconds - hours * 3600 69 | minutes = int(seconds / 60) 70 | seconds = seconds - minutes * 60 71 | secondsf = int(seconds) 72 | seconds = seconds - secondsf 73 | millis = int(seconds * 1000) 74 | 75 | f = "" 76 | i = 1 77 | if days > 0: 78 | f += str(days) + "D" 79 | i += 1 80 | if hours > 0 and i <= 2: 81 | f += str(hours) + "h" 82 | i += 1 83 | if minutes > 0 and i <= 2: 84 | f += str(minutes) + "m" 85 | i += 1 86 | if secondsf > 0 and i <= 2: 87 | f += str(secondsf) + "s" 88 | i += 1 89 | if millis > 0 and i <= 2: 90 | f += str(millis) + "ms" 91 | i += 1 92 | if f == "": 93 | f = "0ms" 94 | return f 95 | 96 | 97 | def summary(model, print_arch=False): 98 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 99 | params = sum([np.prod(p.size()) for p in model_parameters]) / 1000000.0 100 | 101 | print("*" * 100) 102 | if print_arch: 103 | print(model) 104 | if model.__class__.__name__ == "DataParallel": 105 | print( 106 | "Trainable parameters for Model {} : {} M".format( 107 | model.module.__class__.__name__, params 108 | ) 109 | ) 110 | else: 111 | print( 112 | "Trainable parameters for Model {} : {} M".format( 113 | model.__class__.__name__, params 114 | ) 115 | ) 116 | print("*" * 100) 117 | -------------------------------------------------------------------------------- /utils/viz_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import sys 5 | import time 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | term_width = int(190) 14 | 15 | TOTAL_BAR_LENGTH = 25.0 16 | last_time = time.time() 17 | begin_time = last_time 18 | 19 | 20 | def progress_bar(current, total, msg=None): 21 | global last_time, begin_time 22 | if current == 0: 23 | begin_time = time.time() # Reset for new bar. 24 | 25 | cur_len = int(TOTAL_BAR_LENGTH * current / total) 26 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 27 | 28 | sys.stdout.write("[") 29 | for i in range(cur_len): 30 | sys.stdout.write("=") 31 | sys.stdout.write(">") 32 | for i in range(rest_len): 33 | sys.stdout.write(".") 34 | sys.stdout.write("]") 35 | 36 | cur_time = time.time() 37 | step_time = cur_time - last_time 38 | last_time = cur_time 39 | tot_time = cur_time - begin_time 40 | 41 | L = [] 42 | L.append("S:%s" % format_time(step_time)) 43 | L.append("|T:%s" % format_time(tot_time)) 44 | if msg: 45 | L.append("|" + msg) 46 | 47 | msg = "".join(L) 48 | sys.stdout.write(msg) 49 | for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3): 50 | sys.stdout.write(" ") 51 | 52 | # Go back to the center of the bar. 53 | for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2): 54 | sys.stdout.write("\b") 55 | sys.stdout.write(" %d/%d " % (current + 1, total)) 56 | 57 | if current < total - 1: 58 | sys.stdout.write("\r") 59 | else: 60 | sys.stdout.write("\n") 61 | sys.stdout.flush() 62 | 63 | 64 | def format_time(seconds): 65 | days = int(seconds / 3600 / 24) 66 | seconds = seconds - days * 3600 * 24 67 | hours = int(seconds / 3600) 68 | seconds = seconds - hours * 3600 69 | minutes = int(seconds / 60) 70 | seconds = seconds - minutes * 60 71 | secondsf = int(seconds) 72 | seconds = seconds - secondsf 73 | millis = int(seconds * 1000) 74 | 75 | f = "" 76 | i = 1 77 | if days > 0: 78 | f += str(days) + "D" 79 | i += 1 80 | if hours > 0 and i <= 2: 81 | f += str(hours) + "h" 82 | i += 1 83 | if minutes > 0 and i <= 2: 84 | f += str(minutes) + "m" 85 | i += 1 86 | if secondsf > 0 and i <= 2: 87 | f += str(secondsf) + "s" 88 | i += 1 89 | if millis > 0 and i <= 2: 90 | f += str(millis) + "ms" 91 | i += 1 92 | if f == "": 93 | f = "0ms" 94 | return f 95 | 96 | 97 | def summary(model, print_arch=False): 98 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 99 | params = sum([np.prod(p.size()) for p in model_parameters]) / 1000000.0 100 | 101 | print("*" * 100) 102 | if print_arch: 103 | print(model) 104 | if model.__class__.__name__ == "DataParallel": 105 | print( 106 | "Trainable parameters for Model {} : {} M".format( 107 | model.module.__class__.__name__, params 108 | ) 109 | ) 110 | else: 111 | print( 112 | "Trainable parameters for Model {} : {} M".format( 113 | model.__class__.__name__, params 114 | ) 115 | ) 116 | print("*" * 100) 117 | --------------------------------------------------------------------------------