├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── data └── download_d4rl_datasets.py ├── decision_transformer ├── __init__.py ├── d4rl_infos.py ├── model.py └── utils.py ├── dt_runs ├── __init__.py ├── dt_halfcheetah-medium-v2_log_22-02-10-11-56-32.csv ├── dt_halfcheetah-medium-v2_log_22-02-11-10-13-57.csv ├── dt_halfcheetah-medium-v2_log_22-02-13-09-03-10.csv ├── dt_halfcheetah-medium-v2_model_22-02-13-09-03-10_best.pt ├── dt_hopper-medium-v2_log_22-02-12-09-43-59.csv ├── dt_hopper-medium-v2_log_22-02-13-05-45-16.csv ├── dt_hopper-medium-v2_log_22-02-13-08-03-24.csv ├── dt_hopper-medium-v2_model_22-02-12-09-43-59_best.pt ├── dt_walker2d-medium-v2_log_22-02-20-06-27-12.csv ├── dt_walker2d-medium-v2_log_22-02-20-09-11-30.csv ├── dt_walker2d-medium-v2_log_22-02-22-09-24-12.csv └── dt_walker2d-medium-v2_model_22-02-22-09-24-12_best.pt ├── media ├── halfcheetah-medium-v2.gif ├── halfcheetah-medium-v2.png ├── hopper-medium-v2.gif ├── hopper-medium-v2.png ├── walker2d-medium-v2.gif └── walker2d-medium-v2.png ├── min_decision_transformer.ipynb ├── requirements.txt └── scripts ├── plot.py ├── test.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-detectable=false 2 | *.py linguist-detectable=true 3 | 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.csv 3 | *.log 4 | *.pt 5 | *.pkl 6 | 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Nikhil Barhate 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Decision Transformer 2 | 3 | 4 | ## Overview 5 | 6 | Minimal code for [Decision Transformer: Reinforcement Learning via Sequence Modeling](https://arxiv.org/abs/2106.01345) for mujoco control tasks in OpenAI gym. 7 | Notable difference from official implementation are: 8 | 9 | - Simple GPT implementation (causal transformer) 10 | - Uses PyTorch's Dataset and Dataloader class and removes redundant computations for calculating rewards to go and state normalization for efficient training 11 | - Can be trained and the results can be visualized and rendered on google colab with the provided notebook 12 | 13 | #### [Open `min_decision_transformer.ipynb` in Google Colab](https://colab.research.google.com/github/nikhilbarhate99/min-decision-transformer/blob/master/min_decision_transformer.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nikhilbarhate99/min-decision-transformer/blob/master/min_decision_transformer.ipynb) 14 | 15 | 16 | 17 | ## Results 18 | 19 | **Note:** these results are mean and variance of 3 random seeds obtained after 20k updates (due to timelimits on GPU resources on colab) while the official results are obtained after 100k updates. So these numbers are not directly comparable, but they can be used as rough reference points along with their corresponding plots to measure the learning progress of the model. The variance in returns and scores should decrease as training reaches saturation. 20 | 21 | 22 | | Dataset | Environment | DT (this repo) 20k updates | DT (official) 100k updates| 23 | | :---: | :---: | :---: | :---: | 24 | | Medium | HalfCheetah | 42.18 ± 00.59 | 42.60 ± 00.10 | 25 | | Medium | Hopper | 69.43 ± 27.34 | 67.60 ± 01.00 | 26 | | Medium | Walker | 75.47 ± 31.08 | 74.00 ± 01.40 | 27 | 28 | 29 | | ![](https://github.com/nikhilbarhate99/min-decision-transformer/blob/master/media/halfcheetah-medium-v2.png) | ![](https://github.com/nikhilbarhate99/min-decision-transformer/blob/master/media/halfcheetah-medium-v2.gif) | 30 | | :---:|:---: | 31 | 32 | 33 | | ![](https://github.com/nikhilbarhate99/min-decision-transformer/blob/master/media/hopper-medium-v2.png) | ![](https://github.com/nikhilbarhate99/min-decision-transformer/blob/master/media/hopper-medium-v2.gif) | 34 | | :---:|:---: | 35 | 36 | 37 | | ![](https://github.com/nikhilbarhate99/min-decision-transformer/blob/master/media/walker2d-medium-v2.png) | ![](https://github.com/nikhilbarhate99/min-decision-transformer/blob/master/media/walker2d-medium-v2.gif) | 38 | | :---:|:---: | 39 | 40 | 41 | 42 | ## Instructions 43 | 44 | ### Mujoco-py 45 | 46 | Install `mujoco-py` library by following instructions on [mujoco-py repo](https://github.com/openai/mujoco-py) 47 | 48 | 49 | ### D4RL Data 50 | 51 | Datasets are expected to be stored in the `data` directory. Install the [D4RL repo](https://github.com/rail-berkeley/d4rl). Then save formatted data in the `data` directory by running the following script: 52 | ``` 53 | python3 data/download_d4rl_datasets.py 54 | ``` 55 | 56 | 57 | ### Running experiments 58 | 59 | - Example command for training: 60 | ``` 61 | python3 scripts/train.py --env halfcheetah --dataset medium --device cuda 62 | ``` 63 | 64 | 65 | - Example command for testing with a pretrained model: 66 | ``` 67 | python3 scripts/test.py --env halfcheetah --dataset medium --device cpu --num_eval_ep 1 --chk_pt_name dt_halfcheetah-medium-v2_model_22-02-13-09-03-10_best.pt 68 | ``` 69 | The `dataset` needs to be specified for testing, to load the same state normalization statistics (mean and var) that is used for training. 70 | An additional `--render` flag can be passed to the script for rendering the test episode. 71 | 72 | 73 | - Example command for plotting graphs using logged data from the csv files: 74 | ``` 75 | python3 scripts/plot.py --env_d4rl_name halfcheetah-medium-v2 --smoothing_window 5 76 | ``` 77 | Additionally `--plot_avg` and `--save_fig` flags can be passed to the script to average all values in one plot and to save the figure. 78 | 79 | 80 | ### Note: 81 | 1. If you find it difficult to install `mujoco-py` and `d4rl` then you can refer to their installation in the colab notebook 82 | 2. Once the dataset is formatted and saved with `download_d4rl_datasets.py`, `d4rl` library is not required further for training. 83 | 3. The evaluation is done on `v3` control environments in `mujoco-py` so that the results are consistent with the decision transformer paper. 84 | 85 | 86 | ## Citing 87 | 88 | Please use this bibtex if you want to cite this repository in your publications: 89 | 90 | @misc{minimal_decision_transformer, 91 | author = {Barhate, Nikhil}, 92 | title = {Minimal Implementation of Decision Transformer}, 93 | year = {2022}, 94 | publisher = {GitHub}, 95 | journal = {GitHub repository}, 96 | howpublished = {\url{https://github.com/nikhilbarhate99/min-decision-transformer}}, 97 | } 98 | 99 | 100 | 101 | ## References 102 | 103 | - Official [code](https://github.com/kzl/decision-transformer) and [paper](https://arxiv.org/abs/2106.01345) 104 | - Minimal GPT (causal transformer) [tweet](https://twitter.com/MishaLaskin/status/1481767788775628801?cxt=HHwWgoCzmYD9pZApAAAA) and [colab notebook](https://colab.research.google.com/drive/1NUBqyboDcGte5qAJKOl8gaJC28V_73Iv?usp=sharing) 105 | -------------------------------------------------------------------------------- /data/download_d4rl_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gym 3 | import numpy as np 4 | 5 | import collections 6 | import pickle 7 | 8 | import d4rl 9 | 10 | def download_d4rl_data(): 11 | datasets = [] 12 | 13 | data_dir = 'data/' 14 | 15 | print(data_dir) 16 | 17 | if not os.path.exists(data_dir): 18 | os.makedirs(data_dir) 19 | 20 | for env_name in ['walker2d', 'halfcheetah', 'hopper']: 21 | for dataset_type in ['medium', 'medium-expert', 'medium-replay']: 22 | 23 | name = f'{env_name}-{dataset_type}-v2' 24 | pkl_file_path = os.path.join(data_dir, name) 25 | 26 | print("processing: ", name) 27 | 28 | env = gym.make(name) 29 | dataset = env.get_dataset() 30 | 31 | N = dataset['rewards'].shape[0] 32 | data_ = collections.defaultdict(list) 33 | 34 | use_timeouts = False 35 | if 'timeouts' in dataset: 36 | use_timeouts = True 37 | 38 | episode_step = 0 39 | paths = [] 40 | for i in range(N): 41 | done_bool = bool(dataset['terminals'][i]) 42 | if use_timeouts: 43 | final_timestep = dataset['timeouts'][i] 44 | else: 45 | final_timestep = (episode_step == 1000-1) 46 | for k in ['observations', 'next_observations', 'actions', 'rewards', 'terminals']: 47 | data_[k].append(dataset[k][i]) 48 | if done_bool or final_timestep: 49 | episode_step = 0 50 | episode_data = {} 51 | for k in data_: 52 | episode_data[k] = np.array(data_[k]) 53 | paths.append(episode_data) 54 | data_ = collections.defaultdict(list) 55 | episode_step += 1 56 | 57 | returns = np.array([np.sum(p['rewards']) for p in paths]) 58 | num_samples = np.sum([p['rewards'].shape[0] for p in paths]) 59 | print(f'Number of samples collected: {num_samples}') 60 | print(f'Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}') 61 | 62 | with open(f'{pkl_file_path}.pkl', 'wb') as f: 63 | pickle.dump(paths, f) 64 | 65 | 66 | if __name__ == "__main__": 67 | download_d4rl_data() 68 | -------------------------------------------------------------------------------- /decision_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/d6694248b48c57c84fc7487e6e8017dcca861b02/decision_transformer/__init__.py -------------------------------------------------------------------------------- /decision_transformer/d4rl_infos.py: -------------------------------------------------------------------------------- 1 | 2 | ## from infos.py from official d4rl github repo 3 | REF_MIN_SCORE = { 4 | 'halfcheetah' : -280.178953, 5 | 'walker2d' : 1.629008, 6 | 'hopper' : -20.272305, 7 | } 8 | 9 | REF_MAX_SCORE = { 10 | 'halfcheetah' : 12135.0, 11 | 'walker2d' : 4592.3, 12 | 'hopper' : 3234.3, 13 | } 14 | 15 | ## calculated from d4rl datasets 16 | D4RL_DATASET_STATS = { 17 | 'halfcheetah-medium-v2': { 18 | 'state_mean':[-0.06845773756504059, 0.016414547339081764, -0.18354906141757965, 19 | -0.2762460708618164, -0.34061527252197266, -0.09339715540409088, 20 | -0.21321271359920502, -0.0877423882484436, 5.173007488250732, 21 | -0.04275195300579071, -0.036108363419771194, 0.14053793251514435, 22 | 0.060498327016830444, 0.09550975263118744, 0.06739100068807602, 23 | 0.005627387668937445, 0.013382787816226482 24 | ], 25 | 'state_std':[0.07472999393939972, 0.3023499846458435, 0.30207309126853943, 26 | 0.34417077898979187, 0.17619241774082184, 0.507205605506897, 27 | 0.2567007839679718, 0.3294812738895416, 1.2574149370193481, 28 | 0.7600541710853577, 1.9800915718078613, 6.565362453460693, 29 | 7.466367721557617, 4.472222805023193, 10.566964149475098, 30 | 5.671932697296143, 7.4982590675354 31 | ] 32 | }, 33 | 'halfcheetah-medium-replay-v2': { 34 | 'state_mean':[-0.12880703806877136, 0.3738119602203369, -0.14995987713336945, 35 | -0.23479078710079193, -0.2841278612613678, -0.13096535205841064, 36 | -0.20157982409000397, -0.06517726927995682, 3.4768247604370117, 37 | -0.02785065770149231, -0.015035249292850494, 0.07697279006242752, 38 | 0.01266712136566639, 0.027325302362442017, 0.02316424623131752, 39 | 0.010438721626996994, -0.015839405357837677 40 | ], 41 | 'state_std':[0.17019015550613403, 1.284424901008606, 0.33442774415016174, 42 | 0.3672759234905243, 0.26092398166656494, 0.4784106910228729, 43 | 0.3181420564651489, 0.33552637696266174, 2.0931615829467773, 44 | 0.8037433624267578, 1.9044333696365356, 6.573209762573242, 45 | 7.572863578796387, 5.069749355316162, 9.10555362701416, 46 | 6.085654258728027, 7.25300407409668 47 | ] 48 | }, 49 | 'halfcheetah-medium-expert-v2': { 50 | 'state_mean':[-0.05667462572455406, 0.024369969964027405, -0.061670560389757156, 51 | -0.22351515293121338, -0.2675151228904724, -0.07545716315507889, 52 | -0.05809682980179787, -0.027675075456500053, 8.110626220703125, 53 | -0.06136331334710121, -0.17986927926540375, 0.25175222754478455, 54 | 0.24186332523822784, 0.2519369423389435, 0.5879552960395813, 55 | -0.24090635776519775, -0.030184272676706314 56 | ], 57 | 'state_std':[0.06103534251451492, 0.36054104566574097, 0.45544400811195374, 58 | 0.38476887345314026, 0.2218363732099533, 0.5667523741722107, 59 | 0.3196682929992676, 0.2852923572063446, 3.443821907043457, 60 | 0.6728139519691467, 1.8616976737976074, 9.575807571411133, 61 | 10.029894828796387, 5.903450012207031, 12.128185272216797, 62 | 6.4811787605285645, 6.378620147705078 63 | ] 64 | }, 65 | 'walker2d-medium-v2': { 66 | 'state_mean':[1.218966007232666, 0.14163373410701752, -0.03704913705587387, 67 | -0.13814310729503632, 0.5138224363327026, -0.04719110205769539, 68 | -0.47288352251052856, 0.042254164814949036, 2.3948874473571777, 69 | -0.03143199160695076, 0.04466355964541435, -0.023907244205474854, 70 | -0.1013401448726654, 0.09090937674045563, -0.004192637279629707, 71 | -0.12120571732521057, -0.5497063994407654 72 | ], 73 | 'state_std':[0.12311358004808426, 0.3241879940032959, 0.11456084251403809, 74 | 0.2623065710067749, 0.5640279054641724, 0.2271878570318222, 75 | 0.3837319612503052, 0.7373676896095276, 1.2387926578521729, 76 | 0.798020601272583, 1.5664079189300537, 1.8092705011367798, 77 | 3.025604248046875, 4.062486171722412, 1.4586567878723145, 78 | 3.7445690631866455, 5.5851287841796875 79 | ] 80 | }, 81 | 'walker2d-medium-replay-v2': { 82 | 'state_mean':[1.209364652633667, 0.13264022767543793, -0.14371201395988464, 83 | -0.2046516090631485, 0.5577612519264221, -0.03231537342071533, 84 | -0.2784661054611206, 0.19130706787109375, 1.4701707363128662, 85 | -0.12504704296588898, 0.0564953051507473, -0.09991033375263214, 86 | -0.340340256690979, 0.03546293452382088, -0.08934258669614792, 87 | -0.2992438077926636, -0.5984178185462952 88 | ], 89 | 'state_std':[0.11929835379123688, 0.3562574088573456, 0.25852200388908386, 90 | 0.42075422406196594, 0.5202291011810303, 0.15685082972049713, 91 | 0.36770978569984436, 0.7161387801170349, 1.3763766288757324, 92 | 0.8632221817970276, 2.6364643573760986, 3.0134117603302, 93 | 3.720684051513672, 4.867283821105957, 2.6681625843048096, 94 | 3.845186948776245, 5.4768385887146 95 | ] 96 | }, 97 | 'walker2d-medium-expert-v2': { 98 | 'state_mean':[1.2294334173202515, 0.16869689524173737, -0.07089081406593323, 99 | -0.16197483241558075, 0.37101927399635315, -0.012209027074277401, 100 | -0.42461398243904114, 0.18986578285694122, 3.162475109100342, 101 | -0.018092676997184753, 0.03496946766972542, -0.013921679928898811, 102 | -0.05937029421329498, -0.19549426436424255, -0.0019200450042262673, 103 | -0.062483321875333786, -0.27366524934768677 104 | ], 105 | 'state_std':[0.09932824969291687, 0.25981399416923523, 0.15062759816646576, 106 | 0.24249176681041718, 0.6758718490600586, 0.1650741547346115, 107 | 0.38140663504600525, 0.6962361335754395, 1.3501490354537964, 108 | 0.7641991376876831, 1.534574270248413, 2.1785972118377686, 109 | 3.276582717895508, 4.766193866729736, 1.1716983318328857, 110 | 4.039782524108887, 5.891613960266113 111 | ] 112 | }, 113 | 'hopper-medium-v2': { 114 | 'state_mean':[1.311279058456421, -0.08469521254301071, -0.5382719039916992, 115 | -0.07201576232910156, 0.04932365566492081, 2.1066856384277344, 116 | -0.15017354488372803, 0.008783451281487942, -0.2848185896873474, 117 | -0.18540096282958984, -0.28461286425590515 118 | ], 119 | 'state_std':[0.17790751159191132, 0.05444620922207832, 0.21297138929367065, 120 | 0.14530418813228607, 0.6124444007873535, 0.8517446517944336, 121 | 1.4515252113342285, 0.6751695871353149, 1.5362390279769897, 122 | 1.616074562072754, 5.607253551483154 123 | ] 124 | }, 125 | 'hopper-medium-replay-v2': { 126 | 'state_mean':[1.2305138111114502, -0.04371410980820656, -0.44542956352233887, 127 | -0.09370097517967224, 0.09094487875699997, 1.3694725036621094, 128 | -0.19992674887180328, -0.022861352190375328, -0.5287045240402222, 129 | -0.14465883374214172, -0.19652697443962097 130 | ], 131 | 'state_std':[0.1756512075662613, 0.0636928603053093, 0.3438323438167572, 132 | 0.19566889107227325, 0.5547984838485718, 1.051029920578003, 133 | 1.158307671546936, 0.7963128685951233, 1.4802359342575073, 134 | 1.6540331840515137, 5.108601093292236 135 | ] 136 | }, 137 | 'hopper-medium-expert-v2': { 138 | 'state_mean':[1.3293815851211548, -0.09836531430482864, -0.5444297790527344, 139 | -0.10201650857925415, 0.02277466468513012, 2.3577215671539307, 140 | -0.06349576264619827, -0.00374026270583272, -0.1766270101070404, 141 | -0.11862941086292267, -0.12097819894552231 142 | ], 143 | 'state_std':[0.17012375593185425, 0.05159067362546921, 0.18141433596611023, 144 | 0.16430604457855225, 0.6023368239402771, 0.7737284898757935, 145 | 1.4986555576324463, 0.7483318448066711, 1.7953159809112549, 146 | 2.0530025959014893, 5.725032806396484 147 | ] 148 | }, 149 | } 150 | -------------------------------------------------------------------------------- /decision_transformer/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | this extremely minimal Decision Transformer model is based on 3 | the following causal transformer (GPT) implementation: 4 | 5 | Misha Laskin's tweet: 6 | https://twitter.com/MishaLaskin/status/1481767788775628801?cxt=HHwWgoCzmYD9pZApAAAA 7 | 8 | and its corresponding notebook: 9 | https://colab.research.google.com/drive/1NUBqyboDcGte5qAJKOl8gaJC28V_73Iv?usp=sharing 10 | 11 | ** the above colab notebook has a bug while applying masked_fill 12 | which is fixed in the following code 13 | """ 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | 21 | class MaskedCausalAttention(nn.Module): 22 | def __init__(self, h_dim, max_T, n_heads, drop_p): 23 | super().__init__() 24 | 25 | self.n_heads = n_heads 26 | self.max_T = max_T 27 | 28 | self.q_net = nn.Linear(h_dim, h_dim) 29 | self.k_net = nn.Linear(h_dim, h_dim) 30 | self.v_net = nn.Linear(h_dim, h_dim) 31 | 32 | self.proj_net = nn.Linear(h_dim, h_dim) 33 | 34 | self.att_drop = nn.Dropout(drop_p) 35 | self.proj_drop = nn.Dropout(drop_p) 36 | 37 | ones = torch.ones((max_T, max_T)) 38 | mask = torch.tril(ones).view(1, 1, max_T, max_T) 39 | 40 | # register buffer makes sure mask does not get updated 41 | # during backpropagation 42 | self.register_buffer('mask',mask) 43 | 44 | def forward(self, x): 45 | B, T, C = x.shape # batch size, seq length, h_dim * n_heads 46 | 47 | N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim 48 | 49 | # rearrange q, k, v as (B, N, T, D) 50 | q = self.q_net(x).view(B, T, N, D).transpose(1,2) 51 | k = self.k_net(x).view(B, T, N, D).transpose(1,2) 52 | v = self.v_net(x).view(B, T, N, D).transpose(1,2) 53 | 54 | # weights (B, N, T, T) 55 | weights = q @ k.transpose(2,3) / math.sqrt(D) 56 | # causal mask applied to weights 57 | weights = weights.masked_fill(self.mask[...,:T,:T] == 0, float('-inf')) 58 | # normalize weights, all -inf -> 0 after softmax 59 | normalized_weights = F.softmax(weights, dim=-1) 60 | 61 | # attention (B, N, T, D) 62 | attention = self.att_drop(normalized_weights @ v) 63 | 64 | # gather heads and project (B, N, T, D) -> (B, T, N*D) 65 | attention = attention.transpose(1, 2).contiguous().view(B,T,N*D) 66 | 67 | out = self.proj_drop(self.proj_net(attention)) 68 | return out 69 | 70 | 71 | class Block(nn.Module): 72 | def __init__(self, h_dim, max_T, n_heads, drop_p): 73 | super().__init__() 74 | self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p) 75 | self.mlp = nn.Sequential( 76 | nn.Linear(h_dim, 4*h_dim), 77 | nn.GELU(), 78 | nn.Linear(4*h_dim, h_dim), 79 | nn.Dropout(drop_p), 80 | ) 81 | self.ln1 = nn.LayerNorm(h_dim) 82 | self.ln2 = nn.LayerNorm(h_dim) 83 | 84 | def forward(self, x): 85 | # Attention -> LayerNorm -> MLP -> LayerNorm 86 | x = x + self.attention(x) # residual 87 | x = self.ln1(x) 88 | x = x + self.mlp(x) # residual 89 | x = self.ln2(x) 90 | return x 91 | 92 | 93 | class DecisionTransformer(nn.Module): 94 | def __init__(self, state_dim, act_dim, n_blocks, h_dim, context_len, 95 | n_heads, drop_p, max_timestep=4096): 96 | super().__init__() 97 | 98 | self.state_dim = state_dim 99 | self.act_dim = act_dim 100 | self.h_dim = h_dim 101 | 102 | ### transformer blocks 103 | input_seq_len = 3 * context_len 104 | blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] 105 | self.transformer = nn.Sequential(*blocks) 106 | 107 | ### projection heads (project to embedding) 108 | self.embed_ln = nn.LayerNorm(h_dim) 109 | self.embed_timestep = nn.Embedding(max_timestep, h_dim) 110 | self.embed_rtg = torch.nn.Linear(1, h_dim) 111 | self.embed_state = torch.nn.Linear(state_dim, h_dim) 112 | 113 | # # discrete actions 114 | # self.embed_action = torch.nn.Embedding(act_dim, h_dim) 115 | # use_action_tanh = False # False for discrete actions 116 | 117 | # continuous actions 118 | self.embed_action = torch.nn.Linear(act_dim, h_dim) 119 | use_action_tanh = True # True for continuous actions 120 | 121 | ### prediction heads 122 | self.predict_rtg = torch.nn.Linear(h_dim, 1) 123 | self.predict_state = torch.nn.Linear(h_dim, state_dim) 124 | self.predict_action = nn.Sequential( 125 | *([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])) 126 | ) 127 | 128 | 129 | def forward(self, timesteps, states, actions, returns_to_go): 130 | 131 | B, T, _ = states.shape 132 | 133 | time_embeddings = self.embed_timestep(timesteps) 134 | 135 | # time embeddings are treated similar to positional embeddings 136 | state_embeddings = self.embed_state(states) + time_embeddings 137 | action_embeddings = self.embed_action(actions) + time_embeddings 138 | returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings 139 | 140 | # stack rtg, states and actions and reshape sequence as 141 | # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) 142 | h = torch.stack( 143 | (returns_embeddings, state_embeddings, action_embeddings), dim=1 144 | ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) 145 | 146 | h = self.embed_ln(h) 147 | 148 | # transformer and prediction 149 | h = self.transformer(h) 150 | 151 | # get h reshaped such that its size = (B x 3 x T x h_dim) and 152 | # h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t 153 | # h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t 154 | # h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t 155 | # that is, for each timestep (t) we have 3 output embeddings from the transformer, 156 | # each conditioned on all previous timesteps plus 157 | # the 3 input variables at that timestep (r_t, s_t, a_t) in sequence. 158 | h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) 159 | 160 | # get predictions 161 | return_preds = self.predict_rtg(h[:,2]) # predict next rtg given r, s, a 162 | state_preds = self.predict_state(h[:,2]) # predict next state given r, s, a 163 | action_preds = self.predict_action(h[:,1]) # predict action given r, s 164 | 165 | return state_preds, action_preds, return_preds 166 | -------------------------------------------------------------------------------- /decision_transformer/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | import pickle 4 | import torch 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | from decision_transformer.d4rl_infos import REF_MIN_SCORE, REF_MAX_SCORE, D4RL_DATASET_STATS 8 | 9 | 10 | def discount_cumsum(x, gamma): 11 | disc_cumsum = np.zeros_like(x) 12 | disc_cumsum[-1] = x[-1] 13 | for t in reversed(range(x.shape[0]-1)): 14 | disc_cumsum[t] = x[t] + gamma * disc_cumsum[t+1] 15 | return disc_cumsum 16 | 17 | 18 | def get_d4rl_normalized_score(score, env_name): 19 | env_key = env_name.split('-')[0].lower() 20 | assert env_key in REF_MAX_SCORE, f'no reference score for {env_key} env to calculate d4rl score' 21 | return (score - REF_MIN_SCORE[env_key]) / (REF_MAX_SCORE[env_key] - REF_MIN_SCORE[env_key]) 22 | 23 | 24 | def get_d4rl_dataset_stats(env_d4rl_name): 25 | return D4RL_DATASET_STATS[env_d4rl_name] 26 | 27 | 28 | def evaluate_on_env(model, device, context_len, env, rtg_target, rtg_scale, 29 | num_eval_ep=10, max_test_ep_len=1000, 30 | state_mean=None, state_std=None, render=False): 31 | 32 | eval_batch_size = 1 # required for forward pass 33 | 34 | results = {} 35 | total_reward = 0 36 | total_timesteps = 0 37 | 38 | state_dim = env.observation_space.shape[0] 39 | act_dim = env.action_space.shape[0] 40 | 41 | if state_mean is None: 42 | state_mean = torch.zeros((state_dim,)).to(device) 43 | else: 44 | state_mean = torch.from_numpy(state_mean).to(device) 45 | 46 | if state_std is None: 47 | state_std = torch.ones((state_dim,)).to(device) 48 | else: 49 | state_std = torch.from_numpy(state_std).to(device) 50 | 51 | # same as timesteps used for training the transformer 52 | # also, crashes if device is passed to arange() 53 | timesteps = torch.arange(start=0, end=max_test_ep_len, step=1) 54 | timesteps = timesteps.repeat(eval_batch_size, 1).to(device) 55 | 56 | model.eval() 57 | 58 | with torch.no_grad(): 59 | 60 | for _ in range(num_eval_ep): 61 | 62 | # zeros place holders 63 | actions = torch.zeros((eval_batch_size, max_test_ep_len, act_dim), 64 | dtype=torch.float32, device=device) 65 | states = torch.zeros((eval_batch_size, max_test_ep_len, state_dim), 66 | dtype=torch.float32, device=device) 67 | rewards_to_go = torch.zeros((eval_batch_size, max_test_ep_len, 1), 68 | dtype=torch.float32, device=device) 69 | 70 | # init episode 71 | running_state = env.reset() 72 | running_reward = 0 73 | running_rtg = rtg_target / rtg_scale 74 | 75 | for t in range(max_test_ep_len): 76 | 77 | total_timesteps += 1 78 | 79 | # add state in placeholder and normalize 80 | states[0, t] = torch.from_numpy(running_state).to(device) 81 | states[0, t] = (states[0, t] - state_mean) / state_std 82 | 83 | # calcualate running rtg and add it in placeholder 84 | running_rtg = running_rtg - (running_reward / rtg_scale) 85 | rewards_to_go[0, t] = running_rtg 86 | 87 | if t < context_len: 88 | _, act_preds, _ = model.forward(timesteps[:,:context_len], 89 | states[:,:context_len], 90 | actions[:,:context_len], 91 | rewards_to_go[:,:context_len]) 92 | act = act_preds[0, t].detach() 93 | else: 94 | _, act_preds, _ = model.forward(timesteps[:,t-context_len+1:t+1], 95 | states[:,t-context_len+1:t+1], 96 | actions[:,t-context_len+1:t+1], 97 | rewards_to_go[:,t-context_len+1:t+1]) 98 | act = act_preds[0, -1].detach() 99 | 100 | running_state, running_reward, done, _ = env.step(act.cpu().numpy()) 101 | 102 | # add action in placeholder 103 | actions[0, t] = act 104 | 105 | total_reward += running_reward 106 | 107 | if render: 108 | env.render() 109 | if done: 110 | break 111 | 112 | results['eval/avg_reward'] = total_reward / num_eval_ep 113 | results['eval/avg_ep_len'] = total_timesteps / num_eval_ep 114 | 115 | return results 116 | 117 | 118 | class D4RLTrajectoryDataset(Dataset): 119 | def __init__(self, dataset_path, context_len, rtg_scale): 120 | 121 | self.context_len = context_len 122 | 123 | # load dataset 124 | with open(dataset_path, 'rb') as f: 125 | self.trajectories = pickle.load(f) 126 | 127 | # calculate min len of traj, state mean and variance 128 | # and returns_to_go for all traj 129 | min_len = 10**6 130 | states = [] 131 | for traj in self.trajectories: 132 | traj_len = traj['observations'].shape[0] 133 | min_len = min(min_len, traj_len) 134 | states.append(traj['observations']) 135 | # calculate returns to go and rescale them 136 | traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale 137 | 138 | # used for input normalization 139 | states = np.concatenate(states, axis=0) 140 | self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 141 | 142 | # normalize states 143 | for traj in self.trajectories: 144 | traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std 145 | 146 | def get_state_stats(self): 147 | return self.state_mean, self.state_std 148 | 149 | def __len__(self): 150 | return len(self.trajectories) 151 | 152 | def __getitem__(self, idx): 153 | traj = self.trajectories[idx] 154 | traj_len = traj['observations'].shape[0] 155 | 156 | if traj_len >= self.context_len: 157 | # sample random index to slice trajectory 158 | si = random.randint(0, traj_len - self.context_len) 159 | 160 | states = torch.from_numpy(traj['observations'][si : si + self.context_len]) 161 | actions = torch.from_numpy(traj['actions'][si : si + self.context_len]) 162 | returns_to_go = torch.from_numpy(traj['returns_to_go'][si : si + self.context_len]) 163 | timesteps = torch.arange(start=si, end=si+self.context_len, step=1) 164 | 165 | # all ones since no padding 166 | traj_mask = torch.ones(self.context_len, dtype=torch.long) 167 | 168 | else: 169 | padding_len = self.context_len - traj_len 170 | 171 | # padding with zeros 172 | states = torch.from_numpy(traj['observations']) 173 | states = torch.cat([states, 174 | torch.zeros(([padding_len] + list(states.shape[1:])), 175 | dtype=states.dtype)], 176 | dim=0) 177 | 178 | actions = torch.from_numpy(traj['actions']) 179 | actions = torch.cat([actions, 180 | torch.zeros(([padding_len] + list(actions.shape[1:])), 181 | dtype=actions.dtype)], 182 | dim=0) 183 | 184 | returns_to_go = torch.from_numpy(traj['returns_to_go']) 185 | returns_to_go = torch.cat([returns_to_go, 186 | torch.zeros(([padding_len] + list(returns_to_go.shape[1:])), 187 | dtype=returns_to_go.dtype)], 188 | dim=0) 189 | 190 | timesteps = torch.arange(start=0, end=self.context_len, step=1) 191 | 192 | traj_mask = torch.cat([torch.ones(traj_len, dtype=torch.long), 193 | torch.zeros(padding_len, dtype=torch.long)], 194 | dim=0) 195 | 196 | return timesteps, states, actions, returns_to_go, traj_mask 197 | -------------------------------------------------------------------------------- /dt_runs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/d6694248b48c57c84fc7487e6e8017dcca861b02/dt_runs/__init__.py -------------------------------------------------------------------------------- /dt_runs/dt_halfcheetah-medium-v2_log_22-02-10-11-56-32.csv: -------------------------------------------------------------------------------- 1 | ,duration,num_updates,action_loss,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score 2 | 0,0:00:50,100,0.8170432960987091,-55.734773632444664,1000.0,1.8078207347411668 3 | 1,0:01:39,200,0.7662054866552352,-146.62216775046357,1000.0,1.0757540085015351 4 | 2,0:02:21,300,0.6748140245676041,-104.03758409227495,1000.0,1.418758195709795 5 | 3,0:03:03,400,0.5723127049207687,-115.673485436588,1000.0,1.3250350090496352 6 | 4,0:03:45,500,0.4817734429240227,-102.8708835426667,1000.0,1.4281555677011697 7 | 5,0:04:27,600,0.4027451115846634,-59.32679857374095,1000.0,1.7788882082355515 8 | 6,0:05:15,700,0.3367310807108879,36.39476233059897,1000.0,2.5498924866814114 9 | 7,0:06:00,800,0.2922138586640358,229.37399096679923,1000.0,4.104273856186913 10 | 8,0:06:50,900,0.26553354173898697,409.75452221692063,1000.0,5.55717704777993 11 | 9,0:07:36,1000,0.2510985206067562,356.17952231328684,1000.0,5.1256488345624485 12 | 10,0:08:28,1100,0.2387624034285545,408.5377833228312,1000.0,5.547376634119396 13 | 11,0:09:12,1200,0.2262622436881065,391.4304257582649,1000.0,5.409582747866693 14 | 12,0:10:00,1300,0.2181432858109474,382.74223983165354,1000.0,5.339602395916052 15 | 13,0:10:42,1400,0.2103891736268997,598.5660583942955,1000.0,7.0779890867538064 16 | 14,0:11:23,1500,0.2032386972010136,626.5778387855551,1000.0,7.303614351579253 17 | 15,0:12:05,1600,0.19371537402272224,501.9555451582149,1000.0,6.2998246027635405 18 | 16,0:12:46,1700,0.1856754983961582,257.45450555783293,1000.0,4.330452751371088 19 | 17,0:13:27,1800,0.17992598846554755,363.2553622803873,1000.0,5.182642293890641 20 | 18,0:14:08,1900,0.17226264491677284,178.95847413566102,1000.0,3.698194193364528 21 | 19,0:14:50,2000,0.16773971989750866,328.0737090461648,1000.0,4.8992661672362505 22 | 20,0:15:31,2100,0.16027964174747467,238.1040053514056,1000.0,4.174591121992388 23 | 21,0:16:12,2200,0.15555616930127145,400.4980582968866,1000.0,5.482619411880551 24 | 22,0:16:53,2300,0.15267556935548782,304.03084148692426,1000.0,4.7056091313589645 25 | 23,0:17:34,2400,0.1483520731329918,390.214888940773,1000.0,5.3997920165200615 26 | 24,0:18:15,2500,0.14420259192585946,205.90278702220837,1000.0,3.915221374273882 27 | 25,0:18:56,2600,0.14041977912187575,521.0711700078372,1000.0,6.453794391857906 28 | 26,0:19:37,2700,0.13579027041792868,325.18139217694517,1000.0,4.875969548797087 29 | 27,0:20:23,2800,0.13365695871412753,401.89588365319213,1000.0,5.493878414764015 30 | 28,0:21:04,2900,0.12971090264618396,234.55989653494103,1000.0,4.146044543405955 31 | 29,0:21:45,3000,0.12719877541065214,157.06290232104175,1000.0,3.5218328867936832 32 | 30,0:22:26,3100,0.12344265699386595,665.2384860126018,1000.0,7.615012579292316 33 | 31,0:23:07,3200,0.121046177521348,333.254413635907,1000.0,4.940994962361595 34 | 32,0:23:48,3300,0.11885549604892733,653.8569027877691,1000.0,7.523337837688347 35 | 33,0:24:29,3400,0.11752748496830465,420.1135382491772,1000.0,5.640615362052103 36 | 34,0:25:10,3500,0.11457525938749312,667.2705128613469,1000.0,7.6313798572545375 37 | 35,0:25:52,3600,0.1114836598187685,814.7679330635617,1000.0,8.819420889611736 38 | 36,0:26:34,3700,0.10894531585276128,276.4045279290226,1000.0,4.483088669410842 39 | 37,0:27:15,3800,0.10797606021165848,439.10529453804634,1000.0,5.793587432456932 40 | 38,0:27:56,3900,0.10484644643962383,780.0500821458184,1000.0,8.539780531231287 41 | 39,0:28:37,4000,0.1026029658317566,340.92145078058763,1000.0,5.002750311790754 42 | 40,0:29:18,4100,0.10138623975217342,956.6432408261252,1000.0,9.96217773830203 43 | 41,0:29:59,4200,0.09861044250428676,1511.5533317453574,1000.0,14.43178782624316 44 | 42,0:30:41,4300,0.09634928107261656,1736.2946085137378,1000.0,16.24200157845069 45 | 43,0:31:22,4400,0.09544318906962873,1330.863097746816,1000.0,12.976390085440727 46 | 44,0:32:04,4500,0.09425550647079944,2456.922416957479,1000.0,22.04641093228936 47 | 45,0:32:45,4600,0.09201115779578686,1377.7767997354424,1000.0,13.354263833102578 48 | 46,0:33:26,4700,0.09015206806361677,3257.6827965449015,1000.0,28.49626060919576 49 | 47,0:34:07,4800,0.08877075530588628,1958.4644860346973,1000.0,18.031503593379558 50 | 48,0:34:49,4900,0.08679009176790714,1764.0255618681476,1000.0,16.4653648780003 51 | 49,0:35:35,5000,0.08509850807487965,2087.1537023810274,1000.0,19.068051007101964 52 | 50,0:36:16,5100,0.08428584985435009,1203.8428783178324,1000.0,11.953285868338076 53 | 51,0:36:58,5200,0.08225006856024265,2829.7521031594756,1000.0,25.04942593202005 54 | 52,0:37:39,5300,0.08057696856558323,3715.552359314317,1000.0,32.18424259079076 55 | 53,0:38:20,5400,0.08013116516172886,3031.077361386993,1000.0,26.671031701777142 56 | 54,0:39:13,5500,0.07883696436882019,2315.3007898452065,1000.0,20.905697394060002 57 | 55,0:40:11,5600,0.07787837184965611,3364.1607256111433,1000.0,29.353903736768327 58 | 56,0:40:52,5700,0.07637658413499594,2917.78094775218,1000.0,25.75846802417154 59 | 57,0:41:33,5800,0.07614435449242593,3574.50920541378,1000.0,31.048188455490077 60 | 58,0:42:14,5900,0.07361038438975813,3821.208862887122,1000.0,33.03526941829593 61 | 59,0:42:56,6000,0.07354964684695005,3678.801249656209,1000.0,31.88822503198443 62 | 60,0:43:37,6100,0.07177457325160504,3727.1953323748403,1000.0,32.278022737694805 63 | 61,0:44:17,6200,0.07035566125065089,3949.484560129701,1000.0,34.06848607774313 64 | 62,0:44:58,6300,0.06933463383466006,4250.38654209128,1000.0,36.492148137715844 65 | 63,0:45:39,6400,0.06816816948354244,4433.902319967568,1000.0,37.970304663457604 66 | 64,0:46:19,6500,0.06874442987143993,4673.536702470343,1000.0,39.9004772643517 67 | 65,0:46:59,6600,0.06682684503495694,4258.03127141469,1000.0,36.55372380531074 68 | 66,0:47:40,6700,0.066059561483562,4469.7311389226725,1000.0,38.258893487595714 69 | 67,0:48:20,6800,0.06530408855527639,4358.5614100912935,1000.0,37.36345952524824 70 | 68,0:49:01,6900,0.06486820977181196,4103.650243990749,1000.0,35.310237682328705 71 | 69,0:49:41,7000,0.06362807419151068,4401.790216430416,1000.0,37.711652704764795 72 | 70,0:50:21,7100,0.06376377396285532,4768.574301177143,1000.0,40.665972462339454 73 | 71,0:51:02,7200,0.06288216136395931,4785.226311220529,1000.0,40.800098680788864 74 | 72,0:51:42,7300,0.061463375464081774,4213.615908428285,1000.0,36.195973319759545 75 | 73,0:52:23,7400,0.06175081416964531,4660.057141983613,1000.0,39.79190403687138 76 | 74,0:53:03,7500,0.06052427981048824,4451.553831127998,1000.0,38.11248151992705 77 | 75,0:53:44,7600,0.06067767400294542,4721.471393471804,1000.0,40.28657472764985 78 | 76,0:54:25,7700,0.05860625334084034,4599.815690715442,1000.0,39.30667984883328 79 | 77,0:55:06,7800,0.058999049291014675,4766.440443207788,1000.0,40.648784969694894 80 | 78,0:55:47,7900,0.05809070449322463,4341.250235305868,1000.0,37.22402396132314 81 | 79,0:56:28,8000,0.057912974469363686,5045.749876231343,1000.0,42.89852646823416 82 | 80,0:57:08,8100,0.05660187132656575,4736.901942351689,1000.0,40.41086249618144 83 | 81,0:57:49,8200,0.05639633990824223,4064.534357572773,1000.0,34.99517265937531 84 | 82,0:58:30,8300,0.055291385762393466,4985.146454973052,1000.0,42.410386736316354 85 | 83,0:59:11,8400,0.055140086337924006,4995.975019948418,1000.0,42.497607105965145 86 | 84,0:59:51,8500,0.05466329272836447,4120.352133192484,1000.0,35.44476566025768 87 | 85,1:00:32,8600,0.053507385142147534,4600.6177254626755,1000.0,39.313139963103644 88 | 86,1:01:13,8700,0.0536041934415698,4449.0955080397825,1000.0,38.092680572252256 89 | 87,1:01:53,8800,0.05318384800106287,4697.9763383754325,1000.0,40.097330132905675 90 | 88,1:02:34,8900,0.05281607292592525,4182.503783777959,1000.0,35.945375847358186 91 | 89,1:03:15,9000,0.05215774279087782,4690.228370384115,1000.0,40.034922913318674 92 | 90,1:03:56,9100,0.051580743379890916,4546.9549217455215,1000.0,38.88090452034197 93 | 91,1:04:37,9200,0.05083662379533052,5030.262572351364,1000.0,42.77378155768064 94 | 92,1:05:18,9300,0.05068173661828041,4950.278109725288,1000.0,42.12953419782485 95 | 93,1:05:59,9400,0.050778139866888526,4874.060438698888,1000.0,41.51562704984948 96 | 94,1:06:41,9500,0.04975619949400425,4937.008185986642,1000.0,42.022649522308846 97 | 95,1:07:22,9600,0.05004123408347368,4979.9959717830325,1000.0,42.368901364180225 98 | 96,1:08:03,9700,0.04949270389974117,5008.098722212681,1000.0,42.5952593614031 99 | 97,1:08:44,9800,0.04964465372264385,4357.047187651233,1000.0,37.3512629838549 100 | 98,1:09:24,9900,0.04877568513154983,4721.100458781256,1000.0,40.28358697618891 101 | 99,1:10:05,10000,0.04847250949591398,4595.091355127275,1000.0,39.26862694918478 102 | 100,1:10:47,10100,0.04801385246217251,4605.4994095738975,1000.0,39.352460251032646 103 | 101,1:11:28,10200,0.04814957469701767,4410.500704163474,1000.0,37.781812690102385 104 | 102,1:12:09,10300,0.047178924828767774,4626.5667929447945,1000.0,39.522150784295626 105 | 103,1:12:50,10400,0.046970902644097805,4604.738267216251,1000.0,39.346329510907786 106 | 104,1:13:31,10500,0.04700594186782837,4281.6391776588125,1000.0,36.743877377268866 107 | 105,1:14:12,10600,0.04707200288772583,4965.029561012871,1000.0,42.24835206862177 108 | 106,1:14:53,10700,0.04589315827935934,4958.055447784242,1000.0,42.19217798321365 109 | 107,1:15:35,10800,0.045325954742729664,4926.084874484196,1000.0,41.93466600194397 110 | 108,1:16:16,10900,0.04499753255397081,4993.3455833598355,1000.0,42.47642789784793 111 | 109,1:16:56,11000,0.04512132585048675,4978.738145676979,1000.0,42.358770007146894 112 | 110,1:17:37,11100,0.044957182966172696,5012.338303347195,1000.0,42.62940773051292 113 | 111,1:18:19,11200,0.04498727105557918,4926.900747106067,1000.0,41.941237575539176 114 | 112,1:18:59,11300,0.04501208584755659,5001.94522345198,1000.0,42.54569504353063 115 | 113,1:19:40,11400,0.044447521641850465,5073.361237459812,1000.0,43.12092649430705 116 | 114,1:20:22,11500,0.044407006725668906,5006.237196044538,1000.0,42.58026540783071 117 | 115,1:21:07,11600,0.044012592993676665,4575.016648381878,1000.0,39.106932085007685 118 | 116,1:21:48,11700,0.043272637724876405,4686.609795595521,1000.0,40.005776536916905 119 | 117,1:22:29,11800,0.04407030057162047,5000.666354490897,1000.0,42.535394193531424 120 | 118,1:23:10,11900,0.04330223489552736,4646.231993464672,1000.0,39.68054721655265 121 | 119,1:23:51,12000,0.042889186851680285,4889.069927497822,1000.0,41.63652332412595 122 | 120,1:24:32,12100,0.04293781075626612,4762.617166954821,1000.0,40.6179897933431 123 | 121,1:25:14,12200,0.04248238772153854,5084.451087944137,1000.0,43.21025142893997 124 | 122,1:25:56,12300,0.042045941166579726,4386.172671682038,1000.0,37.58585874877351 125 | 123,1:26:38,12400,0.042759826742112636,4549.57446993622,1000.0,38.90200408081238 126 | 124,1:27:19,12500,0.042140917107462886,4884.486190975659,1000.0,41.59960290163736 127 | 125,1:28:00,12600,0.04213300243020058,5068.248466131098,1000.0,43.0797448782541 128 | 126,1:28:41,12700,0.04192981451749802,5054.504460351525,1000.0,42.969041634816335 129 | 127,1:29:22,12800,0.041476597338914865,4728.376165063621,1000.0,40.342190290002655 130 | 128,1:30:03,12900,0.0415212594717741,4954.233041748988,1000.0,42.1613898161665 131 | 129,1:30:44,13000,0.04167712945491076,4971.877180678692,1000.0,42.30350729185089 132 | 130,1:31:25,13100,0.04139833312481642,4931.015595386873,1000.0,41.97438126437672 133 | 131,1:32:07,13200,0.04055118229240179,5008.896728930016,1000.0,42.60168703127686 134 | 132,1:32:48,13300,0.04061872020363808,5106.502566305333,1000.0,43.38786850916632 135 | 133,1:33:30,13400,0.04064277980476618,5028.570076117162,1000.0,42.76014908213914 136 | 134,1:34:11,13500,0.04040551275014877,4948.080302474447,1000.0,42.111831615694044 137 | 135,1:34:53,13600,0.040255818553268916,4959.613116280469,1000.0,42.20472446765922 138 | 136,1:35:34,13700,0.0405152177810669,4983.579081572234,1000.0,42.39776208219939 139 | 137,1:36:16,13800,0.039837098754942415,4658.74401825126,1000.0,39.78132727646121 140 | 138,1:36:56,13900,0.03986068237572908,5035.865647571116,1000.0,42.81891240308339 141 | 139,1:37:37,14000,0.039398337937891485,4631.82048039732,1000.0,39.56446743130018 142 | 140,1:38:21,14100,0.03947044119238854,4624.6949006098575,1000.0,39.507073334812006 143 | 141,1:39:02,14200,0.03965242676436901,5053.9869654893955,1000.0,42.96487339153859 144 | 142,1:39:43,14300,0.039702990129590034,5050.580593671793,1000.0,42.93743623714477 145 | 143,1:40:25,14400,0.03962724547833205,4982.390390401581,1000.0,42.388187583312565 146 | 144,1:41:07,14500,0.04025613836944103,4799.53229629537,1000.0,40.91532847432626 147 | 145,1:41:49,14600,0.03923682786524296,5026.804346262155,1000.0,42.745926734948725 148 | 146,1:42:31,14700,0.039113823771476736,5035.648550394819,1000.0,42.81716375993359 149 | 147,1:43:12,14800,0.03879088185727597,5115.57993776177,1000.0,43.46098361681641 150 | 148,1:43:54,14900,0.03915165562182665,4930.505991272201,1000.0,41.970276578358074 151 | 149,1:44:35,15000,0.03887611918151378,5041.933020696071,1000.0,42.86778300855693 152 | 150,1:45:17,15100,0.03888282496482134,4897.737906095244,1000.0,41.7063409129842 153 | 151,1:46:00,15200,0.03901642251759768,4986.183308232399,1000.0,42.41873823300659 154 | 152,1:46:42,15300,0.03802666950970888,5015.271267044851,1000.0,42.65303174518688 155 | 153,1:47:23,15400,0.038550314232707016,4963.609308833476,1000.0,42.23691242538529 156 | 154,1:48:05,15500,0.03855738956481218,5016.182503736881,1000.0,42.660371443595416 157 | 155,1:48:46,15600,0.03770638342946768,4659.906851875813,1000.0,39.79069350170013 158 | 156,1:49:27,15700,0.038392423279583456,5031.094471628546,1000.0,42.78048222047682 159 | 157,1:50:09,15800,0.03773642376065254,5014.625519113661,1000.0,42.64783046751191 160 | 158,1:50:50,15900,0.037794077843427656,4719.574327764745,1000.0,40.27129451530464 161 | 159,1:51:32,16000,0.038214647024869916,4727.833671495152,1000.0,40.33782069073613 162 | 160,1:52:13,16100,0.03773270595818758,4901.818058044979,1000.0,41.739205134798326 163 | 161,1:52:55,16200,0.037771521359682085,4893.945859547697,1000.0,41.67579728117751 164 | 162,1:53:36,16300,0.03757071798667312,4914.0133628809335,1000.0,41.837434124345094 165 | 163,1:54:18,16400,0.03746144991368056,4630.994590840279,1000.0,39.5578151747345 166 | 164,1:54:59,16500,0.037571895979344835,5007.498571611704,1000.0,42.59042535455351 167 | 165,1:55:41,16600,0.03721294801682234,4984.72270244139,1000.0,42.40697355529603 168 | 166,1:56:22,16700,0.03721242653205991,5062.798494359099,1000.0,43.03584722850912 169 | 167,1:57:04,16800,0.036894718799740066,5068.453442868844,1000.0,43.081395895436536 170 | 168,1:57:45,16900,0.03663720076903701,5053.902733944525,1000.0,42.96419493539075 171 | 169,1:58:26,17000,0.03701364492997527,4560.249131206947,1000.0,38.98798480900919 172 | 170,1:59:08,17100,0.03619711451232433,5118.870060139568,1000.0,43.487484421921614 173 | 171,1:59:49,17200,0.036422716639935966,5033.526859668583,1000.0,42.800074270251095 174 | 172,2:00:30,17300,0.03665077719837427,5083.130096125875,1000.0,43.19961129380165 175 | 173,2:01:11,17400,0.036483543422073134,4679.479829619917,1000.0,39.94834711119058 176 | 174,2:01:52,17500,0.036940774358808985,5043.993422684677,1000.0,42.88437883851965 177 | 175,2:02:33,17600,0.036740592364221815,4730.514509146328,1000.0,40.35941391674862 178 | 176,2:03:14,17700,0.03641395907849074,5101.537742726438,1000.0,43.347878561396016 179 | 177,2:03:55,17800,0.03628888323903084,4993.368048949124,1000.0,42.47660885044937 180 | 178,2:04:36,17900,0.03599242318421602,5049.112966638298,1000.0,42.92561500573883 181 | 179,2:05:17,18000,0.03585299048572778,4939.853148566753,1000.0,42.045564718222494 182 | 180,2:05:59,18100,0.03565149124711752,4768.133955127304,1000.0,40.66242562623256 183 | 181,2:06:40,18200,0.03583951679989696,4933.212411119048,1000.0,41.99207586016539 184 | 182,2:07:21,18300,0.035990918334573505,5139.14033255091,1000.0,43.65075450033193 185 | 183,2:08:02,18400,0.036003620736300944,5005.400543907548,1000.0,42.57352646238209 186 | 184,2:08:43,18500,0.03591528914868832,5026.982164561678,1000.0,42.74735900024426 187 | 185,2:09:24,18600,0.03534703999757767,4990.38433725781,1000.0,42.45257607812598 188 | 186,2:10:05,18700,0.03641320426017046,4613.6624282533285,1000.0,39.41821056128056 189 | 187,2:10:46,18800,0.03604630442336201,4927.089999470356,1000.0,41.942761938297096 190 | 188,2:11:27,18900,0.03492267739027738,5117.564069076466,1000.0,43.476965112711135 191 | 189,2:12:08,19000,0.035232288166880614,5050.627848288068,1000.0,42.93781685683985 192 | 190,2:12:49,19100,0.03540768364444375,4850.16958941188,1000.0,41.32319446891407 193 | 191,2:13:30,19200,0.035296422131359584,4961.913704649762,1000.0,42.2232549163785 194 | 192,2:14:11,19300,0.03506270909681916,5052.041123676718,1000.0,42.949200304424465 195 | 193,2:14:52,19400,0.03535914344713092,4995.390585774812,1000.0,42.492899689537104 196 | 194,2:15:33,19500,0.03518122298642993,5040.732719129157,1000.0,42.858114991918114 197 | 195,2:16:14,19600,0.035614689849317066,5115.5204377188975,1000.0,43.46050436441821 198 | 196,2:16:55,19700,0.035280204694718126,5025.68273982372,1000.0,42.736892580526295 199 | 197,2:17:36,19800,0.034820137321948996,5009.366336443343,1000.0,42.605469558416466 200 | 198,2:18:17,19900,0.03540942640975118,5082.9670591427175,1000.0,43.19829808692985 201 | 199,2:18:58,20000,0.035335704796016214,4745.66566583327,1000.0,40.481451277178934 202 | -------------------------------------------------------------------------------- /dt_runs/dt_halfcheetah-medium-v2_log_22-02-11-10-13-57.csv: -------------------------------------------------------------------------------- 1 | ,duration,num_updates,action_loss,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score 2 | 0,0:01:06,100,0.8700329530239105,-121.80672710369377,1000.0,1.2756338551047404 3 | 1,0:01:55,200,0.8118200784921646,-109.19730906133486,1000.0,1.3771983842194169 4 | 2,0:02:41,300,0.7109269291162491,-82.26911104421941,1000.0,1.594095765393359 5 | 3,0:03:26,400,0.5884400677680969,-4.079822484864903,1000.0,2.2238836150518684 6 | 4,0:04:13,500,0.4842221122980118,-15.613033185554528,1000.0,2.130987566236537 7 | 5,0:04:59,600,0.400865326821804,-6.327428477096876,1000.0,2.205779921172459 8 | 6,0:05:45,700,0.3345118749141693,99.66437970689944,1000.0,3.059507512093607 9 | 7,0:06:31,800,0.2857773834466934,218.6401961137984,1000.0,4.017816827306092 10 | 8,0:07:16,900,0.26098459869623186,115.26421627499157,1000.0,3.185158834778106 11 | 9,0:08:02,1000,0.2428016723692417,159.73046508778057,1000.0,3.5433191881739323 12 | 10,0:08:48,1100,0.23208357617259026,108.38575867088213,1000.0,3.129755222553513 13 | 11,0:09:34,1200,0.22262664556503292,-24.06234111174161,1000.0,2.0629312944890774 14 | 12,0:10:20,1300,0.20967158779501915,537.228384845339,1000.0,6.583935204959898 15 | 13,0:11:06,1400,0.20462087512016294,512.1442710633586,1000.0,6.381891288581882 16 | 14,0:11:51,1500,0.19489261478185652,696.3804419117578,1000.0,7.865850332151532 17 | 15,0:12:37,1600,0.18769512981176376,701.6367819346145,1000.0,7.908188344698559 18 | 16,0:13:23,1700,0.181135138720274,614.8812061338224,1000.0,7.209401995108094 19 | 17,0:14:14,1800,0.17287482768297194,293.3487575784689,1000.0,4.619568616366032 20 | 18,0:15:00,1900,0.16703419044613838,208.815975106858,1000.0,3.938686103180957 21 | 19,0:15:46,2000,0.16230630606412888,437.099069511589,1000.0,5.777427979306459 22 | 20,0:16:32,2100,0.1581895300745964,406.94623828902496,1000.0,5.534557285805279 23 | 21,0:17:18,2200,0.1523415331542492,517.6283165840375,1000.0,6.426063390663053 24 | 22,0:18:04,2300,0.1482094943523407,643.6343841144256,1000.0,7.440998962735013 25 | 23,0:18:50,2400,0.14304509356617928,802.0734478965138,1000.0,8.71717117404094 26 | 24,0:19:35,2500,0.1396954370290041,299.6217051605537,1000.0,4.670095053446254 27 | 25,0:20:20,2600,0.13703133530914782,435.9919061549858,1000.0,5.768510158944832 28 | 26,0:21:11,2700,0.13186223536729813,460.341761482739,1000.0,5.964639875801384 29 | 27,0:22:00,2800,0.128665576800704,285.47366546149874,1000.0,4.5561374556330065 30 | 28,0:22:45,2900,0.12623580418527125,838.6609899947704,1000.0,9.011871252362528 31 | 29,0:23:31,3000,0.12277084104716778,755.9451416156412,1000.0,8.345623518904432 32 | 30,0:24:17,3100,0.1185512564331293,499.36311531152563,1000.0,6.278943471234921 33 | 31,0:25:02,3200,0.11923709459602833,748.8108618771047,1000.0,8.288159347300102 34 | 32,0:25:48,3300,0.1149604281783104,328.9523677163666,1000.0,4.90634346087437 35 | 33,0:26:34,3400,0.1117021717876196,685.4925558679765,1000.0,7.778152151682291 36 | 34,0:27:20,3500,0.1100981656461954,435.1015024832968,1000.0,5.761338263355894 37 | 35,0:28:05,3600,0.10786762557923794,1185.2172587071007,1000.0,11.803262903053064 38 | 36,0:28:50,3700,0.10584268748760224,1382.1392377020334,1000.0,13.389401771775114 39 | 37,0:29:35,3800,0.1032477230578661,1309.2801508001626,1000.0,12.802546864748061 40 | 38,0:30:20,3900,0.10154466181993484,766.3927777381748,1000.0,8.429775637549561 41 | 39,0:31:05,4000,0.09906728699803352,953.9242684878557,1000.0,9.94027735048996 42 | 40,0:31:50,4100,0.0981192423403263,1478.772863072714,1000.0,14.167752416067117 43 | 41,0:32:37,4200,0.09676806017756462,1245.8103917034755,1000.0,12.291319766556695 44 | 42,0:33:24,4300,0.09422173857688904,2153.9294708737257,1000.0,19.60590687487069 45 | 43,0:34:09,4400,0.09204089924693107,1693.5069260502155,1000.0,15.897361500160208 46 | 44,0:34:54,4500,0.08984700180590152,2090.788263965116,1000.0,19.097326151647586 47 | 45,0:35:40,4600,0.08803952217102051,1762.2843602944085,1000.0,16.451340097686373 48 | 46,0:36:25,4700,0.0874203984439373,1522.795357693832,1000.0,14.52233848194481 49 | 47,0:37:10,4800,0.08497955657541753,3675.890344714028,1000.0,31.864778693005338 50 | 48,0:37:54,4900,0.08407509446144104,2270.22307076508,1000.0,20.54261185779204 51 | 49,0:38:39,5000,0.08185447856783867,3398.76647083986,1000.0,29.63264112234871 52 | 50,0:39:24,5100,0.0808236364275217,3075.74163317703,1000.0,27.030787062204254 53 | 51,0:40:09,5200,0.0806402637064457,2808.421261264034,1000.0,24.87761333087918 54 | 52,0:40:54,5300,0.07776038005948066,2981.782915606262,1000.0,26.27398188101061 55 | 53,0:41:39,5400,0.0765034330636263,2822.6197145636793,1000.0,24.991976992920584 56 | 54,0:42:24,5500,0.07644095495343209,3728.112803256548,1000.0,32.28541265035882 57 | 55,0:43:09,5600,0.07483840860426426,3535.413232080341,1000.0,30.733283825589503 58 | 56,0:43:54,5700,0.07348211094737052,3952.560642982157,1000.0,34.093262868026216 59 | 57,0:44:39,5800,0.07182960987091064,3426.0782472031783,1000.0,29.85262809528492 60 | 58,0:45:24,5900,0.07146185677498579,4045.8797328691107,1000.0,34.8449160680343 61 | 59,0:46:08,6000,0.07072196155786514,4289.77344783861,1000.0,36.809396128231626 62 | 60,0:46:53,6100,0.06940111685544252,4363.7702985637,1000.0,37.40541533186308 63 | 61,0:47:38,6200,0.06800241976976394,4660.839823302788,1000.0,39.798208265929524 64 | 62,0:48:22,6300,0.06815596289932728,3840.978770130919,1000.0,33.19450922723175 65 | 63,0:49:07,6400,0.06644017066806555,3873.004201558808,1000.0,33.452463071869246 66 | 64,0:49:51,6500,0.06605536881834269,3973.8535921149683,1000.0,34.26477025598591 67 | 65,0:50:36,6600,0.06511681411415339,4650.667788618051,1000.0,39.716276022155625 68 | 66,0:51:20,6700,0.06448001816868781,4306.370799560644,1000.0,36.94308209268583 69 | 67,0:52:05,6800,0.06306074101477861,4364.931799731248,1000.0,37.41477082461872 70 | 68,0:52:49,6900,0.06270165368914604,4442.807638182052,1000.0,38.042033941369745 71 | 69,0:53:34,7000,0.06191851746290921,4305.2549175195945,1000.0,36.934094046317156 72 | 70,0:54:22,7100,0.06114271230995655,4339.582333553909,1000.0,37.21058958588422 73 | 71,0:55:08,7200,0.060012868233025075,4492.194022353522,1000.0,38.43982429427912 74 | 72,0:55:53,7300,0.059701225869357576,4924.086266034215,1000.0,41.918567897699596 75 | 73,0:56:42,7400,0.059487079307436935,4802.257108410965,1000.0,40.93727589953785 76 | 74,0:57:29,7500,0.058900052942335614,4890.381413161165,1000.0,41.647086890453174 77 | 75,0:58:26,7600,0.058106004483997824,4443.392153703881,1000.0,38.04674201302978 78 | 76,0:59:13,7700,0.057634426765143874,4830.309234873492,1000.0,41.16322613810246 79 | 77,1:00:01,7800,0.05696785513311625,4552.012855325621,1000.0,38.9216444371748 80 | 78,1:00:46,7900,0.05667423911392689,4767.914202565889,1000.0,40.66065559486815 81 | 79,1:01:31,8000,0.055956434085965165,4849.088897240892,1000.0,41.31448986485577 82 | 80,1:02:16,8100,0.054952321499586114,4488.358715058293,1000.0,38.40893221201635 83 | 81,1:03:00,8200,0.05520208325237036,4379.657968278004,1000.0,37.53338505162668 84 | 82,1:03:44,8300,0.0538373476266861,4905.361166695393,1000.0,41.76774365739094 85 | 83,1:04:29,8400,0.05378747802227735,4966.062574731258,1000.0,42.256672639129036 86 | 84,1:05:13,8500,0.05277358137071133,4948.487363302918,1000.0,42.11511035078125 87 | 85,1:05:57,8600,0.052445767261087885,4609.869500327317,1000.0,39.387659830273215 88 | 86,1:06:41,8700,0.05244192194193602,4723.6756070466245,1000.0,40.30432891051879 89 | 87,1:07:25,8800,0.05154969844967127,4324.530594544666,1000.0,37.08935300068297 90 | 88,1:08:10,8900,0.05146684240549803,4943.560219094873,1000.0,42.07542389739465 91 | 89,1:08:54,9000,0.05206151738762856,4523.854971832925,1000.0,38.69484236207549 92 | 90,1:09:38,9100,0.05054484277963638,4982.915681953748,1000.0,42.392418626249224 93 | 91,1:10:23,9200,0.04981490548700094,5010.356986459847,1000.0,42.61344890386331 94 | 92,1:11:07,9300,0.04969214163720608,4672.340743983453,1000.0,39.89084422972999 95 | 93,1:11:51,9400,0.04987182389944792,5008.033696723017,1000.0,42.59473560342982 96 | 94,1:12:35,9500,0.04836590778082609,4935.176775166829,1000.0,42.00789813751811 97 | 95,1:13:19,9600,0.04833526697009802,4583.415372166562,1000.0,39.17458091887854 98 | 96,1:14:03,9700,0.04898699291050434,5008.0818133580515,1000.0,42.595123166389776 99 | 97,1:14:48,9800,0.04816077660769224,4987.342507733438,1000.0,42.42807518662945 100 | 98,1:15:32,9900,0.0475557267293334,4973.207958081004,1000.0,42.3142262465059 101 | 99,1:16:17,10000,0.04738544635474682,4971.898756765151,1000.0,42.30368107981271 102 | 100,1:17:01,10100,0.04685639087110758,4977.882782441533,1000.0,42.35188035023028 103 | 101,1:17:45,10200,0.04694512948393822,4955.456234760295,1000.0,42.17124221552326 104 | 102,1:18:30,10300,0.04603830847889185,4933.318023899008,1000.0,41.992926534814224 105 | 103,1:19:14,10400,0.046809835396707064,4592.819811436934,1000.0,39.25033044537327 106 | 104,1:19:59,10500,0.045678039342164994,4979.62448849184,1000.0,42.3659091939296 107 | 105,1:20:53,10600,0.045413269065320484,4973.648272668112,1000.0,42.31777282919131 108 | 106,1:21:38,10700,0.04516366574913264,4926.094846821152,1000.0,41.93474632569117 109 | 107,1:22:35,10800,0.04529505651444197,4505.487870689983,1000.0,38.546901674208854 110 | 108,1:23:35,10900,0.04456073343753815,4960.988971283199,1000.0,42.21580650689473 111 | 109,1:24:36,11000,0.04477448154240846,4921.558586787377,1000.0,41.89820831000128 112 | 110,1:25:22,11100,0.044482230208814144,4991.347623599277,1000.0,42.4603350185739 113 | 111,1:26:06,11200,0.044093156792223455,4558.405590299538,1000.0,38.97313572053138 114 | 112,1:26:57,11300,0.04441542726010085,5022.616639578724,1000.0,42.71219619671577 115 | 113,1:27:48,11400,0.04355392377823592,5082.080303283854,1000.0,43.19115557321965 116 | 114,1:28:32,11500,0.04328299786895514,5057.556689376153,1000.0,42.993626290713635 117 | 115,1:29:17,11600,0.043763098679482935,5000.879221501868,1000.0,42.5371087641532 118 | 116,1:30:01,11700,0.042863015756011015,5086.484155719272,1000.0,43.22662709120655 119 | 117,1:30:46,11800,0.04321520932018757,4888.081964548471,1000.0,41.62856562208162 120 | 118,1:31:36,11900,0.042797876186668866,4664.780502467736,1000.0,39.82994908239189 121 | 119,1:32:20,12000,0.041930680684745314,4933.658390611499,1000.0,41.995668071716594 122 | 120,1:33:04,12100,0.04232073538005352,4578.305382098882,1000.0,39.13342170492739 123 | 121,1:33:49,12200,0.04164695870131254,4585.448972415133,1000.0,39.190960870035646 124 | 122,1:34:33,12300,0.042179904058575635,5011.640605161942,1000.0,42.62378801139414 125 | 123,1:35:18,12400,0.042020661868155,4983.778563586287,1000.0,42.39936884127078 126 | 124,1:36:02,12500,0.04149483133107424,4995.6510520769425,1000.0,42.49499765609171 127 | 125,1:36:47,12600,0.04112834580242634,4850.8898136523885,1000.0,41.32899562766689 128 | 126,1:37:32,12700,0.041288108788430686,5020.629288755573,1000.0,42.696188768786826 129 | 127,1:38:22,12800,0.041735970936715605,4985.045297504776,1000.0,42.40957194767207 130 | 128,1:39:31,12900,0.04124053660780192,4087.712690278192,1000.0,35.18186616410177 131 | 129,1:40:26,13000,0.04070476491004229,4986.595988716682,1000.0,42.422062232490156 132 | 130,1:41:29,13100,0.0401571773737669,5028.882061541628,1000.0,42.762662017519666 133 | 131,1:42:15,13200,0.04069599751383066,5036.598494126447,1000.0,42.82481523024444 134 | 132,1:42:59,13300,0.04060786869376898,4876.276220206702,1000.0,41.53347440844337 135 | 133,1:43:43,13400,0.039769974350929264,5047.792322000095,1000.0,42.91497766701659 136 | 134,1:44:29,13500,0.04009150326251984,5025.1708440971115,1000.0,42.732769436361025 137 | 135,1:45:13,13600,0.040570413246750835,4977.78461838378,1000.0,42.3510896724791 138 | 136,1:45:58,13700,0.04013428058475256,5028.765526943515,1000.0,42.7617233713789 139 | 137,1:46:42,13800,0.03964978367090225,4963.067308546302,1000.0,42.23254679933007 140 | 138,1:47:34,13900,0.0391604271531105,4620.904499332255,1000.0,39.47654295508933 141 | 139,1:48:29,14000,0.039104795940220365,5070.933873445447,1000.0,43.101374911333075 142 | 140,1:49:13,14100,0.03904504384845496,4983.0647414185,1000.0,42.39361924901364 143 | 141,1:50:05,14200,0.039295772314071664,5046.4863987023655,1000.0,42.904458903632886 144 | 142,1:50:49,14300,0.03873799603432417,4903.894584282997,1000.0,41.755930840048975 145 | 143,1:51:34,14400,0.03885609772056341,4844.691999418566,1000.0,41.27907436388738 146 | 144,1:52:34,14500,0.03889308143407107,5037.838851113429,1000.0,42.83480587952688 147 | 145,1:53:31,14600,0.038433850705623634,5011.426923912381,1000.0,42.62206688236031 148 | 146,1:54:31,14700,0.03825357872992754,4926.728490050982,1000.0,41.93985010415646 149 | 147,1:55:17,14800,0.03919964246451855,5027.718207619962,1000.0,42.753287574138135 150 | 148,1:56:02,14900,0.03823530346155166,5036.715908273308,1000.0,42.825760960848115 151 | 149,1:56:46,15000,0.03864511232823133,4981.209879272717,1000.0,42.37867897185128 152 | 150,1:57:30,15100,0.038505403622984886,4506.23035443601,1000.0,38.55288212562915 153 | 151,1:58:14,15200,0.03796029964461923,5057.736094401394,1000.0,42.99507133653954 154 | 152,1:58:59,15300,0.038172388933599,5065.133919028413,1000.0,43.054658271653615 155 | 153,1:59:43,15400,0.03762233829125762,4972.431365422457,1000.0,42.307971059516774 156 | 154,2:00:27,15500,0.0374450365267694,5049.750634738626,1000.0,42.930751203152845 157 | 155,2:01:15,15600,0.037612152472138415,4945.13921178919,1000.0,42.08814214092778 158 | 156,2:02:00,15700,0.03784166596829891,4987.82383332195,1000.0,42.43195209883777 159 | 157,2:02:44,15800,0.03797922693192959,4926.011041817995,1000.0,41.934071305190265 160 | 158,2:03:28,15900,0.03733926512300968,4969.351956693261,1000.0,42.28316748043946 161 | 159,2:04:12,16000,0.0375642435066402,5042.578550433512,1000.0,42.872982528756225 162 | 160,2:04:57,16100,0.03728803128004074,5037.89093759949,1000.0,42.835225418272636 163 | 161,2:05:41,16200,0.03707937855273485,4982.512792422401,1000.0,42.3891734895269 164 | 162,2:06:30,16300,0.03712301319465041,5015.132159367889,1000.0,42.65191128065319 165 | 163,2:07:14,16400,0.037342985048890116,5042.71360391316,1000.0,42.87407033812378 166 | 164,2:07:59,16500,0.03713747594505548,5058.4736787599595,1000.0,43.001012325077504 167 | 165,2:08:43,16600,0.037134092897176736,4943.997510094255,1000.0,42.07894612611996 168 | 166,2:09:30,16700,0.03660496266558767,5066.035417948929,1000.0,43.06191953565897 169 | 167,2:10:14,16800,0.03667282532900572,4969.980304358327,1000.0,42.28822860495039 170 | 168,2:10:59,16900,0.0366557383351028,5054.1090716712715,1000.0,42.9658569148719 171 | 169,2:11:44,17000,0.03669415656477213,5029.388508371176,1000.0,42.766741272691625 172 | 170,2:12:29,17100,0.03657138550654054,4316.501373967916,1000.0,37.024680388172534 173 | 171,2:13:14,17200,0.03623724704608322,4814.550051751816,1000.0,41.036291333688155 174 | 172,2:13:59,17300,0.03628178855404258,4975.885088840287,1000.0,42.33578961477807 175 | 173,2:14:44,17400,0.03619858954101801,5029.236041249615,1000.0,42.765513202422646 176 | 174,2:15:28,17500,0.03632583538070321,4996.298260461638,1000.0,42.50021069721779 177 | 175,2:16:13,17600,0.036661872956901785,4903.088064425253,1000.0,41.74943459975476 178 | 176,2:16:57,17700,0.03617170359939337,5057.494741891011,1000.0,42.99312732500901 179 | 177,2:17:41,17800,0.036089637018740184,4805.98158648173,1000.0,40.967275290483926 180 | 178,2:18:25,17900,0.036258722897619014,5046.858014355642,1000.0,42.90745214001461 181 | 179,2:19:10,18000,0.035698072351515286,5001.3257503794875,1000.0,42.54070540081315 182 | 180,2:19:55,18100,0.0361002010665834,5023.07118033685,1000.0,42.7158573663199 183 | 181,2:20:40,18200,0.03595602897927165,4983.941378794191,1000.0,42.40068026181911 184 | 182,2:21:24,18300,0.035628468245267865,4518.850223100737,1000.0,38.65453083091565 185 | 183,2:22:09,18400,0.03627290891483426,4990.355398085396,1000.0,42.45234298303711 186 | 184,2:22:53,18500,0.035719001572579144,5024.859239608086,1000.0,42.73025956928456 187 | 185,2:23:37,18600,0.0362304119579494,4972.808167484209,1000.0,42.31100607063645 188 | 186,2:24:21,18700,0.03542027980089188,5118.644368591355,1000.0,43.4856665540595 189 | 187,2:25:05,18800,0.035339906495064494,4733.043416955865,1000.0,40.379783400097274 190 | 188,2:25:49,18900,0.03496069768443704,5059.447960055046,1000.0,43.00885982609843 191 | 189,2:26:33,19000,0.03549280324950814,4760.856483377038,1000.0,40.603808092181566 192 | 190,2:27:17,19100,0.035504854600876566,4960.891927673846,1000.0,42.21502485397035 193 | 191,2:28:01,19200,0.035393788777291775,5069.055172730534,1000.0,43.08624262268847 194 | 192,2:28:46,19300,0.034903209898620835,5060.537895874748,1000.0,43.01763888457055 195 | 193,2:29:33,19400,0.034767945259809495,4999.894903099978,1000.0,42.52918041768622 196 | 194,2:30:18,19500,0.03509797116741538,4924.294118905878,1000.0,41.92024208115237 197 | 195,2:31:02,19600,0.0348175454698503,5078.420989239588,1000.0,43.16168105611347 198 | 196,2:31:46,19700,0.03528429336845875,4988.675373093934,1000.0,42.43881095907013 199 | 197,2:32:30,19800,0.03462814662605524,4973.116132261546,1000.0,42.31348662108604 200 | 198,2:33:13,19900,0.034336135070770976,5006.585915688245,1000.0,42.58307422472354 201 | 199,2:33:57,20000,0.03477673007175326,5049.198323661414,1000.0,42.92630252722717 202 | -------------------------------------------------------------------------------- /dt_runs/dt_halfcheetah-medium-v2_log_22-02-13-09-03-10.csv: -------------------------------------------------------------------------------- 1 | duration,num_updates,action_loss,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score 2 | 0:00:32,100,0.8343127316236496,-283.66523900445765,1000.0,-0.028080835706482045 3 | 0:01:01,200,0.7824837225675583,-244.96124936879588,1000.0,0.2836665002133868 4 | 0:01:29,300,0.6899479562044144,-223.5892289632806,1000.0,0.45581078010192566 5 | 0:01:56,400,0.580738832950592,-216.8118383607401,1000.0,0.5104003323604761 6 | 0:02:24,500,0.48492678731679917,-201.96499517297917,1000.0,0.6299865521319868 7 | 0:02:52,600,0.40410026252269743,-147.01866796690823,1000.0,1.0725603355150586 8 | 0:03:20,700,0.33457408487796786,-14.120819439167178,1000.0,2.1430068351656146 9 | 0:03:48,800,0.2859437969326973,182.63243303823924,1000.0,3.727786669771728 10 | 0:04:16,900,0.25969662472605703,475.19923757789303,1000.0,6.084311739987958 11 | 0:04:44,1000,0.24299988597631456,629.2006983996813,1000.0,7.3247405844274125 12 | 0:05:11,1100,0.23253118202090264,569.6996996773275,1000.0,6.845480487189942 13 | 0:05:39,1200,0.22182855844497681,452.462340135923,1000.0,5.901173844609688 14 | 0:06:07,1300,0.2125807870924473,348.32490145729537,1000.0,5.062382562801673 15 | 0:06:34,1400,0.20150695189833642,689.205723017006,1000.0,7.808060437040773 16 | 0:07:02,1500,0.1956814283132553,335.86558326199963,1000.0,4.962027036373397 17 | 0:07:30,1600,0.1875275845825672,-116.17907595078893,1000.0,1.3209626512035266 18 | 0:07:57,1700,0.17884221836924552,329.50735444612803,1000.0,4.9108136882618485 19 | 0:08:25,1800,0.17367472559213637,-32.244035667528976,1000.0,1.9970305564750646 20 | 0:08:53,1900,0.16676318034529686,191.20217088657574,1000.0,3.7968129631564542 21 | 0:09:21,2000,0.16205843642354012,84.21929158046684,1000.0,2.935102634927495 22 | 0:09:48,2100,0.15619607627391816,-9.672275962283926,1000.0,2.178838324133466 23 | 0:10:16,2200,0.15302484899759292,194.5000749255822,1000.0,3.82337644686854 24 | 0:10:44,2300,0.14815503224730492,102.74687889403435,1000.0,3.084335983747574 25 | 0:11:12,2400,0.1451341937482357,306.2623822153468,1000.0,4.7235834250592035 26 | 0:11:40,2500,0.14208106234669685,159.5822044553617,1000.0,3.542124999729447 27 | 0:12:08,2600,0.1364733275771141,365.5016671358411,1000.0,5.200735507560437 28 | 0:12:35,2700,0.1341610546410084,614.7893804357589,1000.0,7.208662370666022 29 | 0:13:03,2800,0.13020548678934574,613.4986256114802,1000.0,7.19826578412333 30 | 0:13:31,2900,0.1287697058916092,370.6574310148586,1000.0,5.242263413831749 31 | 0:13:59,3000,0.124820379242301,433.16134599095886,1000.0,5.745710969543355 32 | 0:14:26,3100,0.12218771293759347,450.6340751042693,1000.0,5.886447798061548 33 | 0:14:54,3200,0.11939074948430062,796.4805059344549,1000.0,8.672121948546632 34 | 0:15:22,3300,0.11600389406085014,1193.0977795900485,1000.0,11.866737790630449 35 | 0:15:50,3400,0.11494396232068539,1015.5936039971981,1000.0,10.437002655399406 36 | 0:16:18,3500,0.11073567926883697,724.1555029793099,1000.0,8.089568904172925 37 | 0:16:45,3600,0.10951775871217251,747.674484073056,1000.0,8.279006214603822 38 | 0:17:12,3700,0.10818672023713588,881.2691385694367,1000.0,9.355065246874952 39 | 0:17:40,3800,0.10504961468279361,955.5763756845479,1000.0,9.953584506213986 40 | 0:18:07,3900,0.10323845706880093,1019.9283368328057,1000.0,10.471917438762718 41 | 0:18:35,4000,0.10089175820350647,935.8708019209286,1000.0,9.79486288135285 42 | 0:19:03,4100,0.09826306015253067,1827.6024771316224,1000.0,16.97745508229101 43 | 0:19:30,4200,0.09640958987176418,1447.657659425143,1000.0,13.917130143401025 44 | 0:19:58,4300,0.0958897591382265,2006.4970445691422,1000.0,18.418389346023805 45 | 0:20:25,4400,0.09307771861553192,1127.0226716684895,1000.0,11.334525502980796 46 | 0:20:53,4500,0.09156623221933842,2080.573035939388,1000.0,19.01504600035537 47 | 0:21:20,4600,0.0904998654127121,1822.9949394166877,1000.0,16.940342949373896 48 | 0:21:48,4700,0.08837514266371727,2692.5022221247937,1000.0,23.943925306098595 49 | 0:22:15,4800,0.08788972571492196,3168.6558342953076,1000.0,27.779179022320356 50 | 0:22:43,4900,0.08491212002933025,2888.8516330442035,1000.0,25.525452335734876 51 | 0:23:10,5000,0.08403476357460021,2901.343631568555,1000.0,25.626071090983128 52 | 0:23:38,5100,0.08166310600936413,2303.3406405236947,1000.0,20.809362501371066 53 | 0:24:05,5200,0.08117769435048103,2749.9052385982727,1000.0,24.40628687729132 54 | 0:24:33,5300,0.07957220941781998,2087.8550107536184,1000.0,19.073699805039116 55 | 0:25:00,5400,0.07859855450689793,1671.6796678280466,1000.0,15.721550436100642 56 | 0:25:28,5500,0.07715587168931962,3002.668240415789,1000.0,26.44220599512601 57 | 0:25:55,5600,0.07579587303102016,2635.3030731989033,1000.0,23.48320581794278 58 | 0:26:22,5700,0.07472566425800324,3569.7755006297325,1000.0,31.010060090188475 59 | 0:26:49,5800,0.07406453892588616,3060.238724509674,1000.0,26.90591646045099 60 | 0:27:17,5900,0.0722687716037035,3219.495055628395,1000.0,28.188671479300226 61 | 0:27:45,6000,0.07141946729272604,3487.2826755695814,1000.0,30.345608732922962 62 | 0:28:12,6100,0.07084653932601213,3702.9596814002493,1000.0,32.082812897656744 63 | 0:28:40,6200,0.06988365158438682,2778.5651147908175,1000.0,24.637132331078508 64 | 0:29:07,6300,0.0692059276252985,3436.4853007610777,1000.0,29.936453335318085 65 | 0:29:35,6400,0.06818865921348333,3897.8833213520447,1000.0,33.65285583211395 66 | 0:30:02,6500,0.0675203338265419,3447.146254021153,1000.0,30.022323650199844 67 | 0:30:29,6600,0.06643117092549801,3973.9055144264566,1000.0,34.2651884723619 68 | 0:30:57,6700,0.06570949103683234,3767.6232802232166,1000.0,32.60365596458122 69 | 0:31:24,6800,0.064666579477489,3181.57555991698,1000.0,27.883242972349446 70 | 0:31:52,6900,0.06373731590807438,3868.5388173446618,1000.0,33.41649593655004 71 | 0:32:19,7000,0.0633472204953432,4110.952949774364,1000.0,35.36905846784666 72 | 0:32:46,7100,0.0628714295476675,4343.006209745128,1000.0,37.23816773199216 73 | 0:33:14,7200,0.06141019638627768,4183.138998974191,1000.0,35.95049228747263 74 | 0:33:41,7300,0.0611110120266676,4386.341461018108,1000.0,37.58721828887123 75 | 0:34:08,7400,0.05998071689158678,4293.16294911807,1000.0,36.83669739623824 76 | 0:34:36,7500,0.059286071583628655,3835.5393990237185,1000.0,33.15069696219882 77 | 0:35:03,7600,0.05904442846775055,4758.783252884457,1000.0,40.58710893302786 78 | 0:35:30,7700,0.058288184367120265,4791.266003302971,1000.0,40.84874632497752 79 | 0:35:58,7800,0.057520129643380644,4516.280527587427,1000.0,38.63383281662977 80 | 0:36:26,7900,0.05670076467096805,3928.852683739349,1000.0,33.90230340354683 81 | 0:36:54,8000,0.05614825654774904,4894.178204122014,1000.0,41.6776687368786 82 | 0:37:21,8100,0.05526526283472776,4801.8575294707625,1000.0,40.93405742848951 83 | 0:37:48,8200,0.055937583334743975,4971.030745190484,1000.0,42.296689544870254 84 | 0:38:16,8300,0.05533768724650145,4517.420802036202,1000.0,38.64301733546024 85 | 0:38:43,8400,0.05482041105628013,4568.387234987972,1000.0,39.05353443831243 86 | 0:39:11,8500,0.05349003326147795,4954.452425340134,1000.0,42.163156875602176 87 | 0:39:39,8600,0.05391915164887905,4536.235045870422,1000.0,38.79455960404489 88 | 0:40:06,8700,0.05321140967309475,4689.832019547389,1000.0,40.0317304435345 89 | 0:40:33,8800,0.052513382509350774,4930.276006628296,1000.0,41.96842413108546 90 | 0:41:01,8900,0.05192779049277305,4721.438297757164,1000.0,40.28630815304176 91 | 0:41:28,9000,0.05203000195324421,4507.127305588854,1000.0,38.56010675892876 92 | 0:41:56,9100,0.051698968410491944,4629.809239048805,1000.0,39.54826757339938 93 | 0:42:23,9200,0.050326685421168804,5004.946449897966,1000.0,42.56986889118396 94 | 0:42:50,9300,0.05075975686311722,5050.614239495508,1000.0,42.93770724269243 95 | 0:43:18,9400,0.04968950662761926,5006.24961323108,1000.0,42.580365424001144 96 | 0:43:45,9500,0.05019066274166107,4821.059415216133,1000.0,41.08872201945563 97 | 0:44:13,9600,0.049082524850964544,3869.6482726009185,1000.0,33.42543221737577 98 | 0:44:40,9700,0.04844164222478867,4562.268896413901,1000.0,39.004253323660485 99 | 0:45:08,9800,0.04828441377729178,4986.46689679914,1000.0,42.42102244145671 100 | 0:45:35,9900,0.04819107871502638,3986.7302325314713,1000.0,34.36848716949357 101 | 0:46:03,10000,0.048519133292138576,5001.973166481112,1000.0,42.54592011502769 102 | 0:46:31,10100,0.047430793009698394,4957.43044064356,1000.0,42.18714376547867 103 | 0:46:58,10200,0.04775850534439087,4603.103336939897,1000.0,39.33316070937424 104 | 0:47:25,10300,0.04665050383657217,4847.717344550916,1000.0,41.303442479271006 105 | 0:47:53,10400,0.0467537971585989,4988.290963540482,1000.0,42.43571467221912 106 | 0:48:20,10500,0.04597364723682404,4673.892833131505,1000.0,39.90334577444334 107 | 0:48:48,10600,0.046645872555673124,4875.320244979566,1000.0,41.52577435650891 108 | 0:49:15,10700,0.04584549762308598,4951.737243685532,1000.0,42.14128702044438 109 | 0:49:42,10800,0.04524834435433149,4594.798922818796,1000.0,39.26627150743411 110 | 0:50:10,10900,0.04502670329064131,4871.327594896747,1000.0,41.493614932162835 111 | 0:50:37,11000,0.044666882269084454,4614.337020678764,1000.0,39.423644171444295 112 | 0:51:04,11100,0.043923421390354635,4998.285700320959,1000.0,42.516218842302486 113 | 0:51:32,11200,0.04455430325120688,5035.340093043106,1000.0,42.81467924196667 114 | 0:51:59,11300,0.04447721507400274,4780.765633393018,1000.0,40.764169453796654 115 | 0:52:26,11400,0.044197061099112034,5034.301049956632,1000.0,42.80631010697145 116 | 0:52:54,11500,0.04418921891599894,5070.789745484528,1000.0,43.100214010137336 117 | 0:53:22,11600,0.04386120047420263,4816.827688628107,1000.0,41.05463691601858 118 | 0:53:54,11700,0.042983717061579226,4911.465209847867,1000.0,41.816909627334525 119 | 0:54:22,11800,0.04287682596594095,4707.078195107113,1000.0,40.1706424610335 120 | 0:54:50,11900,0.04314782101660967,4913.4734033647155,1000.0,41.83308493599863 121 | 0:55:17,12000,0.04271706335246563,4584.5008003828425,1000.0,39.183323670154124 122 | 0:55:45,12100,0.04293158460408449,4979.113313776282,1000.0,42.36179185726056 123 | 0:56:12,12200,0.04302726477384567,5014.386179051639,1000.0,42.64590266556135 124 | 0:56:40,12300,0.042482958771288395,5040.012796780235,1000.0,42.8523162647983 125 | 0:57:07,12400,0.04155253190547228,4979.475599168009,1000.0,42.36470994159183 126 | 0:57:35,12500,0.0420973202958703,4294.676495387038,1000.0,36.8488884913058 127 | 0:58:02,12600,0.04198410116136074,4997.232224770008,1000.0,42.507733458765614 128 | 0:58:29,12700,0.04155763871967792,4935.979828512483,1000.0,42.01436645624872 129 | 0:58:57,12800,0.04128145672380924,4996.712507353155,1000.0,42.50354731357334 130 | 0:59:24,12900,0.040988856367766856,5023.113322071553,1000.0,42.71619680351096 131 | 0:59:52,13000,0.04128929030150175,4856.518105860778,1000.0,41.37432958724729 132 | 1:00:20,13100,0.040933935754001144,5038.606195946469,1000.0,42.84098657845958 133 | 1:00:47,13200,0.0411022624000907,5048.922612630516,1000.0,42.924081769621154 134 | 1:01:15,13300,0.04053767930716276,4964.935925045971,1000.0,42.24759786308632 135 | 1:01:42,13400,0.04094157513231039,5037.990990876015,1000.0,42.8360313130318 136 | 1:02:11,13500,0.03950948935002088,5021.456251108577,1000.0,42.70284966635532 137 | 1:02:39,13600,0.04019030898809433,5003.414024585403,1000.0,42.557525731908015 138 | 1:03:06,13700,0.03986798405647278,5021.584293934346,1000.0,42.70388100731507 139 | 1:03:34,13800,0.040183973275125025,4977.780171181775,1000.0,42.35105385179521 140 | 1:04:02,13900,0.0401476414874196,4800.162977717427,1000.0,40.92040839644775 141 | 1:04:29,14000,0.0392570922523737,4992.423403289102,1000.0,42.46900005428461 142 | 1:04:57,14100,0.039165541715919974,4650.809752698118,1000.0,39.71741949403472 143 | 1:05:25,14200,0.0393249961733818,4995.184686342919,1000.0,42.49124124037037 144 | 1:05:53,14300,0.03923071805387735,5006.705223721909,1000.0,42.584035209934584 145 | 1:06:20,14400,0.038980237133800985,5017.845371369805,1000.0,42.673765270935476 146 | 1:06:58,14500,0.03978114198893309,5014.477799603504,1000.0,42.64664063762129 147 | 1:07:33,14600,0.03833754047751427,5004.69620201773,1000.0,42.567853230506145 148 | 1:08:06,14700,0.03874169554561377,4531.512498021814,1000.0,38.75652110402418 149 | 1:08:37,14800,0.038895750157535075,4984.017308438135,1000.0,42.40129184900791 150 | 1:09:04,14900,0.03917961470782757,4972.725327620375,1000.0,42.310338824001114 151 | 1:09:32,15000,0.03843617357313633,5055.530138036226,1000.0,42.977303116093275 152 | 1:09:59,15100,0.03888559389859438,4789.969437683505,1000.0,40.838302934476474 153 | 1:10:28,15200,0.03848989751189947,5075.431954428535,1000.0,43.137605407889886 154 | 1:10:56,15300,0.038144539669156076,5020.3384337454445,1000.0,42.69384603163234 155 | 1:11:27,15400,0.038408829756081106,4972.241101543641,1000.0,42.30643854935693 156 | 1:11:55,15500,0.0383669700846076,4949.365308284581,1000.0,42.122181895903445 157 | 1:12:29,15600,0.03871394880115986,4942.25383327071,1000.0,42.06490140851947 158 | 1:12:57,15700,0.03779404971748591,5048.171733012607,1000.0,42.91803369233808 159 | 1:13:25,15800,0.037720352932810786,4572.81975800202,1000.0,39.089236887957554 160 | 1:13:53,15900,0.03757006943225861,4980.890845911657,1000.0,42.3761092677635 161 | 1:14:22,16000,0.03759148234501481,4945.986355676252,1000.0,42.09496559381774 162 | 1:14:49,16100,0.038014373816549776,4985.164891241461,1000.0,42.41053523412278 163 | 1:15:17,16200,0.03765553265810013,4906.820662878059,1000.0,41.77949939758761 164 | 1:15:45,16300,0.03673014665022492,5065.592936818584,1000.0,43.058355502212336 165 | 1:16:13,16400,0.03692703541368246,4878.639632701242,1000.0,41.552510883902045 166 | 1:16:40,16500,0.03718964777886868,4945.421393651169,1000.0,42.09041501885444 167 | 1:17:08,16600,0.03697293806821108,4943.902495178041,1000.0,42.078180813621664 168 | 1:17:35,16700,0.03740268375724554,4565.880962297828,1000.0,39.033347273071946 169 | 1:18:03,16800,0.0370396657846868,4953.883857657799,1000.0,42.158577258308796 170 | 1:18:31,16900,0.036925401091575626,5073.003081053843,1000.0,43.11804166753715 171 | 1:18:59,17000,0.036611346248537305,5060.83477878757,1000.0,43.020030174409754 172 | 1:19:27,17100,0.03681207664310932,4977.750882958058,1000.0,42.35081794521805 173 | 1:19:54,17200,0.036499656718224284,4736.327397360794,1000.0,40.40623473372171 174 | 1:20:22,17300,0.03616801258176565,5005.126819117355,1000.0,42.5713217032624 175 | 1:20:50,17400,0.0361792298592627,5022.783614270801,1000.0,42.71354112047974 176 | 1:21:17,17500,0.03657400170341134,5083.604758117801,1000.0,43.20343453302941 177 | 1:21:45,17600,0.036434238832443955,5014.755789514096,1000.0,42.64887975082009 178 | 1:22:13,17700,0.036985368244349955,5046.186237262444,1000.0,42.90204120638456 179 | 1:22:40,17800,0.03656367681920528,4808.784758801466,1000.0,40.98985388021145 180 | 1:23:08,17900,0.03556339886039495,5022.534675220311,1000.0,42.711536001975745 181 | 1:23:36,18000,0.035886919256299735,4949.916111626013,1000.0,42.12661842753555 182 | 1:24:03,18100,0.03538292992860079,4865.348617973019,1000.0,41.4454563277129 183 | 1:24:31,18200,0.03578974638134241,5050.283902591207,1000.0,42.9350464924483 184 | 1:24:59,18300,0.036011181976646184,4691.842198740683,1000.0,40.04792174613999 185 | 1:25:26,18400,0.036012686621397734,5040.916527147083,1000.0,42.85959550233704 186 | 1:25:54,18500,0.0357749555259943,5034.609614536475,1000.0,42.80879548862412 187 | 1:26:21,18600,0.03564588150009513,4968.540989080966,1000.0,42.2766354150108 188 | 1:26:49,18700,0.03614397956058383,4943.274751832965,1000.0,42.07312455670058 189 | 1:27:17,18800,0.03586820861324668,5059.121184251277,1000.0,43.00622775929532 190 | 1:27:45,18900,0.03573339950293303,5022.8939832748765,1000.0,42.71443010487934 191 | 1:28:12,19000,0.035496701188385486,5035.4244684925,1000.0,42.815358857215976 192 | 1:28:40,19100,0.0353499143384397,4956.2236838437,1000.0,42.17742375415682 193 | 1:29:07,19200,0.03541418816894293,5045.210010868442,1000.0,42.8941780382603 194 | 1:29:35,19300,0.03539438188076019,5089.0368458593575,1000.0,43.24718813305499 195 | 1:30:02,19400,0.03515978574752807,4972.192110553764,1000.0,42.30604394376919 196 | 1:30:30,19500,0.03523656824603677,5065.442904731123,1000.0,43.05714704530625 197 | 1:30:57,19600,0.03505357289686799,4987.630544921891,1000.0,42.4303952271987 198 | 1:31:25,19700,0.03516284739598632,5048.201511495101,1000.0,42.91827354778122 199 | 1:31:52,19800,0.03521464221179485,5000.32141739392,1000.0,42.53261584375262 200 | 1:32:20,19900,0.03504663400352001,5068.030420065255,1000.0,43.07798859212509 201 | 1:32:47,20000,0.035235555991530415,4987.542736224073,1000.0,42.42968795831318 202 | -------------------------------------------------------------------------------- /dt_runs/dt_halfcheetah-medium-v2_model_22-02-13-09-03-10_best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/d6694248b48c57c84fc7487e6e8017dcca861b02/dt_runs/dt_halfcheetah-medium-v2_model_22-02-13-09-03-10_best.pt -------------------------------------------------------------------------------- /dt_runs/dt_hopper-medium-v2_log_22-02-12-09-43-59.csv: -------------------------------------------------------------------------------- 1 | duration,num_updates,action_loss,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score 2 | 0:00:50,100,0.648569563627243,9.791683762848228,11.8,0.9237462236331611 3 | 0:01:28,200,0.607141290307045,10.96003875225573,12.6,0.9596451031145774 4 | 0:02:06,300,0.5481627225875855,14.825391533019546,15.3,1.0784119461441661 5 | 0:02:46,400,0.48277618080377577,71.08845254559778,45.8,2.8071509551421006 6 | 0:03:28,500,0.4129638901352882,204.7435669163048,109.6,6.913838465675286 7 | 0:04:12,600,0.36223515719175337,354.63525885582237,168.1,11.519411115243985 8 | 0:04:59,700,0.3224516436457634,594.2275191299143,246.2,18.881123740463778 9 | 0:05:47,800,0.2943602591753006,614.9085918051827,273.8,19.516570451649027 10 | 0:06:32,900,0.27422395676374434,377.9920564021816,174.2,12.237072158155097 11 | 0:07:18,1000,0.2589342734217644,421.65785246476787,179.5,13.57874756034243 12 | 0:08:04,1100,0.2483059647679329,616.1879212801271,230.3,19.555879133560286 13 | 0:08:49,1200,0.23774641558527945,389.59061932566686,183.8,12.593449643014365 14 | 0:09:35,1300,0.22756794542074205,514.4769711184019,205.5,16.430708123978885 15 | 0:10:23,1400,0.21979302570223808,806.700850504056,283.5,25.40958006167437 16 | 0:11:12,1500,0.2124730460345745,818.3928069437771,280.1,25.768827156039386 17 | 0:12:00,1600,0.20533964157104492,814.1046868625456,281.3,25.637070363460413 18 | 0:12:48,1700,0.1989651092886925,788.8404227433434,279.5,24.86080049597618 19 | 0:13:34,1800,0.1932859268784523,623.3517816734192,229.9,19.775995933002296 20 | 0:14:20,1900,0.18721728697419165,604.2032610287299,230.0,19.18763842085634 21 | 0:15:06,2000,0.18397217571735383,642.4768243825004,236.6,20.363632062016833 22 | 0:15:56,2100,0.18052443176507948,918.1469595537612,307.9,28.833873597217902 23 | 0:16:45,2200,0.175783871114254,893.0285447408984,302.5,28.062085096029175 24 | 0:17:34,2300,0.17322211995720863,836.5358952911974,287.4,26.32629175191108 25 | 0:18:24,2400,0.17006334871053697,1092.0053090515057,349.0,34.17584585055042 26 | 0:19:16,2500,0.16640724003314972,1150.2123894071196,366.1,35.96431680466597 27 | 0:20:04,2600,0.16354369193315507,900.2408471609036,306.1,28.28369032535301 28 | 0:20:55,2700,0.16002631694078445,1080.80069481592,346.6,33.831572834450206 29 | 0:21:50,2800,0.1577116085588932,1501.9704918554128,474.2,46.77243748792401 30 | 0:22:39,2900,0.15638397857546807,902.4929845285993,307.1,28.352889505971486 31 | 0:23:29,3000,0.15354946807026862,1038.9899254449783,341.3,32.546894988863315 32 | 0:24:30,3100,0.1524072003364563,1218.1978591448742,382.5,38.05323858505808 33 | 0:25:23,3200,0.15003470748662948,1447.5262594953404,444.9,45.09958381444963 34 | 0:26:16,3300,0.1457542099058628,1208.6833328336916,385.5,37.7608952164205 35 | 0:27:04,3400,0.14503645561635495,867.3044579610547,293.6,27.27168671586956 36 | 0:27:55,3500,0.1434373440593481,1113.4831141617656,358.3,34.83577296531335 37 | 0:28:47,3600,0.14136259764432907,1268.2816393395283,400.5,39.59211299008237 38 | 0:29:42,3700,0.13946433179080486,1551.0203026384331,495.0,48.27954214519849 39 | 0:30:47,3800,0.13796746104955673,1359.1061573845432,424.6,42.382787448458394 40 | 0:31:42,3900,0.13660455361008644,1337.3335869270675,423.7,41.713803372608346 41 | 0:32:35,4000,0.13492576733231545,1260.5131397154366,403.9,39.35341804352497 42 | 0:33:31,4100,0.1337595248222351,1602.806883457774,491.9,49.870736808158696 43 | 0:34:31,4200,0.1327200772613287,1919.7463622657567,608.1,59.60902033994776 44 | 0:35:28,4300,0.13013287082314492,1718.7765705236027,538.6,53.4340218176103 45 | 0:36:28,4400,0.13039979845285415,1956.564913645633,619.7,60.74030727812123 46 | 0:37:28,4500,0.12748417481780053,1900.1402898916415,604.5,59.00660409176687 47 | 0:38:28,4600,0.12566563069820405,1937.0495269455532,612.3,60.1406774382772 48 | 0:39:27,4700,0.12618642933666707,1619.0080505571393,501.5,50.36853392498647 49 | 0:40:27,4800,0.12464966714382171,1835.7138398829109,574.2,57.02703676398766 50 | 0:41:23,4900,0.12264418751001357,1636.013629741877,510.4,50.89104741035633 51 | 0:42:19,5000,0.12202038705348968,1671.3287522906096,522.1,51.97613998901799 52 | 0:43:18,5100,0.12134742371737957,1868.2987074032485,597.8,58.02823951711985 53 | 0:44:18,5200,0.12005957908928394,1953.3116203419115,615.0,60.640346576718976 54 | 0:45:17,5300,0.12018336251378059,1840.783557724279,589.3,57.182808932071914 55 | 0:46:16,5400,0.11623787842690944,1836.3442064666137,576.7,57.04640540983813 56 | 0:47:13,5500,0.11780578069388867,1797.5613148297975,564.9,55.85476214607552 57 | 0:48:09,5600,0.11479436106979847,1671.5396884626884,523.8,51.98262121457733 58 | 0:49:06,5700,0.11438754841685295,1727.765210278411,544.7,53.710206794081685 59 | 0:50:00,5800,0.1144917456805706,1508.930770257324,475.4,46.98629902638847 60 | 0:50:57,5900,0.11335794746875763,1772.1084906400379,555.9,55.07269858120537 61 | 0:51:59,6000,0.11140007726848125,1890.3838880351875,593.4,58.70682885428128 62 | 0:52:51,6100,0.11162008993327617,1321.517232965491,418.1,41.227830025595054 63 | 0:53:50,6200,0.11113816753029823,1864.8209146007255,586.8,57.92138084333405 64 | 0:54:49,6300,0.10879709132015705,1830.0396796530142,574.5,56.85269249696433 65 | 0:55:46,6400,0.10864685870707035,1704.361323424118,528.2,52.991098884930686 66 | 0:56:43,6500,0.10770384579896927,1694.8013613882708,530.7,52.69735945805852 67 | 0:57:40,6600,0.10688597314059735,1730.4566345668431,540.8,53.79290350615956 68 | 0:58:33,6700,0.10562076717615128,1342.650621540579,428.8,41.87717459669648 69 | 0:59:29,6800,0.10509590849280358,1643.347321174611,516.6,51.11638243890885 70 | 1:00:23,6900,0.1040635897219181,1428.1700694434235,451.1,44.50484545137255 71 | 1:01:21,7000,0.10416218921542168,1847.0681558745487,575.5,57.37590951676671 72 | 1:02:16,7100,0.10321411937475204,1555.1056931581793,487.7,48.40506986856386 73 | 1:03:15,7200,0.10393177919089794,1873.3177903286244,583.9,58.18245587659864 74 | 1:04:09,7300,0.10119697332382202,1410.449906103196,449.3,43.960375650747636 75 | 1:05:06,7400,0.10007099233567715,1770.177019307853,552.2,55.013352186312936 76 | 1:06:02,7500,0.09971646152436733,1691.5973830175506,531.6,52.59891400745974 77 | 1:06:58,7600,0.0993777596950531,1603.1716406045,502.3,49.88194433752179 78 | 1:07:51,7700,0.09935993455350399,1446.3983956426814,460.7,45.06492906577724 79 | 1:08:48,7800,0.09820244528353214,1700.3049286909488,533.5,52.86646208620487 80 | 1:09:46,7900,0.09904529996216298,1840.4137680053555,576.5,57.17144677187792 81 | 1:10:46,8000,0.09797265447676182,2054.453893919757,637.4,63.74804442759974 82 | 1:11:51,8100,0.09773713268339634,1746.5531287886138,544.1,54.28748444378512 83 | 1:12:45,8200,0.0960708150267601,1334.7768820199085,427.5,41.63524604870957 84 | 1:13:43,8300,0.0948629393428564,1863.1638469063248,584.1,57.870465775573685 85 | 1:14:39,8400,0.09478869378566741,1588.7113011965953,499.8,49.43763589841631 86 | 1:15:42,8500,0.09367111779749393,1856.3315434318633,582.9,57.66053639517661 87 | 1:16:37,8600,0.09403191953897476,1533.9691151360785,487.5,47.75562729850238 88 | 1:17:33,8700,0.09381491005420685,1623.3251889225658,511.9,50.50118233346688 89 | 1:18:34,8800,0.0926191609352827,1985.244422526174,623.5,61.621513968059595 90 | 1:19:32,8900,0.09174103327095509,1926.85217175701,597.6,59.827353467171164 91 | 1:20:33,9000,0.09157714270055294,2029.5018650797563,643.5,62.98136830239377 92 | 1:21:27,9100,0.09154749773442745,1487.513186840884,475.5,46.32822228359999 93 | 1:22:25,9200,0.089630077034235,1835.1453857581964,576.7,57.00957044056811 94 | 1:23:20,9300,0.09010873667895794,1491.1071680832522,474.4,46.43865096379391 95 | 1:24:12,9400,0.0887365110218525,1379.4699678615286,437.8,43.00848596023214 96 | 1:25:07,9500,0.08870436787605286,1608.0137288287613,505.5,50.030722357196524 97 | 1:26:06,9600,0.08959654949605465,1908.4269511436937,597.6,59.26122007431307 98 | 1:27:05,9700,0.08740700207650662,1886.0923196610133,595.4,58.57496610944132 99 | 1:28:01,9800,0.08753197766840458,1741.2856899368846,544.4,54.12563709924658 100 | 1:28:58,9900,0.08853863686323166,1715.1616411863065,558.0,53.32294948617854 101 | 1:29:57,10000,0.08578162617981434,1883.3715103037346,600.7,58.491366511635526 102 | 1:30:50,10100,0.08640753485262394,1368.45523023234,436.9,42.670047093402644 103 | 1:31:44,10200,0.08550317868590356,1549.098964908117,488.1,48.220507115392444 104 | 1:32:37,10300,0.0864135567843914,1397.06404551233,448.0,43.54908165152379 105 | 1:33:33,10400,0.08532103762030602,1656.1472298123874,517.1,51.509672476377425 106 | 1:34:31,10500,0.08499579414725304,1813.7486566631512,572.5,56.352134467731574 107 | 1:35:24,10600,0.08438526280224323,1451.7402799344457,457.1,45.229063821166065 108 | 1:36:24,10700,0.08394303046166897,2026.3684046995688,645.2,62.885089587818165 109 | 1:37:30,10800,0.08432248152792454,2047.6504977082448,651.4,63.53900325186491 110 | 1:38:30,10900,0.083306999579072,1507.263740158428,474.5,46.93507785375283 111 | 1:39:30,11000,0.08242204055190086,1986.693184916497,625.3,61.666028646319994 112 | 1:40:25,11100,0.08275152318179607,1611.692642919325,507.5,50.14376068438045 113 | 1:41:23,11200,0.08317864269018173,1738.2116365646696,547.8,54.031183724605235 114 | 1:42:17,11300,0.08129370994865895,1545.6414613147927,488.7,48.11427184791928 115 | 1:43:16,11400,0.08203768216073513,1944.795598015103,614.1,60.37868324498948 116 | 1:44:13,11500,0.08116393819451333,1791.7587685328272,560.1,55.67647308830729 117 | 1:45:10,11600,0.08064296282827854,1765.7079751128088,554.9,54.876036318781644 118 | 1:46:10,11700,0.08014843612909317,2063.0929694438673,639.5,64.01348869229892 119 | 1:47:11,11800,0.08052294842898845,1647.6057412115974,513.7,51.247226667824705 120 | 1:48:10,11900,0.08133552581071854,1938.3617689255602,609.1,60.18099738993386 121 | 1:49:04,12000,0.08001792185008526,1516.4366554049445,478.4,47.21692487962545 122 | 1:50:03,12100,0.08015075728297233,1985.689528796322,619.4,61.63519030486931 123 | 1:51:01,12200,0.08015875220298767,1832.2115384790377,578.6,56.919425038831264 124 | 1:52:02,12300,0.07988612353801727,2097.2479682436133,657.6,65.06293530460106 125 | 1:53:02,12400,0.07985741242766381,2070.9892126092855,647.4,64.25610868735286 126 | 1:54:01,12500,0.07837412409484386,1967.1626051110143,613.0,61.065931983066335 127 | 1:55:00,12600,0.07846987895667552,1896.2767210775896,594.5,58.8878920629047 128 | 1:56:00,12700,0.07779640451073647,1982.3106103943442,634.3,61.531369646259684 129 | 1:57:01,12800,0.07785647764801978,2079.236804131393,662.9,64.50952421324045 130 | 1:57:58,12900,0.07795421272516251,1814.930237173078,568.4,56.388439714602626 131 | 1:58:58,13000,0.07678936377167701,1962.9161090733062,613.0,60.935454130993904 132 | 1:59:57,13100,0.07757695533335209,1948.8172735817807,614.4,60.50225326248453 133 | 2:00:56,13200,0.0762648443132639,1949.0705922640893,608.5,60.510036733201076 134 | 2:01:54,13300,0.07659673191606998,1914.0503021491063,594.8,59.43400317692761 135 | 2:02:55,13400,0.076354498565197,2099.8988897713652,661.1,65.14438752871293 136 | 2:03:49,13500,0.07569374114274979,1476.731278263293,473.6,45.99693732302232 137 | 2:04:53,13600,0.07693080961704254,1595.9436440485208,504.7,49.659856890121254 138 | 2:05:49,13700,0.07648300837725401,1676.1932520605314,522.4,52.12560662592289 139 | 2:06:47,13800,0.07601972743868828,1897.9028175280557,600.7,58.93785550811585 140 | 2:07:46,13900,0.07537052042782307,1857.7237638598253,591.7,57.70331376490421 141 | 2:08:44,14000,0.07650071486830712,1919.7922731032074,606.1,59.610430996499474 142 | 2:09:40,14100,0.07521079897880555,1692.724417671338,533.8,52.63354327816471 143 | 2:10:35,14200,0.07573584727942943,1586.0177390814351,499.7,49.354873499466926 144 | 2:11:33,14300,0.07553605400025845,1803.2000451872354,564.3,56.02801779471406 145 | 2:12:35,14400,0.07501246888190508,1686.5194107276388,527.8,52.442888213160735 146 | 2:13:33,14500,0.07581782959401608,1854.8035165860576,580.6,57.61358623698045 147 | 2:14:33,14600,0.07355216234922408,1552.2497942796683,491.2,48.317319509657295 148 | 2:15:35,14700,0.07433047093451023,1879.6757295110558,583.1,58.37780993810354 149 | 2:16:33,14800,0.07427788361907005,1815.40320512342,565.7,56.40297212949521 150 | 2:17:29,14900,0.0743489333987236,1652.5398888765135,514.2,51.398833306194234 151 | 2:18:30,15000,0.07492121480405331,2057.1583625945827,649.6,63.83114194153946 152 | 2:19:26,15100,0.0737347611784935,1721.4787061089912,537.2,53.517047645035845 153 | 2:20:26,15200,0.07354043800383807,1926.6320703256279,601.8,59.82059062982249 154 | 2:21:26,15300,0.07337431252002716,2044.7535484606428,628.0,63.449991579174416 155 | 2:22:24,15400,0.0734033976867795,1883.0631290528704,592.7,58.481891188245406 156 | 2:23:25,15500,0.07303308382630348,2093.071096707642,661.8,64.93459673521194 157 | 2:24:23,15600,0.07291058480739593,1909.5132409500552,591.9,59.294597418693854 158 | 2:25:20,15700,0.07190525654703378,1779.054087871716,551.3,55.28610902598202 159 | 2:26:16,15800,0.07150179386138916,1712.5653478447712,537.8,53.24317576790696 160 | 2:27:19,15900,0.07252746630460023,2321.3799708327538,726.3,71.94961599824569 161 | 2:28:16,16000,0.07233786560595036,1837.2426694570743,567.8,57.07401158681814 162 | 2:29:16,16100,0.0715704045817256,2060.4496216358557,648.5,63.932269178326194 163 | 2:30:12,16200,0.07241638153791427,1636.7950900373085,511.0,50.915058562120606 164 | 2:31:08,16300,0.07132892843335867,1688.3502849395045,528.3,52.499143660583215 165 | 2:32:03,16400,0.07159417893737555,1677.9435934549122,533.3,52.17938762171431 166 | 2:33:02,16500,0.07100436978042125,1936.9506894179685,599.3,60.13764055605974 167 | 2:33:58,16600,0.07182492427527905,1749.9139272963662,551.0,54.39074835046155 168 | 2:34:56,16700,0.07084676377475262,1883.6681378647816,586.6,58.50048069111131 169 | 2:35:55,16800,0.07059317883104085,1952.9326750651348,611.9,60.628703102822435 170 | 2:36:53,16900,0.07132943969219924,1864.4481523180718,578.8,57.909927348136506 171 | 2:37:52,17000,0.07135576914995909,1980.822164656458,619.4,61.48563565732358 172 | 2:38:49,17100,0.07100418530404567,1819.4893984863684,570.1,56.528524521023606 173 | 2:39:47,17200,0.07062721576541663,1880.379315312543,587.3,58.39942831789515 174 | 2:40:46,17300,0.07054258473217487,1972.3186985592565,617.5,61.2243581283488 175 | 2:41:43,17400,0.07025174267590045,1827.6829657737794,573.8,56.7802800980874 176 | 2:42:42,17500,0.0699254022911191,1978.9739193960581,615.1,61.428846467003225 177 | 2:43:37,17600,0.06960193779319525,1532.9525898071508,485.2,47.724393537698674 178 | 2:44:35,17700,0.07049755524843931,1923.1744428991624,596.2,59.71435155745179 179 | 2:45:33,17800,0.06951983563601971,1835.5939412025334,571.0,57.02335275671601 180 | 2:46:32,17900,0.06925403777509928,1952.8023287874435,614.1,60.62469808263926 181 | 2:47:32,18000,0.06954852517694235,2094.6554663952656,646.9,64.9832780837624 182 | 2:48:31,18100,0.06951434217393399,1907.3823543408605,596.9,59.22912378930418 183 | 2:49:28,18200,0.06924330212175846,1785.7918832203952,561.8,55.493134549376535 184 | 2:50:24,18300,0.06969802789390087,1785.819299689505,559.9,55.49397694790207 185 | 2:51:22,18400,0.06872367691248656,1857.8099430956051,584.9,57.7059617084035 186 | 2:52:17,18500,0.06967751152813434,1608.8758359683316,500.6,50.05721146417521 187 | 2:53:14,18600,0.06898383647203446,1765.5940922060909,556.3,54.872537152192436 188 | 2:54:13,18700,0.06903096206486226,1981.735879356441,616.7,61.5137104583836 189 | 2:55:14,18800,0.06905508708208799,2094.554532237645,655.3,64.98017678048315 190 | 2:56:13,18900,0.0685419424250722,2017.348140125326,630.5,62.60793290703449 191 | 2:57:13,19000,0.0683863053843379,2051.099026514851,640.9,63.64496276001006 192 | 2:58:13,19100,0.06867271322757006,1953.8720688969238,614.4,60.657566920975924 193 | 2:59:10,19200,0.06773304492235184,1836.0740719357268,574.0,57.038105255299484 194 | 3:00:08,19300,0.0687664358690381,1766.6540816909292,555.3,54.905106392802324 195 | 3:01:07,19400,0.06828755713999271,1991.9291222636632,627.0,61.826908075519405 196 | 3:02:07,19500,0.06813475776463747,2013.755971366938,629.2,62.49755991722967 197 | 3:03:06,19600,0.06790053885430097,1929.1703685857833,605.8,59.898582391021215 198 | 3:04:08,19700,0.06753950744867325,2249.912307559888,708.0,69.75370032714292 199 | 3:05:06,19800,0.06816073544323445,1819.4627829465724,573.1,56.52770673185496 200 | 3:06:05,19900,0.06771818049252033,1952.9916977332282,610.7,60.63051663352823 201 | 3:07:07,20000,0.06740686003118754,2127.4807980789537,668.1,65.99186933961676 202 | -------------------------------------------------------------------------------- /dt_runs/dt_hopper-medium-v2_log_22-02-13-05-45-16.csv: -------------------------------------------------------------------------------- 1 | duration,num_updates,action_loss,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score 2 | 0:00:16,100,0.58705835044384,7.6599875806839055,11.4,0.8582477193015967 3 | 0:00:18,200,0.5506787076592445,7.748187573336812,11.7,0.8609577525836166 4 | 0:00:20,300,0.4965582340955734,9.028784469474434,13.0,0.9003053772828818 5 | 0:00:22,400,0.4449744701385498,12.478886979050909,15.7,1.0063132390309857 6 | 0:00:25,500,0.3942586645483971,29.099152572583666,23.8,1.5169875776529618 7 | 0:00:30,600,0.3558053889870644,205.50673310499246,118.6,6.93728751265191 8 | 0:00:34,700,0.32275239020586016,136.28480308001951,80.0,4.810374249160198 9 | 0:00:38,800,0.2935658752918243,137.68678070114248,81.1,4.853451418438911 10 | 0:00:43,900,0.2751554834842682,162.0794867491835,96.0,5.602941789587418 11 | 0:00:49,1000,0.26037271082401275,369.15330669747465,162.4,11.96549270388616 12 | 0:00:56,1100,0.24912963420152665,379.2982901402144,166.8,12.277207500547892 13 | 0:01:01,1200,0.2368820485472679,305.1180341989391,143.1,9.99794469765019 14 | 0:01:09,1300,0.22530338153243065,595.5705121882936,224.3,18.92238854986181 15 | 0:01:18,1400,0.21824575945734978,692.0600123557897,249.4,21.887125268700693 16 | 0:01:28,1500,0.2105904772877693,851.894709691083,301.8,26.79820673675532 17 | 0:01:39,1600,0.20288973897695542,965.9479059977399,333.4,30.302605644453173 18 | 0:01:51,1700,0.1970210325717926,1090.6751702807487,368.5,34.13497600203873 19 | 0:02:03,1800,0.19060962677001952,1096.1985655927824,367.2,34.30468786566972 20 | 0:02:16,1900,0.1870059372484684,1021.2478684116247,339.2,32.001752482547005 21 | 0:02:23,2000,0.18280440151691438,571.9131945266656,211.4,18.195493724840308 22 | 0:02:31,2100,0.17874206289649008,582.7225431354138,215.5,18.527621807911064 23 | 0:02:42,2200,0.17602714478969575,1042.1296025791903,344.3,32.64336471944476 24 | 0:02:49,2300,0.17346609622240067,466.857331930528,194.3,14.967546924127348 25 | 0:03:02,2400,0.17019012093544006,1248.2018646440038,411.1,38.97514176272092 26 | 0:03:14,2500,0.16634261429309846,1082.2044193430436,357.1,33.87470367916879 27 | 0:03:26,2600,0.16278548002243043,1165.8532960752375,373.9,36.44489935752825 28 | 0:03:36,2700,0.16118941470980644,898.761443193561,303.3,28.238234153890186 29 | 0:03:49,2800,0.1595885720849037,1224.6379636230563,398.4,38.25111725772693 30 | 0:04:01,2900,0.15592970699071884,1190.467760412031,381.5,37.20120347463078 31 | 0:04:17,3000,0.15390647381544112,1186.1294589005329,381.6,37.06790480725033 32 | 0:04:30,3100,0.1523503193259239,1299.5409831427771,415.4,40.55258769685798 33 | 0:04:45,3200,0.14957921594381332,1201.5413719154517,383.1,37.54145130032536 34 | 0:04:57,3300,0.14762532696127892,1153.2941109932267,370.2,36.05900579287412 35 | 0:05:08,3400,0.14655685007572175,1121.9460962537846,359.9,35.09580658260362 36 | 0:05:22,3500,0.1437293741106987,1206.0090636546186,384.2,37.67872561229266 37 | 0:05:36,3600,0.14141346216201783,1369.8545736393387,424.6,42.71304332380898 38 | 0:05:48,3700,0.13989000879228114,1230.6994755956225,391.8,38.43736329574716 39 | 0:06:01,3800,0.13820896722376347,1269.8048388788216,402.7,39.63891482444178 40 | 0:06:15,3900,0.1355729253590107,1300.4971754297944,412.8,40.581967664405425 41 | 0:06:28,4000,0.13482317134737967,1435.1801603099339,436.6,44.720237527798105 42 | 0:06:42,4100,0.13249746218323707,1455.0374846589398,453.0,45.330373745036205 43 | 0:06:56,4200,0.13205426439642906,1407.5396292299288,434.8,43.870954473384444 44 | 0:07:09,4300,0.12887104891240597,1335.1557334964264,420.2,41.64688664049903 45 | 0:07:22,4400,0.12901252642273903,1387.453436597311,431.1,43.25378604846547 46 | 0:07:38,4500,0.12618561908602716,1659.9067058255416,513.4,51.62518614947599 47 | 0:08:02,4600,0.12637492649257184,1867.6987469694445,575.5,58.00980513073728 48 | 0:08:22,4700,0.1249912465363741,1693.831326036654,527.5,52.667554148459885 49 | 0:08:41,4800,0.1225405489653349,2005.564625565703,622.5,62.245872597557884 50 | 0:08:56,4900,0.12207958340644837,1552.9959720015906,476.1,48.34024656894481 51 | 0:09:14,5000,0.11972943887114525,1697.4102205267595,527.1,52.77751927305113 52 | 0:09:36,5100,0.11892052635550499,1595.5984026928124,503.4,49.64924900302107 53 | 0:09:59,5200,0.11807537116110325,1837.6974757946268,566.4,57.08798596793279 54 | 0:10:18,5300,0.11767125859856606,1749.7292239389535,540.9,54.38507315445719 55 | 0:10:38,5400,0.11655780456960202,1622.2056858165938,506.1,50.46678447712636 56 | 0:10:52,5500,0.11446010112762452,1471.702727493541,460.4,45.8424300545242 57 | 0:11:06,5600,0.1141236226260662,1352.5667477089817,426.3,42.181857523948345 58 | 0:11:23,5700,0.11399869605898857,1678.632694607805,524.4,52.20056094614266 59 | 0:11:40,5800,0.11210516050457954,1842.587085839903,577.6,57.23822414324585 60 | 0:11:57,5900,0.11147593788802623,1753.9495147580062,549.4,54.514745824889765 61 | 0:12:15,6000,0.11001419901847839,1506.8334354209126,469.5,46.92185630888641 62 | 0:12:31,6100,0.11053332202136516,1674.3811486601487,524.2,52.06992793051954 63 | 0:12:45,6200,0.10882183827459813,1393.1411987958313,452.4,43.42854824962419 64 | 0:13:01,6300,0.1079296188056469,1768.1413366031652,551.9,54.950803792425354 65 | 0:13:15,6400,0.10718245133757591,1328.6230920772296,422.1,41.44616467745766 66 | 0:13:28,6500,0.10621959753334523,1413.6343216286614,443.2,44.05822001329484 67 | 0:13:41,6600,0.10474265560507774,1303.4333014487079,414.9,40.67218308270794 68 | 0:14:01,6700,0.10399792104959488,1608.6571898780828,503.0,50.050493343643275 69 | 0:14:22,6800,0.10519288182258606,1603.9102434590075,508.6,49.90463865140668 70 | 0:14:38,6900,0.10295578300952911,1641.4586959590276,513.5,51.05835253394463 71 | 0:14:54,7000,0.10291605673730374,1478.1202835071629,460.6,46.039615902992296 72 | 0:15:08,7100,0.10161948412656784,1315.2297862258495,420.8,41.03464191513328 73 | 0:15:21,7200,0.10236643016338348,1400.504443859272,442.3,43.65479134313692 74 | 0:15:34,7300,0.10115573808550835,1313.6194417952177,419.1,40.98516246653852 75 | 0:15:51,7400,0.10036485396325588,1700.4470108479368,539.9,52.87082770305626 76 | 0:16:03,7500,0.09996483229100704,1320.9718447322762,416.1,41.2110724248382 77 | 0:16:16,7600,0.09912870593369007,1309.7613652794228,416.2,40.86661919405176 78 | 0:16:30,7700,0.09863243266940117,1409.2837842612198,442.7,43.92454538696198 79 | 0:16:44,7800,0.09793490894138814,1526.0329049865297,483.9,47.51177927775458 80 | 0:16:57,7900,0.0967340862751007,1350.4121469342103,428.5,42.115655253024414 81 | 0:17:11,8000,0.0969493930786848,1383.68947443956,440.0,43.13813453407236 82 | 0:17:26,8100,0.09643730230629444,1558.202900797308,487.6,48.50023468129119 83 | 0:17:41,8200,0.09531646214425564,1561.4171121163945,492.4,48.59899454949717 84 | 0:17:54,8300,0.09503262110054493,1422.10645071902,447.9,44.31853468128802 85 | 0:18:14,8400,0.09369626365602017,1354.5518784420171,429.7,42.24285265777855 86 | 0:18:30,8500,0.0940320885181427,1721.6240654067515,539.3,53.521513955326036 87 | 0:18:43,8600,0.09190608508884907,1333.7302891543145,425.9,41.603088432669324 88 | 0:18:58,8700,0.09398079618811607,1412.6350763141293,449.4,44.02751719827374 89 | 0:19:15,8800,0.0921962857991457,1564.8745480327168,499.1,48.705227737526535 90 | 0:19:33,8900,0.09162565976381302,1414.9750772640714,448.8,44.09941607562691 91 | 0:19:50,9000,0.09145431943237782,1293.8200652232363,416.6,40.37680675289948 92 | 0:20:07,9100,0.09073655694723129,1318.6300063043316,419.6,41.13911708912952 93 | 0:20:22,9200,0.09119737952947617,1266.8673832089642,400.3,39.54865855127972 94 | 0:20:36,9300,0.09058749914169312,1422.9437169267812,448.9,44.34426052570927 95 | 0:20:51,9400,0.09033970206975937,1434.9322409113424,452.8,44.71261995549189 96 | 0:21:06,9500,0.08948310613632202,1508.8319152297242,481.5,46.9832616064655 97 | 0:21:22,9600,0.08770075671374798,1632.6088969342065,517.8,50.78643357822731 98 | 0:21:39,9700,0.08766396410763264,1899.5970417786552,596.7,58.98991224835164 99 | 0:21:52,9800,0.08786047413945199,1336.5411497198434,435.1,41.68945494421404 100 | 0:22:06,9900,0.08775643311440945,1358.2675305021426,434.6,42.35701979594344 101 | 0:22:24,10000,0.08655504539608955,2022.2316754797007,633.8,62.757984431373714 102 | 0:22:38,10100,0.0867172234505415,1367.0176711621184,436.2,42.625876648394765 103 | 0:22:53,10200,0.08766613662242889,1581.4980587426928,499.4,49.21600178560767 104 | 0:23:10,10300,0.08593859888613224,1775.9483970799786,557.0,55.19068355987803 105 | 0:23:25,10400,0.08582285724580288,1523.0310248581814,480.5,47.41954349845613 106 | 0:23:43,10500,0.08467190839350223,1795.7904083387482,563.1,55.8003492670521 107 | 0:24:03,10600,0.08521857552230358,2041.2888566250263,644.8,63.34353544574349 108 | 0:24:18,10700,0.08410436064004898,1602.8570254465026,505.8,49.872277471079336 109 | 0:24:37,10800,0.08285078570246697,2035.9774950178967,638.8,63.18033853046926 110 | 0:24:56,10900,0.08321721538901329,1910.315441868111,602.7,59.31924584690126 111 | 0:25:10,11000,0.08263486288487912,1533.34151869291,485.4,47.73634376799964 112 | 0:25:27,11100,0.0821751532703638,1718.9091486047444,542.7,53.43809541219409 113 | 0:25:41,11200,0.08202179245650769,1357.2624554905592,438.7,42.32613785763039 114 | 0:26:03,11300,0.08097969934344292,1704.5003737226518,540.1,52.99537134488864 115 | 0:26:20,11400,0.08227638863027095,1696.1657059031,534.1,52.73928031238193 116 | 0:26:41,11500,0.0819849994033575,2181.6217040004217,684.9,67.65540300388015 117 | 0:26:56,11600,0.0803079966455698,1617.4956994179001,508.7,50.32206541860498 118 | 0:27:11,11700,0.08127107679843902,1625.5916406300325,510.5,50.5708213365391 119 | 0:27:38,11800,0.08056041635572911,2111.152284674037,668.7,65.49015937976026 120 | 0:28:04,11900,0.07949614979326725,2116.4124414545363,661.2,65.65178297535277 121 | 0:28:25,12000,0.08013886690139771,1948.458225701009,615.8,60.49122115604707 122 | 0:28:44,12100,0.07956540212035179,1961.4530111111933,613.0,60.89049897790466 123 | 0:29:01,12200,0.0802165337651968,1769.0942008397456,553.2,54.98008150228346 124 | 0:29:17,12300,0.07922936201095582,1769.0152809252697,554.4,54.9776566087159 125 | 0:29:33,12400,0.07894258059561253,1701.324602576737,534.0,52.89779258957766 126 | 0:29:48,12500,0.0782620034366846,1598.123807420092,504.0,49.726844597483655 127 | 0:30:07,12600,0.07770861014723777,2018.181032854701,633.4,62.63352437194358 128 | 0:30:26,12700,0.0788294579833746,1904.9747131127847,595.4,59.15514659652905 129 | 0:30:42,12800,0.0776385148614645,1656.168997421394,522.9,51.51034130800771 130 | 0:30:56,12900,0.07836728364229202,1566.2013867953078,492.1,48.745996189975806 131 | 0:31:15,13000,0.07675645492970944,2016.642186385579,628.4,62.58624176996365 132 | 0:31:33,13100,0.07804169245064259,1983.8566420087795,631.7,61.57887301903957 133 | 0:31:54,13200,0.07799001671373844,2235.661913320987,694.2,69.31584266403283 134 | 0:32:10,13300,0.07756067790091038,1732.15815345093,542.5,53.845184381329325 135 | 0:32:26,13400,0.07694390267133713,1727.923094973831,538.3,53.71505795978408 136 | 0:32:42,13500,0.07698791913688183,1717.807506807916,529.9,53.40424636250065 137 | 0:33:02,13600,0.07623212557286024,2270.2641619468595,711.4,70.37903147605317 138 | 0:33:25,13700,0.07545420832931996,2186.8513324685728,677.5,67.81608858644094 139 | 0:33:43,13800,0.07602385364472866,1925.6310774100925,607.5,59.789834118006866 140 | 0:33:58,13900,0.07715249694883823,1559.6266908681528,486.9,48.54398205997616 141 | 0:34:15,14000,0.07557542137801647,1837.0109300251781,575.4,57.06689116022506 142 | 0:34:32,14100,0.07581040024757385,1881.3651901739322,585.3,58.429720312326324 143 | 0:34:48,14200,0.07588129371404648,1766.2240788910162,549.9,54.89189412527175 144 | 0:35:03,14300,0.07475406162440777,1605.9955647399136,502.7,49.96871224036036 145 | 0:35:21,14400,0.07448151648044586,1999.748274965875,622.7,62.06715938873187 146 | 0:35:38,14500,0.07480694655328989,1968.9696964560358,613.4,61.121456678039166 147 | 0:35:59,14600,0.0742546771094203,2278.0463866457135,707.6,70.61814813930562 148 | 0:36:16,14700,0.07436971351504326,1913.3168254558398,605.5,59.411466369490896 149 | 0:36:36,14800,0.07458377715200186,2164.7934212891346,679.2,67.1383371305722 150 | 0:36:52,14900,0.07368173189461232,1729.276290927939,542.3,53.756636263391876 151 | 0:37:11,15000,0.07456804990768433,1985.7972865890838,629.2,61.63850127118573 152 | 0:37:29,15100,0.07371191531419755,1988.7374369845143,626.7,61.72884034249515 153 | 0:37:42,15200,0.0742776420712471,1397.2916904064753,443.0,43.55607626933565 154 | 0:38:00,15300,0.07216435234993696,1887.0935086436627,590.3,58.60572864561577 155 | 0:38:16,15400,0.07415706232190132,1669.9641660017987,521.2,51.934211705946375 156 | 0:38:32,15500,0.07239442680031061,1776.5558927924417,556.5,55.20934947525898 157 | 0:38:51,15600,0.07274777602404356,2127.553916864036,658.3,65.99411598766235 158 | 0:39:09,15700,0.07198308955878019,1890.068433841982,588.9,58.6971362076401 159 | 0:39:30,15800,0.0726435686647892,2001.7398734816695,630.3,62.1283532516777 160 | 0:39:55,15900,0.07255262020975352,2205.541554778544,683.0,68.39036442235515 161 | 0:40:12,16000,0.0726834511384368,1822.5378514999363,565.5,56.622191298955826 162 | 0:40:29,16100,0.07219523519277572,1800.3617619948764,560.6,55.94080869544167 163 | 0:40:46,16200,0.0725621822476387,1769.8021092146962,547.9,55.00183269748238 164 | 0:41:07,16300,0.07264695085585117,2161.877464787035,671.6,67.0487414409137 165 | 0:41:22,16400,0.0710860912874341,1591.5870298422753,507.1,49.5259955468181 166 | 0:41:38,16500,0.07069615848362445,1546.3523127594817,487.9,48.13611347188925 167 | 0:41:59,16600,0.07127196535468101,1913.864225291684,598.1,59.42828577875715 168 | 0:42:19,16700,0.07074172489345074,1916.987946875893,602.6,59.524265259053536 169 | 0:42:36,16800,0.07127759106457233,1789.7967486680652,559.1,55.616188059096295 170 | 0:42:56,16900,0.07107611414045095,2237.348077144997,702.3,69.3676517395731 171 | 0:43:15,17000,0.07096689950674773,2152.111853206733,671.7,66.74868322542102 172 | 0:43:34,17100,0.07153069507330656,1953.656043945779,618.5,60.650929337573245 173 | 0:43:48,17200,0.07130783196538687,1509.047860464053,479.0,46.98989674048901 174 | 0:44:04,17300,0.07058477986603975,1704.9411649068693,540.7,53.008915096353014 175 | 0:44:22,17400,0.07002212528139352,1896.3098303418974,592.1,58.88890937827536 176 | 0:44:40,17500,0.07048300188034773,1993.2344112208605,617.8,61.86701438857295 177 | 0:44:58,17600,0.07025888446718455,1993.996368822359,623.7,61.890426300495385 178 | 0:45:15,17700,0.07019284430891276,1745.5170416158476,545.1,54.255649625699355 179 | 0:45:39,17800,0.07050611365586519,1867.6704561487873,585.4,58.008935866883036 180 | 0:45:57,17900,0.06943127881735563,1973.8263888465667,623.8,61.27068342537766 181 | 0:46:17,18000,0.06930315352976323,2131.42590553842,658.6,66.11308672518246 182 | 0:46:33,18100,0.06961404137313366,1838.9008041579782,573.2,57.12495943942404 183 | 0:46:51,18200,0.0689073083922267,1923.6669795741793,598.2,59.72948524104703 184 | 0:47:10,18300,0.06977932900190353,1999.8375701477264,619.8,62.069903072801026 185 | 0:47:27,18400,0.06863380763679743,1830.1092301026533,573.9,56.85482950432264 186 | 0:47:45,18500,0.06905523050576448,2027.0046160618786,627.8,62.9046378203565 187 | 0:48:03,18600,0.06942136272788048,1927.714094384291,597.9,59.85383690482461 188 | 0:48:21,18700,0.06844715241342783,1944.1314828597108,602.3,60.35827764040752 189 | 0:48:41,18800,0.0683626215532422,2187.889685520358,681.2,67.84799302593333 190 | 0:48:59,18900,0.06905108526349067,1943.8741678622032,608.9,60.35037137889622 191 | 0:49:20,19000,0.06899742167443038,2399.433404744286,746.5,74.34788608097264 192 | 0:49:38,19100,0.06922576077282429,1872.6456133190236,582.6,58.161802563456135 193 | 0:49:56,19200,0.06859909560531378,1971.0211742040472,621.2,61.18449039048303 194 | 0:50:14,19300,0.06815141454339027,2012.8079741748381,632.8,62.46843175219725 195 | 0:50:33,19400,0.06784586634486914,2114.5186809208058,665.0,65.59359528258524 196 | 0:50:51,19500,0.06830030351877213,1925.1327627760627,600.5,59.77452290082284 197 | 0:51:09,19600,0.06879708334803582,1989.2034411212833,620.2,61.74315878722759 198 | 0:51:27,19700,0.06787603754550219,1914.819646727712,599.5,59.45764206113442 199 | 0:51:46,19800,0.06765746429562569,2142.0551940362916,667.3,66.43968228065812 200 | 0:52:04,19900,0.06802730720490217,2031.9133154958054,636.7,63.05546253629185 201 | 0:52:22,20000,0.06749976746737957,2000.280609578123,617.6,62.08351590388534 202 | -------------------------------------------------------------------------------- /dt_runs/dt_hopper-medium-v2_log_22-02-13-08-03-24.csv: -------------------------------------------------------------------------------- 1 | duration,num_updates,action_loss,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score 2 | 0:00:07,100,0.5768070870637894,15.220256738686174,28.2,1.0905445758313295 3 | 0:00:23,200,0.53927430331707,610.8185384021606,431.2,19.390899456515854 4 | 0:00:30,300,0.48396564841270445,360.4948501864895,172.6,11.699452938916638 5 | 0:00:37,400,0.4222676295042038,378.59498950502086,182.1,12.255597882776824 6 | 0:00:44,500,0.37640556633472444,509.4765978553684,194.6,16.277066637650517 7 | 0:00:52,600,0.3399521344900131,470.49704611960397,187.4,15.079380795001388 8 | 0:01:00,700,0.308564290702343,406.9465616177977,220.9,13.12672838644455 9 | 0:01:09,800,0.28991901576519014,445.41739809911604,246.8,14.308783442410448 10 | 0:01:17,900,0.2728688418865204,426.7074591659936,215.9,13.73390179346449 11 | 0:01:26,1000,0.26000842422246934,447.4678159619154,256.9,14.37178458881759 12 | 0:01:33,1100,0.24738058894872667,380.3828067580843,180.5,12.310530361932896 13 | 0:01:43,1200,0.23677757665514945,674.7998336971638,285.9,21.356788959007737 14 | 0:01:51,1300,0.22842204749584197,423.43259440070625,203.7,13.633278287258893 15 | 0:01:59,1400,0.22122607246041298,422.53672219276984,231.8,13.605751714671763 16 | 0:02:07,1500,0.21123996943235399,405.77435810271993,219.9,13.090711257149959 17 | 0:02:17,1600,0.2050129483640194,733.2039900516027,297.8,23.151315270950864 18 | 0:02:25,1700,0.19885642409324647,382.95026782078423,207.9,12.389418179504364 19 | 0:02:32,1800,0.19252268120646476,457.14776126661684,194.2,14.669210622027244 20 | 0:02:41,1900,0.18899804398417472,682.5002188792366,265.6,21.59339102098199 21 | 0:02:49,2000,0.18367402330040933,489.76287227399797,193.6,15.671342636647859 22 | 0:02:56,2100,0.17949361220002175,522.4507914434218,201.7,16.675711755109457 23 | 0:03:08,2200,0.1766086360812187,1045.4661014488202,353.3,32.74588199535546 24 | 0:03:19,2300,0.17255539551377297,962.1617343469703,325.5,30.186271721100088 25 | 0:03:31,2400,0.16942552357912063,1051.079540740047,352.5,32.918360550605335 26 | 0:03:45,2500,0.16768476024270057,1238.0696257378936,422.8,38.66381855473613 27 | 0:03:57,2600,0.16458747163414955,1152.3483758405189,376.6,36.02994713127195 28 | 0:04:12,2700,0.16166277781128882,1450.7791354218537,471.6,45.19953169151833 29 | 0:04:27,2800,0.15864664882421495,1481.1768436957143,481.2,46.133531782011346 30 | 0:04:42,2900,0.1561785088479519,1436.2694174933922,457.9,44.75370604781792 31 | 0:04:56,3000,0.1534361757338047,1350.096616712358,437.8,42.10596027032676 32 | 0:05:09,3100,0.15229857057332993,1241.203750401083,411.9,38.76011768007357 33 | 0:05:27,3200,0.14968092769384383,1785.3264042485334,583.6,55.47883224086285 34 | 0:05:44,3300,0.14797610640525818,1732.7254008953291,558.4,53.8626136282853 35 | 0:06:03,3400,0.14538042560219766,1777.9866275867382,578.2,55.253310237540965 36 | 0:06:23,3500,0.1420811629295349,1933.466870738448,636.0,60.030596731156294 37 | 0:06:40,3600,0.14123285003006458,1728.1008234563808,551.4,53.72051884575908 38 | 0:06:55,3700,0.14014941968023778,1571.306280129044,507.0,48.90284916036129 39 | 0:07:13,3800,0.13907427355647087,1732.980225837741,549.3,53.870443380355034 40 | 0:07:31,3900,0.1381451866030693,1877.076164748311,593.2,58.29793570213248 41 | 0:07:50,4000,0.13434841446578502,1882.2193547955728,605.9,58.4559653774714 42 | 0:08:06,4100,0.13365551091730596,1677.6444438023234,534.8,52.170195948445006 43 | 0:08:23,4200,0.13209981605410576,1755.1442175964626,553.8,54.551454268473 44 | 0:08:40,4300,0.13079093053936958,1717.570160249074,540.4,53.396953651305466 45 | 0:08:58,4400,0.12959881737828255,1784.3236187187272,577.1,55.44802064917489 46 | 0:09:14,4500,0.12765885718166828,1649.7948931745989,524.1,51.31449055867876 47 | 0:09:31,4600,0.12609561443328857,1753.3609573012516,558.8,54.4966618064198 48 | 0:09:48,4700,0.1266464975476265,1719.148960826124,542.7,53.445463883345 49 | 0:10:03,4800,0.12485880136489869,1656.8995863931625,519.3,51.53278938730361 50 | 0:10:19,4900,0.12301638580858708,1614.517494041304,508.9,50.230557069811475 51 | 0:10:36,5000,0.12231802411377429,1698.8070680791411,539.6,52.82043881581979 52 | 0:10:52,5100,0.12060909569263459,1691.1316295657598,536.5,52.58460326527481 53 | 0:11:08,5200,0.11965408965945244,1613.8536859494986,504.9,50.210160900066356 54 | 0:11:24,5300,0.11932086579501629,1633.1058642723042,516.9,50.801703398391815 55 | 0:11:41,5400,0.11757739819586277,1752.4236254981581,559.4,54.467861346167204 56 | 0:11:59,5500,0.11662465669214725,1861.3790192706463,584.3,57.81562515541182 57 | 0:12:17,5600,0.115715788975358,1891.2314748805488,603.3,58.73287181064944 58 | 0:12:34,5700,0.11567795149981976,1750.1048521230482,545.3,54.3966147073524 59 | 0:12:51,5800,0.11443565711379051,1734.2935935550147,547.7,53.91079792141888 60 | 0:13:07,5900,0.11372581712901592,1684.4966346366557,533.7,52.380736387930874 61 | 0:13:24,6000,0.11232412569224834,1808.8193143927888,567.4,56.200675479931874 62 | 0:13:41,6100,0.11012794494628907,1726.8504730623995,547.2,53.68210057519062 63 | 0:13:58,6200,0.11078558504581451,1820.7177899697822,567.2,56.56626808202936 64 | 0:14:14,6300,0.10949359409511089,1704.8040507254257,532.3,53.00470212553553 65 | 0:14:31,6400,0.10799183495342732,1756.3496267378057,549.9,54.588491673956085 66 | 0:14:48,6500,0.10908653169870376,1697.6828180234904,536.0,52.78589510468689 67 | 0:15:05,6600,0.10788251616060734,1819.3401225899802,576.7,56.5239378693103 68 | 0:15:22,6700,0.1067982043325901,1772.9935044962378,562.1,55.099891520039144 69 | 0:15:43,6800,0.10549175329506397,2176.9562466799553,681.6,67.51205214597175 70 | 0:15:59,6900,0.1043413021415472,1661.7817695718156,521.9,51.68279936468689 71 | 0:16:15,7000,0.10414625130593777,1626.719309945039,515.2,50.6054701078469 72 | 0:16:33,7100,0.10268543764948845,1902.462114795279,594.5,59.077944491857856 73 | 0:16:51,7200,0.10405284024775029,1849.5535459979571,583.5,57.45227562237113 74 | 0:17:09,7300,0.10223606891930104,1916.8770956968333,604.4,59.52085924533894 75 | 0:17:26,7400,0.10285503178834915,1820.4555035963044,567.8,56.558209070002654 76 | 0:17:45,7500,0.10177698850631714,1997.8351334947868,629.4,62.00837619721546 77 | 0:18:04,7600,0.10077178671956062,1930.3034586927447,616.2,59.93339772160154 78 | 0:18:21,7700,0.10022705808281898,1747.021925033215,549.1,54.301888678832555 79 | 0:18:40,7800,0.0995451832562685,1987.3988974256686,620.1,61.68771237133932 80 | 0:18:58,7900,0.09961777150630952,1954.599837771802,612.2,60.67992834996492 81 | 0:19:16,8000,0.09785485245287419,1838.819038422358,581.2,57.12244710514606 82 | 0:19:35,8100,0.0972407142072916,2039.358552259957,643.9,63.28422490708673 83 | 0:19:54,8200,0.09668638557195663,1939.4590803023634,616.8,60.214713383126494 84 | 0:20:13,8300,0.09642652891576291,1970.2254114031498,624.7,61.16003977988591 85 | 0:20:29,8400,0.09541876822710037,1687.6694012004093,534.0,52.47822282443988 86 | 0:20:47,8500,0.0953488041460514,1934.5308033066078,614.4,60.06328712695807 87 | 0:21:04,8600,0.09406395524740219,1652.1134056900569,523.5,51.38572918231898 88 | 0:21:22,8700,0.09480445750057698,1901.7551330493468,604.4,59.056221768265395 89 | 0:21:38,8800,0.09340312771499157,1660.0000354454983,527.0,51.628053795704446 90 | 0:21:58,8900,0.09331169091165066,2090.103959338495,662.4,64.84342846205394 91 | 0:22:16,9000,0.0933695114403963,1816.7174403446083,570.3,56.443353325487365 92 | 0:22:32,9100,0.09140803597867489,1645.4025423089963,522.1,51.1795311706555 93 | 0:22:49,9200,0.09105824142694473,1734.368473729303,541.8,53.91309868991535 94 | 0:23:08,9300,0.09119718864560128,1937.202447970778,608.8,60.14537609023186 95 | 0:23:23,9400,0.09107057221233844,1616.7921393974525,512.6,50.300447830961694 96 | 0:23:43,9500,0.09078037522733212,2022.9265819026382,636.2,62.779336128549716 97 | 0:24:01,9600,0.08873943082988262,1895.5176928014785,600.6,58.86457015744434 98 | 0:24:16,9700,0.08893276780843734,1558.9210359291624,498.1,48.52230010385842 99 | 0:24:34,9800,0.08921090014278889,1767.7809175971374,563.6,54.939729556788485 100 | 0:24:54,9900,0.0876706263422966,2091.5080736975115,660.3,64.88657128474861 101 | 0:25:13,10000,0.08767062649130822,2000.5026681221698,631.2,62.09033887548458 102 | 0:25:29,10100,0.08835361585021019,1621.2670492796206,516.7,50.437943927616026 103 | 0:25:42,10200,0.08750568851828575,1265.9952219040274,415.7,39.5218605199809 104 | 0:25:58,10300,0.08612973824143409,1592.7495108012936,505.1,49.56171394082128 105 | 0:26:16,10400,0.08652967490255832,1904.707696869175,601.0,59.14694225449616 106 | 0:26:34,10500,0.08505361065268517,1930.706857130918,607.5,59.945792543420474 107 | 0:26:56,10600,0.08613393321633339,2188.8891444795345,695.5,67.87870240540053 108 | 0:27:14,10700,0.084682871773839,1840.011325261324,579.0,57.15908131472052 109 | 0:27:31,10800,0.08451968289911747,1786.8706136958772,561.9,55.52627962573033 110 | 0:27:49,10900,0.08386263430118561,1825.6953680104912,573.4,56.71920916227704 111 | 0:28:09,11000,0.08381616160273551,1776.4396614701914,558.2,55.20577815124593 112 | 0:28:26,11100,0.08333075545728207,1869.5314051371702,586.0,58.06611539199372 113 | 0:28:45,11200,0.08323560245335102,1841.6297490386073,585.0,57.208809009348684 114 | 0:29:03,11300,0.08202484339475631,1942.5294631359768,606.7,60.309053976785954 115 | 0:29:22,11400,0.08235715344548225,1995.7309886384915,625.7,61.943724235018685 116 | 0:29:41,11500,0.0817306423932314,1981.5088682622027,619.0,61.50673531470988 117 | 0:29:58,11600,0.08143431834876537,1848.0459377379927,581.6,57.40595284571478 118 | 0:30:16,11700,0.0817880541831255,1819.1829332836255,569.9,56.51910807013475 119 | 0:30:31,11800,0.08187269032001496,1597.6109208986725,501.1,49.71108564443685 120 | 0:30:48,11900,0.08088715240359307,1717.4859854395436,536.0,53.394367295814114 121 | 0:31:02,12000,0.08056333154439926,1411.3744434401374,453.6,43.98878298818859 122 | 0:31:17,12100,0.08051395416259766,1448.2937980132358,464.0,45.123167205628754 123 | 0:31:33,12200,0.07931895524263383,1642.1546203364865,518.6,51.0797355087948 124 | 0:31:53,12300,0.08009820871055126,2057.3726457968805,645.5,63.837726007960995 125 | 0:32:11,12400,0.0795274656265974,1881.3909942806335,595.3,58.430513169398864 126 | 0:32:29,12500,0.07978324547410011,1835.2598183238508,574.2,57.013086495979714 127 | 0:32:43,12600,0.07947862554341555,1361.824101141249,441.7,42.46629899781099 128 | 0:33:00,12700,0.07895056452602148,1771.0512805596572,554.8,55.040214740586535 129 | 0:33:15,12800,0.07882151193916798,1572.4379839234653,498.8,48.93762189509768 130 | 0:33:32,12900,0.07945110850036144,1733.8736973600212,545.8,53.89789618946632 131 | 0:33:48,13000,0.07844202049076557,1499.9629029121854,474.7,46.71075230304909 132 | 0:34:04,13100,0.07828368742018937,1716.9041424769835,537.3,53.3764895869162 133 | 0:34:23,13200,0.07742979779839515,1911.4913817644388,601.3,59.35537778025917 134 | 0:34:40,13300,0.07731691360473633,1780.0038679167296,555.9,55.31529197095928 135 | 0:34:56,13400,0.07777290441095829,1740.5348806618629,547.4,54.10256773084237 136 | 0:35:15,13500,0.07769340045750141,2023.0763307219847,628.0,62.78393731129549 137 | 0:35:34,13600,0.07681618005037308,1965.177772435471,613.5,61.00494600735168 138 | 0:35:49,13700,0.07637575112283229,1489.6626600178513,476.3,46.3942670039359 139 | 0:36:04,13800,0.07640257373452186,1608.67123209908,508.3,50.05092480497464 140 | 0:36:23,13900,0.07605163872241974,1919.4144274496052,599.6,59.59882130962843 141 | 0:36:39,14000,0.07583362303674221,1759.746287051664,553.8,54.69285747061207 142 | 0:36:56,14100,0.07503743439912797,1706.4630967892724,541.7,53.05567798068239 143 | 0:37:15,14200,0.07623854793608188,1931.6785002024947,607.5,59.975647251828214 144 | 0:37:33,14300,0.07572923973202705,1926.5911797829735,599.3,59.819334226866204 145 | 0:37:52,14400,0.0752310037612915,1967.2013428206453,622.2,61.067122238067626 146 | 0:38:06,14500,0.0747097148001194,1372.2287218377949,440.3,42.785991409639145 147 | 0:38:24,14600,0.07507314555346965,1913.3089815157623,599.0,59.411225356560706 148 | 0:38:41,14700,0.07514971129596233,1766.9604787744115,558.9,54.91452075065856 149 | 0:39:00,14800,0.07461729615926743,1909.0306884991528,595.3,59.279770510403594 150 | 0:39:19,14900,0.07507061056792735,2051.3264067058817,648.2,63.65194924455309 151 | 0:39:35,15000,0.07393001288175582,1659.856202205642,518.3,51.62363437507472 152 | 0:39:52,15100,0.07412395142018795,1768.4756512963831,553.6,54.96107594685573 153 | 0:40:08,15200,0.07428098782896995,1684.3524668270952,529.7,52.376306687311256 154 | 0:40:26,15300,0.07315141882747411,1843.1522988143806,581.8,57.25559087907192 155 | 0:40:39,15400,0.07346053391695023,1334.3475146594362,426.8,41.62205330569345 156 | 0:40:58,15500,0.07381334993988276,1932.9348807355557,603.4,60.01425080447108 157 | 0:41:11,15600,0.07387652158737183,1301.4055621756258,419.2,40.60987875872758 158 | 0:41:28,15700,0.0737112408131361,1755.0958259260174,549.1,54.54996738583804 159 | 0:41:47,15800,0.07305116392672062,1958.7944101905632,616.7,60.80881079674041 160 | 0:42:06,15900,0.07280270975083113,2098.499084594034,654.3,65.10137710994974 161 | 0:42:24,16000,0.07291269078850746,1825.9439539098016,566.9,56.72684721348668 162 | 0:42:41,16100,0.0720273320376873,1881.253075622309,587.9,58.42627548022193 163 | 0:42:55,16200,0.07329363979399205,1419.606205612263,449.6,44.241712141413394 164 | 0:43:16,16300,0.07197386760264635,2217.3926095532934,697.1,68.75449997271741 165 | 0:43:34,16400,0.07336676854640245,1811.133909651963,566.1,56.271793741941856 166 | 0:43:52,16500,0.07261673510074615,1843.555330244455,579.1,57.26797442419872 167 | 0:44:07,16600,0.07181379355490208,1598.9064247873634,502.8,49.75089130144133 168 | 0:44:25,16700,0.07054194558411836,1881.5664878110153,587.0,58.43590538422575 169 | 0:44:45,16800,0.07186041820794344,2104.4623522452907,657.3,65.28460449261061 170 | 0:45:04,16900,0.07109503295272589,2008.438913375829,630.0,62.33418797484141 171 | 0:45:23,17000,0.07123076602816582,1958.966086892267,613.7,60.814085735676 172 | 0:45:37,17100,0.0710426064580679,1392.8096237453688,443.4,43.418360273466675 173 | 0:45:53,17200,0.07215427406132222,1639.6534141455104,515.8,51.00288343864311 174 | 0:46:12,17300,0.07065974619239569,2011.3744838162906,628.5,62.42438632243847 175 | 0:46:29,17400,0.07081293068826199,1842.4171089387078,572.5,57.2330014323866 176 | 0:46:44,17500,0.07054672103375197,1602.7804981402503,503.2,49.86992609280039 177 | 0:47:03,17600,0.07011458091437817,2020.296766095064,629.3,62.69853239886965 178 | 0:47:17,17700,0.07036718409508466,1321.0689975864932,427.1,41.21405754377589 179 | 0:47:30,17800,0.07079787120223045,1402.217811987471,444.7,43.707436298222625 180 | 0:47:47,17900,0.0701261667534709,1721.840070613173,537.7,53.528150932052284 181 | 0:48:04,18000,0.06961971580982208,1759.7807984900294,552.1,54.6939178691877 182 | 0:48:20,18100,0.07083872698247433,1682.4134887329856,523.8,52.316729639625734 183 | 0:48:38,18200,0.069383694678545,2016.9765955365833,630.5,62.59651682670431 184 | 0:49:00,18300,0.06880833342671394,2339.8037048725682,723.2,72.51570371464118 185 | 0:49:13,18400,0.06980048436671496,1344.1140991685083,430.1,41.922141415398926 186 | 0:49:31,18500,0.06896339260041713,1898.839747345763,595.6,58.96664361696406 187 | 0:49:47,18600,0.06866284057497979,1579.3391333669683,496.6,49.14966663697976 188 | 0:50:03,18700,0.06828272823244333,1739.3655442073893,539.6,54.066638694861915 189 | 0:50:21,18800,0.0685167359188199,1891.1633727348606,594.3,58.73077930388339 190 | 0:50:40,18900,0.06871617682278157,1990.0599933171882,614.4,61.76947721298783 191 | 0:50:56,19000,0.06942813083529473,1735.4972835422582,539.9,53.94778250416711 192 | 0:51:14,19100,0.06821426399052143,1893.018181128172,589.7,58.78777015304848 193 | 0:51:34,19200,0.06796323198825122,2176.4466651436383,673.6,67.49639474190874 194 | 0:51:49,19300,0.06855917233973742,1522.877372407131,481.9,47.41482237270899 195 | 0:52:06,19400,0.06914132315665483,1835.2160661008595,568.5,57.01174216502341 196 | 0:52:25,19500,0.0683505168557167,2078.6063330363986,641.3,64.49015235617568 197 | 0:52:42,19600,0.06851455185562372,1762.2420872660543,549.6,54.769543436708325 198 | 0:53:01,19700,0.06773159082978963,1956.2485957031918,611.6,60.730588091918015 199 | 0:53:18,19800,0.06837044894695282,1870.161995913247,579.8,58.085490926379855 200 | 0:53:33,19900,0.06788079094141722,1435.7998711057567,456.1,44.73927876387299 201 | 0:53:49,20000,0.0680889619141817,1756.0444098443616,546.7,54.57911357862309 202 | -------------------------------------------------------------------------------- /dt_runs/dt_hopper-medium-v2_model_22-02-12-09-43-59_best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/d6694248b48c57c84fc7487e6e8017dcca861b02/dt_runs/dt_hopper-medium-v2_model_22-02-12-09-43-59_best.pt -------------------------------------------------------------------------------- /dt_runs/dt_walker2d-medium-v2_log_22-02-20-06-27-12.csv: -------------------------------------------------------------------------------- 1 | duration,num_updates,action_loss,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score 2 | 0:00:21,100,0.7826899242401123,0.33698463768095877,15.3,-0.02814454280366867 3 | 0:00:26,200,0.7246889197826385,6.648241470530787,22.3,0.10933550845350555 4 | 0:00:33,300,0.6252569282054901,-18.26857375478725,71.4,-0.4334351511894897 5 | 0:00:40,400,0.4998399719595909,-36.35868205371129,91.0,-0.8274975514453353 6 | 0:00:49,500,0.39157980293035505,-20.941641233664154,100.5,-0.491663403301984 7 | 0:00:57,600,0.31690163880586625,-17.321939169767678,104.5,-0.41281431849053923 8 | 0:01:06,700,0.270312991142273,-16.460312831198866,107.7,-0.39404524660387297 9 | 0:01:15,800,0.24189319476485252,1.955136505365995,132.6,0.0071041576696375675 10 | 0:01:29,900,0.22416843876242637,62.53526814560409,192.9,1.3267398219507187 11 | 0:01:45,1000,0.20781991079449655,225.80081450530605,280.6,4.883203498921232 12 | 0:02:02,1100,0.19504526033997535,195.53473592383344,292.8,4.223908188187437 13 | 0:02:25,1200,0.1872202667593956,453.3503715769871,433.0,9.839985578670001 14 | 0:02:47,1300,0.17795019537210466,515.3115297121709,417.7,11.189704568402902 15 | 0:03:04,1400,0.17339187026023864,413.6156508203264,279.9,8.974431919392195 16 | 0:03:21,1500,0.16591425925493242,476.89346370760995,319.2,10.352832005078048 17 | 0:03:35,1600,0.16207062676548958,406.6714375673245,233.7,8.82316398350432 18 | 0:03:51,1700,0.1561943657696247,437.5956720132256,275.4,9.496796106124123 19 | 0:04:07,1800,0.15147392615675925,422.9236731622485,254.8,9.177191436642351 20 | 0:04:22,1900,0.14630322769284249,480.8092822811397,264.4,10.43813148701334 21 | 0:04:39,2000,0.14282810620963574,504.00734372786854,285.0,10.943462003775602 22 | 0:04:58,2100,0.1404953298717737,533.3777894184984,332.1,11.583247467421607 23 | 0:05:21,2200,0.13657934591174126,654.8302694639813,413.1,14.228884243769418 24 | 0:05:37,2300,0.13319598153233528,498.3049764727945,266.9,10.819245581709822 25 | 0:06:09,2400,0.13040619723498822,795.2798787909719,634.8,17.28834133777043 26 | 0:06:58,2500,0.12756308875977992,983.7057674571224,1000.0,21.39288049979083 27 | 0:07:19,2600,0.1249207753688097,606.0495878722106,383.0,13.16627963375099 28 | 0:07:44,2700,0.12303725391626358,642.9028823079733,484.0,13.969066296092631 29 | 0:08:34,2800,0.12083647206425667,987.4407860251191,1000.0,21.474241559524923 30 | 0:09:11,2900,0.11853870511054992,834.521963530369,726.6,18.14316375496789 31 | 0:09:44,3000,0.1163690073043108,764.9451945222202,627.2,16.627551568222255 32 | 0:10:26,3100,0.11491844780743123,905.0414836497843,860.7,19.679312179507725 33 | 0:11:03,3200,0.11402832671999931,860.388129735499,730.8,18.706614419374162 34 | 0:11:31,3300,0.11083751544356346,777.3963053426919,527.2,16.898777949772356 35 | 0:12:13,3400,0.10941306322813034,898.4668792005148,860.7,19.536095546019357 36 | 0:12:35,3500,0.10854245357215404,743.9288918125997,392.8,16.16974697394301 37 | 0:13:24,3600,0.10636503800749779,965.6398227259854,1000.0,20.999344461973703 38 | 0:14:14,3700,0.10563862450420856,962.8998642562723,1000.0,20.939659102807518 39 | 0:15:04,3800,0.10381215110421181,969.5562348585366,1000.0,21.08465687358796 40 | 0:15:53,3900,0.10170185983181,969.3932727200752,1000.0,21.081107019138674 41 | 0:16:42,4000,0.10129092887043953,970.7030718701742,1000.0,21.109638777401937 42 | 0:17:32,4100,0.09916638739407063,963.1477246093966,1000.0,20.945058321212766 43 | 0:18:21,4200,0.09942674212157726,964.4707167141709,1000.0,20.973877465670725 44 | 0:19:10,4300,0.09730922803282738,966.9450782974143,1000.0,21.027777246063515 45 | 0:19:59,4400,0.09595392182469369,971.4755273624862,1000.0,21.126465413282794 46 | 0:20:48,4500,0.09494938559830189,966.4073401363627,1000.0,21.016063530095007 47 | 0:21:38,4600,0.09514373101294041,967.8129412818578,1000.0,21.04668217272796 48 | 0:22:27,4700,0.09250498853623867,962.2598805372611,1000.0,20.92571814036158 49 | 0:23:17,4800,0.0918784200400114,963.4912756673824,1000.0,20.952541999711713 50 | 0:24:06,4900,0.09027835547924042,963.4997449031869,1000.0,20.9527264876835 51 | 0:24:55,5000,0.0894149012863636,958.941508475772,1000.0,20.853433019792675 52 | 0:25:45,5100,0.08964217491447926,959.9216848297895,1000.0,20.87478450317551 53 | 0:26:34,5200,0.0900094460695982,961.5440013649779,1000.0,20.910123923883628 54 | 0:27:23,5300,0.0869937638938427,963.9910338214798,1000.0,20.9634283854921 55 | 0:28:08,5400,0.08623833708465099,2212.003261560959,907.9,48.1492630905787 56 | 0:28:57,5500,0.08533291831612587,962.7168674657294,1000.0,20.935672827361905 57 | 0:29:46,5600,0.08477564133703709,969.4052319349466,1000.0,21.08136753039928 58 | 0:30:27,5700,0.08518708318471908,3083.158724699978,825.4,67.12591083242626 59 | 0:31:16,5800,0.08372270543128252,965.6223093264438,1000.0,20.99896296219792 60 | 0:32:05,5900,0.08294234909117222,966.9422260576005,1000.0,21.02771511484525 61 | 0:32:54,6000,0.08314431570470333,966.5175939347312,1000.0,21.018465222539547 62 | 0:33:32,6100,0.08186133846640586,1368.193758262896,768.9,29.768300813636174 63 | 0:34:21,6200,0.0801357738673687,1235.3026914793666,1000.0,26.873493779651074 64 | 0:35:09,6300,0.08091344326734543,1984.4570937409244,985.5,43.19255483994232 65 | 0:35:54,6400,0.0790667923539877,3294.6629051623227,931.8,71.73317153202606 66 | 0:36:30,6500,0.07909646958112716,2652.3010431887983,743.7,57.740405265549 67 | 0:37:19,6600,0.07849374234676361,973.738874587962,1000.0,21.175768602934593 68 | 0:38:05,6700,0.0776206536591053,3335.5286591457284,935.6,72.62336283640447 69 | 0:38:43,6800,0.07697978638112545,2739.9236013529066,757.9,59.64911443501039 70 | 0:39:20,6900,0.07629730135202407,2320.2313077171457,748.3,50.506827950809196 71 | 0:40:09,7000,0.07606719747185707,3592.4644867297925,994.6,78.22027509676504 72 | 0:40:52,7100,0.074544787555933,3182.997563438987,886.6,69.3007310038782 73 | 0:41:40,7200,0.07442153081297874,3734.562969355507,1000.0,81.3156501056329 74 | 0:42:23,7300,0.07438390202820301,2911.5611359922304,889.9,63.38794771098313 75 | 0:43:00,7400,0.07274295348674059,2279.7063952535887,745.7,49.624061302226046 76 | 0:43:47,7500,0.07283081274479627,3179.982008772919,958.7,69.23504224789193 77 | 0:44:32,7600,0.07238136503845453,3478.629397481579,919.9,75.74057029006924 78 | 0:45:19,7700,0.07181750372052192,3509.163363453412,978.9,76.40570107432808 79 | 0:46:06,7800,0.07134818200021982,3518.123033439771,974.5,76.60087232493552 80 | 0:46:49,7900,0.07111444298177957,3218.956315694656,879.6,70.08403157842021 81 | 0:47:36,8000,0.07007800105959178,3551.2383569521603,982.3,77.32223361547666 82 | 0:48:21,8100,0.06990218508988619,3438.515400256109,932.0,74.86675473466622 83 | 0:49:04,8200,0.0692890978604555,3173.3752094060305,885.6,69.09112430259803 84 | 0:49:45,8300,0.0691000546887517,2814.2889194952886,834.0,61.269037062268495 85 | 0:50:24,8400,0.0686588853597641,2955.4654292487717,803.5,64.34432845212208 86 | 0:51:12,8500,0.06748432613909244,3636.8475926695683,1000.0,79.18708596204203 87 | 0:51:54,8600,0.06745146099478007,3144.230456967898,867.2,68.45625518457756 88 | 0:52:28,8700,0.06688137613236904,2420.5153458174855,676.5,52.691346037056306 89 | 0:53:10,8800,0.06623180158436298,3195.6106655836466,896.3,69.57548609233129 90 | 0:53:56,8900,0.06561856273561716,3381.0643843981443,962.3,73.61528156313895 91 | 0:54:43,9000,0.0657417669892311,2934.231001266542,1000.0,63.88177236785393 92 | 0:55:29,9100,0.0656845472380519,3782.955367081243,993.8,82.36979660861836 93 | 0:56:17,9200,0.06493333477526902,3584.8684543871336,985.3,78.05480838490753 94 | 0:57:02,9300,0.06491713963449001,3078.6976039844626,953.5,67.0287328659963 95 | 0:57:48,9400,0.06442555774003267,3402.251137023337,976.9,74.07679912042228 96 | 0:58:32,9500,0.0633901346847415,3337.0477485257893,933.7,72.65645362820175 97 | 0:59:18,9600,0.06384179081767798,3369.5434004356202,1000.0,73.36431642138514 98 | 1:00:04,9700,0.06264458406716585,3439.3097689461365,963.0,74.88405871248149 99 | 1:00:51,9800,0.06235703360289335,3276.841912531104,1000.0,71.34497136123895 100 | 1:01:36,9900,0.06222419396042824,3357.452230694157,955.8,73.10093074720953 101 | 1:02:22,10000,0.06191367510706186,3488.4707950461416,967.0,75.95494848405686 102 | 1:03:08,10100,0.06144244804978371,3374.7640814536776,951.0,73.47804012380588 103 | 1:03:34,10200,0.061042127162218095,1697.4024319682353,537.1,36.93955473880396 104 | 1:04:20,10300,0.061015864685177806,3178.356399532752,961.2,69.1996311011772 105 | 1:05:06,10400,0.06032001305371523,3461.497059439669,980.0,75.36737129428484 106 | 1:05:43,10500,0.06051487427204847,2819.4717495445952,779.1,61.38193624538003 107 | 1:06:24,10600,0.05998043693602085,3072.000514059341,852.7,66.8828480936658 108 | 1:07:10,10700,0.05976848412305116,3538.288374218121,984.6,77.04014015339658 109 | 1:07:57,10800,0.05943823616951704,3516.9148907969,1000.0,76.5745549816762 110 | 1:08:43,10900,0.059576812498271466,3368.3949546524127,964.3,73.3392994732046 111 | 1:09:27,11000,0.05855456132441759,3331.4007419811096,920.3,72.53344314553983 112 | 1:10:02,11100,0.058611393831670286,2277.049061036177,717.4,49.566175772593404 113 | 1:10:46,11200,0.05875485084950924,3408.6237159341895,928.4,74.21561496938985 114 | 1:11:31,11300,0.058086014539003375,3341.7490632218332,962.1,72.75886381408169 115 | 1:12:17,11400,0.057662168852984905,3551.653178057132,981.0,77.33126979135804 116 | 1:13:02,11500,0.05794393382966519,3406.934167104134,952.7,74.17881100689722 117 | 1:13:49,11600,0.057141586355865,3417.7257015818286,1000.0,74.41388632587565 118 | 1:14:36,11700,0.05722889166325331,3566.025204711937,1000.0,77.64434007410864 119 | 1:15:23,11800,0.056443915516138074,3660.802496486169,1000.0,79.70890300923071 120 | 1:16:11,11900,0.05678156618028879,3559.8538287128918,1000.0,77.5099070901331 121 | 1:16:51,12000,0.05654059439897537,2889.49946733422,826.1,62.907371588310504 122 | 1:17:36,12100,0.056181509606540204,3199.9220299260296,973.2,69.66940186956508 123 | 1:18:20,12200,0.05564858302474022,3239.3159967602323,923.3,70.52753278120856 124 | 1:19:07,12300,0.055608641840517524,3687.8382169057345,1000.0,80.29783043327568 125 | 1:19:51,12400,0.05498258527368307,2957.012491925257,924.8,64.37802859485028 126 | 1:20:37,12500,0.05536829341202974,3568.921351757398,976.7,77.70742773712148 127 | 1:21:20,12600,0.054966694079339504,3229.59375592013,894.9,70.31575021484637 128 | 1:22:06,12700,0.054440183863043785,3542.188366652591,972.7,77.12509488967078 129 | 1:22:52,12800,0.0551599058508873,3571.6799017685967,989.5,77.76751808156145 130 | 1:23:38,12900,0.053986068814992905,3727.218681427444,1000.0,81.15566721988785 131 | 1:24:23,13000,0.05396680124104023,3529.954703620019,959.0,76.85860524025152 132 | 1:25:10,13100,0.05415511380881071,3604.0690757969955,1000.0,78.4730614342618 133 | 1:25:50,13200,0.053232499472796915,2915.250733865471,856.6,63.468319357735204 134 | 1:26:36,13300,0.054234681762754915,3596.9453637107595,976.3,78.31788342000962 135 | 1:27:23,13400,0.053401131629943845,3656.1700679464557,1000.0,79.60799339170886 136 | 1:28:09,13500,0.05277917128056288,3589.228752194728,978.4,78.14979009488397 137 | 1:28:54,13600,0.05320276319980621,3388.6717900598005,949.0,73.7809960235068 138 | 1:29:34,13700,0.05267050735652447,2410.534162176347,836.3,52.47392283991296 139 | 1:30:20,13800,0.05322949577122927,3502.3105477392623,982.0,76.25642407917658 140 | 1:31:06,13900,0.05300251662731171,3553.4525834792485,992.0,77.37046679382786 141 | 1:31:53,14000,0.05205541007220745,3658.1271471380524,1000.0,79.65062505045782 142 | 1:32:39,14100,0.05241932481527328,3472.6672084049337,1000.0,75.61069408924729 143 | 1:33:17,14200,0.05175412781536579,1670.3765700262254,793.8,36.35084206501168 144 | 1:34:02,14300,0.05139639649540186,3615.7935418808256,979.9,78.72845908973007 145 | 1:34:43,14400,0.05161351673305035,2267.1202001442534,868.3,49.34989233800995 146 | 1:35:31,14500,0.05075718354433775,3521.4303736825022,1000.0,76.67291713599897 147 | 1:36:15,14600,0.051962455883622166,3231.0475392991257,907.4,70.34741842591026 148 | 1:36:50,14700,0.05070974215865135,2157.5380380162173,723.6,46.96283035253983 149 | 1:37:33,14800,0.050801689699292184,3320.160884062846,914.5,72.28860185898606 150 | 1:38:13,14900,0.0511224702000618,2989.5481379737216,838.1,65.08676259267246 151 | 1:38:59,15000,0.050554512478411195,3591.6211013943553,970.3,78.20190337426723 152 | 1:39:45,15100,0.05029684282839298,3590.7647700131142,967.2,78.18324964406672 153 | 1:40:32,15200,0.050704943239688875,3595.931220976034,1000.0,78.29579203649526 154 | 1:41:20,15300,0.0503308280557394,3573.2712932200893,1000.0,77.80218385164747 155 | 1:42:05,15400,0.049793781340122224,3560.0643852006315,976.1,77.51449370695028 156 | 1:42:45,15500,0.04980747319757938,2606.630209589058,817.4,56.74554343207564 157 | 1:43:28,15600,0.049799401573836805,3239.2796129183257,888.8,70.52674022077524 158 | 1:44:15,15700,0.04940336156636477,3771.494670396286,1000.0,82.12014472319836 159 | 1:44:56,15800,0.05015094149857759,3192.4217130307816,843.0,69.50602015677583 160 | 1:45:40,15900,0.04985287126153708,3194.843245956106,955.6,69.55876915424362 161 | 1:46:27,16000,0.048677744530141356,3466.644836793856,955.7,75.47950691374346 162 | 1:47:11,16100,0.04926906406879425,3310.8376514499614,920.5,72.08551101171925 163 | 1:47:57,16200,0.04900280989706516,3232.0658251828263,991.3,70.36960006091472 164 | 1:48:27,16300,0.0491021379455924,2028.173428611882,627.9,44.14484122568289 165 | 1:49:14,16400,0.048717569150030615,3290.884239650427,998.8,71.65085969747115 166 | 1:49:58,16500,0.048304263390600684,3634.417720637813,969.2,79.1341553112507 167 | 1:50:45,16600,0.04867617111653089,3725.4531437631913,995.2,81.11720796921774 168 | 1:51:30,16700,0.04838556725531817,3468.3507711032485,942.2,75.51666780617869 169 | 1:52:06,16800,0.04806860979646444,2598.365700151918,728.2,56.56551507779929 170 | 1:52:46,16900,0.048396479934453965,2222.4604888802733,817.5,48.37705609377012 171 | 1:53:33,17000,0.048353834114968776,3710.377458836055,1000.0,80.78880968161648 172 | 1:54:15,17100,0.04745920673012734,3266.57033064006,867.1,71.12122232958446 173 | 1:55:00,17200,0.0478622256219387,3303.0512718246914,934.9,71.91589790638369 174 | 1:55:46,17300,0.04780504569411278,3627.5411565032045,962.3,78.98436099694257 175 | 1:56:17,17400,0.0472821632027626,1852.896109394896,616.7,40.326721401316576 176 | 1:57:01,17500,0.04734675843268633,3453.0634716616833,927.2,75.18365985443906 177 | 1:57:46,17600,0.04767358858138323,3553.143534549158,953.8,77.36373468580643 178 | 1:58:33,17700,0.04775761049240827,3445.6372123422116,982.0,75.02189136063035 179 | 1:59:19,17800,0.04743493478745222,3402.9425478591666,932.8,74.0918603355918 180 | 2:00:07,17900,0.04649154528975487,3679.030497560838,994.1,80.10596917028721 181 | 2:00:55,18000,0.04747708253562451,3575.7917339042447,967.1,77.85708738728634 182 | 2:01:43,18100,0.04699791777879,3622.1354971827654,1000.0,78.86660785519358 183 | 2:02:30,18200,0.046881279163062575,3474.1551692359412,956.8,75.64310679827393 184 | 2:03:15,18300,0.046443770937621594,3393.911023617624,925.5,73.89512386161486 185 | 2:04:04,18400,0.04657957509160042,3480.713315441693,988.1,75.78596491677513 186 | 2:04:47,18500,0.04631125275045633,3229.717055223576,894.7,70.31843608154563 187 | 2:05:20,18600,0.04642774552106857,2203.149467528254,631.7,47.95639816847615 188 | 2:06:08,18700,0.04667449299246073,3704.951975648802,980.9,80.67062471047156 189 | 2:06:52,18800,0.04631829027086496,3397.1939526999486,906.1,73.96663691685332 190 | 2:07:31,18900,0.04581700276583433,2750.5975109707697,800.0,59.88162749544238 191 | 2:08:18,19000,0.04702753737568855,3505.3379503783945,950.0,76.32237092320892 192 | 2:08:56,19100,0.04568715546280146,2484.2786498566657,760.7,54.08032172601982 193 | 2:09:42,19200,0.04589634343981743,3270.2816982746217,932.6,71.20206819375178 194 | 2:10:29,19300,0.04567991957068443,3392.6057962594036,963.6,73.86669169210207 195 | 2:11:17,19400,0.04602051254361868,3473.719674241614,990.1,75.63362027669383 196 | 2:12:04,19500,0.04598451796919108,3402.5400015050154,973.9,74.0830915443878 197 | 2:12:49,19600,0.04538236368447542,3441.2577641095354,922.7,74.92649249104662 198 | 2:13:37,19700,0.04522538248449564,3499.2155999108763,990.8,76.18900587748495 199 | 2:14:13,19800,0.04598742425441742,1935.6937102573531,709.5,42.130327039942074 200 | 2:15:02,19900,0.045358895659446716,3646.923096924764,991.7,79.40656377416916 201 | 2:15:49,20000,0.04559674881398678,3477.531931972746,953.3,75.71666386090571 202 | -------------------------------------------------------------------------------- /dt_runs/dt_walker2d-medium-v2_log_22-02-20-09-11-30.csv: -------------------------------------------------------------------------------- 1 | duration,num_updates,action_loss,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score 2 | 0:00:11,100,0.6412723487615586,-0.9869671908196122,20.2,-0.05698459321912591 3 | 0:00:16,200,0.5928776508569717,2.2918990330575575,24.2,0.014439959522535031 4 | 0:00:22,300,0.5136207875609398,16.3356512300006,52.4,0.32035933865941046 5 | 0:00:31,400,0.4269720208644867,-0.0006673745601628811,108.1,-0.03549972057244225 6 | 0:00:40,500,0.35608836501836777,-4.998830198714646,103.5,-0.14437624064684107 7 | 0:00:49,600,0.3063909536600113,1.2892189841149422,120.9,-0.00740172877727888 8 | 0:00:59,700,0.26801478818058966,14.467858096484841,133.2,0.27967262561091066 9 | 0:01:09,800,0.24293274492025374,14.170848121524951,138.0,0.2732027658566944 10 | 0:01:19,900,0.22397424429655075,24.558617823066463,143.5,0.4994827523694267 11 | 0:01:30,1000,0.20768051847815513,33.18592670772834,152.3,0.6874140787418979 12 | 0:01:43,1100,0.19675054728984834,75.96957290840615,197.3,1.6193834199391945 13 | 0:01:56,1200,0.1881781643629074,71.44561739011938,194.9,1.5208367036493426 14 | 0:02:09,1300,0.1781380282342434,88.43304322045917,208.9,1.8908790321007427 15 | 0:02:23,1400,0.1718070162832737,104.386426354008,220.4,2.238396490035546 16 | 0:02:37,1500,0.16713388994336129,104.7402223203305,221.7,2.246103336527901 17 | 0:02:52,1600,0.16217972084879875,136.35221474700197,249.5,2.934717103050498 18 | 0:03:07,1700,0.15744786694645882,113.72073222697448,230.5,2.441728549536936 19 | 0:03:22,1800,0.15293219923973084,122.84885446248704,236.7,2.6405692473656375 20 | 0:03:37,1900,0.1500479534268379,135.16967445914526,249.0,2.9089574637751614 21 | 0:03:54,2000,0.1450935335457325,206.00252276972282,285.6,4.451931212798244 22 | 0:04:17,2100,0.1414915583282709,521.7638049787031,415.1,11.330256467630191 23 | 0:04:36,2200,0.13643936529755593,316.3491175162175,318.4,6.8556450694172835 24 | 0:04:56,2300,0.13428600400686264,572.6030965227473,348.2,12.437704412225655 25 | 0:05:16,2400,0.13078454852104188,562.0824794710339,346.6,12.208530570971353 26 | 0:05:34,2500,0.12903099298477172,263.9113655539121,311.2,5.713377369255655 27 | 0:05:50,2600,0.1252146426588297,508.145073011947,272.0,11.033595435060246 28 | 0:06:06,2700,0.12404202677309513,484.99319347470964,267.4,10.529270913055004 29 | 0:06:29,2800,0.12149275027215481,573.1233843007185,420.1,12.449038001125357 30 | 0:06:59,2900,0.1194714617729187,563.2760582772596,571.9,12.234530665691835 31 | 0:07:22,3000,0.11703295730054379,569.8151776999568,404.5,12.376974317918116 32 | 0:08:06,3100,0.11572174683213234,876.0593357794971,861.5,19.047985126865676 33 | 0:08:49,3200,0.11370789542794228,839.8116912106503,852.6,18.258391522096044 34 | 0:09:40,3300,0.11149841159582138,945.144375863271,1000.0,20.55288583101472 35 | 0:10:30,3400,0.11081763215363026,959.5237431230782,1000.0,20.866116016424776 36 | 0:10:46,3500,0.10919691719114781,376.4146715687426,248.5,8.164071531631615 37 | 0:11:36,3600,0.1080987086892128,946.0605220513353,1000.0,20.572842525573336 38 | 0:12:26,3700,0.10734239898622036,953.8954098929795,1000.0,20.74351230032517 39 | 0:13:16,3800,0.10534275099635124,945.2894791473476,1000.0,20.556046660538975 40 | 0:13:31,3900,0.10264556467533112,522.9811611519617,243.5,11.356774512059427 41 | 0:14:22,4000,0.1018887285888195,953.4776027353086,1000.0,20.73441107833825 42 | 0:14:39,4100,0.10034056901931762,534.8005212262315,279.9,11.614239272545792 43 | 0:15:29,4200,0.09990146815776825,951.2576774256593,1000.0,20.686053761651475 44 | 0:16:19,4300,0.09921074666082859,949.7816347502642,1000.0,20.653900669478954 45 | 0:16:56,4400,0.09617936171591282,877.4853285145279,705.0,19.079047965773448 46 | 0:17:46,4500,0.09717400215566158,955.451599284516,1000.0,20.777411253098048 47 | 0:18:36,4600,0.09634312383830547,955.2504396220065,1000.0,20.773029330218804 48 | 0:19:27,4700,0.09385764010250569,950.9792614461851,1000.0,20.67998894062729 49 | 0:20:12,4800,0.09241068437695503,933.1803129736008,872.7,20.292268964536607 50 | 0:21:02,4900,0.09214929096400738,954.7903019604785,1000.0,20.763006009829912 51 | 0:21:48,5000,0.09068352900445462,946.4287616329975,913.0,20.58086400178681 52 | 0:22:38,5100,0.08967191837728024,952.5650992567414,1000.0,20.714533734042455 53 | 0:23:28,5200,0.08965510010719299,951.4362642962633,1000.0,20.689943974452945 54 | 0:24:18,5300,0.08926689602434636,946.7429004905401,1000.0,20.58770698526548 55 | 0:25:09,5400,0.08774110153317452,959.6571896455977,1000.0,20.86902292312212 56 | 0:25:58,5500,0.08755047000944614,958.6075892369212,1000.0,20.846159154176238 57 | 0:26:22,5600,0.08604740157723427,1060.1089300420163,426.5,23.05719412013171 58 | 0:26:49,5700,0.0861504279077053,1052.6448173930398,486.5,22.894601055588772 59 | 0:27:21,5800,0.08460824616253376,1802.4662611414053,629.6,39.2281924860148 60 | 0:27:46,5900,0.08482429087162018,1398.842638637896,449.9,30.435934813729205 61 | 0:28:13,6000,0.08335200525820255,1364.2197267967435,503.7,29.681733262333154 62 | 0:28:51,6100,0.08326162822544575,2724.013940284763,747.8,59.30254938829131 63 | 0:29:24,6200,0.08177453130483628,2275.4408929848487,632.3,49.53114455266649 64 | 0:29:52,6300,0.08180231414735317,1809.9605658348269,527.5,39.391443233160075 65 | 0:30:19,6400,0.0798805160075426,1816.1251110052222,507.3,39.52572741909146 66 | 0:30:52,6500,0.08021790690720082,2262.99170639432,623.0,49.25996008720983 67 | 0:31:26,6600,0.07834820725023746,2485.5133931036753,673.5,54.107218518431246 68 | 0:31:59,6700,0.07825386881828308,2246.739524213547,619.5,48.90593379760871 69 | 0:32:15,6800,0.07799843303859234,859.3917276986089,268.5,18.684909486944317 70 | 0:32:34,6900,0.0782447686046362,1018.0939838703555,328.8,22.14196960840176 71 | 0:32:56,7000,0.07609723068773747,1319.8627997862645,393.5,28.715492660735297 72 | 0:33:22,7100,0.07508274361491203,1742.9988072451047,481.4,37.932794623699415 73 | 0:34:00,7200,0.07523761853575707,2697.827252900294,729.5,58.73211671232513 74 | 0:34:40,7300,0.07460158471018076,2865.3355979974835,797.4,62.38100258084197 75 | 0:35:14,7400,0.07370971314609051,2369.5904806837398,651.6,51.58203401660242 76 | 0:35:53,7500,0.07390922486782074,2739.8875241630467,765.2,59.64832855447304 77 | 0:36:22,7600,0.07278675150126218,2022.8349065186467,550.8,44.02855055483023 78 | 0:36:52,7700,0.07301000311970711,2065.643421705468,563.5,44.961061624811556 79 | 0:37:35,7800,0.07203142952173948,3093.859900838124,849.3,67.3590178478668 80 | 0:38:10,7900,0.07183183781802654,2548.989024382585,678.7,55.48992774306368 81 | 0:38:43,8000,0.0707525610551238,2217.6425402353907,628.4,48.27210523466306 82 | 0:39:22,8100,0.07156692508608103,2811.9660386698574,756.4,61.21843703387441 83 | 0:40:02,8200,0.07091634552925825,2837.9575161180073,783.6,61.78461739167928 84 | 0:40:37,8300,0.06977598804980517,2362.169326409134,679.7,51.420376727558214 85 | 0:41:14,8400,0.06864094391465186,2626.71761265437,715.5,57.18311351933125 86 | 0:41:55,8500,0.0685340815782547,2827.096156334178,787.9,61.548021046553316 87 | 0:42:37,8600,0.06796952094882727,3182.3891247230104,840.4,69.28747719594823 88 | 0:43:21,8700,0.06771729532629252,3257.418246574041,858.5,70.92185966382235 89 | 0:44:07,8800,0.06635470025241375,3349.623290088861,908.1,72.93039052293862 90 | 0:44:43,8900,0.0672356055304408,2163.0924075737266,701.7,47.08382289518095 91 | 0:45:17,9000,0.06598716668784618,2426.224042116643,696.2,52.815700326638506 92 | 0:46:03,9100,0.0660145130008459,3565.3712112254652,945.4,77.63009393258355 93 | 0:46:40,9200,0.065955640822649,2694.3810380742475,732.8,58.657046753444355 94 | 0:47:20,9300,0.06478048045188188,2791.5353131341512,790.9,60.77338824751376 95 | 0:48:03,9400,0.06447704419493676,3405.515678521734,895.1,74.14791163325725 96 | 0:48:41,9500,0.06458344988524914,2892.8715390913494,790.3,62.980826465887354 97 | 0:49:15,9600,0.0641660388931632,2319.0726915659952,689.3,50.4815894583716 98 | 0:49:53,9700,0.0638224594667554,2772.759542962325,778.2,60.36438986351834 99 | 0:50:09,9800,0.062431233264505864,351.2949492810444,264.2,7.616880884959842 100 | 0:50:29,9900,0.06320817355066538,1188.385168331039,354.4,25.851474923799962 101 | 0:51:03,10000,0.06284180399030447,2238.9616041793024,673.8,48.73650497014974 102 | 0:51:24,10100,0.06180677805095911,647.4338037593966,381.1,14.067764753449282 103 | 0:52:00,10200,0.06121071372181177,806.2396350922079,719.7,17.52708108453806 104 | 0:52:44,10300,0.06216386370360851,3382.4471964039963,904.3,73.64540378292472 105 | 0:53:29,10400,0.06148246329277754,3487.3942015735506,939.0,75.9314967168867 106 | 0:53:47,10500,0.06088115517050028,912.6559727568787,305.0,19.84518094074903 107 | 0:54:32,10600,0.06051419354975224,3334.2033832958464,950.7,72.5944939444235 108 | 0:55:02,10700,0.05948980920016766,1919.4421798911885,599.2,41.776314949019294 109 | 0:55:48,10800,0.059083571061491966,3333.6011382670395,961.9,72.5813750554886 110 | 0:56:25,10900,0.059086164310574535,2642.8649888313735,737.9,57.53485678747534 111 | 0:57:10,11000,0.0591065426543355,3335.6547462237454,929.9,72.62610943005573 112 | 0:57:35,11100,0.059078799411654476,1395.605559933796,478.8,30.365420531400083 113 | 0:58:15,11200,0.05859747052192688,2745.540477025707,821.2,59.771468567610796 114 | 0:58:56,11300,0.058347235433757305,3079.51694509687,837.1,67.04658082577899 115 | 0:59:28,11400,0.0574583775550127,1140.3093900646893,634.1,24.804225440007073 116 | 1:00:02,11500,0.05798234786838293,2206.234210759199,692.0,48.02359408027904 117 | 1:00:20,11600,0.05746835194528103,769.5007886597352,321.0,16.726787478298448 118 | 1:00:56,11700,0.05647266950458288,2174.8371295630122,763.5,47.33966179127594 119 | 1:01:44,11800,0.05746171165257692,3629.832279031734,1000.0,79.0342692245747 120 | 1:02:29,11900,0.057151501029729844,2863.20815283906,939.7,62.3346597877703 121 | 1:02:50,12000,0.05620826672762633,410.6020975646337,400.4,8.908786760744489 122 | 1:03:30,12100,0.05650175932794809,822.3166294092146,817.7,17.877291202950456 123 | 1:03:54,12200,0.05610295254737139,1699.9142173943615,480.2,36.99426973429163 124 | 1:04:38,12300,0.055670005977153776,3376.4691736278583,920.9,73.51518267175045 125 | 1:05:21,12400,0.0549100199714303,3283.3317409136544,893.3,71.48634129155764 126 | 1:06:02,12500,0.05478613335639238,2520.9944039135826,858.4,54.880112303930986 127 | 1:06:49,12600,0.05467073820531368,3652.600732961265,1000.0,79.53024146848432 128 | 1:07:27,12700,0.0548620667681098,2916.469970233239,786.2,63.49487835902049 129 | 1:08:12,12800,0.05500482130795717,3421.9460102439225,961.9,74.50581860918345 130 | 1:08:56,12900,0.054540783390402796,3244.8488694621715,915.0,70.64805705122444 131 | 1:09:40,13000,0.054118176139891144,3277.4345788748674,923.7,71.35788159472762 132 | 1:10:27,13100,0.053669474497437475,3593.1546082515897,989.4,78.23530822640991 133 | 1:11:11,13200,0.0537726766243577,3257.1323366192482,917.9,70.91563159922588 134 | 1:11:54,13300,0.05333694871515036,3319.652692897153,911.1,72.27753177431698 135 | 1:12:30,13400,0.05274759169667959,2659.339948543132,742.3,57.89373590864234 136 | 1:13:16,13500,0.05339444946497679,3533.6888726486686,964.6,76.9399477942084 137 | 1:14:04,13600,0.05248939294368029,3628.5502557008995,1000.0,79.00634251553655 138 | 1:14:49,13700,0.05243374984711409,3448.556039829983,950.4,75.08547307870289 139 | 1:15:36,13800,0.05286451142281294,3566.3740082427867,1000.0,77.65193816884158 140 | 1:16:24,13900,0.05191159851849079,3668.4904079200714,1000.0,79.87637115162862 141 | 1:17:11,14000,0.05224816113710404,3399.0742727673364,952.0,74.00759650796024 142 | 1:18:00,14100,0.05235660646110773,3636.1996520712646,1000.0,79.17297167244402 143 | 1:18:48,14200,0.05183851581066847,3608.9293594616815,998.7,78.57893449014308 144 | 1:19:31,14300,0.05191284615546465,3288.16477394831,899.2,71.59162073857263 145 | 1:20:18,14400,0.051286631897091865,3604.908066851395,972.8,78.49133743478245 146 | 1:21:05,14500,0.05122464053332806,3579.385153297506,957.7,77.93536394858911 147 | 1:21:50,14600,0.05106003262102604,3351.8882968386106,929.3,72.97972986251877 148 | 1:22:33,14700,0.05084971088916063,3315.5620309346646,903.7,72.18842362499377 149 | 1:23:18,14800,0.050918541848659515,3405.224403267658,939.2,74.14156669469416 150 | 1:24:00,14900,0.05047809984534979,3159.2701623244893,850.6,68.78386971811308 151 | 1:24:45,15000,0.05070396676659584,3497.1702672563465,943.7,76.14445176637365 152 | 1:25:34,15100,0.050703027993440626,3675.8660937103655,1000.0,80.03703798667621 153 | 1:26:19,15200,0.05062726873904467,3503.340050642893,933.7,76.27885005798065 154 | 1:27:05,15300,0.04995539590716362,3509.46357587561,964.7,76.41224069397674 155 | 1:27:52,15400,0.05034440916031599,3558.943949030171,990.2,77.49008690079027 156 | 1:28:38,15500,0.049157407954335215,3513.428653995726,968.8,76.49861321178571 157 | 1:29:18,15600,0.04994081273674965,1639.2027972729525,835.9,35.67177417259251 158 | 1:29:44,15700,0.04973868321627378,1690.9064134757984,504.3,36.79804996741527 159 | 1:30:26,15800,0.0486184960976243,3043.762365347184,840.8,66.26772780381349 160 | 1:31:13,15900,0.049726711995899675,2007.7946652141309,949.9,43.700924346576016 161 | 1:31:51,16000,0.04952793996781111,2843.7090959251846,779.3,61.90990582592342 162 | 1:32:38,16100,0.048854658231139186,3466.4604374629735,963.9,75.4754900863297 163 | 1:33:23,16200,0.049671712890267374,3359.2205993463394,921.5,73.13945166616156 164 | 1:34:07,16300,0.04871780026704073,3413.588970964875,919.8,74.32377464886454 165 | 1:34:54,16400,0.04853825241327286,3572.929023627098,973.7,77.79472808769516 166 | 1:35:36,16500,0.04851634033024311,3085.1688480542416,848.1,67.1696979685065 167 | 1:36:19,16600,0.048122283071279526,3164.2249614665066,869.4,68.89180163374485 168 | 1:36:58,16700,0.04818148322403431,2839.654756447764,795.5,61.82158890047863 169 | 1:37:40,16800,0.048443644009530544,3134.268125608673,855.7,68.23924265249704 170 | 1:38:26,16900,0.04806218471378088,3587.596904645191,966.8,78.11424305715505 171 | 1:39:09,17000,0.04791088093072176,3387.8654647081044,905.4,73.76343159004813 172 | 1:39:47,17100,0.04817161962389946,2755.4973580704022,763.2,59.98836237381139 173 | 1:40:13,17200,0.04783609367907047,1697.4552426125617,501.6,36.94070512933333 174 | 1:40:59,17300,0.04717900436371565,3509.8153004690384,944.1,76.41990241911543 175 | 1:41:42,17400,0.04731346067041159,3224.9008611147083,880.1,70.2135234420369 176 | 1:42:22,17500,0.04703417900949716,2987.8995616563648,815.0,65.05085114704218 177 | 1:43:08,17600,0.04767159968614578,3549.1078813984263,967.3,77.27582481037068 178 | 1:43:52,17700,0.04741339858621359,3462.8976696418476,921.5,75.39788121766247 179 | 1:44:38,17800,0.04693534675985575,3546.8316697119744,947.6,77.2262413902036 180 | 1:45:26,17900,0.04685991924256086,3634.047275354323,1000.0,79.12608578755503 181 | 1:46:12,18000,0.04712756771594286,3572.7736031112154,962.8,77.79134251473309 182 | 1:46:55,18100,0.0465202646702528,3288.5827803168722,869.8,71.60072630003175 183 | 1:47:11,18200,0.046675598435103896,756.6239867202885,276.6,16.446288136004334 184 | 1:47:56,18300,0.045956105068325995,3532.873654331772,954.3,76.92218964255001 185 | 1:48:35,18400,0.04659823406487704,2775.9731886159198,784.0,60.434393696491675 186 | 1:49:24,18500,0.046758133359253405,3643.773719149355,997.1,79.33795990817882 187 | 1:50:08,18600,0.04621251810342073,3271.936226378092,905.1,71.23810928897194 188 | 1:50:56,18700,0.046776832677423955,3612.3178795906774,992.2,78.65274766766987 189 | 1:51:32,18800,0.04604645080864429,2529.699785653142,703.1,55.06974431534566 190 | 1:52:20,18900,0.04541662193834781,3605.624995447715,974.9,78.50695451118739 191 | 1:53:08,19000,0.046127894744277004,2984.337306972411,949.0,64.97325345620001 192 | 1:53:55,19100,0.04593969512730837,3589.5081881743113,940.0,78.15587713488466 193 | 1:54:41,19200,0.045254308916628364,3389.7754615759127,914.3,73.80503764003815 194 | 1:55:25,19300,0.04580109968781471,3001.478924744676,869.0,65.34665459520859 195 | 1:56:10,19400,0.045522925853729246,2633.242197317093,873.1,57.325240556404765 196 | 1:56:59,19500,0.045008112378418445,3732.4503127116323,984.5,81.26962945532804 197 | 1:57:45,19600,0.04549287892878055,3356.245898377656,924.9,73.07465283884704 198 | 1:58:22,19700,0.04535468854010105,2515.334487950659,710.1,54.75682061143578 199 | 1:59:09,19800,0.045302155017852785,3238.040286858045,928.4,70.49974359953097 200 | 1:59:56,19900,0.04541571393609047,3415.126714531359,937.7,74.35727179054959 201 | 2:00:44,20000,0.045301046781241895,3460.807956474319,952.3,75.35236035216872 202 | -------------------------------------------------------------------------------- /dt_runs/dt_walker2d-medium-v2_log_22-02-22-09-24-12.csv: -------------------------------------------------------------------------------- 1 | duration,num_updates,action_loss,eval_avg_reward,eval_avg_ep_len,eval_d4rl_score 2 | 0:00:21,100,0.7494068479537964,4.606717397848781,30.2,0.06486436085352076 3 | 0:00:26,200,0.6946344923973083,14.652548691326103,41.6,0.28369579771719133 4 | 0:00:35,300,0.6070064491033554,35.68925987752296,109.8,0.7419449561268615 5 | 0:00:46,400,0.4988633921742439,259.64411563073816,159.6,5.620422550001338 6 | 0:00:56,500,0.4002128967642784,23.06346062672421,144.1,0.4669132827004434 7 | 0:01:06,600,0.32265163868665697,12.281529137764652,127.1,0.2320471485830377 8 | 0:01:16,700,0.2724792277812958,17.600136380731676,132.8,0.34790400811916156 9 | 0:01:25,800,0.23924661606550215,22.849816848991885,136.2,0.46225941449458346 10 | 0:01:37,900,0.2202159868180752,92.08859475785208,172.0,1.9705090370338625 11 | 0:01:49,1000,0.20659611627459526,93.0872986808466,173.1,1.9922641121576286 12 | 0:02:00,1100,0.1943126106262207,81.7246228652983,169.1,1.7447474455232819 13 | 0:02:12,1200,0.18469071850180627,73.96044325523161,165.7,1.575617929955796 14 | 0:02:23,1300,0.17855634063482284,67.10855316687442,165.4,1.4263610979959858 15 | 0:02:35,1400,0.17178114444017412,81.77458732496645,183.0,1.745835836735704 16 | 0:02:47,1500,0.1680697436630726,87.79718456747528,185.4,1.8770279272384691 17 | 0:03:00,1600,0.16065416529774665,118.43606757337575,206.3,2.5444441515615317 18 | 0:03:14,1700,0.15667216792702676,137.8676317093936,229.0,2.9677278974427015 19 | 0:03:30,1800,0.1509039180725813,171.80215483026345,258.7,3.706934065342913 20 | 0:03:46,1900,0.14685582980513573,445.2970160229248,285.0,9.664556854457427 21 | 0:04:03,2000,0.14311599165201186,340.85618752997965,293.3,7.389490122928409 22 | 0:04:16,2100,0.13905903108417988,106.0129678976753,194.7,2.2738279454045283 23 | 0:04:43,2200,0.13523478865623473,406.1300420145707,530.5,8.811370597446002 24 | 0:05:25,2300,0.13459579899907112,799.1460299941725,846.3,17.372558900081867 25 | 0:05:56,2400,0.1310204781591892,534.3074660376803,614.3,11.6034988995282 26 | 0:06:44,2500,0.12758126646280288,965.0037845717137,1000.0,20.9854894469796 27 | 0:07:33,2600,0.12636793226003648,954.1331610693567,1000.0,20.748691307420025 28 | 0:08:23,2700,0.12255082719027996,959.2510062710453,1000.0,20.86017490558263 29 | 0:09:12,2800,0.12107995934784413,955.3401717983336,1000.0,20.774983993850405 30 | 0:10:00,2900,0.11841835044324397,955.5349713588718,1000.0,20.779227372669702 31 | 0:10:49,3000,0.11626955851912499,956.5744064783552,1000.0,20.8018697079905 32 | 0:11:39,3100,0.11559994503855706,955.0500470579551,1000.0,20.768664117281507 33 | 0:12:27,3200,0.11285074077546596,949.8220122799621,1000.0,20.654780225643364 34 | 0:13:17,3300,0.11198795266449452,952.9561336614067,1000.0,20.723051756905488 35 | 0:14:06,3400,0.10931384190917015,948.5011477898072,1000.0,20.626007427669894 36 | 0:14:56,3500,0.10756863370537757,954.0814781275969,1000.0,20.747565481982964 37 | 0:15:46,3600,0.1070942847430706,953.1023351871443,1000.0,20.726236509766068 38 | 0:16:34,3700,0.1067583317309618,948.4609259804881,1000.0,20.625131263610452 39 | 0:17:23,3800,0.10349807038903236,951.0125402460313,1000.0,20.680713862973153 40 | 0:18:12,3900,0.10242165714502334,947.5412767821887,1000.0,20.60509826190107 41 | 0:19:01,4000,0.10137719236314297,948.3888322523893,1000.0,20.62356082372869 42 | 0:19:50,4100,0.09912252336740494,943.3411523247951,1000.0,20.513605657340364 43 | 0:20:13,4200,0.09891738384962082,942.2157492064262,434.1,20.48909065462442 44 | 0:20:33,4300,0.09791590325534344,716.1864775340173,371.9,15.565425419927747 45 | 0:21:21,4400,0.09603212863206863,946.6338382107275,1000.0,20.585331248036592 46 | 0:22:10,4500,0.09664956159889698,946.1774027564479,1000.0,20.575388573968358 47 | 0:22:59,4600,0.09455271989107132,946.731835708574,1000.0,20.587465957712308 48 | 0:23:38,4700,0.0934526763111353,1274.6431864002375,781.9,27.730459896138804 49 | 0:24:27,4800,0.09207363754510879,1190.21829844377,1000.0,25.89140656159159 50 | 0:25:16,4900,0.09079103358089924,950.202220100163,1000.0,20.663062409682766 51 | 0:26:05,5000,0.08993766807019711,950.534000395545,1000.0,20.670289682034895 52 | 0:26:25,5100,0.08961838528513909,1056.55911399889,357.9,22.979867384033387 53 | 0:26:50,5200,0.08887228973209858,1394.9463172158041,485.4,30.35106004424819 54 | 0:27:10,5300,0.08791158027946949,1169.409667608464,359.3,25.4381257477069 55 | 0:27:43,5400,0.08807892374694347,2313.052289896596,648.0,50.350445194714055 56 | 0:28:03,5500,0.08676943302154541,1116.5205584220048,363.0,24.2860259941278 57 | 0:28:51,5600,0.08543437473475933,1046.7782923029645,1000.0,22.766808732847746 58 | 0:29:30,5700,0.08276517026126384,1322.9285596813893,799.1,28.782275052687744 59 | 0:30:15,5800,0.08395211078226567,3291.7550561543053,936.2,71.66982896155903 60 | 0:30:59,5900,0.08324780516326427,3328.2595616748504,906.8,72.46501784754454 61 | 0:31:38,6000,0.08193713515996932,2022.663653828234,795.1,44.024820104734566 62 | 0:32:08,6100,0.08238500386476516,2169.2482999841072,619.6,47.217918595376155 63 | 0:32:47,6200,0.08093899175524712,3066.6605201679245,816.2,66.76652536218009 64 | 0:33:23,6300,0.07938482321798801,2590.4250293460545,744.7,56.39254100016006 65 | 0:34:09,6400,0.07909935839474201,3617.235666407665,965.9,78.75987333242689 66 | 0:34:49,6500,0.07864443875849247,3038.5747458911383,824.6,66.15472429157995 67 | 0:35:21,6600,0.07761473648250103,2495.906807109336,645.2,54.33362145655888 68 | 0:35:31,6700,0.07672786429524421,387.55222076072045,166.5,8.406684195692856 69 | 0:36:08,6800,0.07613473787903785,2829.189242414216,760.0,61.59361538523901 70 | 0:36:49,6900,0.07574322041124106,3058.225981224791,824.1,66.58279320281098 71 | 0:37:30,7000,0.07613274045288562,2935.037249731478,813.8,63.89933512646462 72 | 0:38:04,7100,0.07470361799001694,2512.9311978360192,679.9,54.70446900273134 73 | 0:38:48,7200,0.0742376471683383,3276.0741093436336,883.1,71.32824606794722 74 | 0:39:25,7300,0.07340330451726913,2625.5101689279345,744.7,57.15681140078388 75 | 0:39:50,7400,0.07330021433532238,1733.6030394150675,492.9,37.728123719458814 76 | 0:40:31,7500,0.07173143241554499,3126.2839910217212,832.7,68.0653217899289 77 | 0:41:04,7600,0.07224882706999779,2069.5974360233995,629.9,45.04719313641023 78 | 0:41:40,7700,0.07072153046727181,2527.5357119300065,737.3,55.022603630968426 79 | 0:42:13,7800,0.07123743306845426,2057.841302814178,667.3,44.79110566619709 80 | 0:42:52,7900,0.07041781555861235,2976.76446360117,796.4,64.80829187684836 81 | 0:43:35,8000,0.07028338529169559,2976.6673711008148,872.8,64.8061768810117 82 | 0:44:16,8100,0.0702977330237627,3229.747983757081,856.1,70.31910980731595 83 | 0:45:02,8200,0.06899643313139676,3550.9793980431496,957.3,77.31659263381054 84 | 0:45:37,8300,0.06884092830121517,2645.884212929734,706.4,57.60062547583532 85 | 0:46:21,8400,0.06746281534433365,3258.645179796047,905.1,70.94858632805388 86 | 0:47:07,8500,0.06787878714501858,3497.263343674341,957.5,76.1464792786514 87 | 0:47:43,8600,0.06744561642408371,1599.5847054469505,746.1,34.808761077229256 88 | 0:48:30,8700,0.06655144702643156,3517.9186361732923,971.3,76.59641987633194 89 | 0:49:10,8800,0.06614126697182655,2782.0109568692897,833.6,60.56591626179622 90 | 0:49:39,8900,0.0657840321213007,1985.8184626288905,575.5,43.22220995768738 91 | 0:50:21,9000,0.06542456336319447,3198.1235150629627,882.4,69.63022426641727 92 | 0:51:03,9100,0.06584585767239332,2826.971390177607,866.3,61.5453032269407 93 | 0:51:36,9200,0.06418715581297875,2183.652721779177,657.4,47.53169455144385 94 | 0:52:23,9300,0.06473316632211208,3465.023827797356,997.5,75.44419597555327 95 | 0:53:08,9400,0.0636089513450861,3244.2944089529037,898.6,70.6359790671948 96 | 0:53:44,9500,0.0635768450051546,1250.2898817419693,704.1,27.199964360721264 97 | 0:54:28,9600,0.06302034344524145,3096.8365176579205,885.7,67.42385840875613 98 | 0:55:16,9700,0.06269818644970655,3610.539782432484,986.4,78.61401482967534 99 | 0:55:57,9800,0.06219020787626505,2849.5357965666267,824.0,62.036830640435205 100 | 0:56:40,9900,0.06198899317532778,1075.7531011088724,861.9,23.397975916390227 101 | 0:57:19,10000,0.06184269107878208,2459.6909565185524,798.1,53.544720429804926 102 | 0:57:36,10100,0.062097062319517136,873.6943862112569,301.8,18.99646870209114 103 | 0:58:15,10200,0.0608263423666358,2365.559411167901,765.9,51.49422398789717 104 | 0:58:40,10300,0.060862679332494736,1059.2093641603064,474.5,23.037598599492625 105 | 0:59:18,10400,0.06031456973403692,2837.1617978262902,777.8,61.767284015074765 106 | 0:59:59,10500,0.05958413891494274,2943.2443995556528,809.4,64.07811399862682 107 | 1:00:43,10600,0.05949912000447512,3072.2868475402115,877.8,66.88908538406125 108 | 1:01:23,10700,0.05919633124023676,2914.1461304970007,825.1,63.44425744237697 109 | 1:01:57,10800,0.05845240533351898,2457.386715268039,714.5,53.49452643301167 110 | 1:02:43,10900,0.058647090084850785,3306.560767209371,945.7,71.99234632516158 111 | 1:03:30,11000,0.05872486475855112,3580.328619223933,1000.0,77.95591575742209 112 | 1:04:01,11100,0.05813798546791077,2032.3342414930466,613.9,44.23547749407189 113 | 1:04:44,11200,0.057435166575014594,2920.2071715699103,891.4,63.57628696667683 114 | 1:05:20,11300,0.05736306045204401,2546.6892811916478,720.0,55.439831728887434 115 | 1:05:49,11400,0.057853570766747,898.2272793691515,567.4,19.53087626910361 116 | 1:06:28,11500,0.05715966358780861,2789.5256230289024,793.5,60.72961054903022 117 | 1:07:16,11600,0.056575519666075706,3373.5248584161077,1000.0,73.45104574673705 118 | 1:07:48,11700,0.05598588481545448,1870.3930556548617,642.6,40.70786277020267 119 | 1:08:22,11800,0.05605641216039658,2100.788792097527,706.5,45.726644051722694 120 | 1:09:00,11900,0.056266498155891895,2491.713257173852,760.4,54.242272066833664 121 | 1:09:37,12000,0.05553261686116457,1220.879507038972,749.8,26.55930911110198 122 | 1:10:12,12100,0.055310980193316935,2509.3450378432985,729.1,54.62635057518621 123 | 1:10:47,12200,0.05541072141379118,2202.605736237166,701.1,47.944553902310375 124 | 1:11:32,12300,0.05550422679632902,3583.4321528053406,952.9,78.02352098521594 125 | 1:12:03,12400,0.05502849627286196,1810.6636283520425,615.0,39.406758260493575 126 | 1:12:48,12500,0.05469878222793341,3442.672426201506,921.0,74.95730851106714 127 | 1:13:30,12600,0.05395554948598146,2995.761290297045,830.7,65.22210560318553 128 | 1:14:07,12700,0.054322229959070684,2013.616424856363,753.3,43.82774152978033 129 | 1:14:45,12800,0.05432156518101692,1613.9260824995595,746.2,35.121163710256134 130 | 1:15:32,12900,0.053966226428747176,3476.2475253206017,940.5,75.68868523524549 131 | 1:16:19,13000,0.05412171874195337,3609.5151894883115,952.9,78.59169580603024 132 | 1:17:02,13100,0.053484919518232345,2112.4067495820777,876.8,45.97972159757158 133 | 1:17:35,13200,0.05350233294069767,2050.6770102221813,640.5,44.63504367429042 134 | 1:18:23,13300,0.052814999409019944,3625.0935981508637,955.0,78.93104507958307 135 | 1:19:10,13400,0.05339811958372593,3496.5986860040343,941.5,76.13200083592561 136 | 1:19:59,13500,0.05256705578416586,3606.131805319031,961.6,78.51799450669566 137 | 1:20:40,13600,0.052404896467924116,2318.0159350712806,816.1,50.458569806199705 138 | 1:21:28,13700,0.05198549743741751,3560.8887976425226,957.4,77.5324521370649 139 | 1:22:13,13800,0.05160391606390476,3332.1988431184227,900.8,72.5508284284038 140 | 1:23:00,13900,0.05209978446364403,3497.451766861491,948.3,76.15058375896547 141 | 1:23:47,14000,0.0516047441214323,3566.8073266822284,951.8,77.6613772778563 142 | 1:24:35,14100,0.05175862479954958,3628.7424043738406,966.5,79.01052814925494 143 | 1:25:24,14200,0.051176839731633665,3639.1209615268044,1000.0,79.23660745598482 144 | 1:26:10,14300,0.05168245106935501,3045.3764123561573,913.8,66.30288708688529 145 | 1:26:49,14400,0.05087256103754043,2699.6489849911773,784.7,58.77180006349663 146 | 1:27:36,14500,0.051584865637123584,3352.7236093781626,948.6,72.99792573281763 147 | 1:28:22,14600,0.051273541562259196,3526.27379187904,954.4,76.77842280619355 148 | 1:29:03,14700,0.05075372889637947,3087.949756916624,880.4,67.23027536268764 149 | 1:29:52,14800,0.05068669371306896,3653.168443440106,1000.0,79.54260808068176 150 | 1:30:37,14900,0.050353191569447515,3584.6109906806037,958.7,78.04919997369751 151 | 1:31:22,15000,0.05044253397732973,3516.1208661480823,960.5,76.55725849821656 152 | 1:32:09,15100,0.050561795346438884,3652.305165699165,1000.0,79.52380303578866 153 | 1:32:55,15200,0.050162769369781014,3576.6880642852607,977.2,77.87661242801738 154 | 1:33:42,15300,0.0496873714402318,3719.003078548383,1000.0,80.97670421222779 155 | 1:34:29,15400,0.049916689582169056,3798.2096832440416,1000.0,82.70208607543883 156 | 1:35:15,15500,0.04934413980692625,3669.0213405194263,977.8,79.88793661995079 157 | 1:36:00,15600,0.04988970138132572,3524.722614732548,944.4,76.74463303669809 158 | 1:36:42,15700,0.0494575371965766,3346.7982208338276,904.2,72.86885116932439 159 | 1:37:28,15800,0.04870194889605045,3689.170749260478,987.2,80.32685739593681 160 | 1:38:13,15900,0.048884320706129074,3584.6023808466816,952.9,78.04901242303363 161 | 1:38:58,16000,0.04952946279197931,3445.6371528282116,962.4,75.02189006421855 162 | 1:39:45,16100,0.048653635792434216,3689.487718476959,995.2,80.3337620339959 163 | 1:40:29,16200,0.04895910393446684,3529.9106195019585,942.2,76.8576449423313 164 | 1:41:14,16300,0.04875406168401241,3599.880280415051,958.8,78.3818156144406 165 | 1:41:57,16400,0.048319988027215,3457.046344394226,933.8,75.27041999777067 166 | 1:42:44,16500,0.04863546460866928,3698.028082536486,1000.0,80.5197994144662 167 | 1:43:31,16600,0.048063377067446705,3731.270544411052,1000.0,81.24393019910524 168 | 1:44:16,16700,0.048136615045368675,3602.9238335669716,985.6,78.44811427006685 169 | 1:45:03,16800,0.04824067141860724,3731.0335520541676,1000.0,81.23876772160908 170 | 1:45:46,16900,0.04851640034466982,3411.691126188264,902.8,74.28243331161956 171 | 1:46:30,17000,0.047490435615181924,3524.0773632910436,955.1,76.7305773258308 172 | 1:47:12,17100,0.048166980035603046,3348.475271067876,888.7,72.90538287105102 173 | 1:47:57,17200,0.04686649281531572,3562.7102219148155,962.2,77.57212878292924 174 | 1:48:43,17300,0.047712755538523194,3668.1152882967654,1000.0,79.86819980534918 175 | 1:49:27,17400,0.047162753045558926,3295.0783224688043,915.3,71.74222069514853 176 | 1:50:05,17500,0.04735571425408125,2973.161166879114,806.4,64.72980015465055 177 | 1:50:51,17600,0.047372789904475215,3628.656922372046,1000.0,79.00866606848409 178 | 1:51:37,17700,0.04723315633833408,3522.1038574739628,965.9,76.68758784084004 179 | 1:52:13,17800,0.04710644524544477,2881.1054315598803,773.2,62.72452172194091 180 | 1:53:00,17900,0.04633235666900873,3643.2769196480285,1000.0,79.32713797164291 181 | 1:53:39,18000,0.047232932448387145,3068.7864523096523,815.3,66.81283519674311 182 | 1:54:13,18100,0.04670387715101242,2604.134519607962,724.0,56.691179048623965 183 | 1:54:58,18200,0.046496619582176206,3663.4346751902,973.7,79.76624056856828 184 | 1:55:41,18300,0.04675097044557333,3359.6542645027716,904.6,73.14889832781925 185 | 1:56:25,18400,0.04648404803127051,3401.7116735396107,932.2,74.06504782121861 186 | 1:57:11,18500,0.046322394274175165,3493.7759771978635,986.5,76.07051290069587 187 | 1:57:53,18600,0.045983492545783516,3246.6756046948653,884.3,70.68784938737481 188 | 1:58:34,18700,0.045901420973241326,3194.913798682191,875.3,69.56030602600394 189 | 1:59:18,18800,0.045691707953810695,3535.1720540801166,948.9,76.97225639210252 190 | 2:00:02,18900,0.04533644154667854,3522.9424058397626,930.9,76.7058541981386 191 | 2:00:47,19000,0.045702502317726615,3548.4087483589938,971.9,77.26059538006189 192 | 2:01:33,19100,0.046290372088551524,3547.4913433254696,967.7,77.24061126368494 193 | 2:02:14,19200,0.045040771886706354,3076.522425220322,844.4,66.98135027709087 194 | 2:03:00,19300,0.045772487185895445,3560.24620831815,969.2,77.51845441591493 195 | 2:03:45,19400,0.045721214152872565,3546.3807132790703,951.4,77.21641806734534 196 | 2:04:31,19500,0.045273531340062616,3684.1985050913245,1000.0,80.21854546990642 197 | 2:05:13,19600,0.045986538603901866,3298.421755705741,887.4,71.81505173102025 198 | 2:05:56,19700,0.04526801057159901,3392.336908912027,928.2,73.86083443620537 199 | 2:06:36,19800,0.045240545570850374,3025.6640998622634,826.4,65.87348771306291 200 | 2:07:15,19900,0.04544916365295649,3157.9841795720085,824.5,68.75585675977383 201 | 2:07:47,20000,0.04539958458393812,2211.1165392693933,642.3,48.12994734581914 202 | -------------------------------------------------------------------------------- /dt_runs/dt_walker2d-medium-v2_model_22-02-22-09-24-12_best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/d6694248b48c57c84fc7487e6e8017dcca861b02/dt_runs/dt_walker2d-medium-v2_model_22-02-22-09-24-12_best.pt -------------------------------------------------------------------------------- /media/halfcheetah-medium-v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/d6694248b48c57c84fc7487e6e8017dcca861b02/media/halfcheetah-medium-v2.gif -------------------------------------------------------------------------------- /media/halfcheetah-medium-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/d6694248b48c57c84fc7487e6e8017dcca861b02/media/halfcheetah-medium-v2.png -------------------------------------------------------------------------------- /media/hopper-medium-v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/d6694248b48c57c84fc7487e6e8017dcca861b02/media/hopper-medium-v2.gif -------------------------------------------------------------------------------- /media/hopper-medium-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/d6694248b48c57c84fc7487e6e8017dcca861b02/media/hopper-medium-v2.png -------------------------------------------------------------------------------- /media/walker2d-medium-v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/d6694248b48c57c84fc7487e6e8017dcca861b02/media/walker2d-medium-v2.gif -------------------------------------------------------------------------------- /media/walker2d-medium-v2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/d6694248b48c57c84fc7487e6e8017dcca861b02/media/walker2d-medium-v2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.21.0 2 | matplotlib==3.3.3 3 | numpy==1.18.2 4 | pandas==1.0.3 5 | torch==1.9.0 6 | -------------------------------------------------------------------------------- /scripts/plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def plot(args): 8 | 9 | # env_d4rl_name = 'halfcheetah-medium-v2' 10 | # log_dir = 'dt_runs/' 11 | # x_key = "num_updates" 12 | # y_key = "eval_d4rl_score" 13 | # y_smoothing_win = 5 14 | # plot_avg = False 15 | # save_fig = False 16 | 17 | env_d4rl_name = args.env_d4rl_name 18 | log_dir = args.log_dir 19 | x_key = args.x_key 20 | y_key = args.y_key 21 | y_smoothing_win = args.smoothing_window 22 | plot_avg = args.plot_avg 23 | save_fig = args.save_fig 24 | 25 | if plot_avg: 26 | save_fig_path = env_d4rl_name + "_avg.png" 27 | else: 28 | save_fig_path = env_d4rl_name + ".png" 29 | 30 | all_files = glob.glob(log_dir + f'/dt_{env_d4rl_name}*.csv') 31 | 32 | ax = plt.gca() 33 | ax.set_title(env_d4rl_name) 34 | 35 | if plot_avg: 36 | name_list = [] 37 | df_list = [] 38 | for filename in all_files: 39 | frame = pd.read_csv(filename, index_col=None, header=0) 40 | print(filename, frame.shape) 41 | frame['y_smooth'] = frame[y_key].rolling(window=y_smoothing_win).mean() 42 | df_list.append(frame) 43 | 44 | df_concat = pd.concat(df_list) 45 | df_concat_groupby = df_concat.groupby(df_concat.index) 46 | data_avg = df_concat_groupby.mean() 47 | 48 | data_avg.plot(x=x_key, y='y_smooth', ax=ax) 49 | 50 | ax.set_xlabel(x_key) 51 | ax.set_ylabel(y_key) 52 | ax.legend(['avg of all runs'], loc='lower right') 53 | 54 | if save_fig: 55 | plt.savefig(save_fig_path) 56 | 57 | plt.show() 58 | 59 | else: 60 | name_list = [] 61 | for filename in all_files: 62 | frame = pd.read_csv(filename, index_col=None, header=0) 63 | print(filename, frame.shape) 64 | frame['y_smooth'] = frame[y_key].rolling(window=y_smoothing_win).mean() 65 | frame.plot(x=x_key, y='y_smooth', ax=ax) 66 | name_list.append(filename.split('/')[-1]) 67 | 68 | ax.set_xlabel(x_key) 69 | ax.set_ylabel(y_key) 70 | ax.legend(name_list, loc='lower right') 71 | 72 | if save_fig: 73 | plt.savefig(save_fig_path) 74 | 75 | plt.show() 76 | 77 | 78 | if __name__ == "__main__": 79 | 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument('--env_d4rl_name', type=str, default='halfcheetah-medium-v2') 82 | parser.add_argument('--log_dir', type=str, default='dt_runs/') 83 | parser.add_argument('--x_key', type=str, default='num_updates') 84 | parser.add_argument('--y_key', type=str, default='eval_d4rl_score') 85 | parser.add_argument('--smoothing_window', type=int, default=1) 86 | parser.add_argument("--plot_avg", action="store_true", default=False, 87 | help="plot avg of all logs else plot separately") 88 | parser.add_argument("--save_fig", action="store_true", default=False, 89 | help="save figure if true") 90 | 91 | args = parser.parse_args() 92 | 93 | plot(args) 94 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import gym 4 | import torch 5 | import numpy as np 6 | from decision_transformer.utils import evaluate_on_env, get_d4rl_normalized_score, get_d4rl_dataset_stats 7 | from decision_transformer.model import DecisionTransformer 8 | 9 | def test(args): 10 | 11 | eval_dataset = args.dataset # medium / medium-replay / medium-expert 12 | eval_rtg_scale = args.rtg_scale # normalize returns to go 13 | 14 | if args.env == 'walker2d': 15 | eval_env_name = 'Walker2d-v3' 16 | eval_rtg_target = 5000 17 | eval_env_d4rl_name = f'walker2d-{eval_dataset}-v2' 18 | 19 | elif args.env == 'halfcheetah': 20 | eval_env_name = 'HalfCheetah-v3' 21 | eval_rtg_target = 6000 22 | eval_env_d4rl_name = f'halfcheetah-{eval_dataset}-v2' 23 | 24 | elif args.env == 'hopper': 25 | eval_env_name = 'Hopper-v3' 26 | eval_rtg_target = 3600 27 | eval_env_d4rl_name = f'hopper-{eval_dataset}-v2' 28 | 29 | else: 30 | raise NotImplementedError 31 | 32 | render = args.render # render the env frames 33 | 34 | num_test_eval_ep = args.num_eval_ep # num of evaluation episodes 35 | eval_max_eval_ep_len = args.max_eval_ep_len # max len of one episode 36 | 37 | context_len = args.context_len # K in decision transformer 38 | n_blocks = args.n_blocks # num of transformer blocks 39 | embed_dim = args.embed_dim # embedding (hidden) dim of transformer 40 | n_heads = args.n_heads # num of transformer heads 41 | dropout_p = args.dropout_p # dropout probability 42 | 43 | 44 | eval_chk_pt_dir = args.chk_pt_dir 45 | 46 | eval_chk_pt_name = args.chk_pt_name 47 | eval_chk_pt_list = [eval_chk_pt_name] 48 | 49 | 50 | ## manually override check point list 51 | ## passing a list will evaluate on all checkpoints 52 | ## and output mean and std score 53 | 54 | # eval_chk_pt_list = [ 55 | # "dt_halfcheetah-medium-v2_model_22-02-09-10-38-54_best.pt", 56 | # "dt_halfcheetah-medium-v2_model_22-02-10-11-56-32_best.pt", 57 | # "dt_halfcheetah-medium-v2_model_22-02-11-10-13-57_best.pt" 58 | # ] 59 | 60 | 61 | device = torch.device(args.device) 62 | print("device set to: ", device) 63 | 64 | env_data_stats = get_d4rl_dataset_stats(eval_env_d4rl_name) 65 | eval_state_mean = np.array(env_data_stats['state_mean']) 66 | eval_state_std = np.array(env_data_stats['state_std']) 67 | 68 | eval_env = gym.make(eval_env_name) 69 | 70 | state_dim = eval_env.observation_space.shape[0] 71 | act_dim = eval_env.action_space.shape[0] 72 | 73 | all_scores = [] 74 | 75 | for eval_chk_pt_name in eval_chk_pt_list: 76 | 77 | eval_model = DecisionTransformer( 78 | state_dim=state_dim, 79 | act_dim=act_dim, 80 | n_blocks=n_blocks, 81 | h_dim=embed_dim, 82 | context_len=context_len, 83 | n_heads=n_heads, 84 | drop_p=dropout_p, 85 | ).to(device) 86 | 87 | eval_chk_pt_path = os.path.join(eval_chk_pt_dir, eval_chk_pt_name) 88 | 89 | # load checkpoint 90 | eval_model.load_state_dict(torch.load(eval_chk_pt_path, map_location=device)) 91 | 92 | print("model loaded from: " + eval_chk_pt_path) 93 | 94 | # evaluate on env 95 | results = evaluate_on_env(eval_model, device, context_len, 96 | eval_env, eval_rtg_target, eval_rtg_scale, 97 | num_test_eval_ep, eval_max_eval_ep_len, 98 | eval_state_mean, eval_state_std, render=render) 99 | print(results) 100 | 101 | norm_score = get_d4rl_normalized_score(results['eval/avg_reward'], eval_env_name) * 100 102 | print("normalized d4rl score: " + format(norm_score, ".5f")) 103 | 104 | all_scores.append(norm_score) 105 | 106 | print("=" * 60) 107 | all_scores = np.array(all_scores) 108 | print("evaluated on env: " + eval_env_name) 109 | print("total num of checkpoints evaluated: " + str(len(eval_chk_pt_list))) 110 | print("d4rl score mean: " + format(all_scores.mean(), ".5f")) 111 | print("d4rl score std: " + format(all_scores.std(), ".5f")) 112 | print("d4rl score var: " + format(all_scores.var(), ".5f")) 113 | print("=" * 60) 114 | 115 | 116 | if __name__ == "__main__": 117 | 118 | parser = argparse.ArgumentParser() 119 | 120 | parser.add_argument('--env', type=str, default='halfcheetah') 121 | parser.add_argument('--dataset', type=str, default='medium') 122 | parser.add_argument('--rtg_scale', type=int, default=1000) 123 | 124 | parser.add_argument('--max_eval_ep_len', type=int, default=1000) 125 | parser.add_argument('--num_eval_ep', type=int, default=10) 126 | 127 | parser.add_argument("--render", action="store_true", default=False) 128 | 129 | parser.add_argument('--chk_pt_dir', type=str, default='dt_runs/') 130 | parser.add_argument('--chk_pt_name', type=str, 131 | default='dt_halfcheetah-medium-v2_model_22-02-13-09-03-10_best.pt') 132 | 133 | parser.add_argument('--context_len', type=int, default=20) 134 | parser.add_argument('--n_blocks', type=int, default=3) 135 | parser.add_argument('--embed_dim', type=int, default=128) 136 | parser.add_argument('--n_heads', type=int, default=1) 137 | parser.add_argument('--dropout_p', type=float, default=0.1) 138 | 139 | parser.add_argument('--device', type=str, default='cuda') 140 | 141 | args = parser.parse_args() 142 | 143 | test(args) 144 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import random 5 | import csv 6 | from datetime import datetime 7 | 8 | import numpy as np 9 | import gym 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.utils.data import DataLoader 14 | 15 | from decision_transformer.utils import D4RLTrajectoryDataset, evaluate_on_env, get_d4rl_normalized_score 16 | from decision_transformer.model import DecisionTransformer 17 | 18 | 19 | def train(args): 20 | 21 | dataset = args.dataset # medium / medium-replay / medium-expert 22 | rtg_scale = args.rtg_scale # normalize returns to go 23 | 24 | # use v3 env for evaluation because 25 | # Decision Transformer paper evaluates results on v3 envs 26 | 27 | if args.env == 'walker2d': 28 | env_name = 'Walker2d-v3' 29 | rtg_target = 5000 30 | env_d4rl_name = f'walker2d-{dataset}-v2' 31 | 32 | elif args.env == 'halfcheetah': 33 | env_name = 'HalfCheetah-v3' 34 | rtg_target = 6000 35 | env_d4rl_name = f'halfcheetah-{dataset}-v2' 36 | 37 | elif args.env == 'hopper': 38 | env_name = 'Hopper-v3' 39 | rtg_target = 3600 40 | env_d4rl_name = f'hopper-{dataset}-v2' 41 | 42 | else: 43 | raise NotImplementedError 44 | 45 | max_eval_ep_len = args.max_eval_ep_len # max len of one episode 46 | num_eval_ep = args.num_eval_ep # num of evaluation episodes 47 | 48 | batch_size = args.batch_size # training batch size 49 | lr = args.lr # learning rate 50 | wt_decay = args.wt_decay # weight decay 51 | warmup_steps = args.warmup_steps # warmup steps for lr scheduler 52 | 53 | # total updates = max_train_iters x num_updates_per_iter 54 | max_train_iters = args.max_train_iters 55 | num_updates_per_iter = args.num_updates_per_iter 56 | 57 | context_len = args.context_len # K in decision transformer 58 | n_blocks = args.n_blocks # num of transformer blocks 59 | embed_dim = args.embed_dim # embedding (hidden) dim of transformer 60 | n_heads = args.n_heads # num of transformer heads 61 | dropout_p = args.dropout_p # dropout probability 62 | 63 | # load data from this file 64 | dataset_path = f'{args.dataset_dir}/{env_d4rl_name}.pkl' 65 | 66 | # saves model and csv in this directory 67 | log_dir = args.log_dir 68 | if not os.path.exists(log_dir): 69 | os.makedirs(log_dir) 70 | 71 | # training and evaluation device 72 | device = torch.device(args.device) 73 | 74 | start_time = datetime.now().replace(microsecond=0) 75 | start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S") 76 | 77 | prefix = "dt_" + env_d4rl_name 78 | 79 | save_model_name = prefix + "_model_" + start_time_str + ".pt" 80 | save_model_path = os.path.join(log_dir, save_model_name) 81 | save_best_model_path = save_model_path[:-3] + "_best.pt" 82 | 83 | log_csv_name = prefix + "_log_" + start_time_str + ".csv" 84 | log_csv_path = os.path.join(log_dir, log_csv_name) 85 | 86 | csv_writer = csv.writer(open(log_csv_path, 'a', 1)) 87 | csv_header = (["duration", "num_updates", "action_loss", 88 | "eval_avg_reward", "eval_avg_ep_len", "eval_d4rl_score"]) 89 | 90 | csv_writer.writerow(csv_header) 91 | 92 | print("=" * 60) 93 | print("start time: " + start_time_str) 94 | print("=" * 60) 95 | 96 | print("device set to: " + str(device)) 97 | print("dataset path: " + dataset_path) 98 | print("model save path: " + save_model_path) 99 | print("log csv save path: " + log_csv_path) 100 | 101 | traj_dataset = D4RLTrajectoryDataset(dataset_path, context_len, rtg_scale) 102 | 103 | traj_data_loader = DataLoader( 104 | traj_dataset, 105 | batch_size=batch_size, 106 | shuffle=True, 107 | pin_memory=True, 108 | drop_last=True 109 | ) 110 | 111 | data_iter = iter(traj_data_loader) 112 | 113 | ## get state stats from dataset 114 | state_mean, state_std = traj_dataset.get_state_stats() 115 | 116 | env = gym.make(env_name) 117 | 118 | state_dim = env.observation_space.shape[0] 119 | act_dim = env.action_space.shape[0] 120 | 121 | model = DecisionTransformer( 122 | state_dim=state_dim, 123 | act_dim=act_dim, 124 | n_blocks=n_blocks, 125 | h_dim=embed_dim, 126 | context_len=context_len, 127 | n_heads=n_heads, 128 | drop_p=dropout_p, 129 | ).to(device) 130 | 131 | optimizer = torch.optim.AdamW( 132 | model.parameters(), 133 | lr=lr, 134 | weight_decay=wt_decay 135 | ) 136 | 137 | scheduler = torch.optim.lr_scheduler.LambdaLR( 138 | optimizer, 139 | lambda steps: min((steps+1)/warmup_steps, 1) 140 | ) 141 | 142 | max_d4rl_score = -1.0 143 | total_updates = 0 144 | 145 | for i_train_iter in range(max_train_iters): 146 | 147 | log_action_losses = [] 148 | model.train() 149 | 150 | for _ in range(num_updates_per_iter): 151 | try: 152 | timesteps, states, actions, returns_to_go, traj_mask = next(data_iter) 153 | except StopIteration: 154 | data_iter = iter(traj_data_loader) 155 | timesteps, states, actions, returns_to_go, traj_mask = next(data_iter) 156 | 157 | timesteps = timesteps.to(device) # B x T 158 | states = states.to(device) # B x T x state_dim 159 | actions = actions.to(device) # B x T x act_dim 160 | returns_to_go = returns_to_go.to(device).unsqueeze(dim=-1) # B x T x 1 161 | traj_mask = traj_mask.to(device) # B x T 162 | action_target = torch.clone(actions).detach().to(device) 163 | 164 | state_preds, action_preds, return_preds = model.forward( 165 | timesteps=timesteps, 166 | states=states, 167 | actions=actions, 168 | returns_to_go=returns_to_go 169 | ) 170 | # only consider non padded elements 171 | action_preds = action_preds.view(-1, act_dim)[traj_mask.view(-1,) > 0] 172 | action_target = action_target.view(-1, act_dim)[traj_mask.view(-1,) > 0] 173 | 174 | action_loss = F.mse_loss(action_preds, action_target, reduction='mean') 175 | 176 | optimizer.zero_grad() 177 | action_loss.backward() 178 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25) 179 | optimizer.step() 180 | scheduler.step() 181 | 182 | log_action_losses.append(action_loss.detach().cpu().item()) 183 | 184 | # evaluate action accuracy 185 | results = evaluate_on_env(model, device, context_len, env, rtg_target, rtg_scale, 186 | num_eval_ep, max_eval_ep_len, state_mean, state_std) 187 | 188 | eval_avg_reward = results['eval/avg_reward'] 189 | eval_avg_ep_len = results['eval/avg_ep_len'] 190 | eval_d4rl_score = get_d4rl_normalized_score(results['eval/avg_reward'], env_name) * 100 191 | 192 | mean_action_loss = np.mean(log_action_losses) 193 | time_elapsed = str(datetime.now().replace(microsecond=0) - start_time) 194 | 195 | total_updates += num_updates_per_iter 196 | 197 | log_str = ("=" * 60 + '\n' + 198 | "time elapsed: " + time_elapsed + '\n' + 199 | "num of updates: " + str(total_updates) + '\n' + 200 | "action loss: " + format(mean_action_loss, ".5f") + '\n' + 201 | "eval avg reward: " + format(eval_avg_reward, ".5f") + '\n' + 202 | "eval avg ep len: " + format(eval_avg_ep_len, ".5f") + '\n' + 203 | "eval d4rl score: " + format(eval_d4rl_score, ".5f") 204 | ) 205 | 206 | print(log_str) 207 | 208 | log_data = [time_elapsed, total_updates, mean_action_loss, 209 | eval_avg_reward, eval_avg_ep_len, 210 | eval_d4rl_score] 211 | 212 | csv_writer.writerow(log_data) 213 | 214 | # save model 215 | print("max d4rl score: " + format(max_d4rl_score, ".5f")) 216 | if eval_d4rl_score >= max_d4rl_score: 217 | print("saving max d4rl score model at: " + save_best_model_path) 218 | torch.save(model.state_dict(), save_best_model_path) 219 | max_d4rl_score = eval_d4rl_score 220 | 221 | print("saving current model at: " + save_model_path) 222 | torch.save(model.state_dict(), save_model_path) 223 | 224 | 225 | print("=" * 60) 226 | print("finished training!") 227 | print("=" * 60) 228 | end_time = datetime.now().replace(microsecond=0) 229 | time_elapsed = str(end_time - start_time) 230 | end_time_str = end_time.strftime("%y-%m-%d-%H-%M-%S") 231 | print("started training at: " + start_time_str) 232 | print("finished training at: " + end_time_str) 233 | print("total training time: " + time_elapsed) 234 | print("max d4rl score: " + format(max_d4rl_score, ".5f")) 235 | print("saved max d4rl score model at: " + save_best_model_path) 236 | print("saved last updated model at: " + save_model_path) 237 | print("=" * 60) 238 | 239 | 240 | 241 | if __name__ == "__main__": 242 | 243 | parser = argparse.ArgumentParser() 244 | 245 | parser.add_argument('--env', type=str, default='halfcheetah') 246 | parser.add_argument('--dataset', type=str, default='medium') 247 | parser.add_argument('--rtg_scale', type=int, default=1000) 248 | 249 | parser.add_argument('--max_eval_ep_len', type=int, default=1000) 250 | parser.add_argument('--num_eval_ep', type=int, default=10) 251 | 252 | parser.add_argument('--dataset_dir', type=str, default='data/') 253 | parser.add_argument('--log_dir', type=str, default='dt_runs/') 254 | 255 | parser.add_argument('--context_len', type=int, default=20) 256 | parser.add_argument('--n_blocks', type=int, default=3) 257 | parser.add_argument('--embed_dim', type=int, default=128) 258 | parser.add_argument('--n_heads', type=int, default=1) 259 | parser.add_argument('--dropout_p', type=float, default=0.1) 260 | 261 | parser.add_argument('--batch_size', type=int, default=64) 262 | parser.add_argument('--lr', type=float, default=1e-4) 263 | parser.add_argument('--wt_decay', type=float, default=1e-4) 264 | parser.add_argument('--warmup_steps', type=int, default=10000) 265 | 266 | parser.add_argument('--max_train_iters', type=int, default=200) 267 | parser.add_argument('--num_updates_per_iter', type=int, default=100) 268 | 269 | parser.add_argument('--device', type=str, default='cuda') 270 | 271 | args = parser.parse_args() 272 | 273 | train(args) 274 | --------------------------------------------------------------------------------