├── geometry_translation ├── __init__.py ├── load_debug.py ├── models │ ├── losses.py │ ├── metrics.py │ ├── __init__.py │ ├── translator.py │ ├── networks.py │ ├── t2rnet_model.py │ └── base_model.py ├── options │ ├── test_options.py │ ├── train_options.py │ └── base_options.py ├── test.py ├── region_concatenation.py ├── data_loader.py ├── utils │ ├── util.py │ ├── html.py │ └── visualizer.py ├── train.py └── feature_extraction.py ├── .gitmodules ├── .gitignore ├── data └── conf_tdrive_sample.json ├── LICENSE ├── topology_construction ├── custom_map_matching.py ├── topo_utils.py ├── map_refinement.py ├── main.py ├── graph_extraction.py └── link_generation.py └── README.md /geometry_translation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tptk"] 2 | path = tptk 3 | url = https://github.com/sjruan/tptk 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | 3 | data/ 4 | 5 | __pycache__/ 6 | 7 | geometry_translation/checkpoints/ 8 | 9 | geometry_translation/results/ 10 | -------------------------------------------------------------------------------- /geometry_translation/load_debug.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | if __name__ == '__main__': 4 | data = '../data/tdrive_sample_s5/learning/train/0_0_direction.npy' 5 | ab = np.load(data) 6 | print(ab.shape) 7 | -------------------------------------------------------------------------------- /geometry_translation/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SoftDiceLoss(nn.Module): 6 | def __init__(self): 7 | super(SoftDiceLoss, self).__init__() 8 | pass 9 | 10 | def forward(self, y_pred, y_true): 11 | smooth = 1.0 # may change 12 | i = torch.sum(y_true) 13 | j = torch.sum(y_pred) 14 | intersection = torch.sum(y_true * y_pred) 15 | score = (2. * intersection + smooth) / (i + j + smooth) 16 | loss = 1. - score.mean() 17 | return loss 18 | -------------------------------------------------------------------------------- /data/conf_tdrive_sample.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "dataset_name": "tdrive_sample", 4 | "min_lat": 39.8451, 5 | "min_lng": 116.2810, 6 | "max_lat": 39.9890, 7 | "max_lng": 116.4684, 8 | "nb_rows": 8192, 9 | "nb_cols": 8192 10 | }, 11 | "feature_extraction": { 12 | "nbhd_dist": 300, 13 | "nbhd_size": 8, 14 | "tile_pixel_size": 256, 15 | "test_tile_row_min": 12, 16 | "test_tile_row_max": 22, 17 | "test_tile_col_min": 3, 18 | "test_tile_col_max": 13 19 | }, 20 | "topology_construction": { 21 | "link_radius": 50, 22 | "alpha": 1.4, 23 | "min_supp": 5 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /geometry_translation/models/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Metrics(nn.Module): 6 | def __init__(self, threshold=0.5): 7 | super(Metrics, self).__init__() 8 | self.threshold = threshold 9 | 10 | def forward(self, pred, true): 11 | eps = 1e-10 12 | pred_ = (pred > self.threshold).data.float() 13 | true_ = (true > self.threshold).data.float() 14 | intersection = torch.clamp(pred_ * true_, 0, 1) 15 | union = torch.clamp(pred_ + true_, 0, 1) 16 | if torch.mean(intersection).lt(eps): 17 | return torch.tensor([0., 0., 0., 0.]) 18 | else: 19 | iou = torch.mean(intersection) / torch.mean(union) 20 | precision = torch.mean(intersection) / torch.mean(pred_) 21 | recall = torch.mean(intersection) / torch.mean(true_) 22 | f1 = 2 * precision * recall / (precision + recall) 23 | return torch.tensor([precision, recall, f1, iou]) 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sijie Ruan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /topology_construction/custom_map_matching.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tptk.map_matching.hmm.hmm_map_matcher import TIHMMMapMatcher 3 | from tptk.common.trajectory import parse_traj_file 4 | from tptk.common.path import store_path_file 5 | from tptk.common.road_network import load_rn_shp 6 | 7 | 8 | class CustomMapMatching: 9 | def __init__(self, rn_path, alpha): 10 | self.rn_path = rn_path 11 | self.alpha = alpha 12 | 13 | def execute(self, filename, traj_path, mm_result_path): 14 | rn = load_rn_shp(self.rn_path, is_directed=True) 15 | for u, v, data in rn.edges(data=True): 16 | if data['type'] == 'virtual': 17 | data['weight'] = data['length'] * self.alpha 18 | else: 19 | data['weight'] = data['length'] 20 | map_matcher = TIHMMMapMatcher(rn, routing_weight='weight') 21 | traj_list = parse_traj_file(os.path.join(traj_path, filename)) 22 | all_paths = [] 23 | for traj in traj_list: 24 | paths = map_matcher.match_to_path(traj) 25 | all_paths.extend(paths) 26 | store_path_file(all_paths, os.path.join(mm_result_path, filename)) 27 | -------------------------------------------------------------------------------- /geometry_translation/options/test_options.py: -------------------------------------------------------------------------------- 1 | from geometry_translation.options.base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | It also includes shared options defined in BaseOptions. 7 | """ 8 | 9 | def initialize(self, parser): 10 | parser = BaseOptions.initialize(self, parser) # define shared options 11 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 12 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 13 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 14 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 15 | # Dropout and Batchnorm has different behavioir during training and test. 16 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 17 | # rewrite devalue values 18 | parser.set_defaults(model='traj2rn') 19 | # To avoid cropping, the load_size should be the same as crop_size 20 | parser.set_defaults(load_size=parser.get_default('crop_size')) 21 | self.is_train = False 22 | return parser 23 | -------------------------------------------------------------------------------- /topology_construction/topo_utils.py: -------------------------------------------------------------------------------- 1 | from tptk.common.spatial_func import SPoint 2 | import numpy as np 3 | 4 | 5 | def unit_vector(vector): 6 | """ Returns the unit vector of the vector. """ 7 | return vector / np.linalg.norm(vector) 8 | 9 | 10 | def angle_between(v1, v2): 11 | v1_u = unit_vector(v1) 12 | v2_u = unit_vector(v2) 13 | return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) 14 | 15 | 16 | def magnitude(vector): 17 | return np.sqrt(np.dot(np.array(vector),np.array(vector))) 18 | 19 | 20 | def norm(vector): 21 | return np.array(vector)/magnitude(np.array(vector)) 22 | 23 | 24 | def ccw(A,B,C): 25 | return (C.lat-A.lat) * (B.lng-A.lng) > (B.lat-A.lat) * (C.lng-A.lng) 26 | 27 | 28 | def is_line_line_intersected(A, B, C, D): 29 | return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D) 30 | 31 | 32 | def line_ray_intersection_test(o, f, a, b): 33 | """ 34 | :param o: ray original point SPoint 35 | :param f: ray from point SPoint ray: f->o 36 | :param a: line segment point 1 SPoint 37 | :param b: line segment point 2 SPoint 38 | :return: 39 | """ 40 | o = np.array((o.lng, o.lat), dtype=np.float) 41 | dir = np.array(norm((o[0] - f.lng, o[1] - f.lat)), dtype=np.float) 42 | a = np.array((a.lng, a.lat), dtype=np.float) 43 | b = np.array((b.lng, b.lat), dtype=np.float) 44 | 45 | v1 = o - a 46 | v2 = b - a 47 | v3 = np.asarray([-dir[1], dir[0]]) 48 | t1 = np.cross(v2, v1) / np.dot(v2, v3) 49 | t2 = np.dot(v1, v3) / np.dot(v2, v3) 50 | # t1=inf parallel 51 | if t1 == np.inf or t1 < 0: 52 | # ray has no intersection with line segment 53 | return None 54 | else: 55 | pt = o + t1 * dir 56 | # 1. t2<0, in extension of a; 2. t2 in [0,1], within ab; 3. t2>1, in extension of b 57 | return t2, SPoint(pt[1], pt[0]) 58 | -------------------------------------------------------------------------------- /geometry_translation/test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from geometry_translation.options.test_options import TestOptions 4 | from geometry_translation.models import create_model 5 | from geometry_translation.data_loader import get_data_loader 6 | from geometry_translation.utils.visualizer import save_images 7 | from geometry_translation.utils import html 8 | import os 9 | 10 | 11 | class Tester: 12 | def __init__(self, opt, model, test_dl): 13 | self.opt = opt 14 | self.model = model 15 | self.test_dl = test_dl 16 | 17 | def pred(self): 18 | self.model.eval() 19 | # create a website 20 | web_dir = os.path.join(self.opt.results_dir, self.opt.name, '%s_%s' % (self.opt.phase, self.opt.epoch)) # define the website directory 21 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) 22 | tot_loss = 0 23 | tot_metrics = 0 24 | for i, data in enumerate(self.test_dl): 25 | self.model.set_input(data) 26 | _, iter_loss, iter_metrics, iter_road_metrics = self.model.test() 27 | tot_loss += iter_loss.item() 28 | tot_metrics += iter_metrics.numpy() 29 | visuals = model.get_current_visuals() # get image results 30 | img_path = model.get_image_paths() # get image paths 31 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) 32 | tot_loss /= len(self.test_dl) 33 | tot_metrics /= len(self.test_dl) 34 | print('loss\t{:.6f}\tprecision\t{:.4f}\trecall\t{:.4f}\tf1\t{:.4f}\tiou\t{:.4f}\n'.format(tot_loss, tot_metrics[0], tot_metrics[1], tot_metrics[2], tot_metrics[3])) 35 | webpage.save() # save the HTML 36 | 37 | 38 | if __name__ == '__main__': 39 | opt = TestOptions().parse() 40 | model = create_model(opt) 41 | model.setup(opt) 42 | test_dl = get_data_loader(opt.dataroot, 'test') 43 | tester = Tester(opt, model, test_dl) 44 | tester.pred() 45 | -------------------------------------------------------------------------------- /geometry_translation/region_concatenation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from PIL import Image 3 | from skimage.morphology import skeletonize 4 | import numpy as np 5 | import cv2 6 | import json 7 | import os 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--conf_path', help='the configuration file of the dateset') 12 | parser.add_argument('--tile_path', help='the directory of predicted tiles') 13 | parser.add_argument('--results_path', help='the path to results') 14 | 15 | opt = parser.parse_args() 16 | print(opt) 17 | os.makedirs(opt.results_path, exist_ok=True) 18 | with open(opt.conf_path, 'r') as f: 19 | conf = json.load(f) 20 | tile_height, tile_width = conf['feature_extraction']['tile_pixel_size'], \ 21 | conf['feature_extraction']['tile_pixel_size'] 22 | row_min, row_max, col_min, col_max = conf['feature_extraction']['test_tile_row_min'], \ 23 | conf['feature_extraction']['test_tile_row_max'], \ 24 | conf['feature_extraction']['test_tile_col_min'], \ 25 | conf['feature_extraction']['test_tile_col_max'] 26 | 27 | img_tmp = '{}_{}_pred_rn_img.png' 28 | pred_eval_region_data = np.zeros(((row_max - row_min) * tile_height, (col_max - col_min) * tile_width), 29 | dtype=np.uint8) 30 | for i in range(row_min, row_max): 31 | for j in range(col_min, col_max): 32 | filename = opt.tile_path + img_tmp.format(i, j) 33 | tile_data = cv2.imread(filename, cv2.IMREAD_GRAYSCALE) 34 | sl = (slice((i - row_min) * tile_height, (i + 1 - row_min) * tile_height), 35 | slice((j - col_min) * tile_width, (j + 1 - col_min) * tile_width)) 36 | pred_eval_region_data[sl] = tile_data 37 | cv2.imwrite(opt.results_path + 'pred.png', pred_eval_region_data) 38 | 39 | inp = cv2.imread(opt.results_path + 'pred.png', cv2.IMREAD_GRAYSCALE) 40 | inp[inp > 0] = 1 41 | skeleton = skeletonize(inp) 42 | skeleton[skeleton > 0] = 255 43 | skeleton = Image.fromarray(skeleton).convert('RGB') 44 | skeleton.save(opt.results_path + 'pred_thinned.png') 45 | -------------------------------------------------------------------------------- /geometry_translation/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | from PIL import Image 4 | import os 5 | import torch 6 | import torchvision.transforms as transforms 7 | 8 | 9 | def get_data_loader(root_dir, mode): 10 | if mode == 'train': 11 | dl = data.DataLoader(TrajDataset(os.path.join(root_dir, mode)), shuffle=True) 12 | else: 13 | dl = data.DataLoader(TrajDataset(os.path.join(root_dir, mode)), shuffle=False) 14 | return dl 15 | 16 | 17 | class TrajDataset(data.Dataset): 18 | def __init__(self, data_path): 19 | super(TrajDataset, self).__init__() 20 | self.img_paths = [os.path.join(data_path, filename) for filename in os.listdir(data_path) 21 | if filename.endswith('.png')] 22 | 23 | def __getitem__(self, index): 24 | img_path = self.img_paths[index] 25 | img = Image.open(img_path).convert('L') 26 | speed = np.load(img_path.replace('.png', '_speed.npy')) 27 | direction = np.load(img_path.replace('.png', '_direction.npy')) 28 | transition_view = np.load(img_path.replace('.png', '_transition.npy')).astype('float32') 29 | 30 | transition_view[transition_view > 0] = 1.0 31 | transition_view = torch.from_numpy(transition_view) 32 | 33 | w, h = img.size 34 | unit_w = int(w / 4) 35 | point_img = img.crop((0, 0, unit_w, h)) 36 | line_img = img.crop((unit_w, 0, unit_w * 2, h)) 37 | centerline_img = img.crop((unit_w * 2, 0, unit_w * 3, h)) 38 | region_img = img.crop((unit_w * 3, 0, unit_w * 4, h)) 39 | 40 | # normalization & to torch data structure 41 | spatial_features = [] 42 | img_transform = transforms.ToTensor() 43 | spatial_features.append(img_transform(point_img)) 44 | line_img = img_transform(line_img) 45 | spatial_features.append(line_img) 46 | speed = torch.from_numpy(((speed / 34 - 0.5) * 2.0).astype('float32')) 47 | spatial_features.append(speed.permute(2, 0, 1)) 48 | direction = direction.astype('float32') 49 | summed = np.sum(direction, axis=2, keepdims=True) 50 | direction = torch.from_numpy(np.divide(direction, summed, out=np.zeros_like(direction), where=summed != 0)) 51 | for i in range(direction.shape[2]): 52 | spatial_features.append(direction[:, :, i:i+1].permute(2, 0, 1)) 53 | spatial_view = torch.cat(tuple(spatial_features), 0) 54 | centerline_img = img_transform(centerline_img) 55 | region_img = img_transform(region_img) 56 | 57 | return { 58 | 'img_path': img_path, 59 | 'spatial_view': spatial_view, 'transition_view': transition_view, 60 | 'centerline': centerline_img, 'region': region_img 61 | } 62 | 63 | def __len__(self): 64 | return len(self.img_paths) 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepMG: Learning to Generate Maps from Trajectories 2 | 3 | In this study, we aim to generate a routable road network from trajectories using a deep learning approach. 4 | 5 | ## Paper 6 | 7 | If you find our code useful for your research, please cite our paper: 8 | 9 | *Sijie Ruan, Cheng Long, Jie Bao, Chunyang Li, Zisheng Yu, Ruiyuan Li, Yuxuan Liang, Tianfu He, Yu Zheng. "Learning to Generate Maps from Trajectories". AAAI 2020.* 10 | 11 | 12 | ## Geometry Translation 13 | 14 | ### Feature Extraction 15 | 16 | `python feature_extraction.py ../data/conf_tdrive_sample.json` 17 | 18 | ### Training 19 | 20 | `python train.py --name t2rnet_tdrive_sample --dataroot ../data/tdrive_sample/learning --lam 0.2 --batch_size 8 --model t2rnet --display_id -1` 21 | 22 | ### Inference 23 | 24 | `python test.py --name t2rnet_tdrive_sample --dataroot ../data/tdrive_sample/learning --lam 0.2 --model t2rnet` 25 | 26 | ### Region Concatenation 27 | 28 | `python region_concatenation.py --tile_path ./results/t2rnet_tdrive_sample/test_latest/images/ --conf_path ../data/conf_tdrive_sample.json --results_path ../data/tdrive_sample/results/` 29 | 30 | ## Topology Construction 31 | 32 | ### Graph Extraction 33 | 34 | `python main.py --phase 1 --conf_path ../data/conf_tdrive_sample.json --results_path ../data/tdrive_sample/results/` 35 | 36 | ### Link Generation 37 | 38 | `python main.py --phase 2 --conf_path ../data/conf_tdrive_sample.json --results_path ../data/tdrive_sample/results/` 39 | 40 | ### Map Refinement 41 | 42 | #### Map Matching 43 | 44 | `python main.py --phase 3 --conf_path ../data/conf_tdrive_sample.json --results_path ../data/tdrive_sample/results/` 45 | 46 | #### Edge Pruning 47 | 48 | `python main.py --phase 4 --conf_path ../data/conf_tdrive_sample.json --results_path ../data/tdrive_sample/results/` 49 | 50 | ## Requirements 51 | 52 | DeepMG uses the following dependencies with Python 3.6 53 | 54 | * gdal==2.3.2 55 | * opencv==3.3.1 56 | * rtree==0.8.3 57 | * networkx==2.3 58 | * scikit-image==0.16.2 59 | * pytorch==1.1.0 60 | * torchvision==0.3.0 61 | 62 | Other packages can be easily installed using `conda install`, while the following scripts are recommended for `gdal`, `opencv` and `pytorch`. 63 | 64 | `conda install -c conda-forge gdal==2.3.2` 65 | 66 | `conda install -c menpo opencv==3.3.1` 67 | 68 | `conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=9.0 -c pytorch -c defaults -c numba/label/dev` 69 | 70 | Note that `gdal` must be installed first, and a restart might be required after all installation. 71 | 72 | ## Datasets 73 | 74 | * Trajectory Data 75 | * There are some open-source trajectory datasets, e.g., [TDrive](https://www.microsoft.com/en-us/research/publication/t-drive-trajectory-data-sample/). 76 | * Data should be organized as text files. A text file can contain several trajectories (refer to [sample.txt.template](https://github.com/sjruan/DeepMG/blob/master/data/tdrive_sample/traj/sample.txt.template)) 77 | 78 | * Road Network Data 79 | * [OpenStreetMap (OSM)](https://www.openstreetmap.org/) is an open-source map data source 80 | * The Shapefile road network format can be generated by [osm2rn](https://github.com/sjruan/osm2rn). 81 | -------------------------------------------------------------------------------- /topology_construction/map_refinement.py: -------------------------------------------------------------------------------- 1 | from tptk.common.road_network import store_rn_shp 2 | from tptk.common.path import parse_path_file 3 | import os 4 | import copy 5 | 6 | 7 | class MapRefiner: 8 | def __init__(self, min_sup): 9 | self.min_sup = min_sup 10 | self.pred_min_sup = 1 11 | 12 | def refine(self, linked_rn, mm_traj_path, final_rn_path): 13 | edge2cnt = self.get_edge2cnt(linked_rn, mm_traj_path) 14 | edges_to_del = [] 15 | for u, v, type in linked_rn.edges(data='type'): 16 | if type == 'pred': 17 | if (u, v) not in edge2cnt or edge2cnt[(u, v)] < self.pred_min_sup: 18 | edges_to_del.append((u, v)) 19 | elif type == 'virtual': 20 | if (u, v) not in edge2cnt or edge2cnt[(u, v)] < self.min_sup: 21 | edges_to_del.append((u, v)) 22 | final_rn = copy.deepcopy(linked_rn) 23 | print('edges&links to delete:{}'.format(len(edges_to_del))) 24 | for u, v in edges_to_del: 25 | final_rn.remove_edge(u, v) 26 | self.final_refine(final_rn) 27 | store_rn_shp(final_rn, final_rn_path) 28 | 29 | def get_edge2cnt(self, linked_rn, mm_result_path): 30 | edge2cnt = {} 31 | for filename in (os.listdir(mm_result_path)): 32 | paths = parse_path_file(os.path.join(mm_result_path, filename)) 33 | for path in paths: 34 | for path_entity in path.path_entities: 35 | edge = linked_rn.edge_idx[path_entity.eid] 36 | if edge not in edge2cnt: 37 | edge2cnt[edge] = 1 38 | else: 39 | edge2cnt[edge] += 1 40 | return edge2cnt 41 | 42 | def final_refine(self, rn): 43 | # make sure each vertex will not exceed 2 virtual links 44 | rn_undir = rn.to_undirected() 45 | link_cnt = 0 46 | links_to_delete = [] 47 | for u, v, data in rn_undir.edges(data=True): 48 | if data['type'] != 'virtual': 49 | continue 50 | link_cnt += 1 51 | candi_links = [] 52 | for x in rn_undir[u]: 53 | if rn_undir[x][u]['type'] == 'virtual': 54 | candi_links.append((x, u)) 55 | if len(candi_links) > 2: 56 | candi_links = sorted(candi_links, key=lambda k: rn_undir[k[0]][k[1]]['length']) 57 | links_to_delete.extend(candi_links[2:]) 58 | candi_links = [] 59 | for y in rn_undir[v]: 60 | if rn_undir[y][v]['type'] == 'virtual': 61 | candi_links.append((y, v)) 62 | if len(candi_links) > 2: 63 | candi_links = sorted(candi_links, key=lambda k: rn_undir[k[0]][k[1]]['length']) 64 | links_to_delete.extend(candi_links[2:]) 65 | print('final links to delete:{}'.format(len(links_to_delete))) 66 | for link in links_to_delete: 67 | if rn.has_edge(link[0], link[1]): 68 | rn.remove_edge(link[0], link[1]) 69 | if rn.has_edge(link[1], link[0]): 70 | rn.remove_edge(link[1], link[0]) 71 | -------------------------------------------------------------------------------- /geometry_translation/options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | It also includes shared options defined in BaseOptions. 7 | """ 8 | 9 | def initialize(self, parser): 10 | parser = BaseOptions.initialize(self, parser) 11 | # visdom and HTML visualization parameters 12 | parser.add_argument('--display_freq', type=int, default=20, help='frequency of showing training results on screen') 13 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 14 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 15 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 16 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 17 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 18 | parser.add_argument('--update_html_freq', type=int, default=40, help='frequency of saving training results to html') 19 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 20 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 21 | # network saving and loading parameters 22 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 23 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 24 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 25 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 26 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 27 | # training parameters 28 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training') 29 | parser.add_argument('--sample_interval', type=int, default=20, help='validation interval') 30 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 31 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 32 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 33 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') 34 | parser.add_argument('--lr_decay_iters', type=int, default=10, help='multiply by a gamma every lr_decay_iters iterations') 35 | parser.add_argument('--phase', type=str, default='train') 36 | self.is_train = True 37 | return parser 38 | -------------------------------------------------------------------------------- /geometry_translation/utils/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | Parameters: 12 | input_image (tensor) -- the input image tensor array 13 | imtype (type) -- the desired type of the converted numpy array 14 | """ 15 | if not isinstance(input_image, np.ndarray): 16 | if isinstance(input_image, torch.Tensor): # get the data from a variable 17 | image_tensor = input_image.data 18 | else: 19 | return input_image 20 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 21 | if image_numpy.shape[0] == 1: # grayscale to RGB 22 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 23 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 # post-processing: tranpose and scaling 24 | else: # if it is a numpy array, do nothing 25 | image_numpy = input_image 26 | return image_numpy.astype(imtype) 27 | 28 | 29 | def diagnose_network(net, name='network'): 30 | """Calculate and print the mean of average absolute(gradients) 31 | Parameters: 32 | net (torch network) -- Torch network 33 | name (str) -- the name of the network 34 | """ 35 | mean = 0.0 36 | count = 0 37 | for param in net.parameters(): 38 | if param.grad is not None: 39 | mean += torch.mean(torch.abs(param.grad.data)) 40 | count += 1 41 | if count > 0: 42 | mean = mean / count 43 | print(name) 44 | print(mean) 45 | 46 | 47 | def save_image(image_numpy, image_path): 48 | """Save a numpy image to the disk 49 | Parameters: 50 | image_numpy (numpy array) -- input numpy array 51 | image_path (str) -- the path of the image 52 | """ 53 | image_pil = Image.fromarray(image_numpy) 54 | image_pil.save(image_path) 55 | 56 | 57 | def print_numpy(x, val=True, shp=False): 58 | """Print the mean, min, max, median, std, and size of a numpy array 59 | Parameters: 60 | val (bool) -- if print the values of the numpy array 61 | shp (bool) -- if print the shape of the numpy array 62 | """ 63 | x = x.astype(np.float64) 64 | if shp: 65 | print('shape,', x.shape) 66 | if val: 67 | x = x.flatten() 68 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 69 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 70 | 71 | 72 | def mkdirs(paths): 73 | """create empty directories if they don't exist 74 | Parameters: 75 | paths (str list) -- a list of directory paths 76 | """ 77 | if isinstance(paths, list) and not isinstance(paths, str): 78 | for path in paths: 79 | mkdir(path) 80 | else: 81 | mkdir(paths) 82 | 83 | 84 | def mkdir(path): 85 | """create a single empty directory if it didn't exist 86 | Parameters: 87 | path (str) -- a single directory path 88 | """ 89 | if not os.path.exists(path): 90 | os.makedirs(path) 91 | -------------------------------------------------------------------------------- /geometry_translation/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 3 | You need to implement the following five functions: 4 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 5 | -- : unpack data from dataset and apply preprocessing. 6 | -- : produce intermediate results. 7 | -- : calculate loss, gradients, and update network weights. 8 | -- : (optionally) add model-specific options and set default options. 9 | In the function <__init__>, you need to define four lists: 10 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 11 | -- self.model_names (str list): define networks used in our training. 12 | -- self.visual_names (str list): specify the images that you want to display and save. 13 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 14 | Now you can use the model class by specifying flag '--model dummy'. 15 | See our template model class 'template_model.py' for more details. 16 | """ 17 | 18 | import importlib 19 | from geometry_translation.models.base_model import BaseModel 20 | 21 | 22 | def find_model_using_name(model_name): 23 | """Import the module "models/[model_name]_model.py". 24 | In the file, the class called DatasetNameModel() will 25 | be instantiated. It has to be a subclass of BaseModel, 26 | and it is case-insensitive. 27 | """ 28 | model_filename = "models." + model_name + "_model" 29 | modellib = importlib.import_module(model_filename) 30 | model = None 31 | target_model_name = model_name.replace('_', '') + 'model' 32 | for name, cls in modellib.__dict__.items(): 33 | if name.lower() == target_model_name.lower() \ 34 | and issubclass(cls, BaseModel): 35 | model = cls 36 | 37 | if model is None: 38 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 39 | exit(0) 40 | 41 | return model 42 | 43 | 44 | def get_option_setter(model_name): 45 | """Return the static method of the model class.""" 46 | model_class = find_model_using_name(model_name) 47 | return model_class.modify_commandline_options 48 | 49 | 50 | def create_model(opt): 51 | """Create a model given the option. 52 | This function warps the class CustomDatasetDataLoader. 53 | This is the main interface between this package and 'train.py'/'test.py' 54 | Example: 55 | >>> from geometry_translation.models import create_model 56 | >>> model = create_model(opt) 57 | """ 58 | model = find_model_using_name(opt.model) 59 | instance = model(opt) 60 | print("model [%s] was created" % type(instance).__name__) 61 | return instance 62 | -------------------------------------------------------------------------------- /geometry_translation/utils/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | It consists of functions such as (add a text header to the HTML file), 9 | (add a row of images to the HTML file), and (save the HTML to the disk). 10 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 11 | """ 12 | 13 | def __init__(self, web_dir, title, refresh=0): 14 | """Initialize the HTML classes 15 | Parameters: 16 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 30 | with self.doc.head: 31 | meta(http_equiv="refresh", content=str(refresh)) 32 | 33 | def get_image_dir(self): 34 | """Return the directory that stores images""" 35 | return self.img_dir 36 | 37 | def add_header(self, text): 38 | """Insert a header to the HTML file 39 | Parameters: 40 | text (str) -- the header text 41 | """ 42 | with self.doc: 43 | h3(text) 44 | 45 | def add_images(self, ims, txts, links, width=400): 46 | """add images to the HTML file 47 | Parameters: 48 | ims (str list) -- a list of image paths 49 | txts (str list) -- a list of image names shown on the website 50 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 51 | """ 52 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 53 | self.doc.add(self.t) 54 | with self.t: 55 | with tr(): 56 | for im, txt, link in zip(ims, txts, links): 57 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 58 | with p(): 59 | with a(href=os.path.join('images', link)): 60 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 61 | br() 62 | p(txt) 63 | 64 | def save(self): 65 | """save the current content to the HMTL file""" 66 | html_file = '%s/index.html' % self.web_dir 67 | f = open(html_file, 'wt') 68 | f.write(self.doc.render()) 69 | f.close() 70 | 71 | 72 | if __name__ == '__main__': # we show an example usage here. 73 | html = HTML('web/', 'test_html') 74 | html.add_header('hello world') 75 | 76 | ims, txts, links = [], [], [] 77 | for n in range(4): 78 | ims.append('image_%d.png' % n) 79 | txts.append('text_%d' % n) 80 | links.append('image_%d.png' % n) 81 | html.add_images(ims, txts, links) 82 | html.save() 83 | -------------------------------------------------------------------------------- /geometry_translation/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from geometry_translation.options.train_options import TrainOptions 4 | from geometry_translation.utils.visualizer import Visualizer 5 | from geometry_translation.models import create_model 6 | from geometry_translation.data_loader import get_data_loader 7 | from datetime import datetime 8 | import os 9 | 10 | 11 | class Trainer: 12 | def __init__(self, opt, model, train_dl, val_dl, visualizer): 13 | self.opt = opt 14 | self.model = model 15 | self.train_dl = train_dl 16 | self.val_dl = val_dl 17 | self.visualizer = visualizer 18 | 19 | def fit(self): 20 | best_f1_score = 0.0 21 | # training phase 22 | tot_iters = 0 23 | for epoch in range(1, self.opt.n_epochs + 1): 24 | print(f'epoch {epoch}/{self.opt.n_epochs}') 25 | ep_time = datetime.now() 26 | for i, data in enumerate(self.train_dl): 27 | self.model.train() 28 | self.model.set_input(data) 29 | iter_loss_all, iter_loss_road, iter_loss_cl, iter_metrics, iter_road_metrics = self.model.optimize_parameters() 30 | iter_metrics = iter_metrics.numpy() 31 | iter_road_metrics = iter_road_metrics.numpy() 32 | print("[Epoch %d/%d] [Batch %d/%d] [Loss: %f] [Road Loss: %f] [CL Loss: %f] [Precision: %f] [Recall: %f] [F1: %f] [Road IOU: %f] [CL IOU: %f]" % (epoch, opt.n_epochs, i, len(train_dl), iter_loss_all.item(), iter_loss_road.item(), iter_loss_cl.item(), iter_metrics[0], iter_metrics[1], iter_metrics[2], iter_road_metrics[3], iter_metrics[3])) 33 | tot_iters += 1 34 | 35 | if tot_iters % self.opt.display_freq == 0: # display images on visdom and save images to a HTML file 36 | save_result = tot_iters % self.opt.update_html_freq == 0 37 | model.compute_visuals() 38 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 39 | if tot_iters % self.opt.print_freq == 0: # print training losses 40 | losses = model.get_current_losses() 41 | if opt.display_id > 0: 42 | visualizer.plot_current_losses(epoch, i / len(self.train_dl), losses) 43 | 44 | # validating phase 45 | if tot_iters % opt.sample_interval == 0: 46 | self.model.eval() 47 | tot_loss = 0 48 | tot_metrics = 0 49 | tot_road_metrics = 0 50 | for i, data in enumerate(self.val_dl): 51 | self.model.set_input(data) 52 | _, iter_loss, iter_metrics, iter_road_metrics = self.model.test() 53 | tot_loss += iter_loss.item() 54 | tot_metrics += iter_metrics.numpy() 55 | tot_road_metrics += iter_road_metrics.numpy() 56 | tot_loss /= len(self.val_dl) 57 | tot_metrics /= len(self.val_dl) 58 | tot_road_metrics /= len(self.val_dl) 59 | if tot_metrics[2] > best_f1_score: 60 | best_f1_score = tot_metrics[2] 61 | self.model.save_networks('latest') 62 | self.model.save_networks(epoch) 63 | with open(os.path.join(opt.checkpoints_dir, opt.name, 'results.txt'), 'a') as f: 64 | f.write('epoch\t{}\titer\t{}\tloss\t{:.6f}\tprecision\t{:.4f}\trecall\t{:.4f}\tf1\t{:.4f}\troad_iou\t{:.4f}\tcl_iou\t{:.4f}\n'.format(epoch, tot_iters, tot_loss, tot_metrics[0], tot_metrics[1], tot_metrics[2], tot_road_metrics[3], tot_metrics[3])) 65 | f.close() 66 | print('=================time cost: {}==================='.format(datetime.now() - ep_time)) 67 | self.model.update_learning_rate() 68 | 69 | 70 | if __name__ == '__main__': 71 | opt = TrainOptions().parse() 72 | model = create_model(opt) 73 | model.setup(opt) 74 | train_dl = get_data_loader(opt.dataroot, 'train') 75 | val_dl = get_data_loader(opt.dataroot, 'val') 76 | visualizer = Visualizer(opt) 77 | trainer = Trainer(opt, model, train_dl, val_dl, visualizer) 78 | trainer.fit() 79 | -------------------------------------------------------------------------------- /geometry_translation/models/translator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class T2RNet_Trans(nn.Module): 6 | def __init__(self, input_nc, base_channels): 7 | super(T2RNet_Trans, self).__init__() 8 | self.max_pool = nn.MaxPool2d(2) 9 | 10 | self.conv1 = self.conv_stage(input_nc, base_channels) 11 | self.conv2 = self.conv_stage(base_channels, base_channels * 2) 12 | self.conv3 = self.conv_stage(base_channels * 2, base_channels * 4) 13 | self.conv4 = self.conv_stage(base_channels * 4, base_channels * 8) 14 | 15 | self.center = self.conv_stage(base_channels * 8, base_channels * 16) 16 | 17 | self.d1_conv4 = self.conv_stage(base_channels * 16, base_channels * 8) 18 | self.d1_conv3 = self.conv_stage(base_channels * 8, base_channels * 4) 19 | self.d1_conv2 = self.conv_stage(base_channels * 4, base_channels * 2) 20 | self.d1_conv1 = self.conv_stage(base_channels * 2, base_channels) 21 | 22 | self.d1_up4 = self.upsample(base_channels * 16, base_channels * 8) 23 | self.d1_up3 = self.upsample(base_channels * 8, base_channels * 4) 24 | self.d1_up2 = self.upsample(base_channels * 4, base_channels * 2) 25 | self.d1_up1 = self.upsample(base_channels * 2, base_channels) 26 | 27 | self.d1_conv_last = nn.Sequential( 28 | nn.Conv2d(64, 1, 3, 1, 1), 29 | nn.Sigmoid() 30 | ) 31 | 32 | self.d2_conv4 = self.conv_stage(base_channels * (16 + 8), base_channels * 8) 33 | self.d2_conv3 = self.conv_stage(base_channels * (8 + 4), base_channels * 4) 34 | self.d2_conv2 = self.conv_stage(base_channels * (4 + 2), base_channels * 2) 35 | self.d2_conv1 = self.conv_stage(base_channels * (2 + 1), base_channels) 36 | 37 | self.d2_up4 = self.upsample(base_channels * 16, base_channels * 8) 38 | self.d2_up3 = self.upsample(base_channels * 8, base_channels * 4) 39 | self.d2_up2 = self.upsample(base_channels * 4, base_channels * 2) 40 | self.d2_up1 = self.upsample(base_channels * 2, base_channels) 41 | 42 | self.d2_conv_last = nn.Sequential( 43 | nn.Conv2d(64, 1, 3, 1, 1), 44 | nn.Sigmoid() 45 | ) 46 | 47 | def upsample(self, ch_coarse, ch_fine): 48 | return nn.Sequential( 49 | nn.ConvTranspose2d(ch_coarse, ch_fine, 4, 2, 1, bias=False), 50 | nn.ReLU() 51 | ) 52 | 53 | def conv_stage(self, dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=True): 54 | return nn.Sequential( 55 | nn.Conv2d(dim_in, dim_out, kernel_size=kernel_size, 56 | stride=stride, padding=padding, bias=bias), 57 | nn.BatchNorm2d(dim_out), 58 | nn.ReLU(), 59 | nn.Conv2d(dim_out, dim_out, kernel_size=kernel_size, 60 | stride=stride, padding=padding, bias=bias), 61 | nn.BatchNorm2d(dim_out), 62 | nn.ReLU(), 63 | ) 64 | 65 | def forward(self, input): 66 | # conv1_out: 256x256x64 67 | conv1_out = self.conv1(input) 68 | # 128 x 128 69 | # conv2_out: 128x128x128 70 | conv2_out = self.conv2(self.max_pool(conv1_out)) 71 | # 64 x 64 72 | # conv3_out: 64x64x256 73 | conv3_out = self.conv3(self.max_pool(conv2_out)) 74 | # 32 * 32 75 | # conv4_out: 32x32x512 76 | conv4_out = self.conv4(self.max_pool(conv3_out)) 77 | # 16 * 16 78 | # out: 16x16x1024 79 | out = self.center(self.max_pool(conv4_out)) 80 | 81 | # d1_up4_out: 32 * 32 * 512 82 | d1_up4_out = self.d1_up4(out) 83 | out1 = self.d1_conv4(torch.cat((d1_up4_out, conv4_out), 1)) 84 | d1_up3_out = self.d1_up3(out1) 85 | out1 = self.d1_conv3(torch.cat((d1_up3_out, conv3_out), 1)) 86 | d1_up2_out = self.d1_up2(out1) 87 | out1 = self.d1_conv2(torch.cat((d1_up2_out, conv2_out), 1)) 88 | d1_up1_out = self.d1_up1(out1) 89 | out1 = self.d1_conv1(torch.cat((d1_up1_out, conv1_out), 1)) 90 | out1 = self.d1_conv_last(out1) 91 | 92 | out2 = self.d2_conv4(torch.cat((self.d2_up4(out), conv4_out, d1_up4_out), 1)) 93 | out2 = self.d2_conv3(torch.cat((self.d2_up3(out2), conv3_out, d1_up3_out), 1)) 94 | out2 = self.d2_conv2(torch.cat((self.d2_up2(out2), conv2_out, d1_up2_out), 1)) 95 | out2 = self.d2_conv1(torch.cat((self.d2_up1(out2), conv1_out, d1_up1_out), 1)) 96 | out2 = self.d2_conv_last(out2) 97 | return out1, out2 98 | -------------------------------------------------------------------------------- /topology_construction/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | sys.path.append('../tptk/') 4 | import argparse 5 | import cv2 6 | from tptk.common.mbr import MBR 7 | from tptk.common.grid import Grid 8 | from tptk.common.road_network import load_rn_shp 9 | from topology_construction.graph_extraction import GraphExtractor 10 | from topology_construction.link_generation import LinkGenerator 11 | from topology_construction.custom_map_matching import CustomMapMatching 12 | from topology_construction.map_refinement import MapRefiner 13 | import json 14 | import os 15 | from multiprocessing import Pool 16 | from functools import partial 17 | 18 | 19 | def get_test_mbr(conf): 20 | dataset_conf = conf['dataset'] 21 | feature_extraction_conf = conf['feature_extraction'] 22 | min_lat, min_lng, max_lat, max_lng = dataset_conf['min_lat'], dataset_conf['min_lng'], \ 23 | dataset_conf['max_lat'], dataset_conf['max_lng'] 24 | whole_region_mbr = MBR(min_lat, min_lng, max_lat, max_lng) 25 | whole_region_grid = Grid(whole_region_mbr, dataset_conf['nb_rows'], dataset_conf['nb_cols']) 26 | test_row_min, test_col_min, test_row_max, test_col_max = feature_extraction_conf['test_tile_row_min'], \ 27 | feature_extraction_conf['test_tile_col_min'], \ 28 | feature_extraction_conf['test_tile_row_max'], \ 29 | feature_extraction_conf['test_tile_col_max'] 30 | tile_pixel_size = feature_extraction_conf['tile_pixel_size'] 31 | test_row_min_idx = test_row_min * tile_pixel_size 32 | test_row_max_idx = test_row_max * tile_pixel_size 33 | test_col_min_idx = test_col_min * tile_pixel_size 34 | test_col_max_idx = test_col_max * tile_pixel_size 35 | 36 | test_region_lower_left_mbr = whole_region_grid.get_mbr_by_matrix_idx(test_row_max_idx, test_col_min_idx) 37 | test_region_min_lat, test_region_min_lng = test_region_lower_left_mbr.max_lat, test_region_lower_left_mbr.min_lng 38 | test_region_upper_right_mbr = whole_region_grid.get_mbr_by_matrix_idx(test_row_min_idx, test_col_max_idx) 39 | test_region_max_lat, test_region_max_lng = test_region_upper_right_mbr.max_lat, test_region_upper_right_mbr.min_lng 40 | test_region_mbr = MBR(test_region_min_lat, test_region_min_lng, test_region_max_lat, test_region_max_lng) 41 | return test_region_mbr 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--phase', type=int, default=1, help='1,2,3,4') 47 | parser.add_argument('--conf_path', help='the configuration file path') 48 | parser.add_argument('--results_path', help='the path to the results directory') 49 | 50 | opt = parser.parse_args() 51 | print(opt) 52 | with open(opt.conf_path, 'r') as f: 53 | conf = json.load(f) 54 | results_path = opt.results_path 55 | traj_path = '../data/{}/traj/'.format(conf['dataset']['dataset_name']) 56 | extracted_rn_path = results_path + 'extracted_rn/' 57 | linked_rn_path = results_path + 'linked_rn/' 58 | mm_on_linked_rn_path = results_path + 'mm_on_linked_rn/' 59 | final_rn_path = results_path + 'final_rn/' 60 | os.makedirs(mm_on_linked_rn_path, exist_ok=True) 61 | topo_params = conf['topology_construction'] 62 | # Graph Extraction 63 | if opt.phase == 1: 64 | mbr = get_test_mbr(conf) 65 | skeleton = cv2.imread(results_path + 'pred_thinned.png', cv2.IMREAD_GRAYSCALE) 66 | map_extractor = GraphExtractor(epsilon=10, min_road_dist=5) 67 | map_extractor.extract(skeleton, mbr, extracted_rn_path) 68 | # Link Generation 69 | elif opt.phase == 2: 70 | link_generator = LinkGenerator(radius=topo_params['link_radius']) 71 | # the extracted rn is undirected, while the linked rn is directed (bi-directional) 72 | extracted_rn = load_rn_shp(extracted_rn_path, is_directed=False) 73 | link_generator.generate(extracted_rn, linked_rn_path) 74 | # Custom Map Matching 75 | elif opt.phase == 3: 76 | custom_map_matching = CustomMapMatching(linked_rn_path, topo_params['alpha']) 77 | filenames = os.listdir(traj_path) 78 | with Pool() as pool: 79 | pool.map(partial(custom_map_matching.execute, 80 | traj_path=traj_path, mm_result_path=mm_on_linked_rn_path), filenames) 81 | # Map Refinement 82 | elif opt.phase == 4: 83 | linked_rn = load_rn_shp(linked_rn_path, is_directed=True) 84 | map_refiner = MapRefiner(topo_params['min_supp']) 85 | map_refiner.refine(linked_rn, mm_on_linked_rn_path, final_rn_path) 86 | else: 87 | raise Exception('invalid phase') 88 | -------------------------------------------------------------------------------- /geometry_translation/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import init 3 | from torch.optim import lr_scheduler 4 | from geometry_translation.models.translator import T2RNet_Trans 5 | 6 | 7 | def define_translator(input_nc, output_nc, net_trans, init_type='normal', init_gain=0.02, gpu_ids=[]): 8 | if net_trans == 'T2RNet_Trans': 9 | net = T2RNet_Trans(input_nc, 64) 10 | else: 11 | raise NotImplementedError('Translator model name {} is not recognized'.format(net_trans)) 12 | return init_net(net, init_type, init_gain, gpu_ids) 13 | 14 | 15 | def get_scheduler(optimizer, opt): 16 | """Return a learning rate scheduler 17 | 18 | Parameters: 19 | optimizer -- the optimizer of the network 20 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  21 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 22 | 23 | For 'linear', we keep the same learning rate for the first epochs 24 | and linearly decay the rate to zero over the next epochs. 25 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 26 | See https://pytorch.org/docs/stable/optim.html for more details. 27 | """ 28 | if opt.lr_policy == 'linear': 29 | def lambda_rule(epoch): 30 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 31 | return lr_l 32 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 33 | elif opt.lr_policy == 'step': 34 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.5) 35 | elif opt.lr_policy == 'plateau': 36 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 37 | elif opt.lr_policy == 'cosine': 38 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 39 | else: 40 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 41 | return scheduler 42 | 43 | 44 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 45 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 46 | Parameters: 47 | net (network) -- the network to be initialized 48 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 49 | gain (float) -- scaling factor for normal, xavier and orthogonal. 50 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 51 | 52 | Return an initialized network. 53 | """ 54 | if len(gpu_ids) > 0: 55 | assert(torch.cuda.is_available()) 56 | net.to(gpu_ids[0]) 57 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 58 | init_weights(net, init_type, init_gain=init_gain) 59 | return net 60 | 61 | 62 | def init_weights(net, init_type='normal', init_gain=0.02): 63 | """Initialize network weights. 64 | 65 | Parameters: 66 | net (network) -- network to be initialized 67 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 68 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 69 | 70 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 71 | work better for some applications. Feel free to try yourself. 72 | """ 73 | def init_func(m): # define the initialization function 74 | classname = m.__class__.__name__ 75 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 76 | if init_type == 'normal': 77 | init.normal_(m.weight.data, 0.0, init_gain) 78 | elif init_type == 'xavier': 79 | init.xavier_normal_(m.weight.data, gain=init_gain) 80 | elif init_type == 'kaiming': 81 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 82 | elif init_type == 'orthogonal': 83 | init.orthogonal_(m.weight.data, gain=init_gain) 84 | else: 85 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 86 | if hasattr(m, 'bias') and m.bias is not None: 87 | init.constant_(m.bias.data, 0.0) 88 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 89 | init.normal_(m.weight.data, 1.0, init_gain) 90 | init.constant_(m.bias.data, 0.0) 91 | 92 | print('initialize network with %s' % init_type) 93 | net.apply(init_func) # apply the initialization function -------------------------------------------------------------------------------- /geometry_translation/models/t2rnet_model.py: -------------------------------------------------------------------------------- 1 | from geometry_translation.models.networks import define_translator, init_net 2 | from geometry_translation.models.base_model import BaseModel 3 | from geometry_translation.models.losses import SoftDiceLoss 4 | from geometry_translation.models.metrics import Metrics 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | from itertools import chain 9 | 10 | 11 | class T2RNetModel(BaseModel): 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train=True): 15 | return parser 16 | 17 | def __init__(self, opt): 18 | BaseModel.__init__(self, opt) 19 | self.visual_names = [] 20 | self.visual_names.append('traj_pt') 21 | self.visual_names.append('traj_line') 22 | self.visual_names.append('pred_rn_img') 23 | self.visual_names.append('pred_road_img') 24 | self.visual_names.append('real_rn') 25 | self.loss_names = ['centerline', 'region', 'all'] 26 | self.model_names = ['Trans'] 27 | input_nc = 11 28 | self.model_names.append('EM') 29 | transit_nc = 128 30 | transit_embedding_dim = 8 31 | self.netEM = nn.Sequential( 32 | nn.Linear(transit_nc, 64), 33 | nn.ReLU(), 34 | nn.Linear(64, transit_embedding_dim), 35 | nn.ReLU() 36 | ) 37 | self.netEM = init_net(self.netEM, init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids) 38 | input_nc += transit_embedding_dim 39 | 40 | # add the output of dilation 41 | self.netTrans = define_translator(input_nc, opt.output_nc, opt.net_trans, gpu_ids=opt.gpu_ids) 42 | self.metrics = Metrics() 43 | self.criterion = SoftDiceLoss() 44 | if opt.is_train: 45 | self.optimizer = torch.optim.Adam( 46 | chain(self.netEM.parameters(), self.netTrans.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) 47 | self.optimizers.append(self.optimizer) 48 | self.traj = None 49 | self.traj_pt = None 50 | self.traj_line = None 51 | self.transist = None 52 | self.real_rn = None 53 | self.pred_rn = None 54 | self.pred_rn_img = None 55 | self.pred_road_img = None 56 | self.loss_all = None 57 | self.loss_centerline = None 58 | self.loss_region = None 59 | self.real_road = None 60 | self.pred_road = None 61 | 62 | def set_input(self, inp): 63 | self.traj = inp['spatial_view'].to(self.device) 64 | self.real_rn = inp['centerline'].to(self.device) 65 | self.traj_pt = self.traj[:, 0:1, :, :] 66 | self.traj_line = self.traj[:, 1:2, :, :] 67 | self.transist = inp['transition_view'].to(self.device) 68 | self.real_road = inp['region'].to(self.device) 69 | self.image_paths = inp['img_path'] 70 | 71 | def forward(self): 72 | embedded_trans = self.netEM(self.transist.reshape([-1, 256, 256, 128])).permute(0, 3, 1, 2) 73 | inp = torch.cat([self.traj, embedded_trans], 1) 74 | self.pred_road, self.pred_rn = self.netTrans(inp) 75 | self.pred_rn_img = T2RNetModel.pred2im(self.pred_rn) 76 | self.pred_road_img = T2RNetModel.pred2im(self.pred_road) 77 | 78 | def _backward(self): 79 | self.loss_centerline = self.criterion(self.real_rn, self.pred_rn) 80 | self.loss_region = self.criterion(self.real_road, self.pred_road) 81 | self.loss_all = (1 - self.opt.lam) * self.loss_centerline + self.opt.lam * self.loss_region 82 | self.loss_all.backward() 83 | 84 | def optimize_parameters(self): 85 | self.forward() 86 | metrics = self.metrics(self.pred_rn, self.real_rn) 87 | road_metrics = self.metrics(self.pred_road, self.real_road) 88 | 89 | self.optimizer.zero_grad() 90 | self._backward() 91 | self.optimizer.step() 92 | return self.loss_all, self.loss_region, self.loss_centerline, metrics, road_metrics 93 | 94 | def test(self): 95 | BaseModel.test(self) 96 | metrics = self.metrics(self.pred_rn, self.real_rn) 97 | road_metrics = self.metrics(self.pred_road, self.real_road) 98 | self.loss_centerline = self.criterion(self.real_rn, self.pred_rn) 99 | self.loss_region = self.criterion(self.real_road, self.pred_road) 100 | self.loss_all = (1 - self.opt.lam) * self.loss_centerline + self.opt.lam * self.loss_region 101 | return self.pred_rn, self.loss_all, metrics, road_metrics 102 | 103 | @staticmethod 104 | def pred2im(image_tensor): 105 | image_numpy = image_tensor[0].cpu().float().detach().numpy() # convert it into a numpy array 106 | # handle sigmoid cases 107 | image_numpy[image_numpy > 0.5] = 1 108 | image_numpy[image_numpy <= 0.5] = 0 109 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 110 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255 111 | return image_numpy 112 | -------------------------------------------------------------------------------- /geometry_translation/models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from abc import ABC, abstractmethod 4 | import os 5 | from collections import OrderedDict 6 | from geometry_translation.models import networks 7 | 8 | 9 | class BaseModel(ABC): 10 | def __init__(self, opt): 11 | self.opt = opt 12 | self.gpu_ids = opt.gpu_ids 13 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 14 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 15 | self.is_train = opt.is_train 16 | self.model_names = [] 17 | self.loss_names = [] 18 | self.visual_names = [] 19 | self.optimizers = [] 20 | self.image_paths = [] 21 | self.metric = 0 # used for learning rate policy 'plateau' 22 | 23 | @abstractmethod 24 | def set_input(self, inp): 25 | pass 26 | 27 | @abstractmethod 28 | def forward(self): 29 | pass 30 | 31 | @abstractmethod 32 | def optimize_parameters(self): 33 | pass 34 | 35 | def train(self): 36 | for name in self.model_names: 37 | if isinstance(name, str): 38 | net = getattr(self, 'net' + name) 39 | net.train() 40 | 41 | def eval(self): 42 | for name in self.model_names: 43 | if isinstance(name, str): 44 | net = getattr(self, 'net' + name) 45 | net.eval() 46 | 47 | def test(self): 48 | with torch.no_grad(): 49 | self.forward() 50 | self.compute_visuals() 51 | 52 | def compute_visuals(self): 53 | """Calculate additional output images for visdom and HTML visualization""" 54 | pass 55 | 56 | def get_image_paths(self): 57 | """ Return image paths that are used to load current data""" 58 | return self.image_paths 59 | 60 | def save_networks(self, epoch): 61 | for name in self.model_names: 62 | if isinstance(name, str): 63 | save_filename = '%s_net_%s.pth' % (epoch, name) 64 | save_path = os.path.join(self.save_dir, save_filename) 65 | net = getattr(self, 'net' + name) 66 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 67 | torch.save(net.module.cpu().state_dict(), save_path) 68 | net.cuda(self.gpu_ids[0]) 69 | else: 70 | torch.save(net.cpu().state_dict(), save_path) 71 | 72 | def setup(self, opt): 73 | """Load and print networks; create schedulers 74 | Parameters: 75 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 76 | """ 77 | if self.is_train: 78 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 79 | if not self.is_train or opt.continue_train: 80 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 81 | self.load_networks(load_suffix) 82 | self.print_networks(opt.verbose) 83 | 84 | def load_networks(self, epoch): 85 | for name in self.model_names: 86 | if isinstance(name, str): 87 | load_filename = '%s_net_%s.pth' % (epoch, name) 88 | load_path = os.path.join(self.save_dir, load_filename) 89 | net = getattr(self, 'net' + name) 90 | if isinstance(net, nn.DataParallel): 91 | net = net.module 92 | print('loading the model from %s' % load_path) 93 | state_dict = torch.load(load_path, map_location=self.device) 94 | net.load_state_dict(state_dict) 95 | 96 | def print_networks(self, verbose): 97 | print('---------- Networks initialized -------------') 98 | for name in self.model_names: 99 | if isinstance(name, str): 100 | net = getattr(self, 'net' + name) 101 | num_params = 0 102 | for param in net.parameters(): 103 | num_params += param.numel() 104 | if verbose: 105 | print(net) 106 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 107 | print('-----------------------------------------------') 108 | 109 | def get_current_losses(self): 110 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 111 | errors_ret = OrderedDict() 112 | for name in self.loss_names: 113 | if isinstance(name, str): 114 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 115 | return errors_ret 116 | 117 | def get_current_visuals(self): 118 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 119 | visual_ret = OrderedDict() 120 | for name in self.visual_names: 121 | if isinstance(name, str): 122 | visual_ret[name] = getattr(self, name) 123 | return visual_ret 124 | 125 | def update_learning_rate(self): 126 | """Update learning rates for all the networks; called at the end of every epoch""" 127 | for scheduler in self.schedulers: 128 | if self.opt.lr_policy == 'plateau': 129 | scheduler.step(self.metric) 130 | else: 131 | scheduler.step() 132 | lr = self.optimizers[0].param_groups[0]['lr'] 133 | print('learning rate = %.7f' % lr) 134 | -------------------------------------------------------------------------------- /geometry_translation/options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from geometry_translation import models 5 | 6 | 7 | class BaseOptions: 8 | """This class defines options used during both training and test time. 9 | It also implements several helper functions such as parsing, printing, and saving the options. 10 | It also gathers additional options defined in functions in both dataset class and model class. 11 | """ 12 | 13 | def __init__(self): 14 | """Reset the class; indicates the class hasn't been initailized""" 15 | self.is_train = True 16 | self.initialized = False 17 | 18 | def initialize(self, parser): 19 | """Define the common options that are used in both training and test.""" 20 | # basic parameters 21 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 22 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 23 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 24 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 25 | # model parameters 26 | parser.add_argument('--model', type=str, default='t2rnet', help='chooses which model to use.') 27 | parser.add_argument('--net_trans', type=str, default='T2RNet_Trans') 28 | parser.add_argument('--lam', type=float, default=0.2) 29 | 30 | # parser.add_argument('--input_nc', type=int, default=11, help='# of input image channels: 3 for RGB and 1 for grayscale') 31 | parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale') 32 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') 33 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 34 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 35 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 36 | 37 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 38 | parser.add_argument('--batch_size', type=int, default=4, help='input batch size') 39 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 40 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 41 | # additional parameters 42 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 43 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 44 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 45 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 46 | self.initialized = True 47 | return parser 48 | 49 | def gather_options(self): 50 | """Initialize our parser with basic options(only once). 51 | Add additional model-specific and dataset-specific options. 52 | These options are defined in the function 53 | in model and dataset classes. 54 | """ 55 | if not self.initialized: # check if it has been initialized 56 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 57 | parser = self.initialize(parser) 58 | 59 | # get the basic options 60 | opt, _ = parser.parse_known_args() 61 | 62 | # modify model-related parser options 63 | model_name = opt.model 64 | model_option_setter = models.get_option_setter(model_name) 65 | parser = model_option_setter(parser, self.is_train) 66 | opt, _ = parser.parse_known_args() # parse again with new defaults 67 | 68 | # save and return the parser 69 | self.parser = parser 70 | return parser.parse_args() 71 | 72 | def print_options(self, opt): 73 | """Print and save options 74 | It will print both current options and default values(if different). 75 | It will save options into a text file / [checkpoints_dir] / opt.txt 76 | """ 77 | message = '' 78 | message += '----------------- Options ---------------\n' 79 | for k, v in sorted(vars(opt).items()): 80 | comment = '' 81 | default = self.parser.get_default(k) 82 | if v != default: 83 | comment = '\t[default: %s]' % str(default) 84 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 85 | message += '----------------- End -------------------' 86 | print(message) 87 | 88 | # save to the disk 89 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 90 | os.makedirs(expr_dir, exist_ok=True) 91 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 92 | with open(file_name, 'wt') as opt_file: 93 | opt_file.write(message) 94 | opt_file.write('\n') 95 | 96 | def parse(self): 97 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 98 | opt = self.gather_options() 99 | opt.is_train = self.is_train # train or test 100 | 101 | # process opt.suffix 102 | if opt.suffix: 103 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 104 | opt.name = opt.name + suffix 105 | 106 | self.print_options(opt) 107 | 108 | # set gpu ids 109 | str_ids = opt.gpu_ids.split(',') 110 | opt.gpu_ids = [] 111 | for str_id in str_ids: 112 | id = int(str_id) 113 | if id >= 0: 114 | opt.gpu_ids.append(id) 115 | if len(opt.gpu_ids) > 0: 116 | torch.cuda.set_device(opt.gpu_ids[0]) 117 | 118 | self.opt = opt 119 | return self.opt 120 | -------------------------------------------------------------------------------- /geometry_translation/utils/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util, html 7 | from subprocess import Popen, PIPE 8 | # from scipy.misc import imresize 9 | from PIL import Image 10 | 11 | if sys.version_info[0] == 2: 12 | VisdomExceptionBase = Exception 13 | else: 14 | VisdomExceptionBase = ConnectionError 15 | 16 | 17 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 18 | """Save images to the disk. 19 | Parameters: 20 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 21 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 22 | image_path (str) -- the string is used to create image paths 23 | aspect_ratio (float) -- the aspect ratio of saved images 24 | width (int) -- the images will be resized to width x width 25 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 26 | """ 27 | image_dir = webpage.get_image_dir() 28 | short_path = ntpath.basename(image_path[0]) 29 | name = os.path.splitext(short_path)[0] 30 | 31 | webpage.add_header(name) 32 | ims, txts, links = [], [], [] 33 | 34 | for label, im_data in visuals.items(): 35 | im = util.tensor2im(im_data) 36 | image_name = '%s_%s.png' % (name, label) 37 | save_path = os.path.join(image_dir, image_name) 38 | h, w, _ = im.shape 39 | if aspect_ratio > 1.0: 40 | # im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 41 | im = np.array(Image.fromarray(im).resize((h, int(w * aspect_ratio)))) 42 | if aspect_ratio < 1.0: 43 | # im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 44 | im = np.array(Image.fromarray(im).resize((int(h / aspect_ratio), w))) 45 | util.save_image(im, save_path) 46 | 47 | ims.append(image_name) 48 | txts.append(label) 49 | links.append(image_name) 50 | webpage.add_images(ims, txts, links, width=width) 51 | 52 | 53 | class Visualizer(): 54 | """This class includes several functions that can display/save images and print/save logging information. 55 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 56 | """ 57 | 58 | def __init__(self, opt): 59 | """Initialize the Visualizer class 60 | Parameters: 61 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 62 | Step 1: Cache the training/test options 63 | Step 2: connect to a visdom server 64 | Step 3: create an HTML object for saveing HTML filters 65 | Step 4: create a logging file to store training losses 66 | """ 67 | self.opt = opt # cache the option 68 | self.display_id = opt.display_id 69 | self.use_html = opt.is_train and not opt.no_html 70 | self.win_size = opt.display_winsize 71 | self.name = opt.name 72 | self.port = opt.display_port 73 | self.saved = False 74 | if self.display_id > 0: # connect to a visdom server given and 75 | import visdom 76 | self.ncols = opt.display_ncols 77 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 78 | if not self.vis.check_connection(): 79 | self.create_visdom_connections() 80 | 81 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 82 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 83 | self.img_dir = os.path.join(self.web_dir, 'images') 84 | print('create web directory %s...' % self.web_dir) 85 | util.mkdirs([self.web_dir, self.img_dir]) 86 | # create a logging file to store training losses 87 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 88 | with open(self.log_name, "a") as log_file: 89 | now = time.strftime("%c") 90 | log_file.write('================ Training Loss (%s) ================\n' % now) 91 | 92 | def reset(self): 93 | """Reset the self.saved status""" 94 | self.saved = False 95 | 96 | def create_visdom_connections(self): 97 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 98 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 99 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 100 | print('Command: %s' % cmd) 101 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 102 | 103 | def display_current_results(self, visuals, epoch, save_result): 104 | """Display current results on visdom; save current results to an HTML file. 105 | Parameters: 106 | visuals (OrderedDict) - - dictionary of images to display or save 107 | epoch (int) - - the current epoch 108 | save_result (bool) - - if save the current results to an HTML file 109 | """ 110 | if self.display_id > 0: # show images in the browser using visdom 111 | ncols = self.ncols 112 | if ncols > 0: # show all the images in one visdom panel 113 | ncols = min(ncols, len(visuals)) 114 | h, w = next(iter(visuals.values())).shape[:2] 115 | table_css = """""" % (w, h) # create a table css 119 | # create a table of images. 120 | title = self.name 121 | label_html = '' 122 | label_html_row = '' 123 | images = [] 124 | idx = 0 125 | for label, image in visuals.items(): 126 | image_numpy = util.tensor2im(image) 127 | label_html_row += '%s' % label 128 | images.append(image_numpy.transpose([2, 0, 1])) 129 | idx += 1 130 | if idx % ncols == 0: 131 | label_html += '%s' % label_html_row 132 | label_html_row = '' 133 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 134 | while idx % ncols != 0: 135 | images.append(white_image) 136 | label_html_row += '' 137 | idx += 1 138 | if label_html_row != '': 139 | label_html += '%s' % label_html_row 140 | try: 141 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 142 | padding=2, opts=dict(title=title + ' images')) 143 | label_html = '%s
' % label_html 144 | self.vis.text(table_css + label_html, win=self.display_id + 2, 145 | opts=dict(title=title + ' labels')) 146 | except VisdomExceptionBase: 147 | self.create_visdom_connections() 148 | 149 | else: # show each image in a separate visdom panel; 150 | idx = 1 151 | try: 152 | for label, image in visuals.items(): 153 | image_numpy = util.tensor2im(image) 154 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 155 | win=self.display_id + idx) 156 | idx += 1 157 | except VisdomExceptionBase: 158 | self.create_visdom_connections() 159 | 160 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 161 | self.saved = True 162 | # save images to the disk 163 | for label, image in visuals.items(): 164 | image_numpy = util.tensor2im(image) 165 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 166 | util.save_image(image_numpy, img_path) 167 | 168 | # update website 169 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 170 | for n in range(epoch, 0, -1): 171 | webpage.add_header('epoch [%d]' % n) 172 | ims, txts, links = [], [], [] 173 | 174 | for label, image_numpy in visuals.items(): 175 | image_numpy = util.tensor2im(image) 176 | img_path = 'epoch%.3d_%s.png' % (n, label) 177 | ims.append(img_path) 178 | txts.append(label) 179 | links.append(img_path) 180 | webpage.add_images(ims, txts, links, width=self.win_size) 181 | webpage.save() 182 | 183 | def plot_current_losses(self, epoch, counter_ratio, losses): 184 | """display the current losses on visdom display: dictionary of error labels and values 185 | Parameters: 186 | epoch (int) -- current epoch 187 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 188 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 189 | """ 190 | if not hasattr(self, 'plot_data'): 191 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 192 | self.plot_data['X'].append(epoch + counter_ratio) 193 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 194 | try: 195 | self.vis.line( 196 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 197 | Y=np.array(self.plot_data['Y']), 198 | opts={ 199 | 'title': self.name + ' loss over time', 200 | 'legend': self.plot_data['legend'], 201 | 'xlabel': 'epoch', 202 | 'ylabel': 'loss'}, 203 | win=self.display_id) 204 | except VisdomExceptionBase: 205 | self.create_visdom_connections() 206 | 207 | # losses: same format as |losses| of plot_current_losses 208 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 209 | """print current losses on console; also save the losses to the disk 210 | Parameters: 211 | epoch (int) -- current epoch 212 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 213 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 214 | t_comp (float) -- computational time per data point (normalized by batch_size) 215 | t_data (float) -- data loading time per data point (normalized by batch_size) 216 | """ 217 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 218 | for k, v in losses.items(): 219 | message += '%s: %.3f ' % (k, v) 220 | 221 | print(message) # print the message 222 | with open(self.log_name, "a") as log_file: 223 | log_file.write('%s\n' % message) # save the message 224 | -------------------------------------------------------------------------------- /topology_construction/graph_extraction.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from osgeo import ogr 3 | import numpy as np 4 | import sys 5 | import itertools 6 | from collections import deque 7 | from rtree import Rtree 8 | from tptk.common.douglas_peucker import DouglasPeucker 9 | from tptk.common.spatial_func import SPoint 10 | from tptk.common.mbr import MBR 11 | from tptk.common.grid import Grid 12 | from tptk.common.spatial_func import distance 13 | 14 | 15 | class CenterNodePixel: 16 | def __init__(self, node_pixels): 17 | self.node_pixels = node_pixels 18 | self.i = sum([node_pixel[0] for node_pixel in node_pixels]) / len(node_pixels) 19 | self.j = sum([node_pixel[1] for node_pixel in node_pixels]) / len(node_pixels) 20 | 21 | def center_pixel(self): 22 | return int(self.i), int(self.j) 23 | 24 | 25 | class GraphExtractor: 26 | INVALID = -1 27 | BLANK = 0 28 | EDGE = 1 29 | NODE = 2 30 | VISITED_NODE = 3 31 | 32 | def __init__(self, epsilon, min_road_dist): 33 | self.segment_simplifier = DouglasPeucker(epsilon) 34 | self.min_road_dist = min_road_dist 35 | 36 | def extract(self, skeleton, mbr, target_path): 37 | assert skeleton.ndim == 2, 'the grey scale skeleton should only have 1 channel' 38 | 39 | # pad zero for safety 40 | grid = Grid(mbr, skeleton.shape[0], skeleton.shape[1]) 41 | eval_mbr_padded = MBR(mbr.min_lat - 2 * grid.lat_interval, mbr.min_lng - 2 * grid.lng_interval, 42 | mbr.max_lat + 2 * grid.lat_interval, mbr.max_lng + 2 * grid.lng_interval) 43 | mbr = eval_mbr_padded 44 | skeleton_padded = np.zeros((skeleton.shape[0] + 4, skeleton.shape[1] + 4)) 45 | skeleton_padded[2:skeleton.shape[0] + 2, 2:skeleton.shape[1] + 2] = skeleton 46 | skeleton = skeleton_padded 47 | 48 | nb_rows, nb_cols = skeleton.shape 49 | yscale = nb_rows / (mbr.max_lat - mbr.min_lat) 50 | xscale = nb_cols / (mbr.max_lng - mbr.min_lng) 51 | 52 | sys.stdout.write("Identifying road nodes pixels... ") 53 | sys.stdout.flush() 54 | binary_skeleton = (skeleton > 0).astype('int32') 55 | status_matrix = self.identify_node_pixels(binary_skeleton) 56 | sys.stdout.write("done.\n") 57 | sys.stdout.flush() 58 | 59 | sys.stdout.write("Detecting road network component... ") 60 | nodes, segments = self.detect_rn_component(status_matrix) 61 | sys.stdout.write("done.\n") 62 | sys.stdout.flush() 63 | 64 | sys.stdout.write("Constructing and saving road network... ") 65 | sys.stdout.flush() 66 | self.construct_undirected_rn(nodes, segments, target_path, nb_rows, xscale, yscale, mbr) 67 | sys.stdout.write("done.\n") 68 | sys.stdout.flush() 69 | 70 | def identify_node_pixels(self, skeleton): 71 | """ 72 | 22 23 08 09 10 73 | 21 07 00 01 11 74 | 20 06 -1 02 12 75 | 19 05 04 03 13 76 | 18 17 16 15 14 77 | :param skeleton: 78 | :return: 79 | """ 80 | status_matrix = np.copy(skeleton) 81 | road_pixels = np.where(status_matrix == GraphExtractor.EDGE) 82 | nb_road_pixels = len(road_pixels[0]) 83 | print('\n# of road pixels:{}'.format(nb_road_pixels)) 84 | cnt = 1 85 | for i, j in zip(road_pixels[0], road_pixels[1]): 86 | if (cnt % 100 == 0) or (cnt == nb_road_pixels): 87 | sys.stdout.write("\r" + str(cnt) + "/" + str(nb_road_pixels) + "... ") 88 | sys.stdout.flush() 89 | cnt += 1 90 | # skip boundary 91 | if i < 2 or i >= status_matrix.shape[0] - 2 or j < 2 or j >= status_matrix.shape[1] - 2: 92 | continue 93 | p = [skeleton[i - 1][j], skeleton[i - 1][j + 1], skeleton[i][j + 1], skeleton[i + 1][j + 1], 94 | skeleton[i + 1][j], skeleton[i + 1][j - 1], skeleton[i][j - 1], skeleton[i - 1][j - 1], 95 | skeleton[i - 2][j], skeleton[i - 2][j + 1], skeleton[i - 2][j + 2], skeleton[i - 1][j + 2], 96 | skeleton[i][j + 2], skeleton[i + 1][j + 2], skeleton[i + 2][j + 2], skeleton[i + 2][j + 1], 97 | skeleton[i + 2][j], skeleton[i + 2][j - 1], skeleton[i + 2][j - 2], skeleton[i + 1][j - 2], 98 | skeleton[i][j - 2], skeleton[i - 1][j - 2], skeleton[i - 2][j - 2], skeleton[i - 2][j - 1]] 99 | fringe = [bool(p[8] and bool(p[7] or p[0] or p[1])), 100 | bool(p[9] and bool(p[0] or p[1])), 101 | bool(p[10] and p[1]), 102 | bool(p[11] and bool(p[1] or p[2])), 103 | bool(p[12] and bool(p[1] or p[2] or p[3])), 104 | bool(p[13] and bool(p[2] or p[3])), 105 | bool(p[14] and p[3]), 106 | bool(p[15] and bool(p[3] or p[4])), 107 | bool(p[16] and bool(p[3] or p[4] or p[5])), 108 | bool(p[17] and bool(p[4] or p[5])), 109 | bool(p[18] and p[5]), 110 | bool(p[19] and bool(p[5] or p[6])), 111 | bool(p[20] and bool(p[5] or p[6] or p[7])), 112 | bool(p[21] and bool(p[6] or p[7])), 113 | bool(p[22] and p[7]), 114 | bool(p[23] and bool(p[7] or p[0]))] 115 | connected_component_cnt = 0 116 | for k in range(0, len(fringe)): 117 | connected_component_cnt += int(not bool(fringe[k]) and bool(fringe[(k + 1) % len(fringe)])) 118 | if connected_component_cnt == 0: 119 | status_matrix[i][j] = GraphExtractor.BLANK 120 | elif (connected_component_cnt == 1) or (connected_component_cnt > 2): 121 | status_matrix[i][j] = GraphExtractor.NODE 122 | # if connected_component_cnt == 2, we think it is a normal internal node 123 | return status_matrix 124 | 125 | def detect_rn_component(self, status_matrix): 126 | node_pixels = np.where(status_matrix == GraphExtractor.NODE) 127 | nb_node_pixels = len(node_pixels[0]) 128 | print('\n# of node pixels:{}'.format(nb_node_pixels)) 129 | neighbor_deltas = [dxdy for dxdy in itertools.product([-1, 0, 1], [-1, 0, 1]) 130 | if dxdy[0] != 0 or dxdy[1] != 0] 131 | # node pixel -> center node 132 | nodes = {} 133 | node_pixel_spatial_index = Rtree() 134 | # [node pixel sequence (start and end must be node pixel)] 135 | connected_segments = [] 136 | cnt = 1 137 | center_nodes = [] 138 | node_pixel_id = 0 139 | for i, j in zip(node_pixels[0], node_pixels[1]): 140 | if (cnt % 100 == 0) or (cnt == nb_node_pixels): 141 | sys.stdout.write("\r" + str(cnt) + "/" + str(nb_node_pixels) + "... ") 142 | sys.stdout.flush() 143 | cnt += 1 144 | if status_matrix[i][j] == GraphExtractor.VISITED_NODE: 145 | continue 146 | # region merge neighbor node pixels 147 | status_matrix[i][j] = GraphExtractor.VISITED_NODE 148 | candidates = [(i, j)] 149 | node_pixels = [] 150 | while len(candidates) > 0: 151 | node_pixel = candidates.pop() 152 | node_pixels.append(node_pixel) 153 | m, n = node_pixel 154 | for dm, dn in neighbor_deltas: 155 | if status_matrix[m + dm][n + dn] == GraphExtractor.NODE: 156 | status_matrix[m + dm][n + dn] = GraphExtractor.VISITED_NODE 157 | candidates.append((m + dm, n + dn)) 158 | center_node = CenterNodePixel(node_pixels) 159 | center_nodes.append(center_node) 160 | for node_pixel in node_pixels: 161 | nodes[node_pixel] = center_node 162 | node_pixel_spatial_index.insert(node_pixel_id, node_pixel, obj=node_pixel) 163 | node_pixel_id += 1 164 | # endregion 165 | 166 | # region find neighbor segments 167 | # mask current nodes, make sure the edge doesn't return to itself 168 | for m, n in node_pixels: 169 | status_matrix[m][n] = GraphExtractor.INVALID 170 | # find new road segment of the current node in each possible direction 171 | for node_pixel in node_pixels: 172 | connected_segment = self.detect_connected_segment(status_matrix, node_pixel) 173 | if connected_segment is not None: 174 | connected_segments.append(connected_segment) 175 | # restore masked nodes 176 | for m, n in node_pixels: 177 | status_matrix[m][n] = GraphExtractor.VISITED_NODE 178 | # endregion 179 | print('\n# of directly connected segments:{}'.format(len(connected_segments))) 180 | 181 | # there might be few edge pixels left, that should be fine 182 | nb_unprocessed_edge_pixels = np.sum(status_matrix[status_matrix == GraphExtractor.EDGE]) 183 | print('unprocessed edge pixels:{}'.format(nb_unprocessed_edge_pixels)) 184 | 185 | print('# of nodes:{}'.format(len(center_nodes))) 186 | print('# of segments:{}'.format(len(connected_segments))) 187 | return nodes, connected_segments 188 | 189 | def detect_connected_segment(self, status_matrix, start_node_pixel): 190 | """ 191 | find a path ended with node pixel 192 | :param status_matrix: status 193 | :param start_node_pixel: start node pixel 194 | :return: [start_node_pixel, edge_pixel,...,end_node_pixel] 195 | """ 196 | # for current implementation, we assume edge pixel has only two arcs 197 | # but it is possible that edge pixel has multiple connected component rather than 2, 198 | # because crossing are detected using outer pixels 199 | s = deque() 200 | neighbor_deltas = [dxdy for dxdy in itertools.product([-1, 0, 1], [-1, 0, 1]) 201 | if dxdy[0] != 0 or dxdy[1] != 0] 202 | # add candidates to stack 203 | m, n = start_node_pixel 204 | for dm, dn in neighbor_deltas: 205 | if status_matrix[m + dm][n + dn] == GraphExtractor.EDGE: 206 | s.appendleft(((m + dm, n + dn), [start_node_pixel])) 207 | while len(s) > 0: 208 | (m, n), path = s.popleft() 209 | # end node pixel 210 | if status_matrix[m][n] == GraphExtractor.NODE or \ 211 | status_matrix[m][n] == GraphExtractor.VISITED_NODE: 212 | path.append((m, n)) 213 | return path 214 | # internal edge pixel 215 | elif status_matrix[m][n] == GraphExtractor.EDGE: 216 | # mark the edge as visited 217 | status_matrix[m][n] = GraphExtractor.BLANK 218 | new_path = path.copy() 219 | new_path.append((m, n)) 220 | for dm, dn in neighbor_deltas: 221 | s.appendleft(((m + dm, n + dn), new_path)) 222 | return None 223 | 224 | def construct_undirected_rn(self, nodes, segments, target_path, nb_rows, xscale, yscale, mbr): 225 | # node pixel -> road node 226 | road_nodes = {} 227 | eid = 0 228 | # we use coordinate tuples as key, consistent with networkx 229 | # !!! Though we construct DiGraph (compatible with networkx interface, one segment will only add once) 230 | # loading this data, we should call g.to_undirected() 231 | rn = nx.DiGraph() 232 | for segment in segments: 233 | coords = [] 234 | # start node 235 | start_node_pixel = nodes[segment[0]].center_pixel() 236 | if start_node_pixel not in road_nodes: 237 | lat, lng = self.pixels_to_latlng(start_node_pixel, mbr, nb_rows, xscale, yscale) 238 | road_node = SPoint(lat, lng) 239 | geo_pt = ogr.Geometry(ogr.wkbPoint) 240 | geo_pt.AddPoint(lng, lat) 241 | rn.add_node((lng, lat)) 242 | road_nodes[start_node_pixel] = road_node 243 | start_node = road_nodes[start_node_pixel] 244 | coords.append(start_node) 245 | start_node_key = (start_node.lng, start_node.lat) 246 | # internal nodes, we didn't create id for them 247 | for coord in segment[1:-1]: 248 | lat, lng = self.pixels_to_latlng(coord, mbr, nb_rows, xscale, yscale) 249 | coords.append(SPoint(lat, lng)) 250 | # end node 251 | end_node_pixel = nodes[segment[-1]].center_pixel() 252 | if end_node_pixel not in road_nodes: 253 | lat, lng = self.pixels_to_latlng(end_node_pixel, mbr, nb_rows, xscale, yscale) 254 | road_node = SPoint(lat, lng) 255 | geo_pt = ogr.Geometry(ogr.wkbPoint) 256 | geo_pt.AddPoint(lng, lat) 257 | rn.add_node((lng, lat)) 258 | road_nodes[end_node_pixel] = road_node 259 | end_node = road_nodes[end_node_pixel] 260 | coords.append(end_node) 261 | end_node_key = (end_node.lng, end_node.lat) 262 | # region add segment 263 | # skip loop 264 | if start_node_key == end_node_key: 265 | continue 266 | simplified_coords = self.segment_simplifier.simplify(coords) 267 | # skip too short segment 268 | if not self.is_valid(simplified_coords): 269 | continue 270 | # add forward segment 271 | geo_line = ogr.Geometry(ogr.wkbLineString) 272 | for simplified_coord in simplified_coords: 273 | geo_line.AddPoint(simplified_coord.lng, simplified_coord.lat) 274 | rn.add_edge(start_node_key, end_node_key, eid=eid, Wkb=geo_line.ExportToWkb(), type='pred') 275 | eid += 1 276 | # endregion 277 | rn.remove_nodes_from(list(nx.isolates(rn))) 278 | print('\n# of nodes:{}'.format(rn.number_of_nodes())) 279 | print('# of edges:{}'.format(rn.number_of_edges())) 280 | nx.write_shp(rn, target_path) 281 | return rn 282 | 283 | def is_valid(self, coords): 284 | dist = 0.0 285 | for i in range(len(coords) - 1): 286 | dist += distance(coords[i], coords[i + 1]) 287 | return dist > self.min_road_dist 288 | 289 | def pixels_to_latlng(self, pixel, mbr, nb_rows, xscale, yscale): 290 | i, j = pixel 291 | return (((nb_rows - i) / yscale) + mbr.min_lat), ((j / xscale) + mbr.min_lng) 292 | -------------------------------------------------------------------------------- /geometry_translation/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('../tptk/') 4 | from tptk.common.road_network import load_rn_shp 5 | from tptk.common.trajectory import parse_traj_file 6 | from tptk.common.grid import Grid 7 | from tptk.common.mbr import MBR 8 | from tptk.common.spatial_func import distance, bearing, LAT_PER_METER 9 | import cv2 10 | from tqdm import tqdm 11 | import numpy as np 12 | import json 13 | import shutil 14 | 15 | 16 | def generate_point_image(traj_dir, grid_idx, feature_path): 17 | pt_cnt = np.zeros((grid_idx.row_num, grid_idx.col_num)) 18 | for filename in tqdm(os.listdir(traj_dir)): 19 | if not filename.endswith('.txt'): 20 | continue 21 | trajs = parse_traj_file(os.path.join(traj_dir, filename)) 22 | for traj in trajs: 23 | for cur_pt in traj.pt_list: 24 | try: 25 | row_idx, col_idx = grid_idx.get_matrix_idx(cur_pt.lat, cur_pt.lng) 26 | pt_cnt[row_idx, col_idx] += 1 27 | except IndexError: 28 | continue 29 | pt_cnt = pt_cnt / 2 * 255 30 | pt_cnt[pt_cnt > 255] = 255 31 | cv2.imwrite(os.path.join(feature_path, 'point.png'), pt_cnt) 32 | 33 | 34 | def generate_line_image(traj_dir, grid_idx, feature_path): 35 | MIN_DISTANCE_IN_METER = 5 36 | MAX_DISTANCE_IN_METER = 300 37 | traj_line_img = np.zeros((grid_idx.row_num, grid_idx.col_num), dtype=np.uint8) 38 | for filename in tqdm(os.listdir(traj_dir)): 39 | if not filename.endswith('.txt'): 40 | continue 41 | traj_list = parse_traj_file(os.path.join(traj_dir, filename)) 42 | for traj in traj_list: 43 | one_traj_line_img = np.zeros((grid_idx.row_num, grid_idx.col_num), dtype=np.uint8) 44 | for j in range(len(traj.pt_list) - 1): 45 | cur_pt, next_pt = traj.pt_list[j], traj.pt_list[j + 1] 46 | if MIN_DISTANCE_IN_METER < distance(cur_pt, next_pt) < MAX_DISTANCE_IN_METER: 47 | try: 48 | y1, x1 = grid_idx.get_matrix_idx(cur_pt.lat, cur_pt.lng) 49 | y2, x2 = grid_idx.get_matrix_idx(next_pt.lat, next_pt.lng) 50 | cv2.line(one_traj_line_img, (x1, y1), (x2, y2), 16, 1, lineType=cv2.LINE_AA) 51 | except IndexError: 52 | continue 53 | traj_line_img = cv2.add(traj_line_img, one_traj_line_img) 54 | cv2.imwrite(os.path.join(feature_path, 'line.png'), traj_line_img) 55 | 56 | 57 | def generate_speed_data(traj_dir, grid_idx, feature_path): 58 | MIN_DISTANCE_IN_METER = 5 59 | MAX_DISTANCE_IN_METER = 300 60 | speed_data = np.zeros((grid_idx.row_num, grid_idx.col_num, 1), dtype=np.float) 61 | cnt_data = np.zeros((grid_idx.row_num, grid_idx.col_num, 1), dtype=np.float) 62 | for filename in tqdm(os.listdir(traj_dir)): 63 | if not filename.endswith('.txt'): 64 | continue 65 | traj_list = parse_traj_file(os.path.join(traj_dir, filename)) 66 | for traj in traj_list: 67 | for i in range(len(traj.pt_list) - 1): 68 | cur_pt, next_pt = traj.pt_list[i], traj.pt_list[i + 1] 69 | delta_time = (next_pt.time - cur_pt.time).total_seconds() 70 | if MIN_DISTANCE_IN_METER < distance(cur_pt, next_pt) < MAX_DISTANCE_IN_METER: 71 | try: 72 | row_idx, col_idx = grid_idx.get_matrix_idx(cur_pt.lat, cur_pt.lng) 73 | speed = distance(next_pt, cur_pt) / delta_time 74 | # 120 km/h 75 | if speed > 34: 76 | continue 77 | speed_data[row_idx, col_idx, 0] += speed 78 | cnt_data[row_idx, col_idx, 0] += 1 79 | except IndexError: 80 | continue 81 | speed_data = np.divide(speed_data, cnt_data, out=np.zeros_like(speed_data), where=cnt_data != 0) 82 | np.save(os.path.join(feature_path, 'speed.npy'), speed_data) 83 | 84 | 85 | def generate_dir_dist_data(traj_dir, grid_idx, feature_path): 86 | MIN_DISTANCE_IN_METER = 5 87 | MAX_DISTANCE_IN_METER = 300 88 | dir_data = np.zeros((grid_idx.row_num, grid_idx.col_num, 8), dtype=np.uint8) 89 | for filename in tqdm(os.listdir(traj_dir)): 90 | if not filename.endswith('.txt'): 91 | continue 92 | traj_list = parse_traj_file(os.path.join(traj_dir, filename)) 93 | for traj in traj_list: 94 | for i in range(len(traj.pt_list) - 1): 95 | cur_pt, next_pt = traj.pt_list[i], traj.pt_list[i+1] 96 | if MIN_DISTANCE_IN_METER < distance(cur_pt, next_pt) < MAX_DISTANCE_IN_METER: 97 | try: 98 | row_idx, col_idx = grid_idx.get_matrix_idx(cur_pt.lat, cur_pt.lng) 99 | direction = int(((bearing(cur_pt, next_pt) + 22.5) % 360) // 45) 100 | dir_data[row_idx, col_idx, direction] += 1 101 | except IndexError: 102 | continue 103 | np.save(os.path.join(feature_path, 'direction.npy'), dir_data) 104 | 105 | 106 | def generate_spatial_view(traj_dir, grid_idx, feature_path): 107 | generate_point_image(traj_dir, grid_idx, feature_path) 108 | generate_line_image(traj_dir, grid_idx, feature_path) 109 | generate_speed_data(traj_dir, grid_idx, feature_path) 110 | generate_dir_dist_data(traj_dir, grid_idx, feature_path) 111 | 112 | 113 | def generate_transition_view(traj_dir, grid_idx, nbhd_size, nbhd_dist, feature_path): 114 | MIN_DISTANCE_IN_METER = 5 115 | MAX_DISTANCE_IN_METER = 300 116 | meters_per_grid = grid_idx.lat_interval / LAT_PER_METER 117 | radius = int(nbhd_dist / meters_per_grid) 118 | transit_data = np.zeros((grid_idx.row_num, grid_idx.col_num, nbhd_size, nbhd_size, 2), 119 | dtype=np.uint8) 120 | for filename in tqdm(os.listdir(traj_dir)): 121 | if not filename.endswith('.txt'): 122 | continue 123 | traj_list = parse_traj_file(os.path.join(traj_dir, filename)) 124 | for traj in traj_list: 125 | for idx in range(len(traj.pt_list) - 1): 126 | cur_pt = traj.pt_list[idx] 127 | next_pt = traj.pt_list[idx + 1] 128 | if MIN_DISTANCE_IN_METER < distance(cur_pt, next_pt) < MAX_DISTANCE_IN_METER: 129 | try: 130 | global_cur_i, global_cur_j = grid_idx.get_matrix_idx(cur_pt.lat, cur_pt.lng) 131 | local_idx = get_local_idx(global_cur_i, global_cur_j, radius, grid_idx, nbhd_dist) 132 | local_next_i, local_next_j = local_idx.get_matrix_idx(next_pt.lat, next_pt.lng) 133 | transit_data[global_cur_i, global_cur_j, local_next_i, local_next_j, 0] = 1 134 | 135 | global_next_i, global_next_j = grid_idx.get_matrix_idx(next_pt.lat, next_pt.lng) 136 | local_idx = get_local_idx(global_next_i, global_next_j, radius, grid_idx, nbhd_dist) 137 | local_cur_i, local_cur_j = local_idx.get_matrix_idx(cur_pt.lat, cur_pt.lng) 138 | transit_data[global_next_i, global_next_j, local_cur_i, local_cur_j, 1] = 1 139 | except IndexError: 140 | continue 141 | np.save(os.path.join(feature_path, 'transition.npy'), transit_data) 142 | 143 | 144 | def get_local_idx(i, j, radius, grid_idx, target_region_size): 145 | min_i = i - radius 146 | max_i = i + radius 147 | min_j = j - radius 148 | max_j = j + radius 149 | local_lower_left_mbr = grid_idx.get_mbr_by_matrix_idx(max_i, min_j) 150 | local_upper_right_mbr = grid_idx.get_mbr_by_matrix_idx(min_i, max_j) 151 | local_mbr = MBR(local_lower_left_mbr.min_lat, local_lower_left_mbr.min_lng, 152 | local_upper_right_mbr.max_lat, local_upper_right_mbr.max_lng) 153 | local_idx = Grid(local_mbr, target_region_size, target_region_size) 154 | return local_idx 155 | 156 | 157 | def generate_features(traj_dir, grid_idx, nbhd_size, nbhd_dist, feature_path): 158 | os.makedirs(feature_path, exist_ok=True) 159 | generate_spatial_view(traj_dir, grid_idx, feature_path) 160 | generate_transition_view(traj_dir, grid_idx, nbhd_size, nbhd_dist, feature_path) 161 | 162 | 163 | def generate_road_centerline_label(rn_path, grid_idx, label_path): 164 | rn = load_rn_shp(rn_path) 165 | centerline_img = np.zeros((grid_idx.row_num, grid_idx.col_num), dtype=np.uint8) 166 | for edge_key in tqdm(rn.edges): 167 | coords = rn.edges[edge_key]['coords'] 168 | for i in range(len(coords)-1): 169 | start_node, end_node = coords[i], coords[i+1] 170 | try: 171 | y1, x1 = grid_idx.get_matrix_idx(start_node.lat, start_node.lng) 172 | y2, x2 = grid_idx.get_matrix_idx(end_node.lat, end_node.lng) 173 | cv2.line(centerline_img, (x1, y1), (x2, y2), 255, 1, lineType=cv2.LINE_8) 174 | except IndexError: 175 | continue 176 | cv2.imwrite(os.path.join(label_path, 'centerline.png'), centerline_img) 177 | 178 | 179 | def generate_road_region_label(centerline_path, radius, label_path): 180 | centerline_img = cv2.imread(centerline_path, cv2.IMREAD_GRAYSCALE) 181 | centerline_pixels = np.where(centerline_img == 255) 182 | H, W = centerline_img.shape 183 | region_img = np.zeros(centerline_img.shape, dtype=np.uint8) 184 | for i, j in tqdm(list(zip(centerline_pixels[0], centerline_pixels[1]))): 185 | for y in range(max(i-radius, 0), min(i+radius+1, H)): 186 | for x in range(max(j-radius, 0), min(j+radius+1, W)): 187 | region_img[y, x] = 255 188 | cv2.imwrite(os.path.join(label_path, 'region.png'), region_img) 189 | 190 | 191 | def generate_labels(rn_path, grid_idx, label_path): 192 | os.makedirs(label_path, exist_ok=True) 193 | generate_road_centerline_label(rn_path, grid_idx, label_path) 194 | generate_road_region_label(label_path + 'centerline.png', 2, label_path) 195 | 196 | 197 | def generate_samples(feature_path, label_path, grid_idx, tile_pixel_size, dataset_path): 198 | os.makedirs(dataset_path, exist_ok=True) 199 | point = cv2.imread(os.path.join(feature_path, 'point.png')) 200 | line = cv2.imread(os.path.join(feature_path, 'line.png')) 201 | speed = np.load(os.path.join(feature_path, 'speed.npy')) 202 | direction = np.load(os.path.join(feature_path, 'direction.npy')) 203 | transition = np.load(os.path.join(feature_path, 'transition.npy')) 204 | centerline = cv2.imread(os.path.join(label_path, 'centerline.png')) 205 | region = cv2.imread(os.path.join(label_path, 'region.png')) 206 | for i in tqdm(range(grid_idx.row_num // tile_pixel_size)): 207 | for j in range(grid_idx.col_num // tile_pixel_size): 208 | slices = (slice(i * tile_pixel_size, (i + 1) * tile_pixel_size), 209 | slice(j * tile_pixel_size, (j + 1) * tile_pixel_size)) 210 | image_sample = np.concatenate((point[slices], line[slices], centerline[slices], region[slices]), axis=1) 211 | cv2.imwrite(os.path.join(dataset_path, '{}_{}.png'.format(i, j)), image_sample) 212 | np.save(os.path.join(dataset_path, '{}_{}_speed.npy'.format(i, j)), speed[slices]) 213 | np.save(os.path.join(dataset_path, '{}_{}_direction.npy'.format(i, j)), direction[slices]) 214 | np.save(os.path.join(dataset_path, '{}_{}_transition.npy'.format(i, j)), transition[slices]) 215 | 216 | 217 | def split_train_val_test(dataset_path, test_row_min, test_row_max, test_col_min, test_col_max, learning_path): 218 | test_tiles = set() 219 | for row in range(test_row_min, test_row_max): 220 | for col in range(test_col_min, test_col_max): 221 | test_tiles.add('{}_{}'.format(row, col)) 222 | samples = set([name[:-4] for name in os.listdir(dataset_path) if name.endswith('.png')]) 223 | train_val_samples = samples - test_tiles 224 | val_split = 0.1 225 | train_val_samples = list(train_val_samples) 226 | nb_samples = len(train_val_samples) 227 | seed = 2017 228 | idxes = np.random.RandomState(seed=seed).permutation(nb_samples) 229 | train_split = 1 - val_split 230 | train_size = int(nb_samples * train_split) 231 | train_idxes = idxes[:train_size] 232 | val_idxes = idxes[train_size:] 233 | # create train set 234 | train_path = os.path.join(learning_path, 'train') 235 | os.makedirs(train_path, exist_ok=True) 236 | for train_idx in train_idxes: 237 | shutil.move(os.path.join(dataset_path, train_val_samples[idxes[train_idx]] + '.png'), train_path) 238 | shutil.move(os.path.join(dataset_path, train_val_samples[idxes[train_idx]] + '_direction.npy'), train_path) 239 | shutil.move(os.path.join(dataset_path, train_val_samples[idxes[train_idx]] + '_speed.npy'), train_path) 240 | shutil.move(os.path.join(dataset_path, train_val_samples[idxes[train_idx]] + '_transition.npy'), train_path) 241 | # create val set 242 | val_path = os.path.join(learning_path, 'val') 243 | os.makedirs(val_path, exist_ok=True) 244 | for val_idx in val_idxes: 245 | shutil.move(os.path.join(dataset_path, train_val_samples[idxes[val_idx]] + '.png'), val_path) 246 | shutil.move(os.path.join(dataset_path, train_val_samples[idxes[val_idx]] + '_direction.npy'), val_path) 247 | shutil.move(os.path.join(dataset_path, train_val_samples[idxes[val_idx]] + '_speed.npy'), val_path) 248 | shutil.move(os.path.join(dataset_path, train_val_samples[idxes[val_idx]] + '_transition.npy'), val_path) 249 | # create test set 250 | test_path = os.path.join(learning_path, 'test') 251 | os.makedirs(test_path, exist_ok=True) 252 | for test_tile in test_tiles: 253 | shutil.move(os.path.join(dataset_path, test_tile + '.png'), test_path) 254 | shutil.move(os.path.join(dataset_path, test_tile + '_direction.npy'), test_path) 255 | shutil.move(os.path.join(dataset_path, test_tile + '_speed.npy'), test_path) 256 | shutil.move(os.path.join(dataset_path, test_tile + '_transition.npy'), test_path) 257 | 258 | 259 | if __name__ == '__main__': 260 | with open(sys.argv[1], 'r') as f: 261 | conf = json.load(f) 262 | traj_dir = '../data/{}/traj/'.format(conf['dataset']['dataset_name']) 263 | rn_path = '../data/{}/rn/'.format(conf['dataset']['dataset_name']) 264 | feature_path = '../data/{}/feature/'.format(conf['dataset']['dataset_name']) 265 | label_path = '../data/{}/label/'.format(conf['dataset']['dataset_name']) 266 | dataset_path = '../data/{}/dataset/'.format(conf['dataset']['dataset_name']) 267 | learning_path = '../data/{}/learning/'.format(conf['dataset']['dataset_name']) 268 | mbr = MBR(conf['dataset']['min_lat'], conf['dataset']['min_lng'], 269 | conf['dataset']['max_lat'], conf['dataset']['max_lng']) 270 | grid_idx = Grid(mbr, conf['dataset']['nb_rows'], conf['dataset']['nb_cols']) 271 | generate_features(traj_dir, grid_idx, conf['feature_extraction']['nbhd_size'], 272 | conf['feature_extraction']['nbhd_dist'], feature_path) 273 | generate_labels(rn_path, grid_idx, label_path) 274 | generate_samples(feature_path, label_path, grid_idx, conf['feature_extraction']['tile_pixel_size'], dataset_path) 275 | split_train_val_test(dataset_path, conf['feature_extraction']['test_tile_row_min'], 276 | conf['feature_extraction']['test_tile_row_max'], 277 | conf['feature_extraction']['test_tile_col_min'], 278 | conf['feature_extraction']['test_tile_col_max'], learning_path) 279 | -------------------------------------------------------------------------------- /topology_construction/link_generation.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from tptk.common.spatial_func import LAT_PER_METER, LNG_PER_METER, SPoint, distance, project_pt_to_line, cal_loc_along_line, angle 3 | from tptk.common.mbr import MBR 4 | from tptk.common.road_network import store_rn_shp 5 | from topology_construction.topo_utils import line_ray_intersection_test, is_line_line_intersected, angle_between 6 | import numpy as np 7 | 8 | 9 | class VirtualLink: 10 | def __init__(self, end_node, target_segment, split_edge_idx, split_edge_offset): 11 | self.end_node = end_node 12 | self.target_segment = target_segment 13 | self.split_edge_idx = split_edge_idx 14 | self.split_edge_offset = split_edge_offset 15 | 16 | def __repr__(self): 17 | return '(index:{},offset:{})'.format(self.split_edge_idx, self.split_edge_offset) 18 | 19 | 20 | class LinkGenerator: 21 | def __init__(self, radius): 22 | self.radius = radius 23 | self.NO_NEW_VERTEX_OFFSET = 15 24 | self.SIMILAR_DIRECTION_THRESHOLD = 20 25 | 26 | def generate_pt_to_link(self, extracted_rn, last_edge_of_dead_end, target_coords, dead_segment, target_segment): 27 | """ 28 | 29 | :param extracted_rn: 30 | :param last_edge_of_dead_end: the last edge of the road segment containing the dead end 31 | :param target_coords: the coords of the target segment 32 | :param dead_segment: key 33 | :param target_segment: key 34 | :return: 35 | """ 36 | f, o = last_edge_of_dead_end 37 | opposite_of_o = dead_segment[1] if dead_segment[0][0] == o.lng and dead_segment[0][1] == o.lat else dead_segment[0] 38 | opposite_of_o = SPoint(opposite_of_o[1], opposite_of_o[0]) 39 | # short isolated segment, the direction might be unreliable, we add the perpendicular edge to the neighborhood 40 | if extracted_rn.edges[dead_segment]['length'] < self.NO_NEW_VERTEX_OFFSET and \ 41 | extracted_rn.degree(dead_segment[0]) == 1 and extracted_rn.degree(dead_segment[1]) == 1: 42 | return self.perpendicular_intersection(o, target_coords, extracted_rn, opposite_of_o, target_segment) 43 | return self.extension_intersection(o, f, target_coords) 44 | 45 | def cal_projection(self, pt, target_coords): 46 | split_edge_idx = float('inf') 47 | split_edge_offset = float('inf') 48 | min_dist = float('inf') 49 | candidates = [project_pt_to_line(target_coords[i], target_coords[i + 1], pt) for i in 50 | range(len(target_coords) - 1)] 51 | if len(candidates) > 0: 52 | for i in range(len(candidates)): 53 | if candidates[i][1] <= 0.0 or candidates[i][1] >= 1.0: 54 | continue 55 | if candidates[i][2] < min_dist and candidates[i][2] < self.radius: 56 | min_dist = candidates[i][2] 57 | split_edge_idx = i 58 | split_edge_offset = candidates[i][1] 59 | return min_dist, split_edge_idx, split_edge_offset 60 | 61 | def extension_intersection(self, o, f, target_coords): 62 | min_dist = float('inf') 63 | split_edge_idx = float('inf') 64 | split_edge_offset = float('inf') 65 | # check whether internal edge has intersection (if multiple intersections, select the shortest) 66 | for i in range(0, len(target_coords) - 1): 67 | a = target_coords[i] 68 | b = target_coords[i + 1] 69 | result = line_ray_intersection_test(o, f, a, b) 70 | if result is None or result[0] < 0 or result[0] > 1: 71 | continue 72 | else: 73 | dist_tmp = distance(o, result[1]) 74 | if dist_tmp < self.radius and dist_tmp < min_dist: 75 | nearest_node_with_offset = (target_coords[i], 0.0) if result[0] < 0.5 else \ 76 | (target_coords[i + 1], 1.0) 77 | min_dist = dist_tmp 78 | split_edge_idx = i 79 | # prefer link to existing nodes if too short 80 | if distance(nearest_node_with_offset[0], result[1]) < self.NO_NEW_VERTEX_OFFSET: 81 | split_edge_offset = nearest_node_with_offset[1] 82 | else: 83 | split_edge_offset = result[0] 84 | # doesn't have internal intersection, check whether has smooth transition 85 | if split_edge_idx == float('inf'): 86 | # check start node 87 | tmp_dist = distance(target_coords[0], o) 88 | if tmp_dist < self.radius and \ 89 | angle_between((o.lng - f.lng, o.lat - f.lat), 90 | (target_coords[0].lng - o.lng, target_coords[0].lat - o.lat)) <= 0.5 * np.pi: 91 | min_dist = tmp_dist 92 | split_edge_idx = 0 93 | split_edge_offset = 0.0 94 | # check end node 95 | tmp_dist = distance(target_coords[-1], o) 96 | if tmp_dist < self.radius and \ 97 | angle_between((o.lng - f.lng, o.lat - f.lat), 98 | (target_coords[-1].lng - o.lng, target_coords[-1].lat - o.lat)) <= 0.5 * np.pi: 99 | if tmp_dist < min_dist: 100 | split_edge_idx = len(target_coords) - 2 101 | split_edge_offset = 1.0 102 | return split_edge_idx, split_edge_offset 103 | 104 | def perpendicular_intersection(self, o, target_coords, rn, opposite_of_o, target_segment): 105 | split_edge_idx = float('inf') 106 | split_edge_offset = float('inf') 107 | o_min_dist, o_split_edge_idx, o_split_edge_offset = self.cal_projection(o, target_coords) 108 | other_min_dist, other_split_edge_idx, other_split_edge_offset = self.cal_projection(opposite_of_o, 109 | target_coords) 110 | if o_min_dist < other_min_dist: 111 | split_edge_idx = o_split_edge_idx 112 | split_edge_offset = o_split_edge_offset 113 | # no internal intersection 114 | if split_edge_idx == float('inf'): 115 | # the target segment is also a short isolated one 116 | if rn.edges[target_segment]['length'] < self.NO_NEW_VERTEX_OFFSET and \ 117 | rn.degree(target_segment[0]) == 1 and rn.degree(target_segment[1]) == 1: 118 | a = SPoint(target_segment[0][1], target_segment[0][0]) 119 | b = SPoint(target_segment[1][1], target_segment[1][0]) 120 | dead_end_to_target_dist = min(distance(o, a), distance(o, b)) 121 | opposite_to_target_dist = min(distance(opposite_of_o, a), distance(opposite_of_o, b)) 122 | # current dead end is shorter to the target than opposite, link with the nearest vertex 123 | if dead_end_to_target_dist < opposite_to_target_dist and dead_end_to_target_dist < self.radius: 124 | target = a if distance(o, a) < distance(o, b) else b 125 | if target == target_coords[0]: 126 | split_edge_idx = 0 127 | split_edge_offset = 0.0 128 | else: 129 | split_edge_idx = len(target_coords) - 2 130 | split_edge_offset = 1.0 131 | return split_edge_idx, split_edge_offset 132 | 133 | def divide_segment(self, ori_coords, virtual_links): 134 | """ 135 | :param ori_coords: 136 | :param virtual_links: 137 | :return: list of edges to add (start, end, coords) 138 | """ 139 | splitted_segments = [] 140 | link_segments = [] 141 | # aggregate by edge_idx, and offset, i.e., new nodes 142 | loc2virtual_links = {} 143 | for virtual_link in virtual_links: 144 | split_edge_idx = virtual_link.split_edge_idx 145 | split_edge_offset = virtual_link.split_edge_offset 146 | # (big_edge_index, 0.0):will not appear, distance is only updated if next edge distance is smaller 147 | if (split_edge_idx, split_edge_offset) not in loc2virtual_links: 148 | loc2virtual_links[(split_edge_idx, split_edge_offset)] = [] 149 | loc2virtual_links[(split_edge_idx, split_edge_offset)].append(virtual_link) 150 | # (edge_idx,offset), pt, virtual link list 151 | node_seq = [] 152 | for loc in loc2virtual_links: 153 | edge_idx, edge_offset_rate = loc 154 | # if edge_offset_rate == 0.0: 155 | if edge_offset_rate <= 0.0: 156 | new_node = ori_coords[edge_idx] 157 | # elif edge_offset_rate == 1.0: 158 | elif edge_offset_rate >= 1.0: 159 | new_node = ori_coords[edge_idx + 1] 160 | else: 161 | new_node = cal_loc_along_line(ori_coords[edge_idx], ori_coords[edge_idx + 1], edge_offset_rate) 162 | node_seq.append((new_node, loc, loc2virtual_links[loc])) 163 | node_seq = sorted(node_seq, key=lambda data: (data[1][0], data[1][1])) 164 | # add the ori start node if not added 165 | if node_seq[0][1][0] != 0 or node_seq[0][1][1] != 0.0: 166 | node_seq = [(ori_coords[0], (0, 0.0), [])] + node_seq 167 | # add the ori end node if not added 168 | if node_seq[-1][1][0] != len(ori_coords) - 2 or node_seq[-1][1][1] != 1.0: 169 | node_seq.append((ori_coords[-1], (len(ori_coords) - 2, 1.0), [])) 170 | # add splitted segment 171 | for i in range(len(node_seq) - 1): 172 | from_node, (from_edge_idx, from_offset), _ = node_seq[i] 173 | to_node, (to_edge_idx, to_offset), _ = node_seq[i + 1] 174 | shape = [] 175 | shape.append(from_node) 176 | if from_edge_idx != to_edge_idx: 177 | for j in range(from_edge_idx + 1, to_edge_idx + 1): 178 | shape.append(ori_coords[j]) 179 | shape.append(to_node) 180 | splitted_segments.append((from_node, to_node, shape)) 181 | # add links 182 | for node in node_seq: 183 | node_pt, _, node_virtual_links = node 184 | for node_virtual_link in node_virtual_links: 185 | shape = [node_pt, node_virtual_link.end_node] 186 | link_segments.append((node_pt, node_virtual_link.end_node, shape)) 187 | return splitted_segments, link_segments 188 | 189 | def update_link(self, linked_rn, from_pt, to_pt, coords, avail_eid): 190 | """ 191 | make sure two link will not have too similar direction 192 | """ 193 | is_valid = True 194 | link_dist = distance(from_pt, to_pt) 195 | # if the new edge is shorter, add new edge and delete old edge 196 | # check from pt 197 | links_with_from = [edge for edge in list(linked_rn.edges((from_pt.lng, from_pt.lat))) if 198 | linked_rn.edges[edge]['type'] == 'virtual'] 199 | edges_to_delete = [] 200 | for u, v in links_with_from: 201 | other_node = v if u[0] == from_pt.lng and u[1] == from_pt.lat else u 202 | ang = angle(from_pt, to_pt, from_pt, SPoint(other_node[1], other_node[0])) 203 | if ang < self.SIMILAR_DIRECTION_THRESHOLD: 204 | if link_dist >= linked_rn[u][v]['length']: 205 | is_valid = False 206 | break 207 | else: 208 | edges_to_delete.append((u, v)) 209 | # check to pt 210 | links_with_to = [edge for edge in list(linked_rn.edges((to_pt.lng, to_pt.lat))) if 211 | linked_rn.edges[edge]['type'] == 'virtual'] 212 | for u, v in links_with_to: 213 | other_node = v if u[0] == to_pt.lng and u[1] == to_pt.lat else u 214 | ang = angle(to_pt, from_pt, to_pt, SPoint(other_node[1], other_node[0])) 215 | if ang < self.SIMILAR_DIRECTION_THRESHOLD: 216 | if link_dist >= linked_rn[u][v]['length']: 217 | is_valid = False 218 | break 219 | else: 220 | edges_to_delete.append((u, v)) 221 | if is_valid: 222 | linked_rn.add_edge((from_pt.lng, from_pt.lat), (to_pt.lng, to_pt.lat), coords=coords, eid=avail_eid, 223 | type='virtual') 224 | for u, v in edges_to_delete: 225 | if linked_rn.has_edge(u, v): 226 | # didn't destroy the connectivity 227 | if linked_rn.degree(u) == 2 or linked_rn.degree(v) == 2: 228 | continue 229 | linked_rn.remove_edge(u, v) 230 | # linked_rn.add_edge((from_pt.lng, from_pt.lat), (to_pt.lng, to_pt.lat), coords=coords, eid=avail_eid, type='virtual') 231 | 232 | def is_intersected_with_existing_edges(self, rn, virtual_link, candidates): 233 | is_intersected = False 234 | f = virtual_link.end_node 235 | edge_idx = virtual_link.split_edge_idx 236 | edge_offset_rate = virtual_link.split_edge_offset 237 | ori_segment = virtual_link.target_segment 238 | ori_coords = rn[ori_segment[0]][ori_segment[1]]['coords'] 239 | o = cal_loc_along_line(ori_coords[edge_idx], ori_coords[edge_idx + 1], edge_offset_rate) 240 | for candidate in candidates: 241 | if candidate == virtual_link.target_segment: 242 | continue 243 | u, v = candidate 244 | coords = rn[u][v]['coords'] 245 | for i in range(len(coords) - 1): 246 | if is_line_line_intersected(f, o, coords[i], coords[i + 1]): 247 | is_intersected = True 248 | break 249 | if is_intersected: 250 | break 251 | return is_intersected 252 | 253 | def remove_similar_links(self, rn, edge_virtual_links): 254 | new_edge_virtual_links = copy.copy(edge_virtual_links) 255 | o = edge_virtual_links[0].end_node 256 | stable = False 257 | while not stable: 258 | stable = True 259 | for i in range(len(new_edge_virtual_links) - 1): 260 | link_a = new_edge_virtual_links[i] 261 | a = cal_loc_along_line(rn.edges[link_a.target_segment]['coords'][link_a.split_edge_idx], 262 | rn.edges[link_a.target_segment]['coords'][link_a.split_edge_idx + 1], 263 | link_a.split_edge_offset) 264 | for j in range(i + 1, len(new_edge_virtual_links)): 265 | link_b = new_edge_virtual_links[j] 266 | b = cal_loc_along_line(rn.edges[link_b.target_segment]['coords'][link_b.split_edge_idx], 267 | rn.edges[link_b.target_segment]['coords'][link_b.split_edge_idx + 1], 268 | link_b.split_edge_offset) 269 | # if small angle 270 | if angle(o, a, o, b) < self.SIMILAR_DIRECTION_THRESHOLD: 271 | # delete longer edge 272 | if distance(o, a) < distance(o, b): 273 | new_edge_virtual_links.remove(new_edge_virtual_links[j]) 274 | else: 275 | new_edge_virtual_links.remove(new_edge_virtual_links[i]) 276 | stable = False 277 | break 278 | if not stable: 279 | break 280 | return new_edge_virtual_links 281 | 282 | def generate(self, init_rn, linked_rn_path): 283 | """ 284 | :param init_rn: it must be undirected 285 | :param linked_rn_path: the output is directed 286 | :return: 287 | """ 288 | linked_rn = copy.deepcopy(init_rn) 289 | HALF_DELTA_LAT = LAT_PER_METER * self.radius 290 | HALF_DELTA_LNG = LNG_PER_METER * self.radius 291 | dead_end_cnt = 0 292 | virtual_links = [] 293 | for node, degree in init_rn.degree(): 294 | if degree == 1: 295 | lng, lat = node 296 | dead_end_pt = SPoint(lat, lng) 297 | # get opposite node 298 | u = list(init_rn.adj[node])[0] 299 | v = node 300 | seg_coords = init_rn[u][v]['coords'] 301 | if seg_coords[0].lat == lat and seg_coords[0].lng == lng: 302 | last_edge_of_dead_end = (seg_coords[1], dead_end_pt) 303 | elif seg_coords[-1].lat == lat and seg_coords[-1].lng == lng: 304 | last_edge_of_dead_end = (seg_coords[-2], dead_end_pt) 305 | else: 306 | raise Exception('error, coords ends is not consistent with node') 307 | dead_end_cnt += 1 308 | query_mbr = MBR(lat - HALF_DELTA_LAT, lng - HALF_DELTA_LNG, lat + HALF_DELTA_LAT, lng + HALF_DELTA_LNG) 309 | # get nearby candidate road segments to be linked (except for the self) 310 | candidates = init_rn.range_query(query_mbr) 311 | candidates = [candidate for candidate in candidates if 312 | not (candidate[0] in [u, v] and candidate[1] in [u, v])] 313 | if len(candidates) == 0: 314 | continue 315 | edge_virtual_links = [] 316 | # calculate the linking position 317 | for candidate in candidates: 318 | split_edge_idx, split_edge_offset = self.generate_pt_to_link(init_rn, last_edge_of_dead_end, 319 | init_rn[candidate[0]][candidate[1]]['coords'], 320 | (u, v), candidate) 321 | if split_edge_idx == float('inf'): 322 | continue 323 | virtual_link = VirtualLink(dead_end_pt, candidate, split_edge_idx, split_edge_offset) 324 | # not intersect with existing roads 325 | if not self.is_intersected_with_existing_edges(init_rn, virtual_link, candidates): 326 | edge_virtual_links.append(virtual_link) 327 | if len(edge_virtual_links) > 0: 328 | # remove links from the same dead end that have similar direction (only the shortest link is reserved) 329 | edge_virtual_links = self.remove_similar_links(init_rn, edge_virtual_links) 330 | virtual_links.extend(edge_virtual_links) 331 | print('number of dead ends:{}'.format(dead_end_cnt)) 332 | segment2infos = {} 333 | for virtual_link in virtual_links: 334 | segment = virtual_link.target_segment 335 | if segment not in segment2infos: 336 | segment2infos[segment] = [] 337 | segment2infos[segment].append(virtual_link) 338 | avail_eid = max([eid for u, v, eid in init_rn.edges.data(data='eid')]) + 1 339 | for segment in segment2infos: 340 | ori_coords = init_rn[segment[0]][segment[1]]['coords'] 341 | virtual_links = segment2infos[segment] 342 | splitted_segments, link_segments = self.divide_segment(ori_coords, virtual_links) 343 | for from_pt, to_pt, coords in splitted_segments: 344 | linked_rn.add_edge((from_pt.lng, from_pt.lat), (to_pt.lng, to_pt.lat), coords=coords, eid=avail_eid, 345 | type=init_rn[segment[0]][segment[1]]['type']) 346 | avail_eid += 1 347 | for from_pt, to_pt, coords in link_segments: 348 | # check whether this link should be added, and other links should be removed (similar direction, but the shortest) 349 | self.update_link(linked_rn, from_pt, to_pt, coords, avail_eid) 350 | avail_eid += 1 351 | # if a segment is splitted to multiple segments, remove the original segment 352 | if len(splitted_segments) > 1: 353 | linked_rn.remove_edge(segment[0], segment[1]) 354 | linked_rn_directed = linked_rn.to_directed() 355 | store_rn_shp(linked_rn_directed, linked_rn_path) 356 | --------------------------------------------------------------------------------