├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── hpc_rll ├── __init__.py ├── origin │ ├── __init__.py │ ├── gae.py │ ├── padding.py │ ├── ppo.py │ ├── rnn.py │ ├── scatter_connection.py │ ├── td.py │ ├── upgo.py │ └── vtrace.py ├── rl_utils │ ├── __init__.py │ ├── gae.py │ ├── padding.py │ ├── ppo.py │ ├── td.py │ ├── upgo.py │ └── vtrace.py └── torch_utils │ ├── __init__.py │ └── network │ ├── __init__.py │ ├── rnn.py │ └── scatter_connection.py ├── include └── hpc │ └── rll │ └── cuda │ ├── basic_math.h │ ├── common.h │ ├── models │ ├── actor_critic_kernel.h │ └── entry.h │ ├── reduce.h │ ├── rl_utils │ ├── dist_nstep_td_kernel.h │ ├── entry.h │ ├── gae_kernel.h │ ├── iqn_nstep_td_error_kernel.h │ ├── padding_kernel.h │ ├── ppo_kernel.h │ ├── q_nstep_td_kernel.h │ ├── q_nstep_td_rescale_kernel.h │ ├── qrdqn_nstep_td_error_kernel.h │ ├── td_lambda_kernel.h │ ├── upgo_kernel.h │ └── vtrace_kernel.h │ ├── status.h │ └── torch_utils │ └── network │ ├── entry.h │ ├── lstm_kernel.h │ └── scatter_connection_kernel.h ├── setup.py ├── src ├── models │ ├── actor_critic.cu │ └── entry.cpp ├── rl_utils │ ├── dist_nstep_td.cu │ ├── entry.cpp │ ├── gae.cu │ ├── iqn_nstep_td_error.cu │ ├── padding.cu │ ├── ppo.cu │ ├── q_nstep_td.cu │ ├── q_nstep_td_rescale.cu │ ├── qrdqn_nstep_td_error.cu │ ├── td_lambda.cu │ ├── upgo.cu │ └── vtrace.cu └── torch_utils │ └── network │ ├── entry.cpp │ ├── lstm.cu │ └── scatter_connection.cu ├── tests ├── test_actor_critic.py ├── test_dntd.py ├── test_gae.py ├── test_iqn_nstep_td_error.py ├── test_lstm.py ├── test_padding.py ├── test_ppo.py ├── test_qntd.py ├── test_qntd_rescale.py ├── test_qrdqn_nstep_td_error.py ├── test_scatter.py ├── test_tdlambda.py ├── test_upgo.py ├── test_vtrace.py └── testbase.py └── triton_rl ├── README.md └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Python: 2 | *.py[cod] 3 | *.so 4 | *.egg 5 | *.egg-info 6 | dist 7 | build 8 | __pycache__ 9 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.6.0-devel-ubuntu20.04 AS di-hpc-develop 2 | 3 | ENV TZ="Asia/Beijing" 4 | 5 | ARG DEBIAN_FRONTEND="noninteractive" 6 | 7 | RUN mkdir -p /workspace 8 | WORKDIR /workspace 9 | 10 | RUN apt update \ 11 | && apt-get install build-essential checkinstall -y \ 12 | && apt-get install libreadline-gplv2-dev libncursesw5-dev libssl-dev libsqlite3-dev tk-dev libgdbm-dev libc6-dev libbz2-dev -y \ 13 | && apt install libffi-dev gnupg pciutils make wget git vim locales -y \ 14 | && apt clean \ 15 | && rm -rf /var/cache/apt/* \ 16 | && sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \ 17 | && locale-gen 18 | 19 | RUN cd /workspace/ \ 20 | && wget https://www.python.org/ftp/python/3.8.13/Python-3.8.13.tgz \ 21 | && tar xzf Python-3.8.13.tgz \ 22 | && cd Python-3.8.13 \ 23 | && ./configure --enable-optimizations \ 24 | && make altinstall \ 25 | && ln -s /usr/local/bin/python3.8 /usr/bin/python3.8.13 \ 26 | && ln -s /usr/local/bin/python3.8 /usr/bin/python3.8 \ 27 | && ln -s /usr/local/bin/python3.8 /usr/bin/python3 \ 28 | && ln -s /usr/local/bin/python3.8 /usr/bin/python \ 29 | && ln -s /usr/local/bin/pip3.8 /usr/bin/pip3 \ 30 | && ln -s /usr/local/bin/pip3.8 /usr/bin/pip \ 31 | && cd /workspace/ \ 32 | && rm -rf ./Python-3.8.13* 33 | 34 | RUN cd /workspace/ \ 35 | && wget https://download.pytorch.org/whl/cu113/torch-1.11.0%2Bcu113-cp38-cp38-linux_x86_64.whl -O torch-1.10.0+cu113-cp38-cp38-linux_x86_64.whl \ 36 | && pip install --no-cache-dir /workspace/torch-1.10.0+cu113-cp38-cp38-linux_x86_64.whl \ 37 | && rm /workspace/torch-1.10.0+cu113-cp38-cp38-linux_x86_64.whl 38 | 39 | ADD setup.py /workspace/setup.py 40 | ADD hpc_rll /workspace/hpc_rll 41 | ADD include /workspace/include 42 | ADD src /workspace/src 43 | ADD tests /workspace/tests 44 | 45 | RUN python /workspace/setup.py install 46 | 47 | FROM nvidia/cuda:11.6.0-runtime-ubuntu20.04 AS di-hpc-runtime 48 | 49 | COPY --from=di-hpc-develop /usr/local/bin/ /usr/local/bin/ 50 | COPY --from=di-hpc-develop /usr/local/include/ /usr/local/include/ 51 | COPY --from=di-hpc-develop /usr/local/lib/ /usr/local/lib/ 52 | COPY --from=di-hpc-develop /usr/local/share/ /usr/local/share/ 53 | 54 | RUN ln -s /usr/local/bin/python3.8 /usr/bin/python3 \ 55 | && ln -s /usr/local/bin/python3.8 /usr/bin/python \ 56 | && ln -s /usr/local/bin/pip3.8 /usr/bin/pip3 \ 57 | && ln -s /usr/local/bin/pip3.8 /usr/bin/pip 58 | 59 | FROM ubuntu:20.04 AS di-hpc-nightly 60 | 61 | COPY --from=di-hpc-develop /usr/local/bin/ /usr/local/bin/ 62 | COPY --from=di-hpc-develop /usr/local/include/ /usr/local/include/ 63 | COPY --from=di-hpc-develop /usr/local/lib/ /usr/local/lib/ 64 | COPY --from=di-hpc-develop /usr/local/share/ /usr/local/share/ 65 | 66 | RUN ln -s /usr/local/bin/python3.8 /usr/bin/python3 \ 67 | && ln -s /usr/local/bin/python3.8 /usr/bin/python \ 68 | && ln -s /usr/local/bin/pip3.8 /usr/bin/pip3 \ 69 | && ln -s /usr/local/bin/pip3.8 /usr/bin/pip 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DI-HPC: Decision Intelligence - High Performance Computation 2 | **DI-HPC** is an acceleration operator component for general algorithm modules in reinforcement learning algorithms, such as GAE, n-step TD and LSTM, etc. The operators support forward and backward propagation, and can be used in training, data collection, and test modules. 3 | 4 | ## Requirements 5 | #### Setting 1 6 | * CUDA 9.2 7 | * PyTorch 1.5 (recommend) 8 | * python 3.6 or python 3.7 or python3.8 9 | * Linux Platform 10 | 11 | #### Setting 2 12 | * CUDA 9.0 13 | * gcc 5.4.0 14 | * PyTorch 1.1.0 15 | * python 3.6 or python 3.7 16 | * Linux Platform 17 | 18 | *Note: We recommend that DI-HPC and DI-Engine share the same environment, and it should be fine with PyTorch from 1.1.0 to 1.10.0.* 19 | 20 | ## Quick Start 21 | #### Install from whl 22 | The easiest way to get DI-HPC is to use pip, and you can get `.whl` from 23 | * [di_hpc_rll-0.0.2-cp36-cp36m-linux_x86_64.whl](http://opendilab.org/download/DI-hpc/di_hpc_rll-0.0.2-cp36-cp36m-linux_x86_64.whl) 24 | * [di_hpc_rll-0.0.2-cp37-cp37m-linux_x86_64.whl](http://opendilab.org/download/DI-hpc/di_hpc_rll-0.0.2-cp37-cp37m-linux_x86_64.whl) 25 | * [di_hpc_rll-0.0.2-cp38-cp38-linux_x86_64.whl](http://opendilab.org/download/DI-hpc/di_hpc_rll-0.0.2-cp38-cp38-linux_x86_64.whl) 26 | 27 | and then call 28 | ``` 29 | $ pip install 30 | ``` 31 | 32 | #### Install from source code 33 | Alternatively you can install latest DI-HPC from git master branch: 34 | ``` 35 | $ python3 setup.py install 36 | ``` 37 | 38 | #### Run on Linux 39 | You will get benchmark result by following commands: 40 | ``` 41 | $ python3 tests/test_gae.py 42 | ``` 43 | ## TODO 44 | - [] Trition Kernel for Reinfocement Learning 45 | 46 | ## Feedback and Contribution 47 | 48 | - [File an issue](https://github.com/opendilab/DI-hpc/issues/new/choose) on Github 49 | - Discuss on DI-engine's (also for DI-hpc) [discord server](https://discord.gg/dkZS2JF56X) 50 | - Contact our email (opendilab@pjlab.org.cn) 51 | 52 | We appreciate all the feedbacks and contributions to improve DI-engine, both algorithms and system designs. And `CONTRIBUTING.md` offers some necessary information. 53 | 54 | 55 | ## License 56 | DI-hpc released under the Apache 2.0 license. 57 | -------------------------------------------------------------------------------- /hpc_rll/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-hpc/a8a773480571491e70bb3021cbb2c1adcb7dce12/hpc_rll/__init__.py -------------------------------------------------------------------------------- /hpc_rll/origin/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /hpc_rll/origin/gae.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | 4 | gae_data = namedtuple('gae_data', ['value', 'reward']) 5 | 6 | def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.FloatTensor: 7 | """ 8 | Overview: 9 | Implementation of Generalized Advantage Estimator (arXiv:1506.02438) 10 | Arguments: 11 | - data (:obj:`namedtuple`): gae input data with fields ['value', 'reward'], which contains some episodes or\ 12 | trajectories data 13 | - gamma (:obj:`float`): the future discount factor, should be in [0, 1], defaults to 0.99. 14 | - lambda (:obj:`float`): the gae parameter lambda, should be in [0, 1], defaults to 0.97, when lambda -> 0,\ 15 | it induces bias, but when lambda -> 1, it has high variance due to the sum of terms. 16 | Returns: 17 | - adv (:obj:`torch.FloatTensor`): the calculated advantage 18 | Shapes: 19 | - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)`, where T is trajectory length and B is batch size 20 | - reward (:obj:`torch.FloatTensor`): :math:`(T, B)` 21 | - adv (:obj:`torch.FloatTensor`): :math:`(T, B)` 22 | 23 | .. note:: 24 | value_{T+1} should be 0 if this trajectory reached a terminal state(done=True), otherwise we use value 25 | function, this operation is implemented in collector for packing trajectory. 26 | """ 27 | value, reward = data 28 | delta = reward + gamma * value[1:] - value[:-1] 29 | factor = gamma * lambda_ 30 | adv = torch.zeros_like(reward) 31 | gae_item = 0. 32 | denom = 0. 33 | for t in reversed(range(reward.shape[0])): 34 | denom = 1 + lambda_ * denom 35 | gae_item = denom * delta[t] + factor * gae_item 36 | adv[t] += gae_item / denom 37 | return adv 38 | 39 | -------------------------------------------------------------------------------- /hpc_rll/origin/padding.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Union 2 | from functools import reduce 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def cum(t: List[int]) -> int: 8 | return reduce(lambda x, y: x * y, t) 9 | 10 | 11 | def oracle_split_group(x: List[torch.Tensor], group: int) -> Tuple[List[Tuple], List[int]]: 12 | arr = [None] + [cum(t.shape) for t in x] 13 | N, M = len(arr) - 1, group 14 | 15 | def p(start, end): # cost of [start, end] in arr 16 | return arr[end] * (end - start + 1) # total cost is enough 17 | # return arr[end] * (end - start + 1) - sum(arr[start:end + 1]) 18 | 19 | # DP, time complex O(MN^2), space complex O(MN) 20 | f = {(0, 0): (0, 0)} 21 | for i, length_ in enumerate(arr[1:], start=1): 22 | for j in range(1, M + 1): 23 | ress = [] 24 | for k in range(0, i): 25 | if (k, j - 1) in f: 26 | last_cost, _ = f[(k, j - 1)] 27 | ress.append((last_cost + p(k + 1, i), k)) 28 | 29 | if ress: 30 | f[(i, j)] = min(ress) 31 | 32 | min_cost, _ = f[(N, M)] 33 | last_position, last_cnt = N, M 34 | positions = [N] 35 | while last_position > 0: 36 | _, last_position = f[(last_position, last_cnt)] 37 | last_cnt -= 1 38 | positions.append(last_position) 39 | 40 | assert len(positions) == M + 1 41 | positions = positions[::-1] 42 | 43 | # print(min_cost) # minimal cost 44 | # for i in range(0, M): # solution 45 | # start = positions[i] + 1 46 | # end = positions[i + 1] 47 | # cost = p(start, end) 48 | # print(i, arr[start:end + 1], start, end, cost) 49 | shapes = [x[i - 1].shape for i in positions[1:]] 50 | return shapes, positions 51 | 52 | 53 | def _Padding1D(x: List[torch.Tensor], value: int = 0) -> Tuple[torch.Tensor, torch.Tensor, List]: 54 | shapes = [t.shape for t in x] 55 | max_shape = [max(t) for t in list(zip(*shapes))] 56 | new_shape = [len(x)] + max_shape 57 | mask = torch.full(new_shape, fill_value=value, dtype=x[0].dtype, device=x[0].device) 58 | new_x = torch.full(new_shape, fill_value=value, dtype=x[0].dtype, device=x[0].device) 59 | for i in range(mask.shape[0]): 60 | idx = [i] + list(shapes[i]) 61 | mask[idx[0], :idx[1]] = 1 62 | new_x[idx[0], :idx[1]] = x[i] 63 | return new_x, mask, shapes 64 | 65 | 66 | def Padding1D(x: List[torch.Tensor], mode='constant', value: int = 0, group: int = 1, group_mode='sample') -> Tuple: 67 | assert mode in ['constant'], mode 68 | assert group_mode in ['sample', 'oracle'], group_mode 69 | assert group >= 1, group 70 | if group > 1: 71 | x = sorted(x, key=lambda t: cum(t.shape)) 72 | if group_mode == 'sample': 73 | sampled_idx = np.random.choice(len(x), group - 1) 74 | group_shape = [t.shape for i, t in enumerate(x) if i in sampled_idx] 75 | group_shape += [x[-1].shape] # max shape 76 | print('sample group_shape', group_shape) 77 | group_shape = list(set(group_shape)) # remove repeat shape 78 | group_shape = sorted(group_shape, key=lambda t: cum(t)) 79 | group_shape_idx = 0 80 | group_idx = [0] 81 | for i, t in enumerate(x): 82 | if cum(t.shape) > cum(group_shape[group_shape_idx]): 83 | group_idx.append(i) 84 | group_shape_idx += 1 85 | group_idx.append(len(x)) 86 | elif group_mode == 'oracle': 87 | group_shape, group_idx = oracle_split_group(x, group) 88 | print('group_shape', group_shape) 89 | assert len(group_idx) == len(group_shape) + 1 90 | 91 | ret = [_Padding1D(x[group_idx[i]:group_idx[i + 1]], value) for i in range(len(group_shape))] 92 | return list(zip(*ret)) 93 | else: 94 | return _Padding1D(x, value) 95 | 96 | 97 | def _UnPadding1D(x, shapes, deepcopy: bool = False): 98 | new_x = [] 99 | for i in range(x.shape[0]): 100 | idx = [i] + list(shapes[i]) 101 | item = x[idx[0], :idx[1]] 102 | if deepcopy: 103 | item = item.clone() 104 | new_x.append(item) 105 | return new_x 106 | 107 | 108 | def UnPadding1D(x: Union[torch.Tensor, List[torch.Tensor]], 109 | shapes: Union[List, List[List]], 110 | deepcopy: bool = False) -> List[torch.Tensor]: 111 | if isinstance(x, torch.Tensor): 112 | return _UnPadding1D(x, shapes, deepcopy) 113 | else: 114 | ret = [_UnPadding1D(t, s, deepcopy) for t, s in zip(x, shapes)] 115 | return sum(ret, []) 116 | 117 | 118 | def Padding2D(x: List[torch.Tensor], 119 | mode='constant', 120 | value: int = 0, 121 | group: int = 1) -> Tuple[torch.Tensor, torch.Tensor, List]: 122 | assert mode in ['constant'], mode 123 | assert group >= 1, group 124 | shapes = [t.shape for t in x] 125 | max_shape = [max(t) for t in list(zip(*shapes))] 126 | new_shape = [len(x)] + max_shape 127 | mask = torch.full(new_shape, fill_value=value, dtype=x[0].dtype, device=x[0].device) 128 | new_x = torch.full(new_shape, fill_value=value, dtype=x[0].dtype, device=x[0].device) 129 | for i in range(mask.shape[0]): 130 | idx = [i] + list(shapes[i]) 131 | mask[idx[0], :idx[1], :idx[2]] = 1 132 | new_x[idx[0], :idx[1], :idx[2]] = x[i] 133 | return new_x, mask, shapes 134 | 135 | 136 | def UnPadding2D(x: torch.Tensor, shapes: List, deepcopy: bool = False) -> List[torch.Tensor]: 137 | new_x = [] 138 | for i in range(x.shape[0]): 139 | idx = [i] + list(shapes[i]) 140 | item = x[idx[0], :idx[1], :idx[2]] 141 | if deepcopy: 142 | item = item.clone() 143 | new_x.append(item) 144 | return new_x 145 | 146 | 147 | def Padding3D(x: List[torch.Tensor], 148 | mode='constant', 149 | value: int = 0, 150 | group: int = 1) -> Tuple[torch.Tensor, torch.Tensor, List]: 151 | assert mode in ['constant'], mode 152 | assert group >= 1, group 153 | shapes = [t.shape for t in x] 154 | max_shape = [max(t) for t in list(zip(*shapes))] 155 | new_shape = [len(x)] + max_shape 156 | mask = torch.full(new_shape, fill_value=value, dtype=x[0].dtype, device=x[0].device) 157 | new_x = torch.full(new_shape, fill_value=value, dtype=x[0].dtype, device=x[0].device) 158 | for i in range(mask.shape[0]): 159 | idx = [i] + list(shapes[i]) 160 | mask[idx[0], :idx[1], :idx[2], :idx[3]] = 1 161 | new_x[idx[0], :idx[1], :idx[2], :idx[3]] = x[i] 162 | return new_x, mask, shapes 163 | 164 | 165 | def UnPadding3D(x: torch.Tensor, shapes: List, deepcopy: bool = False) -> List[torch.Tensor]: 166 | new_x = [] 167 | for i in range(x.shape[0]): 168 | idx = [i] + list(shapes[i]) 169 | item = x[idx[0], :idx[1], :idx[2], :idx[3]] 170 | if deepcopy: 171 | item = item.clone() 172 | new_x.append(item) 173 | return new_x 174 | 175 | 176 | -------------------------------------------------------------------------------- /hpc_rll/origin/ppo.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import Optional, Tuple 3 | import torch 4 | from torch.distributions import Independent, Normal 5 | 6 | ppo_data = namedtuple( 7 | 'ppo_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight'] 8 | ) 9 | ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss']) 10 | ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac']) 11 | 12 | 13 | def ppo_error( 14 | data: namedtuple, 15 | clip_ratio: float = 0.2, 16 | use_value_clip: bool = True, 17 | dual_clip: Optional[float] = None 18 | ) -> Tuple[namedtuple, namedtuple]: 19 | """ 20 | Overview: 21 | Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip 22 | Arguments: 23 | - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data`` 24 | - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2 25 | - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy 26 | - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\ 27 | defaults to 5.0, if you don't want to use it, set this parameter to None 28 | Returns: 29 | - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor 30 | - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar 31 | Shapes: 32 | - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim 33 | - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)` 34 | - action (:obj:`torch.LongTensor`): :math:`(B, )` 35 | - value_new (:obj:`torch.FloatTensor`): :math:`(B, )` 36 | - value_old (:obj:`torch.FloatTensor`): :math:`(B, )` 37 | - adv (:obj:`torch.FloatTensor`): :math:`(B, )` 38 | - return (:obj:`torch.FloatTensor`): :math:`(B, )` 39 | - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` 40 | - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor 41 | - value_loss (:obj:`torch.FloatTensor`): :math:`()` 42 | 43 | .. note:: 44 | adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many 45 | ways to calculate this mean and std, like among data buffer or train batch, so we don't couple 46 | this part into ppo_error, you can refer to our examples for different ways. 47 | """ 48 | assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format( 49 | dual_clip 50 | ) 51 | logit_new, logit_old, action, value_new, value_old, adv, return_, weight = data 52 | if weight is None: 53 | weight = torch.ones_like(adv) 54 | dist_new = torch.distributions.categorical.Categorical(logits=logit_new) 55 | dist_old = torch.distributions.categorical.Categorical(logits=logit_old) 56 | logp_new = dist_new.log_prob(action) 57 | logp_old = dist_old.log_prob(action) 58 | entropy_loss = (dist_new.entropy() * weight).mean() 59 | # policy_loss 60 | ratio = torch.exp(logp_new - logp_old) 61 | surr1 = ratio * adv 62 | surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv 63 | if dual_clip is not None: 64 | policy_loss = (-torch.max(torch.min(surr1, surr2), dual_clip * adv) * weight).mean() 65 | else: 66 | policy_loss = (-torch.min(surr1, surr2) * weight).mean() 67 | with torch.no_grad(): 68 | approx_kl = (logp_old - logp_new).mean().item() 69 | clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio) 70 | clipfrac = torch.as_tensor(clipped).float().mean().item() 71 | # value_loss 72 | if use_value_clip: 73 | value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio) 74 | v1 = (return_ - value_new).pow(2) 75 | v2 = (return_ - value_clip).pow(2) 76 | value_loss = 0.5 * (torch.max(v1, v2) * weight).mean() 77 | else: 78 | value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean() 79 | 80 | return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac) 81 | 82 | -------------------------------------------------------------------------------- /hpc_rll/origin/scatter_connection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Tuple 4 | 5 | 6 | class ScatterConnection(nn.Module): 7 | r""" 8 | Overview: 9 | Scatter feature to its corresponding location 10 | In alphastar, each entity is embedded into a tensor, these tensors are scattered into a feature map 11 | with map size 12 | """ 13 | 14 | def __init__(self, scatter_type) -> None: 15 | r""" 16 | Overview: 17 | Init class 18 | Arguments: 19 | - scatter_type (:obj:`str`): add or cover, if two entities have same location, scatter type decides the 20 | first one should be covered or added to second one 21 | """ 22 | super(ScatterConnection, self).__init__() 23 | self.scatter_type = scatter_type 24 | assert self.scatter_type in ['cover', 'add'] 25 | 26 | def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torch.Tensor) -> torch.Tensor: 27 | """ 28 | Overview: 29 | scatter x into a spatial feature map 30 | Arguments: 31 | - x (:obj:`tensor`): input tensor :math: `(B, M, N)` where `M` means the number of entity, `N` means\ 32 | the dimension of entity attributes 33 | - spatial_size (:obj:`tuple`): Tuple[H, W], the size of spatial feature x will be scattered into 34 | - location (:obj:`tensor`): :math: `(B, M, 2)` torch.LongTensor, each location should be (y, x) 35 | Returns: 36 | - output (:obj:`tensor`): :math: `(B, N, H, W)` where `H` and `W` are spatial_size, return the\ 37 | scattered feature map 38 | Shapes: 39 | - Input: :math: `(B, M, N)` where `M` means the number of entity, `N` means\ 40 | the dimension of entity attributes 41 | - Size: Tuple[H, W] 42 | - Location: :math: `(B, M, 2)` torch.LongTensor, each location should be (y, x) 43 | - Output: :math: `(B, N, H, W)` where `H` and `W` are spatial_size 44 | 45 | .. note:: 46 | when there are some overlapping in locations, ``cover`` mode will result in the loss of information, we 47 | use the addition as temporal substitute. 48 | """ 49 | device = x.device 50 | B, M, N = x.shape 51 | H, W = spatial_size 52 | index = location.view(-1, 2) 53 | bias = torch.arange(B).mul_(H * W).unsqueeze(1).repeat(1, M).view(-1).to(device) 54 | index = index[:, 0] * W + index[:, 1] 55 | index += bias 56 | index = index.repeat(N, 1) 57 | x = x.view(-1, N).permute(1, 0) 58 | output = torch.zeros(N, B * H * W, device=device) 59 | if self.scatter_type == 'cover': 60 | output.scatter_(dim=1, index=index, src=x) 61 | elif self.scatter_type == 'add': 62 | output.scatter_add_(dim=1, index=index, src=x) 63 | output = output.reshape(N, B, H, W) 64 | output = output.permute(1, 0, 2, 3).contiguous() 65 | return output 66 | 67 | -------------------------------------------------------------------------------- /hpc_rll/origin/upgo.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import torch 3 | import torch.nn.functional as F 4 | from .td import generalized_lambda_returns 5 | 6 | 7 | def tb_cross_entropy(logit, label, mask=None): 8 | assert (len(label.shape) >= 2) 9 | T, B = label.shape[:2] 10 | # Special 2D case 11 | assert len(label.shape) == 2 12 | assert mask is None 13 | 14 | label = label.reshape(-1) 15 | logit = logit.reshape(-1, logit.shape[-1]) 16 | ce = -F.cross_entropy(logit, label, reduction='none') 17 | ce = ce.reshape(T, B, -1) 18 | return ce.mean(dim=2) 19 | 20 | 21 | def upgo_returns(rewards: torch.Tensor, bootstrap_values: torch.Tensor) -> torch.Tensor: 22 | r""" 23 | Overview: 24 | Computing UPGO return targets. Also notice there is no special handling for the terminal state. 25 | Arguments: 26 | - rewards (:obj:`torch.Tensor`): the returns from time step 0 to T-1, \ 27 | of size [T_traj, batchsize] 28 | - bootstrap_values (:obj:`torch.Tensor`): estimation of the state value at step 0 to T, \ 29 | of size [T_traj+1, batchsize] 30 | Returns: 31 | - ret (:obj:`torch.Tensor`): Computed lambda return value for each state from 0 to T-1, \ 32 | of size [T_traj, batchsize] 33 | """ 34 | # UPGO can be viewed as a lambda return! The trace continues for V_t (i.e. lambda = 1.0) if r_tp1 + V_tp2 > V_tp1. 35 | # as the lambdas[-1, :] is ignored in generalized_lambda_returns, we don't care about bootstrap_values_tp2[-1] 36 | lambdas = (rewards + bootstrap_values[1:]) >= bootstrap_values[:-1] 37 | lambdas = torch.cat([lambdas[1:], torch.ones_like(lambdas[-1:])], dim=0) 38 | return generalized_lambda_returns(bootstrap_values, rewards, 1.0, lambdas) 39 | 40 | def upgo_loss( 41 | target_output: torch.Tensor, 42 | rhos: torch.Tensor, 43 | action: torch.Tensor, 44 | rewards: torch.Tensor, 45 | bootstrap_values: torch.Tensor, 46 | mask=None 47 | ) -> torch.Tensor: 48 | r""" 49 | Overview: 50 | Computing UPGO loss given constant gamma and lambda. There is no special handling for terminal state value, 51 | if the last state in trajectory is the terminal, just pass a 0 as bootstrap_terminal_value. 52 | Arguments: 53 | - target_output (:obj:`torch.Tensor`): the output computed by the target policy network, \ 54 | of size [T_traj, batchsize, n_output] 55 | - rhos (:obj:`torch.Tensor`): the importance sampling ratio, of size [T_traj, batchsize] 56 | - action (:obj:`torch.Tensor`): the action taken, of size [T_traj, batchsize] 57 | - rewards (:obj:`torch.Tensor`): the returns from time step 0 to T-1, of size [T_traj, batchsize] 58 | - bootstrap_values (:obj:`torch.Tensor`): estimation of the state value at step 0 to T, \ 59 | of size [T_traj+1, batchsize] 60 | Returns: 61 | - loss (:obj:`torch.Tensor`): Computed importance sampled UPGO loss, averaged over the samples, of size [] 62 | """ 63 | # discard the value at T as it should be considered in the next slice 64 | with torch.no_grad(): 65 | returns = upgo_returns(rewards, bootstrap_values) 66 | advantages = rhos * (returns - bootstrap_values[:-1]) 67 | metric = tb_cross_entropy(target_output, action, mask) 68 | assert (metric.shape == action.shape[:2]) 69 | losses = advantages * metric 70 | return -losses.mean() 71 | 72 | -------------------------------------------------------------------------------- /hpc_rll/origin/vtrace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from collections import namedtuple 4 | 5 | def vtrace_nstep_return(clipped_rhos, clipped_cs, reward, bootstrap_values, gamma=0.99, lambda_=0.95): 6 | deltas = clipped_rhos * (reward + gamma * bootstrap_values[1:] - bootstrap_values[:-1]) 7 | factor = gamma * lambda_ 8 | result = bootstrap_values[:-1].clone() 9 | vtrace_item = 0. 10 | for t in reversed(range(reward.size()[0])): 11 | vtrace_item = deltas[t] + factor * clipped_cs[t] * vtrace_item 12 | result[t] += vtrace_item 13 | return result 14 | 15 | 16 | def vtrace_advantage(clipped_pg_rhos, reward, return_, bootstrap_values, gamma): 17 | return clipped_pg_rhos * (reward + gamma * return_ - bootstrap_values) 18 | 19 | 20 | vtrace_data = namedtuple('vtrace_data', ['target_output', 'behaviour_output', 'action', 'value', 'reward', 'weight']) 21 | vtrace_loss = namedtuple('vtrace_loss', ['policy_loss', 'value_loss', 'entropy_loss']) 22 | 23 | 24 | def vtrace_error( 25 | data: namedtuple, 26 | gamma: float = 0.99, 27 | lambda_: float = 0.95, 28 | rho_clip_ratio: float = 1.0, 29 | c_clip_ratio: float = 1.0, 30 | rho_pg_clip_ratio: float = 1.0 31 | ): 32 | """ 33 | Overview: 34 | Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\ 35 | Architectures), (arXiv:1802.01561) 36 | Arguments: 37 | - data (:obj:`namedtuple`): input data with fieids shown in ``vtrace_data`` 38 | - target_output (:obj:`torch.Tensor`): the output taking the action by the current policy network,\ 39 | usually this output is network output logit 40 | - behaviour_output (:obj:`torch.Tensor`): the output taking the action by the behaviour policy network,\ 41 | usually this output is network output logit, which is used to produce the trajectory(actor) 42 | - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\ 43 | i.e.: behaviour_action 44 | - gamma: (:obj:`float`): the future discount factor, defaults to 0.95 45 | - lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 46 | - rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ 47 | the baseline targets (vs) 48 | - c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\ 49 | the baseline targets (vs) 50 | - rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ 51 | the policy gradient advantage 52 | Returns: 53 | - trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor 54 | Shapes: 55 | - target_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where T is timestep, B is batch size and\ 56 | N is action dim 57 | - behaviour_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)` 58 | - action (:obj:`torch.LongTensor`): :math:`(T, B)` 59 | - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)` 60 | - reward (:obj:`torch.LongTensor`): :math:`(T, B)` 61 | - weight (:obj:`torch.LongTensor`): :math:`(T, B)` 62 | """ 63 | target_output, behaviour_output, action, value, reward, weight = data 64 | with torch.no_grad(): 65 | IS = compute_importance_weights(target_output, behaviour_output, action) 66 | rhos = torch.clamp(IS, max=rho_clip_ratio) 67 | cs = torch.clamp(IS, max=c_clip_ratio) 68 | return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_) 69 | pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio) 70 | return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0) 71 | adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma) 72 | 73 | if weight is None: 74 | weight = torch.ones_like(reward) 75 | dist_target = torch.distributions.Categorical(logits=target_output) 76 | pg_loss = -(dist_target.log_prob(action) * adv * weight).mean() 77 | value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean() 78 | entropy_loss = (dist_target.entropy() * weight).mean() 79 | return vtrace_loss(pg_loss, value_loss, entropy_loss) 80 | 81 | def compute_importance_weights(target_output, behaviour_output, action, requires_grad=False): 82 | """ 83 | Overview: 84 | Computing importance sampling weight with given output and action 85 | Arguments: 86 | - target_output (:obj:`torch.Tensor`): the output taking the action by the current policy network,\ 87 | usually this output is network output logit 88 | - behaviour_output (:obj:`torch.Tensor`): the output taking the action by the behaviour policy network,\ 89 | usually this output is network output logit, which is used to produce the trajectory(actor) 90 | - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\ 91 | i.e.: behaviour_action 92 | - requires_grad (:obj:`bool`): whether requires grad computation 93 | Returns: 94 | - rhos (:obj:`torch.Tensor`): Importance sampling weight 95 | Shapes: 96 | - target_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where T is timestep, B is batch size and\ 97 | N is action dim 98 | - behaviour_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)` 99 | - action (:obj:`torch.LongTensor`): :math:`(T, B)` 100 | - rhos (:obj:`torch.FloatTensor`): :math:`(T, B)` 101 | """ 102 | grad_context = torch.enable_grad() if requires_grad else torch.no_grad() 103 | assert isinstance(action, torch.Tensor) 104 | device = action.device 105 | 106 | with grad_context: 107 | dist_target = torch.distributions.Categorical(logits=target_output) 108 | dist_behaviour = torch.distributions.Categorical(logits=behaviour_output) 109 | rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action) 110 | rhos = torch.exp(rhos) 111 | return rhos 112 | 113 | -------------------------------------------------------------------------------- /hpc_rll/rl_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-hpc/a8a773480571491e70bb3021cbb2c1adcb7dce12/hpc_rll/rl_utils/__init__.py -------------------------------------------------------------------------------- /hpc_rll/rl_utils/gae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hpc_rl_utils 3 | 4 | # hpc version only support cuda 5 | 6 | class GAEFunction(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, value, reward, gamma, lambda_, adv): 9 | 10 | inputs = [value, reward] 11 | outputs = [adv] 12 | hpc_rl_utils.GaeForward(inputs, outputs, gamma, lambda_) 13 | 14 | return adv 15 | 16 | @staticmethod 17 | def backward(ctx, grad_adv): 18 | return None, None, None, None, None 19 | 20 | class GAE(torch.nn.Module): 21 | """ 22 | Overview: 23 | Implementation of Generalized Advantage Estimator (arXiv:1506.02438) 24 | 25 | Interface: 26 | __init__, forward 27 | """ 28 | def __init__(self, T, B): 29 | r""" 30 | Overview 31 | initialization of gae 32 | 33 | Arguments: 34 | - T (:obj:`int`): trajectory length 35 | - B (:obj:`int`): batch size 36 | """ 37 | 38 | super().__init__() 39 | self.register_buffer('adv', torch.zeros(T, B)) 40 | 41 | def forward(self, value, reward, gamma: float = 0.99, lambda_: float = 0.97) -> torch.FloatTensor: 42 | """ 43 | Overview: 44 | forward of gae 45 | Arguments: 46 | - value (:obj:`torch.FloatTensor`): :math:`(T + 1, B)`, gae input data 47 | - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, gae input data 48 | - gamma (:obj:`float`): the future discount factor, should be in [0, 1], defaults to 0.99. 49 | - lambda (:obj:`float`): the gae parameter lambda, should be in [0, 1], defaults to 0.97, when lambda -> 0,\ 50 | it induces bias, but when lambda -> 1, it has high variance due to the sum of terms. 51 | Returns: 52 | - adv (:obj:`torch.FloatTensor`): :math:`(T, B)`, the calculated advantage 53 | 54 | .. note:: 55 | value_{T+1} should be 0 if this trajectory reached a terminal state(done=True), otherwise we use value 56 | function, this operation is implemented in actor for packing trajectory. 57 | """ 58 | assert(value.is_cuda) 59 | assert(reward.is_cuda) 60 | 61 | return GAEFunction.apply(value, reward, gamma, lambda_, self.adv) 62 | 63 | -------------------------------------------------------------------------------- /hpc_rll/rl_utils/padding.py: -------------------------------------------------------------------------------- 1 | from fcntl import DN_DELETE 2 | from typing import Tuple, List, Union 3 | from functools import reduce 4 | import torch 5 | import numpy as np 6 | import hpc_rl_utils 7 | import pdb 8 | 9 | # hpc version only support cuda 10 | 11 | def cum(t: List[int]) -> int: 12 | return reduce(lambda x, y: x * y, t) 13 | 14 | def Padding1D(inputs: List[torch.Tensor], mode='constant', value: int = 0, group: int = 1, group_mode='sample'): 15 | assert mode in ['constant'], mode 16 | assert group_mode in ['sample', 'oracle'], group_mode 17 | assert group >= 1, group 18 | if group > 1: 19 | inputs = sorted(inputs, key=lambda t: cum(t.shape)) 20 | if group_mode == 'sample': 21 | res = hpc_rl_utils.sample_split_group(inputs, group) 22 | group_idx = res[-1] 23 | group_shape = res[:-1] 24 | elif group_mode == 'oracle': 25 | res = hpc_rl_utils.oracle_split_group(inputs, group) 26 | group_idx = res[-1] 27 | group_shape = res[:-1] 28 | assert len(group_idx) == len(group_shape) + 1 29 | max_shape = [s[0] for s in group_shape] 30 | group_num = [(group_idx[i+1] - group_idx[i]) for i in range(len(group_shape))] 31 | shapes = [] 32 | group_id = [] 33 | k = 0 34 | for i in range(len(group_num)): 35 | shape = [] 36 | for j in range(group_num[i]): 37 | shape.append(inputs[k].shape[0]) 38 | group_id.append(i) 39 | k = k + 1 40 | shapes.append(shape) 41 | assert len(group_id) == len(inputs) 42 | result = hpc_rl_utils.GroupPad1DForward(inputs, group_num, max_shape, group_id, group_idx, value) 43 | new_x = result[0] 44 | mask = result[1] 45 | return [tuple(new_x), tuple(mask), tuple(shapes)] 46 | else: 47 | shapes = [t.shape[0] for t in inputs] 48 | result = hpc_rl_utils.Pad1DForward(inputs, value) 49 | new_x = result[0] 50 | mask = result[1] 51 | return new_x, mask, shapes 52 | 53 | 54 | def UnPadding1D(x: Union[torch.Tensor, List[torch.Tensor]], 55 | shapes: Union[List, List[List]]) -> List[torch.Tensor]: 56 | if isinstance(x, torch.Tensor): 57 | return hpc_rl_utils.Unpad1DForward(x, shapes) 58 | else: 59 | ret = [] 60 | for t, s in zip(x, shapes): 61 | ret.append(hpc_rl_utils.Unpad1DForward(t, s)) 62 | return sum(ret, []) 63 | 64 | def Padding2D(inputs: List[torch.Tensor], mode='constant', value: int = 0, group: int = 1, group_mode='sample'): 65 | assert mode in ['constant'], mode 66 | assert group_mode in ['sample', 'oracle'], group_mode 67 | assert group >= 1, group 68 | if group > 1: 69 | inputs = sorted(inputs, key=lambda t: cum(t.shape)) 70 | if group_mode == 'sample': 71 | res = hpc_rl_utils.sample_split_group(inputs, group) 72 | group_idx = res[-1] 73 | group_shape = res[:-1] 74 | elif group_mode == 'oracle': 75 | res = hpc_rl_utils.oracle_split_group(inputs, group) 76 | group_idx = res[-1] 77 | group_shape = res[:-1] 78 | assert len(group_idx) == len(group_shape) + 1 79 | max_shape = [] 80 | for s in group_shape: 81 | max_shape.append(s[0]) 82 | max_shape.append(s[1]) 83 | group_cnt = [(group_idx[i+1] - group_idx[i]) for i in range(len(group_shape))] 84 | shapes = [] 85 | group_id = [] 86 | k = 0 87 | for i in range(len(group_cnt)): 88 | shape = [] 89 | for j in range(group_cnt[i]): 90 | shape.append(inputs[k].shape[0]) 91 | shape.append(inputs[k].shape[1]) 92 | group_id.append(i) 93 | k = k + 1 94 | shapes.append(shape) 95 | assert len(group_id) == len(inputs) 96 | result = hpc_rl_utils.GroupPad2DForward(inputs, group_cnt, max_shape, group_id, group_idx, value) 97 | new_x = result[0] 98 | mask = result[1] 99 | return [tuple(new_x), tuple(mask), tuple(shapes)] 100 | else: 101 | shapes = [] 102 | for t in inputs: 103 | shapes.append(t.shape[0]) 104 | shapes.append(t.shape[1]) 105 | result = hpc_rl_utils.Pad2DForward(inputs, value) 106 | new_x = result[0] 107 | mask = result[1] 108 | return new_x, mask, shapes 109 | 110 | 111 | def UnPadding2D(x: Union[torch.Tensor, List[torch.Tensor]], 112 | shapes: Union[List, List[List]]) -> List[torch.Tensor]: 113 | if isinstance(x, torch.Tensor): 114 | return hpc_rl_utils.Unpad2DForward(x, shapes) 115 | else: 116 | ret = [] 117 | for t, s in zip(x, shapes): 118 | ret.append(hpc_rl_utils.Unpad2DForward(t, s)) 119 | return sum(ret, []) 120 | 121 | def Padding3D(inputs: List[torch.Tensor], mode='constant', value: int = 0, group: int = 1, group_mode='sample'): 122 | assert mode in ['constant'], mode 123 | assert group_mode in ['sample', 'oracle'], group_mode 124 | assert group >= 1, group 125 | if group > 1: 126 | inputs = sorted(inputs, key=lambda t: cum(t.shape)) 127 | if group_mode == 'sample': 128 | res = hpc_rl_utils.sample_split_group(inputs, group) 129 | group_idx = res[-1] 130 | group_shape = res[:-1] 131 | elif group_mode == 'oracle': 132 | res = hpc_rl_utils.oracle_split_group(inputs, group) 133 | group_idx = res[-1] 134 | group_shape = res[:-1] 135 | assert len(group_idx) == len(group_shape) + 1 136 | max_shape = [] 137 | for s in group_shape: 138 | max_shape.append(s[0]) 139 | max_shape.append(s[1]) 140 | max_shape.append(s[2]) 141 | 142 | group_cnt = [(group_idx[i+1] - group_idx[i]) for i in range(len(group_shape))] 143 | shapes = [] 144 | group_id = [] 145 | k = 0 146 | for i in range(len(group_cnt)): 147 | shape = [] 148 | for j in range(group_cnt[i]): 149 | shape.append(inputs[k].shape[0]) 150 | shape.append(inputs[k].shape[1]) 151 | shape.append(inputs[k].shape[2]) 152 | group_id.append(i) 153 | k = k + 1 154 | shapes.append(shape) 155 | assert len(group_id) == len(inputs) 156 | result = hpc_rl_utils.GroupPad3DForward(inputs, group_cnt, max_shape, group_id, group_idx, value) 157 | new_x = result[0] 158 | mask = result[1] 159 | return [tuple(new_x), tuple(mask), tuple(shapes)] 160 | else: 161 | shapes = [] 162 | for t in inputs: 163 | shapes.append(t.shape[0]) 164 | shapes.append(t.shape[1]) 165 | shapes.append(t.shape[2]) 166 | result = hpc_rl_utils.Pad3DForward(inputs, value) 167 | new_x = result[0] 168 | mask = result[1] 169 | return new_x, mask, shapes 170 | 171 | 172 | def UnPadding3D(x: Union[torch.Tensor, List[torch.Tensor]], 173 | shapes: Union[List, List[List]]) -> List[torch.Tensor]: 174 | if isinstance(x, torch.Tensor): 175 | return hpc_rl_utils.Unpad3DForward(x, shapes) 176 | else: 177 | ret = [] 178 | for t, s in zip(x, shapes): 179 | ret.append(hpc_rl_utils.Unpad3DForward(t, s)) 180 | return sum(ret, []) 181 | 182 | 183 | -------------------------------------------------------------------------------- /hpc_rll/rl_utils/upgo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import hpc_rl_utils 3 | 4 | # hpc version only support cuda 5 | # 需排除spe2d case 6 | 7 | class UpgoFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, target_output, rho, action, reward, value, advantage, metric, loss, grad_buf, grad_target_output): 10 | inputs = [target_output, rho, action, reward, value] 11 | outputs = [advantage, metric, loss, grad_buf] 12 | hpc_rl_utils.UpgoForward(inputs, outputs) 13 | 14 | ctx.bp_inputs = [grad_buf, advantage] 15 | ctx.bp_outputs = [grad_target_output] 16 | 17 | return loss 18 | 19 | @staticmethod 20 | def backward(ctx, grad_loss): 21 | inputs = [grad_loss] 22 | for var in ctx.bp_inputs: 23 | inputs.append(var) 24 | outputs = ctx.bp_outputs 25 | 26 | hpc_rl_utils.UpgoBackward(inputs, outputs) 27 | grad_target_output = outputs[0] 28 | return grad_target_output, None, None, None, None, None, None, None, None, None 29 | 30 | class UPGO(torch.nn.Module): 31 | """ 32 | Overview: 33 | Computing UPGO loss given constant gamma and lambda. There is no special handling for terminal state value, 34 | if the last state in trajectory is the terminal, just pass a 0 as bootstrap_terminal_value. 35 | 36 | Interface: 37 | __init__, forward 38 | """ 39 | 40 | def __init__(self, T, B, N): 41 | r""" 42 | Overview 43 | initialization of UPGO 44 | 45 | Arguments: 46 | - T (:obj:`int`): trajectory length 47 | - B (:obj:`int`): batch size 48 | - N (:obj:`int`): number of output 49 | """ 50 | 51 | super().__init__() 52 | self.register_buffer('loss', torch.zeros(1)) 53 | self.register_buffer('advantage', torch.zeros(T, B)) 54 | self.register_buffer('metric', torch.zeros(T, B)) 55 | self.register_buffer('grad_buf', torch.zeros(T, B, N)) 56 | self.register_buffer('grad_target_output', torch.zeros(T, B, N)) 57 | 58 | def forward(self, target_output, rhos, action, rewards, bootstrap_values): 59 | """ 60 | Overview: 61 | forward of UPGO 62 | Arguments: 63 | - target_output (:obj:`torch.Tensor`): :math:`(T, B, N)`, the output computed by the target policy network 64 | - rhos (:obj:`torch.Tensor`): :math:`(T, B)`, the importance sampling ratio 65 | - action (:obj:`torch.Tensor`): :math:`(T, B)`, the action taken 66 | - rewards (:obj:`torch.Tensor`): :math:`(T, B)`, the returns from time step 0 to T-1 67 | - bootstrap_values (:obj:`torch.Tensor`): :math:`(T + 1, B)`, estimation of the state value at step 0 to T 68 | Returns: 69 | - loss (:obj:`torch.Tensor`): :math:`()`, 0-dim tensor, Computed importance sampled UPGO loss, averaged over the samples 70 | """ 71 | assert(target_output.is_cuda) 72 | assert(rhos.is_cuda) 73 | assert(action.is_cuda) 74 | assert(rewards.is_cuda) 75 | assert(bootstrap_values.is_cuda) 76 | 77 | loss = UpgoFunction.apply(target_output, rhos, action, rewards, bootstrap_values, 78 | self.advantage, self.metric, self.loss, self.grad_buf, self.grad_target_output) 79 | return loss 80 | 81 | -------------------------------------------------------------------------------- /hpc_rll/rl_utils/vtrace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from collections import namedtuple 4 | import hpc_rl_utils 5 | 6 | # hpc version only support cuda 7 | 8 | hpc_vtrace_loss = namedtuple('hpc_vtrace_loss', ['policy_loss', 'value_loss', 'entropy_loss']) 9 | 10 | class VtraceFunction(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, target_output, behaviour_output, 13 | action, value, reward, weight, gamma, lambda_, rho_clip_ratio, c_clip_ratio, rho_pg_clip_ratio, 14 | target_output_prob, target_output_entropy, 15 | target_output_grad_logits, target_output_grad_prob, target_output_grad_entropy, behaviour_output_prob, 16 | importance_weights, returns, advantages, pg_loss, value_loss, entropy_loss, grad_value, grad_target_output): 17 | 18 | inputs = [target_output, behaviour_output, action, value, reward, weight] 19 | outputs = [target_output_prob, target_output_entropy, 20 | target_output_grad_logits, target_output_grad_prob, target_output_grad_entropy, behaviour_output_prob, 21 | importance_weights, returns, advantages, pg_loss, value_loss, entropy_loss] 22 | 23 | hpc_rl_utils.VTraceForward(inputs, outputs, gamma, lambda_, rho_clip_ratio, c_clip_ratio, rho_pg_clip_ratio) 24 | 25 | bp_inputs = [value, action, weight, returns, advantages, target_output_grad_logits, target_output_grad_prob, target_output_grad_entropy] 26 | bp_outputs = [grad_value, grad_target_output] 27 | ctx.bp_inputs = bp_inputs 28 | ctx.bp_outputs = bp_outputs 29 | 30 | return pg_loss, value_loss, entropy_loss 31 | 32 | @staticmethod 33 | def backward(ctx, grad_pg_loss, grad_value_loss, grad_entropy_loss): 34 | inputs = [grad_pg_loss, grad_value_loss, grad_entropy_loss] 35 | for var in ctx.bp_inputs: 36 | inputs.append(var) 37 | outputs = ctx.bp_outputs 38 | 39 | hpc_rl_utils.VTraceBackward(inputs, outputs) 40 | 41 | grad_value = outputs[0] 42 | grad_target_output = outputs[1] 43 | return grad_target_output, None, None, grad_value, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None 44 | 45 | class VTrace(torch.nn.Module): 46 | """ 47 | Overview: 48 | Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\ 49 | Architectures), (arXiv:1802.01561) 50 | 51 | Interface: 52 | __init__, forward 53 | """ 54 | 55 | def __init__(self, T, B, N): 56 | r""" 57 | Overview 58 | initialization of Vtrace 59 | 60 | Arguments: 61 | - T (:obj:`int`): trajectory length 62 | - B (:obj:`int`): batch size 63 | - N (:obj:`int`): number of output 64 | """ 65 | 66 | super().__init__() 67 | self.register_buffer('weight', torch.ones(T, B)) 68 | self.register_buffer('target_output_prob', torch.zeros(T, B)) 69 | self.register_buffer('target_output_entropy', torch.zeros(T, B)) 70 | self.register_buffer('target_output_grad_logits', torch.zeros(T, B, N)) 71 | self.register_buffer('target_output_grad_prob', torch.zeros(T, B, N)) 72 | self.register_buffer('target_output_grad_entropy', torch.zeros(T, B, N)) 73 | self.register_buffer('behaviour_output_prob', torch.zeros(T, B)) 74 | self.register_buffer('importance_weights', torch.zeros(T, B)) 75 | self.register_buffer('returns', torch.zeros(T, B)) 76 | self.register_buffer('advantages', torch.zeros(T, B)) 77 | self.register_buffer('pg_loss', torch.zeros(1)) 78 | self.register_buffer('value_loss', torch.zeros(1)) 79 | self.register_buffer('entropy_loss', torch.zeros(1)) 80 | self.register_buffer('grad_value', torch.zeros(T + 1, B)) 81 | self.register_buffer('grad_target_output', torch.zeros(T, B, N)) 82 | 83 | def forward(self, target_output, behaviour_output, action, value, reward, 84 | weight = None, 85 | gamma: float = 0.99, 86 | lambda_: float = 0.95, 87 | rho_clip_ratio: float = 1.0, 88 | c_clip_ratio: float = 1.0, 89 | rho_pg_clip_ratio: float = 1.0 90 | ): 91 | """ 92 | Overview: 93 | forward of Vtrace 94 | Arguments: 95 | - target_output (:obj:`torch.Tensor`): :math:`(T, B, N)`, the output taking the action by the current policy network,\ 96 | usually this output is network output logit 97 | - behaviour_output (:obj:`torch.Tensor`): :math:`(T, B, N)`, the output taking the action by the behaviour policy network,\ 98 | usually this output is network output logit, which is used to produce the trajectory(actor) 99 | - action (:obj:`torch.Tensor`): :math:`(T, B)`, the chosen action(index for the discrete action space) in trajectory,\ 100 | i.e.: behaviour_action 101 | - value (:obj:`torch.Tensor`): :math:`(T + 1, B)`, estimation of the state value at step 0 to T 102 | - reward (:obj:`torch.Tensor`): :math:`(T, B)`, the returns from time step 0 to T-1 103 | - gamma: (:obj:`float`): the future discount factor, defaults to 0.95 104 | - lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0 105 | - rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ 106 | the baseline targets (vs) 107 | - c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\ 108 | the baseline targets (vs) 109 | - rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\ 110 | the policy gradient advantage 111 | 112 | Returns: 113 | - trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor 114 | """ 115 | 116 | assert(target_output.is_cuda) 117 | assert(behaviour_output.is_cuda) 118 | assert(action.is_cuda) 119 | assert(value.is_cuda) 120 | assert(reward.is_cuda) 121 | if weight is None: 122 | weight = self.weight 123 | else: 124 | assert(weight.is_cuda) 125 | 126 | pg_loss, value_loss, entropy_loss = VtraceFunction.apply(target_output, behaviour_output, 127 | action, value, reward, weight, gamma, lambda_, rho_clip_ratio, c_clip_ratio, rho_pg_clip_ratio, 128 | self.target_output_prob, self.target_output_entropy, 129 | self.target_output_grad_logits, self.target_output_grad_prob, self.target_output_grad_entropy, self.behaviour_output_prob, 130 | self.importance_weights, self.returns, self.advantages, 131 | self.pg_loss, self.value_loss, self.entropy_loss, self.grad_value, self.grad_target_output) 132 | 133 | return hpc_vtrace_loss(pg_loss, value_loss, entropy_loss) 134 | -------------------------------------------------------------------------------- /hpc_rll/torch_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-hpc/a8a773480571491e70bb3021cbb2c1adcb7dce12/hpc_rll/torch_utils/__init__.py -------------------------------------------------------------------------------- /hpc_rll/torch_utils/network/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /hpc_rll/torch_utils/network/scatter_connection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | import hpc_torch_utils_network 4 | 5 | # hpc version only support cuda 6 | 7 | class ScatterConnectionFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, input, location, output, grad_in, scatter_type): 10 | inputs = [input, location] 11 | outputs = [output] 12 | hpc_torch_utils_network.ScatterConnectionForward(inputs, outputs, scatter_type) 13 | 14 | ctx.bp_inputs = [location] 15 | ctx.bp_outputs = [grad_in] 16 | 17 | return output 18 | 19 | @staticmethod 20 | def backward(ctx, grad_out): 21 | inputs = [grad_out] 22 | for var in ctx.bp_inputs: 23 | inputs.append(var) 24 | outputs = ctx.bp_outputs 25 | 26 | hpc_torch_utils_network.ScatterConnectionBackward(inputs, outputs) 27 | grad_in = outputs[0] 28 | return grad_in, None, None, None, None 29 | 30 | class ScatterConnection(torch.nn.Module): 31 | r""" 32 | Overview: 33 | Scatter feature to its corresponding location 34 | In alphastar, each entity is embedded into a tensor, these tensors are scattered into a feature map 35 | with map size 36 | 37 | Interface: 38 | __init__, forward 39 | """ 40 | 41 | def __init__(self, B, M, N, H, W, scatter_type) -> None: 42 | r""" 43 | Overview 44 | initialization of scatter connection 45 | 46 | Arguments: 47 | - B (:obj:`int`): batch size 48 | - M (:obj:`int`): the number of entity 49 | - N (:obj:`int`): the dimension of entity attributes 50 | - H (:obj:`int`): height of spatial feature 51 | - W (:obj:`int`): width of spatial feature 52 | - scatter_type (:obj:`str`): add or cover, if two entities have same location, scatter type decides the 53 | first one should be covered or added to second one 54 | """ 55 | 56 | super().__init__() 57 | self.B = B 58 | self.M = M 59 | self.N = N 60 | self.H = H 61 | self.W = W 62 | self.scatter_type = scatter_type 63 | assert self.scatter_type in ['cover', 'add'] 64 | 65 | self.register_buffer('output', torch.zeros(B, N, H, W)) 66 | self.register_buffer('grad_in', torch.zeros(B, M, N)) 67 | 68 | def forward(self, x: torch.Tensor, location: torch.Tensor) -> torch.Tensor: 69 | """ 70 | Overview: 71 | forward of scatter connection, scatter x into a spatial feature map 72 | Arguments: 73 | - x (:obj:`torch.FloatTensor`): :math: `(B, M, N)`, the input tensor 74 | - location (:obj:`torch.LongTensor`): :math: `(B, M, 2)`, each location should be (y, x) 75 | Returns: 76 | - output (:obj:`FloatTensor`): :math: `(B, N, H, W)`, the scattered feature map 77 | 78 | .. note:: 79 | when there are some overlapping in locations, ``cover`` mode will result in the loss of information, we 80 | use the addition as temporal substitute. 81 | """ 82 | 83 | assert(x.is_cuda) 84 | assert(location.is_cuda) 85 | 86 | output = ScatterConnectionFunction.apply(x, location, self.output, self.grad_in, self.scatter_type) 87 | return output 88 | 89 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/basic_math.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_BASIC_MATH_H_ 2 | #define HPC_RLL_CUDA_BASIC_MATH_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | namespace hpc { 10 | namespace rll { 11 | namespace cuda { 12 | 13 | template 14 | __forceinline__ __device__ 15 | T clamp(T in, T min, T max) { 16 | if (in < min) 17 | return min; 18 | else if (in <= max) 19 | return in; 20 | else 21 | return max; 22 | } 23 | 24 | template 25 | __forceinline__ __device__ 26 | T sigmoid(T in) { 27 | T one = static_cast(1.0); 28 | return one / (one + ::exp(-in)); 29 | } 30 | 31 | } // namespace cuda 32 | } // namespace rll 33 | } // namespace hpc 34 | 35 | #endif // HPC_RLL_CUDA_BASIC_MATH_H_ 36 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/common.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_COMMON_H_ 2 | #define HPC_RLL_CUDA_COMMON_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #include "hpc/rll/cuda/status.h" 12 | 13 | namespace hpc { 14 | namespace rll { 15 | namespace cuda { 16 | 17 | #define TRACE { \ 18 | int err = cudaDeviceSynchronize(); \ 19 | fprintf(stderr, "TRACE: %s %d, err = %d\n", __FILE__, __LINE__, err); \ 20 | } 21 | 22 | static void print_tensor(const char* tensor_name, float* ptr, int len) { 23 | float hostbuf[len]; 24 | checkCudaErr(cudaMemcpy(hostbuf, ptr, len * sizeof(float), cudaMemcpyDeviceToHost)); 25 | 26 | fprintf(stderr, "%s\n", tensor_name); 27 | for (int i = 0; i < len; i++) 28 | fprintf(stderr, "%lf\n", hostbuf[i]); 29 | } 30 | 31 | static void save_tensor(const char* tensor_name, float* ptr, int len) { 32 | float hostbuf[len]; 33 | checkCudaErr(cudaMemcpy(hostbuf, ptr, len * sizeof(float), cudaMemcpyDeviceToHost)); 34 | 35 | char filename[256]; 36 | sprintf(filename, "%s.dat", tensor_name); 37 | std::ofstream outfile; 38 | outfile.open(filename); 39 | for (int i = 0; i < len; i++) 40 | outfile << hostbuf[i] << std::endl; 41 | outfile.close(); 42 | } 43 | 44 | const unsigned int DEFAULT_WARP_NUM = 8; 45 | const unsigned int WARP_SIZE = 32; 46 | const float CUDA_FLOAT_INF_POS = FLT_MAX; 47 | const float CUDA_FLOAT_INF_NEG = -FLT_MAX; 48 | 49 | // torch.finfo(torch.float32).eps say epsilon = 1.19209e-07, but pytorch layernorm userguide say epsilon = 1e-5 50 | const float EPSILON = 1e-5; 51 | 52 | } // namespace cuda 53 | } // namespace rll 54 | } // namespace hpc 55 | 56 | #endif // HPC_RLL_CUDA_COMMON_H_ 57 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/models/actor_critic_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_ACTOR_CRITIC_KERNEL_H_ 2 | #define HPC_RLL_CUDA_ACTOR_CRITIC_KERNEL_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | #include "hpc/rll/cuda/reduce.h" 6 | #include "hpc/rll/cuda/basic_math.h" 7 | 8 | namespace hpc { 9 | namespace rll { 10 | namespace cuda { 11 | 12 | // key_embeddings: [batch_size, max_entity_num, input_dim] 13 | // autoregressive_embedding: [batch_size, input_dim] 14 | __global__ void autoregressive_embedding_fp(int64_t batch_size, int64_t max_entity_num, int64_t input_dim, 15 | int64_t* sample_entity, int64_t* entity_num, 16 | const float* key_embeddings, float* autoregressive_embedding) { 17 | unsigned int gidx = blockIdx.x * blockDim.x + threadIdx.x; // input_dim 18 | unsigned int gidy = blockIdx.y; // batch_size 19 | 20 | int64_t entity_index = sample_entity[gidy]; 21 | bool end_flag = (entity_index == entity_num[gidy]); 22 | 23 | float ke = 0.f; 24 | if (!end_flag) { 25 | int64_t ke_index = gidy * max_entity_num * input_dim + entity_index * input_dim + gidx; 26 | ke = key_embeddings[ke_index]; 27 | } 28 | 29 | int64_t ae_index = gidy * input_dim + gidx; 30 | autoregressive_embedding[ae_index] += ke; 31 | } 32 | 33 | __global__ void lstm_activation_fp(unsigned int batch_size, unsigned int hidden_size, 34 | const float* in_x, const float* in_h, const float* bias, float* h, float* c) { 35 | unsigned int gidx = blockIdx.x * blockDim.x + threadIdx.x; // hidden_size 36 | unsigned int gidy = blockIdx.y; // batch_size 37 | unsigned int start = gidy * hidden_size * 4; 38 | if (gidx < hidden_size) { 39 | float val[4]; 40 | for (int i = 0; i < 4; i++) { 41 | val[i] = in_x[start + i * hidden_size + gidx] + in_h[start + i * hidden_size + gidx] 42 | + bias[i * hidden_size + gidx]; 43 | } 44 | 45 | float i = sigmoid(val[0]); 46 | float f = sigmoid(val[1]); 47 | float g = tanh(val[2]); 48 | float o = sigmoid(val[3]); 49 | float pre_c = c[gidy * hidden_size + gidx]; 50 | float new_c = f * pre_c + i * g; 51 | float new_h = o * tanh(new_c); 52 | 53 | h[gidy * hidden_size + gidx] = new_h; 54 | c[gidy * hidden_size + gidx] = new_c; 55 | } 56 | } 57 | 58 | __global__ void pre_sample_fp(unsigned int batch_size, unsigned int max_entity_num, unsigned int hidden_size, 59 | const float mask_value, const float div_factor, 60 | const float* mat, const float* vec, const bool* mask, float* output) { 61 | unsigned int tidx = threadIdx.x; // hidden_size 62 | unsigned int gidy = blockIdx.y; // max_entity_num 63 | unsigned int gidz = blockIdx.z; // batch_size 64 | 65 | if (mask[gidz * max_entity_num + gidy]) { 66 | float mul_val = 0.f; 67 | for (int i = tidx; i < hidden_size; i += blockDim.x) { 68 | float mat_val = mat[gidz * max_entity_num * hidden_size + gidy * hidden_size + i]; 69 | float vec_val = vec[gidz * hidden_size + i]; 70 | mul_val += mat_val * vec_val; 71 | } 72 | float reduce_sum = blockReduceSum(mul_val); 73 | if (tidx == 0) 74 | output[gidz * max_entity_num + gidy] = reduce_sum / div_factor; 75 | } else { 76 | if (tidx == 0) 77 | output[gidz * max_entity_num + gidy] = mask_value / div_factor; 78 | } 79 | } 80 | 81 | } // namespace cuda 82 | } // namespace rll 83 | } // namespace hpc 84 | #endif // HPC_RLL_CUDA_ACTOR_CRITIC_KERNEL_H_ 85 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/models/entry.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_NETWORK_H_ 2 | #define HPC_RLL_CUDA_NETWORK_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | 6 | namespace hpc { 7 | namespace rll { 8 | namespace cuda { 9 | 10 | void actor_critic_update_ae( 11 | const std::vector& inputs, 12 | std::vector& outputs); 13 | 14 | void actor_critic_lstm_activation( 15 | const std::vector& inputs, 16 | std::vector& outputs); 17 | 18 | void actor_critic_pre_sample( 19 | const std::vector& inputs, 20 | std::vector& outputs); 21 | 22 | } // namespace cuda 23 | } // namespace rll 24 | } // namespace hpc 25 | 26 | #endif // HPC_RLL_CUDA_NETWORK_H_ 27 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/reduce.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_REDUCE_H_ 2 | #define HPC_RLL_CUDA_REDUCE_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | 6 | namespace hpc { 7 | namespace rll { 8 | namespace cuda { 9 | 10 | const unsigned int WARP_REDUCE_MASK = 0xffffffff; 11 | 12 | // reduce to all the threads in the warp 13 | template 14 | __forceinline__ __device__ T warpReduceSum(T val) { 15 | for (int mask = (WARP_SIZE >> 1); mask > 0; mask >>= 1) 16 | val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_SIZE); 17 | return val; 18 | } 19 | 20 | // reduce to all the threads in the warp 21 | template 22 | __forceinline__ __device__ T warpReduceMax(T val) { 23 | for (int mask = (WARP_SIZE >> 1); mask > 0; mask >>= 1) 24 | val = max(val, __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_SIZE)); 25 | return val; 26 | } 27 | 28 | // reduce to all the threads in the warp 29 | template 30 | __forceinline__ __device__ T warpReduceMin(T val) { 31 | for (int mask = (WARP_SIZE >> 1); mask > 0; mask >>= 1) 32 | val = min(val, __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_SIZE)); 33 | return val; 34 | } 35 | 36 | // Calculate the sum of all elements in a block, 37 | // reduce to thread 0, not all the threads in the block 38 | template 39 | __forceinline__ __device__ T blockReduceSum(T val) { 40 | static __shared__ T shared[32]; 41 | int lane = threadIdx.x & 0x1f; 42 | int wid = threadIdx.x >> 5; 43 | 44 | val = warpReduceSum(val); 45 | 46 | if (lane == 0) shared[wid] = val; 47 | __syncthreads(); 48 | 49 | if (wid == 0) { 50 | val = (threadIdx.x < ((blockDim.x + 31) >> 5)) ? shared[lane] : (T)0.0f; 51 | val = warpReduceSum(val); 52 | return val; 53 | } 54 | return (T)0.0f; 55 | } 56 | 57 | // Calculate the maximum of all elements in a block 58 | // reduce to thread 0, not all the threads in the block 59 | template 60 | __forceinline__ __device__ T blockReduceMax(T val) { 61 | static __shared__ T shared[32]; 62 | int lane = threadIdx.x & 0x1f; 63 | int wid = threadIdx.x >> 5; 64 | 65 | val = warpReduceMax(val); 66 | 67 | if (lane == 0) shared[wid] = val; 68 | __syncthreads(); 69 | 70 | if (wid == 0) { 71 | val = (threadIdx.x < ((blockDim.x + 31) >> 5)) ? shared[lane] 72 | : CUDA_FLOAT_INF_NEG; 73 | val = warpReduceMax(val); 74 | return val; 75 | } 76 | return CUDA_FLOAT_INF_NEG; 77 | } 78 | 79 | // Calculate the minimum of all elements in a block 80 | // reduce to thread 0, not all the threads in the block 81 | template 82 | __forceinline__ __device__ T blockReduceMin(T val) { 83 | static __shared__ T shared[32]; 84 | int lane = threadIdx.x & 0x1f; 85 | int wid = threadIdx.x >> 5; 86 | 87 | val = warpReduceMin(val); 88 | 89 | if (lane == 0) shared[wid] = val; 90 | __syncthreads(); 91 | 92 | if (wid == 0) { 93 | val = (threadIdx.x < ((blockDim.x + 31) >> 5)) ? shared[lane] 94 | : CUDA_FLOAT_INF_POS; 95 | val = warpReduceMin(val); 96 | return val; 97 | } 98 | return CUDA_FLOAT_INF_POS; 99 | } 100 | 101 | } // namespace cuda 102 | } // namespace rll 103 | } // namespace hpc 104 | 105 | #endif // HPC_RLL_CUDA_REDUCE_H_ 106 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/rl_utils/dist_nstep_td_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_DIST_NSTEP_TD_KERNEL_H_ 2 | #define HPC_RLL_CUDA_DIST_NSTEP_TD_KERNEL_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | #include "hpc/rll/cuda/reduce.h" 6 | 7 | namespace hpc { 8 | namespace rll { 9 | namespace cuda { 10 | 11 | void __global__ distNStepTdRewardKernel(unsigned int time_step, unsigned int batch_size, float gamma, 12 | const float* reward, float* reward_buf) { 13 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; // batch_size 14 | 15 | if (gid < batch_size) { 16 | unsigned int batch_id = gid; 17 | 18 | float sum_reward = 0; 19 | float factor = 1; 20 | for (int t = 0; t < time_step; ++t) { 21 | float rw = reward[t * batch_size + batch_id]; 22 | sum_reward += (factor * rw); 23 | factor *= gamma; 24 | } 25 | 26 | reward_buf[batch_id] = sum_reward; 27 | } 28 | } 29 | 30 | void __global__ distNStepTdProjKernel(unsigned int batch_size, unsigned int action_dim, unsigned int n_atom, 31 | float gamma_nstep, float v_min, float v_max, float delta, 32 | const float* next_n_dist, const int64_t* next_n_action, 33 | const float* reward_buf, const float* done, float* proj_dist) { 34 | unsigned int gidx = threadIdx.x + blockIdx.x*blockDim.x; // n_atom 35 | unsigned int gidy = blockIdx.y; // batch_size 36 | 37 | if (gidx < n_atom) { 38 | unsigned int atom_id = gidx; 39 | unsigned int batch_id = gidy; 40 | 41 | float reward = reward_buf[batch_id]; 42 | float support = v_min + atom_id * delta; 43 | float target = reward + (1 - done[batch_id]) * gamma_nstep * support; 44 | target = min(v_max, target); 45 | target = max(v_min, target); 46 | 47 | float local_box_id = (target - v_min) / delta; 48 | unsigned int local_box_id_l = floor(local_box_id); 49 | unsigned int local_box_id_u = ceil(local_box_id); 50 | unsigned int global_box_id_l = batch_id * n_atom + local_box_id_l; 51 | unsigned int global_box_id_u = batch_id * n_atom + local_box_id_u; 52 | 53 | unsigned int next_n_action_id = next_n_action[batch_id]; 54 | float target_dist_sa = next_n_dist[batch_id * action_dim * n_atom + next_n_action_id * n_atom + atom_id]; 55 | 56 | float proj_dist_l = target_dist_sa * ((float)local_box_id_u - local_box_id); 57 | float proj_dist_u = target_dist_sa * (local_box_id - (float)local_box_id_l); 58 | atomicAdd(&proj_dist[global_box_id_l], proj_dist_l); 59 | atomicAdd(&proj_dist[global_box_id_u], proj_dist_u); 60 | } 61 | } 62 | 63 | void __global__ distNStepTdLossKernel(unsigned int batch_size, unsigned int action_dim, unsigned int n_atom, 64 | const float* dist, const int64_t* action, const float* proj_dist, const float* weight, 65 | float* td_err, float* loss, float* grad_buf) { 66 | unsigned int gidx = threadIdx.x + blockIdx.x*blockDim.x; // n_atom 67 | unsigned int gidy = blockIdx.y; // batch_size 68 | 69 | float sum_val = 0; 70 | float sum_td_err = 0; 71 | if (gidx < n_atom) { 72 | unsigned int atom_id = gidx; 73 | unsigned int batch_id = gidy; 74 | 75 | unsigned int action_id = action[batch_id]; 76 | float dist_sa = dist[batch_id * action_dim * n_atom + action_id * n_atom + atom_id]; 77 | float log_p = log(dist_sa); 78 | 79 | float w = weight[batch_id]; 80 | float proj = proj_dist[batch_id * n_atom + atom_id]; 81 | sum_val = log_p * proj * w; 82 | sum_td_err = log_p * proj; 83 | 84 | grad_buf[batch_id * n_atom + atom_id] = (-1.f) / (float)batch_size * w * proj * (1.f / dist_sa); 85 | } 86 | 87 | float reduced_sum_val = blockReduceSum(sum_val); 88 | float reduced_sum_td_err = blockReduceSum(sum_td_err); 89 | if (threadIdx.x == 0) { 90 | td_err[gidy] = reduced_sum_td_err * (-1.f); 91 | atomicAdd(loss, reduced_sum_val * (-1.f) / (float)batch_size); 92 | } 93 | } 94 | 95 | void __global__ distNStepTdBackwardKernel(unsigned int batch_size, unsigned int action_dim, unsigned int n_atom, 96 | const float* grad_loss, const float* grad_buf, const int64_t* action, float* grad_dist) { 97 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; 98 | 99 | if (gid < batch_size * action_dim * n_atom) { 100 | unsigned int atom_id = gid % n_atom; 101 | unsigned int action_id = (gid / n_atom) % action_dim; 102 | unsigned int batch_id = (gid / n_atom) / action_dim; 103 | 104 | float grad = (action_id == action[batch_id]) ? grad_buf[batch_id * n_atom + atom_id] : 0; 105 | grad_dist[gid] = (*grad_loss) * grad; 106 | } 107 | } 108 | 109 | } // namespace cuda 110 | } // namespace rll 111 | } // namespace hpc 112 | #endif // HPC_RLL_CUDA_DIST_NSTEP_TD_KERNEL_H_ 113 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/rl_utils/entry.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_LOSS_H_ 2 | #define HPC_RLL_CUDA_LOSS_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | 6 | namespace hpc { 7 | namespace rll { 8 | namespace cuda { 9 | 10 | std::vector> sample_split_group(const std::vector& x, int group); 11 | std::vector> oracle_split_group(const std::vector& x, int group); 12 | 13 | std::vector Pad1DForward( 14 | const std::vector& inputs, 15 | const int& value); 16 | 17 | std::vector> GroupPad1DForward( 18 | const std::vector& inputs, 19 | const std::vector& group_cnt, 20 | const std::vector& max_shape, 21 | const std::vector& group_id, 22 | const std::vector& group_idx, 23 | const int& value); 24 | 25 | std::vector Unpad1DForward( 26 | const torch::Tensor& inputs, 27 | const std::vector& shape); 28 | 29 | std::vector Pad2DForward( 30 | const std::vector& inputs, 31 | const int& value); 32 | 33 | std::vector> GroupPad2DForward( 34 | const std::vector& inputs, 35 | const std::vector& group_cnt, 36 | const std::vector& max_shape, 37 | const std::vector& group_id, 38 | const std::vector& group_idx, 39 | const int& value); 40 | 41 | std::vector Unpad2DForward( 42 | const torch::Tensor& inputs, 43 | const std::vector& shape); 44 | 45 | std::vector Pad3DForward( 46 | const std::vector& inputs, 47 | const int& value); 48 | 49 | std::vector> GroupPad3DForward( 50 | const std::vector& inputs, 51 | const std::vector& group_cnt, 52 | const std::vector& max_shape, 53 | const std::vector& group_id, 54 | const std::vector& group_idx, 55 | const int& value); 56 | 57 | std::vector Unpad3DForward( 58 | const torch::Tensor& inputs, 59 | const std::vector& shape); 60 | 61 | // gae 62 | void GaeForward( 63 | const std::vector& inputs, 64 | std::vector& outputs, 65 | float gamma, 66 | float lambda); 67 | 68 | // td_lambda 69 | void TdLambdaForward( 70 | const std::vector& inputs, 71 | std::vector& outputs, 72 | float gamma, 73 | float lambda); 74 | 75 | void TdLambdaBackward( 76 | const std::vector& inputs, 77 | std::vector& outputs); 78 | 79 | // dist_nstep_td 80 | void DistNStepTdForward( 81 | const std::vector& inputs, 82 | std::vector& outputs, 83 | float gamma, 84 | float v_min, 85 | float v_max); 86 | 87 | void DistNStepTdBackward( 88 | const std::vector& inputs, 89 | std::vector& outputs); 90 | 91 | // q_nstep_td 92 | void QNStepTdForward( 93 | const std::vector& inputs, 94 | std::vector& outputs, 95 | float gamma); 96 | 97 | void QNStepTdBackward( 98 | const std::vector& inputs, 99 | std::vector& outputs); 100 | 101 | // q_nstep_td_with_rescale 102 | void QNStepTdRescaleForward( 103 | const std::vector& inputs, 104 | std::vector& outputs, 105 | float gamma); 106 | 107 | void QNStepTdRescaleBackward( 108 | const std::vector& inputs, 109 | std::vector& outputs); 110 | 111 | // upgo 112 | void UpgoForward( 113 | const std::vector& inputs, 114 | std::vector& outputs); 115 | 116 | void UpgoBackward( 117 | const std::vector& inputs, 118 | std::vector& outputs); 119 | 120 | // vtrace 121 | void VTraceForward( 122 | const std::vector& inputs, 123 | std::vector& outputs, 124 | float gamma, 125 | float lambda, 126 | float rho_clip_ratio, 127 | float c_clip_ratio, 128 | float rho_pg_clip_ratio); 129 | 130 | void VTraceBackward( 131 | const std::vector& inputs, 132 | std::vector& outputs); 133 | 134 | // ppo 135 | void PPOForward( 136 | const std::vector& inputs, 137 | std::vector& outputs, 138 | bool use_value_clip, 139 | float clip_ratio, 140 | float dual_clip); 141 | 142 | void PPOBackward( 143 | const std::vector& inputs, 144 | std::vector& outputs); 145 | 146 | // iqn_nstep_td_error 147 | void IQNNStepTDErrorForward( 148 | const std::vector& inputs, 149 | std::vector& outputs, 150 | float gamma, 151 | float kappa); 152 | 153 | void IQNNStepTDErrorBackward( 154 | const std::vector& inputs, 155 | std::vector& outputs); 156 | 157 | // qrdqn_nstep_td_error 158 | void QRDQNNStepTDErrorForward( 159 | const std::vector& inputs, 160 | std::vector& outputs, 161 | float gamma); 162 | 163 | void QRDQNNStepTDErrorBackward( 164 | const std::vector& inputs, 165 | std::vector& outputs); 166 | 167 | } // namespace cuda 168 | } // namespace rll 169 | } // namespace hpc 170 | 171 | #endif // HPC_RLL_CUDA_LOSS_H_ 172 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/rl_utils/gae_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_GAE_KERNEL_H_ 2 | #define HPC_RLL_CUDA_GAE_KERNEL_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | 6 | namespace hpc { 7 | namespace rll { 8 | namespace cuda { 9 | 10 | void __global__ gaeForwardKernel(unsigned int time_step, unsigned int batch_size, float gamma, float lambda, 11 | const float* value, const float* reward, float* adv) { 12 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; 13 | if (gid < batch_size) { 14 | float gae_item = 0; 15 | float denom = 0; 16 | float factor = gamma * lambda; 17 | for (int t = time_step - 1; t >= 0; --t) { 18 | unsigned int index = t * batch_size + gid; 19 | 20 | denom = 1 + lambda * denom; 21 | float reward_data = reward[index]; 22 | float value_data = value[index]; 23 | float next_value_data = value[index + batch_size]; 24 | float delta = reward_data + gamma * next_value_data - value_data; 25 | gae_item = denom * delta + factor * gae_item; 26 | adv[index] = gae_item / denom; 27 | } 28 | } 29 | } 30 | 31 | } // namespace cuda 32 | } // namespace rll 33 | } // namespace hpc 34 | 35 | #endif // HPC_RLL_CUDA_GAE_KERNEL_H_ 36 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/rl_utils/iqn_nstep_td_error_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_IQN_NSTEP_TD_ERROR_KERNEL_H_ 2 | #define HPC_RLL_CUDA_IQN_NSTEP_TD_ERROR_KERNEL_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | #include "hpc/rll/cuda/reduce.h" 6 | 7 | namespace hpc { 8 | namespace rll { 9 | namespace cuda { 10 | 11 | void __global__ bellmanErrorKernel(unsigned int tau, unsigned int tau_prime, 12 | unsigned int time_step, unsigned int batch_size, unsigned int action_dim, float gamma, float kappa, 13 | const float* q, const float* next_n_q, const int64_t* action, const int64_t* next_n_action, 14 | const float* reward, const float* done, const float* value_gamma, float* bellman_err_buf) { 15 | unsigned int gidx = threadIdx.x + blockIdx.x*blockDim.x; // tau 16 | unsigned int gidy = blockIdx.y; // batch_size 17 | if (gidx >= tau) return; 18 | 19 | int64_t action_data = action[gidy]; 20 | int64_t next_n_action_data = next_n_action[gidy]; 21 | float value_gamma_data = value_gamma[gidy]; 22 | float done_data = done[gidy]; 23 | float not_done_data = 1.f - done_data; 24 | 25 | // for i in range(1, nstep): 26 | // reward_factor[i] = gamma * reward_factor[i - 1] 27 | // reward = torch.matmul(reward_factor, reward) 28 | float reward_factor = 1; 29 | float reward_data = 0; 30 | for (int t = 0; t < time_step; t++) { 31 | reward_data += reward_factor * reward[t * batch_size + gidy]; 32 | reward_factor *= gamma; 33 | } 34 | 35 | float qsa = q[gidx * batch_size * action_dim + gidy * action_dim + action_data]; 36 | for (int t = 0; t < tau_prime; t++) { 37 | float target_qsa = next_n_q[t * batch_size * action_dim + gidy * action_dim + next_n_action_data]; 38 | float target_qsa_transform = reward_data + value_gamma_data * target_qsa * not_done_data; 39 | 40 | bellman_err_buf[gidy * tau_prime * tau + t * tau + gidx] = target_qsa_transform - qsa; 41 | } 42 | } 43 | 44 | void __global__ quantileHuberErrorKernel(unsigned int tau, unsigned int tau_prime, unsigned int batch_size, float kappa, 45 | const float* bellman_err_buf, const float* replay_quantiles, 46 | float* quantile_huber_loss_buf, float* grad_buf) { 47 | unsigned int gidx = threadIdx.x + blockIdx.x*blockDim.x; // tau 48 | unsigned int gidy = threadIdx.y + blockIdx.y*blockDim.y; // tau_prime 49 | unsigned int gidz = blockIdx.z; // batch_size 50 | if (gidx >= tau || gidy >= tau_prime) return; 51 | 52 | float bellman_error_data = bellman_err_buf[gidz * tau_prime * tau + gidy * tau + gidx]; 53 | float huber_loss_data = 0.f; 54 | float grad_buf_data = 0.f; 55 | if (abs(bellman_error_data) <= kappa) { 56 | huber_loss_data = 0.5 * bellman_error_data * bellman_error_data; 57 | grad_buf_data = bellman_error_data; 58 | } else { 59 | huber_loss_data = kappa * (abs(bellman_error_data) - 0.5 * kappa); 60 | grad_buf_data = (bellman_error_data >= 0) ? kappa : (-kappa); 61 | } 62 | 63 | float r_q_data = replay_quantiles[gidx * batch_size + gidz]; 64 | float tmp = abs(r_q_data - ((bellman_error_data < 0) ? 1.f : 0.f)) / kappa; 65 | float quantile_huber_loss_data = tmp * huber_loss_data; 66 | grad_buf_data *= tmp; 67 | 68 | quantile_huber_loss_buf[gidz * tau_prime * tau + gidy * tau + gidx] = quantile_huber_loss_data; 69 | grad_buf[gidz * tau_prime * tau + gidy * tau + gidx] = grad_buf_data; 70 | } 71 | 72 | void __global__ lossKernel(unsigned int tau, unsigned int tau_prime, unsigned int batch_size, 73 | const float* quantile_huber_loss_buf, const float* weight, 74 | float* td_err, float* loss) { 75 | unsigned int block_start = blockIdx.x * tau * tau_prime; 76 | unsigned int start = block_start + threadIdx.x; 77 | unsigned int end = block_start + tau * tau_prime; 78 | 79 | float partial_sum_huber = 0.f; 80 | for (int i = start; i < end; i += blockDim.x) { 81 | partial_sum_huber += quantile_huber_loss_buf[i]; 82 | } 83 | float sum_huber = blockReduceSum(partial_sum_huber); 84 | float mean_huber = sum_huber / tau_prime; 85 | 86 | if (threadIdx.x == 0) { 87 | td_err[blockIdx.x] = mean_huber; 88 | atomicAdd(loss, mean_huber * weight[blockIdx.x] / batch_size); 89 | } 90 | } 91 | 92 | void __global__ backwardKernel(unsigned int tau, unsigned int tau_prime, unsigned int batch_size, unsigned int action_dim, 93 | const float* grad_loss, const float* grad_buf, 94 | const float* weight, const int64_t* action, float* grad_q) { 95 | unsigned int gidx = threadIdx.x + blockIdx.x*blockDim.x; // tau 96 | unsigned int gidy = blockIdx.y; // batch_size 97 | if (gidx >= tau) return; 98 | 99 | float grad_data = 0.f; 100 | for (int t = 0; t < tau_prime; t++) { 101 | grad_data += grad_buf[gidy * tau_prime * tau + t * tau + gidx];; 102 | } 103 | grad_data *= (grad_loss[0] / batch_size * weight[gidy] / tau_prime); // mean, weight, mean 104 | grad_data *= -1.f; // target_qsa - qsa 105 | 106 | int output_index = gidx * batch_size * action_dim + gidy * action_dim + action[gidy]; 107 | grad_q[output_index] = grad_data; 108 | } 109 | 110 | } // namespace cuda 111 | } // namespace rll 112 | } // namespace hpc 113 | #endif // HPC_RLL_CUDA_IQN_NSTEP_TD_ERROR_KERNEL_H_ 114 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/rl_utils/q_nstep_td_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_Q_NSTEP_TD_KERNEL_H_ 2 | #define HPC_RLL_CUDA_Q_NSTEP_TD_KERNEL_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | #include "hpc/rll/cuda/reduce.h" 6 | 7 | namespace hpc { 8 | namespace rll { 9 | namespace cuda { 10 | 11 | void __global__ qNStepTdForwardKernel(unsigned int time_step, unsigned int batch_size, unsigned int num_output, float gamma, 12 | const float* q, const float* next_n_q, const int64_t* action, const int64_t* next_n_action, 13 | const float* reward, const float* done, const float* weight, 14 | float* td_err, float* loss, float* grad_buf) { 15 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; 16 | 17 | float sum_square = 0; 18 | if (gid < batch_size) { 19 | unsigned int batch_id = gid; 20 | unsigned int num_out_id = action[batch_id]; 21 | 22 | float qsa = q[batch_id * num_output + num_out_id]; 23 | unsigned int next_n_num_out_id = next_n_action[batch_id]; 24 | float target_qsa = next_n_q[batch_id * num_output + next_n_num_out_id]; 25 | 26 | // nstep_return 27 | float sum_reward = 0; 28 | float factor = 1; 29 | for (int t = 0; t < time_step; ++t) { 30 | float rw = reward[t * batch_size + batch_id]; 31 | sum_reward += (factor * rw); 32 | factor *= gamma; 33 | } 34 | float done_ = done[batch_id]; 35 | target_qsa = sum_reward + factor * target_qsa * (1.f - done_); 36 | 37 | float diff = qsa - target_qsa; 38 | sum_square = diff * diff; 39 | td_err[batch_id] = sum_square; 40 | 41 | float w = weight[batch_id]; 42 | sum_square *= w; 43 | grad_buf[batch_id] = 1.f / batch_size * (2.f * diff) * w; 44 | } 45 | 46 | float reduced_sum_square = blockReduceSum(sum_square); 47 | if (threadIdx.x == 0) { 48 | float mean_loss = reduced_sum_square / batch_size; 49 | atomicAdd(loss, mean_loss); 50 | } 51 | } 52 | 53 | void __global__ qNStepTdBackwardKernel(unsigned int batch_size, unsigned int num_output, 54 | const float* grad_loss, const float* grad_buf, const int64_t* action, float* grad_q) { 55 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; // num_output 56 | 57 | if (gid < num_output) { 58 | unsigned int batch_id = blockIdx.y; 59 | float grad = (gid == action[batch_id]) ? grad_buf[batch_id] : 0; 60 | grad_q[batch_id * num_output + gid] = (*grad_loss) * grad; 61 | } 62 | } 63 | 64 | } // namespace cuda 65 | } // namespace rll 66 | } // namespace hpc 67 | #endif // HPC_RLL_CUDA_Q_NSTEP_TD_KERNEL_H_ 68 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/rl_utils/q_nstep_td_rescale_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_Q_NSTEP_TD_RESCALE_KERNEL_H_ 2 | #define HPC_RLL_CUDA_Q_NSTEP_TD_RESCALE_KERNEL_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | #include "hpc/rll/cuda/reduce.h" 6 | 7 | namespace hpc { 8 | namespace rll { 9 | namespace cuda { 10 | 11 | void __global__ qNStepTdRescaleForwardKernel(unsigned int time_step, unsigned int batch_size, unsigned int num_output, float gamma, 12 | const float* q, const float* next_n_q, const int64_t* action, const int64_t* next_n_action, 13 | const float* reward, const float* done, const float* weight, 14 | float* td_err, float* loss, float* grad_buf) { 15 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; 16 | 17 | float sum_square = 0; 18 | if (gid < batch_size) { 19 | unsigned int batch_id = gid; 20 | unsigned int num_out_id = action[batch_id]; 21 | 22 | float qsa = q[batch_id * num_output + num_out_id]; 23 | unsigned int next_n_num_out_id = next_n_action[batch_id]; 24 | float target_qsa = next_n_q[batch_id * num_output + next_n_num_out_id]; 25 | 26 | // value_inv_transform 27 | float eps = 1e-2; 28 | float sign = (target_qsa > 0.f ? 1.f : (target_qsa < 0.f ? -1.f : 0.f)); 29 | float tmp = (sqrt(1.f + 4.f * eps * (abs(target_qsa) + 1.f + eps)) - 1) / (2 * eps); 30 | target_qsa = sign * (tmp * tmp - 1); 31 | 32 | // nstep_return 33 | float sum_reward = 0; 34 | float factor = 1; 35 | for (int t = 0; t < time_step; ++t) { 36 | float rw = reward[t * batch_size + batch_id]; 37 | sum_reward += (factor * rw); 38 | factor *= gamma; 39 | } 40 | float done_ = done[batch_id]; 41 | target_qsa = sum_reward + factor * target_qsa * (1.f - done_); 42 | 43 | // value_transform 44 | sign = (target_qsa > 0.f ? 1.f : (target_qsa < 0.f ? -1.f : 0.f)); 45 | target_qsa = sign * (sqrt(abs(target_qsa) + 1) - 1) + eps * target_qsa; 46 | 47 | float diff = qsa - target_qsa; 48 | sum_square = diff * diff; 49 | td_err[batch_id] = sum_square; 50 | 51 | float w = weight[batch_id]; 52 | sum_square *= w; 53 | grad_buf[batch_id] = 1.f / batch_size * (2.f * diff) * w; 54 | } 55 | 56 | float reduced_sum_square = blockReduceSum(sum_square); 57 | if (threadIdx.x == 0) { 58 | float mean_loss = reduced_sum_square / batch_size; 59 | atomicAdd(loss, mean_loss); 60 | } 61 | } 62 | 63 | void __global__ qNStepTdRescaleBackwardKernel(unsigned int batch_size, unsigned int num_output, 64 | const float* grad_loss, const float* grad_buf, const int64_t* action, float* grad_q) { 65 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; // num_output 66 | 67 | if (gid < num_output) { 68 | unsigned int batch_id = blockIdx.y; 69 | float grad = (gid == action[batch_id]) ? grad_buf[batch_id] : 0; 70 | grad_q[batch_id * num_output + gid] = (*grad_loss) * grad; 71 | } 72 | } 73 | 74 | } // namespace cuda 75 | } // namespace rll 76 | } // namespace hpc 77 | #endif // HPC_RLL_CUDA_Q_NSTEP_TD_RESCALE_KERNEL_H_ 78 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/rl_utils/qrdqn_nstep_td_error_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_QRDQN_NSTEP_TD_ERROR_KERNEL_H_ 2 | #define HPC_RLL_CUDA_QRDQN_NSTEP_TD_ERROR_KERNEL_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | #include "hpc/rll/cuda/reduce.h" 6 | 7 | namespace hpc { 8 | namespace rll { 9 | namespace cuda { 10 | 11 | void __global__ bellmanErrorKernel(unsigned int tau, unsigned int time_step, unsigned int batch_size, unsigned int action_dim, float gamma, 12 | const float* q, const float* next_n_q, const int64_t* action, const int64_t* next_n_action, 13 | const float* reward, const float* done, const float* value_gamma, float* bellman_err_buf) { 14 | unsigned int gidx = threadIdx.x + blockIdx.x*blockDim.x; // tau 15 | unsigned int gidy = blockIdx.y; // batch_size 16 | if (gidx >= tau) return; 17 | 18 | int64_t action_data = action[gidy]; 19 | int64_t next_n_action_data = next_n_action[gidy]; 20 | float value_gamma_data = value_gamma[gidy]; 21 | float done_data = done[gidy]; 22 | float not_done_data = 1.f - done_data; 23 | 24 | // for i in range(1, nstep): 25 | // reward_factor[i] = gamma * reward_factor[i - 1] 26 | // reward = torch.matmul(reward_factor, reward) 27 | float reward_factor = 1; 28 | float reward_data = 0; 29 | for (int t = 0; t < time_step; t++) { 30 | reward_data += reward_factor * reward[t * batch_size + gidy]; 31 | reward_factor *= gamma; 32 | } 33 | 34 | float target_qsa = next_n_q[gidy * action_dim * tau + next_n_action_data * tau + gidx]; 35 | float target_qsa_transform = reward_data + value_gamma_data * target_qsa * not_done_data; 36 | for (int t = 0; t < tau; t++) { 37 | float qsa = q[gidy * action_dim * tau + action_data * tau + t]; 38 | bellman_err_buf[gidy * tau * tau + t * tau + gidx] = target_qsa_transform - qsa; 39 | } 40 | } 41 | 42 | void __global__ smoothL1LossKernel(unsigned int tau, unsigned int batch_size, 43 | const float* bellman_err_buf, float* quantile_huber_loss_buf, float* grad_buf) { 44 | unsigned int gidx = threadIdx.x + blockIdx.x*blockDim.x; // tau 45 | unsigned int gidy = threadIdx.y + blockIdx.y*blockDim.y; // tau 46 | unsigned int gidz = blockIdx.z; // batch_size 47 | if (gidx >= tau || gidy >= tau) return; 48 | 49 | float bellman_error_data = bellman_err_buf[gidz * tau * tau + gidy * tau + gidx]; 50 | float huber_loss_data = 0.f; 51 | float grad_buf_data = 0.f; 52 | if (abs(bellman_error_data) < 1) { 53 | huber_loss_data = 0.5f * bellman_error_data * bellman_error_data; 54 | grad_buf_data = bellman_error_data; 55 | } else { 56 | huber_loss_data = (abs(bellman_error_data) - 0.5f); 57 | grad_buf_data = (bellman_error_data >= 0) ? 1.f : (-1.f); 58 | } 59 | 60 | float tmp1 = tau - ((bellman_error_data <= 0) ? 1.f : 0.f); 61 | float tmp2 = huber_loss_data * tmp1; 62 | float quantile_huber_loss_data = abs(tmp2); 63 | grad_buf_data *= (tmp2 >= 0 ? tmp1 : -tmp1); 64 | 65 | quantile_huber_loss_buf[gidz * tau * tau + gidy * tau + gidx] = quantile_huber_loss_data; 66 | grad_buf[gidz * tau * tau + gidy * tau + gidx] = grad_buf_data; 67 | 68 | } 69 | 70 | void __global__ lossKernel(unsigned int tau, unsigned int batch_size, 71 | const float* quantile_huber_loss_buf, const float* weight, 72 | float* td_err, float* loss) { 73 | unsigned int block_start = blockIdx.x * tau * tau; 74 | unsigned int start = block_start + threadIdx.x; 75 | unsigned int end = block_start + tau * tau; 76 | 77 | float partial_sum_huber = 0.f; 78 | for (int i = start; i < end; i += blockDim.x) { 79 | partial_sum_huber += quantile_huber_loss_buf[i]; 80 | } 81 | float sum_huber = blockReduceSum(partial_sum_huber); 82 | float mean_huber = sum_huber / tau; 83 | 84 | if (threadIdx.x == 0) { 85 | td_err[blockIdx.x] = mean_huber; 86 | atomicAdd(loss, mean_huber * weight[blockIdx.x] / batch_size); 87 | } 88 | } 89 | 90 | void __global__ backwardKernel(unsigned int tau, unsigned int batch_size, unsigned int action_dim, 91 | const float* grad_loss, const float* grad_buf, 92 | const float* weight, const int64_t* action, float* grad_q) { 93 | unsigned int gidx = threadIdx.x + blockIdx.x*blockDim.x; // tau 94 | unsigned int gidy = blockIdx.y; // batch_size 95 | if (gidx >= tau) return; 96 | 97 | float grad_data = 0.f; 98 | for (int t = 0; t < tau; t++) { 99 | grad_data += grad_buf[gidy * tau * tau + gidx * tau + t]; 100 | } 101 | grad_data *= (grad_loss[0] / batch_size * weight[gidy] / tau); // mean, weight, mean 102 | grad_data *= -1.f; // target_qsa - qsa 103 | 104 | int output_index = gidy * action_dim * tau + action[gidy] * tau + gidx; 105 | grad_q[output_index] = grad_data; 106 | } 107 | 108 | } // namespace cuda 109 | } // namespace rll 110 | } // namespace hpc 111 | #endif // HPC_RLL_CUDA_QRDQN_NSTEP_TD_ERROR_KERNEL_H_ 112 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/rl_utils/td_lambda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_TD_LAMBDA_KERNEL_H_ 2 | #define HPC_RLL_CUDA_TD_LAMBDA_KERNEL_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | #include "hpc/rll/cuda/reduce.h" 6 | 7 | namespace hpc { 8 | namespace rll { 9 | namespace cuda { 10 | 11 | void __global__ tdLambdaForwardKernel(unsigned int time_step, unsigned int batch_size, float gamma, float lambda, 12 | const float* value, const float* reward, const float* weight, float* loss, float* grad_buf) { 13 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; 14 | 15 | float sum_square = 0; 16 | if (gid < batch_size) { 17 | float rt = 0.f; 18 | for (int t = time_step - 1; t >= 0; --t) { 19 | unsigned int index = t * batch_size + gid; 20 | 21 | float value_data = value[index]; 22 | float next_value_data = value[index + batch_size]; 23 | float reward_data = reward[index]; 24 | float weight_data = weight[index]; 25 | 26 | float tmp = (t == time_step - 1) ? next_value_data : (lambda * rt + (1.f - lambda) * next_value_data); 27 | rt = reward_data + gamma * tmp; 28 | 29 | float loss = (rt - value_data); 30 | grad_buf[index] = weight_data * (2.f * loss * (-1.f)); 31 | sum_square += loss * loss * weight_data; 32 | } 33 | } 34 | 35 | float reduced_sum_square = blockReduceSum(sum_square); 36 | if (threadIdx.x == 0) { 37 | float mean_loss = 0.5 * reduced_sum_square / (time_step * batch_size); 38 | atomicAdd(loss, mean_loss); 39 | } 40 | } 41 | 42 | void __global__ tdLambdaBackwardKernel(unsigned int time_step, unsigned int batch_size, 43 | const float* grad_loss, const float* grad_buf, float* grad_value) { 44 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; 45 | 46 | if (gid < (time_step + 1) * batch_size) { 47 | float grad = *grad_loss; 48 | float grad_mean = 1.f / (time_step * batch_size); 49 | grad_value[gid] = (gid < time_step * batch_size) ? (grad * 0.5 * grad_mean * grad_buf[gid]) : 0.f; 50 | } 51 | } 52 | 53 | } // namespace cuda 54 | } // namespace rll 55 | } // namespace hpc 56 | #endif // HPC_RLL_CUDA_TD_LAMBDA_KERNEL_H_ 57 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/rl_utils/upgo_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_UPGO_KERNEL_H_ 2 | #define HPC_RLL_CUDA_UPGO_KERNEL_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | #include "hpc/rll/cuda/reduce.h" 6 | 7 | namespace hpc { 8 | namespace rll { 9 | namespace cuda { 10 | 11 | void __global__ upgoAdvantageKernel(unsigned int time_step, unsigned int batch_size, 12 | const float* rho, const float* reward, const float* value, float* advantage) { 13 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; 14 | 15 | if (gid < batch_size) { 16 | float item = 0; 17 | for (int t = time_step - 1; t >= 0; --t) { 18 | unsigned int index = t * batch_size + gid; 19 | 20 | float rho_data = rho[index]; 21 | 22 | float reward_data0 = reward[index]; 23 | // Note: when t == time_step - 1, reward_data1 is not used. Just avoid accessing out of memory bound. 24 | float reward_data1 = (t == time_step - 1) ? 0.f : reward[index + batch_size]; 25 | 26 | float value0 = value[index]; 27 | float value1 = value[index + batch_size]; 28 | // Note: when t == time_step - 1, value2 is not used. Just avoid accessing out of memory bound. 29 | float value2 = (t == time_step - 1) ? 0.f : value[index + batch_size * 2]; 30 | 31 | float value_data = ((t < time_step - 1) && (reward_data1 + value2 >= value1)) ? item : value1; 32 | 33 | float rt = reward_data0 + value_data; 34 | advantage[index] = (rt - value0) * rho_data; 35 | item = rt; 36 | } 37 | } 38 | } 39 | 40 | void __global__ crossEntropyKernel(unsigned int num, 41 | const float* input, const int64_t* target, float* output, float* grad) { 42 | unsigned int block_start = blockIdx.x * num; 43 | unsigned int start = block_start + threadIdx.x; 44 | unsigned int end = block_start + num; 45 | 46 | // step 1 get max_x 47 | float max_x = CUDA_FLOAT_INF_NEG; 48 | for (int i = start; i < end; i += blockDim.x) { 49 | max_x = max(max_x, input[i]); 50 | } 51 | static __shared__ float s_max_x; 52 | float reduce_max_x = blockReduceMax(max_x); 53 | if (threadIdx.x == 0) { 54 | s_max_x = reduce_max_x; 55 | } 56 | __syncthreads(); 57 | 58 | // step 2 compute sum(exp(x - max_x)) 59 | static __shared__ float s_sum_exp_x; 60 | float sum_exp_x = 0.0; 61 | for (int i = start; i < end; i += blockDim.x) { 62 | sum_exp_x += std::exp(input[i] - s_max_x); 63 | } 64 | float reduce_sum_exp_x = blockReduceSum(sum_exp_x); 65 | if (threadIdx.x == 0) { 66 | s_sum_exp_x = reduce_sum_exp_x; 67 | } 68 | __syncthreads(); 69 | 70 | // step 2 compute cross entropy and grad 71 | for (int i = start; i < end; i += blockDim.x) { 72 | bool flag = (i - block_start == target[blockIdx.x]); 73 | 74 | float softmax_data = std::exp(input[i] - s_max_x) / s_sum_exp_x; 75 | 76 | if (flag) 77 | output[blockIdx.x] = std::log(softmax_data); 78 | 79 | grad[i] = flag ? (1 - softmax_data) : (-softmax_data); 80 | } 81 | } 82 | 83 | void __global__ upgoLossKernel(unsigned int time_step, unsigned int batch_size, 84 | const float* advantage, const float* metric, float* loss) { 85 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; 86 | 87 | float sum_loss = (gid < time_step * batch_size) ? (advantage[gid] * metric[gid]) : 0.f; 88 | float reduced_sum_loss = blockReduceSum(sum_loss); 89 | 90 | if (threadIdx.x == 0) { 91 | float mean_loss = reduced_sum_loss / (time_step * batch_size); 92 | atomicAdd(loss, -mean_loss); 93 | } 94 | } 95 | 96 | void __global__ upgoBackwardKernel(unsigned int time_step, unsigned int batch_size, unsigned int num_output, 97 | const float* grad_loss, const float* grad_buf, 98 | const float* advantages, float* grad_target_output) { 99 | unsigned int gid = threadIdx.x + blockIdx.x*blockDim.x; 100 | 101 | if (gid < time_step * batch_size * num_output) { 102 | unsigned int tb_id = gid / num_output; 103 | 104 | float grad = (*grad_loss); 105 | float grad_mean = 1.f / (time_step * batch_size); 106 | grad_target_output[gid] = grad * (-1.f) * grad_mean * advantages[tb_id] * grad_buf[gid]; 107 | } 108 | } 109 | 110 | } // namespace cuda 111 | } // namespace rll 112 | } // namespace hpc 113 | #endif // HPC_RLL_CUDA_UPGO_KERNEL_H_ 114 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/status.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_STATUS_H_ 2 | #define HPC_RLL_CUDA_STATUS_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | namespace hpc { 12 | namespace rll { 13 | namespace cuda { 14 | 15 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 16 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 17 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 18 | 19 | static int checkCudaError(cudaError_t code, const char* expr, const char* file, int line, bool abort = true) { 20 | if (code) { 21 | fprintf(stderr, "CUDA error at %s:%d, code=%d (%s) in '%s'", file, line, (int) code, cudaGetErrorString(code), expr); 22 | if (abort) 23 | throw std::logic_error("CUDA Error."); 24 | } 25 | return 0; 26 | } 27 | 28 | #define checkCudaErr(...) do { int err = checkCudaError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); } while (0) 29 | 30 | static const char* cublasGetErrorString(cublasStatus_t status) 31 | { 32 | switch(status) 33 | { 34 | case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; 35 | case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; 36 | case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; 37 | case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; 38 | case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; 39 | case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; 40 | case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; 41 | case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; 42 | } 43 | return "unknown error"; 44 | } 45 | 46 | static int checkCublasError(cublasStatus_t code, const char* expr, const char* file, int line, bool abort = true) { 47 | if (code) { 48 | fprintf(stderr, "CUBLAS error at %s:%d, code=%d (%s) in '%s'", file, line, (int) code, cublasGetErrorString(code), expr); 49 | if (abort) 50 | throw std::logic_error("CUBLAS Error."); 51 | } 52 | return 0; 53 | } 54 | 55 | #define checkCublasErr(...) do { int err = checkCublasError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); } while (0) 56 | 57 | static const char* curandGetErrorString(curandStatus_t status) 58 | { 59 | switch (status) 60 | { 61 | case CURAND_STATUS_SUCCESS: return "CURAND_STATUS_SUCCESS"; 62 | case CURAND_STATUS_VERSION_MISMATCH: return "CURAND_STATUS_VERSION_MISMATCH"; 63 | case CURAND_STATUS_NOT_INITIALIZED: return "CURAND_STATUS_NOT_INITIALIZED"; 64 | case CURAND_STATUS_ALLOCATION_FAILED: return "CURAND_STATUS_ALLOCATION_FAILED"; 65 | case CURAND_STATUS_TYPE_ERROR: return "CURAND_STATUS_TYPE_ERROR"; 66 | case CURAND_STATUS_OUT_OF_RANGE: return "CURAND_STATUS_OUT_OF_RANGE"; 67 | case CURAND_STATUS_LENGTH_NOT_MULTIPLE: return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; 68 | case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; 69 | case CURAND_STATUS_LAUNCH_FAILURE: return "CURAND_STATUS_LAUNCH_FAILURE"; 70 | case CURAND_STATUS_PREEXISTING_FAILURE: return "CURAND_STATUS_PREEXISTING_FAILURE"; 71 | case CURAND_STATUS_INITIALIZATION_FAILED: return "CURAND_STATUS_INITIALIZATION_FAILED"; 72 | case CURAND_STATUS_ARCH_MISMATCH: return "CURAND_STATUS_ARCH_MISMATCH"; 73 | case CURAND_STATUS_INTERNAL_ERROR: return "CURAND_STATUS_INTERNAL_ERROR"; 74 | } 75 | return "unknown error"; 76 | } 77 | 78 | static int checkCurandError(curandStatus_t code, const char* expr, const char* file, int line, bool abort = true) { 79 | if (code) { 80 | fprintf(stderr, "CURAND error at %s:%d, code=%d (%s) in '%s'", file, line, (int) code, curandGetErrorString(code), expr); 81 | if (abort) 82 | throw std::logic_error("CURAND Error."); 83 | } 84 | return 0; 85 | } 86 | 87 | #define checkCurandErr(...) do { int err = checkCurandError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); } while (0) 88 | 89 | } // namespace cuda 90 | } // namespace rll 91 | } // namespace hpc 92 | 93 | #endif // HPC_RLL_CUDA_STATUS_H_ 94 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/torch_utils/network/entry.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_NETWORK_H_ 2 | #define HPC_RLL_CUDA_NETWORK_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | 6 | namespace hpc { 7 | namespace rll { 8 | namespace cuda { 9 | 10 | // lstm 11 | void LstmForward( 12 | const std::vector& inputs, 13 | std::vector& outputs, 14 | float dropout_threshold); 15 | 16 | void LstmBackward( 17 | const std::vector& inputs, 18 | std::vector& outputs, 19 | float dropout_threshold); 20 | 21 | // scatter_connection 22 | void ScatterConnectionForward( 23 | const std::vector& inputs, 24 | std::vector& outputs, 25 | const char* scatter_type); 26 | 27 | void ScatterConnectionBackward( 28 | const std::vector& inputs, 29 | std::vector& outputs); 30 | 31 | } // namespace cuda 32 | } // namespace rll 33 | } // namespace hpc 34 | 35 | #endif // HPC_RLL_CUDA_NETWORK_H_ 36 | -------------------------------------------------------------------------------- /include/hpc/rll/cuda/torch_utils/network/scatter_connection_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef HPC_RLL_CUDA_SCATTER_CONNECTION_KERNEL_H_ 2 | #define HPC_RLL_CUDA_SCATTER_CONNECTION_KERNEL_H_ 3 | 4 | #include "hpc/rll/cuda/common.h" 5 | #include "hpc/rll/cuda/reduce.h" 6 | 7 | namespace hpc { 8 | namespace rll { 9 | namespace cuda { 10 | 11 | // deprecated 12 | // single block deal with one out element 13 | // single warp in block 14 | // consider input data overlap, and keep sequence 15 | void __global__ scatterConnectionCoverKeepSeqForwardKernel(unsigned int B, unsigned int M, unsigned int N, unsigned int H, unsigned int W, 16 | const float* in, const int64_t* location, float* out) { 17 | unsigned int lid = threadIdx.x; // each out element use 1 block 18 | unsigned int out_id = blockIdx.x; // B * N * H * W 19 | unsigned int bid = out_id / (N * H * W); 20 | unsigned int nid = (out_id % (N * H * W)) / (H * W); 21 | unsigned int hwid = out_id % (H * W); 22 | unsigned int bhw_id = bid * (H * W) + hwid; 23 | 24 | int in_id = -1; 25 | int max_id = -1; 26 | for (int b = 0; b < B; b++) { 27 | for (int m = lid; m < M; m += WARP_SIZE) { 28 | unsigned int y = location[b * M * 2 + m * 2 + 0]; 29 | unsigned int x = location[b * M * 2 + m * 2 + 1]; 30 | unsigned int target_id = b * H * W + y * W + x; 31 | 32 | if (bhw_id == target_id) { 33 | max_id = b * M + m; 34 | in_id = b * M * N + m * N + nid; 35 | } 36 | } 37 | } 38 | 39 | // reduce max idx 40 | float max_id_fp = max_id; 41 | int reduced_max_id = __float2int_rn(blockReduceMax(max_id_fp)); 42 | static __shared__ int s_max_id; 43 | if (lid == 0) { 44 | s_max_id = reduced_max_id; 45 | } 46 | __syncthreads(); 47 | 48 | if (s_max_id == -1) { 49 | if (lid == 0) 50 | out[out_id] = 0; 51 | } else { 52 | if (((s_max_id % M) % WARP_SIZE) == lid) { 53 | out[out_id] = in[in_id]; 54 | } 55 | } 56 | } 57 | 58 | // assuming no overlap of input data 59 | void __global__ scatterConnectionCoverForwardKernel(unsigned int B, unsigned int M, unsigned int N, unsigned int H, unsigned int W, 60 | const float* in, const int64_t* location, float* out) { 61 | unsigned int nid = threadIdx.x + blockIdx.x*blockDim.x; // N 62 | unsigned int mid = threadIdx.y + blockIdx.y*blockDim.y; // M 63 | unsigned int bid = blockIdx.z; // B 64 | if (nid >= N || mid >= M || bid >= B) return; 65 | 66 | unsigned int in_id = bid * M * N + mid * N + nid; 67 | 68 | unsigned int yid = location[bid * M * 2 + mid * 2 + 0]; 69 | unsigned int xid = location[bid * M * 2 + mid * 2 + 1]; 70 | 71 | unsigned int out_id = bid * N * H * W + nid * H * W + yid * W + xid; 72 | out[out_id] = in[in_id]; 73 | } 74 | 75 | void __global__ scatterConnectionAddForwardKernel(unsigned int B, unsigned int M, unsigned int N, unsigned int H, unsigned int W, 76 | const float* in, const int64_t* location, float* out) { 77 | unsigned int nid = threadIdx.x + blockIdx.x*blockDim.x; // N 78 | unsigned int mid = threadIdx.y + blockIdx.y*blockDim.y; // M 79 | unsigned int bid = blockIdx.z; // B 80 | if (nid >= N || mid >= M || bid >= B) return; 81 | 82 | unsigned int in_id = bid * M * N + mid * N + nid; 83 | 84 | unsigned int yid = location[bid * M * 2 + mid * 2 + 0]; 85 | unsigned int xid = location[bid * M * 2 + mid * 2 + 1]; 86 | 87 | unsigned int out_id = bid * N * H * W + nid * H * W + yid * W + xid; 88 | atomicAdd(&out[out_id], in[in_id]); 89 | } 90 | 91 | void __global__ scatterConnectionBackwardKernel(unsigned int B, unsigned int M, unsigned int N, unsigned int H, unsigned int W, 92 | const float* grad_out, const int64_t* location, float* grad_in) { 93 | unsigned int nid = threadIdx.x + blockIdx.x*blockDim.x; // N 94 | unsigned int mid = threadIdx.y + blockIdx.y*blockDim.y; // M 95 | unsigned int bid = blockIdx.z; // B 96 | if (nid >= N || mid >= M || bid >= B) return; 97 | 98 | unsigned int in_id = bid * M * N + mid * N + nid; 99 | 100 | unsigned int yid = location[bid * M * 2 + mid * 2 + 0]; 101 | unsigned int xid = location[bid * M * 2 + mid * 2 + 1]; 102 | 103 | unsigned int out_id = bid * N * H * W + nid * H * W + yid * W + xid; 104 | 105 | grad_in[in_id] = grad_out[out_id]; 106 | } 107 | 108 | } // namespace cuda 109 | } // namespace rll 110 | } // namespace hpc 111 | #endif // HPC_RLL_CUDA_SCATTER_CONNECTION_KERNEL_H_ 112 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from setuptools import setup 3 | import os 4 | import glob 5 | import torch 6 | import warnings 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | 9 | NAME = 'di_hpc_rll' 10 | VERSION = '0.0.2' 11 | DESC = 'GPU-Accelerated library for reinforcement learning' 12 | PLATFORMS = 'linux-x86_64' 13 | PACKAGES = ['hpc_rll', 'hpc_rll.origin', 'hpc_rll.rl_utils', 'hpc_rll.torch_utils', 'hpc_rll.torch_utils.network'] 14 | include_dirs = [os.path.join(os.getcwd(), 'include')] 15 | print('include_dirs', include_dirs) 16 | 17 | ext_modules = [] 18 | ext_modules.append( 19 | CUDAExtension('hpc_rl_utils', sources=[ 20 | 'src/rl_utils/entry.cpp', 21 | 'src/rl_utils/dist_nstep_td.cu', 22 | 'src/rl_utils/gae.cu', 23 | 'src/rl_utils/padding.cu', 24 | 'src/rl_utils/ppo.cu', 25 | 'src/rl_utils/q_nstep_td.cu', 26 | 'src/rl_utils/q_nstep_td_rescale.cu', 27 | 'src/rl_utils/td_lambda.cu', 28 | 'src/rl_utils/upgo.cu', 29 | 'src/rl_utils/vtrace.cu', 30 | 'src/rl_utils/iqn_nstep_td_error.cu', 31 | 'src/rl_utils/qrdqn_nstep_td_error.cu', 32 | 'src/models/actor_critic.cu', 33 | ], include_dirs=include_dirs) 34 | ) 35 | ext_modules.append( 36 | CUDAExtension('hpc_torch_utils_network', sources=[ 37 | 'src/torch_utils/network/entry.cpp', 38 | 'src/torch_utils/network/lstm.cu', 39 | 'src/torch_utils/network/scatter_connection.cu' 40 | ], include_dirs=include_dirs), 41 | ) 42 | 43 | if int("".join(list(filter(str.isdigit, torch.__version__)))) >= 120: 44 | ext_modules.append( 45 | CUDAExtension('hpc_models', sources=[ 46 | 'src/models/entry.cpp', 47 | 'src/models/actor_critic.cu', 48 | ], include_dirs=include_dirs), 49 | ) 50 | else: 51 | warnings.warn("Torch version is less than 1.2. BoolTensor is not yet well implemented. Thus we skip the compiliation of hpc_models.") 52 | 53 | setup( 54 | name = NAME, 55 | version = VERSION, 56 | description = DESC, 57 | platforms = PLATFORMS, 58 | packages = PACKAGES, 59 | ext_modules=ext_modules, 60 | cmdclass={ 61 | 'build_ext': BuildExtension 62 | } 63 | ) 64 | -------------------------------------------------------------------------------- /src/models/actor_critic.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/models/entry.h" 2 | #include "hpc/rll/cuda/models/actor_critic_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void actor_critic_update_ae( 9 | const std::vector& inputs, 10 | std::vector& outputs) { 11 | unsigned index = 0; 12 | const torch::Tensor& key_embeddings = inputs[index++]; 13 | const torch::Tensor& sample_entity = inputs[index++]; 14 | const torch::Tensor& entity_num = inputs[index++]; 15 | index = 0; 16 | torch::Tensor& autoregressive_embedding = outputs[index++]; 17 | 18 | int64_t batch_size = key_embeddings.size(0); 19 | int64_t max_entity_num = key_embeddings.size(1); 20 | int64_t input_dim = key_embeddings.size(2); 21 | { 22 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 23 | unsigned int grid_size_x = (input_dim + block_size.x - 1) / block_size.x; 24 | unsigned int grid_size_y = batch_size; 25 | dim3 grid_size = {grid_size_x, grid_size_y, 1}; 26 | autoregressive_embedding_fp<<>>(batch_size, max_entity_num, input_dim, 27 | (int64_t*)(sample_entity.data_ptr()), (int64_t*)(entity_num.data_ptr()), 28 | (float*)(key_embeddings.data_ptr()), (float*)(autoregressive_embedding.data_ptr())); 29 | } 30 | } 31 | 32 | void actor_critic_lstm_activation( 33 | const std::vector& inputs, 34 | std::vector& outputs) { 35 | unsigned index = 0; 36 | 37 | const torch::Tensor& lstm_ih = inputs[index++]; 38 | const torch::Tensor& lstm_hh = inputs[index++]; 39 | const torch::Tensor& lstm_bias = inputs[index++]; 40 | index = 0; 41 | torch::Tensor& lstm_hx = outputs[index++]; 42 | torch::Tensor& lstm_cx = outputs[index++]; 43 | 44 | int64_t batch_size = lstm_ih.size(0); 45 | int64_t hidden_size = lstm_ih.size(1) / 4; 46 | { 47 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 48 | unsigned int grid_size_x = (hidden_size + block_size.x - 1) / block_size.x; 49 | unsigned int grid_size_y = batch_size; 50 | dim3 grid_size = {grid_size_x, grid_size_y, 1}; 51 | lstm_activation_fp<<>>(batch_size, hidden_size, 52 | (float*)(lstm_ih.data_ptr()), (float*)(lstm_hh.data_ptr()), (float*)(lstm_bias.data_ptr()), 53 | (float*)(lstm_hx.data_ptr()), (float*)(lstm_cx.data_ptr())); 54 | } 55 | } 56 | 57 | void actor_critic_pre_sample( 58 | const std::vector& inputs, 59 | std::vector& outputs) { 60 | unsigned index = 0; 61 | 62 | const torch::Tensor& mat = inputs[index++]; 63 | const torch::Tensor& vec = inputs[index++]; 64 | const torch::Tensor& mask = inputs[index++]; 65 | index = 0; 66 | torch::Tensor& output = outputs[index++]; 67 | 68 | int64_t batch_size = mat.size(0); 69 | int64_t max_entity_num = mat.size(1); 70 | int64_t hidden_size = mat.size(2); 71 | { 72 | dim3 block_size = {WARP_SIZE, 1, 1}; 73 | unsigned int grid_size_x = 1; 74 | unsigned int grid_size_y = max_entity_num; 75 | unsigned int grid_size_z = batch_size; 76 | dim3 grid_size = {grid_size_x, grid_size_y, grid_size_z}; 77 | const float mask_value = -1e9; 78 | const float div_factor = 0.8; 79 | pre_sample_fp<<>>(batch_size, max_entity_num, hidden_size, mask_value, div_factor, 80 | (float*)(mat.data_ptr()), (float*)(vec.data_ptr()), (bool*)(mask.data_ptr()), 81 | (float*)(output.data_ptr())); 82 | } 83 | } 84 | 85 | } // namespace cuda 86 | } // namespace rll 87 | } // namespace hpc 88 | -------------------------------------------------------------------------------- /src/models/entry.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "hpc/rll/cuda/models/entry.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 9 | m.def("actor_critic_update_ae", &actor_critic_update_ae, "actor critic model update autoregressive embedding (CUDA)"); 10 | m.def("actor_critic_lstm_activation", &actor_critic_lstm_activation, "actor critic model lstm activation (CUDA)"); 11 | m.def("actor_critic_pre_sample", &actor_critic_pre_sample, "actor critic model pre sample (CUDA)"); 12 | } 13 | 14 | } // namespace cuda 15 | } // namespace rll 16 | } // namespace hpc 17 | -------------------------------------------------------------------------------- /src/rl_utils/dist_nstep_td.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/rl_utils/entry.h" 2 | #include "hpc/rll/cuda/rl_utils/dist_nstep_td_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void DistNStepTdForward( 9 | const std::vector& inputs, 10 | std::vector& outputs, 11 | float gamma, 12 | float v_min, 13 | float v_max) { 14 | 15 | unsigned int index = 0; 16 | const torch::Tensor& dist = inputs[index++]; 17 | const torch::Tensor& next_n_dist = inputs[index++]; 18 | const torch::Tensor& action = inputs[index++]; 19 | const torch::Tensor& next_n_action = inputs[index++]; 20 | const torch::Tensor& reward = inputs[index++]; 21 | const torch::Tensor& done = inputs[index++]; 22 | const torch::Tensor& weight = inputs[index++]; 23 | index = 0; 24 | torch::Tensor& td_err = outputs[index++]; 25 | torch::Tensor& loss = outputs[index++]; 26 | torch::Tensor& buf = outputs[index++]; 27 | 28 | // set zero for atomic add 29 | checkCudaErr(cudaMemsetAsync(loss.data_ptr(), 0, sizeof(float) * loss.numel())); 30 | checkCudaErr(cudaMemsetAsync(buf.data_ptr(), 0, sizeof(float) * buf.numel())); 31 | 32 | const unsigned int time_step = reward.size(0); 33 | const unsigned int batch_size = dist.size(0); 34 | const unsigned int action_dim = dist.size(1); 35 | const unsigned int n_atom = dist.size(2); 36 | 37 | // buf0: B for reward x fp reward_factor 38 | // buf1: (B * n_atom) for fp proj_dist and bp grad 39 | float* buf0 = (float*)(buf.data_ptr()); 40 | float* buf1 = (float*)(buf.data_ptr()) + batch_size; 41 | 42 | { 43 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 44 | unsigned int grid_size = (batch_size + block_size - 1) / block_size; 45 | distNStepTdRewardKernel<<>>( 46 | time_step, batch_size, gamma, (float*)(reward.data_ptr()), buf0); 47 | } 48 | 49 | { 50 | float gamma_nstep = 1.f; 51 | for (int t = 0; t < time_step; t++) 52 | gamma_nstep *= gamma; 53 | float delta = (v_max - v_min) / (n_atom - 1); 54 | 55 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 56 | dim3 grid_size = {(n_atom + block_size.x - 1) / block_size.x, batch_size, 1}; 57 | distNStepTdProjKernel<<>>( 58 | batch_size, action_dim, n_atom, gamma_nstep, v_min, v_max, delta, 59 | (float*)(next_n_dist.data_ptr()), (int64_t*)(next_n_action.data_ptr()), 60 | buf0, (float*)(done.data_ptr()), buf1); 61 | } 62 | 63 | { 64 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 65 | dim3 grid_size = {(n_atom + block_size.x - 1) / block_size.x, batch_size, 1}; 66 | distNStepTdLossKernel<<>>( 67 | batch_size, action_dim, n_atom, (float*)(dist.data_ptr()), (int64_t*)(action.data_ptr()), 68 | (const float*)buf1, (float*)(weight.data_ptr()), 69 | (float*)(td_err.data_ptr()), (float*)(loss.data_ptr()), buf1); 70 | } 71 | } 72 | 73 | void DistNStepTdBackward( 74 | const std::vector& inputs, 75 | std::vector& outputs) { 76 | 77 | unsigned int index = 0; 78 | const torch::Tensor& grad_loss = inputs[index++]; 79 | const torch::Tensor& buf = inputs[index++]; 80 | const torch::Tensor& action = inputs[index++]; 81 | index = 0; 82 | torch::Tensor& grad_dist = outputs[index++]; 83 | 84 | const unsigned int batch_size = grad_dist.size(0); 85 | const unsigned int action_dim = grad_dist.size(1); 86 | const unsigned int n_atom = grad_dist.size(2); 87 | 88 | // buf0: B for reward x fp reward_factor 89 | // buf1: (B * n_atom) for fp proj_dist and bp grad, here used for bp grad 90 | float* grad_buf = (float*)(buf.data_ptr()) + batch_size; 91 | 92 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 93 | unsigned int grid_size = (batch_size * action_dim * n_atom + block_size - 1) / block_size; 94 | distNStepTdBackwardKernel<<>>( 95 | batch_size, action_dim, n_atom, (float*)(grad_loss.data_ptr()), grad_buf, 96 | (int64_t*)(action.data_ptr()), (float*)(grad_dist.data_ptr())); 97 | } 98 | 99 | } // namespace cuda 100 | } // namespace rll 101 | } // namespace hpc 102 | 103 | -------------------------------------------------------------------------------- /src/rl_utils/entry.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "hpc/rll/cuda/rl_utils/entry.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 9 | m.def("sample_split_group", &sample_split_group, "sample_split_group"); 10 | m.def("oracle_split_group", &oracle_split_group, "oracle_split_group"); 11 | m.def("Pad1DForward", &Pad1DForward, "Pad1D forward (CUDA)"); 12 | m.def("GroupPad1DForward", &GroupPad1DForward, "Pad1D forward (CUDA)"); 13 | m.def("Unpad1DForward", &Unpad1DForward, "Unpad1D forward (CUDA)"); 14 | m.def("Pad2DForward", &Pad2DForward, "Pad2D forward (CUDA)"); 15 | m.def("GroupPad2DForward", &GroupPad2DForward, "Pad1D forward (CUDA)"); 16 | m.def("Unpad2DForward", &Unpad2DForward, "Unpad2D forward (CUDA)"); 17 | m.def("Pad3DForward", &Pad3DForward, "Pad2D forward (CUDA)"); 18 | m.def("GroupPad3DForward", &GroupPad3DForward, "Pad1D forward (CUDA)"); 19 | m.def("Unpad3DForward", &Unpad3DForward, "Unpad2D forward (CUDA)"); 20 | m.def("DistNStepTdForward", &DistNStepTdForward, "dist_nstep_td forward (CUDA)"); 21 | m.def("DistNStepTdBackward", &DistNStepTdBackward, "dist_nstep_td backward (CUDA)"); 22 | m.def("GaeForward", &GaeForward, "gae forward (CUDA)"); 23 | m.def("PPOForward", &PPOForward, "ppo forward (CUDA)"); 24 | m.def("PPOBackward", &PPOBackward, "ppo backward (CUDA)"); 25 | m.def("QNStepTdForward", &QNStepTdForward, "q_nstep_td forward (CUDA)"); 26 | m.def("QNStepTdBackward", &QNStepTdBackward, "q_nstep_td backward (CUDA)"); 27 | m.def("QNStepTdRescaleForward", &QNStepTdRescaleForward, "q_nstep_td_with_rescale forward (CUDA)"); 28 | m.def("QNStepTdRescaleBackward", &QNStepTdRescaleBackward, "q_nstep_td_with_rescale backward (CUDA)"); 29 | m.def("TdLambdaForward", &TdLambdaForward, "td_lambda forward (CUDA)"); 30 | m.def("TdLambdaBackward", &TdLambdaBackward, "td_lambda backward (CUDA)"); 31 | m.def("UpgoForward", &UpgoForward, "upgo forward (CUDA)"); 32 | m.def("UpgoBackward", &UpgoBackward, "upgo backward (CUDA)"); 33 | m.def("VTraceForward", &VTraceForward, "vtrace forward (CUDA)"); 34 | m.def("VTraceBackward", &VTraceBackward, "vtrace backward (CUDA)"); 35 | m.def("IQNNStepTDErrorForward", &IQNNStepTDErrorForward, "iqn_nstep_td_error forward (CUDA)"); 36 | m.def("IQNNStepTDErrorBackward", &IQNNStepTDErrorBackward, "iqn_nstep_td_error backward (CUDA)"); 37 | m.def("QRDQNNStepTDErrorForward", &QRDQNNStepTDErrorForward, "qrdqn_nstep_td_error forward (CUDA)"); 38 | m.def("QRDQNNStepTDErrorBackward", &QRDQNNStepTDErrorBackward, "qrdqn_nstep_td_error backward (CUDA)"); 39 | } 40 | 41 | } // namespace cuda 42 | } // namespace rll 43 | } // namespace hpc 44 | 45 | -------------------------------------------------------------------------------- /src/rl_utils/gae.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/rl_utils/entry.h" 2 | #include "hpc/rll/cuda/rl_utils/gae_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void GaeForward( 9 | const std::vector& inputs, 10 | std::vector& outputs, 11 | float gamma, 12 | float lambda) { 13 | 14 | unsigned int index = 0; 15 | const torch::Tensor& value = inputs[index++]; 16 | const torch::Tensor& reward = inputs[index++]; 17 | index = 0; 18 | torch::Tensor& adv = outputs[index++]; 19 | 20 | const unsigned int time_step = reward.size(0); 21 | const unsigned int batch_size = reward.size(1); 22 | 23 | unsigned int block_size = 1 * WARP_SIZE; // single warp to utilize more blocks 24 | unsigned int grid_size = (batch_size + block_size - 1) / block_size; 25 | gaeForwardKernel<<>>( 26 | time_step, batch_size, gamma, lambda, 27 | (float*)(value.data_ptr()), (float*)(reward.data_ptr()), (float*)(adv.data_ptr())); 28 | } 29 | 30 | } // namespace cuda 31 | } // namespace rll 32 | } // namespace hpc 33 | -------------------------------------------------------------------------------- /src/rl_utils/iqn_nstep_td_error.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/rl_utils/entry.h" 2 | #include "hpc/rll/cuda/rl_utils/iqn_nstep_td_error_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void IQNNStepTDErrorForward( 9 | const std::vector& inputs, 10 | std::vector& outputs, 11 | float gamma, 12 | float kappa) { 13 | 14 | unsigned int index = 0; 15 | const torch::Tensor& q = inputs[index++]; 16 | const torch::Tensor& next_n_q = inputs[index++]; 17 | const torch::Tensor& action = inputs[index++]; 18 | const torch::Tensor& next_n_action = inputs[index++]; 19 | const torch::Tensor& reward = inputs[index++]; 20 | const torch::Tensor& done = inputs[index++]; 21 | const torch::Tensor& replay_quantiles = inputs[index++]; 22 | const torch::Tensor& weight = inputs[index++]; 23 | const torch::Tensor& value_gamma = inputs[index++]; 24 | index = 0; 25 | torch::Tensor& loss = outputs[index++]; 26 | torch::Tensor& td_err = outputs[index++]; 27 | torch::Tensor& bellman_err_buf = outputs[index++]; 28 | torch::Tensor& quantile_huber_loss_buf = outputs[index++]; 29 | torch::Tensor& grad_buf = outputs[index++]; 30 | 31 | // set zero for atomic add 32 | checkCudaErr(cudaMemsetAsync((float*)(loss.data_ptr()), 0, sizeof(float))); 33 | 34 | const unsigned int tau = q.size(0); 35 | const unsigned int tau_prime = next_n_q.size(0); 36 | const unsigned int time_step = reward.size(0); 37 | const unsigned int batch_size = q.size(1); 38 | const unsigned int action_dim = q.size(2); 39 | 40 | { 41 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 42 | dim3 grid_size = {(tau + block_size.x - 1) / block_size.x, batch_size, 1}; 43 | 44 | bellmanErrorKernel<<>>( 45 | tau, tau_prime, time_step, batch_size, action_dim, gamma, kappa, 46 | (float*)(q.data_ptr()), (float*)(next_n_q.data_ptr()), 47 | (int64_t*)(action.data_ptr()), (int64_t*)(next_n_action.data_ptr()), 48 | (float*)(reward.data_ptr()), (float*)(done.data_ptr()), 49 | (float*)(value_gamma.data_ptr()), (float*)(bellman_err_buf.data_ptr())); 50 | } 51 | 52 | { 53 | dim3 block_size = {WARP_SIZE, DEFAULT_WARP_NUM, 1}; 54 | dim3 grid_size = {(tau + block_size.x - 1) / block_size.x, (tau_prime + block_size.y - 1) / block_size.y, batch_size}; 55 | 56 | quantileHuberErrorKernel<<>> ( 57 | tau, tau_prime, batch_size, kappa, 58 | (float*)(bellman_err_buf.data_ptr()), (float*)(replay_quantiles.data_ptr()), 59 | (float*)(quantile_huber_loss_buf.data_ptr()), (float*)(grad_buf.data_ptr())); 60 | } 61 | 62 | { 63 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 64 | dim3 grid_size = {batch_size, 1, 1}; 65 | 66 | lossKernel<<>> ( 67 | tau, tau_prime, batch_size, 68 | (float*)(quantile_huber_loss_buf.data_ptr()), (float*)(weight.data_ptr()), 69 | (float*)(td_err.data_ptr()), (float*)(loss.data_ptr())); 70 | } 71 | } 72 | 73 | void IQNNStepTDErrorBackward( 74 | const std::vector& inputs, 75 | std::vector& outputs) { 76 | 77 | unsigned int index = 0; 78 | const torch::Tensor& grad_loss = inputs[index++]; 79 | const torch::Tensor& grad_buf = inputs[index++]; 80 | const torch::Tensor& weight = inputs[index++]; 81 | const torch::Tensor& action = inputs[index++]; 82 | index = 0; 83 | torch::Tensor& grad_q = outputs[index++]; 84 | 85 | const unsigned int batch_size = grad_buf.size(0); 86 | const unsigned int tau_prime = grad_buf.size(1); 87 | const unsigned int tau = grad_buf.size(2); 88 | const unsigned int action_dim = grad_q.size(2); 89 | 90 | // set zero 91 | checkCudaErr(cudaMemsetAsync((float*)(grad_q.data_ptr()), 0, tau * batch_size * action_dim * sizeof(float))); 92 | 93 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 94 | dim3 grid_size = {(tau + block_size.x - 1) / block_size.x, batch_size, 1}; 95 | backwardKernel<<>>( 96 | tau, tau_prime, batch_size, action_dim, 97 | (float*)(grad_loss.data_ptr()), (float*)(grad_buf.data_ptr()), 98 | (float*)(weight.data_ptr()), (int64_t*)(action.data_ptr()), 99 | (float*)(grad_q.data_ptr())); 100 | } 101 | 102 | } // namespace cuda 103 | } // namespace rll 104 | } // namespace hpc 105 | -------------------------------------------------------------------------------- /src/rl_utils/ppo.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/rl_utils/entry.h" 2 | #include "hpc/rll/cuda/rl_utils/ppo_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void PPOForward( 9 | const std::vector& inputs, 10 | std::vector& outputs, 11 | bool use_value_clip, 12 | float clip_ratio, 13 | float dual_clip) { 14 | 15 | unsigned int index = 0; 16 | const torch::Tensor& logits_new = inputs[index++]; 17 | const torch::Tensor& logits_old = inputs[index++]; 18 | const torch::Tensor& action = inputs[index++]; 19 | const torch::Tensor& value_new = inputs[index++]; 20 | const torch::Tensor& value_old = inputs[index++]; 21 | const torch::Tensor& adv = inputs[index++]; 22 | const torch::Tensor& return_ = inputs[index++]; 23 | const torch::Tensor& weight = inputs[index++]; 24 | 25 | index = 0; 26 | torch::Tensor& logits_new_prob = outputs[index++]; 27 | torch::Tensor& logits_new_entropy = outputs[index++]; 28 | torch::Tensor& logits_new_grad_logits = outputs[index++]; 29 | torch::Tensor& logits_new_grad_prob = outputs[index++]; 30 | torch::Tensor& logits_new_grad_entropy = outputs[index++]; 31 | torch::Tensor& logits_old_prob = outputs[index++]; 32 | torch::Tensor& grad_policy_loss_buf = outputs[index++]; 33 | torch::Tensor& grad_value_loss_buf = outputs[index++]; 34 | torch::Tensor& grad_entropy_loss_buf = outputs[index++]; 35 | torch::Tensor& policy_loss = outputs[index++]; 36 | torch::Tensor& value_loss = outputs[index++]; 37 | torch::Tensor& entropy_loss = outputs[index++]; 38 | torch::Tensor& approx_kl = outputs[index++]; 39 | torch::Tensor& clipfrac = outputs[index++]; 40 | 41 | checkCudaErr(cudaMemsetAsync((float*)(policy_loss.data_ptr()), 0, sizeof(float))); 42 | checkCudaErr(cudaMemsetAsync((float*)(value_loss.data_ptr()), 0, sizeof(float))); 43 | checkCudaErr(cudaMemsetAsync((float*)(entropy_loss.data_ptr()), 0, sizeof(float))); 44 | checkCudaErr(cudaMemsetAsync((float*)(approx_kl.data_ptr()), 0, sizeof(float))); 45 | checkCudaErr(cudaMemsetAsync((float*)(clipfrac.data_ptr()), 0, sizeof(float))); 46 | 47 | const unsigned int batch_size = logits_new.size(0); 48 | const unsigned int num_output = logits_new.size(1); 49 | { 50 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 51 | unsigned int grid_size = batch_size; 52 | categoricalProbEntropy<<>>( 53 | num_output, (float*)(logits_new.data_ptr()), (int64_t*)(action.data_ptr()), 54 | (float*)(logits_new_prob.data_ptr()), (float*)(logits_new_entropy.data_ptr()), 55 | (float*)(logits_new_grad_logits.data_ptr()), (float*)(logits_new_grad_prob.data_ptr()), 56 | (float*)(logits_new_grad_entropy.data_ptr())); 57 | categoricalProb<<>>( 58 | num_output, (float*)(logits_old.data_ptr()), (int64_t*)(action.data_ptr()), (float*)(logits_old_prob.data_ptr())); 59 | } 60 | { 61 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 62 | unsigned int grid_size = (batch_size + block_size - 1) / block_size; 63 | ppoLoss<<>>( 64 | batch_size, (float*)(value_new.data_ptr()), (float*)(value_old.data_ptr()), 65 | (float*)(logits_new_prob.data_ptr()), (float*)(logits_old_prob.data_ptr()), (float*)(logits_new_entropy.data_ptr()), 66 | (float*)(adv.data_ptr()), (float*)(return_.data_ptr()), (float*)(weight.data_ptr()), 67 | use_value_clip, clip_ratio, dual_clip, 68 | (float*)(policy_loss.data_ptr()), (float*)(value_loss.data_ptr()), (float*)(entropy_loss.data_ptr()), 69 | (float*)(approx_kl.data_ptr()), (float*)(clipfrac.data_ptr()), 70 | (float*)(grad_policy_loss_buf.data_ptr()), (float*)(grad_value_loss_buf.data_ptr()), 71 | (float*)(grad_entropy_loss_buf.data_ptr())); 72 | } 73 | } 74 | 75 | void PPOBackward( 76 | const std::vector& inputs, 77 | std::vector& outputs) { 78 | 79 | unsigned int index = 0; 80 | const torch::Tensor& grad_policy_loss = inputs[index++]; 81 | const torch::Tensor& grad_value_loss = inputs[index++]; 82 | const torch::Tensor& grad_entropy_loss = inputs[index++]; 83 | const torch::Tensor& grad_policy_loss_buf = inputs[index++]; 84 | const torch::Tensor& grad_value_loss_buf = inputs[index++]; 85 | const torch::Tensor& grad_entropy_loss_buf = inputs[index++]; 86 | const torch::Tensor& logits_new_grad_logits = inputs[index++]; 87 | const torch::Tensor& logits_new_grad_prob = inputs[index++]; 88 | const torch::Tensor& logits_new_grad_entropy = inputs[index++]; 89 | 90 | index = 0; 91 | torch::Tensor& grad_value = outputs[index++]; 92 | torch::Tensor& grad_logits_new = outputs[index++]; 93 | 94 | const unsigned int batch_size = grad_logits_new.size(0); 95 | const unsigned int num_output = grad_logits_new.size(1); 96 | { 97 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 98 | unsigned int grid_size = (batch_size + block_size - 1) / block_size; 99 | ppoBackwardValueNew<<>>( 100 | batch_size, (float*)(grad_value_loss.data_ptr()), (float*)(grad_value_loss_buf.data_ptr()), (float*)(grad_value.data_ptr())); 101 | } 102 | { 103 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 104 | unsigned int grid_size = batch_size; 105 | ppoBackwardLogitsNew<<>>( 106 | batch_size, num_output, (float*)(grad_policy_loss.data_ptr()), (float*)(grad_entropy_loss.data_ptr()), 107 | (float*)(grad_policy_loss_buf.data_ptr()), (float*)(grad_entropy_loss_buf.data_ptr()), 108 | (float*)(logits_new_grad_logits.data_ptr()), (float*)(logits_new_grad_prob.data_ptr()), 109 | (float*)(logits_new_grad_entropy.data_ptr()), (float*)(grad_logits_new.data_ptr())); 110 | } 111 | } 112 | 113 | } // namespace cuda 114 | } // namespace rll 115 | } // namespace hpc 116 | 117 | -------------------------------------------------------------------------------- /src/rl_utils/q_nstep_td.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/rl_utils/entry.h" 2 | #include "hpc/rll/cuda/rl_utils/q_nstep_td_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void QNStepTdForward( 9 | const std::vector& inputs, 10 | std::vector& outputs, 11 | float gamma) { 12 | 13 | unsigned int index = 0; 14 | const torch::Tensor& q = inputs[index++]; 15 | const torch::Tensor& next_n_q = inputs[index++]; 16 | const torch::Tensor& action = inputs[index++]; 17 | const torch::Tensor& next_n_action = inputs[index++]; 18 | const torch::Tensor& reward = inputs[index++]; 19 | const torch::Tensor& done = inputs[index++]; 20 | const torch::Tensor& weight = inputs[index++]; 21 | index = 0; 22 | torch::Tensor& td_err = outputs[index++]; 23 | torch::Tensor& loss = outputs[index++]; 24 | torch::Tensor& grad_buf = outputs[index++]; 25 | 26 | // set zero for atomic add 27 | checkCudaErr(cudaMemsetAsync((float*)(loss.data_ptr()), 0, sizeof(float))); 28 | 29 | const unsigned int time_step = reward.size(0); 30 | const unsigned int batch_size = q.size(0); 31 | const unsigned int num_output = q.size(1); 32 | 33 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 34 | unsigned int grid_size = (batch_size + block_size - 1) / block_size; 35 | qNStepTdForwardKernel<<>>( 36 | time_step, batch_size, num_output, gamma, 37 | (float*)(q.data_ptr()), (float*)(next_n_q.data_ptr()), 38 | (int64_t*)(action.data_ptr()), (int64_t*)(next_n_action.data_ptr()), 39 | (float*)(reward.data_ptr()), (float*)(done.data_ptr()), (float*)(weight.data_ptr()), 40 | (float*)(td_err.data_ptr()), (float*)(loss.data_ptr()), (float*)(grad_buf.data_ptr())); 41 | } 42 | 43 | void QNStepTdBackward( 44 | const std::vector& inputs, 45 | std::vector& outputs) { 46 | 47 | unsigned int index = 0; 48 | const torch::Tensor& grad_loss = inputs[index++]; 49 | const torch::Tensor& grad_buf = inputs[index++]; 50 | const torch::Tensor& action = inputs[index++]; 51 | index = 0; 52 | torch::Tensor& grad_q = outputs[index++]; 53 | 54 | const unsigned int batch_size = grad_q.size(0); 55 | const unsigned int num_output = grad_q.size(1); 56 | 57 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 58 | dim3 grid_size = {(num_output + block_size.x - 1) / block_size.x, batch_size, 1}; 59 | qNStepTdBackwardKernel<<>>( 60 | batch_size, num_output, 61 | (float*)(grad_loss.data_ptr()), (float*)(grad_buf.data_ptr()), 62 | (int64_t*)(action.data_ptr()), (float*)(grad_q.data_ptr())); 63 | } 64 | 65 | } // namespace cuda 66 | } // namespace rll 67 | } // namespace hpc 68 | -------------------------------------------------------------------------------- /src/rl_utils/q_nstep_td_rescale.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/rl_utils/entry.h" 2 | #include "hpc/rll/cuda/rl_utils/q_nstep_td_rescale_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void QNStepTdRescaleForward( 9 | const std::vector& inputs, 10 | std::vector& outputs, 11 | float gamma) { 12 | 13 | unsigned int index = 0; 14 | const torch::Tensor& q = inputs[index++]; 15 | const torch::Tensor& next_n_q = inputs[index++]; 16 | const torch::Tensor& action = inputs[index++]; 17 | const torch::Tensor& next_n_action = inputs[index++]; 18 | const torch::Tensor& reward = inputs[index++]; 19 | const torch::Tensor& done = inputs[index++]; 20 | const torch::Tensor& weight = inputs[index++]; 21 | index = 0; 22 | torch::Tensor& td_err = outputs[index++]; 23 | torch::Tensor& loss = outputs[index++]; 24 | torch::Tensor& grad_buf = outputs[index++]; 25 | 26 | // set zero for atomic add 27 | checkCudaErr(cudaMemsetAsync((float*)(loss.data_ptr()), 0, sizeof(float))); 28 | 29 | const unsigned int time_step = reward.size(0); 30 | const unsigned int batch_size = q.size(0); 31 | const unsigned int num_output = q.size(1); 32 | 33 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 34 | unsigned int grid_size = (batch_size + block_size - 1) / block_size; 35 | qNStepTdRescaleForwardKernel<<>>( 36 | time_step, batch_size, num_output, gamma, 37 | (float*)(q.data_ptr()), (float*)(next_n_q.data_ptr()), 38 | (int64_t*)(action.data_ptr()), (int64_t*)(next_n_action.data_ptr()), 39 | (float*)(reward.data_ptr()), (float*)(done.data_ptr()), (float*)(weight.data_ptr()), 40 | (float*)(td_err.data_ptr()), (float*)(loss.data_ptr()), (float*)(grad_buf.data_ptr())); 41 | } 42 | 43 | void QNStepTdRescaleBackward( 44 | const std::vector& inputs, 45 | std::vector& outputs) { 46 | 47 | unsigned int index = 0; 48 | const torch::Tensor& grad_loss = inputs[index++]; 49 | const torch::Tensor& grad_buf = inputs[index++]; 50 | const torch::Tensor& action = inputs[index++]; 51 | index = 0; 52 | torch::Tensor& grad_q = outputs[index++]; 53 | 54 | const unsigned int batch_size = grad_q.size(0); 55 | const unsigned int num_output = grad_q.size(1); 56 | 57 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 58 | dim3 grid_size = {(num_output + block_size.x - 1) / block_size.x, batch_size, 1}; 59 | qNStepTdRescaleBackwardKernel<<>>( 60 | batch_size, num_output, 61 | (float*)(grad_loss.data_ptr()), (float*)(grad_buf.data_ptr()), 62 | (int64_t*)(action.data_ptr()), (float*)(grad_q.data_ptr())); 63 | } 64 | 65 | } // namespace cuda 66 | } // namespace rll 67 | } // namespace hpc 68 | -------------------------------------------------------------------------------- /src/rl_utils/qrdqn_nstep_td_error.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/rl_utils/entry.h" 2 | #include "hpc/rll/cuda/rl_utils/qrdqn_nstep_td_error_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void QRDQNNStepTDErrorForward( 9 | const std::vector& inputs, 10 | std::vector& outputs, 11 | float gamma) { 12 | 13 | unsigned int index = 0; 14 | const torch::Tensor& q = inputs[index++]; 15 | const torch::Tensor& next_n_q = inputs[index++]; 16 | const torch::Tensor& action = inputs[index++]; 17 | const torch::Tensor& next_n_action = inputs[index++]; 18 | const torch::Tensor& reward = inputs[index++]; 19 | const torch::Tensor& done = inputs[index++]; 20 | const torch::Tensor& weight = inputs[index++]; 21 | const torch::Tensor& value_gamma = inputs[index++]; 22 | index = 0; 23 | torch::Tensor& loss = outputs[index++]; 24 | torch::Tensor& td_err = outputs[index++]; 25 | torch::Tensor& bellman_err_buf = outputs[index++]; 26 | torch::Tensor& quantile_huber_loss_buf = outputs[index++]; 27 | torch::Tensor& grad_buf = outputs[index++]; 28 | 29 | // set zero for atomic add 30 | checkCudaErr(cudaMemsetAsync((float*)(loss.data_ptr()), 0, sizeof(float))); 31 | 32 | const unsigned int batch_size = q.size(0); 33 | const unsigned int action_dim = q.size(1); 34 | const unsigned int tau = q.size(2); 35 | const unsigned int time_step = reward.size(0); 36 | 37 | { 38 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 39 | dim3 grid_size = {(tau + block_size.x - 1) / block_size.x, batch_size, 1}; 40 | 41 | bellmanErrorKernel<<>>( 42 | tau, time_step, batch_size, action_dim, gamma, 43 | (float*)(q.data_ptr()), (float*)(next_n_q.data_ptr()), 44 | (int64_t*)(action.data_ptr()), (int64_t*)(next_n_action.data_ptr()), 45 | (float*)(reward.data_ptr()), (float*)(done.data_ptr()), 46 | (float*)(value_gamma.data_ptr()), (float*)(bellman_err_buf.data_ptr())); 47 | } 48 | 49 | { 50 | dim3 block_size = {WARP_SIZE, DEFAULT_WARP_NUM, 1}; 51 | dim3 grid_size = {(tau + block_size.x - 1) / block_size.x, (tau + block_size.y - 1) / block_size.y, batch_size}; 52 | 53 | smoothL1LossKernel<<>> ( 54 | tau, batch_size, 55 | (float*)(bellman_err_buf.data_ptr()), (float*)(quantile_huber_loss_buf.data_ptr()), (float*)(grad_buf.data_ptr())); 56 | } 57 | 58 | { 59 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 60 | dim3 grid_size = {batch_size, 1, 1}; 61 | 62 | lossKernel<<>> ( 63 | tau, batch_size, 64 | (float*)(quantile_huber_loss_buf.data_ptr()), (float*)(weight.data_ptr()), 65 | (float*)(td_err.data_ptr()), (float*)(loss.data_ptr())); 66 | } 67 | } 68 | 69 | void QRDQNNStepTDErrorBackward( 70 | const std::vector& inputs, 71 | std::vector& outputs) { 72 | 73 | unsigned int index = 0; 74 | const torch::Tensor& grad_loss = inputs[index++]; 75 | const torch::Tensor& grad_buf = inputs[index++]; 76 | const torch::Tensor& weight = inputs[index++]; 77 | const torch::Tensor& action = inputs[index++]; 78 | index = 0; 79 | torch::Tensor& grad_q = outputs[index++]; 80 | 81 | const unsigned int batch_size = grad_q.size(0); 82 | const unsigned int action_dim = grad_q.size(1); 83 | const unsigned int tau = grad_q.size(2); 84 | 85 | // set zero 86 | checkCudaErr(cudaMemsetAsync((float*)(grad_q.data_ptr()), 0, tau * batch_size * action_dim * sizeof(float))); 87 | 88 | dim3 block_size = {DEFAULT_WARP_NUM * WARP_SIZE, 1, 1}; 89 | dim3 grid_size = {(tau + block_size.x - 1) / block_size.x, batch_size, 1}; 90 | backwardKernel<<>>( 91 | tau, batch_size, action_dim, 92 | (float*)(grad_loss.data_ptr()), (float*)(grad_buf.data_ptr()), 93 | (float*)(weight.data_ptr()), (int64_t*)(action.data_ptr()), 94 | (float*)(grad_q.data_ptr())); 95 | } 96 | 97 | } // namespace cuda 98 | } // namespace rll 99 | } // namespace hpc 100 | -------------------------------------------------------------------------------- /src/rl_utils/td_lambda.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/rl_utils/entry.h" 2 | #include "hpc/rll/cuda/rl_utils/td_lambda_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void TdLambdaForward( 9 | const std::vector& inputs, 10 | std::vector& outputs, 11 | float gamma, 12 | float lambda) { 13 | 14 | unsigned int index = 0; 15 | const torch::Tensor& value = inputs[index++]; 16 | const torch::Tensor& reward = inputs[index++]; 17 | const torch::Tensor& weight = inputs[index++]; 18 | index = 0; 19 | torch::Tensor& loss = outputs[index++]; 20 | torch::Tensor& grad_buf = outputs[index++]; 21 | 22 | checkCudaErr(cudaMemsetAsync((float*)(loss.data_ptr()), 0, sizeof(float))); 23 | 24 | const unsigned int time_step = reward.size(0); 25 | const unsigned int batch_size = reward.size(1); 26 | 27 | unsigned int block_size = 1 * WARP_SIZE; // in order to use as many sm processors as possible 28 | unsigned int grid_size = (batch_size + block_size - 1) / block_size; 29 | tdLambdaForwardKernel<<>>( 30 | time_step, batch_size, gamma, lambda, 31 | (float*)(value.data_ptr()), (float*)(reward.data_ptr()), (float*)(weight.data_ptr()), 32 | (float*)(loss.data_ptr()), (float*)(grad_buf.data_ptr())); 33 | } 34 | 35 | void TdLambdaBackward( 36 | const std::vector& inputs, 37 | std::vector& outputs) { 38 | 39 | unsigned int index = 0; 40 | const torch::Tensor& grad_loss = inputs[index++]; 41 | const torch::Tensor& grad_buf = inputs[index++]; 42 | index = 0; 43 | torch::Tensor& grad_value = outputs[index++]; 44 | 45 | const unsigned int time_step = grad_value.size(0) - 1; 46 | const unsigned int batch_size = grad_value.size(1); 47 | 48 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 49 | unsigned int grid_size = ((time_step + 1) * batch_size + block_size - 1) / block_size; 50 | tdLambdaBackwardKernel<<>>( 51 | time_step, batch_size, (float*)(grad_loss.data_ptr()), (float*)(grad_buf.data_ptr()), (float*)(grad_value.data_ptr())); 52 | } 53 | 54 | } // namespace cuda 55 | } // namespace rll 56 | } // namespace hpc 57 | -------------------------------------------------------------------------------- /src/rl_utils/upgo.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/rl_utils/entry.h" 2 | #include "hpc/rll/cuda/rl_utils/upgo_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void UpgoForward( 9 | const std::vector& inputs, 10 | std::vector& outputs) { 11 | 12 | unsigned int index = 0; 13 | const torch::Tensor& target_output = inputs[index++]; 14 | const torch::Tensor& rho = inputs[index++]; 15 | const torch::Tensor& action = inputs[index++]; 16 | const torch::Tensor& reward = inputs[index++]; 17 | const torch::Tensor& value = inputs[index++]; 18 | index = 0; 19 | torch::Tensor& advantage = outputs[index++]; 20 | torch::Tensor& metric = outputs[index++]; 21 | torch::Tensor& loss = outputs[index++]; 22 | torch::Tensor& grad_buf = outputs[index++]; 23 | 24 | checkCudaErr(cudaMemsetAsync((float*)(loss.data_ptr()), 0, sizeof(float))); 25 | 26 | const unsigned int time_step = target_output.size(0); 27 | const unsigned int batch_size = target_output.size(1); 28 | const unsigned int num_output = target_output.size(2); 29 | { 30 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 31 | unsigned int grid_size = (batch_size + block_size - 1) / block_size; 32 | upgoAdvantageKernel<<>>(time_step, batch_size, 33 | (float*)(rho.data_ptr()), (float*)(reward.data_ptr()), (float*)(value.data_ptr()), (float*)(advantage.data_ptr())); 34 | } 35 | { 36 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 37 | unsigned int grid_size = time_step * batch_size; 38 | crossEntropyKernel<<>>(num_output, 39 | (float*)(target_output.data_ptr()), (int64_t*)(action.data_ptr()), (float*)(metric.data_ptr()), (float*)(grad_buf.data_ptr())); 40 | } 41 | { 42 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 43 | unsigned int grid_size = (time_step * batch_size + block_size - 1) / block_size; 44 | upgoLossKernel<<>>(time_step, batch_size, 45 | (float*)(advantage.data_ptr()), (float*)(metric.data_ptr()), (float*)(loss.data_ptr())); 46 | } 47 | } 48 | 49 | void UpgoBackward( 50 | const std::vector& inputs, 51 | std::vector& outputs) { 52 | 53 | unsigned int index = 0; 54 | const torch::Tensor& grad_loss = inputs[index++]; 55 | const torch::Tensor& grad_buf = inputs[index++]; 56 | const torch::Tensor& advantage = inputs[index++]; 57 | index = 0; 58 | torch::Tensor& grad_target_output = outputs[index++]; 59 | 60 | const unsigned int time_step = grad_target_output.size(0); 61 | const unsigned int batch_size = grad_target_output.size(1); 62 | const unsigned int num_output = grad_target_output.size(2); 63 | 64 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 65 | unsigned int grid_size = (time_step * batch_size * num_output + block_size - 1) / block_size; 66 | upgoBackwardKernel<<>>(time_step, batch_size, num_output, 67 | (float*)(grad_loss.data_ptr()), (float*)(grad_buf.data_ptr()), 68 | (float*)(advantage.data_ptr()), (float*)(grad_target_output.data_ptr())); 69 | } 70 | 71 | } // namespace cuda 72 | } // namespace rll 73 | } // namespace hpc 74 | -------------------------------------------------------------------------------- /src/rl_utils/vtrace.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/rl_utils/entry.h" 2 | #include "hpc/rll/cuda/rl_utils/vtrace_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void VTraceForward( 9 | const std::vector& inputs, 10 | std::vector& outputs, 11 | float gamma, 12 | float lambda, 13 | float rho_clip_ratio, 14 | float c_clip_ratio, 15 | float rho_pg_clip_ratio) { 16 | 17 | unsigned int index = 0; 18 | const torch::Tensor& target_output = inputs[index++]; 19 | const torch::Tensor& behaviour_output = inputs[index++]; 20 | const torch::Tensor& action = inputs[index++]; 21 | const torch::Tensor& value = inputs[index++]; 22 | const torch::Tensor& reward = inputs[index++]; 23 | const torch::Tensor& weight = inputs[index++]; 24 | 25 | index = 0; 26 | torch::Tensor& target_output_prob = outputs[index++]; 27 | torch::Tensor& target_output_entropy = outputs[index++]; 28 | torch::Tensor& target_output_grad_logits = outputs[index++]; 29 | torch::Tensor& target_output_grad_prob = outputs[index++]; 30 | torch::Tensor& target_output_grad_entropy = outputs[index++]; 31 | torch::Tensor& behaviour_output_prob = outputs[index++]; 32 | torch::Tensor& is = outputs[index++]; 33 | torch::Tensor& ret = outputs[index++]; 34 | torch::Tensor& adv = outputs[index++]; 35 | torch::Tensor& pg_loss = outputs[index++]; 36 | torch::Tensor& value_loss = outputs[index++]; 37 | torch::Tensor& entropy_loss = outputs[index++]; 38 | 39 | checkCudaErr(cudaMemsetAsync((float*)(pg_loss.data_ptr()), 0, sizeof(float))); 40 | checkCudaErr(cudaMemsetAsync((float*)(value_loss.data_ptr()), 0, sizeof(float))); 41 | checkCudaErr(cudaMemsetAsync((float*)(entropy_loss.data_ptr()), 0, sizeof(float))); 42 | 43 | const unsigned int time_step = target_output.size(0); 44 | const unsigned int batch_size = target_output.size(1); 45 | const unsigned int num_output = target_output.size(2); 46 | { 47 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 48 | unsigned int grid_size = time_step * batch_size; 49 | categoricalTarget<<>>(num_output, 50 | (float*)(target_output.data_ptr()), (int64_t*)(action.data_ptr()), 51 | (float*)(target_output_prob.data_ptr()), (float*)(target_output_entropy.data_ptr()), 52 | (float*)(target_output_grad_logits.data_ptr()), (float*)(target_output_grad_prob.data_ptr()), 53 | (float*)(target_output_grad_entropy.data_ptr())); 54 | categoricalBehaviour<<>>(num_output, 55 | (float*)(behaviour_output.data_ptr()), (int64_t*)(action.data_ptr()), (float*)(behaviour_output_prob.data_ptr())); 56 | } 57 | { 58 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 59 | unsigned int grid_size = (time_step * batch_size + block_size - 1) / block_size; 60 | computeImportanceWeights<<>>(time_step, batch_size, 61 | (float*)(target_output_prob.data_ptr()), (float*)(behaviour_output_prob.data_ptr()), (float*)(is.data_ptr())); 62 | } 63 | { 64 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 65 | unsigned int grid_size = (batch_size + block_size - 1) / block_size; 66 | vtraceNStepReturn<<>>( 67 | time_step, batch_size, gamma, lambda, rho_clip_ratio, c_clip_ratio, 68 | (float*)(is.data_ptr()), (float*)(reward.data_ptr()), (float*)(value.data_ptr()), (float*)(ret.data_ptr())); 69 | } 70 | { 71 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 72 | unsigned int grid_size = (time_step * batch_size + block_size - 1) / block_size; 73 | vtraceAdvantage<<>>( 74 | time_step, batch_size, gamma, rho_pg_clip_ratio, 75 | (float*)(is.data_ptr()), (float*)(reward.data_ptr()), (float*)(value.data_ptr()), 76 | (float*)(ret.data_ptr()), (float*)(adv.data_ptr())); 77 | } 78 | { 79 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 80 | unsigned int grid_size = (time_step * batch_size + block_size - 1) / block_size; 81 | vtraceLoss<<>>(time_step, batch_size, 82 | (float*)(value.data_ptr()), (float*)(target_output_prob.data_ptr()), (float*)(target_output_entropy.data_ptr()), 83 | (float*)(ret.data_ptr()), (float*)(adv.data_ptr()), (float*)(weight.data_ptr()), 84 | (float*)(pg_loss.data_ptr()), (float*)(value_loss.data_ptr()), (float*)(entropy_loss.data_ptr())); 85 | } 86 | } 87 | 88 | void VTraceBackward( 89 | const std::vector& inputs, 90 | std::vector& outputs) { 91 | 92 | unsigned int index = 0; 93 | const torch::Tensor& grad_pg_loss = inputs[index++]; 94 | const torch::Tensor& grad_value_loss = inputs[index++]; 95 | const torch::Tensor& grad_entropy_loss = inputs[index++]; 96 | const torch::Tensor& value = inputs[index++]; 97 | const torch::Tensor& action = inputs[index++]; 98 | const torch::Tensor& weight = inputs[index++]; 99 | const torch::Tensor& ret = inputs[index++]; 100 | const torch::Tensor& adv = inputs[index++]; 101 | const torch::Tensor& target_output_grad_logits = inputs[index++]; 102 | const torch::Tensor& target_output_grad_prob = inputs[index++]; 103 | const torch::Tensor& target_output_grad_entropy = inputs[index++]; 104 | 105 | index = 0; 106 | torch::Tensor& grad_value = outputs[index++]; 107 | torch::Tensor& grad_target_output = outputs[index++]; 108 | 109 | const unsigned int time_step = grad_target_output.size(0); 110 | const unsigned int batch_size = grad_target_output.size(1); 111 | const unsigned int num_output = grad_target_output.size(2); 112 | { 113 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 114 | unsigned int grid_size = ((time_step + 1) * batch_size + block_size - 1) / block_size; 115 | vtraceBackwardValue<<>>(time_step, batch_size, 116 | (float*)(grad_value_loss.data_ptr()), (float*)(value.data_ptr()), (float*)(ret.data_ptr()), 117 | (float*)(weight.data_ptr()), (float*)(grad_value.data_ptr())); 118 | } 119 | { 120 | unsigned int block_size = DEFAULT_WARP_NUM * WARP_SIZE; 121 | unsigned int grid_size = time_step * batch_size; 122 | vtraceBackwardTargetOutput<<>>( 123 | time_step, batch_size, num_output, 124 | (float*)(grad_entropy_loss.data_ptr()), (float*)(grad_pg_loss.data_ptr()), 125 | (float*)(target_output_grad_logits.data_ptr()), 126 | (float*)(target_output_grad_entropy.data_ptr()), 127 | (float*)(target_output_grad_prob.data_ptr()), 128 | (float*)(adv.data_ptr()), (float*)(weight.data_ptr()), (float*)(grad_target_output.data_ptr())); 129 | } 130 | } 131 | 132 | } // namespace cuda 133 | } // namespace rll 134 | } // namespace hpc 135 | 136 | -------------------------------------------------------------------------------- /src/torch_utils/network/entry.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "hpc/rll/cuda/torch_utils/network/entry.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 9 | m.def("LstmForward", &LstmForward, "lstm forward (CUDA)"); 10 | m.def("LstmBackward", &LstmBackward, "lstm backward (CUDA)"); 11 | m.def("ScatterConnectionForward", &ScatterConnectionForward, "scatter_connection forward (CUDA)"); 12 | m.def("ScatterConnectionBackward", &ScatterConnectionBackward, "scatter_connection backward (CUDA)"); 13 | } 14 | 15 | } // namespace cuda 16 | } // namespace rll 17 | } // namespace hpc 18 | -------------------------------------------------------------------------------- /src/torch_utils/network/scatter_connection.cu: -------------------------------------------------------------------------------- 1 | #include "hpc/rll/cuda/torch_utils/network/entry.h" 2 | #include "hpc/rll/cuda/torch_utils/network/scatter_connection_kernel.h" 3 | 4 | namespace hpc { 5 | namespace rll { 6 | namespace cuda { 7 | 8 | void ScatterConnectionForward( 9 | const std::vector& inputs, 10 | std::vector& outputs, 11 | const char* scatter_type) { 12 | 13 | unsigned int index = 0; 14 | const torch::Tensor& in = inputs[index++]; 15 | const torch::Tensor& loc = inputs[index++]; 16 | index = 0; 17 | torch::Tensor& out = outputs[index++]; 18 | 19 | const unsigned int B = in.size(0); 20 | const unsigned int M = in.size(1); 21 | const unsigned int N = in.size(2); 22 | const unsigned int H = out.size(2); 23 | const unsigned int W = out.size(3); 24 | 25 | // forward kernel is launched according to input size, some output element will not be set value 26 | unsigned int out_size = B * N * H * W; 27 | checkCudaErr(cudaMemsetAsync((float*)(out.data_ptr()), 0, out_size * sizeof(float))); 28 | 29 | if (std::string(scatter_type) == "cover") { 30 | /* 31 | // consider that the input data may overlap, keep sequence like cpu when cover data 32 | // note: even thought this kernel is launched according to output size, 33 | // there still may be some output element will not be set value according to the mapping index 34 | dim3 block_size = WARP_SIZE; 35 | dim3 grid_size = B * N * H * W; 36 | scatterConnectionCoverKeepSeqForwardKernel<<>>( 37 | B, M, N, H, W, (float*)(in.data_ptr()), (int64_t*)(loc.data_ptr()), (float*)(out.data_ptr())); 38 | */ 39 | dim3 block_size = {WARP_SIZE, DEFAULT_WARP_NUM, 1}; 40 | dim3 grid_size = {(N + block_size.x - 1) / block_size.x, (M + block_size.y - 1) / block_size.y, B}; 41 | scatterConnectionCoverForwardKernel<<>>( 42 | B, M, N, H, W, (float*)(in.data_ptr()), (int64_t*)(loc.data_ptr()), (float*)(out.data_ptr())); 43 | } else if (std::string(scatter_type) == "add") { 44 | dim3 block_size = {WARP_SIZE, DEFAULT_WARP_NUM, 1}; 45 | dim3 grid_size = {(N + block_size.x - 1) / block_size.x, (M + block_size.y - 1) / block_size.y, B}; 46 | scatterConnectionAddForwardKernel<<>>( 47 | B, M, N, H, W, (float*)(in.data_ptr()), (int64_t*)(loc.data_ptr()), (float*)(out.data_ptr())); 48 | } 49 | } 50 | 51 | void ScatterConnectionBackward( 52 | const std::vector& inputs, 53 | std::vector& outputs) { 54 | 55 | unsigned int index = 0; 56 | const torch::Tensor& grad_out = inputs[index++]; 57 | const torch::Tensor& loc = inputs[index++]; 58 | index = 0; 59 | torch::Tensor& grad_in = outputs[index++]; 60 | 61 | const unsigned int B = grad_in.size(0); 62 | const unsigned int M = grad_in.size(1); 63 | const unsigned int N = grad_in.size(2); 64 | const unsigned int H = grad_out.size(2); 65 | const unsigned int W = grad_out.size(3); 66 | 67 | dim3 block_size = {WARP_SIZE, DEFAULT_WARP_NUM, 1}; 68 | dim3 grid_size = {(N + block_size.x - 1) / block_size.x, (M + block_size.y - 1) / block_size.y, B}; 69 | scatterConnectionBackwardKernel<<>>( 70 | B, M, N, H, W, (float*)(grad_out.data_ptr()), (int64_t*)(loc.data_ptr()), (float*)(grad_in.data_ptr())); 71 | } 72 | 73 | } // namespace cuda 74 | } // namespace rll 75 | } // namespace hpc 76 | -------------------------------------------------------------------------------- /tests/test_dntd.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from hpc_rll.origin.td import dist_nstep_td_error, dist_nstep_td_data 4 | from hpc_rll.rl_utils.td import DistNStepTD 5 | from testbase import mean_relative_error, times 6 | 7 | assert torch.cuda.is_available() 8 | use_cuda = True 9 | 10 | T = 128 11 | B = 128 12 | N = 128 13 | gamma = 0.95 14 | v_min = -10.0 15 | v_max = 10.0 16 | n_atom = 51 17 | 18 | def dntd_val(): 19 | ori_dist = torch.randn(B, N, n_atom).abs() 20 | ori_next_n_dist = torch.randn(B, N, n_atom).abs() 21 | ori_action = torch.randint(0, N, size=(B, )) 22 | ori_next_n_action = torch.randint(0, N, size=(B, )) 23 | ori_reward = torch.randn(T, B) 24 | ori_done = torch.randn(B) 25 | ori_weight = torch.randn(B) 26 | 27 | hpc_dist = ori_dist.clone().detach() 28 | hpc_next_n_dist = ori_next_n_dist.clone().detach() 29 | hpc_action = ori_action.clone().detach() 30 | hpc_next_n_action = ori_next_n_action.clone().detach() 31 | hpc_reward = ori_reward.clone().detach() 32 | hpc_done = ori_done.clone().detach() 33 | hpc_weight = ori_weight.clone().detach() 34 | hpc_dntd = DistNStepTD(T, B, N, n_atom) 35 | 36 | if use_cuda: 37 | ori_dist = ori_dist.cuda() 38 | ori_next_n_dist = ori_next_n_dist.cuda() 39 | ori_action = ori_action.cuda() 40 | ori_next_n_action = ori_next_n_action.cuda() 41 | ori_reward = ori_reward.cuda() 42 | ori_done = ori_done.cuda() 43 | ori_weight = ori_weight.cuda() 44 | 45 | hpc_dist = hpc_dist.cuda() 46 | hpc_next_n_dist = hpc_next_n_dist.cuda() 47 | hpc_action = hpc_action.cuda() 48 | hpc_next_n_action = hpc_next_n_action.cuda() 49 | hpc_reward = hpc_reward.cuda() 50 | hpc_done = hpc_done.cuda() 51 | hpc_weight = hpc_weight.cuda() 52 | hpc_dntd = hpc_dntd.cuda() 53 | 54 | ori_dist.requires_grad_(True) 55 | ori_loss, ori_td_err = dist_nstep_td_error( 56 | dist_nstep_td_data(ori_dist, ori_next_n_dist, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight), gamma, v_min, v_max, n_atom, T) 57 | ori_loss = ori_loss.mean() 58 | ori_loss.backward() 59 | 60 | hpc_dist.requires_grad_(True) 61 | hpc_loss, hpc_td_err = hpc_dntd(hpc_dist, hpc_next_n_dist, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma, v_min, v_max) 62 | hpc_loss = hpc_loss.mean() 63 | hpc_loss.backward() 64 | 65 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 66 | print("dntd fp mean_relative_error: " + str(mre)) 67 | mre = mean_relative_error(torch.flatten(ori_td_err).cpu().detach().numpy(), torch.flatten(hpc_td_err).cpu().detach().numpy()) 68 | print("dntd fp td_err mean_relative_error: " + str(mre)) 69 | mre = mean_relative_error(torch.flatten(ori_dist.grad).cpu().detach().numpy(), torch.flatten(hpc_dist.grad).cpu().detach().numpy()) 70 | print("dntd bp mean_relative_error: " + str(mre)) 71 | 72 | 73 | def dntd_perf(): 74 | ori_dist = torch.randn(B, N, n_atom).abs() 75 | ori_next_n_dist = torch.randn(B, N, n_atom).abs() 76 | ori_action = torch.randint(0, N, size=(B, )) 77 | ori_next_n_action = torch.randint(0, N, size=(B, )) 78 | ori_reward = torch.randn(T, B) 79 | ori_done = torch.randn(B) 80 | ori_weight = torch.randn(B) 81 | 82 | hpc_dist = ori_dist.clone().detach() 83 | hpc_next_n_dist = ori_next_n_dist.clone().detach() 84 | hpc_action = ori_action.clone().detach() 85 | hpc_next_n_action = ori_next_n_action.clone().detach() 86 | hpc_reward = ori_reward.clone().detach() 87 | hpc_done = ori_done.clone().detach() 88 | hpc_weight = ori_weight.clone().detach() 89 | hpc_dntd = DistNStepTD(T, B, N, n_atom) 90 | 91 | if use_cuda: 92 | ori_dist = ori_dist.cuda() 93 | ori_next_n_dist = ori_next_n_dist.cuda() 94 | ori_action = ori_action.cuda() 95 | ori_next_n_action = ori_next_n_action.cuda() 96 | ori_reward = ori_reward.cuda() 97 | ori_done = ori_done.cuda() 98 | ori_weight = ori_weight.cuda() 99 | 100 | hpc_dist = hpc_dist.cuda() 101 | hpc_next_n_dist = hpc_next_n_dist.cuda() 102 | hpc_action = hpc_action.cuda() 103 | hpc_next_n_action = hpc_next_n_action.cuda() 104 | hpc_reward = hpc_reward.cuda() 105 | hpc_done = hpc_done.cuda() 106 | hpc_weight = hpc_weight.cuda() 107 | hpc_dntd = hpc_dntd.cuda() 108 | 109 | ori_dist.requires_grad_(True) 110 | for i in range(times): 111 | t = time.time() 112 | ori_loss, ori_td_err = dist_nstep_td_error(dist_nstep_td_data( 113 | ori_dist, ori_next_n_dist, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight), gamma, v_min, v_max, n_atom, T) 114 | ori_loss = ori_loss.mean() 115 | ori_loss.backward() 116 | if use_cuda: 117 | torch.cuda.synchronize() 118 | print('epoch: {}, origin dntd cost time: {}'.format(i, time.time() - t)) 119 | 120 | hpc_dist.requires_grad_(True) 121 | for i in range(times): 122 | t = time.time() 123 | hpc_loss, hpc_td_err = hpc_dntd(hpc_dist, hpc_next_n_dist, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma, v_min, v_max) 124 | hpc_loss = hpc_loss.mean() 125 | hpc_loss.backward() 126 | if use_cuda: 127 | torch.cuda.synchronize() 128 | print('epoch: {}, hpc dntd cost time: {}'.format(i, time.time() - t)) 129 | 130 | 131 | if __name__ == '__main__': 132 | print("target problem: T = {}, B = {}, N = {}, gamma = {}, v_min = {}, v_max = {}, n_atom = {}".format(T, B, N, gamma, v_min, v_max, n_atom)) 133 | print("================run dntd validation test================") 134 | dntd_val() 135 | print("================run dntd performance test================") 136 | dntd_perf() 137 | -------------------------------------------------------------------------------- /tests/test_gae.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from hpc_rll.origin.gae import gae, gae_data 4 | from hpc_rll.rl_utils.gae import GAE 5 | from testbase import mean_relative_error, times 6 | 7 | assert torch.cuda.is_available() 8 | use_cuda = True 9 | 10 | T = 1024 11 | B = 64 12 | 13 | def gae_val(): 14 | value = torch.randn(T + 1, B) 15 | reward = torch.randn(T, B) 16 | 17 | hpc_gae = GAE(T, B) 18 | 19 | if use_cuda: 20 | value = value.cuda() 21 | reward = reward.cuda() 22 | hpc_gae = hpc_gae.cuda() 23 | ori_adv = gae(gae_data(value, reward)) 24 | hpc_adv = hpc_gae(value, reward) 25 | if use_cuda: 26 | torch.cuda.synchronize() 27 | 28 | mre = mean_relative_error(torch.flatten(ori_adv).cpu().detach().numpy(), torch.flatten(hpc_adv).cpu().detach().numpy()) 29 | print("gae mean_relative_error: " + str(mre)) 30 | 31 | def gae_perf(): 32 | value = torch.randn(T + 1, B) 33 | reward = torch.randn(T, B) 34 | 35 | hpc_gae = GAE(T, B) 36 | 37 | if use_cuda: 38 | value = value.cuda() 39 | reward = reward.cuda() 40 | hpc_gae = hpc_gae.cuda() 41 | for i in range(times): 42 | t = time.time() 43 | adv = gae(gae_data(value, reward)) 44 | if use_cuda: 45 | torch.cuda.synchronize() 46 | print('epoch: {}, original gae cost time: {}'.format(i, time.time() - t)) 47 | for i in range(times): 48 | t = time.time() 49 | hpc_adv = hpc_gae(value, reward) 50 | if use_cuda: 51 | torch.cuda.synchronize() 52 | print('epoch: {}, hpc gae cost time: {}'.format(i, time.time() - t)) 53 | 54 | 55 | if __name__ == '__main__': 56 | print("target problem: T = {}, B = {}".format(T, B)) 57 | print("================run gae validation test================") 58 | gae_val() 59 | print("================run gae performance test================") 60 | gae_perf() 61 | -------------------------------------------------------------------------------- /tests/test_iqn_nstep_td_error.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from hpc_rll.origin.td import iqn_nstep_td_error, iqn_nstep_td_data 4 | from hpc_rll.rl_utils.td import IQNNStepTDError 5 | from testbase import mean_relative_error, times 6 | 7 | assert torch.cuda.is_available() 8 | use_cuda = True 9 | 10 | tau = 33 11 | tauPrime = 34 12 | T = 10 13 | B = 64 14 | N = 8 15 | gamma = 0.95 16 | kappa = 0.9 17 | 18 | def iqn_val(): 19 | ori_q = torch.randn(tau, B, N) 20 | ori_next_n_q = torch.randn(tauPrime, B, N) 21 | ori_action = torch.randint(0, N, size=(B, )) 22 | ori_next_n_action = torch.randint(0, N, size=(B, )) 23 | ori_reward = torch.randn(T, B) 24 | ori_done = torch.randn(B) 25 | ori_r_q = torch.randn(tau, B) 26 | ori_weight = torch.randn(B) 27 | ori_value_gamma = torch.randn(B) 28 | 29 | hpc_q = ori_q.clone().detach() 30 | hpc_next_n_q = ori_next_n_q.clone().detach() 31 | hpc_action = ori_action.clone().detach() 32 | hpc_next_n_action = ori_next_n_action.clone().detach() 33 | hpc_reward = ori_reward.clone().detach() 34 | hpc_done = ori_done.clone().detach() 35 | hpc_r_q = ori_r_q.clone().detach() 36 | hpc_weight = ori_weight.clone().detach() 37 | hpc_value_gamma = ori_value_gamma.clone().detach() 38 | hpc_iqn = IQNNStepTDError(tau, tauPrime, T, B, N) 39 | 40 | if use_cuda: 41 | ori_q = ori_q.cuda() 42 | ori_next_n_q = ori_next_n_q.cuda() 43 | ori_action = ori_action.cuda() 44 | ori_next_n_action = ori_next_n_action.cuda() 45 | ori_reward = ori_reward.cuda() 46 | ori_done = ori_done.cuda() 47 | ori_r_q = ori_r_q.cuda() 48 | ori_weight = ori_weight.cuda() 49 | ori_value_gamma = ori_value_gamma.cuda() 50 | 51 | hpc_q = hpc_q.cuda() 52 | hpc_next_n_q = hpc_next_n_q.cuda() 53 | hpc_action = hpc_action.cuda() 54 | hpc_next_n_action = hpc_next_n_action.cuda() 55 | hpc_reward = hpc_reward.cuda() 56 | hpc_done = hpc_done.cuda() 57 | hpc_r_q = hpc_r_q.cuda() 58 | hpc_weight = hpc_weight.cuda() 59 | hpc_value_gamma = hpc_value_gamma.cuda() 60 | hpc_iqn = hpc_iqn.cuda() 61 | 62 | ori_q.requires_grad_(True) 63 | ori_loss, ori_ = iqn_nstep_td_error(iqn_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_r_q, ori_weight), gamma, T, kappa, ori_value_gamma) 64 | ori_loss = ori_loss.mean() 65 | ori_loss.backward() 66 | if use_cuda: 67 | torch.cuda.synchronize() 68 | 69 | torch.cuda.cudart().cudaProfilerStart() 70 | hpc_q.requires_grad_(True) 71 | hpc_loss, hpc_ = hpc_iqn(hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_r_q, gamma, kappa, hpc_weight, hpc_value_gamma) 72 | hpc_loss = hpc_loss.mean() 73 | hpc_loss.backward() 74 | if use_cuda: 75 | torch.cuda.synchronize() 76 | torch.cuda.cudart().cudaProfilerStop() 77 | 78 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 79 | print("iqn fp mean_relative_error: " + str(mre)) 80 | mre = mean_relative_error(torch.flatten(ori_q.grad).cpu().detach().numpy(), torch.flatten(hpc_q.grad).cpu().detach().numpy()) 81 | print("iqn bp mean_relative_error: " + str(mre)) 82 | 83 | def iqn_perf(): 84 | ori_q = torch.randn(tau, B, N) 85 | ori_next_n_q = torch.randn(tauPrime, B, N) 86 | ori_action = torch.randint(0, N, size=(B, )) 87 | ori_next_n_action = torch.randint(0, N, size=(B, )) 88 | ori_reward = torch.randn(T, B) 89 | ori_done = torch.randn(B) 90 | ori_r_q = torch.randn(tau, B) 91 | ori_weight = torch.randn(B) 92 | ori_value_gamma = torch.randn(B) 93 | 94 | hpc_q = ori_q.clone().detach() 95 | hpc_next_n_q = ori_next_n_q.clone().detach() 96 | hpc_action = ori_action.clone().detach() 97 | hpc_next_n_action = ori_next_n_action.clone().detach() 98 | hpc_reward = ori_reward.clone().detach() 99 | hpc_done = ori_done.clone().detach() 100 | hpc_r_q = ori_r_q.clone().detach() 101 | hpc_weight = ori_weight.clone().detach() 102 | hpc_value_gamma = ori_value_gamma.clone().detach() 103 | hpc_iqn = IQNNStepTDError(tau, tauPrime, T, B, N) 104 | 105 | if use_cuda: 106 | ori_q = ori_q.cuda() 107 | ori_next_n_q = ori_next_n_q.cuda() 108 | ori_action = ori_action.cuda() 109 | ori_next_n_action = ori_next_n_action.cuda() 110 | ori_reward = ori_reward.cuda() 111 | ori_done = ori_done.cuda() 112 | ori_r_q = ori_r_q.cuda() 113 | ori_weight = ori_weight.cuda() 114 | ori_value_gamma = ori_value_gamma.cuda() 115 | 116 | hpc_q = hpc_q.cuda() 117 | hpc_next_n_q = hpc_next_n_q.cuda() 118 | hpc_action = hpc_action.cuda() 119 | hpc_next_n_action = hpc_next_n_action.cuda() 120 | hpc_reward = hpc_reward.cuda() 121 | hpc_done = hpc_done.cuda() 122 | hpc_r_q = hpc_r_q.cuda() 123 | hpc_weight = hpc_weight.cuda() 124 | hpc_iqn = hpc_iqn.cuda() 125 | hpc_value_gamma = hpc_value_gamma.cuda() 126 | 127 | ori_q.requires_grad_(True) 128 | for i in range(times): 129 | t = time.time() 130 | ori_loss, ori_ = iqn_nstep_td_error(iqn_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_r_q, ori_weight), gamma, T, kappa, ori_value_gamma) 131 | ori_loss = ori_loss.mean() 132 | ori_loss.backward() 133 | if use_cuda: 134 | torch.cuda.synchronize() 135 | print('epoch: {}, original iqn cost time: {}'.format(i, time.time() - t)) 136 | 137 | #torch.cuda.cudart().cudaProfilerStart() 138 | hpc_q.requires_grad_(True) 139 | for i in range(times): 140 | t = time.time() 141 | hpc_loss, hpc_ = hpc_iqn(hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_r_q, gamma, kappa, hpc_weight, hpc_value_gamma) 142 | hpc_loss = hpc_loss.mean() 143 | hpc_loss.backward() 144 | if use_cuda: 145 | torch.cuda.synchronize() 146 | print('epoch: {}, hpc iqn cost time: {}'.format(i, time.time() - t)) 147 | #torch.cuda.cudart().cudaProfilerStop() 148 | 149 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 150 | print("iqn fp mean_relative_error: " + str(mre)) 151 | mre = mean_relative_error(torch.flatten(ori_q.grad).cpu().detach().numpy(), torch.flatten(hpc_q.grad).cpu().detach().numpy()) 152 | print("iqn bp mean_relative_error: " + str(mre)) 153 | 154 | if __name__ == '__main__': 155 | print("target problem: tau = {}, tauPrime = {}, T = {}, B = {}, N = {}, gamma = {}, kappa = {}".format(tau, tauPrime, T, B, N, gamma, kappa)) 156 | print("================run iqn validation test================") 157 | iqn_val() 158 | print("================run iqn performance test================") 159 | iqn_perf() 160 | -------------------------------------------------------------------------------- /tests/test_lstm.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from hpc_rll.origin.rnn import get_lstm 4 | from hpc_rll.torch_utils.network.rnn import LSTM 5 | from testbase import mean_relative_error, times 6 | 7 | assert torch.cuda.is_available() 8 | use_cuda = True 9 | 10 | seq_len = 64 11 | batch_size = 3 12 | input_size = 1792 13 | hidden_size = 384 14 | num_layers = 3 15 | norm_type = 'LN' 16 | dropout = 0#0.1 17 | 18 | # Note: need open load_params for hpc_lstm to validation 19 | # Note: only used to case of num_layers = 3 20 | def lstm_val(): 21 | ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout) 22 | hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout) 23 | 24 | ori_x = torch.randn(seq_len, batch_size, input_size) 25 | ori_h0 = torch.randn(num_layers, batch_size, hidden_size) 26 | ori_c0 = torch.randn(num_layers, batch_size, hidden_size) 27 | 28 | if use_cuda: 29 | ori_x = ori_x.cuda() 30 | ori_h0 = ori_h0.cuda() 31 | ori_c0 = ori_c0.cuda() 32 | ori_lstm = ori_lstm.cuda() 33 | hpc_lstm = hpc_lstm.cuda() 34 | 35 | ori_x.requires_grad_(True) 36 | ori_output, ori_next_state = ori_lstm(ori_x, [ori_h0, ori_c0]) 37 | ori_loss = ori_output.mean() 38 | ori_loss.backward() 39 | 40 | hpc_x = ori_x.clone().detach() 41 | hpc_h0 = ori_h0.clone().detach() 42 | hpc_c0 = ori_c0.clone().detach() 43 | hpc_x.requires_grad_(True) 44 | hpc_output, hpc_next_state = hpc_lstm(hpc_x, [hpc_h0, hpc_c0]) 45 | hpc_loss = hpc_output.mean() 46 | hpc_loss.backward() 47 | torch.cuda.synchronize() 48 | 49 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 50 | print("lstm fp mean_relative_error: " + str(mre)) 51 | mre = mean_relative_error(torch.flatten(ori_x.grad).cpu().detach().numpy(), torch.flatten(hpc_x.grad).cpu().detach().numpy()) 52 | print("lstm bp mean_relative_error: " + str(mre)) 53 | 54 | ori_wx_grad = torch.cat((ori_lstm.wx[0].grad, ori_lstm.wx[1].grad, ori_lstm.wx[2].grad)) 55 | hpc_wx_grad = hpc_lstm.wx.grad 56 | mre = mean_relative_error(torch.flatten(ori_wx_grad).cpu().numpy(), torch.flatten(hpc_wx_grad).cpu().numpy()) 57 | print("wx grad mean_relative_error: " + str(mre)) 58 | 59 | ori_wh_grad = torch.cat((ori_lstm.wh[0].grad, ori_lstm.wh[1].grad, ori_lstm.wh[2].grad)) 60 | hpc_wh_grad = hpc_lstm.wh.grad 61 | mre = mean_relative_error(torch.flatten(ori_wh_grad).cpu().numpy(), torch.flatten(hpc_wh_grad).cpu().numpy()) 62 | print("wh grad mean_relative_error: " + str(mre)) 63 | 64 | ori_bias_grad = ori_lstm.bias.grad 65 | hpc_bias_grad = hpc_lstm.bias.grad 66 | mre = mean_relative_error(torch.flatten(ori_bias_grad).cpu().numpy(), torch.flatten(hpc_bias_grad).cpu().numpy()) 67 | print("bias grad mean_relative_error: " + str(mre)) 68 | 69 | params = list(ori_lstm.parameters()) 70 | gamma_0_x = params[1] 71 | beta_0_x = params[2] 72 | gamma_0_h = params[3] 73 | beta_0_h = params[4] 74 | gamma_1_x = params[5] 75 | beta_1_x = params[6] 76 | gamma_1_h = params[7] 77 | beta_1_h = params[8] 78 | gamma_2_x = params[9] 79 | beta_2_x = params[10] 80 | gamma_2_h = params[11] 81 | beta_2_h = params[12] 82 | ori_gamma_grad = torch.cat((gamma_0_x.grad, gamma_0_h.grad, gamma_1_x.grad, gamma_1_h.grad, gamma_2_x.grad, gamma_2_h.grad)) 83 | ori_beta_grad = torch.cat((beta_0_x.grad, beta_0_h.grad, beta_1_x.grad, beta_1_h.grad, beta_2_x.grad, beta_2_h.grad)) 84 | hpc_gamma_grad = hpc_lstm.ln_gamma.grad 85 | hpc_beta_grad = hpc_lstm.ln_beta.grad 86 | mre = mean_relative_error(torch.flatten(ori_gamma_grad).cpu().numpy(), torch.flatten(hpc_gamma_grad).cpu().numpy()) 87 | print("ln gamma grad mean_relative_error: " + str(mre)) 88 | mre = mean_relative_error(torch.flatten(ori_beta_grad).cpu().numpy(), torch.flatten(hpc_beta_grad).cpu().numpy()) 89 | print("ln beta grad mean_relative_error: " + str(mre)) 90 | 91 | def lstm_perf(): 92 | ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout) 93 | hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout) 94 | 95 | lstms = {'normal': ori_lstm, 'hpc': hpc_lstm} 96 | 97 | for lstm_type, lstm in lstms.items(): 98 | x = torch.rand(seq_len, batch_size, input_size) 99 | h0 = torch.randn(num_layers, batch_size, hidden_size) 100 | c0 = torch.randn(num_layers, batch_size, hidden_size) 101 | if use_cuda: 102 | x = x.cuda() 103 | h0 = h0.cuda() 104 | c0 = c0.cuda() 105 | lstm = lstm.cuda() 106 | 107 | prev_state = [h0, c0] 108 | x.requires_grad_(True) 109 | for i in range(times): 110 | t = time.time() 111 | output, _ = lstm(x, prev_state) 112 | loss = output.mean() 113 | loss.backward() 114 | if use_cuda: 115 | torch.cuda.synchronize() 116 | print('epoch: {}, {} lstm cost time: {}'.format(i, lstm_type, time.time() - t)) 117 | 118 | if __name__ == '__main__': 119 | print("target problem: seq_len = {}, batch_size = {}, input_size = {}, hidden_size = {}, num_layers = {}, norm_type = {}, dropout = {}".format( 120 | seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)) 121 | print("==============lstm has no validation test================") 122 | #print("===============run lstm validation test==================") 123 | #lstm_val() 124 | print("===============run lstm performance test=================") 125 | lstm_perf() 126 | -------------------------------------------------------------------------------- /tests/test_ppo.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn.functional as F 4 | from hpc_rll.origin.ppo import ppo_error, ppo_data 5 | from hpc_rll.rl_utils.ppo import PPO 6 | from testbase import mean_relative_error, times 7 | 8 | assert torch.cuda.is_available() 9 | use_cuda = True 10 | 11 | B = 128 12 | N = 128 13 | clip_ratio = 0.2 14 | use_value_clip = True 15 | dual_clip = None 16 | 17 | def ppo_val(): 18 | ori_logits_new = torch.randn(B, N) 19 | ori_logits_old = torch.randn(B, N) 20 | ori_action = torch.randint(0, N, size=(B, )) 21 | ori_value_new = torch.randn(B) 22 | ori_value_old = torch.randn(B) 23 | ori_adv = torch.randn(B) 24 | ori_return = torch.randn(B) 25 | ori_weight = torch.randn(B) 26 | 27 | hpc_logits_new = ori_logits_new.clone().detach() 28 | hpc_logits_old = ori_logits_old.clone().detach() 29 | hpc_action = ori_action.clone().detach() 30 | hpc_value_new = ori_value_new.clone().detach() 31 | hpc_value_old = ori_value_old.clone().detach() 32 | hpc_adv = ori_adv.clone().detach() 33 | hpc_return = ori_return.clone().detach() 34 | hpc_weight = ori_weight.clone().detach() 35 | hpc_ppo = PPO(B, N) 36 | 37 | if use_cuda: 38 | ori_logits_new = ori_logits_new.cuda() 39 | ori_logits_old = ori_logits_old.cuda() 40 | ori_action = ori_action.cuda() 41 | ori_value_new = ori_value_new.cuda() 42 | ori_value_old = ori_value_old.cuda() 43 | ori_adv = ori_adv.cuda() 44 | ori_return = ori_return.cuda() 45 | ori_weight = ori_weight.cuda() 46 | 47 | hpc_logits_new = hpc_logits_new.cuda() 48 | hpc_logits_old = hpc_logits_old.cuda() 49 | hpc_action = hpc_action.cuda() 50 | hpc_value_new = hpc_value_new.cuda() 51 | hpc_value_old = hpc_value_old.cuda() 52 | hpc_adv = hpc_adv.cuda() 53 | hpc_return = hpc_return.cuda() 54 | hpc_weight = hpc_weight.cuda() 55 | hpc_ppo = hpc_ppo.cuda() 56 | 57 | ori_logits_new.requires_grad_(True) 58 | ori_value_new.requires_grad_(True) 59 | ori_loss, ori_info = ppo_error(ppo_data(ori_logits_new, ori_logits_old, ori_action, ori_value_new, ori_value_old, ori_adv, ori_return, ori_weight), clip_ratio, use_value_clip, dual_clip) 60 | ori_loss = sum(ori_loss) 61 | ori_loss.backward() 62 | 63 | hpc_logits_new.requires_grad_(True) 64 | hpc_value_new.requires_grad_(True) 65 | hpc_loss, hpc_info = hpc_ppo(hpc_logits_new, hpc_logits_old, hpc_action, hpc_value_new, hpc_value_old, hpc_adv, hpc_return, hpc_weight, clip_ratio, use_value_clip, dual_clip) 66 | hpc_loss = sum(hpc_loss) 67 | hpc_loss.backward() 68 | 69 | print("ori_info: " + str(ori_info)) 70 | print("hpc_info: " + str(hpc_info)) 71 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 72 | print("ppo fp loss mean_relative_error: " + str(mre)) 73 | mre = mean_relative_error(torch.flatten(ori_logits_new.grad).cpu().detach().numpy(), torch.flatten(hpc_logits_new.grad).cpu().detach().numpy()) 74 | print("ppo bp logits_new mean_relative_error: " + str(mre)) 75 | mre = mean_relative_error(torch.flatten(ori_value_new.grad).cpu().detach().numpy(), torch.flatten(hpc_value_new.grad).cpu().detach().numpy()) 76 | print("ppo bp value_new mean_relative_error: " + str(mre)) 77 | 78 | 79 | def ppo_perf(): 80 | ori_logits_new = torch.randn(B, N) 81 | ori_logits_old = torch.randn(B, N) 82 | ori_action = torch.randint(0, N, size=(B, )) 83 | ori_value_new = torch.randn(B) 84 | ori_value_old = torch.randn(B) 85 | ori_adv = torch.randn(B) 86 | ori_return = torch.randn(B) 87 | ori_weight = torch.randn(B) 88 | 89 | hpc_logits_new = ori_logits_new.clone().detach() 90 | hpc_logits_old = ori_logits_old.clone().detach() 91 | hpc_action = ori_action.clone().detach() 92 | hpc_value_new = ori_value_new.clone().detach() 93 | hpc_value_old = ori_value_old.clone().detach() 94 | hpc_adv = ori_adv.clone().detach() 95 | hpc_return = ori_return.clone().detach() 96 | hpc_weight = ori_weight.clone().detach() 97 | hpc_ppo = PPO(B, N) 98 | 99 | if use_cuda: 100 | ori_logits_new = ori_logits_new.cuda() 101 | ori_logits_old = ori_logits_old.cuda() 102 | ori_action = ori_action.cuda() 103 | ori_value_new = ori_value_new.cuda() 104 | ori_value_old = ori_value_old.cuda() 105 | ori_adv = ori_adv.cuda() 106 | ori_return = ori_return.cuda() 107 | ori_weight = ori_weight.cuda() 108 | 109 | hpc_logits_new = hpc_logits_new.cuda() 110 | hpc_logits_old = hpc_logits_old.cuda() 111 | hpc_action = hpc_action.cuda() 112 | hpc_value_new = hpc_value_new.cuda() 113 | hpc_value_old = hpc_value_old.cuda() 114 | hpc_adv = hpc_adv.cuda() 115 | hpc_return = hpc_return.cuda() 116 | hpc_weight = hpc_weight.cuda() 117 | hpc_ppo = hpc_ppo.cuda() 118 | 119 | ori_logits_new.requires_grad_(True) 120 | ori_value_new.requires_grad_(True) 121 | for i in range(times): 122 | t = time.time() 123 | ori_loss, ori_info = ppo_error(ppo_data(ori_logits_new, ori_logits_old, ori_action, ori_value_new, ori_value_old, ori_adv, ori_return, ori_weight), clip_ratio, use_value_clip, dual_clip) 124 | ori_loss = sum(ori_loss) 125 | ori_loss.backward() 126 | if use_cuda: 127 | torch.cuda.synchronize() 128 | print('epoch: {}, origin ppo cost time: {}'.format(i, time.time() - t)) 129 | 130 | hpc_logits_new.requires_grad_(True) 131 | hpc_value_new.requires_grad_(True) 132 | for i in range(times): 133 | t = time.time() 134 | hpc_loss, hpc_info = hpc_ppo(hpc_logits_new, hpc_logits_old, hpc_action, hpc_value_new, hpc_value_old, hpc_adv, hpc_return, hpc_weight, clip_ratio, use_value_clip, dual_clip) 135 | hpc_loss = sum(hpc_loss) 136 | hpc_loss.backward() 137 | if use_cuda: 138 | torch.cuda.synchronize() 139 | print('epoch: {}, hpc ppo cost time: {}'.format(i, time.time() - t)) 140 | 141 | 142 | if __name__ == '__main__': 143 | print("target problem: B = {}, N = {}, clip_ratio = {}, use_value_clip = {}, dual_clip = {}".format(B, N, clip_ratio, use_value_clip, dual_clip)) 144 | print("================run ppo validation test================") 145 | ppo_val() 146 | print("================run ppo performance test================") 147 | ppo_perf() 148 | -------------------------------------------------------------------------------- /tests/test_qntd.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from hpc_rll.origin.td import q_nstep_td_error, q_nstep_td_data 4 | from hpc_rll.rl_utils.td import QNStepTD 5 | from testbase import mean_relative_error, times 6 | 7 | assert torch.cuda.is_available() 8 | use_cuda = True 9 | 10 | T = 1024 11 | B = 64 12 | N = 64 13 | gamma = 0.95 14 | 15 | def qntd_val(): 16 | ori_q = torch.randn(B, N) 17 | ori_next_n_q = torch.randn(B, N) 18 | ori_action = torch.randint(0, N, size=(B, )) 19 | ori_next_n_action = torch.randint(0, N, size=(B, )) 20 | ori_reward = torch.randn(T, B) 21 | ori_done = torch.randn(B) 22 | ori_weight = torch.randn(B) 23 | 24 | hpc_q = ori_q.clone().detach() 25 | hpc_next_n_q = ori_next_n_q.clone().detach() 26 | hpc_action = ori_action.clone().detach() 27 | hpc_next_n_action = ori_next_n_action.clone().detach() 28 | hpc_reward = ori_reward.clone().detach() 29 | hpc_done = ori_done.clone().detach() 30 | hpc_weight = ori_weight.clone().detach() 31 | hpc_qntd = QNStepTD(T, B, N) 32 | 33 | if use_cuda: 34 | ori_q = ori_q.cuda() 35 | ori_next_n_q = ori_next_n_q.cuda() 36 | ori_action = ori_action.cuda() 37 | ori_next_n_action = ori_next_n_action.cuda() 38 | ori_reward = ori_reward.cuda() 39 | ori_done = ori_done.cuda() 40 | ori_weight = ori_weight.cuda() 41 | 42 | hpc_q = hpc_q.cuda() 43 | hpc_next_n_q = hpc_next_n_q.cuda() 44 | hpc_action = hpc_action.cuda() 45 | hpc_next_n_action = hpc_next_n_action.cuda() 46 | hpc_reward = hpc_reward.cuda() 47 | hpc_done = hpc_done.cuda() 48 | hpc_weight = hpc_weight.cuda() 49 | hpc_qntd = hpc_qntd.cuda() 50 | 51 | ori_q.requires_grad_(True) 52 | ori_loss, _ = q_nstep_td_error(q_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight), gamma, T) 53 | ori_loss = ori_loss.mean() 54 | ori_loss.backward() 55 | if use_cuda: 56 | torch.cuda.synchronize() 57 | 58 | hpc_q.requires_grad_(True) 59 | hpc_loss, _ = hpc_qntd(hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma) 60 | hpc_loss = hpc_loss.mean() 61 | hpc_loss.backward() 62 | if use_cuda: 63 | torch.cuda.synchronize() 64 | 65 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 66 | print("qntd fp mean_relative_error: " + str(mre)) 67 | mre = mean_relative_error(torch.flatten(ori_q.grad).cpu().detach().numpy(), torch.flatten(hpc_q.grad).cpu().detach().numpy()) 68 | print("qntd bp mean_relative_error: " + str(mre)) 69 | 70 | def qntd_perf(): 71 | ori_q = torch.randn(B, N) 72 | ori_next_n_q = torch.randn(B, N) 73 | ori_action = torch.randint(0, N, size=(B, )) 74 | ori_next_n_action = torch.randint(0, N, size=(B, )) 75 | ori_reward = torch.randn(T, B) 76 | ori_done = torch.randn(B) 77 | ori_weight = torch.randn(B) 78 | 79 | hpc_q = ori_q.clone().detach() 80 | hpc_next_n_q = ori_next_n_q.clone().detach() 81 | hpc_action = ori_action.clone().detach() 82 | hpc_next_n_action = ori_next_n_action.clone().detach() 83 | hpc_reward = ori_reward.clone().detach() 84 | hpc_done = ori_done.clone().detach() 85 | hpc_weight = ori_weight.clone().detach() 86 | hpc_qntd = QNStepTD(T, B, N) 87 | 88 | if use_cuda: 89 | ori_q = ori_q.cuda() 90 | ori_next_n_q = ori_next_n_q.cuda() 91 | ori_action = ori_action.cuda() 92 | ori_next_n_action = ori_next_n_action.cuda() 93 | ori_reward = ori_reward.cuda() 94 | ori_done = ori_done.cuda() 95 | ori_weight = ori_weight.cuda() 96 | 97 | hpc_q = hpc_q.cuda() 98 | hpc_next_n_q = hpc_next_n_q.cuda() 99 | hpc_action = hpc_action.cuda() 100 | hpc_next_n_action = hpc_next_n_action.cuda() 101 | hpc_reward = hpc_reward.cuda() 102 | hpc_done = hpc_done.cuda() 103 | hpc_weight = hpc_weight.cuda() 104 | hpc_qntd = hpc_qntd.cuda() 105 | 106 | ori_q.requires_grad_(True) 107 | for i in range(times): 108 | t = time.time() 109 | ori_loss, _ = q_nstep_td_error(q_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight), gamma, T) 110 | ori_loss = ori_loss.mean() 111 | ori_loss.backward() 112 | if use_cuda: 113 | torch.cuda.synchronize() 114 | print('epoch: {}, original qntd cost time: {}'.format(i, time.time() - t)) 115 | 116 | hpc_q.requires_grad_(True) 117 | for i in range(times): 118 | t = time.time() 119 | hpc_loss, _ = hpc_qntd(hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma) 120 | hpc_loss = hpc_loss.mean() 121 | hpc_loss.backward() 122 | if use_cuda: 123 | torch.cuda.synchronize() 124 | print('epoch: {}, hpc qntd cost time: {}'.format(i, time.time() - t)) 125 | 126 | 127 | if __name__ == '__main__': 128 | print("target problem: T = {}, B = {}, N = {}, gamma = {}".format(T, B, N, gamma)) 129 | print("================run qntd validation test================") 130 | qntd_val() 131 | print("================run qntd performance test================") 132 | qntd_perf() 133 | -------------------------------------------------------------------------------- /tests/test_qntd_rescale.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from hpc_rll.origin.td import q_nstep_td_error_with_rescale, q_nstep_td_data 4 | from hpc_rll.rl_utils.td import QNStepTDRescale 5 | from testbase import mean_relative_error, times 6 | 7 | assert torch.cuda.is_available() 8 | use_cuda = True 9 | 10 | T = 1024 11 | B = 64 12 | N = 64 13 | gamma = 0.95 14 | 15 | def qntd_rescale_val(): 16 | ori_q = torch.randn(B, N) 17 | ori_next_n_q = torch.randn(B, N) 18 | ori_action = torch.randint(0, N, size=(B, )) 19 | ori_next_n_action = torch.randint(0, N, size=(B, )) 20 | ori_reward = torch.randn(T, B) 21 | ori_done = torch.randn(B) 22 | ori_weight = torch.randn(B) 23 | 24 | hpc_q = ori_q.clone().detach() 25 | hpc_next_n_q = ori_next_n_q.clone().detach() 26 | hpc_action = ori_action.clone().detach() 27 | hpc_next_n_action = ori_next_n_action.clone().detach() 28 | hpc_reward = ori_reward.clone().detach() 29 | hpc_done = ori_done.clone().detach() 30 | hpc_weight = ori_weight.clone().detach() 31 | hpc_qntd_rescale = QNStepTDRescale(T, B, N) 32 | 33 | if use_cuda: 34 | ori_q = ori_q.cuda() 35 | ori_next_n_q = ori_next_n_q.cuda() 36 | ori_action = ori_action.cuda() 37 | ori_next_n_action = ori_next_n_action.cuda() 38 | ori_reward = ori_reward.cuda() 39 | ori_done = ori_done.cuda() 40 | ori_weight = ori_weight.cuda() 41 | 42 | hpc_q = hpc_q.cuda() 43 | hpc_next_n_q = hpc_next_n_q.cuda() 44 | hpc_action = hpc_action.cuda() 45 | hpc_next_n_action = hpc_next_n_action.cuda() 46 | hpc_reward = hpc_reward.cuda() 47 | hpc_done = hpc_done.cuda() 48 | hpc_weight = hpc_weight.cuda() 49 | hpc_qntd_rescale = hpc_qntd_rescale.cuda() 50 | 51 | ori_q.requires_grad_(True) 52 | ori_loss, _ = q_nstep_td_error_with_rescale( 53 | q_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight), gamma, T) 54 | ori_loss = ori_loss.mean() 55 | ori_loss.backward() 56 | if use_cuda: 57 | torch.cuda.synchronize() 58 | 59 | hpc_q.requires_grad_(True) 60 | hpc_loss, _ = hpc_qntd_rescale(hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma) 61 | hpc_loss = hpc_loss.mean() 62 | hpc_loss.backward() 63 | if use_cuda: 64 | torch.cuda.synchronize() 65 | 66 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 67 | print("qntd rescale fp mean_relative_error: " + str(mre)) 68 | mre = mean_relative_error(torch.flatten(ori_q.grad).cpu().detach().numpy(), torch.flatten(hpc_q.grad).cpu().detach().numpy()) 69 | print("qntd rescale bp mean_relative_error: " + str(mre)) 70 | 71 | def qntd_rescale_perf(): 72 | ori_q = torch.randn(B, N) 73 | ori_next_n_q = torch.randn(B, N) 74 | ori_action = torch.randint(0, N, size=(B, )) 75 | ori_next_n_action = torch.randint(0, N, size=(B, )) 76 | ori_reward = torch.randn(T, B) 77 | ori_done = torch.randn(B) 78 | ori_weight = torch.randn(B) 79 | 80 | hpc_q = ori_q.clone().detach() 81 | hpc_next_n_q = ori_next_n_q.clone().detach() 82 | hpc_action = ori_action.clone().detach() 83 | hpc_next_n_action = ori_next_n_action.clone().detach() 84 | hpc_reward = ori_reward.clone().detach() 85 | hpc_done = ori_done.clone().detach() 86 | hpc_weight = ori_weight.clone().detach() 87 | hpc_qntd_rescale = QNStepTDRescale(T, B, N) 88 | 89 | if use_cuda: 90 | ori_q = ori_q.cuda() 91 | ori_next_n_q = ori_next_n_q.cuda() 92 | ori_action = ori_action.cuda() 93 | ori_next_n_action = ori_next_n_action.cuda() 94 | ori_reward = ori_reward.cuda() 95 | ori_done = ori_done.cuda() 96 | ori_weight = ori_weight.cuda() 97 | 98 | hpc_q = hpc_q.cuda() 99 | hpc_next_n_q = hpc_next_n_q.cuda() 100 | hpc_action = hpc_action.cuda() 101 | hpc_next_n_action = hpc_next_n_action.cuda() 102 | hpc_reward = hpc_reward.cuda() 103 | hpc_done = hpc_done.cuda() 104 | hpc_weight = hpc_weight.cuda() 105 | hpc_qntd_rescale = hpc_qntd_rescale.cuda() 106 | 107 | ori_q.requires_grad_(True) 108 | for i in range(times): 109 | t = time.time() 110 | ori_loss, _ = q_nstep_td_error_with_rescale( 111 | q_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight), gamma, T) 112 | ori_loss = ori_loss.mean() 113 | ori_loss.backward() 114 | if use_cuda: 115 | torch.cuda.synchronize() 116 | print('epoch: {}, original qntd rescale cost time: {}'.format(i, time.time() - t)) 117 | 118 | hpc_q.requires_grad_(True) 119 | for i in range(times): 120 | t = time.time() 121 | hpc_loss, _ = hpc_qntd_rescale(hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma) 122 | hpc_loss = hpc_loss.mean() 123 | hpc_loss.backward() 124 | if use_cuda: 125 | torch.cuda.synchronize() 126 | print('epoch: {}, hpc qntd rescale cost time: {}'.format(i, time.time() - t)) 127 | 128 | 129 | if __name__ == '__main__': 130 | print("target problem: T = {}, B = {}, N = {}, gamma = {}".format(T, B, N, gamma)) 131 | print("================run qntd rescale validation test================") 132 | qntd_rescale_val() 133 | print("================run qntd rescale performance test================") 134 | qntd_rescale_perf() 135 | -------------------------------------------------------------------------------- /tests/test_qrdqn_nstep_td_error.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from hpc_rll.origin.td import qrdqn_nstep_td_error, qrdqn_nstep_td_data 4 | from hpc_rll.rl_utils.td import QRDQNNStepTDError 5 | from testbase import mean_relative_error, times 6 | 7 | assert torch.cuda.is_available() 8 | use_cuda = True 9 | 10 | tau = 39 11 | T = 10 12 | B = 89 13 | N = 67 14 | gamma = 0.95 15 | 16 | def qrdqn_val(): 17 | ori_q = torch.randn(B, N, tau) 18 | ori_next_n_q = torch.randn(B, N, tau) 19 | ori_action = torch.randint(0, N, size=(B, )) 20 | ori_next_n_action = torch.randint(0, N, size=(B, )) 21 | ori_reward = torch.randn(T, B) 22 | ori_done = torch.randn(B) 23 | ori_weight = torch.randn(B) 24 | ori_value_gamma = torch.randn(B) 25 | 26 | hpc_q = ori_q.clone().detach() 27 | hpc_next_n_q = ori_next_n_q.clone().detach() 28 | hpc_action = ori_action.clone().detach() 29 | hpc_next_n_action = ori_next_n_action.clone().detach() 30 | hpc_reward = ori_reward.clone().detach() 31 | hpc_done = ori_done.clone().detach() 32 | hpc_weight = ori_weight.clone().detach() 33 | hpc_value_gamma = ori_value_gamma.clone().detach() 34 | hpc_qrdqn = QRDQNNStepTDError(tau, T, B, N) 35 | 36 | if use_cuda: 37 | ori_q = ori_q.cuda() 38 | ori_next_n_q = ori_next_n_q.cuda() 39 | ori_action = ori_action.cuda() 40 | ori_next_n_action = ori_next_n_action.cuda() 41 | ori_reward = ori_reward.cuda() 42 | ori_done = ori_done.cuda() 43 | ori_weight = ori_weight.cuda() 44 | ori_value_gamma = ori_value_gamma.cuda() 45 | 46 | hpc_q = hpc_q.cuda() 47 | hpc_next_n_q = hpc_next_n_q.cuda() 48 | hpc_action = hpc_action.cuda() 49 | hpc_next_n_action = hpc_next_n_action.cuda() 50 | hpc_reward = hpc_reward.cuda() 51 | hpc_done = hpc_done.cuda() 52 | hpc_weight = hpc_weight.cuda() 53 | hpc_value_gamma = hpc_value_gamma.cuda() 54 | hpc_qrdqn = hpc_qrdqn.cuda() 55 | 56 | ori_q.requires_grad_(True) 57 | ori_loss, ori_ = qrdqn_nstep_td_error(qrdqn_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, tau, ori_weight), gamma, T, ori_value_gamma) 58 | ori_loss = ori_loss.mean() 59 | ori_loss.backward() 60 | if use_cuda: 61 | torch.cuda.synchronize() 62 | 63 | torch.cuda.cudart().cudaProfilerStart() 64 | hpc_q.requires_grad_(True) 65 | hpc_loss, hpc_ = hpc_qrdqn(hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, gamma, hpc_weight, hpc_value_gamma) 66 | hpc_loss = hpc_loss.mean() 67 | hpc_loss.backward() 68 | if use_cuda: 69 | torch.cuda.synchronize() 70 | torch.cuda.cudart().cudaProfilerStop() 71 | 72 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 73 | print("qrdqn fp mean_relative_error: " + str(mre)) 74 | mre = mean_relative_error(torch.flatten(ori_q.grad).cpu().detach().numpy(), torch.flatten(hpc_q.grad).cpu().detach().numpy()) 75 | print("qrdqn bp mean_relative_error: " + str(mre)) 76 | 77 | def qrdqn_perf(): 78 | ori_q = torch.randn(B, N, tau) 79 | ori_next_n_q = torch.randn(B, N, tau) 80 | ori_action = torch.randint(0, N, size=(B, )) 81 | ori_next_n_action = torch.randint(0, N, size=(B, )) 82 | ori_reward = torch.randn(T, B) 83 | ori_done = torch.randn(B) 84 | ori_weight = torch.randn(B) 85 | ori_value_gamma = torch.randn(B) 86 | 87 | hpc_q = ori_q.clone().detach() 88 | hpc_next_n_q = ori_next_n_q.clone().detach() 89 | hpc_action = ori_action.clone().detach() 90 | hpc_next_n_action = ori_next_n_action.clone().detach() 91 | hpc_reward = ori_reward.clone().detach() 92 | hpc_done = ori_done.clone().detach() 93 | hpc_weight = ori_weight.clone().detach() 94 | hpc_value_gamma = ori_value_gamma.clone().detach() 95 | hpc_qrdqn = QRDQNNStepTDError(tau, T, B, N) 96 | 97 | if use_cuda: 98 | ori_q = ori_q.cuda() 99 | ori_next_n_q = ori_next_n_q.cuda() 100 | ori_action = ori_action.cuda() 101 | ori_next_n_action = ori_next_n_action.cuda() 102 | ori_reward = ori_reward.cuda() 103 | ori_done = ori_done.cuda() 104 | ori_weight = ori_weight.cuda() 105 | ori_value_gamma = ori_value_gamma.cuda() 106 | 107 | hpc_q = hpc_q.cuda() 108 | hpc_next_n_q = hpc_next_n_q.cuda() 109 | hpc_action = hpc_action.cuda() 110 | hpc_next_n_action = hpc_next_n_action.cuda() 111 | hpc_reward = hpc_reward.cuda() 112 | hpc_done = hpc_done.cuda() 113 | hpc_weight = hpc_weight.cuda() 114 | hpc_value_gamma = hpc_value_gamma.cuda() 115 | hpc_qrdqn = hpc_qrdqn.cuda() 116 | 117 | ori_q.requires_grad_(True) 118 | for i in range(times): 119 | t = time.time() 120 | ori_loss, ori_ = qrdqn_nstep_td_error(qrdqn_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, tau, ori_weight), gamma, T, ori_value_gamma) 121 | ori_loss = ori_loss.mean() 122 | ori_loss.backward() 123 | if use_cuda: 124 | torch.cuda.synchronize() 125 | print('epoch: {}, original qrdqn cost time: {}'.format(i, time.time() - t)) 126 | 127 | #torch.cuda.cudart().cudaProfilerStart() 128 | hpc_q.requires_grad_(True) 129 | for i in range(times): 130 | t = time.time() 131 | hpc_loss, hpc_ = hpc_qrdqn(hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, gamma, hpc_weight, hpc_value_gamma) 132 | hpc_loss = hpc_loss.mean() 133 | hpc_loss.backward() 134 | if use_cuda: 135 | torch.cuda.synchronize() 136 | print('epoch: {}, hpc qrdqn cost time: {}'.format(i, time.time() - t)) 137 | #torch.cuda.cudart().cudaProfilerStop() 138 | 139 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 140 | print("qrdqn fp mean_relative_error: " + str(mre)) 141 | mre = mean_relative_error(torch.flatten(ori_q.grad).cpu().detach().numpy(), torch.flatten(hpc_q.grad).cpu().detach().numpy()) 142 | print("qrdqn bp mean_relative_error: " + str(mre)) 143 | 144 | if __name__ == '__main__': 145 | print("target problem: tau = {}, T = {}, B = {}, N = {}, gamma = {}".format(tau, T, B, N, gamma)) 146 | print("================run qrdqn validation test================") 147 | qrdqn_val() 148 | print("================run qrdqn performance test================") 149 | qrdqn_perf() 150 | -------------------------------------------------------------------------------- /tests/test_scatter.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from typing import Tuple 4 | from hpc_rll.origin.scatter_connection import ScatterConnection 5 | from hpc_rll.torch_utils.network.scatter_connection import ScatterConnection as HPCScatterConnection 6 | from testbase import mean_relative_error, times 7 | 8 | assert torch.cuda.is_available() 9 | use_cuda = True 10 | 11 | B = 256 12 | M = 256 13 | N = 256 14 | H = 16 15 | W = 16 16 | 17 | # Note: origin gpu version of cover mode is not determinate, thus validation test use origin cpu version instead 18 | def scatter_val(): 19 | for scatter_type in ['add', 'cover']: 20 | ori_input = torch.randn(B, M, N) 21 | h = torch.randint(low=0, high=H, size=(B, M, )).unsqueeze(dim=2) 22 | w = torch.randint(low=0, high=W, size=(B, M, )).unsqueeze(dim=2) 23 | ori_location = torch.cat([h, w], dim=2) 24 | ori_scatter = ScatterConnection(scatter_type) 25 | 26 | hpc_input = ori_input.clone().detach() 27 | hpc_location = ori_location.clone().detach() 28 | hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type) 29 | 30 | if use_cuda: 31 | #ori_input = ori_input.cuda() 32 | #ori_location = ori_location.cuda() 33 | #ori_scatter = ori_scatter.cuda() 34 | 35 | hpc_input = hpc_input.cuda() 36 | hpc_location = hpc_location.cuda() 37 | hpc_scatter = hpc_scatter.cuda() 38 | 39 | ori_input.requires_grad_(True) 40 | ori_output = ori_scatter(ori_input, (H, W), ori_location) 41 | ori_loss = ori_output * ori_output 42 | ori_loss = ori_loss.mean() 43 | ori_loss.backward() 44 | if use_cuda: 45 | torch.cuda.synchronize() 46 | 47 | hpc_input.requires_grad_(True) 48 | hpc_output = hpc_scatter(hpc_input, hpc_location) 49 | hpc_loss = hpc_output * hpc_output 50 | hpc_loss = hpc_loss.mean() 51 | hpc_loss.backward() 52 | if use_cuda: 53 | torch.cuda.synchronize() 54 | 55 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 56 | print("scatter type {} fp mean_relative_error: {}".format(scatter_type, str(mre))) 57 | mre = mean_relative_error(torch.flatten(ori_input.grad).cpu().detach().numpy(), torch.flatten(hpc_input.grad).cpu().detach().numpy()) 58 | print("scatter type {} bp mean_relative_error: {}".format(scatter_type, str(mre))) 59 | 60 | 61 | # Note: performance test use origin gpu version 62 | def scatter_perf(): 63 | for scatter_type in ['add', 'cover']: 64 | ori_input = torch.randn(B, M, N) 65 | h = torch.randint(low=0, high=H, size=(B, M, )).unsqueeze(dim=2) 66 | w = torch.randint(low=0, high=W, size=(B, M, )).unsqueeze(dim=2) 67 | ori_location = torch.cat([h, w], dim=2) 68 | ori_scatter = ScatterConnection(scatter_type) 69 | 70 | hpc_input = ori_input.clone().detach() 71 | hpc_location = ori_location.clone().detach() 72 | hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type) 73 | 74 | if use_cuda: 75 | ori_input = ori_input.cuda() 76 | ori_location = ori_location.cuda() 77 | ori_scatter = ori_scatter.cuda() 78 | 79 | hpc_input = hpc_input.cuda() 80 | hpc_location = hpc_location.cuda() 81 | hpc_scatter = hpc_scatter.cuda() 82 | 83 | for i in range(times): 84 | t = time.time() 85 | ori_input.requires_grad_(True) 86 | ori_output = ori_scatter(ori_input, (H, W), ori_location) 87 | ori_loss = ori_output * ori_output 88 | ori_loss = ori_loss.mean() 89 | ori_loss.backward() 90 | if use_cuda: 91 | torch.cuda.synchronize() 92 | print('epoch: {}, original scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t)) 93 | 94 | for i in range(times): 95 | t = time.time() 96 | hpc_input.requires_grad_(True) 97 | hpc_output = hpc_scatter(hpc_input, hpc_location) 98 | hpc_loss = hpc_output * hpc_output 99 | hpc_loss = hpc_loss.mean() 100 | hpc_loss.backward() 101 | if use_cuda: 102 | torch.cuda.synchronize() 103 | print('epoch: {}, hpc scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t)) 104 | 105 | 106 | if __name__ == '__main__': 107 | print("target problem: B = {}, M = {}, N = {}, H = {}, W = {}".format(B, M, N, H, W)) 108 | print("================run scatter validation test================") 109 | scatter_val() 110 | print("================run scatter performance test================") 111 | scatter_perf() 112 | -------------------------------------------------------------------------------- /tests/test_tdlambda.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from hpc_rll.origin.td import td_lambda_error, td_lambda_data 4 | from hpc_rll.rl_utils.td import TDLambda 5 | from testbase import mean_relative_error, times 6 | 7 | assert torch.cuda.is_available() 8 | use_cuda = True 9 | 10 | T = 1024 11 | B = 64 12 | 13 | def td_val(): 14 | ori_value = torch.randn(T + 1, B) 15 | ori_reward = torch.randn(T, B) 16 | ori_weight = torch.randn(T, B) 17 | 18 | hpc_value = ori_value.clone().detach() 19 | hpc_reward = ori_reward.clone().detach() 20 | hpc_weight = ori_weight.clone().detach() 21 | hpc_td = TDLambda(T, B) 22 | 23 | if use_cuda: 24 | ori_value = ori_value.cuda() 25 | ori_reward = ori_reward.cuda() 26 | ori_weight = ori_weight.cuda() 27 | 28 | hpc_value = hpc_value.cuda() 29 | hpc_reward = hpc_reward.cuda() 30 | hpc_weight = hpc_weight.cuda() 31 | hpc_td = hpc_td.cuda() 32 | 33 | ori_value.requires_grad_(True) 34 | ori_loss = td_lambda_error(td_lambda_data(ori_value, ori_reward, ori_weight)) 35 | ori_loss = ori_loss.mean() 36 | ori_loss.backward() 37 | if use_cuda: 38 | torch.cuda.synchronize() 39 | 40 | hpc_value.requires_grad_(True) 41 | hpc_loss = hpc_td(hpc_value, hpc_reward, hpc_weight) 42 | hpc_loss = hpc_loss.mean() 43 | hpc_loss.backward() 44 | if use_cuda: 45 | torch.cuda.synchronize() 46 | 47 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 48 | print("td fp mean_relative_error: " + str(mre)) 49 | mre = mean_relative_error(torch.flatten(ori_value.grad).cpu().detach().numpy(), torch.flatten(hpc_value.grad).cpu().detach().numpy()) 50 | print("td bp mean_relative_error: " + str(mre)) 51 | 52 | def td_perf(): 53 | ori_value = torch.randn(T + 1, B) 54 | ori_reward = torch.randn(T, B) 55 | ori_weight = torch.randn(T, B) 56 | 57 | hpc_value = ori_value.clone().detach() 58 | hpc_reward = ori_reward.clone().detach() 59 | hpc_weight = ori_weight.clone().detach() 60 | hpc_td = TDLambda(T, B) 61 | 62 | if use_cuda: 63 | ori_value = ori_value.cuda() 64 | ori_reward = ori_reward.cuda() 65 | ori_weight = ori_weight.cuda() 66 | 67 | hpc_value = hpc_value.cuda() 68 | hpc_reward = hpc_reward.cuda() 69 | hpc_weight = hpc_weight.cuda() 70 | hpc_td = hpc_td.cuda() 71 | 72 | ori_value.requires_grad_(True) 73 | for i in range(times): 74 | t = time.time() 75 | ori_loss = td_lambda_error(td_lambda_data(ori_value, ori_reward, ori_weight)) 76 | ori_loss = ori_loss.mean() 77 | ori_loss.backward() 78 | if use_cuda: 79 | torch.cuda.synchronize() 80 | print('epoch: {}, original td cost time: {}'.format(i, time.time() - t)) 81 | 82 | hpc_value.requires_grad_(True) 83 | for i in range(times): 84 | t = time.time() 85 | hpc_loss = hpc_td(hpc_value, hpc_reward, hpc_weight) 86 | hpc_loss = hpc_loss.mean() 87 | hpc_loss.backward() 88 | if use_cuda: 89 | torch.cuda.synchronize() 90 | print('epoch: {}, hpc td cost time: {}'.format(i, time.time() - t)) 91 | 92 | 93 | if __name__ == '__main__': 94 | print("target problem: T = {}, B = {}".format(T, B)) 95 | print("================run td validation test================") 96 | td_val() 97 | print("================run td performance test================") 98 | td_perf() 99 | -------------------------------------------------------------------------------- /tests/test_upgo.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from hpc_rll.origin.upgo import upgo_loss 4 | from hpc_rll.rl_utils.upgo import UPGO 5 | from testbase import mean_relative_error, times 6 | 7 | assert torch.cuda.is_available() 8 | use_cuda = True 9 | 10 | T = 256 11 | B = 256 12 | N = 256 13 | 14 | def upgo_val(): 15 | ori_target_output = torch.randn(T, B, N) 16 | ori_rhos = torch.randn(T, B) 17 | ori_action = torch.randint(0, N, size=(T, B, )) 18 | ori_rewards = torch.randn(T, B) 19 | ori_bootstrap_values = torch.randn(T + 1, B) 20 | 21 | hpc_target_output = ori_target_output.clone().detach() 22 | hpc_rhos = ori_rhos.clone().detach() 23 | hpc_action = ori_action.clone().detach() 24 | hpc_rewards = ori_rewards.clone().detach() 25 | hpc_bootstrap_values = ori_bootstrap_values.clone().detach() 26 | hpc_upgo = UPGO(T, B, N) 27 | 28 | if use_cuda: 29 | ori_target_output = ori_target_output.cuda() 30 | ori_rhos = ori_rhos.cuda() 31 | ori_action = ori_action.cuda() 32 | ori_rewards = ori_rewards.cuda() 33 | ori_bootstrap_values = ori_bootstrap_values.cuda() 34 | 35 | hpc_target_output = hpc_target_output.cuda() 36 | hpc_rhos = hpc_rhos.cuda() 37 | hpc_action = hpc_action.cuda() 38 | hpc_rewards = hpc_rewards.cuda() 39 | hpc_bootstrap_values = hpc_bootstrap_values.cuda() 40 | hpc_upgo = hpc_upgo.cuda() 41 | 42 | ori_target_output.requires_grad_(True) 43 | ori_loss = upgo_loss(ori_target_output, ori_rhos, ori_action, ori_rewards, ori_bootstrap_values) 44 | ori_loss = ori_loss.mean() 45 | ori_loss.backward() 46 | if use_cuda: 47 | torch.cuda.synchronize() 48 | 49 | hpc_target_output.requires_grad_(True) 50 | hpc_loss = hpc_upgo(hpc_target_output, hpc_rhos, hpc_action, hpc_rewards, hpc_bootstrap_values) 51 | hpc_loss = hpc_loss.mean() 52 | hpc_loss.backward() 53 | if use_cuda: 54 | torch.cuda.synchronize() 55 | 56 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 57 | print("upgo fp mean_relative_error: " + str(mre)) 58 | mre = mean_relative_error(torch.flatten(ori_target_output.grad).cpu().detach().numpy(), torch.flatten(hpc_target_output.grad).cpu().detach().numpy()) 59 | print("upgo bp mean_relative_error: " + str(mre)) 60 | 61 | def upgo_perf(): 62 | ori_target_output = torch.randn(T, B, N) 63 | ori_rhos = torch.randn(T, B) 64 | ori_action = torch.randint(0, N, size=(T, B, )) 65 | ori_rewards = torch.randn(T, B) 66 | ori_bootstrap_values = torch.randn(T + 1, B) 67 | 68 | hpc_target_output = ori_target_output.clone().detach() 69 | hpc_rhos = ori_rhos.clone().detach() 70 | hpc_action = ori_action.clone().detach() 71 | hpc_rewards = ori_rewards.clone().detach() 72 | hpc_bootstrap_values = ori_bootstrap_values.clone().detach() 73 | hpc_upgo = UPGO(T, B, N) 74 | 75 | if use_cuda: 76 | ori_target_output = ori_target_output.cuda() 77 | ori_rhos = ori_rhos.cuda() 78 | ori_action = ori_action.cuda() 79 | ori_rewards = ori_rewards.cuda() 80 | ori_bootstrap_values = ori_bootstrap_values.cuda() 81 | 82 | hpc_target_output = hpc_target_output.cuda() 83 | hpc_rhos = hpc_rhos.cuda() 84 | hpc_action = hpc_action.cuda() 85 | hpc_rewards = hpc_rewards.cuda() 86 | hpc_bootstrap_values = hpc_bootstrap_values.cuda() 87 | hpc_upgo = hpc_upgo.cuda() 88 | 89 | ori_target_output.requires_grad_(True) 90 | for i in range(times): 91 | t = time.time() 92 | ori_loss = upgo_loss(ori_target_output, ori_rhos, ori_action, ori_rewards, ori_bootstrap_values) 93 | ori_loss = ori_loss.mean() 94 | ori_loss.backward() 95 | if use_cuda: 96 | torch.cuda.synchronize() 97 | print('epoch: {}, original upgo cost time: {}'.format(i, time.time() - t)) 98 | 99 | hpc_target_output.requires_grad_(True) 100 | for i in range(times): 101 | t = time.time() 102 | hpc_loss = hpc_upgo(hpc_target_output, hpc_rhos, hpc_action, hpc_rewards, hpc_bootstrap_values) 103 | hpc_loss = hpc_loss.mean() 104 | hpc_loss.backward() 105 | if use_cuda: 106 | torch.cuda.synchronize() 107 | print('epoch: {}, hpc upgo cost time: {}'.format(i, time.time() - t)) 108 | 109 | if __name__ == '__main__': 110 | print("target problem: T = {}, B = {}, N = {}".format(T, B, N)) 111 | print("================run upgo validation test================") 112 | upgo_val() 113 | print("================run upgo performance test================") 114 | upgo_perf() 115 | -------------------------------------------------------------------------------- /tests/test_vtrace.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn.functional as F 4 | from hpc_rll.origin.vtrace import vtrace_error, vtrace_data 5 | from hpc_rll.rl_utils.vtrace import VTrace 6 | from testbase import mean_relative_error, times 7 | 8 | assert torch.cuda.is_available() 9 | use_cuda = True 10 | 11 | T = 128 12 | B = 128 13 | N = 128 14 | 15 | def vtrace_val(): 16 | ori_target_output = torch.randn(T, B, N) 17 | ori_behaviour_output = torch.randn(T, B, N) 18 | ori_action = torch.randint(0, N, size=(T, B, )) 19 | ori_value = torch.randn(T + 1, B) 20 | ori_reward = torch.randn(T, B) 21 | 22 | hpc_target_output = ori_target_output.clone().detach() 23 | hpc_behaviour_output = ori_behaviour_output.clone().detach() 24 | hpc_action = ori_action.clone().detach() 25 | hpc_value = ori_value.clone().detach() 26 | hpc_reward = ori_reward.clone().detach() 27 | hpc_vtrace = VTrace(T, B, N) 28 | 29 | if use_cuda: 30 | ori_target_output = ori_target_output.cuda() 31 | ori_behaviour_output = ori_behaviour_output.cuda() 32 | ori_action = ori_action.cuda() 33 | ori_value = ori_value.cuda() 34 | ori_reward = ori_reward.cuda() 35 | 36 | hpc_target_output = hpc_target_output.cuda() 37 | hpc_behaviour_output = hpc_behaviour_output.cuda() 38 | hpc_action = hpc_action.cuda() 39 | hpc_value = hpc_value.cuda() 40 | hpc_reward = hpc_reward.cuda() 41 | hpc_vtrace = hpc_vtrace.cuda() 42 | 43 | ori_target_output.requires_grad_(True) 44 | ori_value.requires_grad_(True) 45 | ori_loss = vtrace_error(vtrace_data(ori_target_output, ori_behaviour_output, ori_action, ori_value, ori_reward, None)) 46 | ori_loss = sum(ori_loss) 47 | ori_loss.backward() 48 | 49 | hpc_target_output.requires_grad_(True) 50 | hpc_value.requires_grad_(True) 51 | hpc_loss = hpc_vtrace(hpc_target_output, hpc_behaviour_output, hpc_action, hpc_value, hpc_reward) 52 | hpc_loss = sum(hpc_loss) 53 | hpc_loss.backward() 54 | 55 | mre = mean_relative_error(torch.flatten(ori_loss).cpu().detach().numpy(), torch.flatten(hpc_loss).cpu().detach().numpy()) 56 | print("vtrace fp mean_relative_error: " + str(mre)) 57 | mre = mean_relative_error(torch.flatten(ori_target_output.grad).cpu().detach().numpy(), torch.flatten(hpc_target_output.grad).cpu().detach().numpy()) 58 | print("vtrace bp target_output mean_relative_error: " + str(mre)) 59 | mre = mean_relative_error(torch.flatten(ori_value.grad).cpu().detach().numpy(), torch.flatten(hpc_value.grad).cpu().detach().numpy()) 60 | print("vtrace bp value mean_relative_error: " + str(mre)) 61 | 62 | 63 | def vtrace_perf(): 64 | ori_target_output = torch.randn(T, B, N) 65 | ori_behaviour_output = torch.randn(T, B, N) 66 | ori_action = torch.randint(0, N, size=(T, B, )) 67 | ori_value = torch.randn(T + 1, B) 68 | ori_reward = torch.randn(T, B) 69 | 70 | hpc_target_output = ori_target_output.clone().detach() 71 | hpc_behaviour_output = ori_behaviour_output.clone().detach() 72 | hpc_action = ori_action.clone().detach() 73 | hpc_value = ori_value.clone().detach() 74 | hpc_reward = ori_reward.clone().detach() 75 | hpc_vtrace = VTrace(T, B, N) 76 | 77 | if use_cuda: 78 | ori_target_output = ori_target_output.cuda() 79 | ori_behaviour_output = ori_behaviour_output.cuda() 80 | ori_action = ori_action.cuda() 81 | ori_value = ori_value.cuda() 82 | ori_reward = ori_reward.cuda() 83 | 84 | hpc_target_output = hpc_target_output.cuda() 85 | hpc_behaviour_output = hpc_behaviour_output.cuda() 86 | hpc_action = hpc_action.cuda() 87 | hpc_value = hpc_value.cuda() 88 | hpc_reward = hpc_reward.cuda() 89 | hpc_vtrace = hpc_vtrace.cuda() 90 | 91 | ori_target_output.requires_grad_(True) 92 | ori_value.requires_grad_(True) 93 | for i in range(times): 94 | t = time.time() 95 | ori_loss = vtrace_error(vtrace_data(ori_target_output, ori_behaviour_output, ori_action, ori_value, ori_reward, None)) 96 | ori_loss = sum(ori_loss) 97 | ori_loss.backward() 98 | if use_cuda: 99 | torch.cuda.synchronize() 100 | print('epoch: {}, original vtrace cost time: {}'.format(i, time.time() - t)) 101 | 102 | hpc_target_output.requires_grad_(True) 103 | hpc_value.requires_grad_(True) 104 | for i in range(times): 105 | t = time.time() 106 | hpc_loss = hpc_vtrace(hpc_target_output, hpc_behaviour_output, hpc_action, hpc_value, hpc_reward) 107 | hpc_loss = sum(hpc_loss) 108 | hpc_loss.backward() 109 | if use_cuda: 110 | torch.cuda.synchronize() 111 | print('epoch: {}, hpc vtrace cost time: {}'.format(i, time.time() - t)) 112 | 113 | 114 | if __name__ == '__main__': 115 | print("target problem: T = {}, B = {}, N = {}".format(T, B, N)) 116 | print("================run vtrace validation test================") 117 | vtrace_val() 118 | print("================run vtrace performance test================") 119 | vtrace_perf() 120 | -------------------------------------------------------------------------------- /tests/testbase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | torch.set_printoptions(precision=6) 5 | 6 | times = 6 7 | 8 | def mean_relative_error(y_true, y_pred): 9 | eps = 1e-5 10 | relative_error = np.average(np.abs(y_true - y_pred) / (y_true + eps)) 11 | return relative_error 12 | 13 | -------------------------------------------------------------------------------- /triton_rl/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/DI-hpc/a8a773480571491e70bb3021cbb2c1adcb7dce12/triton_rl/README.md -------------------------------------------------------------------------------- /triton_rl/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>2 2 | triton 3 | --------------------------------------------------------------------------------