├── model ├── 2450k.pth ├── 2799k.pth └── 4595k.pth ├── utils ├── Scheduler.py ├── PoE.py ├── TWQ.py └── Transqer.py ├── play_with_keyboard.py ├── README.md ├── main.py └── Sparrow_V3.py /model/2450k.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/Sparrow-V3/HEAD/model/2450k.pth -------------------------------------------------------------------------------- /model/2799k.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/Sparrow-V3/HEAD/model/2799k.pth -------------------------------------------------------------------------------- /model/4595k.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XinJingHao/Sparrow-V3/HEAD/model/4595k.pth -------------------------------------------------------------------------------- /utils/Scheduler.py: -------------------------------------------------------------------------------- 1 | class LinearSchedule(object): 2 | def __init__(self, schedule_timesteps, initial_p, final_p): 3 | """Linear interpolation between initial_p and final_p over""" 4 | self.schedule_timesteps = schedule_timesteps 5 | self.initial_p = initial_p 6 | self.final_p = final_p 7 | 8 | def value(self, t): 9 | fraction = min(float(t) / self.schedule_timesteps, 1.0) 10 | return self.initial_p + fraction * (self.final_p - self.initial_p) -------------------------------------------------------------------------------- /play_with_keyboard.py: -------------------------------------------------------------------------------- 1 | from Sparrow_V3 import Sparrow 2 | import pygame 3 | import torch 4 | 5 | 6 | def main_dicrete_action(): 7 | envs = Sparrow() 8 | envs.reset() 9 | while True: 10 | keys = pygame.key.get_pressed() 11 | if keys[pygame.K_LEFT]: a = 0 12 | elif keys[pygame.K_UP]: a = 2 13 | elif keys[pygame.K_RIGHT]: a = 4 14 | elif keys[pygame.K_DOWN]: a = 5 15 | else: a = 7 # Stop 16 | 17 | a = torch.ones(envs.N, dtype=torch.long, device=envs.dvc) * a 18 | s_next, r, terminated, truncated, info = envs.step(a) 19 | 20 | 21 | if __name__ == '__main__': 22 | main_dicrete_action() 23 | -------------------------------------------------------------------------------- /utils/PoE.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | 5 | class PositionalEncoding_NTD(nn.Module): 6 | """Batch First Positional Encoding. Note that emb_size must be even numbers""" 7 | def __init__(self, maxlen: int, emb_size: int): 8 | super(PositionalEncoding_NTD, self).__init__() 9 | 10 | den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size) 11 | pos = torch.arange(0, maxlen).reshape(maxlen, 1) 12 | pos_embedding = torch.zeros((maxlen, emb_size)) 13 | pos_embedding[:, 0::2] = torch.sin(pos * den) 14 | pos_embedding[:, 1::2] = torch.cos(pos * den) 15 | pos_embedding = pos_embedding.unsqueeze(0) # (T,D) -> (1,T,D) 用于匹配 (N,T,D) 16 | 17 | self.register_buffer('pos_embedding', pos_embedding) # 同时保存;在同一个dvc;不参与训练 18 | 19 | def forward(self, token_embedding: torch.tensor): 20 | '''token_embedding的维度必须严格为 (N,T,D)''' 21 | return token_embedding + self.pos_embedding # (N,T,D) + (1,T,D) -> (N,T,D) 22 | 23 | 24 | 25 | # pe = BF_PositionalEncoding(maxlen=5, emb_size=6) 26 | # pos_embedding.shape = (1,T,D) 27 | # tensor([[[ 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 1.0000], 28 | # [ 0.8415, 0.5403, 0.0464, 0.9989, 0.0022, 1.0000], 29 | # [ 0.9093, -0.4161, 0.0927, 0.9957, 0.0043, 1.0000], 30 | # [ 0.1411, -0.9900, 0.1388, 0.9903, 0.0065, 1.0000], 31 | # [-0.7568, -0.6536, 0.1846, 0.9828, 0.0086, 1.0000]]]) 32 | -------------------------------------------------------------------------------- /utils/TWQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TimeWindowQueue_NTD: 5 | def __init__(self, N, T, D, device, padding): 6 | self.device = device 7 | self.padding = padding 8 | self.T = T 9 | 10 | # 初始化缓冲区,形状为 (N, T, D), 即transformer(batch_first=True)要求的(batch_size, seq_len, emb_dim) 11 | if padding == -1: self.window = -torch.ones(N, T, D, device=self.device) # 便于可视化 12 | elif padding == 0: self.window = torch.zeros(N, T, D, device=self.device) 13 | else: raise ValueError("Wrong padding value") 14 | self.ptr = 0 15 | 16 | 17 | def append(self, batched_transition: torch.tensor): 18 | """ batched_transition, shape=(B,D): batched transition from vectorized envs """ 19 | 20 | # 将数据写入缓冲区 21 | self.window[:, self.T - 1 - self.ptr, :] = batched_transition # (B,D), 由下往上写入,保证roll的输出顺序 22 | 23 | # 更新写指针和计数器 24 | self.ptr = (self.ptr + 1) % self.T 25 | 26 | def get(self) -> torch.tensor: 27 | """ 28 | 获取时间窗口buffer中的所有数据, shape=(N, T, D), 使用roll保证数据按时序正确排列 29 | t=0为最近时刻的数据, t=T-1为最远时刻的数据 30 | """ 31 | TimeWindow_data = torch.roll(self.window, shifts=self.ptr, dims=1) # (N, T, D) 32 | 33 | return TimeWindow_data 34 | 35 | 36 | def padding_with_done(self, done_flag: torch.tensor): 37 | """ 38 | 根据done_flag,将buffer中对应batch位置置零 39 | :param done_flag: shape=(N,) 40 | """ 41 | self.window[done_flag, :, :] = self.padding 42 | 43 | def clear(self): 44 | self.window.fill_(self.padding) 45 | 46 | 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | ## Sparrow-V3.0: A Reinforcement Learning Friendly Simulator for Multiple Mobile Robots 6 | 7 |

8 | 9 | 10 | 11 |

