├── README.md └── skychainEnv ├── __init__.py └── myEnv.py /README.md: -------------------------------------------------------------------------------- 1 | ## Skychain 2 | The simulation of paper *[SkyChain: A Deep Reinforcement Learning-Empowered Dynamic Blockchain Sharding System](https://dl.acm.org/doi/abs/10.1145/3404397.3404460?casa_token=cORVUGinc7MAAAAA%3ANNmTSDChQXOvPDEf6F72O83Eb8lBqlF3Od8S2_Ve66foL0Vb3_wGxEE1aIWaOoUNi1g-DvmlgArbeA)*. 3 | 4 | ## Installation 5 | 6 | 1. install [openai gym](https://github.com/openai/gym) 7 | 8 | 2. install [openai baselines](https://github.com/openai/baselines) 9 | 10 | ## Register environment 11 | 12 | 1. Create a new folder ```myEnvDir``` in ```$gym_path/gym/envs``` 13 | 14 | 2. Copy ```./skychainEnv/*``` to ```$myEnvDir``` 15 | 16 | 3. Register the environment in ```$gym_path/gym/envs/__init__.py```, adding the following codes: 17 | 18 | ``` 19 | register( 20 | id="skychain-env", 21 | entry_point="gym.envs.myEnvDir.myEnv:myEnvClass", 22 | max_episode_steps=1000, 23 | ) 24 | ``` 25 | 26 | ## Test 27 | 28 | Now, you can test the codes by running 29 | ``` 30 | python -m baselines.run --alg=ddpg --env=skychain-env --num_timesteps=1e6 --num_layers=2 --num_hidden=300 --nsteps=100 --layer_norm=True 31 | ``` -------------------------------------------------------------------------------- /skychainEnv/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.myEnvDir.myEnv import myEnvClass -------------------------------------------------------------------------------- /skychainEnv/myEnv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from scipy.special import comb, perm 4 | import gym 5 | from gym.spaces import Box 6 | import copy 7 | import time 8 | 9 | '''state space''' 10 | node_number = 100000 11 | consensus_node_number = 8000 12 | tx_queue = 2e8 13 | hash_rate = 300 #kGH/s 14 | 15 | '''action sapce''' 16 | epoch_len = 10.0 17 | shard_size = 200 18 | MAX_SHARD_SIZE = 500 19 | MAX_EPOCH_LEN = 10000 20 | shard_number = 20 21 | MAX_SHARD_NUMBER = 1 22 | MAX_BLOCK_SIZE = 10*np.power(2,20) 23 | 24 | fix_epoch = np.ceil(MAX_EPOCH_LEN*0.1) 25 | fix_shard = np.power(2, np.round(0.6 * 10/2) + 2) 26 | fix_block = np.ceil(MAX_BLOCK_SIZE*0.1) 27 | fix_pow = 1 28 | test_type = 1 #1:our, 2:fixe epoch, 3:fix shard, 4:fix block 29 | 30 | '''poisson''' 31 | tx_speed = 100000 32 | node_speed = 100 33 | poisson_size = 100000000 34 | 35 | '''speed''' 36 | trans_rate = 1 * np.power(2,20) #1MBps 37 | '''size''' 38 | header_size = 548 39 | '''time''' 40 | route_trans_time = 0.0003 41 | valiate_time = 0.1 42 | add_block_time = 0.2 43 | randomness_time = 20 44 | state_block_time = 1 45 | verification_time = 1 46 | shuffle_time = 100.0 47 | 48 | 49 | '''others''' 50 | gamma = 1/2 51 | security_parameter = 5 52 | no_corrupt_bound = 0.9 53 | delay_latency_bound = 200 54 | block_size = 1024 55 | tx_size = 512 56 | MAX_TIME = 2000000 57 | MAX_TX_QUEUE = 2e13 58 | system_tolerance = 1/4 59 | shard_tolerance = 1/3 60 | corrupt_ablity = 1.4e-6 61 | no_safe_punishment = 0 62 | confirmation_latency_punishment = -1000 63 | queue_full_punishment = -1000 64 | illegal_punishment = -1000 65 | assign_punishment = 0 66 | 67 | #observation_space = Box(low=np.array([1, 1]),high=np.array([MAX_TX_QUEUE, node_number])) 68 | observation_space = Box(low=0, high=1, shape = (26, )) #改为二进制输入 1+15+10 69 | action_space = Box(low=np.array([0, 0, 0]),high=np.array([MAX_EPOCH_LEN, MAX_SHARD_NUMBER, MAX_BLOCK_SIZE])) #ours 70 | #action_space = Box(low=np.array([0, 0]),high=np.array([MAX_SHARD_NUMBER, MAX_BLOCK_SIZE])) #fix epoch lenth 71 | #action_space = Box(low=np.array([0, 0]),high=np.array([MAX_EPOCH_LEN, MAX_BLOCK_SIZE])) #fix shards number 72 | #action_space = Box(low=np.array([0, 0]),high=np.array([MAX_EPOCH_LEN, MAX_SHARD_NUMBER])) #fix block size 73 | 74 | 75 | 76 | class myEnvClass(gym.Env): 77 | def __init__(self): 78 | self.metadata = {'render.modes': []} 79 | self.reward_range = (-float('inf'), float('inf')) 80 | self.spec = None 81 | 82 | # Set these in ALL subclasses 83 | self.action_space = action_space 84 | self.observation_space = observation_space 85 | 86 | #initialize state and action 87 | self.node_number = node_number 88 | self.consensus_node_number = consensus_node_number 89 | self.tx_queue = tx_queue 90 | self.hash_rate = hash_rate 91 | 92 | self.observation = np.array([self.tx_queue,self.consensus_node_number, self.hash_rate]) 93 | self.epoch_len = epoch_len 94 | self.shard_size = shard_size 95 | self.hash_rate = hash_rate 96 | #self.action = [self.epoch_len, self.shard_size] 97 | 98 | #initialize hyper-parameter 99 | self.delay_latency_bound = delay_latency_bound 100 | self.security_parameter = security_parameter 101 | self.no_corrupt_bound = no_corrupt_bound 102 | self.block_size = block_size 103 | self.tx_size = tx_size 104 | self.randomness_time = randomness_time 105 | self.state_block_time = state_block_time 106 | self.verification_time = verification_time 107 | self.shuffle_time = shuffle_time 108 | self.MAX_TIME = MAX_TIME 109 | self.step_time = 0. 110 | self.MAX_SHARD_SIZE = MAX_SHARD_SIZE 111 | self.MAX_EPOCH_LEN = MAX_EPOCH_LEN 112 | self.trans_rate = trans_rate 113 | self.header_size = header_size 114 | self.route_trans_time = route_trans_time 115 | self.valiate_time = valiate_time 116 | self.add_block_time = add_block_time 117 | self.tx_poisson = np.random.poisson(lam = tx_speed, size = poisson_size) 118 | self.gamma = gamma 119 | # self.node_poisson = np.random.poisson(lam = node_speed, size = poisson_size) 120 | # temp_index = np.random.randint(0, 2, poisson_size) 121 | # old_int = 0 122 | # new_int = -1 123 | # temp_i = (temp_index == old_int) 124 | # temp_index[temp_i] = new_int 125 | # self.node_poisson = self.node_poisson*temp_index 126 | self.node_poisson = np.random.normal(loc=10, scale=100, size = poisson_size) 127 | self.hashrate_in = np.random.normal(loc=0, scale=20, size = poisson_size) 128 | self.system_tolerance = system_tolerance 129 | self.shard_tolerance = shard_tolerance 130 | self.step_number = 0 131 | 132 | '''fix''' 133 | self.fix_epoch = fix_epoch 134 | self.fix_shard = fix_shard 135 | self.fix_block = fix_block 136 | self.fix_pow = fix_pow 137 | 138 | 139 | def reset(self): 140 | """Resets the state of the environment and returns an initial observation. 141 | Returns: 142 | observation (object): the initial observation. 143 | """ 144 | #initialize state and action 145 | self.node_number = node_number 146 | self.consensus_node_number = consensus_node_number 147 | self.tx_queue = tx_queue 148 | self.observation = np.array([self.tx_queue,self.consensus_node_number, self.hash_rate]) 149 | self.epoch_len = epoch_len 150 | self.shard_size = shard_size 151 | self.step_time = 0. 152 | self.MAX_TIME = MAX_TIME 153 | self.MAX_EPOCH_LEN = MAX_EPOCH_LEN 154 | self.MAX_SHARD_SIZE = MAX_SHARD_SIZE 155 | self.trans_rate = trans_rate 156 | self.header_size = header_size 157 | self.route_trans_time = route_trans_time 158 | self.valiate_time = valiate_time 159 | self.add_block_time = add_block_time 160 | #self.tx_poisson = np.random.poisson(lam = tx_speed, size = poisson_size) 161 | # self.node_poisson = np.random.normal(loc=10, scale=100, size = poisson_size) 162 | #self.node_poisson 163 | #raise NotImplementedError 164 | self.gamma = gamma 165 | self.system_tolerance = system_tolerance 166 | self.shard_tolerance = shard_tolerance 167 | self.step_number = 0 168 | 169 | obs_len = 15 170 | obs_len_2 = 10 171 | obs_temp = int(np.round(self.observation[1])) 172 | obs_computing_difficulty = int(np.round(self.observation[2])) 173 | bi_obs_str = '{:015b}'.format(obs_temp) 174 | bi_obs_2_str = '{:010b}'.format(obs_computing_difficulty) 175 | bi_obs = [] 176 | bi_obs_2 = [] 177 | for i in range(obs_len): 178 | bi_obs.append(int(bi_obs_str[i])) 179 | for i in range(obs_len_2): 180 | bi_obs_2.append(int(bi_obs_2_str[i])) 181 | return np.hstack(([self.observation[0]], bi_obs, bi_obs_2)) 182 | 183 | def step(self, action): 184 | # if test_type == 1 or test_type == 4: 185 | # action[1] = int(action[1] * 10/2) + 2 186 | # action[1] = np.power(2, action[1]) 187 | ## action[1] = 64 188 | # elif test_type == 2: 189 | # action[0] = int(action[0] * 10/2) + 2 190 | # action[0] = np.power(2, action[0]) 191 | # 192 | # for i in range(len(action)): 193 | # action[i] = np.ceil(action[i]) 194 | # if test_type == 1: 195 | # self.fix_epoch = action[0] 196 | # self.fix_shard = action[1] 197 | # self.fix_block = action[2] 198 | # elif test_type == 2: 199 | # self.fix_shard = action[0] 200 | # self.fix_block = action[1] 201 | # elif test_type == 3: 202 | # self.fix_epoch = action[0] 203 | # self.fix_block = action[1] 204 | # elif test_type == 4: 205 | # self.fix_epoch = action[0] 206 | # self.fix_shard = action[1] 207 | 208 | action[1] = int(action[1] * 8) 209 | action[1] = np.power(2, action[1]) 210 | self.fix_epoch = action[0] 211 | self.fix_shard = action[1] 212 | self.fix_block = action[2] 213 | self.fix_pow = 10 214 | 215 | obs_len = 15 216 | obs_len_2 = 10 217 | # print("step_time:", self.step_time) 218 | # print("node_poisson:", self.node_poisson[:10]) 219 | # print("action:", [self.fix_epoch, self.fix_shard, self.fix_block]) 220 | info = {'observation':self.observation, 'action':action, 'step_time':self.step_time} 221 | # print("obs:", self.observation) 222 | # print("before action:", action) 223 | 224 | 225 | if self.fix_epoch >= self.MAX_TIME or self.fix_shard >= self.node_number: 226 | print("too large...") 227 | self.step_time += 100 228 | reward = illegal_punishment 229 | info['step_time'] = self.step_time 230 | new_obs = self.observation.copy() 231 | obs_temp = int(np.round(new_obs[1])) 232 | obs_computing_difficulty = int(np.round(new_obs[2])) 233 | bi_obs_str = '{:015b}'.format(obs_temp) 234 | bi_obs_2_str = '{:010b}'.format(obs_computing_difficulty) 235 | bi_obs = [] 236 | bi_obs_2 = [] 237 | for i in range(obs_len): 238 | bi_obs.append(int(bi_obs_str[i])) 239 | for i in range(obs_len_2): 240 | bi_obs_2.append(int(bi_obs_2_str[i])) 241 | return np.hstack(([self.observation[0]], bi_obs, bi_obs_2)), reward, False, info 242 | self.step_number += 1 243 | reward = 0 244 | done= 0 245 | illegal = False 246 | m = math.ceil(self.observation[1]/self.fix_shard) 247 | new_txs = self._compute_new_txs(self.step_time, self.step_time + self.fix_epoch) 248 | ref_time = self.fix_epoch-self.randomness_time-self.shuffle_time-self.state_block_time-self.verification_time 249 | tx_reduct = math.ceil(self.fix_shard*(ref_time/self._consensus_time(action)*self.fix_block/self.tx_size*(1-self._compute_redundancy_txs(action)))) 250 | # print("new_txs:", new_txs) 251 | # print("ref_time/self._consensus_time(action):", ref_time/self._consensus_time(action)) 252 | # print("self.fix_block/self.tx_size:", self.fix_block/self.tx_size) 253 | # print("tx_reduct:", tx_reduct) 254 | if self._shard_security_hyp(action) and self._corrupt_security(action) and self._consensus_time(action) < ref_time/6 and self._assign_security(action, self.gamma): #self._shard_security(action) 255 | #print("shard_safe") 256 | if self.observation[0] + new_txs < tx_reduct and self.observation[0] + new_txs < MAX_TX_QUEUE: 257 | print("no tx...") 258 | # time.sleep(5) 259 | reward = self.observation[0] + new_txs 260 | # print("reward:", reward) 261 | # print("self.observation[0]:", self.observation[0]) 262 | # print("new_txs:", new_txs) 263 | self.observation[0] = 1 264 | self.observation[2] = np.around(hash_rate+self._compute_hashrate(self.run_time)) 265 | # done=1 266 | new_obs = self.observation.copy() 267 | # print("no tx new_obs:", new_obs) 268 | return new_obs, reward, done, info 269 | 270 | else: 271 | self.observation[0] += (new_txs-tx_reduct) 272 | if self.observation[0] > MAX_TX_QUEUE: 273 | self.observation[0] = MAX_TX_QUEUE 274 | reward = tx_reduct 275 | # if self.observation[0] > MAX_TX_QUEUE: 276 | ## print("queue_full!!!") 277 | # self.observation[0] = MAX_TX_QUEUE 278 | # reward = queue_full_punishment * action[0] 279 | # #reward = 0 280 | # else: 281 | # reward = tx_reduct 282 | else: 283 | #print("punish...") 284 | # print("action:", action) 285 | if self._shard_security_hyp(action) == False: 286 | # print("shard_unsafe") 287 | reward = assign_punishment 288 | # a=1 289 | elif self._corrupt_security(action) == False: 290 | # a=1 291 | # print("corrpt_unsafe") 292 | reward = assign_punishment 293 | elif self._assign_security(action, self.gamma) == False: 294 | # assign_punishment 295 | # print("assign_unsafe") 296 | reward = assign_punishment 297 | else: 298 | reward = assign_punishment 299 | # a=1 300 | # confirmation_latency_punishment 301 | # print("latency too large...") 302 | # print("unsafe") 303 | self.observation[0] += new_txs 304 | # reward = -np.log2(-no_safe_punishment * action[0]) 305 | # print("-no_safe_punishment * action[0]:", -no_safe_punishment * action[0]) 306 | 307 | if self.observation[0] + new_txs > MAX_TX_QUEUE: 308 | self.observation[0] = MAX_TX_QUEUE 309 | self.step_time += self.fix_epoch 310 | new_node = self._compute_new_node(self.step_time, self.step_time + self.fix_epoch) 311 | self.observation[1] += new_node 312 | #self.observation[1] += (action[0]*node_speed-node_reduct) 313 | #done = False 314 | 315 | # if self.step_time > self.MAX_TIME: 316 | # done = 1 317 | if self.observation[1] <= 0: 318 | self.observation[1] = 1 319 | #done = True 320 | if self.observation[1] >= self.node_number: 321 | self.observation[1] = self.node_number 322 | new_obs = self.observation.copy() 323 | self.observation[2] = np.around(hash_rate+self._compute_hashrate(self.step_time)) 324 | # print("reward:", reward/action[0]) 325 | # print("step_next_state:", self.observation) 326 | # print("action:", action) 327 | # print("info_step_time:", info['step_time']) 328 | info['step_time'] = self.step_time 329 | obs_temp = int(np.round(new_obs[1])) 330 | obs_computing_difficulty = int(np.round(new_obs[2])) 331 | bi_obs_str = '{:015b}'.format(obs_temp) 332 | bi_obs_2_str = '{:010b}'.format(obs_computing_difficulty) 333 | bi_obs = [] 334 | bi_obs_2 = [] 335 | for i in range(obs_len): 336 | bi_obs.append(int(bi_obs_str[i])) 337 | for i in range(obs_len_2): 338 | bi_obs_2.append(int(bi_obs_2_str[i])) 339 | return np.hstack(([new_obs[0]], bi_obs, bi_obs_2)), reward, done, info 340 | # return new_obs, reward, done, info 341 | 342 | 343 | def _get_step_time(self): 344 | return self.step_time 345 | 346 | 347 | def _shard_security_bio(self, action): 348 | pro = 0.0 349 | m = math.ceil(self.observation[1]/action[1]) 350 | #print("corrupt node number ", int(self.observation[1]*self.tolerance)) 351 | #print("shard size", int(cur_act[1])) 352 | for x in range(math.floor(m*self.shard_tolerance), m+1): 353 | pro += (comb(m, x)*math.pow(self.system_tolerance, x)*math.pow(1-self.system_tolerance, m-x)) 354 | #print("pro:", pro) 355 | if math.isnan(pro) or math.isinf(pro): 356 | return True 357 | 358 | return pro < math.pow(2, -self.security_parameter) 359 | 360 | def _shard_security_hyp(self, action): 361 | pro = 0.0 362 | m = math.ceil(self.observation[1]/self.fix_shard) 363 | # print("mmmmm:",m) 364 | for x in range(math.floor(m * self.shard_tolerance), m+1, 1): 365 | # test = comb(self.observation[1], m) 366 | # temp = comb(np.floor(self.observation[1]*self.system_tolerance), x, exact=True)*comb(np.floor(self.observation[1]-self.observation[1]*self.system_tolerance,), m-x,exact=True)/comb(np.floor(self.observation[1]), m, exact=True) 367 | temp = comb(np.floor(self.observation[1]*self.system_tolerance), x)*comb(np.floor(self.observation[1]-self.observation[1]*self.system_tolerance,), m-x)/comb(np.floor(self.observation[1]), m) 368 | if math.isnan(temp) : 369 | continue 370 | else: 371 | pro += temp 372 | # print("comb test:", test) 373 | # print("comb temp:", temp) 374 | if math.isnan(pro) or math.isinf(pro): 375 | return True 376 | # print("pro:", pro) 377 | return pro < math.pow(2, -self.security_parameter) 378 | 379 | def _corrupt_security(self, action): 380 | 381 | c_t = self.fix_epoch-self.randomness_time-self.shuffle_time 382 | n_c = self.observation[1] * self.system_tolerance * corrupt_ablity * c_t 383 | # print("n_c", n_c) 384 | # print("self.observation[1]:", self.observation[1]) 385 | # print("action[1]:", action[1]) 386 | # 387 | # print("corrupt number", math.ceil(self.observation[1]/action[1] * self.shard_tolerance)) 388 | return math.ceil(n_c) <= math.ceil(self.observation[1]/self.fix_shard * self.shard_tolerance) 389 | 390 | 391 | def _consensus_time(self, action): 392 | #print("shard-cross-time:", 0.01*math.log(self.observation[1], 2)) 393 | #print("consensus_time:", 0.01*(math.log(self.observation[1], 2)+cur_act[1]*cur_act[1]/self.block_size)) 394 | consensus_t = 3*self.valiate_time + self.add_block_time 395 | gossip_t = (self.fix_block / self.trans_rate) * math.log(self.observation[1]/self.fix_shard ,2) + 2 * (self.header_size / self.trans_rate)* math.log(self.observation[1]/self.fix_shard ,2) 396 | new_tx_consensus_t = consensus_t + gossip_t 397 | # print("new_tx_consensus_t:",new_tx_consensus_t) 398 | total_time = (self.route_trans_time*math.log(self.fix_shard+1, 2) + new_tx_consensus_t) 399 | # print("cross_shard_v time:", self.trans_time*math.log(self.observation[1]/action[1]+1, 2)) 400 | # print("one tx time:", total_time) 401 | # print("total_time:", total_time) 402 | return total_time 403 | 404 | def _compute_new_txs(self, start_time, end_time): 405 | start_time = int(start_time) 406 | end_time = int(end_time) 407 | # if end_time > self.MAX_TIME: 408 | # end_time = self.MAX_TIME 409 | new_txs = 0 410 | for i in range(start_time, end_time): 411 | new_txs += self.tx_poisson[i] 412 | return new_txs 413 | 414 | def _compute_new_node(self, start_time, end_time): 415 | start_time = int(start_time) 416 | end_time = int(end_time) 417 | # new_node = 0 418 | # for i in range(start_time, end_time): 419 | # new_node += (self.node_poisson[i]) 420 | # new_node += (self.node_poisson[end_time]) 421 | # print('node_change:', self.node_poisson[end_time]) 422 | return self.node_poisson[end_time] 423 | # return new_node 424 | 425 | def _compute_redundancy_txs(self, action): 426 | k = math.ceil(self.fix_shard) 427 | rebundance_tx = 1 - math.pow(1/k, 2) 428 | # print("rebundance_tx/(rebundance_tx+1):", rebundance_tx/(rebundance_tx+1)) 429 | return rebundance_tx/(rebundance_tx+1) 430 | def _compute_hashrate(self, end_time): 431 | end_time = int(end_time) 432 | #return 0 433 | #print("self.hashrate_in[end_time]:", self.hashrate_in[end_time]) 434 | if self.hashrate_in[end_time] + hash_rate <= 0: 435 | return 1 436 | else: 437 | return self.hashrate_in[end_time] 438 | 439 | def _assign_security(self, action, _gamma): 440 | pow_times = np.around(self.observation[2]/self.fix_pow) 441 | return pow_times <= np.around(_gamma*self.fix_shard) and pow_times > 0 442 | 443 | def __get_node_poisson__(self): 444 | return self.node_poisson 445 | --------------------------------------------------------------------------------