├── .gitignore ├── README.md ├── never_overwrite.py ├── plot_utils.py ├── rl_ray_ppo.py └── yaml_logging.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.jpg 3 | *.tiff 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | *.backup 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 深度学习模板代码 2 | 3 | * `yaml_logging.py`: 使用了yaml保存日志(yaml保存日志的好处是可以随时从日志查看当前loss多高,配置的超参数,本次实验的代码,甚至能根据pid反查进程的cpu使用率等性能细节,后续处理实验数据时非常方便),并且能随时Ctrl + C停止训练并且不丢失数据 4 | 5 | * `rl_ray_ppo.py`: 使用Ray和pytorch实现了多卡异步并行PPO强化学习训练 6 | 7 | * `plot_utils.py`: 可以使用@packplot装饰器包装你画图的函数(函数参数需要包含保存图片路径filename),使得图片的元信息中包含可重新绘制此图片的代码和数据,防止你忘了数据在哪或者找不到画图的代码,都隐藏在图片的exif里面了 8 | 9 | * `never_overwrite.py`: 在代码第一行`import never_overwrite`即可避免你在实验过程中不小心写入了同名文件导致原数据丢失 -------------------------------------------------------------------------------- /never_overwrite.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import builtins 4 | import uuid 5 | import pathlib 6 | import shutil 7 | 8 | original_open = builtins.open 9 | 10 | def safe_open(filename, mode='r', **kwargs): 11 | if 'w' in mode and os.path.isfile(filename) and shutil.disk_usage('/').free > 2**34: 12 | p = pathlib.Path(filename) 13 | suffix = f"{os.path.getctime(filename)}_{time.time()}_{uuid.uuid4().hex}" 14 | newname = str(p.with_stem(f"{p.stem}_{suffix}")) 15 | os.rename(filename, newname) 16 | else: 17 | newname = filename 18 | return original_open(filename, mode=mode, **kwargs) 19 | 20 | builtins.open = safe_open 21 | -------------------------------------------------------------------------------- /plot_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | import dill 5 | from functools import wraps 6 | import numpy as np 7 | 8 | def packplot(plotf): 9 | @wraps(plotf) 10 | def decorated_plotf(*args, **kwargs): 11 | import dill 12 | import inspect 13 | from PIL import Image 14 | import os 15 | n_pos_args = len(args) 16 | all_arg_names = list(inspect.signature(plotf).parameters.keys()) 17 | pos_arg_names = all_arg_names[:n_pos_args] 18 | for i in range(n_pos_args): 19 | kwargs[pos_arg_names[i]] = args[i] 20 | assert 'filename' in kwargs 21 | filename = kwargs['filename'] 22 | 23 | ret = plotf(**kwargs) 24 | 25 | plot_function_name = plotf.__name__ 26 | try: 27 | plot_function_code = inspect.getsource(plotf) 28 | except OSError: 29 | print(f"Cannot find the code of function {plot_function_name}") 30 | return ret 31 | plot_function_code = clean_code(plot_function_code) 32 | 33 | img = Image.open(filename) 34 | img_packed = img.copy() 35 | img.close() 36 | os.remove(filename) 37 | 38 | packed_info = (kwargs, plot_function_code, plot_function_name) 39 | packed_bytes = dill.dumps(packed_info) 40 | img_packed.save(filename, exif=packed_bytes) 41 | return ret 42 | 43 | return decorated_plotf 44 | 45 | #def plot_random(x, y, filename, **kwargs): 46 | # plot_args = locals() 47 | # import dill 48 | # import inspect 49 | # from PIL import Image 50 | # 51 | # import matplotlib.pyplot as plt 52 | # plt.plot(x, y) 53 | # plt.show() 54 | # plt.savefig(filename) 55 | # 56 | # try: 57 | # plot_function_code = inspect.getsource(plot_random) 58 | # plot_function_name = 'plot_random' 59 | # img = Image.open(filename) 60 | # packed_info = (plot_args, plot_function_code, plot_function_name) 61 | # packed_bytes = dill.dumps(packed_info) 62 | # img.save(filename, exif=packed_bytes) 63 | # except OSError: 64 | # return 65 | # return 66 | 67 | def clean_code(code_str): 68 | # Split the code into lines 69 | lines = code_str.splitlines() 70 | common_indent = '' 71 | min_len = min(len(l) for l in lines) 72 | should_break = False 73 | for i in range(min_len): 74 | cs = set() 75 | for j in range(len(lines)): 76 | c = lines[j][i] 77 | if c in ['\t',' ']: 78 | cs.add(c) 79 | else: 80 | should_break = True 81 | if should_break: 82 | break 83 | if should_break: 84 | break 85 | if len(cs)==1: 86 | common_indent += list(cs)[0] 87 | 88 | # Strip leading and trailing whitespace from each line 89 | cleaned_lines = [line[len(common_indent):] for line in lines] 90 | 91 | # Join the cleaned lines back into a single string with newlines 92 | cleaned_code_str = '\n'.join(cleaned_lines) 93 | 94 | return cleaned_code_str 95 | 96 | def retrieve_plot(filename): 97 | img = Image.open(filename) 98 | img.load() 99 | packed_info = dill.loads(img.info['exif'][6:]) 100 | plot_args, plot_function_code, plot_function_name = packed_info 101 | #assert plot_function_code.startswith("@packplot") 102 | plot_function_code = plot_function_code.split("\n", 1)[1] 103 | filename = plot_args['filename'] 104 | plot_args['filename'] = f'tmp-{str(hash(filename))}' + filename 105 | exec(plot_function_code) 106 | exec(plot_function_name + '(**plot_args)') 107 | 108 | def retrieve_data(filename): 109 | img = Image.open(filename) 110 | img.load() 111 | packed_info = dill.loads(img.info['exif'][6:]) 112 | plot_args, plot_function_code, plot_function_name = packed_info 113 | return plot_args 114 | 115 | def retrieve_code(filename): 116 | img = Image.open(filename) 117 | img.load() 118 | packed_info = dill.loads(img.info['exif'][6:]) 119 | plot_args, plot_function_code, plot_function_name = packed_info 120 | return plot_function_code 121 | 122 | if __name__ == '__main__': 123 | @packplot 124 | def plot_test_dec(x, y, filename): 125 | plt.plot(x, y) 126 | plt.show() 127 | plt.savefig(filename) 128 | return 0 129 | 130 | x = np.random.rand(10) 131 | y = np.random.rand(10) 132 | #plot_random(x, y, 'test.png') 133 | plot_test_dec(x, y, 'test.png') 134 | retrieve_plot('test.png') 135 | -------------------------------------------------------------------------------- /rl_ray_ppo.py: -------------------------------------------------------------------------------- 1 | import ray 2 | import torch 3 | import torch.nn as nn 4 | import gym 5 | import numpy as np 6 | import time 7 | 8 | ray.init(num_gpus=1, local_mode=False) 9 | 10 | class AgentModel(nn.Module): 11 | def __init__(self, ctx_size, d_obs, d_act, low, high): 12 | super().__init__() 13 | self.flatten = nn.Flatten() 14 | self.act = nn.LeakyReLU(0.01) 15 | self.sig = nn.Sigmoid() 16 | self.tanh = nn.Tanh() 17 | self.d1 = nn.Linear(ctx_size * d_obs, 128) 18 | self.ln1 = nn.LayerNorm(128) 19 | self.ln2 = nn.LayerNorm(128) 20 | self.d2 = nn.Linear(128, 128) 21 | self.dv = nn.Linear(128, 1) 22 | self.dloc = nn.Linear(128, d_act) 23 | self.dcov = nn.Linear(128, d_act) 24 | self.w = torch.tensor((high - low)/2) 25 | self.b = torch.tensor((high + low)/2) 26 | 27 | def forward(self, obs): 28 | obs = self.flatten(obs) 29 | h = self.act(self.d1(obs)) 30 | h = self.ln1(h) 31 | h = self.act(self.d2(h)) 32 | h = self.ln2(h) 33 | aloc = self.w * self.tanh(self.dloc(h)) + self.b 34 | acov = self.sig(self.dcov(h)) 35 | a_dist = torch.distributions.Normal(aloc, acov) 36 | act = a_dist.sample() 37 | act_logprob = a_dist.log_prob(act) 38 | 39 | value = self.dv(h)/(1 - 0.95) 40 | return act, act_logprob, a_dist, value 41 | 42 | class ReplayMemory: 43 | def __init__(self, dtype): 44 | self.dtype = dtype 45 | self.clear() 46 | def add(self, obs, act, act_logprob, r, done, value): 47 | self.obs.append(obs) 48 | self.act.append(act) 49 | self.act_logprob.append(act_logprob) 50 | self.r.append(r) 51 | self.done.append(done) 52 | self.value.append(value) 53 | def clear(self): 54 | self.obs = [] 55 | self.act = [] 56 | self.act_logprob = [] 57 | self.r = [] 58 | self.done = [] 59 | self.value = [] 60 | def sample(self, ctx_size, batch_size, last_value): 61 | M = len(self.r) 62 | i = np.random.choice(range(ctx_size - 1, M), batch_size) 63 | 64 | returns, advantages = compute_advantages(self.r, self.done, self.value + [last_value]) 65 | returns = np.array(returns)[i] 66 | advantages = np.array(advantages)[i] 67 | 68 | obs_np = np.array(self.obs) 69 | act_np = np.array(self.act) 70 | act_logprob_np = np.array(self.act_logprob) 71 | obs_ctx = [] 72 | for ctx_idx in range(ctx_size - 1, -1, -1): 73 | obs_ctx.append(obs_np[i - ctx_idx]) 74 | obs_ctx = np.stack(obs_ctx, axis=1) 75 | 76 | obs_ctx = torch.tensor(obs_ctx, dtype=self.dtype) 77 | act = torch.tensor(act_np[i], dtype=self.dtype) 78 | act_logprob = torch.tensor(act_logprob_np[i], dtype=self.dtype) 79 | returns = torch.tensor(returns, dtype=self.dtype) 80 | advantages = torch.tensor(advantages, dtype=self.dtype) 81 | return obs_ctx, act, act_logprob, returns, advantages 82 | 83 | class Context: 84 | def __init__(self, ctx_size, d_obs, dtype): 85 | self.ctx_size = ctx_size 86 | self.d_obs = d_obs 87 | self.obs_ctx = [np.zeros(d_obs, dtype=np.float32) for i in range(ctx_size)] 88 | self.dtype = dtype 89 | self.normalizer = Normalizer((d_obs,), np.float32) 90 | def add(self, obs): 91 | normalized_obs = self.normalizer.normalize_obs(obs) 92 | self.obs_ctx.append(normalized_obs) 93 | self.obs_ctx.pop(0) 94 | def get(self): 95 | return torch.tensor(np.stack(self.obs_ctx, 0), dtype=self.dtype).unsqueeze(0) 96 | def reset(self): 97 | self.obs_ctx = [np.zeros(self.d_obs) for i in range(self.ctx_size)] 98 | def set_normalizer(self, normalizer): 99 | self.normalizer = normalizer 100 | 101 | class Normalizer: 102 | def __init__(self, obs_shape, dtype): 103 | self.dtype = dtype 104 | self.obs_shape = obs_shape 105 | self.sum_obs = np.zeros(obs_shape, dtype=dtype) 106 | self.sumsq_obs = np.zeros(obs_shape, dtype=dtype) + 1e-3 107 | self.cnt_obs = 1e-3 108 | 109 | self.sum_obs_collect = np.zeros(obs_shape, dtype=dtype) 110 | self.sumsq_obs_collect = np.zeros(obs_shape, dtype=dtype) + 1e-3 111 | self.cnt_obs_collect = 1e-3 112 | def aggregate_collection(self, normalizers): 113 | for normalizer in normalizers: 114 | self.sum_obs += normalizer.sum_obs_collect 115 | self.sumsq_obs += normalizer.sumsq_obs_collect 116 | self.cnt_obs += normalizer.cnt_obs_collect 117 | def mean_std_obs(self): 118 | mean_obs = self.sum_obs/self.cnt_obs 119 | std_obs = np.sqrt(self.sumsq_obs/self.cnt_obs - mean_obs**2) 120 | return mean_obs, std_obs 121 | def normalize_obs(self, obs): 122 | self.collect_obs(obs) 123 | mean_obs, std_obs = self.mean_std_obs() 124 | normalized_obs = (obs - mean_obs) / (1e-3 + std_obs) 125 | return normalized_obs 126 | def collect_obs(self, obs): 127 | self.sum_obs_collect = self.sum_obs_collect + obs 128 | self.sumsq_obs_collect = self.sumsq_obs_collect + obs**2 129 | self.cnt_obs_collect = self.cnt_obs_collect + 1 130 | #self.sum_obs_collect += obs # WHY isn't it working? 131 | #self.sumsq_obs_collect += obs**2 132 | #self.cnt_obs_collect += 1 133 | return 134 | 135 | def compute_advantages(rs, dones, values): 136 | gamma = 0.95 137 | lambd = 0.95 138 | advantages = [] 139 | returns = [] 140 | adv = 0 141 | for tt in range(len(rs)-1, -1, -1): 142 | m = 1 - int(dones[tt]) 143 | delta = rs[tt] + gamma * values[tt+1] * m - values[tt] 144 | adv = delta + gamma * lambd * m * adv 145 | R = adv + values[tt] 146 | advantages.append(adv) 147 | returns.append(R) 148 | advantages.reverse() 149 | returns.reverse() 150 | returns = torch.tensor(np.array(returns)) 151 | advantages = torch.tensor(np.array(advantages)) 152 | return returns, advantages 153 | 154 | def ppo_loss(pred_act_dist, pred_value, acts, act_logprobs, returns, advantages): 155 | act_logprobs = act_logprobs.view(-1, 1) 156 | pred_value = pred_value.view(-1, 1) 157 | returns = returns.view(-1, 1) 158 | advantages = advantages.view(-1, 1) 159 | ratio = torch.exp(pred_act_dist.log_prob(acts) - act_logprobs) 160 | clipped_ratio = torch.clip(ratio, 1-0.2, 1+0.2) 161 | #g = torch.where(A>=0, (1+0.2)*advantages, (1-0.2)*advantages) 162 | #loss_policy = torch.mean(torch.min(ratio*advantages, g)) 163 | loss_policy = -torch.mean(torch.min(ratio * advantages, clipped_ratio * advantages)) 164 | loss_value = torch.mean( (pred_value - returns)**2 ) 165 | loss = loss_policy + loss_value 166 | return loss 167 | 168 | @ray.remote(num_gpus=0.5) 169 | class Worker: 170 | def __init__(self, params): 171 | #self.n_batches = params['n_batches'] 172 | self.batch_size = params['batch_size'] 173 | self.ctx_size = params['ctx_size'] 174 | self.dtype = params['dtype'] 175 | self.env = gym.make('Pendulum-v1') 176 | self.d_obs = self.env.observation_space.shape[0] 177 | self.d_act = self.env.action_space.shape[0] 178 | self.low = self.env.action_space.low 179 | self.high = self.env.action_space.high 180 | # 用于决策的神经网络 181 | self.agent = AgentModel(self.ctx_size, self.d_obs, self.d_act, self.low, self.high) 182 | self.opt = torch.optim.Adam(self.agent.parameters(), params['lr']) 183 | self.memory = ReplayMemory(self.dtype) 184 | self.context = Context(self.ctx_size, self.d_obs, self.dtype) 185 | self.T = 0 186 | self.rewards = [0.0] 187 | self.done = False 188 | def get_info_dims(self): 189 | return self.d_obs, self.d_act 190 | def get_weights(self): 191 | # 异步收集每个worker的权重用于平均 192 | return self.agent.state_dict() 193 | def get_avg_reward(self): 194 | # 异步收集当前任务成功率等信息 195 | avg_reward_finished = np.mean(self.rewards[-4:]) 196 | return avg_reward_finished 197 | def train_get_weights_infos(self): 198 | # 合并多个异步收集任务,防止时间不同步 199 | if self.T > self.ctx_size: 200 | # 若episode时间够长,则训练 201 | self.train_policy() 202 | self.memory.clear() 203 | return self.get_weights(), self.get_avg_reward(), self.context.normalizer 204 | def set_weights(self, w): 205 | # 为每个worker分发平均后的权重 206 | self.agent.load_state_dict(w) 207 | def set_normalizer(self, normalizer): 208 | self.context.set_normalizer(normalizer) 209 | def reset_initialize(self): 210 | # 初始化仿真环境,上下文和log信息 211 | self.context.reset() 212 | obs, _ = self.env.reset() 213 | self.context.add(obs) 214 | self.T = 0 215 | def train_policy(self): 216 | # episode结束,训练策略网络 217 | n_batches = int(self.T / self.batch_size) + 1 218 | 219 | obs_ctx = self.context.get() # 获取状态上下文 220 | _, _, _, last_value = self.agent(obs_ctx) # 获取不完全轨迹最后一个value用于bootstrap 221 | last_value = float(last_value.detach().numpy()[0]) 222 | 223 | # 计算Generalized Advantage Estimation 224 | self.opt.zero_grad() 225 | for _ in range(n_batches): 226 | obs_ctxs, acts, act_logprobs, returns, advantages = self.memory.sample(self.ctx_size, self.batch_size, last_value) # 从重放记忆中采样经验 227 | _, _, pred_act_dist, pred_value = self.agent(obs_ctxs) 228 | loss = ppo_loss(pred_act_dist, pred_value, acts, act_logprobs, returns, advantages) 229 | (loss/n_batches).backward() 230 | self.opt.step() 231 | def rollout(self, T_rollout): 232 | # 仿真循环,一直展开仿真到done为True 233 | for _ in range(T_rollout): 234 | if self.done: 235 | self.done = False 236 | self.reset_initialize() 237 | obs_ctx = self.context.get() # 获取状态上下文 238 | 239 | # 根据状态上下文决策,得到动作,概率,和价值 240 | act, act_logprob, _, value = self.agent(obs_ctx) 241 | act = act.detach().numpy()[0] 242 | act_logprob = float(act_logprob.detach().numpy()[0]) 243 | value = float(value.detach().numpy()[0]) 244 | 245 | # 仿真一步 246 | obs_, r, terminated, truncated, _ = self.env.step(act) 247 | self.done = terminated or truncated 248 | 249 | # 将历史经验加入重放记忆中 250 | self.memory.add(self.context.obs_ctx[-1], act, act_logprob, r, self.done, value) 251 | # 将需要累积的状态向量加入上下文 252 | self.context.add(obs_) 253 | self.rewards.append(r) 254 | self.T += 1 255 | return 256 | 257 | @ray.remote 258 | class WorkerCaller: 259 | def __init__(self, workers, rollout_steps): 260 | # 设置一个对应的worker 261 | self.workers = workers 262 | self.n_workers = len(workers) 263 | self.rollout_steps = rollout_steps 264 | def start(self): 265 | # 对workers持续不断地触发rollout函数 266 | finish_indicators = [worker.rollout.remote(self.rollout_steps) for worker in self.workers] 267 | while True: 268 | for i in range(self.n_workers): 269 | if is_ready(finish_indicators[i]): 270 | finish_indicators[i] = self.workers[i].rollout.remote(self.rollout_steps) 271 | 272 | def is_ready(obj): 273 | ready_oids, _ = ray.wait([obj]) 274 | if ready_oids: 275 | return True 276 | else: 277 | return False 278 | 279 | def run_parallel(): 280 | params = {'batch_size':64, 'ctx_size':8, 'lr':5e-4, 'n_episodes':99999999, 'n_workers':2, 'rollout_steps':128, 'dtype':torch.float32 } 281 | n_episodes = params['n_episodes'] 282 | n_workers = params['n_workers'] 283 | 284 | # 初始化worker 285 | workers = [Worker.remote(params) for i in range(n_workers)] 286 | avg_weight = ray.get(workers[0].get_weights.remote()) 287 | d_obs, d_act = ray.get(workers[0].get_info_dims.remote()) 288 | ray.get([worker.reset_initialize.remote() for worker in workers]) 289 | 290 | # 初始化标准化器 291 | normalizer = Normalizer((d_obs,), np.float32) 292 | 293 | # 初始化持续调用worker的caller 294 | worker_caller = WorkerCaller.remote(workers, params['rollout_steps']) 295 | 296 | # 启动worker的caller,开始持续异步触发worker的rollout函数 297 | worker_caller.start.remote() 298 | time.sleep(1) 299 | 300 | # 主循环 301 | for i_episodes in range(n_episodes): 302 | # 收集worker的权重,只要有一个未收集完就会阻塞在这里 303 | weights_infos = ray.get([worker.train_get_weights_infos.remote() for worker in workers]) 304 | workers_weights, workers_reward, workers_normalizer = zip(*weights_infos) 305 | # 计算平均权重 306 | avg_weight = {k:sum([workers_weights[wid][k] for wid in range(n_workers)])/n_workers for k in avg_weight.keys()} 307 | # 收集标准化器信息 308 | normalizer.aggregate_collection(workers_normalizer) 309 | 310 | # 非阻塞异步地分发权重给每个worker 311 | finish_setting_indicator = [] 312 | for worker in workers: 313 | finish_setting_indicator.append(worker.set_weights.remote(avg_weight)) 314 | finish_setting_indicator.append(worker.set_normalizer.remote(normalizer)) 315 | ray.get(finish_setting_indicator) 316 | 317 | # 处理所有worker的log信息 318 | avg_reward = sum(workers_reward)/n_workers 319 | print(avg_reward) 320 | time.sleep(0.5) 321 | if __name__ == '__main__': 322 | run_parallel() 323 | -------------------------------------------------------------------------------- /yaml_logging.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | import inspect 3 | import yaml 4 | import sys 5 | import os 6 | import tqdm 7 | import datetime 8 | import uuid 9 | import code 10 | 11 | def handle_exception(*args, **kwargs): 12 | vs = globals().copy() 13 | vs.update(locals()) 14 | shell = code.InteractiveConsole(vs) 15 | sys.__excepthook__(*args, **kwargs) 16 | shell.interact() 17 | return 18 | 19 | sys.excepthook = handle_exception 20 | 21 | os.environ['VAR_NAME']=sys.argv[1] # 设置与实验相关的环境变量(如CUDA_VISIBLE_DEVICES) 22 | experiment_id = str(uuid.uuid1())[:8] # 生成本次实验的UUID 23 | experiment_name = name # 设置描述本次实验的名称 24 | logger=dict() # 用字典保存代码、进程ID、配置参数、开始时间、训练时产生的数据等日志信息 25 | logger['experiment_id'] = experiment_id # 保存本次实验的UUID 26 | logger['experiment_name'] = experiment_name # 保存本次实验的名称 27 | logger['code']=inspect.getsource(sys.modules[__name__]) # 保存本次实验代码 28 | logger['pid']=os.getpid() # 保存本次实验进程PID 29 | logger['config']=config # 保存配置参数 30 | logger['datetime']=str(datetime.datetime.now()) # 保存训练开始时间 31 | logger['loss'] = [] # 保存loss日志 32 | logger['info'] = [] # 保存其他日志信息 33 | logger['env_vars'] = os.environ # 保存相关环境变量 34 | batch_cnt = 0 35 | log_freq = 100 36 | try: 37 | for i in tqdm.tqdm(range(N)): 38 | for x,y in dataset: 39 | loss=model.fit(x, y) # 反向传播 40 | logger['loss'].append(loss) 41 | logger['info'].append(info) 42 | batch_cnt += 1 43 | if batch_cnt % log_freq == 0: # 每log_freq个batch保存一次日志 44 | with open(experiment_name + experiment_id + '.log','w') as f: 45 | f.write(yaml.dump(logger, Dumper=yaml.CDumper)) # 使用yaml保存日志 46 | except KeyboardInterrupt: 47 | print('manully stop training...') 48 | except Exception: 49 | print(traceback.format_exc()) 50 | finally: 51 | postprocess(model) # 训练结束后处理部分,比如保存模型权重等信息到磁盘 52 | --------------------------------------------------------------------------------