12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 |
20 | 21 | ## Dependency 22 | 23 | ```bash 24 | numpy>=1.24.4 25 | scipy>=1.10.1 26 | pygame>=2.5.2 27 | torch>=2.2.0 # GPU version is recommended 28 | torchaudio>=2.2.0 29 | torchvision>=0.17.0 30 | ``` 31 | 32 | ## Play with keyboard 33 | 34 | ```bash 35 | python play_with_keyboard.py 36 | ``` 37 | 38 | ## Play with trained model 39 | 40 | ```bash 41 | python main.py 42 | ``` 43 | 44 | ## Train your own model 45 | Please refer to [ColorDynamic](https://github.com/XinJingHao/ColorDynamic) 46 | 47 |
48 | 49 | ## Documentations 50 | - Please refer to this [repo](https://github.com/XinJingHao/Sparrow-V2). 51 | 52 |
53 | 54 | ## Main differences from Sparrow-V2 55 | - Multiple robots are supported 56 | 57 |
58 | 59 | ## The Sparrow families 60 | - [Sparrow-V1](https://github.com/XinJingHao/Sparrow-V1): Single Robot, Static environments 61 | - [Sparrow-V2](https://github.com/XinJingHao/Sparrow-V2): Single Robot, Dynamic/Static environments 62 | - [Sparrow-V3](https://github.com/XinJingHao/Sparrow-V3): Multiple/Single Robot, Dynamic/Static environments 63 | 64 | 65 |
66 | 67 | ## Citing the Project 68 | 69 | To cite this repository in publications: 70 | 71 | ```bibtex 72 | @article{ColorDynamic, 73 | title={ColorDynamic: Generalizable, Scalable, Real-time, End-to-end Local Planner for Unstructured and Dynamic Environments}, 74 | author={Jinghao Xin, Zhichao Liang, Zihuan Zhang, Peng Wang, and Ning Li}, 75 | journal={arXiv preprint arXiv:2502.19892}, 76 | year={2025} 77 | } 78 | ``` 79 | 80 | 81 | 82 | ## Writing in the end 83 | 84 | The name "Sparrow" actually comes from an old saying *“麻雀虽小,五脏俱全.”* 85 | 86 | Hope you enjoy using Sparrow! 87 | 88 | Additionally, we have made detailed comments on the source code (`Sparrow_V3.py`) so that you can modify Sparrow to fit your own problem. But only for non-commercial purposes, and all rights are reserved by [Jinghao Xin](https://github.com/XinJingHao). 89 | 90 | 91 | -------------------------------------------------------------------------------- /utils/Transqer.py: -------------------------------------------------------------------------------- 1 | from utils.PoE import PositionalEncoding_NTD 2 | from utils.TWQ import TimeWindowQueue_NTD 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import torch 6 | 7 | def orthogonal_init(layer, gain=1.414): 8 | for name, param in layer.named_parameters(): 9 | if 'bias' in name: 10 | nn.init.constant_(param, 0) 11 | elif 'weight' in name: 12 | nn.init.orthogonal_(param, gain=gain) 13 | return layer 14 | 15 | 16 | class Transqer_networks(nn.Module): 17 | def __init__(self, opt): 18 | super(Transqer_networks, self).__init__() 19 | self.d = opt.state_dim - 8 # s[0:7] is robot state, s[8:] is lidar results 20 | # Define the Transformer Encoder block(note that state_dim should be a even number): 21 | self.pe = PositionalEncoding_NTD(maxlen=opt.T, emb_size=self.d) # for (N,T,d) 22 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.d, nhead=opt.H, dropout=0, 23 | dim_feedforward=opt.net_width, batch_first=True) 24 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=opt.L) 25 | 26 | self.fc1 = orthogonal_init(nn.Linear(int(self.d + opt.state_dim), opt.net_width)) 27 | self.fc2 = orthogonal_init(nn.Linear(opt.net_width, int(opt.net_width/2))) 28 | self.fc3 = orthogonal_init(nn.Linear(int(opt.net_width/2), opt.action_dim)) 29 | 30 | def forward(self, TW_s): 31 | '''TW_s.shape = (B,T,D)''' 32 | temporal_ld = TW_s[:,:,8:] # s[0:7] is robot state, s[8:] is lidar results 33 | temporal_ld = self.pe(temporal_ld) # (N,T,d) 34 | temporal_ld = self.transformer_encoder(temporal_ld) # (N,T,d) 35 | temporal_ld = temporal_ld.mean(dim=1) # (N,T,d) -> (N,d) 36 | 37 | aug_s = torch.cat((temporal_ld,TW_s[:,0,:]),dim=-1) # (N,d+S_dim) 38 | 39 | q = F.relu(self.fc1(aug_s)) # (N,256) 40 | q = F.relu(self.fc2(q)) # (N,128) 41 | q = self.fc3(q) # (N,a_dim) 42 | return q 43 | 44 | 45 | class Transqer_agent(object): 46 | '''For Evaluation and Play, not for Training''' 47 | def __init__(self,opt): 48 | self.action_dim = opt.action_dim 49 | self.dvc = opt.dvc 50 | self.N = opt.N 51 | 52 | # Build Transqer 53 | self.q_net = Transqer_networks(opt).to(self.dvc) 54 | 55 | # vectorized e-greedy exploration 56 | self.p = torch.ones(opt.N, device=self.dvc) * 0.01 57 | 58 | # temporal window queue for interaction: 59 | self.queue = TimeWindowQueue_NTD(opt.N, opt.T, opt.state_dim, opt.dvc, padding=0) 60 | 61 | def select_action(self, TW_s, deterministic): 62 | '''Input: batched state in (N, T, s_dim) on device 63 | Output: batched action, (N,), torch.tensor, on device ''' 64 | with torch.no_grad(): 65 | a = self.q_net(TW_s).argmax(dim=-1) 66 | if deterministic: 67 | return a 68 | else: 69 | replace = torch.rand(self.N, device=self.dvc) < self.p # [n] 70 | rd_a = torch.randint(0, self.action_dim, (self.N,), device=self.dvc) 71 | a[replace] = rd_a[replace] 72 | return a 73 | 74 | def load(self,steps): 75 | self.q_net.load_state_dict(torch.load("./model/{}k.pth".format(steps), map_location=self.dvc, weights_only=True)) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from utils.Transqer import Transqer_agent 2 | from Sparrow_V3 import Sparrow, str2bool 3 | import argparse 4 | import torch 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | '''Hyperparameter Setting for Agent''' 9 | parser.add_argument('--ModelIdex', type=int, default=2799, help='which model(e.g. 2799k.pth) to load') 10 | parser.add_argument('--net_width', type=int, default=64, help='Linear net width') 11 | parser.add_argument('--T', type=int, default=10, help='length of time window') 12 | parser.add_argument('--H', type=int, default=8, help='Number of Head') 13 | parser.add_argument('--L', type=int, default=3, help='Number of Transformer Encoder Layers') 14 | 15 | '''Hyperparameter Setting for Sparrow''' 16 | parser.add_argument('--dvc', type=str, default='cuda', help='running device of Sparrow: cuda / cpu') 17 | parser.add_argument('--action_type', type=str, default='Discrete', help='Action type: Discrete / Continuous') 18 | parser.add_argument('--window_size', type=int, default=800, help='size of the map') 19 | parser.add_argument('--ET', type=str2bool, default=False, help='if True: DT2 Exceed D will be regarded as Termination') 20 | parser.add_argument('--D', type=int, default=400, help='maximal local planning distance') 21 | parser.add_argument('--N', type=int, default=10, help='number of agents') 22 | parser.add_argument('--O', type=int, default=5, help='number of obstacles in each environment') 23 | parser.add_argument('--RSEO', type=str2bool, default=True, help='Robot Scan Each Other') 24 | parser.add_argument('--RdOV', type=str2bool, default=True, help='whether to randomize the Velocity of dynamic obstacles') 25 | parser.add_argument('--RdOT', type=str2bool, default=True, help='whether to randomize the Type of dynamic obstacles') 26 | parser.add_argument('--RdOR', type=str2bool, default=True, help='whether to randomize the Radius of obstacles') 27 | parser.add_argument('--Obs_R', type=int, default=14, help='maximal obstacle radius, cm') 28 | parser.add_argument('--Obs_V', type=int, default=25, help='maximal obstacle velocity, cm/s') 29 | parser.add_argument('--MapObs', type=str, default=None, help="name of map file, e.g. 'map.png' or None") 30 | parser.add_argument('--ld_a_range', type=int, default=360, help='max scanning angle of lidar (degree)') 31 | parser.add_argument('--ld_d_range', type=int, default=300, help='max scanning distance of lidar (cm)') 32 | parser.add_argument('--ld_num', type=int, default=72, help='number of lidar streams in each world') 33 | parser.add_argument('--ld_GN', type=int, default=3, help='how many lidar streams are grouped for one group') 34 | parser.add_argument('--basic_ctrl_interval', type=float, default=0.1, help='control interval (s), 0.1 means 10 Hz control frequency') 35 | parser.add_argument('--ctrl_delay', type=int, default=0, help='control delay, in basic_ctrl_interval, 0 means no control delay') 36 | parser.add_argument('--K', type=tuple, default=(0.55,0.6), help='K_linear, K_angular') 37 | parser.add_argument('--show_ld', type=int, default=False, help='whether to render lidar streams') 38 | parser.add_argument('--draw_auxiliary', type=str2bool, default=False, help='whether to draw auxiliary infos') 39 | parser.add_argument('--render_speed', type=str, default='fast', help='fast / slow / real') 40 | parser.add_argument('--max_ep_steps', type=int, default=500, help='maximum episodic steps') 41 | parser.add_argument('--noise', type=str2bool, default=True, help='whether to add noise to the observations') 42 | parser.add_argument('--DR', type=str2bool, default=True, help='whether to use Domain Randomization') 43 | parser.add_argument('--DR_freq', type=int, default=int(3.2e3), help='frequency of Domain Randomization, in total steps') 44 | parser.add_argument('--compile', type=str2bool, default=True, help='whether to torch.compile to boost simulation speed') 45 | opt = parser.parse_args() 46 | opt.render_mode = 'human' 47 | opt.dvc = torch.device(opt.dvc) 48 | 49 | 50 | def main(): 51 | # Build env 52 | env = Sparrow(**vars(opt)) # for test 53 | opt.state_dim = env.state_dim 54 | opt.action_dim = env.action_dim 55 | 56 | # Init agent 57 | agent = Transqer_agent(opt) 58 | agent.load(opt.ModelIdex) 59 | 60 | # Play 61 | while True: 62 | test_ep_steps, test_ep_r, test_arrival_rate = evaluate(env, agent, deterministic=False, turns=100) 63 | print(f'ArrivalRate:{test_arrival_rate}, Reward:{test_ep_r}, Steps: {test_ep_steps}\n') 64 | 65 | 66 | 67 | def evaluate(envs, agent, deterministic, turns): 68 | step_collector, total_steps = torch.zeros(opt.N, device=opt.dvc), 0 69 | r_collector, total_r = torch.zeros(opt.N, device=opt.dvc), 0 70 | arrived, finished = 0, 0 71 | 72 | agent.queue.clear() 73 | s, info = envs.reset() 74 | ct = torch.ones(opt.N, device=opt.dvc, dtype=torch.bool) 75 | while finished < turns: 76 | '''单步state -> 时序窗口state:''' 77 | agent.queue.append(s) # 将s加入时序窗口队列 78 | TW_s = agent.queue.get() # 取出队列所有数据及 79 | a = agent.select_action(TW_s, deterministic) 80 | s, r, dw, tr, info = envs.step(a) 81 | 82 | '''解析dones, wins, deads, truncateds, consistents信号:''' 83 | agent.queue.padding_with_done(~ct) # 根据上一时刻的ct去padding 84 | dones = dw + tr 85 | wins = (r == envs.AWARD) 86 | dead_and_tr = dones^wins # dones-wins = deads and truncateds 87 | ct = ~dones 88 | 89 | '''统计回合步数:''' 90 | step_collector += 1 91 | total_steps += step_collector[wins].sum() # 到达,总步数加上真实步数 92 | total_steps += (envs.max_ep_steps * dead_and_tr).sum() # 未到达,总步数加上回合最大步数 93 | step_collector[dones] = 0 94 | 95 | '''统计总奖励:''' 96 | r_collector += r 97 | total_r += r_collector[dones].sum() 98 | r_collector[dones] = 0 99 | 100 | '''统计到达率:''' 101 | finished += dones.sum() 102 | arrived += wins.sum() 103 | 104 | return int(total_steps.item() / finished.item()), \ 105 | round(total_r.item() / finished.item(), 2), \ 106 | round(arrived.item() / finished.item(), 2) 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /Sparrow_V3.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from scipy import ndimage 3 | import numpy as np 4 | import torch 5 | import copy 6 | import os 7 | 8 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = '1' 9 | import pygame 10 | 11 | 12 | # 19 editable configurations: 13 | default_cfg = dict( 14 | dvc=torch.device('cpu'), # running device of Sparrow: cuda / cpu 15 | action_type='Discrete', # Action type: Discrete / Continuous 16 | window_size=800, # size of the map 17 | D=400, # maximal local planning distance 18 | N=2, # number of robots 19 | O=2, # number of obstacles in each environment 20 | RSEO=True, # Robot Scan Each Other 21 | RdOV=True, # whether to randomize the velocity of obstacles 22 | RdOT=True, # whether to randomize the type of obstacles 23 | RdOR=True, # whether to generate obstacles with random radius between [10, Obs_R] 24 | Obs_R=14, # maximal obstacle radius, cm 25 | Obs_V=25, # maximal obstacle velocity, cm/s 26 | MapObs=None, # None /or/ the name of .png file, e.g. 'map.png' or None 27 | ld_a_range=360, # max scanning angle of lidar (degree) 28 | ld_d_range=100, # max scanning distance of lidar (cm) 29 | ld_num=12, # number of lidar streams in each world 30 | ld_GN=3, # how many lidar streams are grouped for each group 31 | basic_ctrl_interval=0.1, # control interval (s), 0.1 means 10 Hz control frequency 32 | ctrl_delay=0, # control delay, in basic_ctrl_interval, 0 means no control delay 33 | K=(0.55, 0.6), # K_linear, K_angular 34 | show_ld=True, # whether to render lidar streams 35 | draw_auxiliary = False, # draw auxiliary area 36 | render_mode='human', # 'human' / 'rgb_array' / None 37 | render_speed='fast', # 'real' / 'fast' / 'slow' 38 | max_ep_steps=500, # maximum episodic steps 39 | noise=False, # whether to add noise to the observations 40 | DR=False, # whether to use Domain Randomization 41 | DR_freq=int(3.2e3), # frequency of re-Domain Randomization, in total steps 42 | compile=False) # whether to torch.compile to boost simulation speed 43 | 44 | 45 | class Sparrow(): 46 | def __init__(self, **params): 47 | if len(params)==0: self.__dict__.update(**default_cfg) # Use default configration 48 | else: self.__dict__.update(params) # Use user's configration 49 | if self.DR: self.noise = True # 开启Domain Randomization后默认开启state noise 50 | self.version = "V3.0-MultiRobot" 51 | 52 | '''State/Action/Reward initialization''' 53 | assert self.ld_num % self.ld_GN == 0 # ld_num must be divisible by ld_GN 54 | self.grouped_ld_num = int(self.ld_num / self.ld_GN) 55 | self.absolute_state_dim = 5 + self.grouped_ld_num # [dx,dy,orientation,v_linear,v_angular] + [lidar result] 56 | self.state_dim = 8 + self.grouped_ld_num # [cAl,cAa,rAl,rAa,D2T,alpha,v_linear,v_angular] + [lidar result] 57 | if self.action_type == 'Discrete': self.action_dim = 7 # 5:前进+转弯,6:前进+转弯+后退,7:前进+转弯+后退+减速,8:前进+转弯+后退+减速+原地静止 58 | else: self.action_dim = 2 # [V_linear, V_angular] 59 | self.AWARD, self.PUNISH = 200, -200 # 到达奖励,碰撞、越界惩罚 60 | 61 | '''Car initialization''' 62 | self.car_radius = 9 # cm 63 | self.scan_radius = self.car_radius + 3 # 雷达扫描起始距离 64 | self.collision_trsd = self.car_radius + 5 # collision threshould, in cm 65 | self.v_linear_max = 50 # max linear velocity, in cm/s 66 | self.v_angular_max = 2 # max angular velocity, in rad/s 67 | if self.action_type == 'Continuous': self.continuous_scale = torch.tensor([[self.v_linear_max, self.v_angular_max]], device=self.dvc) # (1,2) 68 | self.a_space = torch.tensor([[0.2*self.v_linear_max , self.v_angular_max],[self.v_linear_max , self.v_angular_max], 69 | [self.v_linear_max, 0], # v_linear, v_angular 70 | [self.v_linear_max, -self.v_angular_max],[0.2*self.v_linear_max, -self.v_angular_max], 71 | [-self.v_linear_max, 0], # -v_linear, v_angular 72 | [0.1*self.v_linear_max, 0], # slow down 73 | [0., 0.]], device=self.dvc) # stop 74 | self.a_space = self.a_space.unsqueeze(dim=0).repeat((self.N, 1, 1)) # (action_dim,2) -> (N,action_dim,2) 75 | self.a_state = torch.tensor([[0.2,1], [1,1], [1,0], [1,-1], [0.2,-1], [-1,0], [0.1,0], [0,0]], device=self.dvc) # (action_dim,2) 76 | if self.action_type == 'Discrete': self.init_pre_action = (self.action_dim-1)*torch.ones(self.N, dtype=torch.int64, device=self.dvc) 77 | else: self.init_pre_action = torch.zeros((self.N,2), device=self.dvc) # the pre_action of init state 78 | self.arange_constant = torch.arange(self.N, device=self.dvc) # 仅用于索引 79 | self.K = torch.tensor([self.K], device=self.dvc) # K_linear, K_angular 80 | self.ctrl_interval = self.basic_ctrl_interval * torch.ones((self.N,1), device=self.dvc) # control interval, in second; (N,1) 81 | self.ctrl_pipe_init = deque() # holding the delayed action 82 | for i in range(self.ctrl_delay): self.ctrl_pipe_init.append(self.init_pre_action) # 控制指令管道初始化 83 | 84 | 85 | '''Map initialization''' 86 | self.obs_canvas_torch = torch.ones((self.window_size, self.window_size), dtype=torch.int64) # 用于将障碍物坐标转化为2D栅格图 87 | self.target_area = 30 # # enter the central circle (radius=target_area) will be considered win. 88 | self.R_map = int(self.window_size/2) 89 | self.target_point = torch.tensor([[self.R_map,self.R_map]]).repeat(self.N,1).to(self.dvc) # (N,2) 90 | 91 | # Dynamic Obstacle Related: 92 | self.w = 6 # 动态障碍物线条粗细 93 | self.b_kernel = np.ones((self.w + 3, self.w + 3)) # 用于腐蚀的卷积核 94 | self.Obs_refresh_interval = 1 # 障碍物运动刷新周期(相对于小车控制而言,1表示与小车同频率刷新); 增大时,有助于增加Obs速度分辨率,但不利于时序感知 95 | self.Obs_refresh_counter = self.Obs_refresh_interval 96 | self.max_Obs_V = int(self.Obs_V * self.Obs_refresh_interval * self.basic_ctrl_interval) # 障碍物x,y轴最大速度 (cm per fresh), 标量 97 | self.Dynamic_obs_canvas = pygame.Surface((2 * (self.Obs_R + self.w), 2 * (self.Obs_R + self.w))) # 用于画动态障碍物(L4),方便后续转化为栅格点 98 | self.l_margin = 50 # left&up margin of dynamic obstacle's moving space 99 | self.h_margin = self.window_size-50 # right&bottom margin of dynamic obstacle's moving space 100 | 101 | 102 | # Static Obstacle Related: 103 | self.Static_obs_canvas = pygame.Surface((self.window_size, self.window_size)) # 用于画静态障碍物,方便后续转化为栅格点 104 | self.area = 6 # 横/纵向被切分区域的数量 105 | self.sn = 2 # 每个矩形区域的障碍物的最大数量 106 | self.generate_rate = 0.1 # 每个区域,每次产生静态障碍物的概率 107 | self.d_rect = int(self.window_size/self.area) # 生成静态障碍物的矩形区域的边长 108 | self.rect_offsets = torch.cartesian_prod(torch.arange(0, self.area), torch.arange(0, self.area)).numpy() * self.d_rect 109 | 110 | 111 | # Robot Obstacle Related: 112 | if self.RSEO: 113 | self.Robot_obs_canvas = pygame.Surface((4*self.car_radius, 4*self.car_radius)) # 用于画机器人障碍物(L2),方便后续转化为栅格点 114 | self._robot_obstacle_init() 115 | 116 | # Map Obstacle Related: 117 | if self.MapObs: 118 | # self.MapObs should be the name of the .png file, e.g. self.MapObs = 'map.png' 119 | # 'map.png' should be of shape (window_size,window_size,3), where 0 represents obstacles and 255 represents free space. 120 | self._map_obstacle_init() 121 | 122 | 123 | '''Lidar initialization''' 124 | self.ld_acc = 3 # lidar scan accuracy (cm). Reducing accuracy can accelerate simulation; 125 | self.ld_scan_result = torch.zeros((self.N, self.ld_num), device=self.dvc) # used to hold lidar scan result, (N, ld_num) 126 | self.ld_result_grouped = torch.zeros((self.N, self.grouped_ld_num), device=self.dvc) # the grouped lidar scan result, (N, grouped_ld_num) 127 | self.ld_angle_interval = torch.arange(self.ld_num, device=self.dvc) * (self.ld_a_range/180) * torch.pi / (self.ld_num) - (self.ld_a_range/360) * torch.pi #(ld_num, ) 128 | self.ld_angle_interval = self.ld_angle_interval.unsqueeze(dim=0).repeat((self.N, 1)) # (N, ld_num) 129 | 130 | '''State noise initialization (unormalized magnitude)''' 131 | if self.noise: 132 | self.noise_magnitude = torch.hstack((torch.tensor([2,2,torch.pi/50,1,torch.pi/50]), torch.ones(self.grouped_ld_num))).to(self.dvc) #(abs_state_dim,) 133 | 134 | '''Domain Randomization initialization''' 135 | if self.DR: 136 | # 创建基准值,后续在基准值上随机化 137 | self.ctrl_interval_base = self.ctrl_interval.clone() # (N,1) 138 | self.K_base = self.K.clone() # (1,2) 139 | self.a_space_base = self.a_space.clone() # (N,A,2) 140 | self.noise_magnitude_base = self.noise_magnitude.clone() # (abs_state_dim,) 141 | 142 | '''Pygame initialization''' 143 | self.COLORs = np.random.randint(0, 256, size=(self.N, 3)) 144 | self.ri = 0 145 | assert self.render_mode is None or self.render_mode == 'human' 146 | # "human": will render in a pygame window 147 | # None: not render anything 148 | self.window = None 149 | self.clock = None 150 | self.canvas = None 151 | self.render_rate = self.ctrl_interval[self.ri].item() # FPS = 1/self.render_rate 152 | 153 | '''Internal variables initialization''' 154 | # 提前声明变量的数据格式,速度更快 155 | self.step_counter_DR = 0 # 用于记录DR的持续步数 156 | self.step_counter_vec = torch.zeros(self.N, dtype=torch.long, device=self.dvc) # 用于truncate 157 | self.car_state = torch.zeros((self.N, 5), device=self.dvc, dtype=torch.float32) 158 | self.reward_vec = torch.zeros(self.N, device=self.dvc) # vectorized reward signal 159 | self.dw_vec = torch.zeros(self.N, dtype=torch.bool, device=self.dvc) # vectorized terminated signal 160 | self.tr_vec = torch.zeros(self.N, dtype=torch.bool, device=self.dvc) # vectorized truncated signal 161 | self.done_vec = torch.zeros(self.N, dtype=torch.bool, device=self.dvc) # vectorized done signal 162 | # for state normalization: 163 | self.state_upperbound = torch.ones(self.state_dim-4, device=self.dvc) # -4 for exclusion of act_state 164 | self.state_upperbound[0] *= self.D 165 | self.state_upperbound[1] *= 1 # 仅用于补位,后面单独归一化 166 | self.state_upperbound[2] *= self.v_linear_max 167 | self.state_upperbound[3] *= self.v_angular_max 168 | self.state_upperbound[4:self.state_dim] *= self.ld_d_range 169 | 170 | '''Logging''' 171 | if self.dvc.type == 'cpu': 172 | print("Although Sparrow can be deployed on CPU, we strongly recommend you use GPU to accelerate simulation! " 173 | "Please try to use ' dvc=torch.device('cuda') ' when instantiate Sparrow.") 174 | else: 175 | # 编译雷达扫描函数,加速仿真. 但有些显卡上会报错. 176 | if self.compile == True: 177 | self._ld_scan_vec = torch.compile(self._ld_scan_vec) 178 | else: 179 | print("When instantiate Sparrow, you can set 'compile=True' to boost the simulation speed. ") 180 | print(f"Sparrow-{self.version}, N={self.N}, State dimension={self.state_dim}, {self.action_type} action dimension={self.action_dim}.") 181 | 182 | def _random_noise(self, magnitude:float, size:tuple, device:torch.device): 183 | '''Generate uniform random noise in [-magnitude,magnitude)''' 184 | return (torch.rand(size=size, device=device)-0.5) * 2 * magnitude 185 | 186 | def _world_2_grid(self, coordinate_wd): 187 | ''' Convert world coordinates (denoted by _wd, continuous, unit: cm) to grid coordinates (denoted by _gd, discrete, 1 grid = 1 cm) 188 | Input: torch.tensor; Output: torch.tensor; Shape: Any shape ''' 189 | return coordinate_wd.floor().int() 190 | 191 | def _Domain_Randomization(self): 192 | # 1) randomize the control interval; ctrl_interval.shape: (N,1) 193 | self.ctrl_interval = self.ctrl_interval_base + self._random_noise(0.01, (self.N,1), self.dvc)# control interval, in second 194 | 195 | # 2) randomize the kinematic parameter; K.shape: (N,2) 196 | self.K = self.K_base + self._random_noise(0.05, (self.N,2), self.dvc)# control interval, in second; 197 | 198 | # 3) randomize the max velocity; a_space.shape: (N,6,2) 199 | self.a_space = self.a_space_base * (1 + self._random_noise(0.05, (self.N, 1, 2), self.dvc)) # Random the maximal speed of each env copy by 0.9~1.1 200 | 201 | # 4) randomize the magnitude of state noise; noise_magnitude.shape: (N,abs_state_dim) 202 | self.noise_magnitude = self.noise_magnitude_base * (1+self._random_noise(0.25, (self.N, self.absolute_state_dim), device=self.dvc)) 203 | 204 | def _map_obstacle_init(self): 205 | '''Init the bound points of the map obstacles 206 | even_obs_P 丨 (O*P,2) 丨 pygame转换得来 207 | ↓↓↓ 208 | [并行N份, 然后reshape] 209 | ↓↓↓ 210 | vec_map_obs_P_shaped 丨 (N,O*P,2,1) 丨 用于编码 丨 用于pygame渲染 211 | ↓↓↓ 212 | [每次初始化时,编码 (x*window_size+y) ] 213 | ↓↓↓ 214 | vec_map_bound_code 丨 (N,1,O*P) 丨 雷达扫描 215 | ''' 216 | map_pyg = pygame.image.load(self.MapObs) # 不能用plt.imread读, 有bug 217 | map_np = pygame.surfarray.array3d(map_pyg)[:, :, 0] 218 | x_, y_ = np.where(map_np == 0) # 障碍物栅格的x,y坐标 219 | '''注意: 静态障碍物无需对P补齐''' 220 | even_obs_P = torch.tensor(np.stack((x_, y_), axis=1)) # 障碍物栅格点, (O*P, 2), on cpu 221 | '''地图障碍物并行N份;然后重塑维度,方便后续扫描''' 222 | self.vec_map_obs_P_shaped = even_obs_P[None,:,:,None].repeat(self.N,1,1,1).to(self.dvc) # (N,O*P,2,1), on dvc 223 | # 第ri个env中,所有地图障碍物的x坐标=self.vec_map_obs_P_shaped[self.ri,:,0,0] ; y坐标=self.vec_map_obs_P_shaped[self.ri,:,1,0] 224 | 225 | '''对地图障碍物的x,y坐标进行编码,方便后续扫描''' 226 | self.vec_map_bound_code = (self.vec_map_obs_P_shaped[:,:,0,0]*self.window_size +self.vec_map_obs_P_shaped[:,:,1,0]).unsqueeze(1) # (N,1,O*P) 227 | 228 | def _static_obstacle_init(self): 229 | '''Init the bound points of the static obstacles 230 | even_obs_P 丨 (O*P,2) 丨 pygame绘制得来 231 | ↓↓↓ 232 | [并行N份, 然后reshape] 233 | ↓↓↓ 234 | vec_static_obs_P_shaped 丨 (N,O*P,2,1) 丨 用于编码 丨 用于pygame渲染 235 | ↓↓↓ 236 | [每次初始化时,编码 (x*window_size+y) ] 237 | ↓↓↓ 238 | vec_static_bound_code 丨 (N,1,O*P) 丨 雷达扫描 239 | 240 | 注:静态障碍物在N个并行环境中是完全一致的 241 | ''' 242 | self.Static_obs_canvas.fill((0, 0, 0)) 243 | 244 | '''绘制地图边界''' 245 | pygame.draw.line(self.Static_obs_canvas, (1, 1, 1), (0, 0),(0, self.window_size), width=self.w-1) 246 | pygame.draw.line(self.Static_obs_canvas, (1, 1, 1), (0, self.window_size),(self.window_size, self.window_size), width=self.w) 247 | pygame.draw.line(self.Static_obs_canvas, (1, 1, 1), (self.window_size, self.window_size),(self.window_size, 0), width=self.w+1) 248 | pygame.draw.line(self.Static_obs_canvas, (1, 1, 1), (self.window_size, 0),(0, 0), width=self.w-1) 249 | 250 | '''对Env0绘制4*sn个障碍:''' 251 | cdnts = np.random.rand(2, self.sn, self.area**2, 2)*self.d_rect + self.rect_offsets # (start/end,staticObs_numbers,4_rect,x/y) 252 | for i in range(self.sn): 253 | for j in range(self.area**2): 254 | if np.random.rand() < self.generate_rate: # 以概率在每个区域生成障碍物 255 | pygame.draw.line(self.Static_obs_canvas, (1, 1, 1), cdnts[0,i,j], cdnts[1,i,j], width=2*self.Obs_R) 256 | 257 | obs_np = pygame.surfarray.array3d(self.Static_obs_canvas)[:, :, 0] 258 | b_obs_np = ndimage.binary_erosion(obs_np, self.b_kernel).astype(obs_np.dtype) # 腐蚀障碍物图像 259 | obs_np -= b_obs_np # 减去腐蚀图像,提取轮廓线 260 | x_, y_ = np.where(obs_np == 1) # 障碍物栅格的x,y坐标 261 | '''注意: 静态障碍物无需对P补齐''' 262 | even_obs_P = torch.tensor(np.stack((x_, y_), axis=1)) # 障碍物栅格点, (O*P, 2), on cpu 263 | '''将Env0的障碍物并行N份;然后重塑维度,方便后续扫描''' 264 | self.vec_static_obs_P_shaped = even_obs_P[None,:,:,None].repeat(self.N,1,1,1).to(self.dvc) # (N,O*P,2,1), on dvc 265 | # 第ri个env中,所有静态障碍物的x坐标=self.vec_static_obs_P_shaped[self.ri,:,0,0] ; y坐标=self.vec_static_obs_P_shaped[self.ri,:,1,0] 266 | 267 | '''对静态障碍物的x,y坐标进行编码,方便后续扫描''' 268 | self.vec_static_bound_code = (self.vec_static_obs_P_shaped[:,:,0,0]*self.window_size +self.vec_static_obs_P_shaped[:,:,1,0]).unsqueeze(1) # (N,1,O*P) 269 | 270 | def _dynamic_obstacle_init(self): 271 | '''Init the bound points of the dynamic obstacles: 272 | vec_dynamic_obs_P 丨 (N,O,P,2) 丨 障碍物运动 丨 障碍物反向 273 | ↓↓↓ 274 | [reshape -> 数据联动] 275 | ↓↓↓ 276 | vec_dynamic_obs_P_shaped 丨 (N,O*P,2,1) 丨 用于编码 丨 用于pygame渲染 277 | ↓↓↓ 278 | [每次obs移动后,编码 (x*window_size+y) ] 279 | ↓↓↓ 280 | vec_dynamic_bound_code 丨 (N,1,O*P) 丨 雷达扫描 281 | 282 | 注: 动态障碍物在N个并行环境中完全一致 283 | ''' 284 | 285 | '''变量初始化''' 286 | self.Obs_V_tensor = (self._random_noise(self.Obs_V, (1, self.O, 1, 2), self.dvc).repeat(self.N,1,1,1) * 287 | self.Obs_refresh_interval * self.ctrl_interval.reshape(self.N,1,1,1)).to(self.dvc).round().long() # 障碍物的速度, (N,O,1,2) 288 | 289 | '''对Env0依次绘制O个障碍:''' 290 | uneven_obs_P_list = [] # 未补齐的障碍物栅格点坐标 291 | P_np = np.zeros(self.O, dtype=np.int64) # 记录每个障碍物有多少个Point, 用于后续补齐 292 | for _ in range(self.O): 293 | self.Dynamic_obs_canvas.fill((0, 0, 0)) 294 | if self.RdOT and np.random.rand() < 0.5: # 不规则块状障碍物 295 | thi = np.random.randint(low=20, high=40) 296 | end_pose = np.random.randint(low=(self.Obs_R+self.w), high=(2*(self.Obs_R+self.w)), size=(2,)) 297 | pygame.draw.line(self.Dynamic_obs_canvas, (1, 1, 1), (0, 0),end_pose, width=thi) 298 | else: # 环形障碍物 299 | if self.RdOR: outer_R = 10 + (self.Obs_R-10)*np.random.rand() # 障碍物最小半径10,最大半径Obs_R 300 | else: outer_R = self.Obs_R 301 | pygame.draw.circle(self.Dynamic_obs_canvas, (1, 1, 1), (self.Obs_R + self.w, self.Obs_R + self.w), outer_R) 302 | 303 | obs_np = pygame.surfarray.array3d(self.Dynamic_obs_canvas)[:, :, 0] 304 | b_obs_np = ndimage.binary_erosion(obs_np, self.b_kernel).astype(obs_np.dtype) #腐蚀障碍物图像 305 | obs_np -= b_obs_np #减去腐蚀图像,提取轮廓线 306 | if np.random.rand() < 0.5: obs_np = np.flip(obs_np, (0,)) #水平翻转障碍物 307 | x_, y_ = np.where(obs_np == 1) # 障碍物栅格的x,y坐标 308 | bound_gd = torch.tensor(np.stack((x_,y_), axis=1)) # 障碍物栅格点, (unevenP, 2), on cpu 309 | uneven_obs_P_list.append(bound_gd) 310 | P_np[_] = bound_gd.shape[0] 311 | 312 | self.P = P_np.max() # 障碍物最大Point数量 313 | cP_np = self.P - P_np # 各个障碍物需要补的长度 314 | '''将各个障碍物栅格点bound_gd统一补齐至P个点,方便存储、运算''' 315 | even_obs_P = torch.zeros((self.O, self.P, 2), dtype=torch.long) # (O,P,2) 316 | for _ in range(self.O): 317 | conpensate = torch.ones(size=(cP_np[_],2), dtype=torch.long)*uneven_obs_P_list[_][0] # on cpu 318 | even_obs_P[_] = torch.cat((uneven_obs_P_list[_], conpensate)) 319 | 320 | '''将Env0的障碍物并行N份, 并统一分散''' 321 | self.vec_dynamic_obs_P = even_obs_P[None,:,:,:].repeat(self.N,1,1,1).to(self.dvc) # (N,O,P,2), on cpu 322 | self.vec_dynamic_obs_P += torch.ones((self.N, self.O, 1, 2),dtype=torch.long,device=self.dvc)*(self.R_map-self.Obs_R-self.w) # 平移至中心分散 323 | disperse = self.R_map - 2*(self.Obs_R + self.w) 324 | self.vec_dynamic_obs_P += torch.randint(-disperse, disperse, (1,self.O, 1, 2)).repeat(self.N,1,1,1).to(self.dvc) # N个并行环境统一分散 325 | self.vec_dynamic_obs_P_shaped = self.vec_dynamic_obs_P.reshape(self.N, self.O*self.P, 2, 1) # (N,O,P,2) -> (N,O*P,2,1); on cpu; 与vec_obs_P数据联动 326 | # 第ri个env中,所有动态障碍物的x坐标=self.vec_dynamic_obs_P_shaped[self.ri,:,0,0] ; y坐标=self.vec_dynamic_obs_P_shaped[self.ri,:,1,0] 327 | 328 | def _robot_obstacle_init(self): 329 | '''Init the bound points of the robot obstacles: 330 | robot_obs_P_base 丨 (N,Pr,2) 丨 位置置零 丨 331 | ↓↓↓ 332 | [clone] 333 | ↓↓↓ 334 | robot_obs_P 丨 (N,Pr,2) 丨 更新位置 丨 335 | ↓↓↓ 336 | [reshape, 数据联动] 337 | ↓↓↓ 338 | robot_obs_P_shaped 丨 (N*Pr,2) 丨 用于编码 丨 用于pygame渲染(已注释) 339 | 340 | 注: 动态障碍物在N个并行环境中完全一致 341 | ''' 342 | self.Robot_obs_canvas.fill((0, 0, 0)) 343 | pygame.draw.circle(self.Robot_obs_canvas, (1, 1, 1), (2*self.car_radius, 2*self.car_radius), self.car_radius) 344 | 345 | obs_np = pygame.surfarray.array3d(self.Robot_obs_canvas)[:, :, 0] 346 | b_obs_np = ndimage.binary_erosion(obs_np, self.b_kernel).astype(obs_np.dtype) # 腐蚀障碍物图像 347 | obs_np -= b_obs_np # 减去腐蚀图像,提取轮廓线 348 | x_, y_ = np.where(obs_np == 1) # 障碍物栅格的x,y坐标 349 | bound_gd = torch.tensor(np.stack((x_, y_), axis=1)) # 障碍物栅格点, (P, 2), on cpu 350 | bound_gd -= 2*self.car_radius # 让robot obs的中心处于世界坐标系原点 351 | 352 | '''注意,这里的N不是N个并行环境,而是N个机器人''' 353 | self.robot_obs_P_base = bound_gd[None,:,:].repeat(self.N,1,1).to(self.dvc) # (N,Pr,2), on cpu 354 | self.robot_obs_P = self.robot_obs_P_base.clone()# (N,Pr,2), on dvc, 用于robot obs移动 355 | self.robot_obs_P_shaped = self.robot_obs_P.reshape(-1, 2) # (N*Pr,2), on dvc, 用于编码扫描, 会和vec_robot_obs_P数据联动 356 | 357 | 358 | def _rect_in_bound(self, x:int, y:int, range:int) -> bool: 359 | '''Check whether the rectangle(center=(x,y), D=2*range) has obstacle. 360 | All input should be int.''' 361 | x_min, y_min = max(0, x-range), max(0, y-range) 362 | x_max, y_max = min(self.window_size, x+range), min(self.window_size, y+range) 363 | 364 | rect = torch.cartesian_prod(torch.arange(x_min, x_max), torch.arange(y_min, y_max)) # (X,2) 365 | rect_code = (rect[:,0]*self.window_size + rect[:,1]).unsqueeze(-1).to(self.dvc) # (X*2,1) 366 | 367 | return ((rect_code - self.vec_bound_code[0])==0).any() # (X*2,1)-(1,O1*P1+O2*P2+O3*P3) 368 | 369 | def _target_point_init(self, N:int): 370 | '''Init target point for Envs.N''' 371 | cnt = 0 372 | while True: 373 | cnt += 1 374 | if cnt > 10000: print("The current map is too crowded to find free space for target init.") 375 | d, a = self.D * np.random.uniform(0.3,0.9), 6.28 * torch.rand(1) # dont change 0.9 to 1.0. (Reset Error) 376 | x, y = (self.car_state[N, 0].item() + d * torch.cos(a)).int().item(), (self.car_state[N, 1].item() + d * torch.sin(a)).int().item() 377 | if not ( (self.target_area 10000: print("The current map is too crowded to find free space for robot init.") 388 | loc = torch.randint(low=4*self.car_radius,high=self.window_size-4*self.car_radius,size=(2,),device=self.dvc) 389 | if self._rect_in_bound(loc[0].item(), loc[1].item(), 4*self.car_radius): continue # 与障碍物重合,重新生成 390 | self.car_state[N,0:2] = loc 391 | # 朝向不用管,因为target point也会随机生成 392 | return loc 393 | 394 | def reset(self): 395 | '''Reset all vectorized Env''' 396 | #障碍物初始化 397 | self._static_obstacle_init() 398 | self._dynamic_obstacle_init() 399 | 400 | # 对动态障碍物进行编码,以用于后续生层小车位置和目标位置时的判断: 401 | self.vec_bound_code = (self.vec_dynamic_obs_P_shaped[:,:,0,0]*self.window_size +self.vec_dynamic_obs_P_shaped[:,:,1,0]).unsqueeze(1) # (N,1,O*P) 402 | # 加入静态障碍物: 403 | self.vec_bound_code = torch.cat((self.vec_static_bound_code,self.vec_bound_code), dim=-1) # (N,1,O1*P1)<->(N,1,O2*P2) => (N,1,O1*P1+O2*P2) 404 | # 加入地图障碍物 405 | if self.MapObs: self.vec_bound_code = torch.cat((self.vec_map_bound_code,self.vec_bound_code), dim=-1) 406 | 407 | #小车位置初始化 408 | self.d2target_pre = torch.zeros(self.N, device=self.dvc) # Reset() 不产生奖励信号,这里d2target_pre随便赋值即可 409 | self.car_state.fill_(0) 410 | for i in range(self.N): self._car_loc_init(i) 411 | if self.RSEO: self._robot_obstacle_move() # L2机器人障碍物 412 | 413 | # 目标点初始化: 414 | for i in range(self.N): self._target_point_init(i) 415 | self.d2target_now = (self.car_state[:, 0:2] - self.target_point).pow(2).sum(dim=-1).pow(0.5) # (N,), Reset后离目标点的距离,_reward_function和_Normalize会用 416 | 417 | #步数初始化 418 | self.step_counter_vec.fill_(0) 419 | 420 | #控制指令管道初始化: action5:[0,0] 421 | self.ctrl_pipe = copy.deepcopy(self.ctrl_pipe_init) 422 | 423 | # 获取初始状态 424 | observation_vec = self._get_obs() # absolute car state: (N,abs_state_dim) 425 | # calculate dw, tr, done signals: 426 | self._reward_function(self.init_pre_action) 427 | # add noise to unormalized state: 428 | if self.noise: 429 | observation_vec += self.noise_magnitude*self._random_noise(1, (self.N,self.absolute_state_dim), self.dvc) # (N, abs_state_dim) 430 | 431 | # Normalize the observation: 432 | # absolute coordinates will be transformed to relative distance to target 433 | # absolute orientation will be transformed to relative orientation 434 | relative_observation_vec = self._Normalize(observation_vec) # (N,abs_state_dim) -> (N,abs_state_dim-1) 435 | 436 | # stack action_state to relative_observation_vec 437 | act_relative_observation_vec = self._stack_A_to_S(self.init_pre_action, self.init_pre_action, relative_observation_vec) # (N,abs_state_dim-1) -> (N,state_dim) 438 | 439 | # 由于robot初始化时,未考虑互相重合的现象,有可能会出现reset后马上碰撞的情况。此时我们可以递归地reset(),直到满足需求。 440 | if self.done_vec.any(): return self.reset() 441 | 442 | if self.render_mode == "human": self._render_frame() 443 | return act_relative_observation_vec, dict(abs_car_state=self.car_state.clone(), step_cnt = self.step_counter_vec) 444 | 445 | 446 | def _AutoReset(self): 447 | '''Reset done掉的env(没有done的不受影响)''' 448 | if self.done_vec.any(): 449 | # 1) reset the car pose (only for collided cases) 450 | CollideEnv_idx = torch.where(self.collide_vec)[0] 451 | for i in CollideEnv_idx: self._car_loc_init(i) 452 | 453 | # 2) reset the target point 454 | DoneEnv_idx = torch.where(self.done_vec)[0] 455 | for i in DoneEnv_idx: self._target_point_init(i) 456 | 457 | # 3) reset the step counter 458 | self.step_counter_vec[self.done_vec] = 0 459 | 460 | 461 | def _dynamic_obstacle_move(self): 462 | # 随机障碍物速度(对于每个Env的每一个Obs都随机) 463 | if self.RdOV: self.Obs_V_tensor += torch.randint(-1,2,(1, self.O, 1, 2),device=self.dvc).repeat(self.N,1,1,1) # 每次速度改变量∈[-1,0,1] 464 | 465 | # 限速 466 | self.Obs_V_tensor.clip_(-self.max_Obs_V, self.max_Obs_V) # max_Obs_V是标量 467 | 468 | # 移动障碍物, 注意vec_dynamic_obs_P_shaped会与vec_dynamic_obs_P数据联动 469 | self.vec_dynamic_obs_P += self.Obs_V_tensor # (N,O,P,2) += (N,O,1,2) 470 | 471 | 472 | # 对动态障碍物进行编码,以用于后续雷达扫描: 473 | self.vec_bound_code = (self.vec_dynamic_obs_P_shaped[:,:,0,0]*self.window_size +self.vec_dynamic_obs_P_shaped[:,:,1,0]).unsqueeze(1) # (N,1,O*P) 474 | # 加入静态障碍物: 475 | self.vec_bound_code = torch.cat((self.vec_static_bound_code,self.vec_bound_code), dim=-1) # (N,1,O1*P1)<->(N,1,O2*P2) => (N,1,O1*P1+O2*P2) 476 | # 加入地图障碍物 477 | if self.MapObs: self.vec_bound_code = torch.cat((self.vec_map_bound_code,self.vec_bound_code), dim=-1) 478 | 479 | 480 | # 查看哪些环境的哪些障碍物的x轴/y轴速度需要反向: 481 | Vx_reverse = ((self.vec_dynamic_obs_P[:, :, :, 0] < self.l_margin) + (self.vec_dynamic_obs_P[:, :, :, 0] > self.h_margin)).any(dim=-1) # (N,O) 482 | Vy_reverse = ((self.vec_dynamic_obs_P[:, :, :, 1] < self.l_margin) + (self.vec_dynamic_obs_P[:, :, :, 1] > self.h_margin)).any(dim=-1) # (N,O) 483 | # 对越界的障碍物速度反向: 484 | V_reverse = torch.stack([Vx_reverse, Vy_reverse], dim=2).unsqueeze(2) 485 | self.Obs_V_tensor[V_reverse] *= -1 486 | 487 | 488 | def _robot_obstacle_move(self): 489 | # 根据robot位置,配置机器人障碍物位置: 490 | self.robot_obs_P.copy_(self.robot_obs_P_base) # 先恢复到初始状态(左上角原点); (N,Pr,2) 491 | car_xy = self.car_state[:,0:2].unsqueeze(1).long() # (N,1,2) 492 | self.robot_obs_P.add_(car_xy) # 再添加机器人位置偏置; (N,Pr,2) 493 | # 注: 上面操作后,self.robot_obs_P_shaped的数据会联动改变 494 | 495 | # 对机器人障碍物进行编码,方便后续扫描: 496 | self.robot_obs_bound_code = (self.robot_obs_P_shaped[:,0]*self.window_size + self.robot_obs_P_shaped[:,1]).unsqueeze(0) # (1,N*Pr) 497 | 498 | 499 | def _ld_not_in_bound_vec(self): 500 | '''Check whether ld_end_code is not in bound_code in a vectorized way => goon''' 501 | if self.RSEO: bound_code = torch.cat((self.vec_bound_code[0], self.robot_obs_bound_code),dim=-1) 502 | else: bound_code = self.vec_bound_code[0] 503 | pre_goon = self.ld_end_code[:, :, None] - bound_code # (N,ld_num,1)-(1,N*P) 504 | 505 | # 判断是否存在零值,存在即A中的元素在B中存在 506 | return ~torch.any(pre_goon == 0, dim=2) # goon 507 | 508 | def _ld_scan_vec(self): 509 | '''Get the scan result (in vectorized worlds) of lidars. ''' 510 | # 扫描前首先同步雷达与小车位置: 511 | self.ld_angle = self.ld_angle_interval + self.car_state[:,2,None]# 雷达-小车方向同步, (N, ld_num) + (N, 1) = (N, ld_num) 512 | self.ld_vectors_wd = torch.stack((torch.cos(self.ld_angle), -torch.sin(self.ld_angle)), dim=2) # 雷达射线方向, (N,ld_num,2), 注意在unified_cs中是-sin 513 | self.ld_end_wd = self.car_state[:,None,0:2] + self.scan_radius * self.ld_vectors_wd # 扫描过程中,雷达射线末端世界坐标(初始化于小车轮廓), (N,1,2)+(N,ld_num,2)=(N,ld_num,2) 514 | self.ld_end_gd = self._world_2_grid(self.ld_end_wd) # 扫描过程中,雷达射线末端栅格坐标, (N,ld_num,2) 515 | self.ld_end_code = self.ld_end_gd[:,:,0]*self.window_size + self.ld_end_gd[:,:,1]# 扫描过程中,雷达射线末端栅格坐标的编码值, (N,ld_num) 516 | 517 | # 扫描初始化 518 | self.ld_scan_result.fill_(0) # 结果归零, (N, ld_num) 519 | increment = self.ld_vectors_wd * self.ld_acc # 每次射出的增量, (N,ld_num,2) 520 | 521 | # 并行式烟花式扫描(PS:当射线穿过地图边界后,会对称地进行扫描。比如穿过上边界,射线会从下边界再射出。这可以模拟地图之外的障碍物。) 522 | for i in range( int((self.ld_d_range-self.scan_radius)/self.ld_acc) + 2 ): # 多扫2次,让最大值超过self.ld_d_range,便于clamp 523 | # 更新雷达末端位置 524 | goon = self._ld_not_in_bound_vec() # 计算哪些ld_end_code不在bound_code里, 即还没有扫到障碍 #(N, ld_num) 525 | self.ld_end_wd += (goon[:,:,None] * increment) # 更新雷达末端世界坐标,每次射 ld_acc cm #(N, ld_num,1)*(N,ld_num,2)=(N,ld_num,2) 526 | self.ld_end_gd = self._world_2_grid(self.ld_end_wd)# 更新雷达末端栅格坐标(必须更新,下一轮会调用), (N,ld_num,2) 527 | self.ld_end_code = self.ld_end_gd[:, :, 0] * self.window_size + self.ld_end_gd[:, :, 1]# 更新雷达末端栅格坐标编码值, (N,ld_num) 528 | self.ld_scan_result += (goon * self.ld_acc)# 累计扫描距离 (N, ld_num) 529 | 530 | if (~goon).all(): break # 如果所有ld射线都被挡,则扫描结束 531 | 532 | # 扫描的时候从小车轮廓开始扫的,最后要补偿小车半径的距离; (ld_num, ); torch.tensor 533 | self.ld_scan_result = (self.ld_scan_result + self.scan_radius).clamp(0,self.ld_d_range) #(N, ld_num) 534 | 535 | # 将雷达结果按ld_GN分组,并取每组的最小值作为最终结果 536 | self.ld_result_grouped, _ = torch.min(self.ld_scan_result.reshape(self.N, self.grouped_ld_num, self.ld_GN), dim=-1, keepdim=False) 537 | 538 | def _reward_function(self, current_a): 539 | '''Calculate vectorized reward, terminated(dw), truncated(tr), done(dw+tr) signale''' 540 | self.tr_vec = (self.step_counter_vec > self.max_ep_steps)# truncated signal (N,) 541 | self.exceed_vec = self.d2target_now > self.D # (N,) 542 | self.win_vec = self.d2target_now < self.target_area # (N,) 543 | self.collide_vec = (self.ld_result_grouped < self.collision_trsd).any(dim=-1) # (N,) 544 | self.dw_vec = self.exceed_vec + self.win_vec + self.collide_vec # terminated signal (N,) 545 | self.done_vec = self.tr_vec + self.dw_vec # (N,), used for AutoReset 546 | 547 | xy_in_target = self.car_state[:, 0:2] - self.target_point # 小车在以target为原点的坐标系下的坐标, (N,2), 注意这里是无噪声的 548 | beta = torch.arctan(xy_in_target[:,0] / xy_in_target[:,1]) + torch.pi / 2 + (xy_in_target[:,1] < 0) * torch.pi # (N,) 549 | alpha = (beta - self.car_state[:, 2]) / torch.pi # (N,) 550 | alpha += (2 * (alpha < -1) - 2 * (alpha > 1)) # 修复1/2象限、3/4象限, (N,) 551 | 552 | R_distance = ((self.d2target_pre - self.d2target_now)/(self.v_linear_max*self.ctrl_interval.view(-1))).clamp_(-1,1) # 朝目标点移动时得分,背离时扣分。 (N,)∈[-1,1] 553 | R_orientation = (0.25-alpha.abs().clamp(0,0.25))/0.25 # (-0.25~0~0.25) -> (0,1,0), 朝着目标点时奖励最大, (N,) 554 | if self.action_type == 'Discrete': 555 | R_forward = (current_a==2) # 鼓励使用前进动作(提升移动速度、防止原地滞留) (N,) = 0 or 1 556 | R_retreat_slowdown = (current_a==5) + (current_a==6) # 惩罚后退和减速 557 | else: 558 | # R_forward = (current_a[:,0] > 0.5) 559 | R_forward = current_a[:,0].clip(0., 1.) # 向前的线速度越大,奖励越高 560 | R_retreat_slowdown = (current_a[:,0] <= 0) 561 | self.reward_vec = 0.5 * R_distance + R_orientation * R_forward - 0.5 * R_retreat_slowdown - 0.5 # -0.5为了防止agent太猥琐,到处逗留 562 | # self.reward_vec = (0.5 * R_distance + R_orientation * R_forward - 0.5 * R_retreat_slowdown - 0.25) / 1.25 # Normalized reward, maybe better 563 | 564 | self.reward_vec[self.win_vec] = self.AWARD 565 | self.reward_vec[self.exceed_vec] = self.PUNISH 566 | self.reward_vec[self.collide_vec] = self.PUNISH 567 | 568 | def _Normalize(self, observation) -> torch.tensor: 569 | '''Normalize the raw observations (N,abs_state_dim) to relative observations (N,abs_state_dim-1)''' 570 | # 1) Normalize the orientation: 571 | xy_in_target = observation[:,0:2] - self.target_point # 小车在以target为原点的坐标系下的坐标, (N,2), 注意这里可能带噪声的 572 | beta = torch.arctan(xy_in_target[:,0] / xy_in_target[:,1]) + torch.pi / 2 + (xy_in_target[:,1] < 0) * torch.pi # (N,) 573 | observation[:, 2] = (beta - observation[:, 2]) / torch.pi 574 | observation[:, 2] += ( 2*(observation[:, 2] < -1) - 2*(observation[:, 2] > 1) ) # 修复1/2象限、3/4象限 575 | 576 | # 2) Stack d2target_now with observation[:,2:] 577 | new_obs = torch.hstack((self.d2target_now.unsqueeze(-1), observation[:,2:])) # (N,abs_state_dim-1), [D2T,alpha,Vlinear,Vangle,ld0,...ldn] 578 | 579 | # 3) Normalize new_obs: 580 | return new_obs/self.state_upperbound 581 | 582 | def _stack_A_to_S(self, current_a, real_a, observation) -> torch.tensor: 583 | """ 584 | transform action (N,) to action_state (N,2) and 585 | stack action_state (N,2) to the observation""" 586 | if self.action_type == 'Discrete': 587 | return torch.cat((self.a_state[current_a], self.a_state[real_a], observation), dim=1) # (N,2)+(N,2)->(N,abs_state_dim-1) => (N,state_dim) 588 | else: 589 | return torch.cat((current_a, real_a, observation), dim=1) # (N,2)+(N,2)->(N,abs_state_dim-1) => (N,state_dim) 590 | 591 | 592 | def _get_obs(self) -> torch.tensor: 593 | '''Return: Un-normalized and un-noised observation [dx, dy, theta, v_linear, v_angular, lidar_results(0), ..., lidar_results(n-1)] in shape (N,abs_state_dim) ''' 594 | # 1.障碍物运动: 595 | self.Obs_refresh_counter += 1 596 | if self.Obs_refresh_counter > self.Obs_refresh_interval: 597 | self._dynamic_obstacle_move() # L4动态障碍物 598 | self.Obs_refresh_counter = 1 599 | 600 | # 2.雷达扫描: 601 | self._ld_scan_vec() 602 | 603 | # 3.制作observation 604 | observation_vec = torch.concat((self.car_state, self.ld_result_grouped), dim=-1) #(N, 5) cat (N, grouped_ld_num) = (N, abs_state_dim) 605 | return observation_vec 606 | 607 | def _Discrete_Kinematic_model_vec(self, a): 608 | ''' V_now = K*V_previous + (1-K)*V_target 609 | Input: discrete action index, (N,) 610 | Output: [v_l, v_l, v_a], (N,3)''' 611 | self.car_state[:,3:5] = self.K * self.car_state[:,3:5] + (1-self.K)*self.a_space[self.arange_constant,a] # self.a_space[a] is (N,2) 612 | return torch.stack((self.car_state[:,3],self.car_state[:,3],self.car_state[:,4]),dim=1) # [v_l, v_l, v_a], (N,3) 613 | 614 | 615 | def _Continuous_Kinematic_model_vec(self, a): 616 | ''' V_now = K*V_previous + (1-K)*V_target 617 | Input: continuous action, (N,2) 618 | Output: [v_l, v_l, v_a], (N,3)''' 619 | self.car_state[:,3:5] = self.K * self.car_state[:,3:5] + (1-self.K)*self.continuous_scale*a # a.shape = (N,2) 620 | return torch.stack((self.car_state[:,3],self.car_state[:,3],self.car_state[:,4]),dim=1) # [v_l, v_l, v_a], (N,3) 621 | 622 | def step(self,current_a): 623 | """ 624 | When self.action_type=='Discrete', 'current_a' should be a vectorized discrete action of dim (N,) on self.dvc 625 | For self.action_type=='Continuous', 'current_a' should be a vectorized continuous action of dim (N,2) on self.dvc 626 | """ 627 | 628 | '''Domain randomization''' 629 | self.step_counter_vec += 1 630 | # domain randomization in a fixed frequency 631 | self.step_counter_DR += self.N 632 | if self.DR and (self.step_counter_DR > self.DR_freq): 633 | self.step_counter_DR = 0 634 | self._Domain_Randomization() 635 | 636 | '''Update car state: [dx, dy, theta, v_linear, v_angular]''' 637 | # control delay mechanism 638 | self.ctrl_pipe.append(current_a) # current_a is the action mapped by the current state 639 | real_a = self.ctrl_pipe.popleft() # real_a is the delayed action, 640 | 641 | # calculate and update the velocity of the car based on the delayed action and the Kinematic_model 642 | if self.action_type == "Discrete": velocity = self._Discrete_Kinematic_model_vec(real_a) # [v_l, v_l, v_a], (N,3) 643 | else: velocity = self._Continuous_Kinematic_model_vec(real_a) # [v_l, v_l, v_a], (N,3) 644 | 645 | # calculate and update the [dx,dy,orientation] of the car 646 | self.d2target_pre = (self.car_state[:, 0:2] - self.target_point).pow(2).sum(dim=-1).pow(0.5) # (N,), 执行动作前离目标点的距离 647 | self.car_state[:,0:3] += self.ctrl_interval * velocity * torch.stack((torch.cos(self.car_state[:,2]), 648 | -torch.sin(self.car_state[:,2]), 649 | torch.ones(self.N,device=self.dvc)), dim=1) 650 | self.d2target_now = (self.car_state[:, 0:2] - self.target_point).pow(2).sum(dim=-1).pow(0.5) # (N,), 执行动作后离目标点的距离,_reward_function和_Normalize会用 651 | if self.RSEO: self._robot_obstacle_move() # L2机器人障碍物 652 | 653 | # keep the orientation between [0,2pi] 654 | self.car_state[:,2] %= (2 * torch.pi) 655 | 656 | '''Update observation: observation -> add_noise -> normalize -> stack[cA,rA,O]''' 657 | # get next obervation 658 | observation_vec = self._get_obs() 659 | 660 | # calculate reward, dw, tr, done signals 661 | self._reward_function(current_a) 662 | 663 | # add noise to unormalized state: 664 | if self.noise: 665 | observation_vec += self.noise_magnitude*self._random_noise(1, (self.N,self.absolute_state_dim), self.dvc) # (N, 23) 666 | 667 | # Normalize the observation: 668 | # absolute coordinates will be transformed to relative distance to target 669 | # absolute orientation will be transformed to relative orientation 670 | relative_observation_vec = self._Normalize(observation_vec) # (N,22) 671 | 672 | # stack action_state to relative_observation_vec 673 | act_relative_observation_vec = self._stack_A_to_S(current_a, real_a, relative_observation_vec) # (N,22) -> (N,26) 674 | 675 | '''Render and AutoReset''' 676 | # render the current frame 677 | if self.render_mode == "human": self._render_frame() 678 | 679 | # reset some of the envs based on the done_vec signal 680 | self._AutoReset() 681 | 682 | return act_relative_observation_vec, \ 683 | self.reward_vec.clone(), \ 684 | self.dw_vec.clone(), \ 685 | self.tr_vec.clone(), \ 686 | dict(abs_car_state=self.car_state.clone(), step_cnt = self.step_counter_vec) 687 | 688 | def occupied_grid_map(self) -> np.ndarray: 689 | """Get the occupied grid map (render_mode must be "human") 690 | The ogm can be rendered via 'plt.imshow(self.ogm)' """ 691 | return self.ogm # (window_size, window_size, 3) 692 | 693 | def _render_frame(self): 694 | if self.window is None and self.render_mode == "human": 695 | pygame.init() 696 | pygame.display.init() 697 | self.window = pygame.display.set_mode((self.window_size , self.window_size )) 698 | if self.clock is None and self.render_mode == "human": 699 | self.clock = pygame.time.Clock() 700 | 701 | # init canvas 702 | if self.canvas is None : 703 | self.canvas = pygame.Surface((self.window_size , self.window_size )) 704 | 705 | # draw obstacles: 706 | self.obs_canvas_torch.fill_(255) 707 | if self.MapObs: self.obs_canvas_torch[self.vec_map_obs_P_shaped[self.ri,:,0],self.vec_map_obs_P_shaped[self.ri,:,1]] = 0 708 | self.obs_canvas_torch[self.vec_static_obs_P_shaped[self.ri,:,0],self.vec_static_obs_P_shaped[self.ri,:,1]] = 0 709 | self.obs_canvas_torch[self.vec_dynamic_obs_P_shaped[self.ri,:,0],self.vec_dynamic_obs_P_shaped[self.ri,:,1]] = 105 #101, 104, 105 710 | obstacles = pygame.surfarray.make_surface(self.obs_canvas_torch.numpy()) 711 | self.canvas.blit(obstacles, self.canvas.get_rect()) 712 | self.ogm = np.transpose(np.array(pygame.surfarray.pixels3d(self.canvas)), axes=(1, 0, 2)) # occupied grid maps 713 | 714 | # 绘制辅助显示内容: 715 | if self.draw_auxiliary: 716 | # 静态障碍物生成区域: 717 | for i in range(self.area**2): 718 | pygame.draw.rect(self.canvas, (128, 128, 128), (self.rect_offsets[i,0], self.rect_offsets[i,1], self.d_rect, self.d_rect),width=2) 719 | 720 | # prepare data for plot 721 | if self.show_ld: 722 | ld_result = self.ld_scan_result.cpu().clone() # (N,ld_num) 723 | ld_real_end_gd = self._world_2_grid(self.car_state[:,0:2].cpu().unsqueeze(1) + ld_result.unsqueeze(-1) * self.ld_vectors_wd.cpu()).numpy() #(N,ld_num,2) 724 | target_point_np = self.target_point.cpu().numpy() 725 | car_center_np = self._world_2_grid(self.car_state[:,0:2]).cpu().numpy() 726 | car_head = self.car_state[:,0:2] + self.car_radius * torch.stack([torch.cos(self.car_state[:,2]), -torch.sin(self.car_state[:,2])],dim=1) 727 | car_head_np = self._world_2_grid(car_head).cpu().numpy() 728 | 729 | # plot 730 | for n in range(self.N): 731 | # draw target area 732 | pygame.draw.circle(self.canvas, self.COLORs[n], target_point_np[n], self.target_area, 4) 733 | 734 | # draw lidar rays on canvas 735 | if self.show_ld: 736 | for i in range(self.ld_num): 737 | e = 255 * ld_result[n,i] / self.ld_d_range 738 | pygame.draw.aaline(self.canvas, (255 - e, 0, e), car_center_np[n], ld_real_end_gd[n,i]) 739 | 740 | # draw car to target lines: 741 | pygame.draw.aaline(self.canvas, self.COLORs[n], car_center_np[n], target_point_np[n]) 742 | 743 | #draw collision threshold on canvas 744 | pygame.draw.circle( 745 | self.canvas, 746 | self.COLORs[n], 747 | car_center_np[n], 748 | self.collision_trsd, 749 | ) 750 | 751 | #draw robot on canvas 752 | pygame.draw.circle( 753 | self.canvas, 754 | (200, 128, 250), 755 | car_center_np[n], 756 | self.car_radius, 757 | ) 758 | 759 | # draw robot orientation on canvas 760 | pygame.draw.line( 761 | self.canvas, 762 | (0, 255, 255), 763 | car_center_np[n], 764 | car_head_np[n], 765 | width=2 766 | ) 767 | 768 | if self.render_mode == "human": 769 | # The following line copies our drawings from `canvas` to the visible window 770 | self.window.blit(self.canvas, self.canvas.get_rect()) 771 | pygame.event.pump() 772 | pygame.display.update() 773 | 774 | if self.render_speed == 'real': 775 | self.clock.tick(int(1 / self.render_rate)) 776 | elif self.render_speed == 'fast': 777 | self.clock.tick(0) 778 | elif self.render_speed == 'slow': 779 | self.clock.tick(30) 780 | else: 781 | print('Wrong Render Speed, only "real"; "fast"; "slow" is acceptable.') 782 | 783 | else: #rgb_array 784 | return np.transpose(np.array(pygame.surfarray.pixels3d(self.canvas)), axes=(1, 0, 2)) 785 | 786 | def close(self): 787 | if self.window is not None: 788 | pygame.display.quit() 789 | pygame.quit() 790 | 791 | """--------------------------------------------------------------------------------------------------------""" 792 | 793 | def str2bool(v): 794 | '''Fix the bool BUG for argparse: transfer string to bool''' 795 | if isinstance(v, bool): return v 796 | if v.lower() in ('yes', 'True','true','TRUE', 't', 'y', '1', 'T'): return True 797 | elif v.lower() in ('no', 'False','false','FALSE', 'f', 'n', '0', 'F'): return False 798 | else: print('Wrong Input Type!') 799 | 800 | """--------------------------------------------------------------------------------------------------------""" 801 | 802 | --------------------------------------------------------------------------------