├── .gitignore ├── DEBUGtest ├── RL_planner │ ├── train_dqn_planner.py │ └── train_ppo_planner.py ├── curve_gallery.py ├── data.pt ├── debug_frenet.py ├── planner_arena.py ├── profile_cubicspline.py ├── profile_polynomial.py ├── test_SAT.py ├── test_carla_interface.py ├── test_disk_check.py ├── test_dummybenchmark.py ├── test_exp_buffer.py ├── test_frenet.py ├── test_frenetoptimizer.py ├── test_gru_planner.py ├── test_highway_env.py ├── test_import_spider.py ├── test_log_buffer.py ├── test_mlp_planner.py ├── test_prob_planner.py ├── test_relative_transform.py ├── test_twopt_boundary.py └── test_visualize │ ├── test_elements.py │ ├── test_line.py │ ├── test_surface.py │ └── test_surface3d.py ├── README.md ├── __init__.py ├── _misc.py ├── _virtual_import.py ├── constraints ├── ConstraintCollection.py ├── __init__.py ├── constraint_checker │ ├── BaseConstraintChecker.py │ ├── CartConstraintChecker.py │ ├── ControlConstraintChecker.py │ ├── FrenetConstraintChecker.py │ └── __init__.py └── constraint_formulator │ └── __init__.py ├── control ├── IDMController.py ├── SimpleController.py ├── __init__.py ├── lateral │ ├── PurePursuitController.py │ └── __init__.py ├── longitudinal │ ├── IDMLonController.py │ ├── PIDLonController.py │ └── __init__.py └── vehicle_model │ ├── __init__.py │ ├── bicycle.png │ ├── bicycle.py │ └── steer_curvature.jpg ├── data ├── DataBuffer.py ├── Dataset.py ├── __init__.py ├── common.py ├── data_factory.py └── decorators.py ├── display_assests ├── DQNPlanner0.gif ├── GRU_log_replay.gif ├── LatticePlanner.gif ├── LatticePlanner_highway.gif ├── OptimizedLatticePlanner.gif ├── carla_test.png ├── common_tools.png ├── framework.png └── planner_arena.png ├── elements ├── __init__.py ├── _poly_calc.c ├── _poly_calc.pyx ├── box.py ├── curves.py ├── graph.py ├── grid.py ├── map.py ├── trajectory.py └── vehicle.py ├── evaluator ├── CostEvaluator.py └── __init__.py ├── interface ├── BaseBenchmark.py ├── BaseInterface.py ├── __init__.py ├── carla │ ├── CarlaInterface.py │ ├── __init__.py │ ├── _control_utils.py │ ├── _light_utils.py │ ├── _route_utils.py │ ├── _weather_utils.py │ ├── common.py │ ├── presets.py │ └── visualize.py ├── highway_env │ ├── HighwayEnvBenchmark.py │ ├── HighwayEnvInterface.py │ └── __init__.py ├── metrics_collection.py └── nuplan │ └── __init__.py ├── optimize ├── APFGradDescend.py ├── BaseOptimizer.py ├── TrajectoryOptimizer.py ├── __init__.py └── common.py ├── param.py ├── planner_zoo ├── BaseNeuralPlanner.py ├── BasePlanner.py ├── BaseSamplerPlanner.py ├── BezierPlanner.py ├── DDQNPlanner.py ├── DQNPlanner.py ├── DiscretePPOPlanner.py ├── FallbackDummyPlanner.py ├── FallbackPlanner.py ├── GRUPlanner.py ├── IDMPlanner.py ├── ImaginaryPlanner.py ├── LatticePlanner.py ├── MlpPlanner.py ├── OptimizedGRUPlanner.py ├── OptimizedLatticePlanner.py ├── PiecewiseLatticePlanner.py ├── ProbabilisticPlanner.py └── __init__.py ├── rl ├── PlannerGym.py ├── __init__.py ├── _tensor_utils.py ├── action │ ├── ActionConverter.py │ └── __init__.py ├── convert.py ├── policy │ ├── BasePolicy.py │ ├── ClassificationILPolicy.py │ ├── DDQNPolicy.py │ ├── DQNPolicy.py │ ├── PPOPolicy.py │ ├── RegressionILPolicy.py │ └── __init__.py ├── reward │ ├── BaseReward.py │ ├── TerminateReward.py │ ├── TrajectoryReward.py │ ├── __init__.py │ └── reward_collection.py ├── state │ ├── StateConverter.py │ └── __init__.py └── transition │ ├── DeterministicTransition.py │ ├── GaussianTransition.py │ └── __init__.py ├── sampler ├── BaseSampler.py ├── Combiner.py ├── LatticeSampler.py ├── PathSampler.py ├── PolynomialSampler.py ├── __init__.py └── common.py ├── teaser.py ├── tutorial ├── data_engine.ipynb ├── env_wrapper.ipynb └── interface.ipynb ├── utils ├── ImaginaryEngine.py ├── __init__.py ├── collision │ ├── AABB.py │ ├── CollisionChecker.py │ ├── CollisionConstraints.py │ ├── SAT.py │ ├── __init__.py │ ├── disk_num_calculation.jpg │ ├── disks.py │ └── ray_cast.py ├── geometry.py ├── lane_decision.py ├── potential_field │ ├── __init__.py │ ├── potential_field.py │ ├── static_risk.py │ └── velocity_oriented_risk.py ├── predict │ ├── __init__.py │ ├── common.py │ └── linear.py ├── transform │ ├── __init__.py │ ├── frenet.py │ ├── gps.py │ ├── grid.py │ ├── polar.py │ ├── reference.zip │ └── relative.py └── vector.py └── visualize ├── __init__.py ├── common.py ├── dashboard └── __init__.py ├── elements.py ├── line.py ├── point.py ├── surface.py └── surface3d.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore compiled Python files 2 | __pycache__/ 3 | **/__pycache__/ 4 | *.pyc 5 | *.pyo 6 | *.pyd 7 | *.pt 8 | *.pth 9 | 10 | # Ignore virtual environment 11 | venv/ 12 | 13 | planner_zoo/*.pth 14 | planner_zoo/*.pt 15 | planner_zoo/*.mp4 16 | planner_zoo/*.avi 17 | 18 | rl/backup/ 19 | 20 | elements/_pyxbld/ 21 | visualize/matplotlib_cheatsheet/ 22 | DEBUGtest/videos/ 23 | DEBUGtest/*.avi 24 | DEBUGtest/*.mp4 25 | DEBUGtest/*.png 26 | **/dataset* 27 | **/tensorboard 28 | 29 | interface/carla/example 30 | interface/carla/carla_veh_light 31 | interface/carla/PythonAPI 32 | 33 | *.zip 34 | *.bak 35 | *.bak* 36 | upload.txt 37 | todo.txt 38 | 39 | 40 | -------------------------------------------------------------------------------- /DEBUGtest/RL_planner/train_dqn_planner.py: -------------------------------------------------------------------------------- 1 | import spider.visualize as vis 2 | import tqdm 3 | 4 | class Trainner: 5 | ''' 6 | todo:以后加一个把环境打包成gym环境的功能 7 | ''' 8 | def __init__(self, env_interface, reward_function, visualize=False): 9 | self.env_interface = env_interface 10 | self.reward_function = reward_function 11 | self._visualize = visualize 12 | 13 | 14 | def train(self, planner, train_steps, batch_size=64): 15 | # todo: 是一个step触发训练,还是一个episode触发训练? 16 | # 以及一轮训练的次数是1吗?可以参考stable baselines3 17 | 18 | policy = planner.policy 19 | exp_buffer = planner.exp_buffer 20 | 21 | exp_buffer.apply_to(policy, self.reward_function) # 开始监听 22 | 23 | obs, done = None, True 24 | 25 | policy.set_exploration(enable=True) 26 | 27 | for i in tqdm.tqdm(range(train_steps)): 28 | if done: 29 | obs = self.env_interface.reset() 30 | 31 | # forward 32 | plan = planner.plan(*obs) # 监听exp_buffer记录了obs, plan 33 | self.env_interface.conduct_trajectory(plan) 34 | obs2 = self.env_interface.wrap_observation() 35 | 36 | # feedback 37 | reward, done = self.reward_function.evaluate_log(obs, plan, obs2) # 监听exp_buffer记录了reward, done 38 | policy.try_write_reward(reward, done, i) 39 | 40 | # 学习 41 | batched_data = exp_buffer.sample(batch_size) 42 | policy.learn_batch(*batched_data) 43 | 44 | # visualize 45 | if self._visualize: 46 | vis.cla() 47 | vis.lazy_draw(*obs, plan) 48 | vis.title(f"Step {i}, Reward {reward}") 49 | vis.pause(0.001) 50 | 51 | obs = obs2 52 | 53 | policy.set_exploration(enable=False) 54 | if __name__ == '__main__': 55 | from spider.interface import DummyInterface, DummyBenchmark 56 | from spider.planner_zoo.DQNPlanner import DQNPlanner 57 | from spider.planner_zoo.DDQNPlanner import DDQNPlanner 58 | from spider.rl.reward.TrajectoryReward import TrajectoryReward 59 | 60 | # presets 61 | ego_size = (5.,2.) 62 | 63 | # setup env 64 | env_interface = DummyInterface() 65 | 66 | # setup reward 67 | reward_function = TrajectoryReward( 68 | (-10., 280.), (-15, 15), (240., 280.), (-10,10), ego_size 69 | ) 70 | 71 | # setup_planner 72 | planner_dqn = DQNPlanner({ 73 | "ego_veh_width": ego_size[1], 74 | "ego_veh_length": ego_size[0], 75 | "enable_tensorboard": True, 76 | }) 77 | 78 | planner_school = Trainner(env_interface, reward_function, visualize=False) 79 | planner_school.train(planner_dqn, 50000, 64) 80 | planner_dqn.policy.save_model('./q_net.pth') 81 | 82 | planner_dqn.policy.load_model('./q_net.pth') 83 | DummyBenchmark({"save_video": True,}).test(planner_dqn) -------------------------------------------------------------------------------- /DEBUGtest/RL_planner/train_ppo_planner.py: -------------------------------------------------------------------------------- 1 | import spider.visualize as vis 2 | import tqdm 3 | 4 | class Trainner: 5 | ''' 6 | todo:以后加一个把环境打包成gym环境的功能 7 | ''' 8 | def __init__(self, env_interface, reward_function, visualize=False): 9 | self.env_interface = env_interface 10 | self.reward_function = reward_function 11 | 12 | self.max_eps_len = 150 13 | 14 | self.n_epochs = 10 15 | self._visualize = visualize 16 | 17 | 18 | def train(self, planner, train_steps, batch_size=16): 19 | # todo: 是一个step触发训练,还是一个episode触发训练? 20 | # 以及一轮训练的次数是1吗?可以参考stable baselines3 21 | 22 | policy = planner.policy 23 | exp_buffer = planner.exp_buffer 24 | 25 | exp_buffer.apply_to(policy, self.reward_function) # 开始监听 26 | 27 | obs, done = self.env_interface.reset(), False 28 | 29 | policy.set_exploration(enable=True) 30 | 31 | for i in tqdm.tqdm(range(train_steps)): 32 | 33 | # forward 34 | plan = planner.plan(*obs) # 监听exp_buffer记录了obs, plan 35 | self.env_interface.conduct_trajectory(plan) 36 | obs2 = self.env_interface.wrap_observation() 37 | 38 | # feedback 39 | reward, done = self.reward_function.evaluate_log(obs, plan, obs2) # 监听exp_buffer记录了reward, done 40 | policy.try_write_reward(reward, done, i) 41 | 42 | 43 | # visualize 44 | if self._visualize: 45 | vis.cla() 46 | vis.lazy_draw(*obs, plan) 47 | vis.title(f"Step {i}, Reward {reward}") 48 | vis.pause(0.001) 49 | 50 | if done: 51 | # 一个episode结束,更新网络参数,学习轨迹 52 | policy.learn_buffer(exp_buffer, batch_size,self.n_epochs) 53 | obs = self.env_interface.reset() 54 | exp_buffer.clear() 55 | else: 56 | obs = obs2 57 | 58 | policy._activate_exp_buffer = False 59 | policy.set_exploration(enable=False) 60 | 61 | 62 | if __name__ == '__main__': 63 | from spider.interface import DummyInterface, DummyBenchmark 64 | from spider.rl.reward.TrajectoryReward import TrajectoryReward 65 | from spider.planner_zoo.DiscretePPOPlanner import DiscretePPOPlanner 66 | 67 | # presets 68 | ego_size = (5.,2.) 69 | 70 | # setup env 71 | env_interface = DummyInterface() 72 | 73 | # setup reward 74 | reward_function = TrajectoryReward( 75 | (-10., 280.), (-15, 15), (240., 280.), (-10,10), ego_size 76 | ) 77 | 78 | # setup_planner 79 | planner = DiscretePPOPlanner({ 80 | "ego_veh_width": ego_size[1], 81 | "ego_veh_length": ego_size[0], 82 | "enable_tensorboard": True, 83 | }) 84 | 85 | planner_school = Trainner(env_interface, reward_function, visualize=False) 86 | planner_school.train(planner, 10000) 87 | planner.policy.save_model('./ppo.pth') 88 | 89 | planner.policy.load_model('./ppo.pth') 90 | DummyBenchmark({"save_video": True,}).test(planner) -------------------------------------------------------------------------------- /DEBUGtest/curve_gallery.py: -------------------------------------------------------------------------------- 1 | import math 2 | from spider.elements.curves import * 3 | import matplotlib.pyplot as plt 4 | 5 | ######## 显式曲线 ####### 6 | 7 | x_0, x_end = 0., 30. 8 | y_0, y_end = 0., 3.5 9 | yaw0 = -10*math.pi/180 10 | y_prime_0 = math.tan(yaw0) 11 | 12 | yaw_end = 0*math.pi/180 13 | y_prime_end = math.tan(yaw_end) 14 | 15 | y_2prime_0, y_2prime_end = 0., 0. 16 | 17 | poly3 = CubicPolynomial.from_kine_states(x_0, y_0, yaw0, x_end, y_end, yaw_end) 18 | poly5 = QuinticPolynomial.from_kine_states(y_0, y_prime_0, y_2prime_0, y_end, y_prime_end, y_2prime_end, x_end) 19 | poly4 = QuarticPolynomial.from_kine_states(y_0, y_prime_0, y_2prime_0, y_prime_end, y_2prime_end, x_end) 20 | 21 | all_points_with_derivatives = np.array([ 22 | [x_0, y_0, y_prime_0, y_2prime_0], 23 | [10, 3.5, 0., 0.], 24 | [20, 1.0, 0., 0.], 25 | [x_end, y_end, y_prime_end, y_2prime_end] 26 | ]) 27 | piecewise_poly5 = PiecewiseQuinticPolynomial(all_points_with_derivatives) 28 | 29 | 30 | 31 | xx = np.linspace(x_0-2, x_end+2, 1000) 32 | rows, cols = 2,2 33 | plt.figure(figsize=(12,8)) 34 | plt.subplot(rows,cols,1) 35 | # plt.title 36 | plt.plot(xx, poly3(xx, order=0), label='cubic') 37 | plt.plot(xx, poly4(xx, order=0), label='quartic') 38 | plt.plot(xx, poly5(xx, order=0), label='quintic') 39 | plt.plot(xx, piecewise_poly5(xx, order=0), label='piecewise_quintic') 40 | plt.plot(x_0, poly3(x_0, order=0),'xr') 41 | plt.plot(x_0, poly4(x_0, order=0),'xr') 42 | plt.plot(x_0, poly5(x_0, order=0),'xr') 43 | plt.plot(x_0, piecewise_poly5(x_0, order=0), 'xr') 44 | plt.legend() 45 | plt.subplot(rows,cols,2) 46 | plt.plot(xx, poly3(xx, order=1), label='cubic') 47 | plt.plot(xx, poly4(xx, order=1), label='quartic') 48 | plt.plot(xx, poly5(xx, order=1), label='quintic') 49 | plt.plot(xx, piecewise_poly5(xx, order=1), label='piecewise_quintic') 50 | plt.legend() 51 | plt.subplot(rows,cols,3) 52 | plt.plot(xx, poly3(xx, order=2), label='cubic') 53 | plt.plot(xx, poly4(xx, order=2), label='quartic') 54 | plt.plot(xx, poly5(xx, order=2), label='quintic') 55 | plt.plot(xx, piecewise_poly5(xx, order=2), label='piecewise_quintic') 56 | plt.legend() 57 | plt.subplot(rows,cols,4) 58 | plt.plot(xx, poly3(xx, order=3), label='cubic') 59 | plt.plot(xx, poly4(xx, order=3), label='quartic') 60 | plt.plot(xx, poly5(xx, order=3), label='quintic') 61 | plt.plot(xx, piecewise_poly5(xx, order=3), label='piecewise_quintic') 62 | plt.legend() 63 | 64 | 65 | 66 | 67 | 68 | plt.show() 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /DEBUGtest/data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/DEBUGtest/data.pt -------------------------------------------------------------------------------- /DEBUGtest/debug_frenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spider.utils.transform.frenet import FrenetTransformer 3 | from spider.elements.trajectory import FrenetTrajectory, Trajectory 4 | import matplotlib.pyplot as plt 5 | import torch 6 | 7 | data = torch.load('data.pt') 8 | 9 | traj_arr = data['traj'].numpy() 10 | cline = data['base_centerline'].numpy() 11 | # indices_to_remove = np.where(np.arange(len(cline)) % 10 == 0) 12 | # cline = np.delete(cline, indices_to_remove, axis=0) 13 | 14 | plt.figure() 15 | plt.plot(cline[:,0], cline[:,1], '.-') 16 | plt.plot(traj_arr[:,0], traj_arr[:,1]) 17 | 18 | 19 | transformer = FrenetTransformer() 20 | transformer.set_reference_line(cline) 21 | 22 | plt.figure() 23 | ss = np.linspace(0,168,1000) 24 | cline_interp = transformer.refer_line_csp.calc_point(ss) 25 | plt.subplot(131) 26 | plt.plot(ss, cline_interp[:, 0],lw=2) 27 | plt.plot(transformer.refer_line_csp.s, transformer.refer_line_csp.x, '.') 28 | plt.subplot(132) 29 | plt.plot(ss, cline_interp[:, 1],lw=2) 30 | plt.plot(transformer.refer_line_csp.s, transformer.refer_line_csp.y , '.') 31 | plt.subplot(133) 32 | plt.plot(cline_interp[:,0], cline_interp[:,1]) 33 | 34 | traj = Trajectory.from_trajectory_array(traj_arr, dt=0.1,calc_derivative=False) 35 | frenet_traj = transformer.cart2frenet4traj(traj, order=0) 36 | plt.figure() 37 | plt.plot(frenet_traj.s, frenet_traj.l) 38 | 39 | 40 | plt.figure() 41 | plt.plot(cline[:50,0], cline[:50,1], '.-') 42 | plt.plot(traj_arr[:,0], traj_arr[:,1]) 43 | for i, (x,y) in enumerate(cline): 44 | plt.text(x,y , str(i)) 45 | 46 | 47 | 48 | plt.show() 49 | pass 50 | # import pickle 51 | # 52 | # f = open('data.pt',) 53 | # data = pickle.load('data.pt') 54 | -------------------------------------------------------------------------------- /DEBUGtest/planner_arena.py: -------------------------------------------------------------------------------- 1 | from spider.interface.BaseBenchmark import DummyBenchmark 2 | from spider.planner_zoo import * 3 | from spider.planner_zoo.DQNPlanner import DQNPlanner 4 | from spider.planner_zoo.DDQNPlanner import DDQNPlanner 5 | from spider.planner_zoo.DiscretePPOPlanner import DiscretePPOPlanner 6 | from spider.planner_zoo.ProbabilisticPlanner import ProbabilisticPlanner 7 | from spider.planner_zoo.MlpPlanner import MlpPlanner 8 | from spider.planner_zoo.GRUPlanner import GRUPlanner 9 | 10 | planners = [] 11 | 12 | planner = FallbackDummyPlanner({ 13 | "steps": 20, 14 | "dt": 0.2, 15 | "print_info": False 16 | }) 17 | planners.append(planner) 18 | 19 | planner = LatticePlanner({ 20 | "steps": 20, 21 | "dt": 0.2, 22 | "print_info": False 23 | }) 24 | planners.append(planner) 25 | 26 | planner = OptimizedLatticePlanner({ 27 | "steps": 20, 28 | "dt": 0.2, 29 | "print_info": False 30 | }) 31 | planners.append(planner) 32 | 33 | 34 | planner = BezierPlanner({ 35 | "steps": 20, 36 | "dt": 0.2, 37 | "end_s_candidates": (20,30), 38 | "end_l_candidates": (-3.5, 0, 3.5), 39 | "end_v_candidates": tuple(index * 60 / 3.6 / 3 for index in range(4)), # 改这一项的时候,要连着限速一起改了 40 | "end_T_candidates": (2, 4, 8), # s_dot, T采样生成纵向轨迹 41 | "print_info": False 42 | }) 43 | planners.append(planner) 44 | 45 | planner = PiecewiseLatticePlanner({ 46 | "steps": 20, 47 | "dt": 0.2, 48 | "print_info": False 49 | }) 50 | planners.append(planner) 51 | 52 | planner = ImaginaryPlanner({ 53 | "steps": 20, 54 | "dt": 0.2, 55 | "print_info": False 56 | }) 57 | planners.append(planner) 58 | 59 | planner = DQNPlanner({ 60 | "steps": 20, 61 | "dt": 0.2, 62 | "print_info": False 63 | }) 64 | planner.policy.load_model('./RL_planner/q_net.pth') 65 | planners.append(planner) 66 | 67 | planner = DiscretePPOPlanner({ 68 | "steps": 20, 69 | "dt": 0.2, 70 | "print_info": False 71 | }) 72 | planner.policy.load_model('./RL_planner/ppo.pth') 73 | planners.append(planner) 74 | # 75 | # 76 | # planner = DDQNPlanner({ 77 | # "steps": 20, 78 | # "dt": 0.2, 79 | # "num_object": 5, 80 | # "print_info": False 81 | # }) 82 | # planner.policy.load_model('./RL_planner/q_net_bes_ddqn.pth') 83 | # planners.append(planner) 84 | 85 | ########## IL ########## 86 | il_cfg = { 87 | "steps": 20, 88 | "dt": 0.2, 89 | "num_object": 5, 90 | "normalize": False, 91 | "relative": False, 92 | "longitudinal_range": (-50, 100), 93 | "lateral_range": (-20,20), 94 | 95 | "learning_rate": 0.0001, 96 | "enable_tensorboard": False, 97 | } 98 | 99 | planner = MlpPlanner(il_cfg) 100 | planner.load_state_dict('./mlp.pth') 101 | planners.append(planner) 102 | 103 | planner = GRUPlanner(il_cfg) 104 | planner.load_state_dict('./gru.pth') 105 | planners.append(planner) 106 | 107 | planner = ProbabilisticPlanner(il_cfg) 108 | planner.load_state_dict('./prob.pth') 109 | planners.append(planner) 110 | 111 | 112 | benchmark = DummyBenchmark({ 113 | # "save_video": True, 114 | # "debug_mode": True, 115 | "snapshot": False, 116 | "evaluation": True, 117 | "rendering": False, 118 | }) 119 | 120 | import random 121 | 122 | for planner in planners: 123 | random.seed(0) 124 | 125 | print("--------------------------------------") 126 | print("Planner: ", planner.__class__.__name__) 127 | benchmark.test(planner, 10) 128 | -------------------------------------------------------------------------------- /DEBUGtest/profile_cubicspline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from tqdm import tqdm 4 | 5 | from spider.elements.curves import myCubicSpline as mycsp 6 | from spider.elements.curves import spCubicSpline as my_spcsp 7 | from scipy.interpolate import CubicSpline as spcsp 8 | 9 | 10 | def test_spcsp(): 11 | x = np.arange(10) 12 | y = np.sin(x) 13 | cs = spcsp(x, y) 14 | xs = np.arange(-0.5, 9.6, 0.1) 15 | fig, ax = plt.subplots(figsize=(6.5, 4)) 16 | ax.plot(x, y, 'o', label='data') 17 | ax.plot(xs, np.sin(xs), label='true') 18 | ax.plot(xs, cs(xs), label="S") 19 | ax.plot(xs, cs(xs, 1), label="S'") 20 | ax.plot(xs, cs(xs, 2), label="S''") 21 | ax.plot(xs, cs(xs, 3), label="S'''") 22 | ax.set_xlim(-0.5, 9.5) 23 | ax.legend(loc='lower left', ncol=2) 24 | plt.show() 25 | 26 | 27 | if __name__ == '__main__': 28 | # test_spcsp() 29 | # assert 0 30 | 31 | # 生成一些二维点序列,这里以随机点为例 32 | np.random.seed(6) 33 | x = np.arange(10) + np.random.normal(0, 0.3, 10) 34 | y = np.sin(x) + np.random.normal(0, 0.1, len(x)) 35 | 36 | x = np.append(x, 20) 37 | y = np.append(y, y[-1]) 38 | 39 | 40 | # 使用三次样条插值 41 | csp1 = mycsp(x, y) 42 | csp2 = my_spcsp(x, y) 43 | csp3 = spcsp(x, y, bc_type='natural') 44 | 45 | x_interp = np.linspace(min(x), max(x)+5, 100) 46 | 47 | ############## 计时 #################### 48 | for csp in [csp1, csp2, csp3]: 49 | for _ in tqdm(range(10000),desc=str(csp.__class__)): 50 | y_interp = csp(x_interp) 51 | ########################################## 52 | # : 100%|██████████| 100000/100000 [00:04<00:00, 21689.01it/s] 53 | # : 100%|██████████| 100000/100000 [00:00<00:00, 282696.19it/s] 54 | # 离谱,scipy自带的比我设计的要快10倍以上 55 | 56 | 57 | # 生成插值点 58 | plt.figure(figsize=(12,8)) 59 | for order in range(4): 60 | plt.subplot(2,2,order+1) 61 | for csp in [csp1, csp2, csp3]: 62 | 63 | y_interp = csp(x_interp, order) 64 | # 绘制原始点和插值曲线 65 | plt.plot(x_interp, y_interp, label=str(csp.__class__).split('.')[-1]) 66 | 67 | if order == 0: plt.scatter(x, y, label='Original Points') 68 | plt.xlabel('X') 69 | plt.ylabel('Y') 70 | plt.legend() 71 | plt.title('Cubic Spline for order '+str(order)) 72 | plt.show() 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /DEBUGtest/profile_polynomial.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | valid_x_range = [0,20] 8 | 9 | class tic: 10 | t1 = -1 11 | i = 0 12 | def toc(self): 13 | if self.t1 < 0: 14 | self.t1 = time.time() 15 | else: 16 | t2 = time.time() 17 | print(t2-self.t1,'--'+str(self.i)) 18 | self.t1 = t2 19 | 20 | 21 | def _out_of_range_flag(x): 22 | ''' 23 | if x is a scalar, return bool 24 | if x is an array, return a boolean array with the same size 25 | ''' 26 | return (x < valid_x_range[0]) | (x > valid_x_range[1]) # 不可以用or, 因为要考虑array情况 27 | 28 | 29 | coef = [1,2,3,4] 30 | x = np.linspace(0,19,50) 31 | t = tic() 32 | 33 | poly_func = np.poly1d(coef) 34 | 35 | 36 | # from scipy.interpolate import CubicSpline as spcsp 37 | # csp = spcsp(x, np.random.rand(x.size),bc_type='natural') 38 | # yy = csp(np.linspace(0,20,50)) 39 | 40 | 41 | def calculate1(x): 42 | x = np.asarray(x, dtype=float) 43 | extra_mask = _out_of_range_flag(x) 44 | val = np.empty_like(x) 45 | val[extra_mask] = np.polyval(coef, x[-1]) * np.ones(np.sum(extra_mask)) # 数据范围外,按外推规则计算 46 | val[~extra_mask] = np.polyval(coef, x[~extra_mask]) # 数据范围内,直接计算 47 | return val 48 | 49 | def calculate2(x): 50 | x = np.asarray(x, dtype=float) 51 | extra_mask = _out_of_range_flag(x) 52 | val = np.empty_like(x) 53 | val[extra_mask] = poly_func(x[-1]) * np.ones(np.sum(extra_mask)) # 数据范围外,按外推规则计算 54 | val[~extra_mask] = poly_func(x[~extra_mask]) # 数据范围内,直接计算 55 | return val 56 | 57 | def calculate3(x): 58 | val = poly_func(x) 59 | return val 60 | 61 | def calculate4(x): 62 | val = np.polyval(coef, x) 63 | return val 64 | 65 | def calculate5(x): 66 | # val = np.polyval(coef, x) 67 | val = np.where((x < valid_x_range[0]) | (x > valid_x_range[1]), poly_func(x[-1]), poly_func(x)) 68 | return val 69 | # extra_mask = _out_of_range_flag(x) 70 | # val[extra_mask] = poly_func(x[-1]) * np.ones(np.sum(extra_mask)) # 数据范围外,按外推规则计算 71 | 72 | def calculate6(x): 73 | x = np.asarray(x, dtype=float) 74 | if (x.min()< valid_x_range[0]) or (x.max() > valid_x_range[1]): 75 | val = np.where((x < valid_x_range[0]) | (x > valid_x_range[1]), poly_func(x), poly_func(x[-1])) 76 | # extra_mask = _out_of_range_flag(x) 77 | # val = np.empty_like(x) 78 | # val[extra_mask] = np.polyval(coef, x[-1]) * np.ones(np.sum(extra_mask)) # 数据范围外,按外推规则计算 79 | # val[~extra_mask] = np.polyval(coef, x[~extra_mask]) # 数据范围内,直接计算 80 | else: 81 | val = np.polyval(coef, x) 82 | return val 83 | 84 | import pyximport 85 | pyximport.install( language_level=3)#'../elements/poly_calc.pyx', 86 | poly_calc = pyximport.load_module('poly_calc', '../elements/poly_calc.pyx', language_level=3) 87 | # import poly_calc 88 | # # 导入编译后的模块 89 | 90 | y_0, y_prime_0, y_end, y_prime_end = 1,-1,10,1 91 | 92 | def calculate0(x): 93 | # x = np.asarray(x) 94 | if np.isscalar(x) and np.ndim(x) == 0: 95 | assert 0 96 | x = np.ascontiguousarray(x, dtype=np.float64) 97 | c = np.ascontiguousarray(coef, dtype=np.float64) 98 | val = poly_calc.evaluate(x, c, valid_x_range[0], valid_x_range[1], 0, y_0, y_prime_0, y_end, y_prime_end) 99 | return np.asarray(val) 100 | 101 | 102 | 103 | # import cProfile 104 | # cProfile.run('for _ in range(10000): calculate1(x)') 105 | all_funcs = [calculate0, calculate1, calculate2, calculate3, calculate4,calculate5,calculate6] 106 | 107 | all_val = [func(x) for func in all_funcs] 108 | 109 | for func in all_funcs: 110 | for _ in tqdm(range(100000),desc=str(func.__name__)): 111 | val = func(x) 112 | # print(val) 113 | 114 | 115 | -------------------------------------------------------------------------------- /DEBUGtest/test_SAT.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from spider.elements import FrenetTrajectory, TrackingBoxList, TrackingBox 5 | from spider.elements.box import obb2vertices 6 | from spider.utils.collision import BoxCollisionChecker 7 | from spider.elements.curves import QuinticPolynomial 8 | 9 | np.random.seed(0) 10 | 11 | def check_and_draw(tb): 12 | tb_list = TrackingBoxList() 13 | tb_list.append(tb) 14 | 15 | predicted_obstacles = tb_list.predict(traj.dt * np.arange(traj.steps)) 16 | vertices = tb.vertices 17 | vertices = np.vstack((vertices, vertices[0])) # recurrent to close polyline 18 | if checker.check_trajectory(traj, predicted_obstacles): 19 | plt.plot(vertices[:, 0], vertices[:, 1], color='red', linestyle='-', linewidth=1.5) # 画他车 20 | return True 21 | else: 22 | plt.plot(vertices[:, 0], vertices[:, 1], color='green', linestyle='-', linewidth=1.5) # 画他车 23 | return False 24 | 25 | 26 | qp = QuinticPolynomial() 27 | qp.two_point_boundary_value(100.,100.,0,0,150.,120,0,0) 28 | 29 | traj = FrenetTrajectory(50) 30 | traj.x = np.linspace(100., 150., 50) 31 | traj.y = qp.calc_point(traj.x)#np.linspace(400000., 400020., 50) 32 | traj.heading = qp.calc_yaw(traj.x)#np.ones_like(traj.x) * np.arctan2(20,50) 33 | 34 | 35 | plt.figure(figsize=(12,8)) 36 | # 画自车轨迹 37 | plt.plot(traj.x, traj.y) 38 | # 画自车脚印 39 | for x, y, yaw in zip(traj.x, traj.y, traj.heading): 40 | vertices = obb2vertices((x,y,5.,2.,yaw)) 41 | vertices = np.vstack((vertices, vertices[0])) # recurrent to close polyline 42 | plt.plot(vertices[:, 0], vertices[:, 1], color='gray', linestyle='-', linewidth=1) 43 | 44 | 45 | checker = BoxCollisionChecker(5.,2.) 46 | 47 | tb = TrackingBox(obb=(119,112.5,4,2,np.arctan(3/5)), vx=0,vy=0) 48 | check_and_draw(tb) 49 | # plt.show() 50 | 51 | 52 | count = 0 53 | np.random.seed(0) 54 | for i in range(500): 55 | x = np.random.random()*80 + 100-20 56 | y = np.random.random()*50 + 100-10 57 | heading = np.random.random()*3.14*2 - 3.14 58 | tb = TrackingBox(obb=(x,y,4,2,heading), vx=0,vy=0) 59 | if check_and_draw(tb): 60 | count += 1 61 | 62 | plt.title("SAT: collision report num is " + str(count)) 63 | plt.gca().set_aspect("equal") 64 | plt.show() 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /DEBUGtest/test_carla_interface.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import random 3 | import carla 4 | import pygame 5 | from tqdm import tqdm 6 | 7 | import spider.planner_zoo 8 | from spider.interface.carla.CarlaInterface import CarlaInterface 9 | from spider.planner_zoo import LatticePlanner 10 | import spider.visualize as vis 11 | 12 | import pygame 13 | 14 | 15 | # host = '192.168.3.107' # 无线局域网 16 | host = '192.168.189.13' # 有线局域网 17 | # host = "101.5.228.94" 18 | # host = '127.0.0.1' 19 | cport = 2000 20 | tmport = 8000 21 | # random.seed(123) 22 | # random.seed(40) 23 | random.seed(11) 24 | env = CarlaInterface( 25 | host, 26 | cport, 27 | tmport, 28 | recording=True 29 | ) 30 | 31 | planner = LatticePlanner({ 32 | "steps": 20, 33 | "dt": 0.2, 34 | "ego_veh_length": env.ego_size[0], # !!!这边还没有spawn hero 35 | "ego_veh_width": env.ego_size[1], 36 | "end_s_candidates": (10, 30),#(20,40,60), 37 | "end_l_candidates": (0,),#(-0.5,0,0.5), 38 | "end_v_candidates": tuple(i*60/3.6/3 for i in range(4)), 39 | "constraint_flags": {}, 40 | "print_info": False 41 | }) 42 | 43 | SPIDER_PLOT = 0 44 | # planner = spider.planner_zoo.DummyPlanner() 45 | 46 | try: 47 | # if 1: 48 | 49 | maps = env.client.get_available_maps() 50 | print(maps) 51 | 52 | map_name = 'Town10HD' 53 | env.load_map(map_name) 54 | 55 | # print(env.map.name) 56 | # if env.map is not None and not (map_name in env.map.name): 57 | # print("loading map...") 58 | # # env.load_map('Town10HD_Opt', map_layers=carla.MapLayer.Ground) 59 | # env.load_map(map_name) 60 | # else: 61 | # env.destroy() 62 | # env.random_weather(True) 63 | 64 | env.spawn_hero(autopilot=False) 65 | env.generate_traffic(50,5) 66 | 67 | env.bev_spectator(30, 5, 5, True) 68 | # env.side_view_spectator(left=False) 69 | # env.third_person_spectator() 70 | # env.first_person_spectator() 71 | 72 | env.viewer.change_view("first_person") 73 | vis.figure() 74 | # display = None 75 | display = pygame.display.set_mode( 76 | env.viewer.image_size, 77 | pygame.HWSURFACE | pygame.DOUBLEBUF) 78 | 79 | snapshot = vis.SnapShot(False, record_video=True, video_path='carla_visualize_recording.avi', fps=20) 80 | 81 | 82 | ############## 83 | # test:不允许自动换道 84 | for actor in env.npc_vehicles: 85 | env.traffic_manager.auto_lane_change(actor, False) 86 | ########### 87 | 88 | delta_steps = 1 # int(planner.dt /0.05)#/ 0.05) 89 | traj = None 90 | for i in tqdm(range(1000)): 91 | 92 | env.world.tick() 93 | 94 | 95 | env.render(display) 96 | pygame.display.flip() 97 | 98 | 99 | ego_veh_state, tb_list, routed_local_map = env.wrap_observation() 100 | if traj is None or i%delta_steps==0: 101 | traj = planner.plan(ego_veh_state, tb_list, routed_local_map) 102 | else: 103 | traj.x = traj.x[1:] 104 | traj.y = traj.y[1:] 105 | traj.heading = traj.heading[1:] 106 | traj.v = traj.v[1:] 107 | 108 | # assert traj is not None 109 | 110 | if traj is not None: 111 | for x, y in zip(traj.x, traj.y): 112 | loc = carla.Location(x=x, y=y, z=env.hero.get_location().z + 0.5) 113 | env.world.debug.draw_point(loc, size=0.1, life_time=0.1) 114 | 115 | env.conduct_trajectory(traj, ego_veh_state) 116 | 117 | 118 | if (traj is not None) and SPIDER_PLOT: 119 | # pass 120 | plt.cla() 121 | vis.draw_ego_vehicle(ego_veh_state, color='C0', fill=True, alpha=0.3, linestyle='-', linewidth=1.5) 122 | vis.draw_trackingbox_list(tb_list,draw_prediction=False) 123 | vis.draw_local_map(routed_local_map) 124 | vis.draw_trajectory(traj, '.-', show_footprint=True, color='C2') # 画轨迹 125 | vis.ego_centric_view(ego_veh_state.x(), ego_veh_state.y(),(-50,50),(-50,50)) 126 | plt.pause(0.01) 127 | snapshot.snap(plt.gca()) 128 | 129 | if env.has_arrived(): 130 | print("Hero has arrived!") 131 | break 132 | 133 | 134 | except Exception as e: 135 | print(e) 136 | 137 | finally: 138 | env.destroy() 139 | 140 | if "snapshot" in dir(): 141 | snapshot.release_writer() 142 | 143 | 144 | -------------------------------------------------------------------------------- /DEBUGtest/test_disk_check.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | import spider 5 | from spider.elements import FrenetTrajectory, TrackingBoxList, TrackingBox 6 | from spider.elements.box import obb2vertices 7 | from spider.utils.collision import BoxCollisionChecker 8 | from spider.elements.curves import QuinticPolynomial 9 | 10 | np.random.seed(0) 11 | 12 | def check_and_draw(tb): 13 | tb_list = TrackingBoxList() 14 | tb_list.append(tb) 15 | 16 | predicted_obstacles = tb_list.predict(traj.dt * np.arange(traj.steps)) 17 | vertices = tb.vertices 18 | vertices = np.vstack((vertices, vertices[0])) # recurrent to close polyline 19 | if checker.check_trajectory(traj, predicted_obstacles): 20 | plt.plot(vertices[:, 0], vertices[:, 1], color='red', linestyle='-', linewidth=1.5) # 画他车 21 | return True 22 | else: 23 | plt.plot(vertices[:, 0], vertices[:, 1], color='green', linestyle='-', linewidth=1.5) # 画他车 24 | return False 25 | 26 | qp = QuinticPolynomial() 27 | qp.two_point_boundary_value(100.,100.,0,0,150.,120,0,0) 28 | 29 | traj = FrenetTrajectory(50) 30 | traj.x = np.linspace(100., 150., 50) 31 | traj.y = qp.calc_point(traj.x)#np.linspace(400000., 400020., 50) 32 | traj.heading = qp.calc_yaw(traj.x)#np.ones_like(traj.x) * np.arctan2(20,50) 33 | 34 | 35 | plt.figure(figsize=(12,8)) 36 | # 画自车轨迹 37 | plt.plot(traj.x, traj.y) 38 | # 画自车脚印 39 | for x, y, yaw in zip(traj.x, traj.y, traj.heading): 40 | vertices = obb2vertices((x,y,5.,2.,yaw)) 41 | vertices = np.vstack((vertices, vertices[0])) # recurrent to close polyline 42 | plt.plot(vertices[:, 0], vertices[:, 1], color='gray', linestyle='-', linewidth=1) 43 | 44 | 45 | checker = BoxCollisionChecker(5.,2.,spider.COLLISION_CHECKER_DISK) 46 | 47 | 48 | count = 0 49 | np.random.seed(0) 50 | for i in range(500): 51 | x = np.random.random()*80 + 100-20 52 | y = np.random.random()*50 + 100-10 53 | heading = np.random.random()*3.14*2 - 3.14 54 | tb = TrackingBox(obb=(x,y,4,2,heading), vx=0,vy=0) 55 | if check_and_draw(tb): 56 | count += 1 57 | 58 | plt.title("disk: collision report num is " + str(count)) 59 | plt.gca().set_aspect("equal") 60 | plt.show() 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /DEBUGtest/test_dummybenchmark.py: -------------------------------------------------------------------------------- 1 | from spider.interface.BaseBenchmark import DummyBenchmark 2 | from spider.planner_zoo import * 3 | 4 | benchmark = DummyBenchmark({ 5 | "save_video": True, 6 | "debug_mode": True 7 | }) 8 | 9 | planner = FallbackDummyPlanner({ 10 | "steps": 20, 11 | "dt": 0.2, 12 | "print_info": False 13 | }) 14 | 15 | # planner = LatticePlanner({ 16 | # "steps": 20, 17 | # "dt": 0.2, 18 | # }) 19 | 20 | # planner = DummyPlanner() 21 | 22 | # planner = BezierPlanner({ 23 | # "steps": 20, 24 | # "dt": 0.2, 25 | # "end_s_candidates": (20,30), 26 | # "end_l_candidates": (-3.5, 0, 3.5), 27 | # "end_v_candidates": tuple(index * 60 / 3.6 / 3 for index in range(4)), # 改这一项的时候,要连着限速一起改了 28 | # "end_T_candidates": (2, 4, 8), # s_dot, T采样生成纵向轨迹 29 | # }) 30 | 31 | # planner = PiecewiseLatticePlanner({ 32 | # "steps": 20, 33 | # "dt": 0.2, 34 | # }) 35 | 36 | # planner = ImaginaryPlanner() 37 | 38 | # planner = OptimizedLatticePlanner({}) 39 | 40 | # planner = FallbackPlanner() 41 | 42 | 43 | benchmark.test(planner) 44 | 45 | -------------------------------------------------------------------------------- /DEBUGtest/test_exp_buffer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | import spider 6 | from spider.data.DataBuffer import ExperienceBuffer 7 | 8 | class MlpPolicy(nn.Module): 9 | def __init__(self, obs_space, action_space): 10 | super().__init__() 11 | 12 | self.obs_space = obs_space 13 | self.action_space = action_space 14 | self.device = torch.device("cpu") 15 | 16 | self.actor = nn.Sequential( 17 | nn.Linear(obs_space, 16), 18 | nn.ReLU(), 19 | nn.Linear(16, 16), 20 | nn.ReLU(), 21 | nn.Linear(16, action_space), 22 | nn.Tanh() 23 | ) 24 | # self.action_head = nn.Linear(16, action_space * 2) 25 | 26 | def forward(self, x): 27 | return self.actor(x) 28 | 29 | 30 | def test1(): 31 | exp_buffer = ExperienceBuffer( 32 | forward_only=True, 33 | subdir_prefix='exp', 34 | data_root="./dataset_exp/", 35 | autosave_max_intervals=50, 36 | file_format=spider.DATA_FORMAT_TENSOR 37 | # file_format=spider.DATASET_FORMAT_RAW 38 | ) 39 | 40 | policy = MlpPolicy(64, 64) 41 | 42 | exp_buffer.apply_to(policy) 43 | 44 | for step in range(168): 45 | obs = torch.rand(64) 46 | action = policy(obs) 47 | 48 | exp_buffer.release() 49 | 50 | 51 | 52 | if __name__ == '__main__': 53 | test1() 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /DEBUGtest/test_frenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spider.utils.transform.frenet import FrenetTransformer 3 | from spider.elements.trajectory import FrenetTrajectory 4 | 5 | 6 | xs = np.linspace(0,100,101) 7 | ys = np.linspace(0,-1,101) 8 | centerline = np.column_stack((xs, ys)) 9 | transformer = FrenetTransformer() 10 | transformer.set_reference_line(centerline) 11 | 12 | 13 | # 转某个点的坐标 14 | x, y, speed, yaw, acc, kappa = 50, 1, 5, np.pi/4, 3, 0. 15 | state = transformer.cart2frenet(x, y, speed, yaw, acc, kappa, order=2) 16 | print("s,l,s_dot, l_dot, l_prime,s_2dot, l_2dot, l_2prime") 17 | print(state.s, state.l, state.s_dot, state.l_dot, state.l_prime,state.s_2dot, state.l_2dot, state.l_2prime) 18 | 19 | print("======================") 20 | s, l, s_dot, l_dot, l_prime, s_2dot, l_2dot, l_2prime = \ 21 | state.s, state.l, state.s_dot, state.l_dot, state.l_prime, state.s_2dot, state.l_2dot, state.l_2prime 22 | state = transformer.frenet2cart(s, l, s_dot, l_dot, l_prime, s_2dot, l_2dot, l_2prime, order=2) 23 | print("x, y, speed, yaw, acc, kappa") 24 | print(state.x, state.y, state.speed, state.yaw, state.acceleration, state.curvature) 25 | 26 | 27 | # 转某个轨迹的坐标 28 | xs = np.linspace(0,50,50) 29 | ys = np.linspace(0,1,50) 30 | 31 | 32 | traj = FrenetTrajectory(steps=50, dt=0.1) 33 | traj.x = xs 34 | traj.y = ys 35 | 36 | print("======================") 37 | frenet_traj = transformer.cart2frenet4traj(traj,order=0) 38 | print(frenet_traj.s, frenet_traj.l) 39 | 40 | 41 | import matplotlib.pyplot as plt 42 | plt.subplot(1,2,1) 43 | plt.plot(traj.x, traj.y) 44 | plt.subplot(1,2,2) 45 | plt.plot(frenet_traj.s, frenet_traj.l) 46 | 47 | 48 | 49 | frenet_traj = FrenetTrajectory(steps=50, dt=0.1) 50 | frenet_traj.s = xs 51 | frenet_traj.l = ys 52 | 53 | print("======================") 54 | cart_traj = transformer.frenet2cart4traj(frenet_traj, order=0) 55 | print(cart_traj.x, cart_traj.y) 56 | 57 | plt.show() 58 | -------------------------------------------------------------------------------- /DEBUGtest/test_gru_planner.py: -------------------------------------------------------------------------------- 1 | 2 | from spider.data.Dataset import OfflineLogDataset 3 | from spider.planner_zoo.GRUPlanner import GRUPlanner 4 | 5 | train = 0 6 | test_mode_closed_loop = 1 7 | 8 | # setup the planner 9 | planner = GRUPlanner({ 10 | "steps": 20, 11 | "dt": 0.2, 12 | "num_object": 5, 13 | "normalize": False, 14 | "relative": False, 15 | "longitudinal_range": (-50, 100), 16 | "lateral_range": (-20,20), 17 | 18 | "learning_rate": 0.0001, 19 | "enable_tensorboard": True, 20 | "tensorboard_root": './tensorboard/' 21 | }) 22 | 23 | # setup the dataset 24 | dataset = OfflineLogDataset('./dataset_map/', planner.state_encoder, planner.action_encoder) 25 | train_loader = dataset.get_dataloader(batch_size=32, shuffle=True) #DataLoader(dataset, batch_size=64, shuffle=True) 26 | 27 | if train: 28 | # train_mode the planner 29 | planner.policy.learn_dataset(50, train_loader=train_loader) 30 | 31 | # save the model 32 | planner.save_state_dict('gru.pth') 33 | 34 | # load the model 35 | planner.load_state_dict('gru.pth') 36 | # planner.load_state_dict('gru_best.pth') 37 | 38 | # test the planner 39 | 40 | if test_mode_closed_loop: 41 | from spider.interface.BaseBenchmark import DummyBenchmark 42 | benchmark = DummyBenchmark({ 43 | "save_video": True, 44 | }) 45 | benchmark.test(planner) 46 | else: 47 | dataset.replay(planner, 0, recording=True) 48 | -------------------------------------------------------------------------------- /DEBUGtest/test_highway_env.py: -------------------------------------------------------------------------------- 1 | from spider.interface.highway_env import HighwayEnvBenchmark, HighwayEnvBenchmarkGUI 2 | from spider.planner_zoo import LatticePlanner 3 | 4 | from tqdm import tqdm 5 | 6 | 7 | def test_benchmark_api(): 8 | steps, dt = 20, 0.2 9 | 10 | benchmark = HighwayEnvBenchmark(dt=dt, config={"max_steps":200}) 11 | 12 | 13 | planner = LatticePlanner({ 14 | "steps": steps, 15 | 'dt' : dt, 16 | "max_speed": 120/3.6, 17 | "end_s_candidates": (20, 30, 40, 80), 18 | "end_l_candidates": (-4,0,4), # s,d采样生成横向轨迹 (-3.5, 0, 3.5), # 19 | "end_v_candidates": tuple(i*120/3.6/4 for i in range(5)), # 改这一项的时候,要连着限速一起改了 20 | "end_T_candidates": (1,2,4,8), # s_dot, T采样生成纵向轨迹 21 | "print_info": False 22 | }) 23 | 24 | 25 | obs, info = benchmark.initial_environment() 26 | ego_veh_state, perception, local_map = benchmark.interface.wrap_observation(obs) 27 | benchmark.test(planner) 28 | 29 | # import cProfile 30 | # cProfile.run('for _ in range(10): output = planner.plan(ego_veh_state, perception, local_map)') 31 | # assert 0 32 | 33 | # cProfile.run('benchmark.test(planner)') 34 | 35 | 36 | def test_benchmark_gui(): 37 | HighwayEnvBenchmarkGUI.launch() 38 | # import tkinter as tk 39 | # root = tk.Tk() 40 | # app = HighwayEnvBenchmarkGUI(root) 41 | # root.mainloop() 42 | 43 | if __name__ == '__main__': 44 | test_benchmark_gui() 45 | 46 | 47 | -------------------------------------------------------------------------------- /DEBUGtest/test_import_spider.py: -------------------------------------------------------------------------------- 1 | import spider 2 | 3 | spider.teaser() 4 | -------------------------------------------------------------------------------- /DEBUGtest/test_log_buffer.py: -------------------------------------------------------------------------------- 1 | import spider 2 | from spider.interface.BaseBenchmark import DummyBenchmark 3 | from spider.planner_zoo import * 4 | from spider.data.DataBuffer import LogBuffer 5 | from spider.data.decorators import logbuffer_plan 6 | 7 | 8 | def test1(): 9 | # 用log_buffer.apply_to(planner)启用 10 | benchmark = DummyBenchmark({ 11 | "snapshot": False, 12 | "map_frequency": 1, # 记录地图数据 13 | # "racetrack": "straight", 14 | }) 15 | 16 | planner = LatticePlanner({ 17 | "steps": 20, 18 | "dt": 0.2, 19 | "print_info": False 20 | }) 21 | 22 | log_buffer = LogBuffer( 23 | autosave_max_intervals=100, 24 | file_format=spider.DATA_FORMAT_RAW, 25 | # file_format=spider.DATA_FORMAT_JSON, 26 | data_root='./dataset_map/' 27 | ) 28 | 29 | log_buffer.apply_to(planner) 30 | 31 | for episode in range(10): 32 | benchmark.test(planner) 33 | 34 | log_buffer.release() 35 | 36 | 37 | def test2(): 38 | """ 39 | 用logbuffer_plan装饰器启用 40 | """ 41 | class ClosedLoopPlanner(LatticePlanner): 42 | @logbuffer_plan 43 | def plan(self, *args, **kwargs): 44 | return super(ClosedLoopPlanner, self).plan(*args, **kwargs) 45 | 46 | benchmark = DummyBenchmark({ 47 | "snapshot": False 48 | }) 49 | 50 | planner = ClosedLoopPlanner({ 51 | "steps": 20, 52 | "dt": 0.2, 53 | "print_info": False 54 | }) 55 | 56 | log_buffer = LogBuffer( 57 | autosave_max_intervals=50, 58 | file_format=spider.DATA_FORMAT_JSON 59 | ) 60 | 61 | planner.set_log_buffer(log_buffer) 62 | 63 | for episode in range(3): 64 | benchmark.test(planner) 65 | 66 | log_buffer.release() 67 | 68 | 69 | if __name__ == '__main__': 70 | test1() 71 | -------------------------------------------------------------------------------- /DEBUGtest/test_mlp_planner.py: -------------------------------------------------------------------------------- 1 | 2 | from spider.data.Dataset import OfflineLogDataset 3 | from spider.planner_zoo.MlpPlanner import MlpPlanner 4 | 5 | # import cProfile 6 | # cProfile.run('planner.policy.learn_dataset(100, train_loader=train_loader)') 7 | 8 | # test the planner 9 | train = 0 10 | test_mode_closed_loop = 1 11 | 12 | # setup the planner 13 | planner = MlpPlanner({ 14 | "steps": 20, 15 | "dt": 0.2, 16 | "num_object": 5, 17 | "normalize": False, 18 | "relative": False, 19 | "longitudinal_range": (-50, 100), 20 | "lateral_range": (-20,20), 21 | 22 | "learning_rate": 0.0001, 23 | "enable_tensorboard": True, 24 | "tensorboard_root": './tensorboard/' 25 | }) 26 | 27 | # setup the dataset 28 | dataset = OfflineLogDataset('./dataset_map/', planner.state_encoder, planner.action_encoder) 29 | train_loader = dataset.get_dataloader(batch_size=64, shuffle=True) #DataLoader(dataset, batch_size=64, shuffle=True) 30 | 31 | # train_mode the planner 32 | if train: 33 | planner.policy.learn_dataset(100, train_loader=train_loader) 34 | 35 | # save the model 36 | planner.save_state_dict('mlp.pth') 37 | 38 | # load the model 39 | planner.load_state_dict('mlp.pth') 40 | # planner.load_state_dict('mlp_good.pth') 41 | # planner.load_state_dict('mlp_best.pth') 42 | 43 | 44 | if test_mode_closed_loop: 45 | from spider.interface.BaseBenchmark import DummyBenchmark 46 | benchmark = DummyBenchmark({ 47 | "save_video": True, 48 | }) 49 | benchmark.test(planner) 50 | else: 51 | dataset.replay(planner, 0, recording=True) 52 | 53 | 54 | -------------------------------------------------------------------------------- /DEBUGtest/test_prob_planner.py: -------------------------------------------------------------------------------- 1 | 2 | from spider.data.Dataset import OfflineLogDataset 3 | from spider.planner_zoo.ProbabilisticPlanner import ProbabilisticPlanner 4 | 5 | train = 0 6 | test_mode_closed_loop = 0 7 | 8 | # setup the planner 9 | planner = ProbabilisticPlanner({ 10 | "steps": 20, 11 | "dt": 0.2, 12 | "num_object": 5, 13 | 14 | "learning_rate": 0.0001, 15 | "enable_tensorboard": True, 16 | "tensorboard_root": './tensorboard/' 17 | }) 18 | 19 | # setup the dataset 20 | dataset = OfflineLogDataset('./dataset_map/', planner.state_encoder, planner.action_encoder, use_cache=True) 21 | train_loader = dataset.get_dataloader(batch_size=32, shuffle=True) #DataLoader(dataset, batch_size=64, shuffle=True) 22 | 23 | if train: 24 | # train_mode the planner 25 | planner.policy.learn_dataset(50, train_loader=train_loader) 26 | 27 | # save the model 28 | planner.save_state_dict('prob.pth') 29 | 30 | # load the model 31 | planner.load_state_dict('prob.pth') 32 | 33 | # test the planner 34 | 35 | if test_mode_closed_loop: 36 | from spider.interface.BaseBenchmark import DummyBenchmark 37 | benchmark = DummyBenchmark({ 38 | "map_frequency": 1, 39 | "save_video": True, 40 | }) 41 | benchmark.test(planner) 42 | else: 43 | dataset.replay(planner, 0, recording=True) 44 | -------------------------------------------------------------------------------- /DEBUGtest/test_relative_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | from spider.utils.transform.relative import RelativeTransformer 5 | import spider.visualize as vis 6 | 7 | tf = RelativeTransformer() 8 | tf.abs2rel(2, 2, 3.14 / 3, 1, 1, (1, 1, 3.14 / 6), (1, 0)) 9 | print(tf) 10 | 11 | 12 | n = 20 13 | ego_pose = (1, 1, 3.14 / 6) 14 | xs = np.random.rand(n) * 10 15 | ys = np.random.rand(n) * 10 16 | yaws = np.random.rand(n) * 3.14 17 | # colors = np.random.rand(n,3) 18 | xs_rel, ys_rel, yaws_rel,_,_ = tf.abs2rel(xs, ys, yaws, ego_pose=ego_pose) 19 | vis.figure(figsize=(13,6)) 20 | vis.subplot(1,2,1) 21 | vis.draw_obb([ego_pose[0], ego_pose[1], 5,2, ego_pose[2]]) 22 | vis.plot(xs, ys, '.b') 23 | vis.subplot(1,2,2) 24 | vis.plot(xs_rel, ys_rel, '.b') 25 | 26 | vis.show() 27 | -------------------------------------------------------------------------------- /DEBUGtest/test_twopt_boundary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def test1(xs, vxs, axs, xe, vxe, axe, te): 4 | a0 = xs 5 | a1 = vxs 6 | a2 = axs / 2.0 7 | 8 | A = np.array([[te ** 3, te ** 4, te ** 5], 9 | [3 * te ** 2, 4 * te ** 3, 5 * te ** 4], 10 | [6 * te, 12 * te ** 2, 20 * te ** 3]]) 11 | b = np.array([xe - a0 - a1 * te - a2 * te ** 2, 12 | vxe - a1 - 2 * a2 * te, 13 | axe - 2 * a2]) 14 | x = np.linalg.solve(A, b) 15 | 16 | a3 = x[0] 17 | a4 = x[1] 18 | a5 = x[2] 19 | t_range = te 20 | 21 | def test2(xs, vxs, axs, xe, vxe, axe,ts, te): 22 | 23 | 24 | A = np.array([ 25 | [ts**5, ts**4,ts**3,ts**2,ts, 1], 26 | [5 * ts ** 4, 4 * ts ** 3, 3 * ts ** 2, 2*ts, 1, 0], 27 | [20 * ts ** 3, 12 * ts ** 2, 6 * ts, 2,0,0], 28 | [te ** 5, te ** 4, te ** 3, te ** 2, te, 1], 29 | [5 * te ** 4, 4 * te ** 3, 3 * te ** 2, 2 * te, 1, 0], 30 | [20 * te ** 3, 12 * te ** 2, 6 * te, 2,0,0] 31 | ]) 32 | 33 | b = np.array([ 34 | xs, 35 | vxs, 36 | axs, 37 | xe, 38 | vxe, 39 | axe 40 | ]) 41 | x = np.linalg.solve(A, b) 42 | coef = x 43 | t_range = te 44 | 45 | 46 | if __name__ == '__main__': 47 | from time import time 48 | 49 | xs, vxs, axs, xe, vxe, axe, ts, te = 3.5, 0, 0, 0, 0, 0, 0, 2 50 | 51 | t = 0. 52 | for _ in range(10000): 53 | t1 = time() 54 | test1(xs, vxs, axs, xe, vxe, axe, te) 55 | t += time() - t1 56 | print(t) 57 | 58 | t = 0. 59 | for _ in range(10000): 60 | t1 = time() 61 | test2(xs, vxs, axs, xe, vxe, axe,ts, te) 62 | t += time() - t1 63 | print(t) 64 | -------------------------------------------------------------------------------- /DEBUGtest/test_visualize/test_elements.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import spider.visualize as vis 4 | import spider.elements as elm 5 | from spider.interface.BaseBenchmark import DummyBenchmark 6 | 7 | ego, obs, lmap = DummyBenchmark.get_environment_presets() 8 | xx = np.linspace(0,40,20) 9 | dummy_traj = elm.Trajectory.from_trajectory_array( 10 | np.array([xx + ego.x(), np.sin(xx/5) + ego.y()]).T, dt=0.2, calc_derivative=True) 11 | 12 | 13 | vis.figure() 14 | vis.lazy_draw(ego, obs, lmap, dummy_traj) 15 | vis.show() 16 | 17 | -------------------------------------------------------------------------------- /DEBUGtest/test_visualize/test_line.py: -------------------------------------------------------------------------------- 1 | from spider.visualize import * 2 | from spider.elements import Trajectory 3 | 4 | from copy import deepcopy 5 | plt.figure(figsize=(15,3)) 6 | 7 | steps, dt = 30, 0.2 8 | traj = Trajectory(steps, dt) 9 | s0, s_d0, s_dd0 = 0., 60 / 3.6, 0. 10 | l0, l_d0, l_dd0 = 0., 0., 0. 11 | traj.x = xs = np.array([s0 + s_d0 * i * dt for i in range(steps)]) 12 | traj.y = ys = np.array([l0 + 0.12 * i * i * dt * dt for i in range(steps)]) 13 | traj.heading = np.insert(np.arctan2(np.diff(traj.y), np.diff(traj.x)), 0, 0.0) 14 | draw_trajectory(traj, '.-', show_footprint=True, footprint_fill=False) 15 | 16 | traj2 = deepcopy(traj) 17 | traj2.y *= -1 18 | traj2.heading *= -1 19 | draw_trajectory(traj2, '.-', show_footprint=True, footprint_fill=True) 20 | 21 | traj3 = deepcopy(traj) 22 | traj3.y *= 0 23 | traj3.heading *= 0 24 | # draw_trajectory(traj3, '.-', show_footprint=False) 25 | draw_polyline(np.column_stack((traj3.x, traj3.y)),show_buffer=True) 26 | 27 | plt.show() 28 | 29 | -------------------------------------------------------------------------------- /DEBUGtest/test_visualize/test_surface.py: -------------------------------------------------------------------------------- 1 | from spider.visualize.surface import * 2 | 3 | from spider.elements import TrackingBox, TrackingBoxList 4 | tb = TrackingBox(obb=[1,6,5,2,3.14/6]) 5 | tb_list = TrackingBoxList([tb]) 6 | draw_polygon(tb.vertices, fill=True,lw=1.5, alpha=0.2)#,color='red', 7 | 8 | draw_circle([2, 5], 1,mark_center=True, fill=True, alpha=0.6) 9 | draw_circle([3, 7], 1, mark_center=True, color='blue') 10 | plt.show() 11 | 12 | -------------------------------------------------------------------------------- /DEBUGtest/test_visualize/test_surface3d.py: -------------------------------------------------------------------------------- 1 | from spider.visualize.surface3d import * 2 | 3 | # 创建一个三维坐标轴 4 | fig = plt.figure() 5 | axes = plt.axes(projection="3d") 6 | top_vertices = np.array([[2, 0], [2, 1], [3, 1],[3, 0]]) 7 | bottom_vertices = np.array([[0, 0], [0, 1], [1, 1], [1, 0]]) 8 | 9 | draw_prism(bottom_vertices, 0, top_vertices, 8) 10 | plt.show() 11 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import importlib as _importlib 2 | from spider.param import * 3 | from spider._virtual_import import _virtual_import, _try_import, _try_import_from 4 | from spider._misc import * 5 | from spider.teaser import teaser 6 | 7 | 8 | submodules = [ 9 | 'elements', 'utils', 'sampler', 'evaluator', 'interface', 'control', 'planner_zoo', 'RL', 10 | 'data', 11 | #'teaser', 'param' 12 | ] 13 | 14 | __all__ = submodules + ['teaser', 'param'] 15 | 16 | 17 | def __dir__(): 18 | return __all__ 19 | 20 | 21 | def __getattr__(name): 22 | if name in submodules: 23 | return _importlib.import_module(f'spider.{name}') 24 | elif name =="vehicle_model": 25 | return AttributeError("sub module 'spider.vehicle_model' has been moved to 'spider.control.vehicle_model'") 26 | else: 27 | try: 28 | return globals()[name] 29 | except KeyError: 30 | raise AttributeError( 31 | f"Module 'spider' has no attribute '{name}'" 32 | ) 33 | -------------------------------------------------------------------------------- /_misc.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | def get_timestamp(): 4 | return time.strftime("%Y%m%d_%H%M%S", time.localtime()) 5 | 6 | def get_class_name(obj): 7 | return obj.__class__.__name__ 8 | 9 | 10 | -------------------------------------------------------------------------------- /_virtual_import.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def _try_import(package_name): 4 | ''' 5 | qzl: if you use this function, you will no longer receive the coding hint 6 | ''' 7 | try: 8 | return __import__(package_name) 9 | except (ModuleNotFoundError, ImportError) as e: 10 | return _virtual_import(package_name, e) 11 | 12 | 13 | def _try_import_from(module_name, *object_names): 14 | """ 15 | qzl: if you use this function, you will no longer receive the coding hint 16 | Imports multiple objects from a module and returns them as a tuple. 17 | 18 | Args: 19 | module_name (str): Name of the module. 20 | *object_names (str): Names of the objects to import from the module. 21 | 22 | Returns: 23 | tuple: Tuple containing the imported objects. 24 | """ 25 | assert len(object_names) > 0, "Must provide at least one object name. " \ 26 | "If you only want to import {}, use _try_import() instead.".format(module_name) 27 | try: 28 | module = __import__(module_name, fromlist=object_names) 29 | imported_objects = tuple(getattr(module, obj_name) for obj_name in object_names) 30 | except (ModuleNotFoundError, ImportError) as e: 31 | imported_objects = [_virtual_import(module_name, e)] * len(object_names) 32 | 33 | if len(imported_objects) == 1: 34 | return imported_objects[0] 35 | else: 36 | return imported_objects 37 | 38 | 39 | 40 | def _virtual_import(package_name, error_message=None): 41 | """ 42 | qzl: 43 | If a package exsits, it will return the package directly. 44 | If a package does not exist, this method is used to import it temporally for no error. 45 | When you actually call some property or method of it, it will raise an error. 46 | """ 47 | 48 | return _VirtualPackage(package_name, error_message) 49 | 50 | 51 | 52 | 53 | class _VirtualPackage: 54 | def __init__(self, package_name, error_message=None): 55 | self._name = package_name 56 | if error_message is None: 57 | self._error_message = 'No module named %s' % self._name 58 | else: 59 | self._error_message = error_message 60 | 61 | def __getattr__(self, item): 62 | raise ModuleNotFoundError(self._error_message) 63 | 64 | def __call__(self, *args, **kwargs): 65 | raise ModuleNotFoundError(self._error_message) 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /constraints/ConstraintCollection.py: -------------------------------------------------------------------------------- 1 | from spider.param import * 2 | from spider.elements.trajectory import Trajectory 3 | import numpy as np 4 | 5 | 6 | 7 | class ConstraintCollection: 8 | ''' 9 | todo: 未来改个形式,变成cvt/transform那种的变换的叠加的形式 10 | ''' 11 | # qzl: 这种形式有个弊端,只能处理上下界约束,如果是过程中可变约束怎么办 12 | traj_constraint_functions = { 13 | CONSTRIANT_SPEED_LB: lambda traj, config: np.all(np.array(traj.v) >= config['min_speed']), 14 | CONSTRIANT_SPEED_UB: lambda traj, config: np.all(np.array(traj.v) <= config['max_speed']), 15 | CONSTRIANT_ACCELERATION: lambda traj, config: np.all(np.array(traj.a) <= config["max_acceleration"]), 16 | CONSTRIANT_DECELERATION: lambda traj, config: np.all(np.array(traj.a) >= -config["max_deceleration"]), 17 | CONSTRIANT_CURVATURE: lambda traj, config: np.all(np.array(traj.curvature) <= config["max_curvature"]), 18 | CONSTRIANT_LATERAL_JERK: lambda traj, config: np.all(np.abs(np.array(traj.l_3dot)) <= config["max_lateral_jerk"]), 19 | CONSTRIANT_LONGITUDINAL_JERK: lambda traj, config: np.all(np.abs(np.array(traj.s_3dot)) <= config["max_longitudinal_jerk"]), 20 | } 21 | 22 | control_constraint_functions = { 23 | } 24 | 25 | 26 | def __init__(self, config: dict): 27 | self.config = config 28 | 29 | assert "constraint_flags" in config 30 | self.constraint_flags: set = config["constraint_flags"] 31 | 32 | # if "constraint_flags" in config: 33 | # self.constraint_flags:set = config["constraint_flags"] 34 | # else: 35 | # self.constraint_flags:set = { 36 | # CONSTRIANT_SPEED_UB, 37 | # CONSTRIANT_SPEED_LB, 38 | # CONSTRIANT_ACCELERATION, 39 | # CONSTRIANT_DECELERATION, 40 | # CONSTRIANT_CURVATURE 41 | # } # default 42 | 43 | def aggregate(self): 44 | ''' 45 | 把多个判断可行性的函数聚合成一个大函数 46 | todo: 想想看碰撞检测能不能融进来 47 | ''' 48 | 49 | if self.config.get("output", OUTPUT_TRAJECTORY) == OUTPUT_TRAJECTORY: 50 | all_funcs = self.traj_constraint_functions 51 | else: 52 | all_funcs = self.control_constraint_functions 53 | 54 | funcs = [all_funcs[key] for key in self.constraint_flags] 55 | feasibility_function = lambda traj: np.all([func(traj, self.config) for func in funcs]) 56 | return feasibility_function 57 | 58 | 59 | def formulate(self): 60 | ''' 61 | 形成优化算法里面的约束条件 62 | ''' 63 | pass 64 | 65 | 66 | 67 | if __name__ == '__main__': 68 | config = { 69 | "output": OUTPUT_TRAJECTORY, 70 | "steps": 50, 71 | "dt": 0.1, 72 | "ego_veh_length": 5.0, 73 | "ego_veh_width": 2.0, 74 | "max_speed": 60 / 3.6, 75 | "min_speed": 0, 76 | "max_acceleration": 10, 77 | "max_deceleration": 10, 78 | # "max_centripetal_acceleration" : 100, 79 | "max_curvature": 100, 80 | "end_s_candidates": (10, 20, 40, 60), 81 | "end_l_candidates": (-0.8, 0, 0.8), # s,d采样生成横向轨迹 (-3.5, 0, 3.5), # 82 | "end_v_candidates": tuple(i * 60 / 3.6 / 3 for i in range(4)), # 改这一项的时候,要连着限速一起改了 83 | "end_T_candidates": (1, 2, 4, 8) # s_dot, T采样生成纵向轨迹 84 | } 85 | 86 | temp = ConstraintCollection(config) 87 | func = temp.aggregate() 88 | 89 | traj = Trajectory 90 | traj.v = [3,4,5000,6,7,8,6] 91 | 92 | x = func(traj) 93 | print(x) 94 | 95 | 96 | 97 | # 98 | # import timeit 99 | # 100 | # my_dict = {'name': 'John', 'age': 30, 'city': 'New York'} 101 | # my_dict.update({key:0 for key in range(10000)}) 102 | # class my_object: name = "John" 103 | # 104 | # 105 | # # 使用字符串进行字典键索引 106 | # def dict_index(): 107 | # return my_dict['name'] 108 | # 109 | # 110 | # # 使用对象属性访问 111 | # def object_attr(): 112 | # return my_object.name 113 | # 114 | # 115 | # # 测试性能 116 | # dict_time = timeit.timeit(dict_index, number=1000000) 117 | # object_time = timeit.timeit(object_attr, number=1000000) 118 | # 119 | # print(f"字典键索引耗时: {dict_time} 秒") 120 | # print(f"对象属性访问耗时: {object_time} 秒") 121 | -------------------------------------------------------------------------------- /constraints/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | qzl: 包含两部分内容: 3 | 4 | 第一是ConstraintChecker,主要针对采样算法,即给定一条轨迹,需要给出是否可行 5 | 第二是ConstriantFormulator ,主要针对优化算法,即列出约束条件. 6 | 另外,可以考虑是不是建立一个ConstraintCollection?保存各类约束条件?可以后面再考虑 7 | 8 | 9 | ''' 10 | from spider.constraints.constraint_checker import * 11 | from spider.constraints.constraint_formulator import * 12 | from spider.constraints.ConstraintCollection import ConstraintCollection 13 | 14 | -------------------------------------------------------------------------------- /constraints/constraint_checker/BaseConstraintChecker.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | class BaseConstraintChecker: 4 | def __init__(self, config: dict): 5 | self.config = config 6 | 7 | @abstractmethod 8 | def check(self, *args) -> bool: 9 | pass 10 | 11 | -------------------------------------------------------------------------------- /constraints/constraint_checker/CartConstraintChecker.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 笛卡尔坐标下的约束检查,利用的是笛卡尔坐标下的参数 3 | ''' 4 | 5 | from spider.param import * 6 | from spider.constraints.constraint_checker import BaseConstraintChecker 7 | from spider.constraints.ConstraintCollection import ConstraintCollection 8 | 9 | from spider.elements.trajectory import Trajectory 10 | 11 | 12 | class CartConstriantChecker(BaseConstraintChecker): 13 | def __init__(self, config:dict, collision_checker=None): 14 | super(CartConstriantChecker, self).__init__(config) 15 | # todo:config的存在很混乱,以后要修改constraintchecker的输入定义 16 | 17 | if not ("constraint_flags" in config): 18 | self.config.update({ 19 | "constraint_flags": set() 20 | }) 21 | # self.config.update({ 22 | # "constraint_flags":{ 23 | # CONSTRIANT_SPEED_UB, 24 | # CONSTRIANT_SPEED_LB, 25 | # CONSTRIANT_ACCELERATION, 26 | # CONSTRIANT_DECELERATION, 27 | # CONSTRIANT_CURVATURE} 28 | # }) 29 | 30 | self.kinematics_feasibility_check = ConstraintCollection(self.config).aggregate() # 这是一个函数 31 | self.collision_checker = collision_checker 32 | 33 | def check_kinematics(self, trajectory:Trajectory): 34 | return self.kinematics_feasibility_check(trajectory) 35 | 36 | def check_collision(self, trajectory:Trajectory, predicted_perception): 37 | if self.collision_checker is None: 38 | collision = False 39 | else: 40 | collision: bool = self.collision_checker.check_trajectory(trajectory, predicted_perception) 41 | return collision 42 | 43 | 44 | def check(self, trajectory:Trajectory, predicted_perception) -> bool: 45 | feasible: bool = self.kinematics_feasibility_check(trajectory) 46 | collision: bool = self.check_collision(trajectory, predicted_perception) 47 | 48 | return feasible and (not collision) 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /constraints/constraint_checker/ControlConstraintChecker.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 专门针对控制量的约束检查,别的都是针对轨迹的 3 | 在hybrid A*这类图搜索的算法里面用得上 4 | ''' 5 | 6 | from BaseConstraintChecker import BaseConstraintChecker 7 | 8 | 9 | class ControlConstraintChecker(BaseConstraintChecker): 10 | def __init__(self, config): 11 | super(ControlConstraintChecker, self).__init__(config) 12 | pass 13 | 14 | 15 | -------------------------------------------------------------------------------- /constraints/constraint_checker/FrenetConstraintChecker.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/constraints/constraint_checker/FrenetConstraintChecker.py -------------------------------------------------------------------------------- /constraints/constraint_checker/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.constraints.constraint_checker.BaseConstraintChecker import BaseConstraintChecker 2 | from spider.constraints.constraint_checker.CartConstraintChecker import CartConstriantChecker 3 | 4 | 5 | -------------------------------------------------------------------------------- /constraints/constraint_formulator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/constraints/constraint_formulator/__init__.py -------------------------------------------------------------------------------- /control/IDMController.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | from typing import TypeVar 6 | from spider.control.lateral import PurePursuitController 7 | from spider.control.longitudinal import IDMLonController 8 | 9 | Trajectory = TypeVar("Trajectory") # 避免循环引用 10 | VehicleState = TypeVar("VehicleState") 11 | 12 | class IDMController(object): 13 | 14 | def __init__(self): 15 | self._max_brake = 2.0 16 | self._max_throt = 1.0 17 | self._max_steer = 0.8 18 | 19 | self._lon_controller = IDMLonController(self._max_throt) 20 | self._lat_controller = PurePursuitController() 21 | 22 | 23 | 24 | def get_control(self, reference_line:np.ndarray, front_veh_speed, front_veh_dist, desired_speed, 25 | current_pose, current_speed): 26 | acc = self._lon_controller.run_step(current_speed, desired_speed, front_veh_speed, front_veh_dist) 27 | steering = self._lat_controller.run_step(np.asarray(reference_line), current_pose, current_speed) 28 | 29 | acc = np.clip(acc, -self._max_brake, self._max_throt).item() 30 | steering = np.clip(steering, -self._max_steer, self._max_steer).item() 31 | return acc, steering 32 | 33 | 34 | def get_fallback_control(self, brake=None): 35 | ''' 36 | Returns a control that will bring the vehicle to a halt in case of 37 | failure of control module or planning module. 38 | :param brake: amount of brake to apply 39 | :return: control 40 | ''' 41 | acc = -self._max_brake if brake is None else -brake 42 | steering = 0.0 43 | 44 | return acc, steering 45 | 46 | 47 | -------------------------------------------------------------------------------- /control/SimpleController.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | from typing import TypeVar 6 | from spider.control.lateral import PurePursuitController 7 | from spider.control.longitudinal import PIDLonController 8 | 9 | Trajectory = TypeVar("Trajectory") # 避免循环引用 10 | VehicleState = TypeVar("VehicleState") 11 | 12 | class SimpleController(object): 13 | 14 | def __init__(self): 15 | self._lon_controller = PIDLonController() 16 | self._lat_controller = PurePursuitController() 17 | 18 | self._max_brake = 0.5 19 | self._max_throt = 0.75 20 | self._max_steer = 0.8 21 | 22 | def get_control(self, trajectory:Trajectory, ego_veh_state:VehicleState=None): 23 | # target 24 | target_speed = trajectory.v[-1] 25 | traj_arr = trajectory.trajectory_array 26 | 27 | # current 28 | if ego_veh_state is None: 29 | warnings.warn("Ego_state is not provided. Using the trajectory starting point instead.") 30 | current_speed = trajectory.v[0] 31 | ego_pose = np.array([trajectory.x[0], trajectory.y[0], trajectory.heading[0]]) 32 | else: 33 | current_speed = ego_veh_state.v() 34 | ego_pose = np.array([ego_veh_state.x(), ego_veh_state.y(), ego_veh_state.yaw()]) 35 | 36 | acc, steering = self.get_control_(traj_arr, target_speed, ego_pose, current_speed) 37 | 38 | return acc, steering 39 | 40 | def get_control_(self, trajectory_array:np.ndarray, target_speed, current_pose, current_speed): 41 | ''' 42 | :param trajectory_array: [[x1,y1],[x2,y2]] 43 | :param target_speed: 44 | :param current_pose: [x,y,yaw] 45 | :param current_speed: 46 | :return: 47 | ''' 48 | acc = self._lon_controller.run_step(target_speed, current_speed) 49 | steering = self._lat_controller.run_step(np.asarray(trajectory_array), current_pose, current_speed) 50 | 51 | acc = np.clip(acc, -self._max_brake, self._max_throt).item() 52 | steering = np.clip(steering, -self._max_steer, self._max_steer).item() 53 | return acc, steering 54 | 55 | 56 | def get_fallback_control(self, brake=None): 57 | ''' 58 | Returns a control that will bring the vehicle to a halt in case of 59 | failure of control module or planning module. 60 | :param brake: amount of brake to apply 61 | :return: control 62 | ''' 63 | acc = -self._max_brake if brake is None else -brake 64 | steering = 0.0 65 | 66 | return acc, steering 67 | 68 | if __name__ == '__main__': 69 | ctl = SimpleController() 70 | a, st = ctl.get_control_(np.array([[0,0],[1,1]]), 10, [0,0,0], 10) 71 | print(a, st) 72 | 73 | -------------------------------------------------------------------------------- /control/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.control.SimpleController import SimpleController 2 | from spider.control.vehicle_model import * 3 | -------------------------------------------------------------------------------- /control/lateral/PurePursuitController.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | class PurePursuitController(object): 5 | 6 | def __init__(self): 7 | pass 8 | 9 | def run_step(self, trajectory_array, ego_pose, current_speed): 10 | """ 11 | Execute one step of lateral control to steer the vehicle towards a certain waypoin. 12 | """ 13 | 14 | ego_loc = ego_pose[:2] 15 | control_point = self._control_point(trajectory_array, ego_loc, current_speed) 16 | 17 | if len(control_point) < 2: 18 | return 0.0 19 | return self._purepersuit_control(control_point, ego_pose) 20 | 21 | def _control_point(self, trajectory_arr, ego_loc, current_speed, resolution=0.1): 22 | from spider.utils.geometry import resample_polyline 23 | 24 | if current_speed > 10: 25 | control_target_dt = 0.5 - (current_speed - 10) * 0.01 26 | else: 27 | control_target_dt = 0.5 28 | # control_target_dt = 0.4 29 | 30 | control_target_distance = control_target_dt * current_speed ## m 31 | if control_target_distance < 3: 32 | control_target_distance = 3 33 | 34 | trajectory_dense = resample_polyline(trajectory_arr, resolution) 35 | 36 | end_idx = self.get_next_idx(ego_loc, trajectory_dense, control_target_distance) 37 | wp_loc = trajectory_dense[end_idx] 38 | 39 | return wp_loc 40 | 41 | def _purepersuit_control(self, waypoint, ego_pose): 42 | """ 43 | Estimate the steering angle of the vehicle based on the PID equations 44 | 45 | :param waypoint: target waypoint 46 | :param vehicle_transform: current transform of the vehicle 47 | :return: steering control in the range [-1, 1] 48 | """ 49 | 50 | ego_x, ego_y, ego_yaw = ego_pose 51 | # ego_yaw = ego_vehicle.yaw 52 | # ego_x = ego_vehicle.x 53 | # ego_y = ego_vehicle.y 54 | 55 | v_vec = np.array([math.cos(ego_yaw), 56 | math.sin(ego_yaw), 57 | 0.0]) 58 | 59 | target_x = waypoint[0] 60 | target_y = waypoint[1] 61 | 62 | w_vec = np.array([target_x - 63 | ego_x, target_y - 64 | ego_y, 0.0]) 65 | 66 | _dot = math.acos(np.clip(np.dot(w_vec, v_vec) / 67 | (np.linalg.norm(w_vec) * np.linalg.norm(v_vec)), -1.0, 1.0)) 68 | 69 | _cross = np.cross(v_vec, w_vec) 70 | if _cross[2] < 0: 71 | _dot *= -1.0 72 | 73 | lf = 1.2 74 | lr = 1.95 75 | lwb = lf + lr 76 | 77 | v_rear_x = ego_x - v_vec[0] * lr / np.linalg.norm(v_vec) 78 | v_rear_y = ego_y - v_vec[1] * lr / np.linalg.norm(v_vec) 79 | l = (target_x - v_rear_x) * (target_x - v_rear_x) + (target_y - v_rear_y) * (target_y - v_rear_y) 80 | l = math.sqrt(l) 81 | 82 | theta = np.arctan(2 * np.sin(_dot) * lwb / l) 83 | 84 | k = 1.0 # XXX: np.pi/180*50 85 | theta = theta * k 86 | return theta 87 | 88 | # def convert_trajectory_to_ndarray(self, trajectory): 89 | # trajectory_array = [(pose.pose.position.x, pose.pose.position.y) for pose in trajectory.poses] 90 | # return np.array(trajectory_array) 91 | 92 | def get_idx(self, loc, trajectory): 93 | dist = np.linalg.norm(trajectory - loc, axis=1) 94 | idx = np.argmin(dist) 95 | return idx 96 | 97 | def get_next_idx(self, start_loc, trajectory, distance): 98 | 99 | start_idx = self.get_idx(start_loc, trajectory) 100 | dist_list = np.cumsum(np.linalg.norm(np.diff(trajectory, axis=0), axis=1)) 101 | for end_idx in range(start_idx, len(trajectory) - 1): 102 | if dist_list[end_idx] > dist_list[start_idx] + distance: 103 | return end_idx 104 | -------------------------------------------------------------------------------- /control/lateral/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.control.lateral.PurePursuitController import PurePursuitController 2 | 3 | -------------------------------------------------------------------------------- /control/longitudinal/IDMLonController.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import math 3 | 4 | class IDMLonController(object): 5 | 6 | def __init__(self, a=1.0, b=2, delta=2, s0=8.0, T=2.5): 7 | """ 8 | Initialize the IDM controller with the given parameters. 9 | 10 | a: maximum acceleration of the vehicle 11 | b: comfortable deceleration of the vehicle 12 | delta: acceleration exponent 13 | s0: minimum distance to the front vehicle 14 | T: safe time headway to the front vehicle 15 | """ 16 | self.a = a 17 | self.b = b 18 | self.delta = delta 19 | self.s0 = s0 20 | self.T = T 21 | 22 | def run_step(self, v, v_desired, v_front, s): 23 | """ 24 | Execute one step of longitudinal control. 25 | 26 | v: current speed of the vehicle 27 | v_desired: desired speed of the vehicle 28 | v_front: speed of the vehicle in front 29 | s: distance to the vehicle in front 30 | 31 | return: throttle control in the range [0, 1] 32 | """ 33 | acc = self._calc_acceleration(v, v_desired, v_front, s) 34 | return acc#np.clip(acc, -self.a, self.a) 35 | 36 | def _calc_acceleration(self, v, v_desired, v_front, s): 37 | delta_v = v - v_front 38 | s_star = self.s0 + max(0, v * self.T + v * delta_v / (2 * np.sqrt(self.a * self.b))) 39 | return self.a * (1 - (v / v_desired) ** self.delta - (s_star / s) ** 2) 40 | 41 | -------------------------------------------------------------------------------- /control/longitudinal/PIDLonController.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import numpy as np 3 | 4 | class PIDLonController(object): 5 | 6 | def __init__(self): 7 | """ 8 | vehicle: actor to apply to local planner logic onto 9 | K_P: Proportional term 10 | K_D: Differential term 11 | K_I: Integral term 12 | dt: time differential in seconds 13 | """ 14 | self._K_P = 0.25 / 3.6 15 | self._K_D = 0 # .01 16 | self._K_I = 0 # 0.012 #FIXME: To stop accmulate error when repectly require control signal in the same state. 17 | self._dt = 0.1 # TODO: timestep 18 | self._integ = 0.0 19 | self._e_buffer = deque(maxlen=30) 20 | 21 | def run_step(self, target_speed, current_speed, debug=False): 22 | """ 23 | Execute one step of longitudinal control to reach a given target speed. 24 | 25 | target_speed: target speed in Km/h 26 | return: throttle control in the range [0, 1] 27 | """ 28 | 29 | return self._pid_control(target_speed, current_speed) 30 | 31 | def _pid_control(self, target_speed, current_speed): 32 | """ 33 | Estimate the throttle of the vehicle based on the PID equations 34 | 35 | :param target_speed: target speed in Km/h 36 | :param current_speed: current speed of the vehicle in Km/h 37 | :return: throttle control in the range [-1, 1] 38 | """ 39 | if target_speed == 0: 40 | return -1 41 | 42 | target_speed = target_speed * 3.6 43 | current_speed = current_speed * 3.6 44 | 45 | _e = (target_speed - current_speed) 46 | self._integ += _e * self._dt 47 | self._e_buffer.append(_e) 48 | 49 | if current_speed < 2: 50 | self._integ = 0 51 | 52 | if len(self._e_buffer) >= 2: 53 | _de = (self._e_buffer[-1] - self._e_buffer[-2]) / self._dt 54 | _ie = self._integ 55 | 56 | else: 57 | _de = 0.0 58 | _ie = 0.0 59 | kp = self._K_P 60 | ki = self._K_I 61 | kd = self._K_D 62 | 63 | if target_speed < 5: 64 | ki = 0 65 | kd = 0 66 | 67 | calculate_value = np.clip((kp * _e) + (kd * _de) + (ki * _ie), -1.0, 1.0) 68 | return calculate_value 69 | 70 | -------------------------------------------------------------------------------- /control/longitudinal/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.control.longitudinal.PIDLonController import PIDLonController 2 | from spider.control.longitudinal.IDMLonController import IDMLonController 3 | -------------------------------------------------------------------------------- /control/vehicle_model/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.control.vehicle_model.bicycle import * 2 | -------------------------------------------------------------------------------- /control/vehicle_model/bicycle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/control/vehicle_model/bicycle.png -------------------------------------------------------------------------------- /control/vehicle_model/bicycle.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from typing import List 4 | 5 | 6 | # class BicycleState(np.ndarray): 7 | # def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None, order=None): 8 | # super(BicycleState, self).__init__(shape, dtype, buffer, offset, strides, order) 9 | 10 | def curvature2steer(curvature, wheelbase=3.0): 11 | steer = np.arctan(wheelbase * curvature) 12 | return steer 13 | 14 | def steer2curvature(steer, wheelbase=3.0): 15 | k = np.tan(steer) / wheelbase 16 | return k 17 | 18 | 19 | class Bicycle: 20 | def __init__(self, 21 | x=0., y=0., v=0., a=0., heading=0., steer=0., steer_velocity=0., *, 22 | dt=0.1, wheelbase=3.,): 23 | # self.state = [] 24 | # self.control = [] 25 | self.dt = dt 26 | self.wheelbase = wheelbase 27 | self.setKinematics(x, y, v, a, heading, steer, steer_velocity) 28 | 29 | 30 | def setKinematics(self, x, y, v, a, heading, steer, steer_velocity=0.): 31 | self.x = x 32 | self.y = y 33 | self.heading = heading 34 | self.steer = steer 35 | self.velocity = v 36 | self.acceleration = a 37 | self.steer_velocity = steer_velocity 38 | self.curvature = steer2curvature(steer, wheelbase=self.wheelbase) 39 | 40 | def step(self, a, steer=None, steer_velocity=None, dt=0.): 41 | if dt==0.: 42 | dt = self.dt 43 | 44 | if steer is not None: 45 | v = self.velocity + a * dt 46 | heading = self.heading + dt * self.velocity * math.tan(steer) / self.wheelbase 47 | x = self.x + self.velocity * math.cos(heading) * dt 48 | y = self.y + self.velocity * math.sin(heading) * dt 49 | steer_velocity = (steer-self.steer) / dt 50 | self.setKinematics(x, y, v, a, heading, steer, steer_velocity) 51 | elif steer_velocity is not None: 52 | pass #TODO:补充 53 | else: 54 | raise ValueError("Invalid input") 55 | 56 | def derivative(self, x, y, dt=0.): 57 | if dt==0.: 58 | dt = self.dt 59 | dx = x - self.x 60 | dy = y - self.y 61 | heading = math.atan2(dy,dx) # x+1时刻的 62 | v = dx / dt / math.cos(heading) # x+1时刻的 63 | dv = v - self.velocity 64 | a = dv/dt 65 | dheading = heading-self.heading 66 | steer = math.atan(dheading * self.wheelbase / dt / self.velocity) if self.velocity != 0 else 0 67 | steer_velocity = (steer - self.steer)/dt 68 | self.setKinematics(x, y, v, a, heading, steer, steer_velocity) 69 | 70 | # @staticmethod 71 | def accsteer2state(self, acc:np.ndarray, steer:np.ndarray): 72 | assert acc.shape[0] == steer.shape[0] 73 | #acc: 0-N-1, steer:0-N-1 74 | N = acc.shape[0] 75 | accumulation = np.zeros((N+1,N)) 76 | accumulation[1:, :] = np.tril(np.ones((N, N))) 77 | V = accumulation @ acc * self.dt + self.velocity # 0-N 78 | heading = accumulation @ (V[:-1] * np.tan(steer) * self.dt / self.wheelbase) + self.heading# 0-N 79 | Y = accumulation @ (V[:-1] * np.sin(heading[:-1]) * self.dt) + self.y # 0-N 80 | X = accumulation @ (V[:-1] * np.cos(heading[:-1]) * self.dt) + self.x # 0-N 81 | return X, Y, V, heading 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /control/vehicle_model/steer_curvature.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/control/vehicle_model/steer_curvature.jpg -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.data.DataBuffer import BaseBuffer, LogBuffer, ExperienceBuffer 2 | from spider.data.Dataset import OfflineLogDataset, OfflineExpDataset 3 | 4 | -------------------------------------------------------------------------------- /data/common.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /data/decorators.py: -------------------------------------------------------------------------------- 1 | import spider 2 | 3 | def logbuffer_plan(plan_func): 4 | ''' 5 | 此装饰器,将planner的plan函数,包装成logbuffer_plan函数。 6 | 监听plan函数的输入输出,存入log buffer 7 | ''' 8 | if not hasattr(logbuffer_plan, "t"): 9 | logbuffer_plan.t = 0.0 10 | 11 | if hasattr(plan_func, "__self__"): 12 | # 将实例化后的planner的plan函数,封装起来。 13 | # 用于在planner已经实例化后,在外部加装装饰器(logbuffer.apply_to()) 14 | planner_instance: spider.planner_zoo.BasePlanner = plan_func.__self__ # 实例化后的对象 15 | def wrapper(*args, **kwargs): 16 | plan = plan_func(*args, **kwargs) 17 | 18 | if getattr(planner_instance, "_activate_log_buffer"): 19 | timestamp = logbuffer_plan.t 20 | observation = args[:3] # todo:这里有问题,observation其中有的量可能会以kwargs输入,看看怎么处理 21 | plan = plan 22 | planner_instance.log_buffer.record_forward(timestamp, observation, plan) 23 | 24 | logbuffer_plan.t += planner_instance.dt 25 | return plan 26 | 27 | else: 28 | # 将实例化前的planner的plan函数,封装起来。 29 | # 用于在planner代码构建的时候,用@logbuffer_plan来装饰plan函数 30 | # todo: 这里注意,在给planner这个装饰器的的时候,要把buffer.STORE_FORWARD_ONLY设为TRUE 31 | 32 | def wrapper(*args, **kwargs): 33 | plan = plan_func(*args, **kwargs) 34 | 35 | planner_instance: spider.planner_zoo.BasePlanner = args[0] 36 | if getattr(planner_instance, "_activate_log_buffer"): 37 | timestamp = logbuffer_plan.t 38 | observation = args[1:4] # todo:这里有问题,observation其中有的量可能会以kwargs输入,看看怎么处理 39 | plan = plan 40 | planner_instance.log_buffer.record_forward(timestamp, observation, plan) 41 | 42 | logbuffer_plan.t += planner_instance.dt # todo:目前这个时间戳的计算方式有问题 43 | return plan 44 | 45 | return wrapper 46 | 47 | 48 | def expbuffer_policy(forward_func): 49 | ''' 50 | 此装饰器,将策略网络的forward函数,包装成expbuffer_policy函数。 51 | ''' 52 | # if not hasattr(expbuffer_policy, "t"): 53 | # expbuffer_policy.t = 0.0 54 | 55 | if hasattr(forward_func, "__self__"): 56 | # 将实例化后的policy的forward函数,封装起来。 57 | # 用于在policy已经实例化后,在外部加装装饰器(expbuffer.apply_to()) 58 | policy_instance = forward_func.__self__ # 实例化后的对象 59 | def wrapper(*args, **kwargs): 60 | action = forward_func(*args, **kwargs) 61 | 62 | if getattr(policy_instance, "_activate_exp_buffer"): 63 | state = args[0] # todo:state一定都会放在第一个吗,可能得统一一下policy的输入输出形式 64 | action = action 65 | others = getattr(policy_instance, "_exp_extra_data", []) 66 | policy_instance._exp_extra_data = [] 67 | policy_instance._exp_buffer.record_forward(state, action, others=others) 68 | 69 | # expbuffer_policy.t += policy_instance.dt if hasattr(policy_instance, "dt") else 1 70 | return action 71 | 72 | else: 73 | # 将实例化前的policy的forward函数,封装起来。 74 | # 用于在policy代码构建的时候,用@expbuffer_policy来装饰forward函数 75 | def wrapper(*args, **kwargs): 76 | action = forward_func(*args, **kwargs) 77 | policy_instance = args[0] 78 | 79 | if getattr(policy_instance, "_activate_exp_buffer"): 80 | state = args[1] # todo:state一定都会放在第一个吗,可能得统一一下policy的输入输出形式 81 | action = action 82 | others = getattr(policy_instance, "_exp_extra_data", []) 83 | policy_instance._exp_extra_data = [] 84 | policy_instance._exp_buffer.record_forward(state, action, others=others) 85 | 86 | # expbuffer_policy.t += policy_instance.dt if hasattr(policy_instance, "dt") else 1 87 | return action 88 | 89 | return wrapper 90 | 91 | 92 | def expbuffer_reward(reward_func): 93 | ''' 94 | 此装饰器,将策略网络的forward函数,包装成expbuffer_policy函数。 95 | ''' 96 | # if not hasattr(expbuffer_policy, "t"): 97 | # expbuffer_policy.t = 0.0 98 | 99 | if hasattr(reward_func, "__self__"): 100 | # 将实例化后的policy的forward函数,封装起来。 101 | # 用于在policy已经实例化后,在外部加装装饰器(expbuffer.apply_to()) 102 | reward_instance = reward_func.__self__ # 实例化后的对象 103 | def wrapper(*args, **kwargs): 104 | reword, done = reward_func(*args, **kwargs) 105 | 106 | if getattr(reward_instance, "_activate_exp_buffer"): 107 | reward_instance._exp_buffer.record_feedback(reword, done) 108 | 109 | return reword, done 110 | 111 | else: 112 | # 将实例化前的policy的forward函数,封装起来。 113 | # 用于在policy代码构建的时候,用@expbuffer_policy来装饰forward函数 114 | def wrapper(*args, **kwargs): 115 | reword, done = reward_func(*args, **kwargs) 116 | reward_instance = args[0] 117 | 118 | if getattr(reward_instance, "_activate_exp_buffer"): 119 | reward_instance._exp_buffer.record_feedback(reword, done) 120 | 121 | return reword, done 122 | 123 | return wrapper 124 | -------------------------------------------------------------------------------- /display_assests/DQNPlanner0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/display_assests/DQNPlanner0.gif -------------------------------------------------------------------------------- /display_assests/GRU_log_replay.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/display_assests/GRU_log_replay.gif -------------------------------------------------------------------------------- /display_assests/LatticePlanner.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/display_assests/LatticePlanner.gif -------------------------------------------------------------------------------- /display_assests/LatticePlanner_highway.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/display_assests/LatticePlanner_highway.gif -------------------------------------------------------------------------------- /display_assests/OptimizedLatticePlanner.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/display_assests/OptimizedLatticePlanner.gif -------------------------------------------------------------------------------- /display_assests/carla_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/display_assests/carla_test.png -------------------------------------------------------------------------------- /display_assests/common_tools.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/display_assests/common_tools.png -------------------------------------------------------------------------------- /display_assests/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/display_assests/framework.png -------------------------------------------------------------------------------- /display_assests/planner_arena.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/display_assests/planner_arena.png -------------------------------------------------------------------------------- /elements/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.elements.box import TrackingBoxList, TrackingBox, BoundingBox 2 | from spider.elements.map import ScenarioType, TrafficLight, Lane, LocalMap, RoutedLocalMap 3 | from spider.elements.grid import OccupancyGrid2D 4 | from spider.elements.vehicle import VehicleState, Location, Rotation, Transform, Vector3D 5 | from spider.elements.trajectory import Trajectory, FrenetTrajectory, Path 6 | 7 | from typing import Tuple, Union 8 | 9 | Observation = Tuple[ 10 | VehicleState, 11 | Union[TrackingBoxList,OccupancyGrid2D], 12 | Union[RoutedLocalMap,LocalMap] 13 | ] 14 | 15 | Plan = Union[Trajectory, FrenetTrajectory] # will add control in the future version 16 | 17 | def __getattr__(name): 18 | if name == "Box": 19 | raise ValueError("spider.elements.Box has been renamed as spider.elements.box (small case).") 20 | 21 | else: 22 | raise AttributeError("module 'spider.elements' has no attribute '{}'".format(name)) 23 | 24 | 25 | -------------------------------------------------------------------------------- /elements/graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Node: 4 | def __init__(self): 5 | pass 6 | 7 | class Edge: 8 | def __init__(self): 9 | pass 10 | # 11 | # class Graph: 12 | # def __init__(self): 13 | # self.edges = None 14 | # self.nodes = None 15 | # self. 16 | 17 | 18 | 19 | # import numpy as np 20 | # from spider.elements.trajectory import Path 21 | # from scipy.spatial import cKDTree 22 | # 23 | # 24 | # class hybridAStarPlanner: 25 | # def __init__(self, x_reso, y_reso, yaw_reso): 26 | # self.x_reso = x_reso 27 | # self.y_reso = y_reso 28 | # self.yaw_reso = yaw_reso 29 | # 30 | # # def planWithCostMap(self, start_pose, end_pose, costmap): 31 | # # path = Path() 32 | # # 33 | # # return path 34 | # 35 | # 36 | # 37 | # if __name__ == '__main__': 38 | # XY_GRID_RESO = 0.1 39 | # YAW_GRID_RESO = np.deg2rad(15.) 40 | # 41 | # pass -------------------------------------------------------------------------------- /evaluator/CostEvaluator.py: -------------------------------------------------------------------------------- 1 | from spider.elements.trajectory import Trajectory, FrenetTrajectory 2 | import numpy as np 3 | from typing import List 4 | 5 | 6 | class CartCostEvaluator: 7 | def __init__(self, weight_comfort=1.0, weight_efficiency=40., weight_safety=1.0): 8 | # self.weight_lat_comfort = 5.0 9 | # self.weight_long_comfort = 1.0 10 | 11 | self.weight_comfort = 1.0 12 | self.weight_efficiency = 40.0 13 | self.weight_safety = 1.0 14 | 15 | def evaluate(self, traj: FrenetTrajectory): 16 | ''' 17 | 评价舒适性、通行效率、安全性 18 | ''' 19 | 20 | comfort = np.sum(np.asarray(traj.jerk) ** 2) + \ 21 | np.sum(np.abs(traj.centripetal_acceleration)) 22 | 23 | efficiency = -np.sum(traj.v) 24 | 25 | safety = np.sum(np.array(traj.l) ** 2) 26 | 27 | cost = self.weight_comfort * comfort +\ 28 | self.weight_efficiency * efficiency +\ 29 | self.weight_safety * safety 30 | 31 | return cost 32 | 33 | def evaluate_candidates(self, trajectory_list:List[FrenetTrajectory]): 34 | all_cost = [self.evaluate(t) for t in trajectory_list] 35 | idx = list(range(len(all_cost))) 36 | sorted_cost, sorted_idx = zip(*sorted(zip(all_cost, idx))) 37 | sorted_trajectories = [trajectory_list[i] for i in sorted_idx] 38 | # sorted_cost, sorted_trajectories = zip(*sorted(zip(all_cost, trajectory_list))) 39 | return sorted_trajectories,sorted_cost 40 | 41 | 42 | class FrenetCostEvaluator: 43 | def __init__(self): 44 | self.weight_lat_comfort = 5.0 45 | self.weight_long_comfort = 1.0 46 | 47 | self.weight_comfort = 1.0 48 | self.weight_efficiency = 1.0 49 | self.weight_safety = 1.0 50 | 51 | def evaluate(self, traj: FrenetTrajectory): 52 | ''' 53 | 评价舒适性、通行效率、安全性 54 | todo:参数搞成可调的,加入更多可选的损失项,比如向心加速度什么的 55 | todo:横向的l_3dot和l_3prime始终难以统一,如何是好? 56 | ''' 57 | 58 | comfort = np.sum( 59 | self.weight_long_comfort * np.array(traj.s_3dot) ** 2 + \ 60 | self.weight_lat_comfort * np.array(traj.l_3prime) ** 2 61 | # self.weight_lat_comfort * np.array(traj.l_3dot) ** 2 62 | ) 63 | 64 | efficiency = -(traj.s[-1]-traj.s[0]) ** 2 -\ 65 | 5 * traj.s_dot[-1] ** 2 66 | 67 | safety = np.sum(np.array(traj.l) ** 2) 68 | if traj.l[-1] * traj.l[0] <0: # 异号,说明从左走到右了,惩罚 69 | safety *= 5 70 | 71 | cost = self.weight_comfort * comfort +\ 72 | self.weight_efficiency * efficiency +\ 73 | self.weight_safety * safety 74 | # cost = self.weight_efficiency * efficiency 75 | 76 | return cost 77 | 78 | def evaluate_candidates(self, trajectory_list:List[FrenetTrajectory]): 79 | all_cost = [self.evaluate(t) for t in trajectory_list] 80 | idx = list(range(len(all_cost))) 81 | sorted_cost, sorted_idx = zip(*sorted(zip(all_cost, idx))) 82 | sorted_trajectories = [trajectory_list[i] for i in sorted_idx] 83 | # sorted_cost, sorted_trajectories = zip(*sorted(zip(all_cost, trajectory_list))) 84 | return sorted_trajectories,sorted_cost 85 | 86 | 87 | -------------------------------------------------------------------------------- /evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.evaluator.CostEvaluator import FrenetCostEvaluator, CartCostEvaluator 2 | 3 | -------------------------------------------------------------------------------- /interface/__init__.py: -------------------------------------------------------------------------------- 1 | #todo: 到底应该叫做wrapper还是interface??? 2 | from spider.interface.BaseInterface import BaseInterface, DummyInterface 3 | from spider.interface.BaseBenchmark import BaseBenchmark, DummyBenchmark 4 | 5 | import spider 6 | try: # 这个写的不是特别好, 暂时先这样子 7 | from spider.interface.highway_env import HighwayEnvInterface, HighwayEnvBenchmark, HighwayEnvBenchmarkGUI 8 | except (ModuleNotFoundError, ImportError) as e: 9 | HighwayEnvInterface = HighwayEnvBenchmark = HighwayEnvBenchmarkGUI = spider._virtual_import("highway_env", e) 10 | 11 | try: 12 | from spider.interface.carla import CarlaInterface 13 | except (ModuleNotFoundError, ImportError) as e: 14 | CarlaInterface = spider._virtual_import("carla", e) 15 | 16 | 17 | 18 | # __all__ = [ 19 | # "BaseInterface", 20 | # "DummyInterface", 21 | # "BaseBenchmark", 22 | # "DummyBenchmark", 23 | # "HighwayEnvInterface", 24 | # # "HighwayEnvBenchmark", 25 | # # "HighwayEnvBenchmarkGUI", 26 | # # "CarlaInterface" 27 | # ] 28 | 29 | # def __getattr__(name): 30 | # if name == "HighwayEnvInterface": 31 | # from spider.interface.highway_env import HighwayEnvInterface 32 | # return HighwayEnvInterface 33 | # elif name == "HighwayEnvBenchmark": 34 | # return HighwayEnvBenchmark 35 | # elif name == "HighwayEnvBenchmarkGUI": 36 | # return HighwayEnvBenchmarkGUI 37 | # elif name == "CarlaInterface": 38 | # return CarlaInterface 39 | # else: 40 | # raise AttributeError("Module 'spider.interfaces' has no attribute '{}'".format(name)) 41 | 42 | 43 | -------------------------------------------------------------------------------- /interface/carla/__init__.py: -------------------------------------------------------------------------------- 1 | # __all__ = ['CarlaInterface', 'presets', 'common', 'visualize'] 2 | from spider.interface.carla.common import * 3 | from spider.interface.carla.visualize import * 4 | import spider.interface.carla.presets as presets 5 | 6 | from spider.interface.carla.CarlaInterface import CarlaInterface 7 | 8 | -------------------------------------------------------------------------------- /interface/carla/_light_utils.py: -------------------------------------------------------------------------------- 1 | import carla 2 | from typing import Union 3 | import numpy as np 4 | 5 | from spider.interface.carla.presets import CARLIGHT_TO_LABEL 6 | 7 | 8 | ###### car light ######## 9 | def set_autolight(traffic_manager, actors): 10 | for actor in actors: 11 | if actor.type_id.startswith("vehicle"): 12 | traffic_manager.update_vehicle_lights(actor, True) 13 | 14 | 15 | def set_car_light(actor, light_state: carla.VehicleLightState): 16 | if actor.type_id.startswith("vehicle"): 17 | actor.set_light_state(light_state) 18 | 19 | 20 | def get_car_light(actor) -> carla.VehicleLightState: 21 | if actor.type_id.startswith("vehicle"): 22 | light_state = actor.get_light_state() 23 | else: 24 | light_state = carla.VehicleLightState.NONE 25 | return light_state 26 | 27 | def decompose_car_light(car_light: Union[carla.VehicleLightState, int]): 28 | ''' 29 | example: 30 | >>> veh.get_light_state() 31 | carla.libcarla.VehicleLightState(11) 32 | >>> decompose_car_light(veh.get_light_state()) 33 | (array([0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]), ['Position', 'LowBeam', 'Brake']) 34 | >>> decompose_car_light(11) 35 | (array([0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]), ['Position', 'LowBeam', 'Brake']) 36 | ''' 37 | multi_hot = np.zeros(len(CARLIGHT_TO_LABEL), dtype=int) 38 | labels = [] 39 | for i, light in enumerate(CARLIGHT_TO_LABEL): 40 | if light & car_light: 41 | multi_hot[i] = 1 42 | labels.append(CARLIGHT_TO_LABEL[light]) 43 | return multi_hot, labels 44 | -------------------------------------------------------------------------------- /interface/carla/_weather_utils.py: -------------------------------------------------------------------------------- 1 | import carla 2 | import sys 3 | 4 | SUN_PRESETS = { 5 | 'day': (45.0, 0.0), 6 | 'night': (-90.0, 0.0), 7 | 'sunset': (0.5, 0.0) 8 | } 9 | 10 | WEATHER_PRESETS = { 11 | 'clear': [10.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0331], 12 | 'overcast': [80.0, 0.0, 0.0, 50.0, 2.0, 0.75, 0.1, 10.0, 0.0, 0.03, 0.0331], 13 | 'rain': [100.0, 80.0, 90.0, 100.0, 7.0, 0.75, 0.1, 100.0, 0.0, 0.03, 0.0331]} 14 | 15 | CAR_LIGHTS = { 16 | 'None' : [carla.VehicleLightState.NONE], 17 | 'Position' : [carla.VehicleLightState.Position], 18 | 'LowBeam' : [carla.VehicleLightState.LowBeam], 19 | 'HighBeam' : [carla.VehicleLightState.HighBeam], 20 | 'Brake' : [carla.VehicleLightState.Brake], 21 | 'RightBlinker' : [carla.VehicleLightState.RightBlinker], 22 | 'LeftBlinker' : [carla.VehicleLightState.LeftBlinker], 23 | 'Reverse' : [carla.VehicleLightState.Reverse], 24 | 'Fog' : [carla.VehicleLightState.Fog], 25 | 'Interior' : [carla.VehicleLightState.Interior], 26 | 'Special1' : [carla.VehicleLightState.Special1], 27 | 'Special2' : [carla.VehicleLightState.Special2], 28 | 'All' : [carla.VehicleLightState.All]} 29 | 30 | LIGHT_GROUP = { 31 | 'None' : [carla.LightGroup.NONE], 32 | # 'Vehicle' : [carla.LightGroup.Vehicle], 33 | 'Street' : [carla.LightGroup.Street], 34 | 'Building' : [carla.LightGroup.Building], 35 | 'Other' : [carla.LightGroup.Other]} 36 | 37 | def apply_sun_presets(args, weather): 38 | """Uses sun presets to set the sun position""" 39 | if args.sun is not None: 40 | if args.sun in SUN_PRESETS: 41 | weather.sun_altitude_angle = SUN_PRESETS[args.sun][0] 42 | weather.sun_azimuth_angle = SUN_PRESETS[args.sun][1] 43 | else: 44 | print("[ERROR]: Command [--sun | -s] '" + args.sun + "' not known") 45 | sys.exit(1) 46 | 47 | 48 | def apply_weather_presets(args, weather): 49 | """Uses weather presets to set the weather parameters""" 50 | if args.weather is not None: 51 | if args.weather in WEATHER_PRESETS: 52 | weather.cloudiness = WEATHER_PRESETS[args.weather][0] 53 | weather.precipitation = WEATHER_PRESETS[args.weather][1] 54 | weather.precipitation_deposits = WEATHER_PRESETS[args.weather][2] 55 | weather.wind_intensity = WEATHER_PRESETS[args.weather][3] 56 | weather.fog_density = WEATHER_PRESETS[args.weather][4] 57 | weather.fog_distance = WEATHER_PRESETS[args.weather][5] 58 | weather.fog_falloff = WEATHER_PRESETS[args.weather][6] 59 | weather.wetness = WEATHER_PRESETS[args.weather][7] 60 | weather.scattering_intensity = WEATHER_PRESETS[args.weather][8] 61 | weather.mie_scattering_scale = WEATHER_PRESETS[args.weather][9] 62 | weather.rayleigh_scattering_scale = WEATHER_PRESETS[args.weather][10] 63 | else: 64 | print("[ERROR]: Command [--weather | -w] '" + args.weather + "' not known") 65 | sys.exit(1) 66 | 67 | 68 | def apply_weather_values(args, weather): 69 | """Set weather values individually""" 70 | if args.azimuth is not None: 71 | weather.sun_azimuth_angle = args.azimuth 72 | if args.altitude is not None: 73 | weather.sun_altitude_angle = args.altitude 74 | if args.clouds is not None: 75 | weather.cloudiness = args.clouds 76 | if args.rain is not None: 77 | weather.precipitation = args.rain 78 | if args.puddles is not None: 79 | weather.precipitation_deposits = args.puddles 80 | if args.wind is not None: 81 | weather.wind_intensity = args.wind 82 | if args.fog is not None: 83 | weather.fog_density = args.fog 84 | if args.fogdist is not None: 85 | weather.fog_distance = args.fogdist 86 | if args.fogfalloff is not None: 87 | weather.fog_falloff = args.fogfalloff 88 | if args.wetness is not None: 89 | weather.wetness = args.wetness 90 | if args.scatteringintensity is not None: 91 | weather.scattering_intensity = args.scatteringintensity 92 | if args.miescatteringscale is not None: 93 | weather.mie_scattering_scale = args.miescatteringscale 94 | if args.rayleighscatteringscale is not None: 95 | weather.rayleigh_scattering_scale = args.rayleighscatteringscale -------------------------------------------------------------------------------- /interface/carla/presets.py: -------------------------------------------------------------------------------- 1 | import carla 2 | from carla import ColorConverter as cc 3 | 4 | viewer_sensor_presets = { 5 | # blueprint, color, description 6 | "camera_rgb": ['sensor.camera.rgb', cc.Raw, 'Camera RGB'], 7 | "camera_depth": ['sensor.camera.depth', cc.Raw, 'Camera Depth (Raw)'], 8 | "camera_gray_depth": ['sensor.camera.depth', cc.Depth, 'Camera Depth (Gray Scale)'], 9 | "camera_log_gray_depth": ['sensor.camera.depth', cc.LogarithmicDepth, 'Camera Depth (Logarithmic Gray Scale)'], 10 | "camera_seg": ['sensor.camera.semantic_segmentation', cc.Raw, 'Camera Semantic Segmentation (Raw)'], 11 | "camera_seg_city": ['sensor.camera.semantic_segmentation', cc.CityScapesPalette, 12 | 'Camera Semantic Segmentation (CityScapes Palette)'], 13 | "lidar": ['sensor.lidar.ray_cast', None, 'Lidar (Ray-Cast)'] 14 | } 15 | 16 | SUN_PRESETS = { 17 | 'day': (45.0, 0.0), 18 | 'night': (-90.0, 0.0), 19 | 'sunset': (0.5, 0.0)} 20 | 21 | WEATHER_PRESETS = { 22 | 'clear': [10.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0331], 23 | 'overcast': [80.0, 0.0, 0.0, 50.0, 2.0, 0.75, 0.1, 10.0, 0.0, 0.03, 0.0331], 24 | 'rain': [100.0, 80.0, 90.0, 100.0, 7.0, 0.75, 0.1, 100.0, 0.0, 0.03, 0.0331]} 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | LABEL_TO_CARLIGHT = { 34 | 'None' : carla.VehicleLightState.NONE, 35 | 'Position' : carla.VehicleLightState.Position, 36 | 'LowBeam' : carla.VehicleLightState.LowBeam, 37 | 'HighBeam' : carla.VehicleLightState.HighBeam, 38 | 'Brake' : carla.VehicleLightState.Brake, 39 | 'RightBlinker' : carla.VehicleLightState.RightBlinker, 40 | 'LeftBlinker' : carla.VehicleLightState.LeftBlinker, 41 | 'Reverse' : carla.VehicleLightState.Reverse, 42 | 'Fog' : carla.VehicleLightState.Fog, 43 | 'Interior' : carla.VehicleLightState.Interior, 44 | 'Special1' : carla.VehicleLightState.Special1, 45 | 'Special2' : carla.VehicleLightState.Special2, 46 | # 'All' : carla.VehicleLightState.All 47 | } 48 | 49 | CARLIGHT_TO_LABEL = {value: key for key, value in LABEL_TO_CARLIGHT.items()} 50 | 51 | VEHICLES_WITH_LIGHT = [ # for carla 0.9.13 52 | "vehicle.chevrolet.impala", 53 | "vehicle.dodge.charger_police", 54 | "vehicle.audi.tt", 55 | "vehicle.mercedes.coupe", 56 | "vehicle.mercedes.coupe_2020", 57 | "vehicle.dodge.charger_2020", 58 | "vehicle.lincoln.mkz_2020", 59 | "vehicle.dodge.charger_police_2020", 60 | "vehicle.audi.etron", 61 | "vehicle.volkswagen.t2_2021", 62 | "vehicle.tesla.cybertruck", 63 | "vehicle.lincoln.mkz_2017", 64 | "vehicle.ford.mustang", 65 | "vehicle.volkswagen.t2", 66 | "vehicle.tesla.model3", 67 | ] 68 | 69 | 70 | -------------------------------------------------------------------------------- /interface/highway_env/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.interface.highway_env.HighwayEnvInterface import HighwayEnvInterface 2 | from spider.interface.highway_env.HighwayEnvBenchmark import HighwayEnvBenchmark, HighwayEnvBenchmarkGUI 3 | -------------------------------------------------------------------------------- /interface/nuplan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/interface/nuplan/__init__.py -------------------------------------------------------------------------------- /optimize/APFGradDescend.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | ''' 6 | 要求势场处处可导! 7 | ''' 8 | 9 | # 定义风险势能场的图像表示,假设风险势能场是一个二维数组 10 | # risk_potential_field = np.array([[0.1, 0.2, 0.3], 11 | # [0.2, 0.5, 0.4], 12 | # [0.3, 0.4, 0.6]]) 13 | x, y = np.arange(600), np.arange(200) 14 | xx, yy = np.meshgrid(x,y) 15 | dist = np.sqrt((xx-380) ** 2 + (yy-100) ** 2) 16 | dist[dist<0.001] = 0.001 17 | risk_potential_field = 10/dist # 这里是反比,而不是平方反比 18 | risk_potential_field[risk_potential_field>1] = 1.0 19 | 20 | risk_potential_field -= risk_potential_field.min() 21 | risk_potential_field /= risk_potential_field.max() 22 | 23 | 24 | # 定义初始轨迹,假设是一个简单的二维数组,表示轨迹的坐标点 25 | initial_trajectory0 = initial_trajectory = np.column_stack([ 26 | np.linspace(200.,500, 50), 27 | np.linspace(50., 150, 50) 28 | ]) 29 | 30 | # 梯度下降的学习率/step size 31 | learning_rate = 5. 32 | 33 | # 最大迭代次数 34 | max_iterations = 1000 35 | 36 | # 停止迭代的条件 37 | stop_dcost_thresh = 0.01 38 | stop_dtraj_thresh = 0.01 39 | stop_dcost_endure = 30 40 | dcost_endure_count = 0 41 | dtraj_endure_count = 0 42 | 43 | cost_record = [] 44 | for i in range(max_iterations): 45 | # 计算当前轨迹的总风险势能 46 | total_potential = risk_potential_field[ 47 | initial_trajectory[:, 1].astype(int), 48 | initial_trajectory[:, 0].astype(int) 49 | ].sum() 50 | 51 | if len(cost_record) > 0 and abs(total_potential - cost_record[-1]) < stop_dcost_thresh: 52 | dcost_endure_count += 1 53 | if dcost_endure_count > stop_dcost_endure: 54 | print("Converged at iteration %d" % i) 55 | break 56 | endure_count = 0 57 | cost_record.append(total_potential) 58 | 59 | # grad descend to optimize the traj 60 | gradient_x = np.gradient(risk_potential_field, axis=1) 61 | gradient_y = np.gradient(risk_potential_field, axis=0) 62 | gradient_x = gradient_x[ 63 | initial_trajectory[:, 1].astype(int), 64 | initial_trajectory[:, 0].astype(int) 65 | ] 66 | gradient_y = gradient_y[ 67 | initial_trajectory[:, 1].astype(int), 68 | initial_trajectory[:, 0].astype(int) 69 | ] 70 | updated_trajectory = initial_trajectory - learning_rate * np.column_stack((gradient_x, gradient_y)) 71 | 72 | if np.linalg.norm(updated_trajectory - initial_trajectory) < stop_dtraj_thresh: 73 | dtraj_endure_count += 1 74 | if dtraj_endure_count > stop_dcost_endure: 75 | print("Converged at iteration %d" % i) 76 | break 77 | 78 | initial_trajectory = updated_trajectory 79 | 80 | plt.cla() 81 | plt.plot(cost_record) 82 | plt.pause(0.01) 83 | 84 | plt.figure() 85 | plt.plot(np.diff(cost_record)) 86 | 87 | plt.figure() 88 | plt.imshow(risk_potential_field, cmap='Reds') 89 | plt.plot(initial_trajectory0[:,0], initial_trajectory0[:,1],'-k') 90 | plt.plot(initial_trajectory[:,0], initial_trajectory[:,1],'-r') 91 | plt.show() 92 | 93 | 94 | # for i in range(max_iterations): 95 | # total_potential = 0 96 | # for point in initial_trajectory: 97 | # x, y = point 98 | # total_potential += risk_potential_field[int(y), int(x)] 99 | # 100 | # updated_trajectory = np.copy(initial_trajectory) 101 | # for j, point in enumerate(initial_trajectory): 102 | # x, y = point 103 | # gradient_x = (risk_potential_field[int(y), int(min(x + 1, risk_potential_field.shape[0] - 1))] - 104 | # risk_potential_field[int(y), int(max(x - 1, 0))]) 105 | # gradient_y = (risk_potential_field[int(min(y + 1, risk_potential_field.shape[1] - 1)), int(x)] - 106 | # risk_potential_field[int(max(y - 1, 0)), int(x)]) 107 | # 108 | # # 更新轨迹点的坐标 109 | # updated_trajectory[j, 0] += learning_rate * gradient_x 110 | # updated_trajectory[j, 1] += learning_rate * gradient_y 111 | # 112 | # initial_trajectory = updated_trajectory -------------------------------------------------------------------------------- /optimize/BaseOptimizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from .common import Objective, Constraints 3 | 4 | 5 | class BaseOptimizer: 6 | def __init__(self, objective=None, constraints=None): 7 | self.objective = objective 8 | self.constraints = constraints 9 | pass 10 | 11 | def set_constraints(self, constraints): 12 | self.constraints = constraints 13 | 14 | def set_objective(self, objective): 15 | self.objective = objective 16 | 17 | @abstractmethod 18 | def optimize(self, *args, **kwargs): 19 | pass 20 | -------------------------------------------------------------------------------- /optimize/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import Objective, Constraints 2 | 3 | -------------------------------------------------------------------------------- /param.py: -------------------------------------------------------------------------------- 1 | def _assignment(flag=None): 2 | if flag is None: 3 | _assignment.flag = _assignment.flag + 1 if hasattr(_assignment, "flag") else 0+1 4 | else: 5 | _assignment.flag = flag + 1 6 | return _assignment.flag - 1 7 | 8 | ####################### common ###################### 9 | ######### direction 10 | DIRECTION_LEFT = _assignment(0) 11 | DIRECTION_RIGHT = _assignment() 12 | 13 | ####################### perception ###################### 14 | # _i = 0 15 | PERCEPTION_BOX = _assignment(0) 16 | PERCEPTION_OCC = _assignment() 17 | 18 | ####################### prediction ###################### 19 | # _i = 0 20 | PREDICTION_LINEAR = _assignment(0) 21 | 22 | 23 | ####################### output ###################### 24 | # _i = 0 25 | OUTPUT_TRAJECTORY = _assignment(0) 26 | OUTPUT_CONTROL = _assignment() 27 | 28 | ####################### constraint ###################### 29 | # _i = 0 30 | CONSTRAINT_COLLISION = _assignment(0) 31 | # cartesian 32 | CONSTRIANT_SPEED_UB = _assignment() 33 | CONSTRIANT_SPEED_LB = _assignment() 34 | CONSTRIANT_ACCELERATION = _assignment() 35 | CONSTRIANT_DECELERATION = _assignment() 36 | CONSTRIANT_JERK = _assignment() 37 | CONSTRIANT_CURVATURE = _assignment() 38 | CONSTRIANT_HEADING = _assignment() 39 | CONSTRIANT_STEER = _assignment() 40 | 41 | 42 | # frenet 43 | CONSTRIANT_LATERAL_OFFSET = _assignment() 44 | CONSTRIANT_LATERAL_VELOCITY = _assignment() 45 | CONSTRIANT_LATERAL_ACCELERATION = _assignment() 46 | CONSTRIANT_LATERAL_JERK = _assignment() 47 | 48 | CONSTRIANT_LONGITUDINAL_PROGRESS = _assignment() 49 | CONSTRIANT_LONGITUDINAL_VELOCITY = _assignment() 50 | CONSTRIANT_LONGITUDINAL_ACCELERATION = _assignment() 51 | CONSTRIANT_LONGITUDINAL_JERK = _assignment() 52 | 53 | ####################### collision ###################### 54 | # for collision_checker 55 | COLLISION_CHECKER_SAT = _assignment(0) 56 | COLLISION_CHECKER_AABB = _assignment() 57 | COLLISION_CHECKER_DISK = _assignment() 58 | COLLISION_CHECKER_OCC = _assignment() 59 | 60 | ####################### DATA ###################### 61 | # for data closed-loop engine 62 | DATA_FORMAT_JSON = _assignment(0) # 将log buffer以json格式存储 63 | DATA_FORMAT_TENSOR = _assignment() # 将log buffer以tensor格式存储 64 | DATA_FORMAT_RAW = _assignment() # 将log buffer以原始数据格式存储 65 | 66 | DATA_TYPE_LOG = _assignment(0) 67 | DATA_TYPE_EXP = _assignment() 68 | 69 | DATA_GT_PLAN = _assignment(0) # 真值数据为plan 70 | DATA_GT_TRACE = _assignment() # 真值数据为trace 71 | 72 | 73 | 74 | # LOAD_TYPE_DATALOADER = _assignment(0) # 读取数据 并 转化为dataloader 75 | # LOAD_TYPE_RAW = _assignment() # 读取数据 并 保持原始数据格式 76 | 77 | 78 | ####################### RL ###################### 79 | # NN MODE FLAG 80 | NN_TRAIN_MODE = _assignment(0) 81 | NN_EVAL_MODE = _assignment() 82 | 83 | ####################### interface ###################### 84 | ########### for highway-_env 85 | HIGHWAYENV_OBS_KINEMATICS = _assignment(0) 86 | HIGHWAYENV_OBS_GRAYIMG = _assignment() 87 | HIGHWAYENV_OBS_OCCUPANCY = _assignment() 88 | HIGHWAYENV_OBS_TTC = _assignment() 89 | 90 | HIGHWAYENV_ACT_META = _assignment(0) 91 | HIGHWAYENV_ACT_DISCRETE = _assignment() 92 | HIGHWAYENV_ACT_CONTINUOUS = _assignment() 93 | 94 | # del _i, _assignment 95 | del _assignment 96 | 97 | # if __name__ == '__main__': 98 | # pass 99 | -------------------------------------------------------------------------------- /planner_zoo/BaseNeuralPlanner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Union 4 | from abc import abstractmethod 5 | 6 | import spider 7 | import spider.elements as elm 8 | from spider.planner_zoo.BasePlanner import BasePlanner 9 | 10 | 11 | class BaseNeuralPlanner(BasePlanner): 12 | def __init__(self, config=None, policy:nn.Module=None, state_encoder:nn.Module=None, action_decoder:nn.Module=None): 13 | super(BaseNeuralPlanner, self).__init__(config) 14 | 15 | # self.state_dim = 0 16 | # self.action_dim = 0 17 | 18 | self.state_encoder: nn.Module = state_encoder # Observation -> tensor (state) 19 | self.action_decoder: nn.Module = action_decoder # tensor(state) -> tensor (action) 20 | self.policy: nn.Module = policy # tensor (action) -> Plan 21 | 22 | self.device = self.config.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu")) 23 | if self.policy is not None: 24 | self.policy.to(self.device) 25 | 26 | 27 | @classmethod 28 | def default_config(cls) -> dict: 29 | """ 30 | :return: a configuration dict 31 | """ 32 | config = super().default_config() 33 | config.update({ 34 | "output": spider.OUTPUT_TRAJECTORY, 35 | "steps": 15, 36 | "dt": 0.2, 37 | "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), 38 | "print_info": True, 39 | 40 | "model_path": './model.pth' # todo:以后加上若干epoch自动保存的功能 41 | }) 42 | return config 43 | 44 | @property 45 | def training(self): 46 | return self.policy.training 47 | 48 | 49 | def act(self, state: torch.Tensor, *args, **kwargs) -> torch.Tensor: 50 | if self.training: 51 | action = self.policy(state) 52 | else: 53 | with torch.no_grad(): 54 | action = self.policy(state) 55 | return action 56 | 57 | def plan(self, ego_veh_state:elm.VehicleState, obstacles:elm.TrackingBoxList, routed_local_map:elm.RoutedLocalMap)\ 58 | -> Union[elm.Trajectory, elm.FrenetTrajectory]: 59 | state = self.state_encoder(ego_veh_state, obstacles, routed_local_map).unsqueeze(0) 60 | action = self.act(state.to(self.device)).squeeze(0) 61 | traj = self.action_decoder(action.detach().cpu(), ego_veh_state, obstacles, routed_local_map) 62 | return traj 63 | 64 | def to(self, device): 65 | self.device = device 66 | self.policy.to(device) 67 | 68 | def configure(self, config: dict): 69 | raise RuntimeError("Neural planner does not support. Re-instantiate a planner instead! ") 70 | 71 | 72 | def load_state_dict(self, path): 73 | path = self.config["model_path"] if path is None else path 74 | self.policy.load_state_dict(torch.load(path)) 75 | 76 | def save_state_dict(self, path=None): 77 | path = self.config["model_path"] if path is None else path 78 | torch.save(self.policy.state_dict(), path) 79 | 80 | def train_mode(self): 81 | self.policy.train() 82 | 83 | def eval_mode(self): 84 | self.policy.eval() 85 | 86 | def set_action_decoder(self, action_decoder:nn.Module): 87 | self.action_decoder = action_decoder 88 | 89 | def set_state_encoder(self, state_encoder:nn.Module): 90 | self.state_encoder = state_encoder 91 | 92 | def set_policy(self, policy: nn.Module): 93 | self.policy = policy 94 | self.policy.to(self.device) 95 | 96 | 97 | -------------------------------------------------------------------------------- /planner_zoo/BaseSamplerPlanner.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/planner_zoo/BaseSamplerPlanner.py -------------------------------------------------------------------------------- /planner_zoo/DDQNPlanner.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from spider.rl.policy.DDQNPolicy import DDQNPolicy 4 | 5 | from spider.planner_zoo.DQNPlanner import DQNPlanner, MLP_q_network 6 | 7 | class DDQNPlanner(DQNPlanner): 8 | def __init__(self, config=None): 9 | super().__init__(config) 10 | 11 | self.policy = DDQNPolicy( 12 | MLP_q_network(self.state_encoder.state_dim, self.action_decoder.action_dim), 13 | self.action_decoder.action_dim, 14 | lr=self.config["learning_rate"], 15 | enable_tensorboard=self.config["enable_tensorboard"], 16 | tensorboard_root=self.config["tensorboard_root"] 17 | ) 18 | 19 | 20 | -------------------------------------------------------------------------------- /planner_zoo/DQNPlanner.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union 3 | 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch 7 | 8 | 9 | import spider.elements as elm 10 | import spider.utils.lane_decision 11 | from spider.sampler.LatticeSampler import LatticeSampler 12 | from spider.data.DataBuffer import ExperienceBuffer 13 | 14 | from spider.rl.state.StateConverter import KineStateEncoder 15 | from spider.rl.action.ActionConverter import DiscreteTrajActionDecoder, DiscreteTrajActionEncoder 16 | from spider.rl.policy.DQNPolicy import DQNPolicy 17 | from spider.planner_zoo.BaseNeuralPlanner import BaseNeuralPlanner 18 | from spider.utils.lane_decision import ConstLaneDecision 19 | 20 | 21 | class MLP_q_network(nn.Module): 22 | def __init__(self, input_dim, output_dim, hidden_size=64): 23 | super().__init__() 24 | self.mlp = nn.Sequential( # 5-layer MLP 25 | nn.Linear(input_dim, hidden_size), nn.ReLU(), 26 | nn.Linear(hidden_size, hidden_size), nn.ReLU(), 27 | nn.Linear(hidden_size, hidden_size), nn.ReLU(), 28 | nn.Linear(hidden_size, hidden_size), nn.ReLU(), 29 | nn.Linear(hidden_size, output_dim)#, nn.Sigmoid() 30 | ) 31 | 32 | def forward(self, x): 33 | return self.mlp(x) 34 | 35 | 36 | class DQNPlanner(BaseNeuralPlanner): 37 | def __init__(self, config=None): 38 | super().__init__(config) 39 | 40 | self.state_encoder = KineStateEncoder( 41 | normalize=self.config["normalize"], 42 | relative=self.config["relative"], 43 | num_object=self.config["num_object"], 44 | x_range=self.config["longitudinal_range"], 45 | y_range=self.config["lateral_range"]) 46 | 47 | 48 | self.trajectory_sampler = LatticeSampler( 49 | self.steps, self.dt, 50 | self.config["end_T_candidates"], self.config["end_v_candidates"], 51 | self.config["end_s_candidates"], self.config["end_l_candidates"], 52 | lane_decision=ConstLaneDecision(1), # todo:这个以后要改过来, 现在太笨了 53 | calc_by_need=True 54 | ) 55 | 56 | self.action_decoder = DiscreteTrajActionDecoder(sampler=self.trajectory_sampler) 57 | # self.action_encoder = DiscreteTrajActionEncoder(sampler=self.trajectory_sampler) 58 | 59 | self.policy = DQNPolicy( 60 | MLP_q_network(self.state_encoder.state_dim, self.action_decoder.action_dim), 61 | self.action_decoder.action_dim, 62 | lr = self.config["learning_rate"], 63 | enable_tensorboard=self.config["enable_tensorboard"], 64 | tensorboard_root=self.config["tensorboard_root"] 65 | ) 66 | 67 | # self.reward = None 68 | # 69 | self.exp_buffer = ExperienceBuffer(maxlen=self.config["exp_buffer_maxlen"], forward_only=False, autosave=False) 70 | # self.exp_buffer.apply_to(self.policy, self.reward) 71 | 72 | 73 | @classmethod 74 | def default_config(cls) -> dict: 75 | cfg = super().default_config() 76 | cfg.update({ 77 | "steps": 20, 78 | "dt": 0.2, 79 | 80 | ####### 观测空间到状态空间变换 参数 ######### 81 | "num_object": 5, 82 | "normalize": False, 83 | "relative": False, 84 | "longitudinal_range": (-50, 100), 85 | "lateral_range": (-20,20), 86 | 87 | ####### 离散动作空间 采样器参数 ######### 88 | "end_s_candidates": (10, 20, 40), 89 | "end_l_candidates": (-3.5, 0, 3.5), 90 | "end_v_candidates": tuple(i * 60 / 3.6 / 2 for i in range(3)), 91 | "end_T_candidates": (2, 4), 92 | 93 | ####### 经验回放池 参数 ######### 94 | "exp_buffer_maxlen": 100000, 95 | 96 | ####### 训练 参数 ######### 97 | "batch_size": 64, 98 | 99 | "learning_rate": 0.0001, 100 | "enable_tensorboard": False, 101 | "tensorboard_root": './tensorboard/' 102 | # "epochs": 100, 103 | # "batch_size": 64 104 | }) 105 | return cfg 106 | 107 | 108 | def plan(self, ego_veh_state:elm.VehicleState, obstacles:elm.TrackingBoxList, routed_local_map:elm.RoutedLocalMap)\ 109 | -> Union[elm.Trajectory, elm.FrenetTrajectory]: 110 | 111 | if not (routed_local_map is None): 112 | self.set_local_map(routed_local_map) 113 | 114 | state = self.state_encoder(ego_veh_state, obstacles, self.local_map).unsqueeze(0) 115 | action = self.act(state.to(self.device)).squeeze(0) 116 | 117 | traj = self.action_decoder(action.detach().cpu(), ego_veh_state, obstacles, self.local_map) 118 | return traj 119 | 120 | 121 | if __name__ == '__main__': 122 | from spider.interface import DummyBenchmark 123 | # 124 | # planner = MlpPlanner() 125 | # 126 | # # obs = DummyBenchmark.get_environment_presets() 127 | # # traj = planner.plan(*obs) 128 | # 129 | # bm = DummyBenchmark() 130 | # bm.test(planner) 131 | 132 | 133 | pass 134 | -------------------------------------------------------------------------------- /planner_zoo/FallbackDummyPlanner.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from spider.planner_zoo.BasePlanner import DummyPlanner 4 | from spider.planner_zoo.FallbackPlanner import FallbackPlanner 5 | from spider.utils.collision.CollisionChecker import BoxCollisionChecker 6 | import numpy as np 7 | 8 | 9 | class FallbackDummyPlanner(DummyPlanner): 10 | def __init__(self, config=None): 11 | super().__init__(config) 12 | self.fallback_planner = FallbackPlanner(config) 13 | self.collision_checker = BoxCollisionChecker(self.config["ego_veh_length"], self.config["ego_veh_width"]) 14 | 15 | @classmethod 16 | def default_config(cls) -> dict: 17 | return cls._update_config(super().default_config(), FallbackPlanner.default_config()) 18 | 19 | def plan(self, ego_veh_state, obstacles, routed_local_map=None): 20 | traj = super().plan(ego_veh_state, obstacles, routed_local_map) 21 | if (traj is None): 22 | traj = self.fallback_planner.plan(ego_veh_state, obstacles, routed_local_map) 23 | else: 24 | ts = np.arange(self.steps) * self.dt 25 | obstacles.predict(ts[ts < self.config["min_TTC"]]) 26 | if self.collision_checker.check_trajectory(traj, obstacles): 27 | traj = self.fallback_planner.plan(ego_veh_state, obstacles, routed_local_map) 28 | 29 | return traj 30 | -------------------------------------------------------------------------------- /planner_zoo/FallbackPlanner.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | 5 | import spider 6 | from spider.planner_zoo import BasePlanner 7 | from spider.utils.collision import BoxCollisionChecker 8 | from spider.constraints import CartConstriantChecker 9 | 10 | 11 | 12 | class FallbackPlanner(BasePlanner): 13 | def __init__(self, config=None): 14 | super().__init__(config) 15 | self.constraint_checker = CartConstriantChecker( 16 | self.config, BoxCollisionChecker(self.config["ego_veh_length"], self.config["ego_veh_width"]) 17 | ) 18 | 19 | @classmethod 20 | def default_config(cls) -> dict: 21 | return cls._update_config(super().default_config(),{ 22 | # "acceleration": , 23 | "min_TTC": 2.0, 24 | "max_deceleration": 10.0, 25 | }) 26 | 27 | 28 | def plan(self, ego_veh_state:spider.elements.VehicleState, obstacles:spider.elements.TrackingBoxList, 29 | routed_local_map=None): 30 | 31 | if self.config["print_info"]: 32 | print("WARNING: Fallback planner activated!") 33 | 34 | # 初始状态 35 | ego = ego_veh_state 36 | x, y = ego.x(), ego.y() 37 | vx, vy = ego.velocity.x, ego.velocity.y 38 | cosyaw = vx / ego.v() 39 | sinyaw = vy / ego.v() 40 | # vx, vy = ego.v() * cosyaw, ego.v() * sinyaw 41 | # ax, ay = ego.a() * cosyaw, ego.a() * sinyaw 42 | 43 | acc = -self.config["max_deceleration"] 44 | ax, ay = acc * cosyaw, acc * sinyaw 45 | 46 | # 计算轨迹 47 | ts = np.arange(self.steps) * self.dt 48 | t_stop = ego.v() / abs(acc) 49 | if t_stop > self.horizon: # can not stop in time 50 | xs = x + vx * ts + 0.5 * ax * ts ** 2 51 | ys = y + vy * ts + 0.5 * ay * ts ** 2 52 | else: # has stopped at some timepoint during the horizon 53 | pre_idx = ts dict: 23 | return cls._update_config(super().default_config(),{ 24 | # "min_TTC": 2.0, 25 | # "max_deceleration": 10.0, 26 | }) 27 | 28 | 29 | def plan(self, ego_veh_state:spider.elements.VehicleState, obstacles:spider.elements.TrackingBoxList, 30 | local_map=None): 31 | 32 | if not (local_map is None): 33 | self.set_local_map(local_map) 34 | 35 | ego_lane_idx , (ego_s, _) = self.local_map.match_lane(ego_veh_state.x(), ego_veh_state.y(), return_frenet=True) 36 | target_lane = self.local_map.lanes[ego_lane_idx] 37 | ref_line = target_lane.centerline 38 | 39 | self.frenet_tf.set_reference_line(target_lane.centerline, target_lane.centerline_csp) 40 | roi_range = target_lane.width / 2.0 + 0.5 41 | boxes_with_frenet = self.frenet_tf.cart2frenet4boxes(obstacles, order=1) 42 | 43 | front_veh_speed, front_veh_dist = 60 / 3.6, 1000 44 | for tb in boxes_with_frenet: 45 | if -roi_range < tb.frenet_state.l < roi_range and tb.frenet_state.s > ego_s: 46 | dist = tb.frenet_state.s - ego_s 47 | if dist < front_veh_dist: 48 | front_veh_dist = dist 49 | front_veh_speed = tb.frenet_state.s_dot 50 | 51 | 52 | # 初始状态 53 | ego = ego_veh_state 54 | x, y = ego.x(), ego.y() 55 | vx, vy = ego.velocity.x, ego.velocity.y 56 | cosyaw = vx / ego.v() 57 | sinyaw = vy / ego.v() 58 | 59 | # 控制量计算 60 | current_pose = (x, y, ego.yaw()) 61 | current_speed = ego.v() 62 | desired_speed = target_lane.speed_limit 63 | acc, steer = self.idm_controller.get_control(ref_line, front_veh_speed, front_veh_dist, desired_speed, 64 | current_pose, current_speed) 65 | # ax, ay = acc * cosyaw, acc * sinyaw 66 | 67 | # 计算轨迹 68 | accs = acc*np.ones(self.steps-1) 69 | steers = steer*np.ones(self.steps-1) 70 | veh_model = spider.control.Bicycle(x, y, ego.v(), ego.a(), ego.yaw(), dt=self.dt) 71 | traj = spider.elements.Trajectory(self.steps, self.dt) 72 | traj.step(veh_model, accs, steers) 73 | 74 | if self.config["print_info"]: 75 | print("IDMPlanner: front_veh: dist, speed = ", front_veh_dist, front_veh_speed) 76 | print("IDMPlanner: acc, steer = ", acc, steer) 77 | 78 | return traj 79 | 80 | if __name__ == '__main__': 81 | from spider.interface import DummyBenchmark 82 | planner = IDMPlanner() 83 | benchmark = DummyBenchmark({ 84 | "debug_mode": True, 85 | }) 86 | 87 | benchmark.test(planner) 88 | 89 | -------------------------------------------------------------------------------- /planner_zoo/ImaginaryPlanner.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import warnings 4 | import spider 5 | from spider.planner_zoo.BasePlanner import BasePlanner 6 | 7 | from spider.elements.map import RoutedLocalMap 8 | from spider.elements.trajectory import FrenetTrajectory 9 | from spider.elements.vehicle import VehicleState 10 | from spider.elements.box import TrackingBoxList, TrackingBox 11 | 12 | from spider.utils.ImaginaryEngine import ImaginaryEngine 13 | from spider.planner_zoo.LatticePlanner import LatticePlanner 14 | 15 | 16 | class ImaginaryPlanner(BasePlanner): 17 | def __init__(self, config=None): 18 | super(ImaginaryPlanner, self).__init__(config) 19 | 20 | self.config = self.default_config() 21 | if not (config is None): 22 | self.config.update(config) 23 | 24 | self._atom_planner = LatticePlanner(self.config["planner_config"]) 25 | self._predictor = None 26 | self._tracker = None 27 | self.track_steps = self.config["track_steps"] 28 | 29 | self.imaginary_engine = ImaginaryEngine( 30 | self.steps, self.dt, self._atom_planner, 31 | predictor=self._predictor, tracker=self._tracker, track_steps=self.track_steps 32 | ) 33 | 34 | @classmethod 35 | def default_config(cls) -> dict: 36 | """ 37 | :return: a configuration dict 38 | """ 39 | config = { 40 | "steps": 30, 41 | "dt": 0.2, 42 | "track_steps": 5, 43 | "print_info": True, 44 | 45 | "planner_config": LatticePlanner.default_config() 46 | } 47 | 48 | config["planner_config"].update({ 49 | "steps": 10, 50 | "dt": 0.2, 51 | "end_s_candidates": (20, 40, 60), 52 | "end_l_candidates": (-3.5, 0, 3.5), 53 | "print_info": False 54 | }) 55 | return config 56 | 57 | # todo: 以后atom_planner, predictor和tracker用 @property和@xxx.setter来设置 58 | 59 | def set_atom_planner(self, atom_planner): 60 | self._atom_planner = atom_planner 61 | self.imaginary_engine.atom_planner = self._atom_planner 62 | 63 | def set_predictor(self, predictor): 64 | self._predictor = predictor 65 | self.imaginary_engine.predictor = self._predictor 66 | 67 | def set_tracker(self,tracker): 68 | self._tracker = tracker 69 | self.imaginary_engine.tracker = self._tracker 70 | 71 | def set_local_map(self, local_map:RoutedLocalMap): 72 | self.local_map = local_map 73 | 74 | 75 | def plan(self, ego_veh_state:VehicleState, obstacles:TrackingBoxList, local_map:RoutedLocalMap=None) -> FrenetTrajectory: 76 | """ 77 | 输入定位、物体、(地图optional,更新频率比较慢。建议在外面单独写set地图的逻辑) 78 | 输出轨迹(FrenetTrajectory) 79 | """ 80 | t1 = time.time() 81 | # 储存地图。每条车道代表了一个frenet坐标系,储存地图即储存了lanes,即储存了可供选择的几个frenet坐标系 82 | if not (local_map is None): 83 | self.set_local_map(local_map) 84 | 85 | traj, truncated = self.imaginary_engine.imagine( 86 | ego_veh_state, obstacles, self.local_map) 87 | 88 | if truncated: 89 | if self.config["print_info"]: 90 | warnings.warn("trajectory is truncated! No solution during imagination...") 91 | 92 | t2 = time.time() 93 | if self.config["print_info"]: 94 | print("Planning Succeed! Time: %.2f seconds, FPS: %.2f" % (t2 - t1, 1 / (t2 - t1))) 95 | 96 | return traj 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /planner_zoo/OptimizedGRUPlanner.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import copy 5 | 6 | from spider.planner_zoo.GRUPlanner import GRUPlanner 7 | from spider.optimize.TrajectoryOptimizer import FrenetTrajectoryOptimizer 8 | 9 | from spider.elements import RoutedLocalMap, TrackingBoxList, VehicleState 10 | from spider.control.vehicle_model import Bicycle 11 | 12 | from spider.utils.collision import BoxCollisionChecker 13 | from spider.constraints import CartConstriantChecker 14 | 15 | class OptimizedGRUPlanner(GRUPlanner): 16 | def __init__(self, config=None): 17 | if "steps" in config or "dt" in config: 18 | warnings.warn("ATTENTION! Change of steps and dt needs the adjustment of optimizer parameters!") 19 | 20 | super(OptimizedLatticePlanner, self).__init__(config) 21 | 22 | self.optimizer = FrenetTrajectoryOptimizer(self.steps, self.dt) 23 | self.initial_traj = None 24 | 25 | @property 26 | def corridor(self): 27 | return self.optimizer.corridor 28 | 29 | @classmethod 30 | def default_config(cls) -> dict: 31 | """ 32 | :return: a configuration dict 33 | """ 34 | config = super().default_config() 35 | config.update({ 36 | "steps": 20, 37 | "dt": 0.2, 38 | "wheelbase": 3.0, 39 | "print_info": False 40 | }) 41 | return config 42 | 43 | 44 | def plan(self, ego_veh_state:VehicleState, obstacles:TrackingBoxList, local_map:RoutedLocalMap=None): 45 | traj = super().plan(ego_veh_state, obstacles, local_map) 46 | self.initial_traj = copy.deepcopy(traj) 47 | 48 | if traj is None: 49 | print("No feasible sampling initial trajectory! ") 50 | return None 51 | 52 | raise NotImplementedError 53 | # todo: 补充全局坐标下的optimizer 54 | 55 | # 补充l_dot和l_2dot信息,因为本身以l_prime和l_2prime存储 56 | traj.l_dot = np.asarray(traj.l_prime) * np.asarray(traj.s_dot) 57 | traj.l_2dot = np.asarray(traj.l_2prime) * np.asarray(traj.s_dot) ** 2 \ 58 | + np.asarray(traj.l_prime) * np.asarray(traj.s_2dot) 59 | 60 | 61 | # 将obstacles转化到frenet坐标下 62 | obstacles_with_fstate = self.coordinate_transformer.cart2frenet4boxes(obstacles, convert_prediction=True, order=0) 63 | 64 | # 优化s,l序列 65 | opt_traj = self.optimizer.optimize_traj(traj, obstacles_with_fstate) # 仅有s, l信息 66 | # 将sl序列转化为xy序列 67 | opt_traj = self.coordinate_transformer.frenet2cart4traj(opt_traj, order=0) # s,l -> x,y 68 | # xy序列微分,提取高阶运动学信息 69 | opt_traj.derivative( 70 | Bicycle(ego_veh_state.x(), ego_veh_state.y(), ego_veh_state.v(), ego_veh_state.a(), 71 | ego_veh_state.yaw(), dt=self.dt, wheelbase=self.config["wheelbase"]), 72 | opt_traj.x, opt_traj.y) # 补充高阶导数信息 73 | 74 | # 附加debug信息corridor,供调试和可视化 75 | opt_traj.debug_info["corridor"] = self.corridor 76 | opt_traj.debug_info["initial_trajectory"] = self.initial_traj 77 | return opt_traj 78 | 79 | 80 | if __name__ == '__main__': 81 | from spider.interface.BaseBenchmark import DummyBenchmark 82 | 83 | bm = DummyBenchmark() 84 | planner = OptimizedGRUPlanner({ 85 | "steps": 20, 86 | "dt": 0.2, 87 | }) 88 | bm.test(planner) 89 | 90 | -------------------------------------------------------------------------------- /planner_zoo/OptimizedLatticePlanner.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import copy 5 | 6 | from spider.planner_zoo.LatticePlanner import LatticePlanner 7 | from spider.optimize.TrajectoryOptimizer import FrenetTrajectoryOptimizer 8 | 9 | from spider.elements import RoutedLocalMap, TrackingBoxList, VehicleState 10 | from spider.control.vehicle_model import Bicycle 11 | 12 | from spider.utils.collision import BoxCollisionChecker 13 | from spider.constraints import CartConstriantChecker 14 | 15 | # todo:把优化器的打印输出增加一个可关闭选项,把lattice planner内部的计时和输出打印关闭,增加一个全部的计时打印输出功能 16 | class OptimizedLatticePlanner(LatticePlanner): 17 | def __init__(self, config=None): 18 | if "steps" in config or "dt" in config: 19 | warnings.warn("ATTENTION! Change of steps and dt needs the adjustment of optimizer parameters!") 20 | 21 | super(OptimizedLatticePlanner, self).__init__(config) 22 | 23 | self.constraint_checker = CartConstriantChecker( # 由于优化问题暂时视作质点优化问题,为了保证采样的初始解是可行解,需要使用质点约束检查器 24 | self.config, BoxCollisionChecker(0.1, 0.1) # 留着0.2是为了安全冗余 25 | ) 26 | self.optimizer = FrenetTrajectoryOptimizer(self.steps, self.dt) 27 | self.initial_traj = None 28 | 29 | @property 30 | def corridor(self): 31 | return self.optimizer.corridor 32 | 33 | @classmethod 34 | def default_config(cls) -> dict: 35 | """ 36 | :return: a configuration dict 37 | """ 38 | config = super().default_config() 39 | config.update({ 40 | "steps": 20, 41 | "dt": 0.2, 42 | "end_l_candidates": (-3.5, 0, 3.5), 43 | "wheelbase": 3.0 44 | }) 45 | return config 46 | 47 | # def configure(self, config: dict): 48 | # 49 | # super(OptimizedLatticePlanner, self).configure(config) 50 | 51 | 52 | 53 | def plan(self, ego_veh_state:VehicleState, obstacles:TrackingBoxList, local_map:RoutedLocalMap=None): 54 | # 将obstacles膨胀,以使得优化问题变为一个质点的轨迹规划问题 55 | obstacles = obstacles.dilate(self.length, self.width) 56 | 57 | traj = super().plan(ego_veh_state, obstacles, local_map) 58 | self.initial_traj = copy.deepcopy(traj) 59 | 60 | if traj is None: 61 | if self.config["print_info"]: 62 | print("No feasible sampling initial trajectory! ") 63 | return None 64 | 65 | 66 | # 补充l_dot和l_2dot信息,因为本身以l_prime和l_2prime存储 67 | traj.l_dot = np.asarray(traj.l_prime) * np.asarray(traj.s_dot) 68 | traj.l_2dot = np.asarray(traj.l_2prime) * np.asarray(traj.s_dot) ** 2 \ 69 | + np.asarray(traj.l_prime) * np.asarray(traj.s_2dot) 70 | 71 | 72 | # 将obstacles转化到frenet坐标下 73 | obstacles_with_fstate = self.coordinate_transformer.cart2frenet4boxes(obstacles, convert_prediction=True, order=0) 74 | 75 | # 优化s,l序列 76 | opt_traj = self.optimizer.optimize_traj(traj, obstacles_with_fstate) # 仅有s, l信息 77 | # 将sl序列转化为xy序列 78 | opt_traj = self.coordinate_transformer.frenet2cart4traj(opt_traj, order=0) # s,l -> x,y 79 | # xy序列微分,提取高阶运动学信息 80 | opt_traj.derivative( 81 | Bicycle(ego_veh_state.x(), ego_veh_state.y(), ego_veh_state.v(), ego_veh_state.a(), 82 | ego_veh_state.yaw(), dt=self.dt, wheelbase=self.config["wheelbase"]), 83 | opt_traj.x, opt_traj.y) # 补充高阶导数信息 84 | 85 | # 附加debug信息corridor,供调试和可视化 86 | opt_traj.debug_info["corridor"] = self.corridor 87 | opt_traj.debug_info["initial_trajectory"] = self.initial_traj 88 | return opt_traj 89 | 90 | 91 | if __name__ == '__main__': 92 | from spider.interface.BaseBenchmark import DummyBenchmark 93 | 94 | bm = DummyBenchmark() 95 | # bm._init_obstacles = TrackingBoxList([ 96 | # TrackingBox(obb=(50, 0, 5, 2, np.arctan2(0.2, 5))), 97 | # TrackingBox(obb=(100, 0, 5, 2, np.arctan2(-0.2, 5))), 98 | # TrackingBox(obb=(200, -10, 1, 1, np.pi / 2)) # 横穿马路 99 | # ]) 100 | 101 | # planner = LatticePlanner({ 102 | # "steps": 20, 103 | # "dt": 0.2, 104 | # }) 105 | planner = OptimizedLatticePlanner({ 106 | "steps": 20, 107 | "dt": 0.2, 108 | }) 109 | bm.test(planner) 110 | 111 | -------------------------------------------------------------------------------- /planner_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.planner_zoo.BasePlanner import BasePlanner, DummyPlanner 2 | from spider.planner_zoo.LatticePlanner import LatticePlanner 3 | # from spider.planner_zoo.MFRLPlanner import MFRLPlanner 4 | from spider.planner_zoo.BezierPlanner import BezierPlanner 5 | from spider.planner_zoo.PiecewiseLatticePlanner import PiecewiseLatticePlanner 6 | from spider.planner_zoo.ImaginaryPlanner import ImaginaryPlanner 7 | from spider.planner_zoo.OptimizedLatticePlanner import OptimizedLatticePlanner 8 | from spider.planner_zoo.FallbackPlanner import FallbackPlanner 9 | from spider.planner_zoo.FallbackDummyPlanner import FallbackDummyPlanner 10 | 11 | # import spider 12 | # FallbackPlanner = spider._try_import_from("spider.planner_zoo.FallbackPlanner", "FallbackPlanner") 13 | # ImaginaryPlanner = spider._try_import_from("spider.planner_zoo.ImaginaryPlanner", "ImaginaryPlanner") 14 | -------------------------------------------------------------------------------- /rl/PlannerGym.py: -------------------------------------------------------------------------------- 1 | import spider.visualize as vis 2 | import tqdm 3 | 4 | ''' 5 | qzl: 基本的伪代码 6 | 7 | state, action, next_state, done = None, None, None, False 8 | _env.reset() 9 | closed_loop = True # 默认闭环,开环的话不会计算奖励函数并储存经验池,纯做policy.act(state) 10 | 11 | while True: 12 | observation = _env.observe() 13 | 14 | ------------------------- Planner内部 ------------------------------ 15 | next_state = Encoder.encode(observation) 16 | 17 | if closed_loop: 18 | reward, done = RewardFunction.evaluate(state, action, next_state) 19 | agent.experience_buffer.record(state, action, reward, next_state, done) 20 | # 注意,当state是none的时候,reward的计算以及经验池的record都是无效的 21 | 22 | if done: 23 | state, action, next_state = None, None, None 24 | plan = None 25 | else: 26 | state = next_state 27 | action = agent.policy.act(state) 28 | plan = Decoder.decode(action) 29 | --------------------------------------------------------------------- 30 | 31 | if plan is None: _env.reset() 32 | else: _env.step(plan) 33 | ''' 34 | 35 | class PlannerGym: 36 | ''' 37 | todo:以后加一个把环境打包成gym环境的功能 38 | ''' 39 | def __init__(self, env_interface, reward_function, visualize=False): 40 | self.env_interface = env_interface 41 | self.reward_function = reward_function 42 | self._visualize = visualize 43 | 44 | 45 | def train(self, planner, train_steps, batch_size=64): 46 | # todo: 是一个step触发训练,还是一个episode触发训练? 47 | # 以及一轮训练的次数是1吗?可以参考stable baselines3 48 | 49 | policy = planner.policy 50 | exp_buffer = planner.exp_buffer 51 | 52 | exp_buffer.apply_to(policy, self.reward_function) # 开始监听 53 | 54 | obs, done = None, True 55 | 56 | policy.set_exploration(enable=True) 57 | 58 | for i in tqdm.tqdm(range(train_steps)): 59 | if done: 60 | obs = self.env_interface.reset() 61 | 62 | # forward 63 | plan = planner.plan(*obs) # 监听exp_buffer记录了obs, plan 64 | self.env_interface.conduct_trajectory(plan) 65 | obs2 = self.env_interface.wrap_observation() 66 | 67 | # feedback 68 | reward, done = self.reward_function.evaluate_log(obs, plan, obs2) # 监听exp_buffer记录了reward, done 69 | policy.try_write_reward(reward, done, i) 70 | 71 | # 学习 72 | batched_data = exp_buffer.sample(batch_size) 73 | policy.learn_batch(*batched_data) 74 | 75 | # visualize 76 | if self._visualize: 77 | vis.cla() 78 | vis.lazy_draw(*obs, plan) 79 | vis.title(f"Step {i}, Reward {reward}") 80 | vis.pause(0.001) 81 | 82 | obs = obs2 83 | 84 | policy.set_exploration(enable=False) 85 | 86 | 87 | if __name__ == '__main__': 88 | from spider.interface import DummyInterface, DummyBenchmark 89 | from spider.planner_zoo.DQNPlanner import DQNPlanner 90 | from spider.planner_zoo.DDQNPlanner import DDQNPlanner 91 | from spider.rl.reward.TrajectoryReward import TrajectoryReward 92 | 93 | # presets 94 | ego_size = (5.,2.) 95 | 96 | # setup env 97 | env_interface = DummyInterface() 98 | 99 | # setup reward 100 | reward_function = TrajectoryReward( 101 | (-10., 280.), (-15, 15), (240., 280.), (-10,10), ego_size 102 | ) 103 | 104 | # setup_planner 105 | planner_dqn = DDQNPlanner({ 106 | "ego_veh_width": ego_size[1], 107 | "ego_veh_length": ego_size[0], 108 | "enable_tensorboard": True, 109 | }) 110 | 111 | planner_school = PlannerGym(env_interface, reward_function, visualize=False) 112 | planner_school.train(planner_dqn, 10000, 64) 113 | planner_dqn.policy.save_model('./q_net.pth') 114 | 115 | planner_dqn.policy.load_model('./q_net.pth') 116 | DummyBenchmark({"save_video": True,}).test(planner_dqn) 117 | -------------------------------------------------------------------------------- /rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/rl/__init__.py -------------------------------------------------------------------------------- /rl/_tensor_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pad_at_end(x, dim=0, padding_channels=1, padding_value=0.0): 5 | """ 6 | Pad a tensor at the end of a dimension. 7 | qzl: 没有用torch.nn.functional.pad,参数有点复杂 8 | 9 | :param x: The tensor to pad. 10 | :param dim: The dimension to pad. 11 | :param padding_channels: The number of channels to pad. 12 | :param padding_value: The value to pad with. 13 | """ 14 | if dim >= len(x.size()): 15 | raise ValueError("Dimension specified for padding is out of range for the input tensor.") 16 | 17 | # 构造 pad 的 shape 18 | pad_shape = list(x.shape) 19 | pad_shape[dim] = padding_channels 20 | 21 | # 构造 padding tensor,并填充为指定的值 22 | padding_tensor = torch.full(pad_shape, padding_value, dtype=x.dtype, device=x.device) 23 | 24 | # 使用 torch.cat 拼接原始张量和 padding 张量 25 | padded_tensor = torch.cat([x, padding_tensor], dim=dim) 26 | 27 | return padded_tensor 28 | 29 | 30 | def normalize_to_range(x, min_val, max_val): 31 | """ 32 | Min-Max normalize a tensor to [0, 1] given a specific range. 33 | 34 | :param x: The input tensor to normalize. 35 | :param min_val: The minimum value of the range. 36 | :param max_val: The maximum value of the range. 37 | """ 38 | 39 | clamped_x = torch.clamp(x, min=min_val, max=max_val) # Clip to the specified range 40 | normalized_x = (clamped_x - min_val) / (max_val - min_val) 41 | return normalized_x 42 | 43 | def convert_to_relative(): 44 | pass 45 | 46 | -------------------------------------------------------------------------------- /rl/action/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/rl/action/__init__.py -------------------------------------------------------------------------------- /rl/policy/ClassificationILPolicy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | 4 | import tqdm 5 | import matplotlib.pyplot as plt 6 | 7 | import spider 8 | from spider.rl.policy.BasePolicy import BasePolicy, DataLoader 9 | 10 | class ClassificationILPolicy(BasePolicy): 11 | def __init__(self, critic:torch.nn.Module, criterion:torch.nn.Module=None, lr=1e-4, 12 | enable_tensorboard=False, tensorboard_root='./tensorboard/'): 13 | super().__init__(enable_tensorboard, tensorboard_root) 14 | self.critic = critic 15 | 16 | self.criterion = torch.nn.CrossEntropyLoss() if criterion is None else criterion 17 | self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) 18 | 19 | self._plot_train_curve = True 20 | 21 | def forward(self, state:torch.Tensor) -> torch.Tensor: 22 | prob = self.critic(state.to(self.device)) 23 | return torch.argmax(prob, dim=-1) 24 | 25 | def learn_batch(self, batched_state, batched_target_action, *args): 26 | self.train() 27 | # forward 28 | batched_state = batched_state.to(self.device) 29 | batched_prob = self.critic(batched_state.to(self.device)) 30 | 31 | # calculate loss 32 | batched_target_action = batched_target_action.to(self.device) 33 | loss = self.criterion(batched_prob, batched_target_action) 34 | 35 | # backward 36 | self.optimizer.zero_grad() 37 | loss.backward() 38 | self.optimizer.step() 39 | return loss.item() 40 | 41 | def validate_batch(self, batched_state, batched_target_action): 42 | self.eval() 43 | batched_state = batched_state.to(self.device) 44 | batched_prob = self.critic(batched_state.to(self.device)) 45 | # calculate loss 46 | 47 | batched_target_action = batched_target_action.to(self.device) 48 | loss = self.criterion(batched_prob, batched_target_action) 49 | 50 | return loss.item() 51 | 52 | def learn_dataset(self, epochs:int, train_loader:DataLoader, val_loader:DataLoader=None): 53 | ''' 54 | Optionally implemented 55 | learn from the dataloader of all the dataset 56 | ''' 57 | for epoch in tqdm.tqdm(range(epochs), desc="Training with dataset..."): 58 | avg_train_loss, count = 0.0, 0 59 | for batch_data in train_loader: 60 | avg_train_loss += self.learn_batch(*batch_data) 61 | count += 1 62 | avg_train_loss /= count 63 | if self.enable_tensorboard: 64 | self.writer.add_scalar('loss/train', avg_train_loss, epoch) 65 | 66 | if val_loader is not None: 67 | avg_val_loss, count = 0.0, 0 68 | for batch_data in val_loader: 69 | avg_val_loss += self.learn_batch(*batch_data) 70 | count += 1 71 | avg_val_loss /= count 72 | if self.enable_tensorboard: 73 | self.writer.add_scalar('loss/val', avg_val_loss, epoch) 74 | else: 75 | avg_val_loss = None 76 | 77 | if self._plot_train_curve: 78 | self._update_train_curve(avg_train_loss, avg_val_loss) 79 | 80 | # if self.enable_tensorboard: 81 | # self.start_tensorboard() 82 | plt.savefig('./train_curve.png') 83 | plt.close() 84 | 85 | 86 | def _update_train_curve(self, train_loss, val_loss=None): 87 | if not hasattr(self, "_loss_record"): 88 | self._loss_record = {"train":[], "val":[]} 89 | plt.cla() 90 | ax = plt.gca() 91 | 92 | self._loss_record["train"].append(train_loss) 93 | ax.plot(self._loss_record["train"],label="train") 94 | 95 | if val_loss is not None: 96 | self._loss_record["val"].append(val_loss) 97 | if len(self._loss_record["val"]) > 0: 98 | ax.plot(self._loss_record["val"], label="val") 99 | 100 | plt.legend() 101 | plt.pause(0.01) -------------------------------------------------------------------------------- /rl/policy/DDQNPolicy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | from spider.rl.policy.DQNPolicy import DQNPolicy 5 | 6 | class DDQNPolicy(DQNPolicy): 7 | def learn_batch(self, states, actions, rewards, dones, next_states): 8 | self.train() 9 | 10 | states, actions, rewards, dones, next_states = \ 11 | (x.to(self.device) for x in (states, actions, rewards, dones, next_states)) 12 | 13 | # Get current Q-values estimates 14 | q_values = self.q_network(states) # Qs for current state 15 | # Retrieve the q-values for the actions from the replay buffer 16 | q_values = torch.gather(q_values, dim=1, index=actions.long()) 17 | 18 | # DDQN: use the Q-network to select the action, and use the target Q-network to evaluate the action 19 | next_action = self.q_network(next_states).argmax(dim=1, keepdim=True) 20 | 21 | with torch.no_grad(): 22 | # max(Qs for current state) 23 | q_values_next = self.target_q_network(next_states).gather(1, next_action) 24 | # 1-step TD 25 | q_targets = rewards + (1 - dones) * self.gamma * q_values_next 26 | 27 | loss = self.criterion(q_values, q_targets) 28 | 29 | self.optimizer.zero_grad() 30 | loss.backward() 31 | if getattr(self, "max_grad_norm", None) is not None: 32 | torch.nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm) 33 | self.optimizer.step() 34 | 35 | if self.enable_tensorboard: 36 | self.writer.add_scalar('loss/Q_network', loss.item(), self._learning_count) 37 | 38 | if self._plot_train_curve: 39 | self._update_train_curve(loss.item()) 40 | 41 | self._learning_count += 1 42 | if self._learning_count % self.target_update_frequency == 0: 43 | self.target_q_network.load_state_dict(self.q_network.state_dict()) 44 | print("Target Q Network has updated at steps ", self._learning_count) 45 | 46 | return loss.item() 47 | 48 | -------------------------------------------------------------------------------- /rl/policy/RegressionILPolicy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | 4 | import tqdm 5 | import matplotlib.pyplot as plt 6 | 7 | import spider 8 | from spider.rl.policy.BasePolicy import BasePolicy, DataLoader 9 | 10 | 11 | class RegressionILPolicy(BasePolicy): 12 | def __init__(self, actor:torch.nn.Module, criterion:torch.nn.Module=None, lr=1e-4, 13 | enable_tensorboard=False, tensorboard_root='./tensorboard/'): 14 | super().__init__(enable_tensorboard, tensorboard_root) 15 | self.actor = actor 16 | 17 | self.criterion = torch.nn.L1Loss() if criterion is None else criterion 18 | self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) 19 | 20 | self._plot_train_curve = True 21 | 22 | def forward(self, state:torch.Tensor) -> torch.Tensor: 23 | return self.actor(state.to(self.device)) 24 | 25 | def learn_batch(self, batched_state, batched_target_action, *args): 26 | self.train() 27 | # forward 28 | batched_state = batched_state.to(self.device) 29 | batched_pred_action = self.forward(batched_state) 30 | 31 | # calculate loss 32 | batched_target_action = batched_target_action.to(self.device) 33 | loss = self.criterion(batched_pred_action, batched_target_action) 34 | 35 | # backward 36 | self.optimizer.zero_grad() 37 | loss.backward() 38 | self.optimizer.step() 39 | return loss.item() 40 | 41 | def validate_batch(self, batched_state, batched_target_action): 42 | self.eval() 43 | # forward 44 | batched_state = batched_state.to(self.device) 45 | batched_pred_action = self.forward(batched_state) 46 | 47 | # calculate loss 48 | batched_target_action = batched_target_action.to(self.device) 49 | loss = self.criterion(batched_pred_action, batched_target_action) 50 | return loss.item() 51 | 52 | def learn_dataset(self, epochs:int, train_loader:DataLoader, val_loader:DataLoader=None): 53 | ''' 54 | Optionally implemented 55 | learn from the dataloader of all the dataset 56 | ''' 57 | for epoch in tqdm.tqdm(range(epochs), desc="Training with dataset..."): 58 | avg_train_loss, count = 0.0, 0 59 | for batch_data in train_loader: 60 | avg_train_loss += self.learn_batch(*batch_data) 61 | count += 1 62 | avg_train_loss /= count 63 | if self.enable_tensorboard: 64 | self.writer.add_scalar('loss/train', avg_train_loss, epoch) 65 | 66 | if val_loader is not None: 67 | avg_val_loss, count = 0.0, 0 68 | for batch_data in val_loader: 69 | avg_val_loss += self.learn_batch(*batch_data) 70 | count += 1 71 | avg_val_loss /= count 72 | if self.enable_tensorboard: 73 | self.writer.add_scalar('loss/val', avg_val_loss, epoch) 74 | else: 75 | avg_val_loss = None 76 | 77 | if self._plot_train_curve: 78 | self._update_train_curve(avg_train_loss, avg_val_loss) 79 | 80 | # if self.enable_tensorboard: 81 | # self.start_tensorboard() 82 | plt.savefig('./train_curve.png') 83 | plt.close() 84 | 85 | 86 | def _update_train_curve(self, train_loss, val_loss=None): 87 | if not hasattr(self, "_loss_record"): 88 | self._loss_record = {"train":[], "val":[]} 89 | plt.cla() 90 | ax = plt.gca() 91 | 92 | self._loss_record["train"].append(train_loss) 93 | ax.plot(self._loss_record["train"],label="train") 94 | 95 | if val_loss is not None: 96 | self._loss_record["val"].append(val_loss) 97 | if len(self._loss_record["val"]) > 0: 98 | ax.plot(self._loss_record["val"], label="val") 99 | 100 | plt.legend() 101 | plt.pause(0.01) 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /rl/policy/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | policy 本身应该是nn.Module, 3 | 该文件夹下可以视作是nn.Module的一种wrapper, 4 | 用于在policy本身附加一个训练的流程(用于套用实现好的rl或il算法) 5 | 6 | ''' 7 | -------------------------------------------------------------------------------- /rl/reward/BaseReward.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Tuple 3 | 4 | class BaseReward: 5 | def __init__(self): 6 | pass 7 | 8 | 9 | @abstractmethod 10 | def evaluate_log(self, observation, plan, next_observation) -> Tuple[float, bool]: 11 | pass 12 | 13 | @abstractmethod 14 | def evaluate_exp(self, state, action, next_action) -> Tuple[float, bool]: 15 | pass 16 | 17 | -------------------------------------------------------------------------------- /rl/reward/TerminateReward.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Sequence 2 | import numpy as np 3 | import math 4 | import torch 5 | 6 | 7 | import spider 8 | from spider.rl.reward.BaseReward import BaseReward 9 | 10 | from spider.evaluator import CartCostEvaluator 11 | from spider.constraints import CartConstriantChecker 12 | from spider.utils.collision import BoxCollisionChecker 13 | 14 | 15 | 16 | class TerminateReward(BaseReward): 17 | ''' 18 | 仅检查done 19 | ''' 20 | def __init__(self, 21 | valid_x_range=None, valid_y_range=None, 22 | des_x_range=None, des_y_range=None, 23 | collision_done=False, ego_size=(5.,2.), 24 | kinematics_done=False, constraints={} 25 | ): 26 | super().__init__() 27 | 28 | self.valid_x_range = (-math.inf, math.inf) if valid_x_range is None else valid_x_range # (-10., 250.0) 29 | self.valid_y_range = (-math.inf, math.inf) if valid_y_range is None else valid_y_range # (-10.,10.) 30 | 31 | self.destination_x_range = (math.inf, math.inf) if des_x_range is None else des_x_range #(245., 255.0) 32 | self.destination_y_range = (math.inf, math.inf) if des_y_range is None else des_y_range #(-5., 5.) 33 | 34 | # self.max_reward = 10.0 35 | self.finish_reward = 10.0 36 | self.punish_reward = -10.0 37 | 38 | self.collision_done = collision_done 39 | if collision_done: 40 | self.collision_checker = BoxCollisionChecker(*ego_size) 41 | 42 | if kinematics_done: # todo: 没有启用 43 | self.constraint_checker = CartConstriantChecker({}, None) # check kinematics only 44 | 45 | 46 | def evaluate_log(self, observation, plan, next_observation) -> Tuple[float, bool]: 47 | """ 48 | Evaluate the reward and termination condition for the given log. 49 | 50 | :param observation: the initial observation 51 | :param plan: the plan that lead to the next observation 52 | :param next_observation: the observation we are evaluating the log for 53 | :return: a tuple of (reward, done) 54 | """ 55 | ego_, perc_, lmap_ = next_observation 56 | ego_x_, ego_y_ = ego_.x(), ego_.y() 57 | 58 | range_reward, done = self._out_range_reward(ego_x_, ego_y_) 59 | if done: 60 | return range_reward, done 61 | 62 | if self.collision_done: 63 | c_reward, done = self._collision_reward(ego_, perc_) 64 | if done: 65 | return c_reward, done 66 | 67 | # _break_kinematics_reward, done = self._break_kinematics_reward(plan) 68 | d_reward, done = self._destination_reward(ego_x_, ego_y_) 69 | delay_reward, _ = self._delay_reward() 70 | 71 | return d_reward+delay_reward, done 72 | 73 | def evaluate_exp(self, *args) -> Tuple[float, bool]: 74 | raise NotImplementedError("Not implemented. Ready for model-based reward.") 75 | 76 | def _delay_reward(self): 77 | return -0.5, False 78 | 79 | 80 | def _break_kinematics_reward(self, traj): 81 | if self.constraint_checker.check_kinematics(traj): 82 | return self.punish_reward, False 83 | else: 84 | return 0.0, False 85 | 86 | 87 | def _out_range_reward(self, ego_x, ego_y) -> Tuple[float, bool]: 88 | if not self.valid_x_range[0] <= ego_x <= self.valid_x_range[1]: 89 | return self.punish_reward, True 90 | elif not self.valid_y_range[0] <= ego_y <= self.valid_y_range[1]: 91 | return self.punish_reward, True 92 | else: 93 | return 0.0, False 94 | 95 | def _destination_reward(self, ego_x, ego_y) -> Tuple[float, bool]: 96 | if self.destination_x_range[0] <= ego_x <= self.destination_x_range[1]: 97 | if self.destination_y_range[0] <= ego_y <= self.destination_y_range[1]: 98 | return self.finish_reward, True 99 | return 0.0, False 100 | 101 | def _collision_reward(self, ego_state, perception): 102 | collision = self.collision_checker.check_state(ego_state, perception) 103 | if collision: 104 | return self.punish_reward, True 105 | else: 106 | return 0.0, False -------------------------------------------------------------------------------- /rl/reward/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/rl/reward/__init__.py -------------------------------------------------------------------------------- /rl/reward/reward_collection.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ''' 5 | todo: 完成一堆reward的collection,比如舒适度reward, 碰撞reward什么的 6 | 7 | ''' 8 | 9 | -------------------------------------------------------------------------------- /rl/state/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/rl/state/__init__.py -------------------------------------------------------------------------------- /rl/transition/DeterministicTransition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | from spider.control.vehicle_model import Bicycle 6 | 7 | # class ObjectDeterministicTransition(nn.Module): 8 | # ''' 9 | # 一个简单的 MLP network 10 | # ''' 11 | # def __init__(self, state_veh_num, state_feat_num, action_dim, hidden_dim): 12 | # super(ObjectDeterministicTransition, self).__init__() 13 | # state_dim = state_feat_num * state_veh_num 14 | # self.MLP = nn.Sequential( 15 | # nn.Linear(state_dim + action_dim, hidden_dim), 16 | # nn.ReLU(), 17 | # nn.Linear(hidden_dim, hidden_dim), 18 | # nn.ReLU(), 19 | # nn.Linear(hidden_dim, state_dim), 20 | # ) 21 | # 22 | # 23 | # def forward(self, state, action): 24 | # ego_control = torch.tensor(action) 25 | # x = torch.cat((state, ego_control)) 26 | # delta = self.MLP(x) 27 | # next_state = delta + state 28 | # # next_state = self.MLP(state, action) # qzl:这个似乎不太好 29 | # return next_state 30 | 31 | 32 | class KineTransition(nn.Module): 33 | ''' 34 | 认为每个物体都是车 35 | ''' 36 | 37 | def __init__(self, state_veh_num, state_feat_num, action_dim, hidden_dim, delta_t, 38 | max_accleration=5.0, max_steer=35*math.pi/180): 39 | super(KineTransition, self).__init__() 40 | self.state_veh_num = state_veh_num 41 | self.state_feat_num = state_feat_num 42 | 43 | state_dim = state_feat_num * state_veh_num 44 | self.MLP = nn.Sequential( 45 | nn.Linear(state_dim + action_dim, hidden_dim), 46 | nn.Sigmoid(), 47 | nn.Linear(hidden_dim, hidden_dim), 48 | nn.Sigmoid(), 49 | nn.Linear(hidden_dim, (state_veh_num-1) * 2), 50 | nn.Tanh() # 缩放到-1和1 51 | ) 52 | 53 | self.dt = delta_t 54 | 55 | self.max_accleration = max_accleration # todo:现在暂时写的是正负约束相同 56 | self.max_steer = max_steer 57 | 58 | def forward(self, state, action): 59 | ''' 60 | action是加速度和转向角。 61 | 注意,action不用归一化 62 | ''' 63 | controls = self._forward_control(state, action) # 输出的是车的控制 64 | next_state = self._conduct_controls(state, controls) 65 | return next_state 66 | 67 | def _forward_control(self, state, action): 68 | # qzl: 需要考虑一下action会不会因为归一化而产生缩放的问题 69 | ego_control = action if type(action) is torch.Tensor else torch.tensor(action).to(state.device) 70 | 71 | x = torch.cat((state, ego_control), dim=-1) 72 | results = self.MLP(x) 73 | 74 | batch_size = state.shape[0] if len(state.shape)>1 else 1 75 | 76 | controls = results.view(batch_size, self.state_veh_num-1, 2) 77 | controls[:, :, 0] *= self.max_accleration 78 | controls[:, :, 1] *= self.max_steer 79 | 80 | controls = torch.cat((ego_control.view(batch_size, 1, 2), controls), dim=1) 81 | return controls 82 | 83 | def _conduct_controls(self, state, controls): 84 | batch_size = state.shape[0] if len(state.shape) > 1 else 1 85 | 86 | next_state = state.view(batch_size, self.state_veh_num, self.state_feat_num) 87 | 88 | for batch in range(len(controls)): 89 | batch_controls = controls[batch] 90 | for i, control in enumerate(batch_controls): 91 | acc, steer = control 92 | presence, x, y, length, width, heading, speed = next_state[batch, i, :]#.detach() 93 | if not presence: 94 | continue 95 | veh = Bicycle(x, y, speed, 0.0, heading, dt=self.dt) 96 | veh.step(acc, steer) 97 | next_state[batch, i, :] = torch.tensor([presence, veh.x, veh.y, length, width, veh.heading, veh.velocity]) 98 | 99 | if len(state.shape) > 1: 100 | next_state = next_state.view(batch_size, -1) 101 | else: 102 | next_state = next_state.flatten() 103 | return next_state 104 | 105 | 106 | if __name__ == '__main__': 107 | trans = KineTransition(10, 7, 2, 64, 0.1) 108 | state = torch.rand(70) 109 | action = [1.0, 0.0] 110 | next_state = trans(state, action) 111 | print(next_state) 112 | print(next_state.shape) 113 | pass 114 | -------------------------------------------------------------------------------- /rl/transition/GaussianTransition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | 6 | 7 | class GaussianTransition(nn.Module): 8 | """Predict gaussion trajectory, in offset format""" 9 | 10 | def __init__(self, in_channels, out_channels, hidden_unit, normalize_state_function, max_sigma=1e1, min_sigma=1e-4): 11 | super().__init__() 12 | self.fc = nn.Sequential( 13 | nn.Linear(in_channels, hidden_unit), 14 | nn.Sigmoid(), 15 | nn.Linear(hidden_unit, hidden_unit), 16 | nn.Sigmoid(), 17 | nn.Linear(hidden_unit, hidden_unit), 18 | nn.Sigmoid(), 19 | # nn.Linear(hidden_unit, hidden_unit) 20 | ) 21 | self.fc_mu = nn.Linear(hidden_unit, out_channels) 22 | self.fc_sigma = nn.Linear(hidden_unit, out_channels) 23 | 24 | self.max_sigma = max_sigma 25 | self.min_sigma = min_sigma 26 | assert(self.max_sigma >= self.min_sigma) 27 | 28 | # Vehicle Model 29 | self.wheelbase = 2.96 30 | self.max_steer = np.deg2rad(80) 31 | self.dt = 0.1 32 | self.c_r = 0.01 33 | self.c_a = 0.05 34 | self.vehicle_model_torch = KinematicBicycleModel_Pytorch(self.wheelbase, self.max_steer, self.dt, self.c_r, self.c_a) 35 | 36 | self.normalize_state_function = normalize_state_function 37 | 38 | 39 | def forward(self, x): 40 | """ 41 | qzl: 42 | transition model的本质是,输入s,a, 输出s' 43 | 在这里的函数的意思是, 44 | 输入x是观测量/状态s 45 | 46 | """ 47 | 48 | normalize_x = self.normalize_state_function(x) 49 | normalize_x = self.fc(normalize_x) 50 | mu = self.fc_mu(normalize_x) 51 | sigma = torch.sigmoid(self.fc_sigma(normalize_x)) # range (0, 1.) 52 | sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * sigma # scaled range (min_sigma, max_sigma) 53 | 54 | pred_state = self.forward_torch_vehicle_model(x, mu) # it is quite slow to use that 55 | 56 | return pred_state, sigma 57 | 58 | def sample_prediction(self, x): 59 | mu, sigma = self(x) 60 | eps = torch.randn_like(sigma) 61 | return mu + sigma * eps 62 | 63 | def forward_torch_vehicle_model(self, obs, pred_action): 64 | pred_state = [] 65 | for i in range(len(pred_action[0])): 66 | x = torch.mul(obs[i][0], self.obs_scale) 67 | y = torch.mul(obs[i][1], self.obs_scale) 68 | yaw = torch.mul(obs[i][4], self.obs_scale) 69 | v = torch.tensor \ 70 | (math.sqrt(torch.mul(obs[i][2], self.obs_scale) ** 2 + torch.mul(obs[i][3], self.obs_scale) ** 2)) 71 | x, y, yaw, v, _, _ = self.vehicle_model_torch.kinematic_model(x, y, yaw, v, pred_action[0][i][0], pred_action[0][i][1]) 72 | tensor_list = [torch.div(x, self.obs_scale), torch.div(y, self.obs_scale), torch.div(torch.mul(v, torch.cos(yaw)), self.obs_scale), 73 | torch.div(torch.mul(v, torch.sin(yaw)), self.obs_scale), torch.div(yaw, self.obs_scale)] 74 | next_vehicle_state = torch.stack(tensor_list) 75 | # print("next_vehicle_state",next_vehicle_state) 76 | 77 | pred_state.append(next_vehicle_state) 78 | 79 | print("pred_state" ,pred_state) 80 | pred_state = torch.stack(pred_state) 81 | return pred_state 82 | -------------------------------------------------------------------------------- /rl/transition/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/rl/transition/__init__.py -------------------------------------------------------------------------------- /sampler/BaseSampler.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | class BaseSampler: 5 | def __init__(self): 6 | pass 7 | 8 | # class BaseSampler: 9 | # def __init__(self, steps, dt): 10 | # self.steps = steps 11 | # self.dt = dt 12 | # pass 13 | -------------------------------------------------------------------------------- /sampler/PathSampler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | including some samplers that generates a 2D curve as path 3 | ''' 4 | import numpy as np 5 | from spider.sampler.BaseSampler import BaseSampler 6 | from spider.elements.curves import ParametricCurve, BezierCurve 7 | from spider.utils.transform import FrenetTransformer 8 | 9 | class BezierCurveSampler(BaseSampler): 10 | def __init__(self, end_s_candidates, end_l_candidates, num_control_points=5, max_transition_length=5.0): 11 | ''' 12 | end_dx_candidates: x一阶导的终值候选项 13 | ''' 14 | super(BezierCurveSampler, self).__init__() 15 | assert num_control_points == 5, "Only support 5 points sampling for now..." 16 | self.end_s_candidates = end_s_candidates 17 | self.end_l_candidates = end_l_candidates 18 | self.num_control_points = num_control_points 19 | self.max_transition_length = max_transition_length #transition_length是为了保证初始yaw和末尾yaw而延申的长度 20 | 21 | def sample(self, start_x, start_y, start_yaw, frenet_transformer: FrenetTransformer): 22 | ''' 23 | referline应该是参数化曲线,输入s,输出x,y 24 | ''' 25 | 26 | start_s = frenet_transformer.cart2frenet(start_x, start_y,order=0).s 27 | 28 | samples = [] 29 | for delta_s in self.end_s_candidates: 30 | end_s = delta_s+start_s 31 | for end_l in self.end_l_candidates: 32 | end_x, end_y, end_yaw = self._calc_pose(end_s,end_l, frenet_transformer.refer_line_csp) 33 | control_points = self._calc_control_points(start_x, start_y, start_yaw, end_x, end_y, end_yaw) 34 | samples.append(BezierCurve(control_points)) 35 | 36 | return samples 37 | 38 | def _calc_control_points(self, start_x, start_y, start_yaw, end_x, end_y, end_yaw): 39 | ''' 40 | only support 5 points 41 | ''' 42 | dist = np.array((end_y-start_y)**2 + (end_x-start_x)**2) 43 | transition_length = min([dist/(self.num_control_points-1), self.max_transition_length]) 44 | control_points = np.array([ 45 | [start_x, start_y], 46 | [start_x + transition_length*np.cos(start_yaw), start_y + transition_length*np.sin(start_yaw)], 47 | 48 | [end_x - transition_length*np.cos(end_yaw), end_y - transition_length*np.sin(end_yaw)], 49 | [end_x, end_y] 50 | ]) 51 | control_points = np.insert(control_points, 2, (control_points[1]+control_points[-2])/2, axis=0) 52 | return control_points 53 | 54 | 55 | # def _control_points_sl(self, start_s, end_s, start_l, end_l): 56 | # ss = np.linspace(start_s, end_s, self.num_control_points) 57 | # ls = np.array([start_l, start_l, 0.5*(start_l+end_l), end_l, end_l]) 58 | # return ss, ls 59 | 60 | 61 | def _calc_pose(self, s, l, reference_line_curve:ParametricCurve): 62 | # todo: 用frenet transformer替代 63 | x, y = reference_line_curve(s) 64 | theta = reference_line_curve.calc_yaw(s) 65 | if np.all(l == 0): 66 | return x,y, theta 67 | else: 68 | x = x - l * np.sin(theta) 69 | y = y + l * np.cos(theta) 70 | return x, y, theta 71 | 72 | -------------------------------------------------------------------------------- /sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.sampler.PolynomialSampler import \ 2 | (QuinticPolyminalSampler, QuarticPolyminalSampler, PiecewiseQuinticPolyminalSampler) 3 | from spider.sampler.Combiner import PVDCombiner, LatLonCombiner 4 | from spider.sampler.PathSampler import BezierCurveSampler 5 | 6 | 7 | -------------------------------------------------------------------------------- /sampler/common.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from types import FunctionType 3 | 4 | 5 | def closure(function, *args, **kwargs): 6 | ''' 7 | 创建闭包函数,避免参数被修改。 8 | ''' 9 | new_function = lambda: function(*args, **kwargs) 10 | return new_function 11 | 12 | 13 | class LazyList(list): 14 | ''' 15 | 惰性计算的list。 16 | self 本身存放的是函数,调用 self 时,执行 self 所包含的函数,并将结果缓存下来。 17 | 在后续调用时,直接返回缓存的结果。 18 | call by need 19 | 目前仅检查了索引、切片操作,能完成惰性计算。 20 | ''' 21 | def __init__(self, iterable_of_callable=None): 22 | if iterable_of_callable is None: 23 | super().__init__() 24 | else: 25 | super().__init__(iterable_of_callable) 26 | 27 | self._cache = {} 28 | 29 | def __iter__(self): # 返回的也是一个生成器 30 | return (self[i] for i in range(len(self))) 31 | 32 | def __getitem__(self, index): 33 | if isinstance(index, slice): # 切片。LazyList切片操作并不高效。 34 | return [self.__getitem__(i) for i in range(index.start, index.stop, index.step)] 35 | 36 | else: # 索引 37 | if index in self._cache: 38 | return self._cache[index] 39 | else: 40 | generate_func = super().__getitem__(index) 41 | assert callable(generate_func), \ 42 | 'LazyList is used to store callable. Please disable calc_by_need mode if you want to store the result.' 43 | self._cache[index] = result = generate_func() 44 | return result 45 | 46 | 47 | @staticmethod 48 | def wrap_generator(function, *args, **kwargs): 49 | return closure(function, *args, **kwargs) 50 | 51 | def to_instance_list(self): 52 | return list(self) 53 | 54 | 55 | if __name__ == '__main__': 56 | samples = LazyList() 57 | for x in [1, 2, 3]: 58 | for y in [4,5,6]: 59 | samples.append( # append一个callable函数 60 | closure(lambda x, y: x + y, x, y) 61 | ) 62 | 63 | for val in samples: 64 | print(val) 65 | print(samples._cache) 66 | 67 | print("--------------------------------") 68 | samples = LazyList() 69 | for x in [1, 2]: 70 | for y in [4, 5]: 71 | samples.append( # append一个callable函数 72 | closure(lambda x, y: x + y, x, y) 73 | ) 74 | a,b,c,d = samples 75 | print(a,b,c,d) 76 | print(samples._cache) 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /teaser.py: -------------------------------------------------------------------------------- 1 | # import matplotlib.pyplot as plt 2 | # import numpy as np 3 | # 4 | # from spider.planner_zoo.LatticePlanner import LatticePlanner 5 | # 6 | # from spider.elements.map import RoutedLocalMap 7 | # from spider.elements.Box import TrackingBoxList, TrackingBox 8 | # from spider.elements.map import Lane 9 | # from spider.elements.vehicle import VehicleState, Transform, Location, Rotation,Vector3D 10 | # import spider.visualize as vis 11 | 12 | 13 | __all__ = ['demo', 'teaser'] 14 | 15 | def teaser(): 16 | return demo() # 执行spider.teaser(), 就会执行demo() 17 | 18 | def demo(): 19 | from spider.interface.BaseBenchmark import DummyBenchmark 20 | from spider.planner_zoo import LatticePlanner 21 | 22 | planner = LatticePlanner({ 23 | "steps": 20, 24 | "dt": 0.2, 25 | "end_s_candidates": (20, 40, 60), 26 | "end_l_candidates": (-3.5, 0, 3.5), 27 | }) 28 | 29 | benchmark = DummyBenchmark() 30 | benchmark.test(planner) 31 | 32 | 33 | 34 | if __name__ == '__main__': 35 | demo() 36 | 37 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/utils/__init__.py -------------------------------------------------------------------------------- /utils/collision/AABB.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spider.elements.box import aabb2vertices, vertices2aabb 3 | 4 | def AABB_check(vertices1:np.ndarray ,vertices2:np.ndarray): 5 | xmin1, ymin1, xmax1, ymax1 = vertices2aabb(vertices1) 6 | xmin2, ymin2, xmax2, ymax2 = vertices2aabb(vertices2) 7 | length2, width2 = xmax2-xmin2, ymax2-ymin2 8 | expanded_aabb1 = [xmin1, ymin1, xmax1+length2, ymax1+width2] 9 | if expanded_aabb1[0] bool: 19 | pass 20 | 21 | @abstractmethod 22 | def check_trajectory(self, trajectory, observation) -> bool: 23 | pass 24 | 25 | 26 | 27 | class BoxCollisionChecker(BaseCollisionChecker): 28 | def __init__(self, ego_veh_length=5., ego_veh_width=2., 29 | method=spider.COLLISION_CHECKER_SAT, safe_dist=(0.0, 0.0)): 30 | super(BoxCollisionChecker, self).__init__(method) 31 | # self.method = method_flag 32 | self.ego_box_vertices = None 33 | 34 | self.bboxes_vertices = None 35 | # self.ogm = None 36 | 37 | self.ego_length = ego_veh_length #0. 38 | self.ego_width = ego_veh_width 39 | 40 | self.safe_dist = safe_dist # 分别为纵向、横向的安全距离阈值 41 | 42 | # todo: 这里以后考虑一下如何和vertices统一 43 | self.ego_box_obb = None 44 | self.bboxes_obb = None 45 | 46 | def set_ego_veh_size(self,length,width): 47 | self.ego_length = length 48 | self.ego_width = width 49 | 50 | # setEgoVehicleBox 51 | def set_ego_box(self, ego_box_vertices=None): 52 | self.ego_box_vertices = ego_box_vertices 53 | if self.method == spider.COLLISION_CHECKER_DISK: # todo:以后去掉 54 | self.ego_box_obb = vertices2obb(ego_box_vertices) 55 | 56 | def set_obstacles(self, bboxes_vertices=None): 57 | self.bboxes_vertices = bboxes_vertices 58 | if self.method == spider.COLLISION_CHECKER_DISK: # todo:以后去掉 59 | self.bboxes_obb = [vertices2obb(vs) for vs in bboxes_vertices] 60 | 61 | def check(self, ego_box_vertices=None, bboxes_vertices=None, safe_dilate=False): 62 | if not (ego_box_vertices is None): 63 | self.set_ego_box(ego_box_vertices) 64 | if not (bboxes_vertices is None): 65 | self.set_obstacles(bboxes_vertices) 66 | 67 | if safe_dilate: 68 | self.set_ego_box(dilate(ego_box_vertices, 2 * self.safe_dist[0], 2 * self.safe_dist[1])) 69 | 70 | collision = False 71 | if self.method == spider.COLLISION_CHECKER_SAT: 72 | for bbox_vertices in self.bboxes_vertices: 73 | if SAT_check(self.ego_box_vertices, bbox_vertices): 74 | collision = True 75 | break 76 | elif self.method == spider.COLLISION_CHECKER_DISK: 77 | for obb in self.bboxes_obb: 78 | # todo: 这里的逻辑要改的很多,因为现在输入和预测只接收了顶点,但disk check需要Obb。来回转换很蠢 79 | if disk_check_for_box(self.ego_box_obb, obb): 80 | collision = True 81 | break 82 | elif self.method == spider.COLLISION_CHECKER_AABB: 83 | for bbox_vertices in self.bboxes_vertices: 84 | if AABB_check(self.ego_box_vertices, bbox_vertices): 85 | collision = True 86 | break 87 | else: 88 | raise ValueError("INVALID method for box collision checker") 89 | 90 | return collision # True就是撞了,False就是没撞 91 | 92 | def check_trajectory(self, traj:Trajectory, predicted_obstacles:TrackingBoxList, ego_length=0.0, ego_width=0.0): 93 | if ego_length and ego_width: 94 | self.set_ego_veh_size(ego_length,ego_width) 95 | 96 | for i in range(traj.steps): 97 | x, y, heading = traj.x[i], traj.y[i], traj.heading[i] 98 | ego_box_vertices = obb2vertices( 99 | [x, y, self.ego_length+2*self.safe_dist[0], self.ego_width+2*self.safe_dist[1], heading]) 100 | collision = self.check(ego_box_vertices, predicted_obstacles.get_vertices_at(i)) 101 | if collision: 102 | return True 103 | return False 104 | 105 | 106 | def check_state(self, ego_veh_state:VehicleState, obstacles:TrackingBoxList): 107 | ego_box_vertices = obb2vertices(ego_veh_state.obb) 108 | return self.check(ego_box_vertices, obstacles.get_vertices_at(0)) 109 | 110 | 111 | class GridCollisionChecker(BaseCollisionChecker): 112 | pass 113 | 114 | -------------------------------------------------------------------------------- /utils/collision/CollisionConstraints.py: -------------------------------------------------------------------------------- 1 | # 碰撞躲避约束的建模方法 2 | import spider.utils.collision.CollisionChecker 3 | from spider.elements.trajectory import Trajectory 4 | from spider.utils.collision.CollisionChecker import BoxCollisionChecker 5 | # from utils.collision.SAT import SAT_check 6 | import numpy as np 7 | from spider.elements.box import obb2vertices,aabb2vertices,TrackingBox,TrackingBoxList 8 | import spider 9 | 10 | 11 | 12 | 13 | 14 | def triangle_area(): 15 | # 三角面积法 16 | # TODO:补充这个方法 17 | pass 18 | 19 | 20 | def generate_corridor(initial_guess:Trajectory, observation, delta=0.2, max_expand=5.0): 21 | ''' 22 | 23 | :param initial_guess: 24 | :param observation: 25 | :param delta: 26 | :param max_expand: 27 | :return: corridor: [ [x1_min, y1_min, x1_max, y1_max],... ] 28 | ''' 29 | if isinstance(observation,TrackingBoxList): 30 | corridor = generate_corridor_bboxes(initial_guess, observation, delta,max_expand) 31 | # elif isinstance(observation, ): # TODO:补充OGM 32 | else: 33 | raise ValueError("Invalid input") 34 | 35 | return corridor 36 | 37 | 38 | def generate_corridor_bboxes(initial_guess:Trajectory, bboxes:TrackingBoxList, 39 | delta=0.2, max_expand=5.0): 40 | # lon_offset为负表示在几何中心后方 41 | 42 | # bboxes.dilate(radius) 43 | # bboxes.predict(initial_guess.x) # TODO:QZL:是不是要把预测放到外面 44 | collision_checker = BoxCollisionChecker(method=spider.COLLISION_CHECKER_SAT) 45 | 46 | corridor = [] 47 | for i in range(len(initial_guess.t)): 48 | x, y, heading, t = initial_guess.x[i], initial_guess.y[i], initial_guess.heading[i], initial_guess.t[i] 49 | 50 | if t == 0: 51 | continue 52 | 53 | # collision_checker.set_ego_box(obb2vertices((x,y,ego_veh_size[0],ego_veh_size[1],heading))) 54 | collision_checker.set_obstacles(bboxes_vertices=bboxes.get_vertices_at(step=i)) 55 | 56 | seed = np.float32([x-0.1, y-0.1, x+0.1, y+0.1]) # 坍缩为一个小区域,四个方向发散以扩展 57 | sign = [-1, -1, 1, 1] 58 | space = seed.copy() 59 | StopIterFlag = [False, False, False, False] 60 | 61 | while not np.all(StopIterFlag): # 全部方向都停止迭代才break 62 | for j in range(4): # 每个方向尝试拓展 63 | if StopIterFlag[j]: 64 | continue 65 | 66 | temp_space = space.copy() 67 | temp_space[j] += sign[j] * delta 68 | 69 | collision_checker.set_ego_box(aabb2vertices(temp_space)) 70 | 71 | if np.abs(temp_space[j] - seed[j]) > max_expand or collision_checker.check(): 72 | # 超界 或者 碰撞 73 | # TODO:记得加上道路边界的碰撞 74 | StopIterFlag[j] = True 75 | continue 76 | space = temp_space 77 | corridor.append(space) 78 | return np.array(corridor) 79 | 80 | def generate_corridor_ogm(initial_guess:Trajectory, ogm, 81 | delta=0.2, max_expand=5.0): 82 | # TODO:补充OGM 83 | return [] 84 | 85 | 86 | 87 | 88 | 89 | if __name__ == '__main__': 90 | import matplotlib.pyplot as plt 91 | 92 | from utils.Visualize import * 93 | 94 | traj = Trajectory() 95 | traj.x = np.arange(11) 96 | traj.y = 0.05 *(traj.x-5) ** 3 97 | traj.t = np.array(list(range(11)))*0.1 98 | traj.heading = 45 * np.pi / 180 * np.ones_like(traj.t) 99 | 100 | bboxes = TrackingBoxList() 101 | vertices1 = np.array([ 102 | [4,-4], 103 | [6,-4], 104 | [6,-1.5], 105 | [4,-1.5] 106 | ]) 107 | bboxes.append(TrackingBox.from_vertices(vertices=vertices1, vx=0, vy=0)) 108 | 109 | vertices2 = np.array([ 110 | [-2, -7.5], 111 | [-2, -5], 112 | [-1, -5], 113 | [-1, -7.5] 114 | ]) 115 | bboxes.append(TrackingBox.from_vertices(vertices=vertices2, vx=4, vy=14)) 116 | 117 | corridor = generate_corridor(traj, bboxes) 118 | for i, rect in enumerate(corridor): 119 | x1, y1, x2, y2 = rect 120 | draw_rectangle(x1,y1,x2,y2) 121 | bboxes_vertices = bboxes.get_vertices_at(i) 122 | for vertice in bboxes_vertices: 123 | draw_polygon(vertice,color='red') 124 | plt.pause(0.01) 125 | 126 | draw_trajectory(traj) 127 | plt.show() 128 | -------------------------------------------------------------------------------- /utils/collision/SAT.py: -------------------------------------------------------------------------------- 1 | from spider.utils.vector import * 2 | import numpy as np 3 | from spider.utils.collision.AABB import AABB_check 4 | 5 | 6 | def SAT_check(vertices1:np.ndarray ,vertices2:np.ndarray): 7 | ''' 8 | Separating Axis Theorem,SAT,分离轴定理,用于检测凸多边形碰撞 9 | qzl:已经解决:两个矩形的话只需要检查4条边对应的轴(因为有平行),目前是默认8条全部检测。可以尝试用集合保证唯一性。后期可以改进。 10 | :param vertices1: 多边形1 11 | :param vertices2: 多边形2 12 | :return: 碰撞与否 13 | ''' 14 | if not AABB_check(vertices1,vertices2): 15 | # 先粗检 16 | return False 17 | 18 | separating_axis_vec = np.empty((0, 2)) 19 | # TODO:qzl:下面的内容写成矩阵计算形式更快 20 | # 获取所有边的单位向量并储存 21 | for polygon in [vertices1,vertices2]: 22 | for i in range(len(polygon)): 23 | v1 = polygon[i] 24 | if i == len(polygon)-1: 25 | v2 = polygon[0] 26 | else: 27 | v2 = polygon[i+1] 28 | vertical_vec = rotate90(v2-v1) # v2-v1是1-2的边对应的向量 29 | if vec_in(-vertical_vec, separating_axis_vec) or vec_in(vertical_vec, separating_axis_vec): 30 | continue # 如果已经存在方向相同的向量,则不用添加 todo:这里没有加长度上的判断,不过暂时无关紧要了 31 | # separating_axis_vec.append(vertical_vec) 32 | separating_axis_vec = np.append(separating_axis_vec, [vertical_vec], axis=0) 33 | 34 | # TODO: qzl:投影的时候其实也不用单位化法向量 35 | # 计算每个顶点向量在分离轴上的投影,并检查是否重叠 36 | collision = True 37 | for axis_vec in separating_axis_vec: 38 | proj1 = project(vertices1, axis_vec) 39 | min1, max1 = np.min(proj1), np.max(proj1) 40 | proj2 = project(vertices2, axis_vec) 41 | min2, max2 = np.min(proj2), np.max(proj2) 42 | if min1 > max2 or min2 > max1: # 说明没有重叠, 也就是存在分割线 43 | collision = False 44 | break 45 | 46 | return collision 47 | -------------------------------------------------------------------------------- /utils/collision/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.utils.collision.CollisionChecker import BoxCollisionChecker 2 | from spider.utils.collision.CollisionConstraints import * 3 | -------------------------------------------------------------------------------- /utils/collision/disk_num_calculation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/utils/collision/disk_num_calculation.jpg -------------------------------------------------------------------------------- /utils/collision/disks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Type, Union, Sequence 3 | import spider.elements as elm 4 | 5 | _target_theta = 45 / 180 * np.pi 6 | _tan_target_theta = np.tan(_target_theta) 7 | 8 | #bounding_box: Union[Type[elm.BoundingBox], Sequence] 9 | def disks_approximate(bounding_box: Union[Type[elm.BoundingBox], Sequence],disk_num:int=-1): 10 | ''' 11 | 12 | :param length: 13 | :param width: 14 | :param N: 15 | :return: 返回几个圆盘的圆心在heading方向上与车辆质心的距离 16 | ''' 17 | if isinstance(bounding_box, elm.BoundingBox): 18 | x,y, length, width, heading = bounding_box.obb 19 | else: 20 | x, y, length, width, heading = bounding_box 21 | 22 | if disk_num < 0: # 自动以45度theta角为目标计算disk_num 23 | disk_num = _tan_target_theta * length / width 24 | disk_num = int(np.ceil(disk_num)) 25 | 26 | radius = np.sqrt((width/2) ** 2 + (length/(2*disk_num)) ** 2) 27 | delta_length = length / disk_num 28 | 29 | center_offset = delta_length*np.arange(disk_num) + (delta_length/2. - length/2.) # 把yaw对齐x轴时,圆心的x 30 | centers = np.column_stack((center_offset, np.zeros_like(center_offset))) # 把yaw对齐x轴时,圆心的x,y 31 | centers = elm.vector.rotate(centers, (0.,0.), heading) 32 | 33 | centers[:, 0] += x 34 | centers[:, 1] += y 35 | 36 | return centers, radius 37 | 38 | def disk_check_for_box(obb1, obb2): 39 | ''' 40 | :param obb1: 多边形1 41 | :param obb2: 多边形2 42 | :return: 碰撞与否 43 | ''' 44 | centers1, radius1 = disks_approximate(obb1) 45 | centers2, radius2 = disks_approximate(obb2) 46 | 47 | # 使用广播计算所有点对之间的距离的平方 48 | all_dist_2 = np.sum((centers1[:, np.newaxis, :] - centers2)**2, axis=2) 49 | 50 | threshold = (radius1 + radius2) ** 2 51 | if np.any(all_dist_2 <= threshold): 52 | return True 53 | else: 54 | return False 55 | 56 | 57 | def disk_check(center1, radius1, center2, radius2): 58 | ''' 59 | 检查两个disk之间是否碰撞 60 | return True: 碰撞 61 | ''' 62 | distance = np.linalg.norm(center2-center1) 63 | if distance > radius1+radius2: 64 | return False 65 | else: 66 | return True 67 | 68 | 69 | if __name__ == '__main__': 70 | import matplotlib.pyplot as plt 71 | obb1 = [100, 200, 8, 2, 30/180*3.14] 72 | cs, r = disks_approximate(obb1) 73 | 74 | vs = elm.Box.obb2vertices(obb1) 75 | vs = np.vstack((vs,vs[0,:])) 76 | plt.plot(vs[:,0],vs[:,1]) 77 | 78 | for pt in cs: 79 | plt.plot(pt[0], pt[1], 'r.') 80 | temp = plt.Circle(pt, r, fill=False, color='r') 81 | plt.gca().add_artist(temp) 82 | 83 | obb2 = [105, 199, 5, 2, 150 / 180 * 3.14] 84 | cs, r = disks_approximate(obb2) 85 | 86 | vs = elm.Box.obb2vertices(obb2) 87 | vs = np.vstack((vs, vs[0, :])) 88 | plt.plot(vs[:, 0], vs[:, 1]) 89 | 90 | for pt in cs: 91 | plt.plot(pt[0], pt[1], 'r.') 92 | temp = plt.Circle(pt, r, fill=False, color='r') 93 | plt.gca().add_artist(temp) 94 | 95 | 96 | if disk_check_for_box(obb1, obb2): 97 | print("COLLIDE!") 98 | else: 99 | print("SAFE!") 100 | 101 | plt.gca().set_aspect("equal") 102 | plt.show() 103 | 104 | -------------------------------------------------------------------------------- /utils/collision/ray_cast.py: -------------------------------------------------------------------------------- 1 | def ray_cast_check(point, polygon): 2 | # Check if the point is inside the polygon using the ray-casting algorithm 3 | x, y = point 4 | inside = False 5 | n = len(polygon) 6 | p1x, p1y = polygon[0] 7 | for i in range(n + 1): 8 | p2x, p2y = polygon[i % n] 9 | if y > min(p1y, p2y): 10 | if y <= max(p1y, p2y): 11 | if x <= max(p1x, p2x): 12 | if p1y != p2y: 13 | xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x 14 | if p1x == p2x or x <= xinters: 15 | inside = not inside 16 | p1x, p1y = p2x, p2y 17 | return inside 18 | 19 | if __name__== "__main__": 20 | polygon = [(0, 0), (0, 5), (5, 5), (5, 0)] 21 | point = (2, 2) 22 | is_inside = ray_cast_check(point, polygon) 23 | print(is_inside) -------------------------------------------------------------------------------- /utils/lane_decision.py: -------------------------------------------------------------------------------- 1 | 2 | # lane_id = function(ego_state, perception, local_map) 3 | 4 | class ConstLaneDecision: 5 | def __init__(self, lane_id=1): 6 | self.lane_id = lane_id 7 | 8 | def __call__(self, *args, **kwargs): 9 | return self.decide(*args, **kwargs) 10 | 11 | def decide(self, ego_state, perception, local_map): 12 | idx = min([self.lane_id, len(local_map.lanes)]) 13 | idx = max([self.lane_id, 0]) 14 | return idx 15 | 16 | class NearestLaneDecision: 17 | def __call__(self, *args, **kwargs): 18 | return self.decide(*args, **kwargs) 19 | 20 | def decide(self, ego_state, perception, local_map): 21 | ego_lane_idx = local_map.match_lane(ego_state) 22 | return ego_lane_idx 23 | 24 | class UtilityLaneDecision: 25 | ''' 26 | todo: 未完成 27 | ''' 28 | def __call__(self, *args, **kwargs): 29 | return self.decide(*args, **kwargs) 30 | 31 | def decide(self, ego_state, perception, local_map): 32 | raise NotImplementedError 33 | -------------------------------------------------------------------------------- /utils/potential_field/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/utils/potential_field/__init__.py -------------------------------------------------------------------------------- /utils/potential_field/potential_field.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | # from pyheatmap.heatmap import HeatMap 4 | # import matplotlib.cm as cm 5 | from matplotlib.colors import LogNorm 6 | 7 | from spider.utils.potential_field.static_risk import static_risk 8 | from spider.utils.potential_field.velocity_oriented_risk import vel_oriented_risk 9 | 10 | 11 | class PotentialField: 12 | def __init__(self, obs_xs, obs_ys, obs_vxs, obs_vys, obs_classes, ego_radius, 13 | risk_type="velocity_oriented"): 14 | self.obs_info = list(zip(obs_xs, obs_ys, obs_vxs, obs_vys, obs_classes)) 15 | # self.k = 1 # 常数 16 | # self.a_b = 4 # 长短轴比例 17 | self.ego_radius = ego_radius 18 | # self.max_risk = 2 19 | 20 | self.risk_type = risk_type 21 | # self.build_grid() 22 | 23 | def build_grid(self, x_range, y_range, dx=0.5, dy=0.5, visualize=False): 24 | xs = np.arange(x_range[0], x_range[1], dx) 25 | ys = np.arange(y_range[0], y_range[1], dy) 26 | xxs, yys = np.meshgrid(xs, ys) 27 | xxs = xxs 28 | yys = yys 29 | risk = self.calc_risk_npts(xxs, yys) 30 | # risk[idx] = nocollision 31 | # risk[np.bitwise_not(idx)] = np.max(nocollision)*1 32 | # risk[risk>self.max_risk] = self.max_risk 33 | risk_grid_map = np.reshape(risk[::-1],(ys.shape[0],xs.shape[0])) 34 | if visualize: 35 | plt.imshow(risk_grid_map, extent=(np.amin(xs), np.amax(xs), np.amin(ys), np.amax(ys)), 36 | cmap="Reds", norm=LogNorm()) 37 | plt.colorbar() 38 | return risk_grid_map 39 | 40 | def show_grid(self, risk_grid_map, extent=None): 41 | plt.imshow(risk_grid_map, extent=extent, cmap="Reds", norm=LogNorm()) 42 | plt.colorbar() 43 | 44 | def calc_risk_npts(self, xs, ys): 45 | 46 | # 用于计算一堆点的risk,返回一个数组 47 | risk_sum = np.zeros_like(xs,dtype=np.float32) 48 | 49 | for obs_x, obs_y, obs_vx, obs_vy, obs_class in self.obs_info: 50 | if self.risk_type == "static": 51 | risk= static_risk(xs,ys,obs_x,obs_y,obs_class,self.ego_radius) 52 | elif self.risk_type == "velocity_oriented": 53 | risk = vel_oriented_risk(xs,ys,obs_x,obs_y,obs_vx,obs_vy,obs_class,self.ego_radius) 54 | else: 55 | raise ValueError("Invalid risk type") 56 | risk_sum += risk 57 | return risk_sum 58 | 59 | 60 | if __name__ == '__main__': 61 | pfield= PotentialField([0,10,40],[0,20,-10],[8,0,10],[4,-4,0],[0,0,0],1, )#"static" 62 | pfield.build_grid([-50,100],[-50,50],0.2,0.2, True) 63 | plt.show() 64 | 65 | -------------------------------------------------------------------------------- /utils/potential_field/static_risk.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | 5 | G, K_risk = 9.0, 0.3 6 | Ms = {i:1 for i in range(50)} # qzl: a ratio determined by the class of boundary points 7 | max_risk = 1.0 8 | _epsilon = 0.001 9 | 10 | ''' 11 | Euclidean Signed Distance Field (ESDF) 12 | ''' 13 | def static_risk(target_xs, target_ys, obs_x, obs_y, obs_class=0, ego_radius=1.0): 14 | ''' 15 | 计算静态风险 16 | ''' 17 | dx = target_xs - obs_x 18 | dy = target_ys - obs_y 19 | dist = np.sqrt(dx ** 2 + dy ** 2) 20 | 21 | rr = (dist - ego_radius)**2 22 | rr[rr < _epsilon] = _epsilon # avoid division by zero 23 | 24 | M = Ms.get(int(obs_class), 1.0) 25 | risk = G * M / rr 26 | 27 | risk[dist <= ego_radius] = max_risk 28 | risk[risk > max_risk] = max_risk 29 | return risk 30 | -------------------------------------------------------------------------------- /utils/potential_field/velocity_oriented_risk.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | 5 | G, K_risk = 9, 0.3 6 | Ms = {i:1 for i in range(50)} # qzl: a ratio determined by the class of boundary points 7 | max_risk = 1.0 8 | _epsilon = 0.001 9 | 10 | def vel_oriented_risk(target_xs, target_ys, obs_x, obs_y, obs_vx, obs_vy, obs_class=0, ego_radius=1.0): 11 | dx = target_xs - obs_x 12 | dy = target_ys - obs_y 13 | dist = np.sqrt(dx ** 2 + dy ** 2) 14 | 15 | # if np.any(dist <= ego_radius): 16 | # return max_risk # *10 collision!! 17 | 18 | rr = (dist - ego_radius) ** 2 # dist < ego radius的在后面会被置为max risk 19 | # rr = dist ** 2 20 | rr[rr < _epsilon] = _epsilon # avoid division by zero 21 | 22 | v = np.sqrt(obs_vx ** 2 + obs_vy ** 2) 23 | vtheta = np.arctan2(obs_vy, obs_vx) # if obvx != 0 else 0.5*np.pi*obvy/abs(obvy) 24 | rtheta = np.arctan2(dy, dx) 25 | theta = vtheta - rtheta 26 | M = Ms.get(int(obs_class), 1.0) 27 | risk = np.exp(K_risk * v * np.cos(theta)) * G * M / rr 28 | # risk = np.exp(K_risk * v * np.cos(theta) * 0.5 * (1 + np.cos(2 * theta))) * G * M[int(class_)]/ rr 29 | 30 | risk[dist<= ego_radius] = max_risk 31 | 32 | risk[risk > max_risk] = max_risk 33 | # risk = np.exp(k2*v*np.cos(theta)*(1 - np.abs(theta)//(2*np.pi))) * G / rr 34 | return risk 35 | 36 | -------------------------------------------------------------------------------- /utils/predict/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.utils.predict.linear import vertices_linear_predict 2 | from spider.utils.predict.common import BasePrediction, BasePredictor 3 | -------------------------------------------------------------------------------- /utils/predict/common.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class BasePrediction: 4 | def __init__(self): 5 | self.pred_traj = None 6 | self.pred_vertices = None 7 | 8 | 9 | class BasePredictor: 10 | def __init__(self): 11 | pass 12 | 13 | def predict(self, *args, **kwargs): 14 | pass 15 | 16 | def predict_box(self): 17 | pass 18 | 19 | def predict_vertices(self): 20 | pass 21 | 22 | def predict_occupancy(self): 23 | pass 24 | -------------------------------------------------------------------------------- /utils/predict/linear.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spider.utils.predict.common import BasePredictor 3 | 4 | 5 | def vertices_linear_predict(vertices:np.ndarray, vx, vy, ts): 6 | pred_vertices = [] 7 | for t in ts: 8 | dx, dy = vx * t, vy * t 9 | pred_vertice = vertices.copy() 10 | pred_vertice[:,0] += dx 11 | pred_vertice[:,1] += dy 12 | pred_vertices.append(pred_vertice) 13 | return pred_vertices 14 | 15 | 16 | # 目前,将predict结果暂时统一为x,y,theta的序列,以后可能再改 17 | class LinearPredictor(BasePredictor): 18 | def __init__(self): 19 | super(LinearPredictor, self).__init__() 20 | pass 21 | 22 | def predict(self, trackingbox_list, ts): 23 | if len(trackingbox_list) == 0: 24 | return np.empty((0,3),dtype=float) 25 | for tb in trackingbox_list: 26 | vx, vy, yaw = tb.vx, tb.vy, tb.box_heading 27 | xs, ys = tb.x + np.asarray(ts) * tb.vx, tb.y + np.asarray(ts) * tb.vy 28 | yaws = np.ones_like(xs) * yaw 29 | return np.column_stack((xs, ys, yaws)) 30 | 31 | def predict_box(self): 32 | pass 33 | 34 | def predict_vertices(self): 35 | pass 36 | 37 | def predict_occupancy(self): 38 | pass 39 | -------------------------------------------------------------------------------- /utils/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.utils.transform.frenet import FrenetTransformer 2 | -------------------------------------------------------------------------------- /utils/transform/gps.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/utils/transform/gps.py -------------------------------------------------------------------------------- /utils/transform/grid.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import math 4 | import numpy as np 5 | from typing import List 6 | 7 | from spider.utils.transform.relative import RelativeTransformer 8 | 9 | 10 | class GridTransformer: 11 | def __init__(self, longitudinal_range, lateral_range, grid_resolution, 12 | ego_x=None, ego_y=None, ego_yaw=None, ego_vx=None, ego_vy=None): 13 | ''' 14 | longitudinal_range: [longitudinal_range_front,longitudinal_range_back] 15 | lateral_range: [lateral_range_left, lateral_range_right] 16 | grid_resolution: [longitudinal_resolution, lateral_resolution] 17 | ''' 18 | 19 | self.longitudinal_range = longitudinal_range 20 | self.lateral_range = lateral_range 21 | self.grid_resolution = grid_resolution 22 | self.lon_resolution, self.lat_resolution = grid_resolution 23 | 24 | self.width = int(math.ceil(sum(longitudinal_range) / self.lon_resolution)) 25 | self.height = int(math.ceil(sum(lateral_range) / self.lat_resolution)) 26 | 27 | self.ego_grid_pos = self._calc_ego_grid_pos() # [Float, Float]! 28 | 29 | self.rel_tf = RelativeTransformer(ego_x, ego_y, ego_yaw, ego_vx, ego_vy) 30 | 31 | def _calc_ego_grid_pos(self): 32 | ''' 33 | 计算自车在栅格坐标系下的位置 34 | ''' 35 | ego_grid_x = int(self.longitudinal_range[1] / self.grid_resolution[0]) # rear range 36 | ego_grid_y = int(self.lateral_range[0] / self.grid_resolution[1]) # left range 37 | return ego_grid_x, ego_grid_y 38 | 39 | def set_ego_pose(self, ego_x, ego_y, ego_yaw): 40 | self.rel_tf.set_ego_pose(ego_x, ego_y, ego_yaw) 41 | 42 | def set_ego_velocity(self, ego_vx, ego_vy): 43 | self.rel_tf.set_ego_velocity(ego_vx, ego_vy) 44 | 45 | def cart2grid(self, x, y, yaw=None, vx=None, vy=None, ego_pose=None, ego_velocity=None): 46 | 47 | # 先把坐标转为车辆的相对坐标 48 | rel_x, rel_y, rel_yaw, rel_vx, rel_vy = self.rel_tf.abs2rel(x, y, yaw, vx, vy, ego_pose, ego_velocity) 49 | 50 | # 然后缩放和取整来使得相对坐标转为栅格坐标 51 | x_grid = int(rel_x / self.grid_resolution[0] + self.ego_grid_pos[0]) 52 | y_grid = int(-rel_y / self.grid_resolution[1] + self.ego_grid_pos[1]) # *-1 because the y-axis in grid is reversed 53 | 54 | return x_grid, y_grid, rel_yaw, rel_vx, rel_vy 55 | 56 | def grid2cart(self, x_grid, y_grid, rel_yaw=None, rel_vx=None, rel_vy=None, ego_pose=None, ego_velocity=None): 57 | rel_x = (x_grid - self.ego_grid_pos[0]) * self.grid_resolution[0] 58 | rel_y = -(y_grid - self.ego_grid_pos[1]) * self.grid_resolution[1]# *-1 because the y-axis in grid is reversed 59 | 60 | x, y, yaw, vx, vy = self.rel_tf.rel2abs(rel_x, rel_y, rel_yaw, rel_vx, rel_vy, ego_pose, ego_velocity) 61 | 62 | return x, y, yaw, vx, vy 63 | 64 | def cart2grid4boxes(self): 65 | pass 66 | 67 | 68 | 69 | 70 | if __name__ == '__main__': 71 | from spider.elements.grid import OccupancyGrid2D 72 | import cv2 73 | occ_grid = OccupancyGrid2D([50, 30], [30, 30], [0.1, 0.1], 1) 74 | 75 | grid_tf = GridTransformer([50, 30], [30, 30], [0.1, 0.1]) 76 | grid_tf.set_ego_pose(50, 15, 0.) 77 | 78 | gx, gy = grid_tf.cart2grid(55, 0)[:2] 79 | # occ_grid.grid[0, x, y] = 1 80 | cv2.circle(occ_grid.grid[0], (gx, gy), 10, 1, -1) 81 | print(gx, gy) 82 | 83 | x,y = grid_tf.grid2cart(gx, gy)[:2] 84 | print(x, y) 85 | 86 | occ_grid.show(0, 0) 87 | 88 | 89 | 90 | # class GridTransformer: 91 | # ''' 92 | # 用于世界坐标系(局部笛卡尔坐标)和BEV下栅格坐标系的转换 93 | # 现在只支持2D 94 | # ''' 95 | # def __init__(self): 96 | # pass 97 | # 98 | # # def set_param(self, grid: OccupancyGrid2D): 99 | # # self.grid_resolution = 100 | # # self.ego_x_grid = 101 | # # self.ego_y_grid = 102 | # 103 | # 104 | # def cart2grid(self, grid: OccupancyGrid2D, x_cart, y_cart, vx_cart=0, vy_cart=0): 105 | # # 注意,grid中的x,y是图像坐标系下的,即图像中宽度上从左到右为x正方向,高度上从上到下为y正方向。尤其高度容易弄混。 106 | # # grid坐标系,自车位置固定在grid中央某一位置不动(这个位置由grid对象中的lon_range和lat_range和grid_resolution决定),自车车头在图像中始终朝上 107 | # # C++文件中有个calc_ogm_idx什么的函数,可以借鉴 108 | # 109 | # if vx_cart == 0 and vy_cart == 0: 110 | # vx_grid, vy_grid = 0, 0 111 | # # 0阶坐标变换... 112 | # else: 113 | # # 0阶和1阶坐标变换。。。 114 | # pass 115 | # 116 | # return x_grid, y_grid, vx_grid, vy_grid 117 | # 118 | # 119 | # def grid2cart(self, grid: OccupancyGrid2D, x_grid, y_grid, vx_grid=0, vy_grid=0): 120 | # 121 | # if vx_grid == 0 and vy_grid == 0: 122 | # vx_cart, vy_cart = 0, 0 123 | # # 0阶坐标变换... 124 | # else: 125 | # # 0阶和1阶坐标变换。。。 126 | # pass 127 | # 128 | # return x_cart, y_cart, vx_cart, vy_cart 129 | 130 | -------------------------------------------------------------------------------- /utils/transform/polar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def cart2polar(x, y): 4 | r = np.sqrt(x**2 + y**2) 5 | theta = np.atan2(y, x) 6 | return r, theta 7 | 8 | def polar2cart(r, theta): 9 | x = r * np.cos(theta) 10 | y = r * np.sin(theta) 11 | return x, y 12 | 13 | 14 | 15 | class PolarTransformer: 16 | def __init__(self): 17 | pass 18 | 19 | def cart2polar(self, x, y, vx=None, vy=None, ax=None, ay=None, *, order=0): 20 | r = np.sqrt(x**2 + y**2) 21 | theta = np.arctan2(y, x) 22 | return r, theta 23 | 24 | def polar2cart(self, r, theta): 25 | x = r * np.cos(theta) 26 | y = r * np.sin(theta) 27 | return x, y 28 | -------------------------------------------------------------------------------- /utils/transform/reference.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/utils/transform/reference.zip -------------------------------------------------------------------------------- /utils/transform/relative.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from spider.utils.vector import rotate 4 | 5 | 6 | class RelativeTransformer: 7 | def __init__(self, ego_x=None, ego_y=None, ego_yaw=None, ego_vx=None, ego_vy=None): 8 | self.ego_x = 0. 9 | self.ego_y = 0. 10 | self.ego_yaw = 0. 11 | self.ego_vx = 0. 12 | self.ego_vy = 0. 13 | 14 | self.set_ego_pose(ego_x, ego_y, ego_yaw) 15 | self.set_ego_velocity(ego_vx, ego_vy) 16 | 17 | @property 18 | def ego_pose(self): 19 | return self.ego_x, self.ego_y, self.ego_yaw 20 | 21 | @property 22 | def ego_velocity(self): 23 | return self.ego_vx, self.ego_vy 24 | 25 | def set_ego_pose(self, ego_x, ego_y, ego_yaw): 26 | self.ego_x = ego_x 27 | self.ego_y = ego_y 28 | self.ego_yaw = ego_yaw 29 | 30 | def set_ego_velocity(self, ego_vx, ego_vy): 31 | self.ego_vx = ego_vx 32 | self.ego_vy = ego_vy 33 | 34 | def abs2rel(self, x, y, yaw=None, vx=None, vy=None, ego_pose=None, ego_velocity=None): 35 | ''' 36 | ego_pose: (ego_x, ego_y, ego_yaw) 37 | ''' 38 | if ego_pose is not None: 39 | self.set_ego_pose(*ego_pose) 40 | if ego_velocity is not None: 41 | self.set_ego_velocity(*ego_velocity) 42 | 43 | delta_vec = np.array([x - self.ego_x, y - self.ego_y]).T 44 | rel_x, rel_y = rotate(delta_vec, (0, 0), -self.ego_yaw).T 45 | 46 | rel_yaw = yaw - self.ego_yaw if yaw is not None else None 47 | 48 | if vx is None or vy is None: 49 | rel_vx = rel_vy = None 50 | else: 51 | delta_vel_vec = np.array([vx - self.ego_vx, vy - self.ego_vy]).T 52 | rel_vx, rel_vy = rotate(delta_vel_vec, (0, 0), -self.ego_yaw).T 53 | 54 | return rel_x, rel_y, rel_yaw, rel_vx, rel_vy 55 | 56 | 57 | def rel2abs(self, rel_x, rel_y, rel_yaw=None, rel_vx=None, rel_vy=None, ego_pose=None, ego_velocity=None): 58 | ''' 59 | ego_pose: (ego_x, ego_y, ego_yaw) 60 | ''' 61 | if ego_pose is not None: 62 | self.set_ego_pose(*ego_pose) 63 | if ego_velocity is not None: 64 | self.set_ego_velocity(*ego_velocity) 65 | 66 | delta_vec = rotate(np.array([rel_x, rel_y]), (0, 0), self.ego_yaw).T 67 | abs_x = self.ego_x + delta_vec[..., 0] 68 | abs_y = self.ego_y + delta_vec[..., 1] 69 | 70 | abs_yaw = self.ego_yaw + rel_yaw if rel_yaw is not None else None 71 | 72 | if rel_vx is None or rel_vy is None: 73 | abs_vx = abs_vy = None 74 | else: 75 | delta_vel_vec = rotate(np.array([rel_vx, rel_vy]), (0, 0), self.ego_yaw).T 76 | abs_vx = self.ego_vx + delta_vel_vec[...,0] 77 | abs_vy = self.ego_vy + delta_vel_vec[...,1] 78 | 79 | return abs_x, abs_y, abs_yaw, abs_vx, abs_vy 80 | 81 | if __name__ == '__main__': 82 | tf = RelativeTransformer() 83 | tf.abs2rel(2, 2, 3.14 / 3, 1, 1, (1, 1, 3.14 / 6), (1, 0)) 84 | print(tf) 85 | -------------------------------------------------------------------------------- /utils/vector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # vector定义成形状为(n,)的一维ndarray,matrix定义为形状为(m,n)的二维ndarray 4 | 5 | def vec_in(vector, matrix): 6 | ''' 7 | 判断 一个向量 是否在 一堆向量里面 8 | ''' 9 | return (matrix == vector).all(1).any() 10 | 11 | 12 | def find_vec(vector, matrix, find_all=False): 13 | ''' 14 | 在一堆向量里, 找到一个某个向量的索引, 并返回第一个 15 | ''' 16 | idxs = np.where((matrix == vector).all(1))#[0] 17 | if find_all: 18 | return idxs 19 | else: 20 | if len(idxs) == 0: 21 | return None 22 | else: 23 | return idxs[0] 24 | 25 | def normalize(vector: np.ndarray): 26 | ''' 27 | 求一个矢量方向的单位矢量 28 | :param vector: 任意维度向量 29 | :return: 单位向量 30 | ''' 31 | # vector = np.array(vector) 32 | if modulus(vector) == 0: 33 | return np.zeros_like(vector) 34 | vector_norm = vector / modulus(vector) 35 | return vector_norm 36 | 37 | def modulus(vector): 38 | return np.linalg.norm(vector) 39 | 40 | def project(vector, segment_vector, calc_distance:bool=False): 41 | ''' 42 | project vector or matrix to axis_vector 43 | :param array: vector or matrix 44 | :param segment_vector: 45 | :return: projection 46 | ''' 47 | # 正值表示在目标向量的左侧(从目标向量到vector逆时针),负值表示在参考线的右侧(从目标向量到vector顺时针)。 48 | # 这种约定符合参考线在Frenet坐标系中的定义 49 | # 即:一般认为沿着参考线s增加方向的左边为正,右边为负 50 | segment_unit_vector = normalize(segment_vector) 51 | projection = np.dot(vector, normalize(segment_vector)) # projection length 52 | 53 | if calc_distance: 54 | distance_signed = np.cross(segment_unit_vector, vector) # 叉乘 55 | return projection, float(distance_signed) 56 | else: 57 | return projection 58 | 59 | 60 | def rotate90(vec:np.ndarray): 61 | ''' 62 | 逆时针旋转90度 63 | ''' 64 | return np.flip(vec)*[1,-1] # 离谱bug,1写成0了 65 | 66 | def rotate(array, anchor, angle, clockwise=False): 67 | ''' 68 | 69 | :param array: vector or 2-column matrix 70 | :param angle: rad required 71 | :param clockwise: default False 72 | :return: 73 | ''' 74 | delta_array = np.asarray(array).copy() 75 | delta_array[..., 0] -= anchor[0] 76 | delta_array[..., 1] -= anchor[1] 77 | if clockwise: 78 | angle = -angle 79 | transposed_rot = np.array([ 80 | [np.cos(angle), np.sin(angle)], 81 | [-np.sin(angle), np.cos(angle)] 82 | ]) 83 | delta_array = delta_array @ transposed_rot 84 | delta_array[..., 0] += anchor[0] 85 | delta_array[..., 1] += anchor[1] 86 | return delta_array 87 | 88 | # if __name__ == '__main__': 89 | # temp = Vector3D() 90 | # pass 91 | 92 | 93 | -------------------------------------------------------------------------------- /visualize/__init__.py: -------------------------------------------------------------------------------- 1 | from spider.visualize.common import * 2 | from spider.visualize.line import * 3 | from spider.visualize.point import * 4 | from spider.visualize.surface import * 5 | from spider.visualize.surface3d import * 6 | from spider.visualize.elements import * 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | def __getattr__(name): 11 | return getattr(plt, name) 12 | # if name == 'show': 13 | # return plt.show 14 | # elif name == 'savefig': 15 | # return plt.savefig 16 | # elif name == 'close': 17 | # return plt.close 18 | # else: 19 | # import warnings 20 | # warnings.warn("{} is actually from matplotlib.pyplot. Please try to import plt directly.".format(name)) 21 | # return getattr(plt, name) 22 | 23 | -------------------------------------------------------------------------------- /visualize/dashboard/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/visualize/dashboard/__init__.py -------------------------------------------------------------------------------- /visualize/line.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from typing import Union 4 | 5 | import spider 6 | # from spider.elements.trajectory import Trajectory, FrenetTrajectory 7 | from spider.visualize.surface import draw_obb, draw_polygon 8 | from spider.utils.geometry import generate_parallel_line 9 | 10 | 11 | def draw_polyline(polyline:np.ndarray, *args, show_buffer=False, buffer_dist=1.0, buffer_alpha=0.2, **kwargs): 12 | polyline = np.asarray(polyline) 13 | lines = plt.plot(polyline[:,0], polyline[:,1], *args, **kwargs) 14 | 15 | if show_buffer: 16 | color = lines[0].get_color() 17 | left_bound = generate_parallel_line(polyline, buffer_dist, left_or_right=spider.DIRECTION_LEFT) 18 | right_bound = generate_parallel_line(polyline, buffer_dist, left_or_right=spider.DIRECTION_RIGHT) 19 | polygon_vertices = np.concatenate((left_bound, np.flip(right_bound, axis=0))) 20 | draw_polygon(polygon_vertices, fill=True, color=color, alpha=buffer_alpha) 21 | return lines 22 | 23 | 24 | # def draw_trajectory(traj: Union[Trajectory, FrenetTrajectory], *args, 25 | # show_footprint=False, footprint_size=(5., 2.), footprint_fill=True, footprint_alpha=0.1, **kwargs): 26 | # lines = plt.plot(traj.x, traj.y, *args, **kwargs) 27 | # 28 | # if show_footprint: 29 | # length, width = footprint_size 30 | # color = lines[0].get_color() 31 | # footprint_alpha = footprint_alpha if footprint_fill else 0.8 # 填充就按设定的透明度来,否则默认0.8 32 | # 33 | # for x, y, yaw in zip(traj.x, traj.y, traj.heading): 34 | # draw_obb((x, y, length, width, yaw), fill=footprint_fill, alpha=footprint_alpha, color=color) 35 | # 36 | # return lines 37 | 38 | 39 | if __name__ == '__main__': 40 | pass 41 | 42 | 43 | -------------------------------------------------------------------------------- /visualize/point.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Thu-ADLab/SPIDER/a1a2f81a29da0c63f1a46ad04740d6031c78e885/visualize/point.py -------------------------------------------------------------------------------- /visualize/surface.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.colors import to_rgba 3 | import numpy as np 4 | from spider.elements.box import obb2vertices 5 | 6 | 7 | def draw_polygon(vertices, *args, fill=False, **kwargs): 8 | # vertices = np.vstack((vertices, vertices[0])) # recurrent to close polyline 9 | # plt.plot(vertices[:, 0], vertices[:, 1], *args, **kwargs) 10 | 11 | if fill: 12 | # 这里设计的key不太好,想想重新设计一下 13 | face_alpha = kwargs['alpha'] if 'alpha' in kwargs else 1.0 14 | edge_alpha = kwargs['edge_alpha'] if 'edge_alpha' in kwargs else 0.8 15 | if 'alpha' in kwargs: kwargs.pop('alpha') 16 | if 'edge_alpha' in kwargs: kwargs.pop('edge_alpha') 17 | # if fill: 18 | # if 'color' in kwargs and 'alpha' in kwargs: 19 | # kwargs['edgecolor'] = to_rgba(kwargs['color'], 0.8) 20 | # kwargs['facecolor'] = to_rgba(kwargs['color'], kwargs['alpha']) 21 | # del kwargs['color'], kwargs['alpha'] 22 | 23 | 24 | polygon = plt.Polygon(vertices, *args, fill=fill, **kwargs) 25 | 26 | if fill: 27 | polygon.set_facecolor(to_rgba(polygon.get_facecolor(), face_alpha)) 28 | if 'color' in kwargs or 'edgecolor' in kwargs: # 如果边框已经被设置过颜色,就只改透明度即可 29 | polygon.set_edgecolor(to_rgba(polygon.get_edgecolor(), edge_alpha)) 30 | else: # 如果边框没有被设置过颜色,就把颜色改成facecolor 31 | polygon.set_edgecolor(to_rgba(polygon.get_facecolor(), edge_alpha)) 32 | 33 | plt.gca().add_patch(polygon) 34 | #closed=True, edgecolor='black', color='green' 35 | 36 | return polygon 37 | 38 | 39 | def draw_obb(obb, *args, fill=False, **kwargs): 40 | vertices = obb2vertices(obb) 41 | return draw_polygon(vertices, *args, fill=fill, **kwargs) 42 | 43 | 44 | def draw_aabb(aabb_rect, *args, fill=False, **kwargs): 45 | xmin, ymin, xmax, ymax = aabb_rect 46 | vertices = np.array([ 47 | [xmin, ymin], 48 | [xmin, ymax], 49 | [xmax, ymax], 50 | [xmax, ymin], 51 | # [xmin, ymin] 52 | ]) 53 | # return plt.plot(vertices[:, 0], vertices[:, 1], *args, **kwargs) 54 | return draw_polygon(vertices, *args, fill=fill, **kwargs) 55 | 56 | def draw_circle(center, radius, *args, mark_center=False, fill=False, alpha=1.0, **kwargs): 57 | # fill=False, color='r' 58 | circle = plt.Circle(center, radius, *args, fill=fill,alpha=alpha, **kwargs) 59 | plt.gca().add_artist(circle) 60 | 61 | if mark_center: 62 | marker_color = 'black' if fill else circle.get_edgecolor() 63 | plt.plot(center[0], center[1],'.', color=marker_color) 64 | 65 | 66 | 67 | 68 | if __name__ == '__main__': 69 | pass 70 | -------------------------------------------------------------------------------- /visualize/surface3d.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import matplotlib.patches as patches 4 | import mpl_toolkits.mplot3d.art3d as art3d 5 | 6 | from spider.visualize.surface import draw_polygon 7 | 8 | 9 | def draw_prism(bottom_vertices, bottom_z, top_vertices, top_z, color=(0.8, 0.4, 0.6), alpha=0.4): 10 | ''' 11 | 画棱柱 12 | ''' 13 | 14 | ax = plt.gca() 15 | 16 | # 定义顶部和底部顶点坐标 17 | top_vertices = np.array(top_vertices) 18 | bottom_vertices = np.array(bottom_vertices) 19 | 20 | # 绘制底部的面 21 | bottom = draw_polygon(bottom_vertices, fill=True, facecolor=color, alpha=alpha) 22 | ax.add_patch(bottom) 23 | art3d.pathpatch_2d_to_3d(bottom, z=bottom_z) 24 | 25 | 26 | vertices = np.vstack((bottom_vertices, bottom_vertices[0])) # recurrent to close polyline 27 | plt.plot(vertices[:, 0], vertices[:, 1], bottom_z, color=color, linestyle='-') 28 | 29 | # 绘制顶部的面 30 | # x_top = top_vertices[:, 0] 31 | # y_top = top_vertices[:, 1] 32 | top = draw_polygon(top_vertices, fill=True, facecolor=color, alpha=alpha) 33 | ax.add_patch(top) 34 | art3d.pathpatch_2d_to_3d(top, z=top_z) 35 | # 36 | vertices = np.vstack((top_vertices, top_vertices[0])) # recurrent to close polyline 37 | plt.plot(vertices[:, 0], vertices[:, 1], top_z, color=color, linestyle='-') 38 | 39 | # 绘制四棱柱的边 40 | for i in range(len(top_vertices)): 41 | x = [bottom_vertices[i][0], top_vertices[i][0]] 42 | y = [bottom_vertices[i][1], top_vertices[i][1]] 43 | z = [bottom_z,top_z] 44 | ax.plot(x, y, z, 'r-', alpha=0.6) 45 | 46 | 47 | if __name__ == '__main__': 48 | 49 | pass 50 | --------------------------------------------------------------------------------