├── .gitignore ├── Readme.md ├── config.py ├── environment.yml ├── main.py ├── models.py ├── pack_model.py ├── pack_step.py ├── plot.py ├── plot3d.py ├── problems ├── __init__.py ├── pack2d │ ├── __init__.py │ ├── pack2d.py │ ├── render.py │ ├── resources │ │ └── MONOFONT.TTF │ ├── state_pack2d.py │ └── viewer.py └── pack3d │ ├── __init__.py │ ├── load_br.py │ ├── pack3d.py │ ├── render.py │ ├── state_pack3d.py │ └── viewer.py ├── resources └── MONOFONT.TTF ├── trainer.py └── utils ├── __init__.py ├── functions.py ├── log_graph.py ├── logger.py ├── math_util.py ├── seeding.py ├── statistic.py ├── truncated_normal.py └── utils_fb.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Recurrent Conditional Query Learning (RCQL) 2 | This repository contains the Pytorch implementation of 3 | 4 | [One Model Packs Thousands of Items with Recurrent Conditional Query Learning](https://www.sciencedirect.com/science/article/pii/S095070512100945X) 5 | 6 | Dongda Li, Zhaoquan Gu, Yuexuan Wang, Changwei Ren, Francis C.M. Lau 7 | 8 | We propose a Recurrent Conditional Query Learning (RCQL) method to solve both 2D and 3D packing problems. We first embed states by a recurrent encoder, and then adopt attention with conditional queries from previous actions. The conditional query mechanism fills the information gap between learning steps, which shapes the problem as a Markov decision process. Benefiting from the recurrence, a single RCQL model is capable of handling different sizes of packing problems. Experiment results show that RCQL can effectively learn strong heuristics for offline and online strip packing problems (SPPs), out- performing a wide range of baselines in space utilization ratio. RCQL reduces the average bin gap ratio by 1.83% in offline 2D 40-box cases and 7.84% in 3D cases compared with state-of-the-art methods. Meanwhile, our method also achieves 5.64% higher space utilization ratio for SPPs with 1000 items than the state of the art. 9 | 10 | ## Usage 11 | 12 | ### Preparation 13 | 14 | 1. Install conda 15 | 2. Run `conda env create -f environment.yml` 16 | 17 | ### Train 18 | 19 | 1. Modify the config file in `config.py` as you need. 20 | 2. Run `python main.py`. 21 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | 6 | #!/usr/bin/env python3 7 | 8 | # command-line arguments with their default values 9 | 10 | PARAMS_CONFIG = { 11 | # env-specific 12 | 'env_params': { 13 | '--run-name': { 14 | 'type': str, 15 | 'default': 'run', 16 | 'help': 'run name', 17 | 'dest': 'run_name' 18 | }, 19 | '--output-dir': { 20 | 'type': str, 21 | 'default': 'outputs', 22 | 'help': 'output dir', 23 | 'dest': 'output_dir' 24 | }, 25 | }, 26 | 'rl_params': { 27 | '--ent-coef': { 28 | 'type': float, 29 | 'default': 3e-3, 30 | 'help': 'entropy coefficient', 31 | 'dest': 'ent_coef' 32 | }, 33 | '--soft-temp': { 34 | 'type': float, 35 | 'default': 5e-3, 36 | 'help': 'soft temperature of entropy regularization', 37 | 'dest': 'soft_temp' 38 | }, 39 | '--gamma': { 40 | 'type': float, 41 | 'default': 0.96, 42 | 'help': 'reward discount factor', 43 | 'dest': 'gamma' 44 | }, 45 | '--nsteps': { 46 | 'type': int, 47 | 'default': 10, 48 | 'help': 'GAE rolling out steps', 49 | 'dest': 'nsteps' 50 | }, 51 | '--lam': { 52 | 'type': float, 53 | 'default': 0.98, 54 | 'help': 'lam for General Advantage Estimation', 55 | 'dest': 'lam' 56 | }, 57 | '--target-entropy': { 58 | 'type': float, 59 | 'default': -0.6, 60 | 'help': 'position target entropy for entropy regularization', 61 | 'dest': 'tgt_entropy' 62 | }, 63 | }, 64 | # model-specific 65 | 'model_params': { 66 | '--hid-sz': { 67 | 'type': int, 68 | 'default': 128, 69 | 'help': 'hidden size (i.e. model size)', 70 | 'dest': 'hidden_size' 71 | }, 72 | '--inner-hid-sz': { 73 | 'type': int, 74 | 'default': 512, 75 | 'help': 'inner hidden size of FF layer', 76 | 'dest': 'inner_hidden_size' 77 | }, 78 | '--encoder-layers': { 79 | 'type': int, 80 | 'default': 3, 81 | 'help': 'number of layers', 82 | 'dest': 'encoder_layers' 83 | }, 84 | '--decoder-layers': { 85 | 'type': int, 86 | 'default': 1, 87 | 'help': 'number of layers', 88 | 'dest': 'decoder_layers' 89 | }, 90 | '--critic-encoder-layers': { 91 | 'type': int, 92 | 'default': 3, 93 | 'help': 'number of layers', 94 | 'dest': 'c_encoder_layers' 95 | }, 96 | '--critic-decoder-layers': { 97 | 'type': int, 98 | 'default': 1, 99 | 'help': 'number of layers', 100 | 'dest': 'c_decoder_layers' 101 | }, 102 | '--block-sz': { 103 | 'type': int, 104 | 'default': 20, 105 | 'help': 'block size ' 106 | '(the length of sequence to process in parallel)', 107 | 'dest': 'block_size' 108 | }, 109 | '--nheads': { 110 | 'type': int, 111 | 'default': 8, 112 | 'help': 'number of self-attention heads', 113 | 'dest': 'nb_heads' 114 | }, 115 | '--attn-span': { 116 | 'type': int, 117 | 'default': 20, 118 | 'help': 'length of the attention span', 119 | 'dest': 'attn_span' 120 | }, 121 | '--dropout': { 122 | 'type': float, 123 | 'default': 0.0, 124 | 'help': 'dropout rate of ReLU and attention', 125 | 'dest': 'dropout' 126 | }, 127 | '--normalization': { 128 | 'type': str, 129 | 'default': 'batch', 130 | 'help': 'Normalization type,' 131 | 'batch or instance or layer', 132 | 'dest': 'normalization' 133 | }, 134 | '--head-hid': { 135 | 'type': int, 136 | 'default': 128, 137 | 'help': 'head hidden dim(select and rotate)', 138 | 'dest': 'head_hidden' 139 | }, 140 | '--head-hid-pos': { 141 | 'type': int, 142 | 'default': 512, 143 | 'help': 'position head hidden dim', 144 | 'dest': 'head_hidden_pos' 145 | }, 146 | }, 147 | # problem 148 | 'problem_params':{ 149 | '--problem-type': { 150 | 'type': str, 151 | 'default': 'pack3d', 152 | 'help': 'problem type (2d or 3d)', 153 | 'dest': 'problem_type' 154 | }, 155 | '--online': { 156 | 'action': 'store_true', 157 | 'default': False, 158 | 'help': 'on-line packing', 159 | 'dest': 'on_line' 160 | }, 161 | '--noquery': { 162 | 'type': bool, 163 | 'default': False, 164 | 'help': 'no query model', 165 | 'dest': 'no_query' 166 | }, 167 | '--block-num': { 168 | 'type': int, 169 | 'default': 10, 170 | 'help': 'number of boxes in each instance', 171 | 'dest': 'block_num' 172 | }, 173 | '--position-options': { 174 | 'type': int, 175 | 'default': 128, 176 | 'help': 'position options', 177 | 'dest': 'p_options' 178 | }, 179 | '--size-p1': { 180 | 'type': float, 181 | # 'default': 0.02, # 2d 182 | 'default': 0.2, # 3d 183 | 'help': 'box mean for normal distribution sampling' 184 | 'box size high bound for uniform distribution sampling', 185 | 'dest': 'size_p1' 186 | }, 187 | '--size-p2': { 188 | 'type': float, 189 | # 'default': 0.4, # 2d 190 | 'default': 0.8, # 3d 191 | 'help': 'box variance for normal distribution sampling' 192 | 'box size high bound for uniform distribution sampling', 193 | 'dest': 'size_p2' 194 | }, 195 | '--data-distribution': { 196 | 'type': str, 197 | 'default': 'uniform', 198 | 'help': 'Data distribution to use during training', 199 | 'dest': 'distribution' 200 | }, 201 | }, 202 | # optimization-specific 203 | 'optim_params': { 204 | '--actor-lr': { 205 | 'type': float, 206 | 'default': 4e-5, 207 | 'help': 'actor learning rate', 208 | 'dest': 'actor_lr' 209 | }, 210 | '--critic-lr': { 211 | 'type': float, 212 | 'default': 1e-4, 213 | 'help': 'critic learning rate', 214 | 'dest': 'critic_lr' 215 | }, 216 | '--lr-warmup': { 217 | 'type': int, 218 | 'default': 100, 219 | 'help': 'linearly increase LR from 0 ' 220 | 'during first lr_warmup updates' 221 | 'warmup_epochs=lr_warmup/(block_size/nsteps)', 222 | 'dest': 'lr_warmup' 223 | }, 224 | '--grad-clip': { 225 | 'type': float, 226 | 'default': 5, 227 | 'help': 'clip gradient of each module parameters by a given ' 228 | 'value', 229 | 'dest': 'grad_clip' 230 | }, 231 | }, 232 | # trainer-specific 233 | 'trainer_params': { 234 | '--batch-sz': { 235 | 'type': int, 236 | 'default': 128, 237 | 'help': 'batch size', 238 | 'dest': 'batch_size' 239 | }, 240 | '--niter': { 241 | 'type': int, 242 | 'default': 100000, 243 | 'help': 'number of iterations to train', 244 | 'dest': 'nb_iter' 245 | }, 246 | '--log-interval': { 247 | 'type': int, 248 | 'default': 5, 249 | 'help': 'number of epoch per command-line print log', 250 | 'dest': 'log_interval' 251 | }, 252 | '--checkpoint-interval': { 253 | 'type': int, 254 | 'default': 200, 255 | 'help': 'number of epoch per checkpoint', 256 | 'dest': 'checkpoint_interval' 257 | }, 258 | '--no-tensorboard': { 259 | 'action': 'store_true', 260 | 'default': False, 261 | 'help': 'disable tensorboard.', 262 | 'dest': 'no_tensorboard' 263 | }, 264 | '--checkpoint': { 265 | 'type': str, 266 | 'default': '', 267 | 'help': 'path to save/load model', 268 | 'dest': 'checkpoint_path' 269 | }, 270 | '--full-eval-mode': { 271 | 'action': 'store_true', 272 | 'default': False, 273 | 'help': 'do evaluation on the whole validation and the test data', 274 | 'dest': 'full_eval_mode' 275 | }, 276 | }, 277 | # adaptive attention span specific params 278 | 'adapt_span_params': { 279 | '--adapt-span': { 280 | 'action': 'store_true', 281 | 'default': False, 282 | 'help': 'enable adaptive attention span', 283 | 'dest': 'adapt_span_enabled' 284 | }, 285 | '--adapt-span-loss': { 286 | 'type': float, 287 | 'default': 0, 288 | 'help': 'the loss coefficient for span lengths', 289 | 'dest': 'adapt_span_loss' 290 | }, 291 | '--adapt-span-ramp': { 292 | 'type': int, 293 | 'default': 16, 294 | 'help': 'ramp length of the soft masking function', 295 | 'dest': 'adapt_span_ramp' 296 | }, 297 | '--adapt-span-init': { 298 | 'type': float, 299 | 'default': 0, 300 | 'help': 'initial attention span ratio', 301 | 'dest': 'adapt_span_init' 302 | }, 303 | '--adapt-span-cache': { 304 | 'action': 'store_true', 305 | 'default': False, 306 | 'help': 'adapt cache size as well to reduce memory usage', 307 | 'dest': 'adapt_span_cache' 308 | }, 309 | }, 310 | } 311 | 312 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _ipyw_jlab_nb_ext_conf=0.1.0=py38_0 8 | - _libgcc_mutex=0.1=main 9 | - alabaster=0.7.12=py_0 10 | - anaconda=2020.07=py38_0 11 | - anaconda-clean=1.1.0=py38_1 12 | - anaconda-client=1.7.2=py38_0 13 | - anaconda-navigator=1.9.12=py38_0 14 | - anaconda-project=0.8.4=py_0 15 | - argh=0.26.2=py38_0 16 | - asn1crypto=1.3.0=py38_0 17 | - astroid=2.4.2=py38_0 18 | - astropy=4.0.1.post1=py38h7b6447c_1 19 | - atomicwrites=1.4.0=py_0 20 | - attrs=19.3.0=py_0 21 | - automat=20.2.0=py_0 22 | - autopep8=1.5.3=py_0 23 | - babel=2.8.0=py_0 24 | - backcall=0.2.0=py_0 25 | - backports=1.0=py_2 26 | - backports.functools_lru_cache=1.6.1=py_0 27 | - backports.shutil_get_terminal_size=1.0.0=py38_2 28 | - backports.tempfile=1.0=py_1 29 | - backports.weakref=1.0.post1=py_1 30 | - bcrypt=3.2.0=py38h7b6447c_0 31 | - beautifulsoup4=4.9.1=py38_0 32 | - bitarray=1.4.0=py38h7b6447c_0 33 | - bkcharts=0.2=py38_0 34 | - blas=1.0=mkl 35 | - bleach=3.1.5=py_0 36 | - blosc=1.19.0=hd408876_0 37 | - bokeh=2.1.1=py38_0 38 | - boto=2.49.0=py38_0 39 | - bottleneck=1.3.2=py38heb32a55_1 40 | - brotlipy=0.7.0=py38h7b6447c_1000 41 | - bzip2=1.0.8=h7b6447c_0 42 | - ca-certificates=2020.6.24=0 43 | - cairo=1.14.12=h8948797_3 44 | - certifi=2020.6.20=py38_0 45 | - cffi=1.14.0=py38he30daa8_1 46 | - chardet=3.0.4=py38_1003 47 | - click=7.1.2=py_0 48 | - cloudpickle=1.5.0=py_0 49 | - clyent=1.2.2=py38_1 50 | - colorama=0.4.3=py_0 51 | - conda=4.9.2=py38h06a4308_0 52 | - conda-build=3.18.11=py38_0 53 | - conda-env=2.6.0=1 54 | - conda-package-handling=1.6.1=py38h7b6447c_0 55 | - conda-verify=3.4.2=py_1 56 | - constantly=15.1.0=py_0 57 | - contextlib2=0.6.0.post1=py_0 58 | - cryptography=2.9.2=py38h1ba5d50_0 59 | - cudatoolkit=10.2.89=hfd86e86_1 60 | - curl=7.71.1=hbc83047_1 61 | - cycler=0.10.0=py38_0 62 | - cython=0.29.21=py38he6710b0_0 63 | - cytoolz=0.10.1=py38h7b6447c_0 64 | - dask=2.20.0=py_0 65 | - dask-core=2.20.0=py_0 66 | - dbus=1.13.16=hb2f20db_0 67 | - decorator=4.4.2=py_0 68 | - defusedxml=0.6.0=py_0 69 | - diff-match-patch=20200713=py_0 70 | - distributed=2.20.0=py38_0 71 | - docutils=0.16=py38_1 72 | - entrypoints=0.3=py38_0 73 | - et_xmlfile=1.0.1=py_1001 74 | - expat=2.2.9=he6710b0_2 75 | - fastcache=1.1.0=py38h7b6447c_0 76 | - ffmpeg=4.2.2=h20bf706_0 77 | - filelock=3.0.12=py_0 78 | - flake8=3.8.3=py_0 79 | - flask=1.1.2=py_0 80 | - fontconfig=2.13.0=h9420a91_0 81 | - freetype=2.10.2=h5ab3b9f_0 82 | - fribidi=1.0.9=h7b6447c_0 83 | - fsspec=0.7.4=py_0 84 | - future=0.18.2=py38_1 85 | - get_terminal_size=1.0.0=haa9412d_0 86 | - gevent=20.6.2=py38h7b6447c_0 87 | - glib=2.65.0=h3eb4bd4_0 88 | - glob2=0.7=py_0 89 | - gmp=6.1.2=h6c8ec71_1 90 | - gmpy2=2.0.8=py38hd5f6e3b_3 91 | - gnutls=3.6.5=h71b1129_1002 92 | - graphite2=1.3.14=h23475e2_0 93 | - greenlet=0.4.16=py38h7b6447c_0 94 | - gst-plugins-base=1.14.0=hbbd80ab_1 95 | - gstreamer=1.14.0=hb31296c_0 96 | - h5py=2.10.0=py38h7918eee_0 97 | - harfbuzz=2.4.0=hca77d97_1 98 | - hdf5=1.10.4=hb1b8bf9_0 99 | - heapdict=1.0.1=py_0 100 | - html5lib=1.1=py_0 101 | - hyperlink=20.0.1=py_0 102 | - icu=58.2=he6710b0_3 103 | - idna=2.10=py_0 104 | - imageio=2.9.0=py_0 105 | - imagesize=1.2.0=py_0 106 | - importlib-metadata=1.7.0=py38_0 107 | - importlib_metadata=1.7.0=0 108 | - incremental=17.5.0=py38_0 109 | - intel-openmp=2020.1=217 110 | - intervaltree=3.0.2=py_1 111 | - ipykernel=5.3.2=py38h5ca1d4c_0 112 | - ipython=7.16.1=py38h5ca1d4c_0 113 | - ipython_genutils=0.2.0=py38_0 114 | - ipywidgets=7.5.1=py_0 115 | - isort=4.3.21=py38_0 116 | - itsdangerous=1.1.0=py_0 117 | - jbig=2.1=hdba287a_0 118 | - jdcal=1.4.1=py_0 119 | - jedi=0.17.1=py38_0 120 | - jeepney=0.4.3=py_0 121 | - jinja2=2.11.2=py_0 122 | - joblib=0.16.0=py_0 123 | - jpeg=9b=h024ee3a_2 124 | - json5=0.9.5=py_0 125 | - jsonschema=3.2.0=py38_0 126 | - jupyter=1.0.0=py38_7 127 | - jupyter_client=6.1.6=py_0 128 | - jupyter_console=6.1.0=py_0 129 | - jupyter_core=4.6.3=py38_0 130 | - jupyterlab=2.1.5=py_0 131 | - jupyterlab_server=1.2.0=py_0 132 | - keyring=21.2.1=py38_0 133 | - kiwisolver=1.2.0=py38hfd86e86_0 134 | - krb5=1.18.2=h173b8e3_0 135 | - lame=3.100=h7b6447c_0 136 | - lazy-object-proxy=1.4.3=py38h7b6447c_0 137 | - lcms2=2.11=h396b838_0 138 | - ld_impl_linux-64=2.33.1=h53a641e_7 139 | - libarchive=3.4.2=h62408e4_0 140 | - libcurl=7.71.1=h20c2e04_1 141 | - libedit=3.1.20191231=h14c3975_1 142 | - libffi=3.3=he6710b0_2 143 | - libgcc-ng=9.1.0=hdf63c60_0 144 | - libgfortran-ng=7.3.0=hdf63c60_0 145 | - liblief=0.10.1=he6710b0_0 146 | - libllvm9=9.0.1=h4a3c616_1 147 | - libopus=1.3.1=h7b6447c_0 148 | - libpng=1.6.37=hbc83047_0 149 | - libprotobuf=3.13.0=h8b12597_0 150 | - libsodium=1.0.18=h7b6447c_0 151 | - libspatialindex=1.9.3=he6710b0_0 152 | - libssh2=1.9.0=h1ba5d50_1 153 | - libstdcxx-ng=9.1.0=hdf63c60_0 154 | - libtiff=4.1.0=h2733197_1 155 | - libtool=2.4.6=h7b6447c_5 156 | - libuuid=1.0.3=h1bed415_2 157 | - libuv=1.40.0=h7b6447c_0 158 | - libvpx=1.7.0=h439df22_0 159 | - libxcb=1.14=h7b6447c_0 160 | - libxml2=2.9.10=he19cac6_1 161 | - libxslt=1.1.34=hc22bd24_0 162 | - llvmlite=0.33.0=py38hc6ec683_1 163 | - locket=0.2.0=py38_1 164 | - lxml=4.5.2=py38hefd8a0e_0 165 | - lz4-c=1.9.2=he6710b0_0 166 | - lzo=2.10=h7b6447c_2 167 | - markupsafe=1.1.1=py38h7b6447c_0 168 | - matplotlib=3.2.2=0 169 | - matplotlib-base=3.2.2=py38hef1b27d_0 170 | - mccabe=0.6.1=py38_1 171 | - mistune=0.8.4=py38h7b6447c_1000 172 | - mkl=2020.1=217 173 | - mkl-service=2.3.0=py38he904b0f_0 174 | - mkl_fft=1.1.0=py38h23d657b_0 175 | - mkl_random=1.1.1=py38h0573a6f_0 176 | - mock=4.0.2=py_0 177 | - more-itertools=8.4.0=py_0 178 | - mpc=1.1.0=h10f8cd9_1 179 | - mpfr=4.0.2=hb69a4c5_1 180 | - mpmath=1.1.0=py38_0 181 | - msgpack-python=1.0.0=py38hfd86e86_1 182 | - multipledispatch=0.6.0=py38_0 183 | - navigator-updater=0.2.1=py38_0 184 | - nbconvert=5.6.1=py38_0 185 | - nbformat=5.0.7=py_0 186 | - nccl=2.8.3.1=h1a5f58c_0 187 | - ncurses=6.2=he6710b0_1 188 | - nettle=3.4.1=hbb512f6_0 189 | - networkx=2.4=py_1 190 | - ninja=1.7.2=0 191 | - nltk=3.5=py_0 192 | - nose=1.3.7=py38_2 193 | - notebook=6.0.3=py38_0 194 | - numba=0.50.1=py38h0573a6f_1 195 | - numexpr=2.7.1=py38h423224d_0 196 | - numpy=1.18.5=py38ha1c710e_0 197 | - numpy-base=1.18.5=py38hde5b4d6_0 198 | - numpydoc=1.1.0=py_0 199 | - olefile=0.46=py_0 200 | - openh264=2.1.0=hd408876_0 201 | - openpyxl=3.0.4=py_0 202 | - openssl=1.1.1g=h7b6447c_0 203 | - packaging=20.4=py_0 204 | - pandas=1.0.5=py38h0573a6f_0 205 | - pandoc=2.10=0 206 | - pandocfilters=1.4.2=py38_1 207 | - pango=1.45.3=hd140c19_0 208 | - parso=0.7.0=py_0 209 | - partd=1.1.0=py_0 210 | - patchelf=0.11=he6710b0_0 211 | - path=13.1.0=py38_0 212 | - path.py=12.4.0=0 213 | - pathlib2=2.3.5=py38_0 214 | - pathtools=0.1.2=py_1 215 | - patsy=0.5.1=py38_0 216 | - pcre=8.44=he6710b0_0 217 | - pep8=1.7.1=py38_0 218 | - pexpect=4.8.0=py38_0 219 | - pickleshare=0.7.5=py38_1000 220 | - pillow=7.2.0=py38hb39fc2d_0 221 | - pip=20.1.1=py38_1 222 | - pixman=0.40.0=h7b6447c_0 223 | - pkginfo=1.5.0.1=py38_0 224 | - pluggy=0.13.1=py38_0 225 | - ply=3.11=py38_0 226 | - prometheus_client=0.8.0=py_0 227 | - prompt-toolkit=3.0.5=py_0 228 | - prompt_toolkit=3.0.5=0 229 | - protobuf=3.13.0=py38h950e882_0 230 | - psutil=5.7.0=py38h7b6447c_0 231 | - ptyprocess=0.6.0=py38_0 232 | - py=1.9.0=py_0 233 | - py-lief=0.10.1=py38h403a769_0 234 | - pyasn1=0.4.8=py_0 235 | - pycodestyle=2.6.0=py_0 236 | - pycosat=0.6.3=py38h7b6447c_1 237 | - pycparser=2.20=py_2 238 | - pycurl=7.43.0.5=py38h1ba5d50_0 239 | - pydocstyle=5.0.2=py_0 240 | - pyflakes=2.2.0=py_0 241 | - pygments=2.6.1=py_0 242 | - pyhamcrest=2.0.2=py_0 243 | - pylint=2.5.3=py38_0 244 | - pyodbc=4.0.30=py38he6710b0_0 245 | - pyopenssl=19.1.0=py_1 246 | - pyparsing=2.4.7=py_0 247 | - pyqt=5.9.2=py38h05f1152_4 248 | - pyrsistent=0.16.0=py38h7b6447c_0 249 | - pysocks=1.7.1=py38_0 250 | - pytables=3.6.1=py38h9fd0a39_0 251 | - pytest=5.4.3=py38_0 252 | - python=3.8.3=hcff3b4d_2 253 | - python-dateutil=2.8.1=py_0 254 | - python-jsonrpc-server=0.3.4=py_1 255 | - python-language-server=0.34.1=py38_0 256 | - python-libarchive-c=2.9=py_0 257 | - python_abi=3.8=1_cp38 258 | - pytorch=1.7.0=py3.8_cuda10.2.89_cudnn7.6.5_0 259 | - pytz=2020.1=py_0 260 | - pywavelets=1.1.1=py38h7b6447c_0 261 | - pyxdg=0.26=py_0 262 | - pyyaml=5.3.1=py38h7b6447c_1 263 | - pyzmq=19.0.1=py38he6710b0_1 264 | - qdarkstyle=2.8.1=py_0 265 | - qt=5.9.7=h5867ecd_1 266 | - qtawesome=0.7.2=py_0 267 | - qtconsole=4.7.5=py_0 268 | - qtpy=1.9.0=py_0 269 | - readline=8.0=h7b6447c_0 270 | - regex=2020.6.8=py38h7b6447c_0 271 | - requests=2.24.0=py_0 272 | - ripgrep=11.0.2=he32d670_0 273 | - rope=0.17.0=py_0 274 | - rtree=0.9.4=py38_1 275 | - ruamel_yaml=0.15.87=py38h7b6447c_1 276 | - scikit-image=0.16.2=py38h0573a6f_0 277 | - scikit-learn=0.23.1=py38h423224d_0 278 | - scipy=1.5.0=py38h0b6359f_0 279 | - seaborn=0.10.1=py_0 280 | - secretstorage=3.1.2=py38_0 281 | - send2trash=1.5.0=py38_0 282 | - service_identity=18.1.0=py_0 283 | - setuptools=49.2.0=py38_0 284 | - simplegeneric=0.8.1=py38_2 285 | - singledispatch=3.4.0.3=py38_0 286 | - sip=4.19.13=py38he6710b0_0 287 | - six=1.15.0=py_0 288 | - snappy=1.1.8=he6710b0_0 289 | - snowballstemmer=2.0.0=py_0 290 | - sortedcollections=1.2.1=py_0 291 | - sortedcontainers=2.2.2=py_0 292 | - soupsieve=2.0.1=py_0 293 | - sphinx=3.1.2=py_0 294 | - sphinxcontrib=1.0=py38_1 295 | - sphinxcontrib-applehelp=1.0.2=py_0 296 | - sphinxcontrib-devhelp=1.0.2=py_0 297 | - sphinxcontrib-htmlhelp=1.0.3=py_0 298 | - sphinxcontrib-jsmath=1.0.1=py_0 299 | - sphinxcontrib-qthelp=1.0.3=py_0 300 | - sphinxcontrib-serializinghtml=1.1.4=py_0 301 | - sphinxcontrib-websupport=1.2.3=py_0 302 | - spyder=4.1.4=py38_0 303 | - spyder-kernels=1.9.2=py38_0 304 | - sqlalchemy=1.3.18=py38h7b6447c_0 305 | - sqlite=3.32.3=h62c20be_0 306 | - statsmodels=0.11.1=py38h7b6447c_0 307 | - sympy=1.6.1=py38_0 308 | - tbb=2020.0=hfd86e86_0 309 | - tblib=1.6.0=py_0 310 | - tensorboardx=2.1=py_0 311 | - terminado=0.8.3=py38_0 312 | - testpath=0.4.4=py_0 313 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 314 | - tk=8.6.10=hbc83047_0 315 | - toml=0.10.1=py_0 316 | - toolz=0.10.0=py_0 317 | - torchaudio=0.7.0=py38 318 | - torchvision=0.8.1=py38_cu102 319 | - tornado=6.0.4=py38h7b6447c_1 320 | - tqdm=4.47.0=py_0 321 | - traitlets=4.3.3=py38_0 322 | - twisted=20.3.0=py38h7b6447c_0 323 | - typing_extensions=3.7.4.2=py_0 324 | - ujson=1.35=py38h7b6447c_0 325 | - unicodecsv=0.14.1=py38_0 326 | - unixodbc=2.3.7=h14c3975_0 327 | - urllib3=1.25.9=py_0 328 | - watchdog=0.10.3=py38_0 329 | - wcwidth=0.2.5=py_0 330 | - webencodings=0.5.1=py38_1 331 | - werkzeug=1.0.1=py_0 332 | - wheel=0.34.2=py38_0 333 | - widgetsnbextension=3.5.1=py38_0 334 | - wrapt=1.11.2=py38h7b6447c_0 335 | - wurlitzer=2.0.1=py38_0 336 | - x264=1!157.20191217=h7b6447c_0 337 | - xlrd=1.2.0=py_0 338 | - xlsxwriter=1.2.9=py_0 339 | - xlwt=1.3.0=py38_0 340 | - xmltodict=0.12.0=py_0 341 | - xz=5.2.5=h7b6447c_0 342 | - yaml=0.2.5=h7b6447c_0 343 | - yapf=0.30.0=py_0 344 | - zeromq=4.3.2=he6710b0_2 345 | - zict=2.0.0=py_0 346 | - zipp=3.1.0=py_0 347 | - zlib=1.2.11=h7b6447c_3 348 | - zope=1.0=py38_1 349 | - zope.event=4.4=py38_0 350 | - zope.interface=4.7.1=py38h7b6447c_0 351 | - zstd=1.4.5=h0b5b093_0 352 | - pip: 353 | - absl-py==0.10.0 354 | - cachetools==4.1.1 355 | - cheroot==8.5.1 356 | - docopt==0.6.2 357 | - fitlog==0.9.13 358 | - gitdb==4.0.5 359 | - gitpython==3.1.11 360 | - google-auth==1.21.1 361 | - google-auth-oauthlib==0.4.1 362 | - grpcio==1.32.0 363 | - jaraco-functools==3.0.1 364 | - markdown==3.2.2 365 | - oauthlib==3.1.0 366 | - pyasn1-modules==0.2.8 367 | - requests-oauthlib==1.3.0 368 | - rsa==4.6 369 | - smmap==3.0.4 370 | - tensorboard==2.3.0 371 | - tensorboard-plugin-wit==1.7.0 372 | - wsgidav==3.0.3 373 | prefix: /home/dongda/anaconda3 374 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import sys 5 | import json 6 | import math 7 | import time 8 | import pprint as pp 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import torch.optim as optim 13 | 14 | from tensorboardX import SummaryWriter 15 | 16 | 17 | from config import PARAMS_CONFIG 18 | from models import EncoderSeq, QDecoder 19 | 20 | from pack_model import build_model, get_tgt_entropy 21 | 22 | from problems.pack2d.render import render 23 | 24 | 25 | from trainer import train_epoch, full_eval, epoch_logger 26 | 27 | 28 | 29 | from utils import ( 30 | get_params, 31 | set_up_env, 32 | logger, 33 | log_graph, 34 | get_scheduler, 35 | get_grad_requiring_params, 36 | load_checkpoint, 37 | save_checkpoint) 38 | 39 | 40 | def launch(env_params, 41 | model_params, 42 | problem_params, 43 | adapt_span_params, 44 | optim_params, 45 | trainer_params, 46 | rl_params): 47 | 48 | # print args and prepare directory and logger 49 | parameters_dict = locals() 50 | # print parameters 51 | for params_key, params_val in parameters_dict.items(): 52 | print(params_key) 53 | pp.pprint(params_val) 54 | 55 | writer_name = "{}".format(env_params['run_name']) 56 | run_name = "{}_{}".format(env_params['run_name'], time.strftime("%Y%m%dT%H%M%S")) 57 | save_dir = os.path.join( 58 | env_params['output_dir'], 59 | "{}_{}".format(problem_params['problem_type'], model_params['block_size']), 60 | run_name 61 | ) 62 | 63 | os.makedirs(save_dir) 64 | 65 | 66 | 67 | checkpoint_file = os.path.join(save_dir, 'checkpoint.pt') 68 | 69 | 70 | # Save arguments so exact configuration can always be found 71 | with open(os.path.join(save_dir, "args.json"), 'w') as fp: 72 | json.dump(parameters_dict, fp) 73 | 74 | logger.configure(dir=save_dir, format_strs=os.getenv('OPENAI_LOG_FORMAT', 'log,csv').split(',')) 75 | 76 | if not trainer_params['no_tensorboard']: 77 | tb_writer = SummaryWriter(comment= "-" + writer_name) 78 | else: 79 | tb_writer = None 80 | 81 | # ENV and MODEL 82 | set_up_env(env_params) 83 | device = env_params['device'] 84 | 85 | target_entropy = get_tgt_entropy( 86 | problem_params['problem_type'], 87 | model_params['block_size'], 88 | rl_params['tgt_entropy'], 89 | problem_params['p_options'] 90 | ).to(device) 91 | 92 | 93 | modules = build_model( 94 | device, 95 | problem_params, 96 | model_params, 97 | adapt_span_params) 98 | 99 | # show model size 100 | get_grad_requiring_params(modules) 101 | # print(modules) 102 | 103 | critic_params = [param for name, param in modules['critic'].named_parameters() if 'module.log_alpha' not in name] 104 | # OPTIMIZER AND SCHEDULER 105 | optimizer = optim.Adam([ 106 | {'params': modules['actor'].parameters()}, 107 | {'params': critic_params, 'lr': optim_params['critic_lr']}, 108 | {'params': modules['critic'].module.log_alpha, 'lr': optim_params['critic_lr']} 109 | ], lr=optim_params['actor_lr']) 110 | 111 | 112 | lambda1 = lambda epoch: min(1, epoch / optim_params['lr_warmup']) 113 | 114 | # end_lr = 1e-1 115 | # start_lr = 1e-7 116 | # lr_find_epochs = trainer_params['nb_iter']/2 117 | # search_lambda = lambda epoch: math.exp(epoch * math.log(end_lr / start_lr) / (lr_find_epochs * model_params['block_size'])) 118 | 119 | 120 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 121 | 122 | # warm up scheduler for both two groups 123 | # scheduler = get_scheduler(optimizer, optim_params['lr_warmup']) 124 | 125 | iter_init = load_checkpoint( 126 | trainer_params['checkpoint_path'], modules, optimizer, scheduler) 127 | 128 | 129 | for epoch in tqdm(range(iter_init, trainer_params['nb_iter'])): 130 | # print("Start train epoch {}, lr={} for run {}".format(epoch, optimizer.param_groups[0]['lr'], run_name)) 131 | # t_sta = time.time() # in seconds 132 | state, values, returns, losses, entropy, grad_norms, log_alpha = train_epoch( 133 | modules, 134 | optimizer, 135 | scheduler, 136 | problem_params, 137 | device, 138 | target_entropy, 139 | **model_params, **trainer_params, **optim_params, **rl_params) 140 | 141 | # for resume and render 142 | if epoch % trainer_params['checkpoint_interval'] == 0: 143 | log_graph.save_train_graph(state, epoch, save_dir) 144 | save_checkpoint(checkpoint_file, epoch, modules, optimizer, scheduler) 145 | # with torch.no_grad(): 146 | # modules.eval() 147 | # trainer_params['full_eval_mode'] = True 148 | # t_sta = time.time() # in seconds 149 | 150 | # state, values, returns, losses, entropy, _, _ = train_epoch( 151 | # modules, 152 | # optimizer, 153 | # scheduler, 154 | # problem_params, 155 | # device, 156 | # target_entropy, 157 | # **model_params, **trainer_params, **optim_params, **rl_params) 158 | 159 | # gap_ratio = state.get_gap_ratio() 160 | # avg_gap_ratio = gap_ratio.mean().item() 161 | 162 | # elapsed = time.time() - t_sta 163 | 164 | # print("Finished evaluation with gap ratio {}, took {} s".format(avg_gap_ratio, time.strftime('%H:%M:%S', time.gmtime(elapsed)))) 165 | 166 | # for monitor 167 | epoch_logger(epoch, state, values, returns, losses, entropy, grad_norms, log_alpha, optimizer, 168 | tb_writer, trainer_params['log_interval'], run_name) 169 | 170 | 171 | # perform a evaluation after training 172 | 173 | with torch.no_grad(): 174 | modules.eval() 175 | trainer_params['full_eval_mode'] = True 176 | t_sta = time.time() # in seconds 177 | 178 | state, values, returns, losses, entropy, _, _ = train_epoch( 179 | modules, 180 | optimizer, 181 | scheduler, 182 | problem_params, 183 | device, 184 | target_entropy, 185 | **model_params, **trainer_params, **optim_params, **rl_params) 186 | 187 | gap_ratio = state.get_gap_ratio() 188 | avg_gap_ratio = gap_ratio.mean().item() 189 | 190 | elapsed = time.time() - t_sta 191 | 192 | print("Finished evaluation with gap ratio {}, took {} s".format(avg_gap_ratio, time.strftime('%H:%M:%S', time.gmtime(elapsed)))) 193 | 194 | 195 | 196 | 197 | 198 | if __name__ == '__main__': 199 | launch(**get_params(params_config=PARAMS_CONFIG)) 200 | 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | # Size notations: 11 | # B = batch_size, H = hidden_size, M = block_size, L = attn_span 12 | 13 | 14 | class SeqAttention(nn.Module): 15 | """Sequential self-attention layer. 16 | """ 17 | 18 | def __init__(self, hidden_size, enable_mem, attn_span, 19 | dropout, adapt_span_params, **kargs): 20 | nn.Module.__init__(self) 21 | self.dropout = nn.Dropout(dropout) 22 | self.hidden_size = hidden_size # size of a single head 23 | self.attn_span = attn_span 24 | self.enable_mem = enable_mem 25 | self.adapt_span_enabled = adapt_span_params['adapt_span_enabled'] 26 | if self.adapt_span_enabled and self.enable_mem: 27 | self.adaptive_span = AdaptiveSpan(attn_span=attn_span, 28 | **adapt_span_params, **kargs) 29 | 30 | def forward(self, query, key, value): 31 | # query size = B x M x H 32 | # key, value sizes = B x (M+L) x H 33 | 34 | if self.adapt_span_enabled: 35 | # [optional] trim out memory to reduce unnecessary computation 36 | key, value, key_pe = self.adaptive_span.trim_memory( 37 | query, key, value) 38 | 39 | # compute attention from context 40 | # B x M (dest) x (M+L) (src) 41 | attn = torch.matmul(query, key.transpose(-1, -2)) 42 | 43 | attn = attn / math.sqrt(self.hidden_size) # B x M X (M+L) 44 | attn = F.softmax(attn, dim=-1) 45 | 46 | if self.adapt_span_enabled and self.enable_mem: 47 | # trim attention lengths according to the learned span 48 | attn = self.adaptive_span(attn) 49 | 50 | attn = self.dropout(attn) # B x M X (M+L) 51 | 52 | out = torch.matmul(attn, value) # B x M x H 53 | 54 | return out 55 | 56 | def get_cache_size(self): 57 | if self.adapt_span_enabled: 58 | return self.adaptive_span.get_cache_size() 59 | else: 60 | return self.attn_span 61 | 62 | 63 | class MultiHeadSeqAttention(nn.Module): 64 | def __init__(self, hidden_size, enable_mem, nb_heads, **kargs): 65 | nn.Module.__init__(self) 66 | assert hidden_size % nb_heads == 0 67 | self.nb_heads = nb_heads 68 | self.head_dim = hidden_size // nb_heads 69 | self.attn = SeqAttention( 70 | hidden_size=self.head_dim, enable_mem=enable_mem, nb_heads=nb_heads, **kargs) 71 | self.proj_query = nn.Linear(hidden_size, hidden_size, bias=False) 72 | self.proj_out = nn.Linear(hidden_size, hidden_size, bias=False) 73 | self.proj_val = nn.Linear(hidden_size, hidden_size, bias=False) 74 | self.proj_key = nn.Linear(hidden_size, hidden_size, bias=False) 75 | 76 | # note that the linear layer initialization in current Pytorch is kaiming uniform init 77 | 78 | def head_reshape(self, x): 79 | K = self.nb_heads 80 | D = self.head_dim 81 | x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D 82 | x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D 83 | x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D 84 | return x 85 | 86 | def forward(self, query, key, value): 87 | B = query.size(0) 88 | K = self.nb_heads 89 | D = self.head_dim 90 | M = query.size(1) 91 | 92 | query = self.proj_query(query) 93 | query = self.head_reshape(query) 94 | value = self.proj_val(value) 95 | value = self.head_reshape(value) 96 | key = self.proj_key(key) 97 | key = self.head_reshape(key) 98 | 99 | out = self.attn(query, key, value) # B_K x M x D 100 | out = out.view(B, K, M, D) # B x K x M x D 101 | out = out.transpose(1, 2).contiguous() # B x M x K x D 102 | out = out.view(B, M, -1) # B x M x K_D 103 | out = self.proj_out(out) 104 | return out 105 | 106 | 107 | class FeedForwardLayer(nn.Module): 108 | def __init__(self, hidden_size, inner_hidden_size, dropout, **kargs): 109 | nn.Module.__init__(self) 110 | self.fc1 = nn.Linear(hidden_size, inner_hidden_size) 111 | self.fc2 = nn.Linear(inner_hidden_size, hidden_size) 112 | self.dropout = nn.Dropout(dropout) 113 | 114 | def forward(self, h): 115 | h1 = F.relu(self.fc1(h)) 116 | h1 = self.dropout(h1) 117 | h2 = self.fc2(h1) 118 | return h2 119 | 120 | 121 | class Normalization(nn.Module): 122 | 123 | def __init__(self, embed_dim, normalization='batch'): 124 | super(Normalization, self).__init__() 125 | 126 | normalizer_class = { 127 | 'batch': nn.BatchNorm1d, 128 | 'instance': nn.InstanceNorm1d, 129 | 'layer': nn.LayerNorm 130 | }.get(normalization, None) 131 | 132 | self.normalizer = normalizer_class(embed_dim, affine=True) 133 | 134 | # Normalization by default initializes affine parameters with bias 0 and weight unif(0,1) which is too large! 135 | self.init_parameters() 136 | 137 | def init_parameters(self): 138 | # xavier_uniform initialization 139 | for name, param in self.named_parameters(): 140 | stdv = 1. / math.sqrt(param.size(-1)) 141 | param.data.uniform_(-stdv, stdv) 142 | 143 | def forward(self, input): 144 | 145 | if isinstance(self.normalizer, nn.BatchNorm1d): 146 | return self.normalizer(input.view(-1, input.size(-1))).view(*input.size()) 147 | elif isinstance(self.normalizer, nn.InstanceNorm1d): 148 | return self.normalizer(input.permute(0, 2, 1)).permute(0, 2, 1) 149 | elif isinstance(self.normalizer, nn.LayerNorm): 150 | return self.normalizer(input) 151 | else: 152 | assert self.normalizer is None, "Unknown normalizer type" 153 | return input 154 | 155 | 156 | class TransformerSeqLayer(nn.Module): 157 | def __init__(self, hidden_size, enable_mem, normalization, **kargs): 158 | nn.Module.__init__(self) 159 | self.attn = MultiHeadSeqAttention( 160 | hidden_size=hidden_size, enable_mem=enable_mem, **kargs) 161 | self.ff = FeedForwardLayer(hidden_size=hidden_size, **kargs) 162 | self.norm1 = Normalization(hidden_size, normalization) 163 | self.norm2 = Normalization(hidden_size, normalization) 164 | 165 | self.enable_mem = enable_mem 166 | 167 | def forward(self, h, h_cache): 168 | # h = B x M x H 169 | # h_cache = B x L x H 170 | if self.enable_mem: 171 | h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H 172 | else: 173 | h_all = h_cache # B x M x H 174 | attn_out = self.attn(h, h_all, h_all) 175 | h = self.norm1(h + attn_out) # B x M x H 176 | ff_out = self.ff(h) 177 | out = self.norm2(h + ff_out) # B x M x H 178 | return out 179 | 180 | 181 | class EncoderSeq(nn.Module): 182 | def __init__(self, state_size, hidden_size, nb_heads, encoder_nb_layers, 183 | attn_span, **kargs): 184 | nn.Module.__init__(self) 185 | # init embeddings 186 | self.init_embed = nn.Linear(state_size, hidden_size) 187 | 188 | self.layers = nn.ModuleList() 189 | self.layers.extend( 190 | TransformerSeqLayer( 191 | hidden_size=hidden_size, enable_mem=True, nb_heads=nb_heads, 192 | attn_span=attn_span, **kargs) 193 | for _ in range(encoder_nb_layers)) 194 | 195 | def forward(self, x, h_cache): 196 | # x size = B x M 197 | block_size = x.size(1) 198 | h = self.init_embed(x) # B x M x H 199 | h_cache_next = [] 200 | for l, layer in enumerate(self.layers): 201 | cache_size = layer.attn.attn.get_cache_size() 202 | 203 | # B x L x H 204 | h_cache_next_l = torch.cat( 205 | [h_cache[l][:, -cache_size + 1:, :], h[:, 0:1, :]], 206 | dim=1).detach() 207 | 208 | h_cache_next.append(h_cache_next_l) 209 | 210 | h = layer(h, h_cache[l]) # B x M x H 211 | 212 | return h, h_cache_next 213 | 214 | 215 | class QDecoder(nn.Module): 216 | def __init__(self, state_size, hidden_size, nb_heads, decoder_nb_layers, 217 | attn_span, **kargs): 218 | nn.Module.__init__(self) 219 | # init embeddings 220 | self.init_embed = nn.Linear(state_size, hidden_size) 221 | 222 | self.layers = nn.ModuleList() 223 | self.layers.extend( 224 | TransformerSeqLayer( 225 | hidden_size=hidden_size, enable_mem=False, nb_heads=nb_heads, 226 | attn_span=attn_span, **kargs) 227 | for _ in range(decoder_nb_layers)) 228 | 229 | def forward(self, x, embedding): 230 | # x size = B x Q_M 231 | block_size = x.size(1) 232 | h = self.init_embed(x) # B x Q_M x H 233 | h_cache_next = [] 234 | for l, layer in enumerate(self.layers): 235 | 236 | h = layer(h, embedding) # B x Q_M x H 237 | 238 | return h 239 | -------------------------------------------------------------------------------- /pack_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import torch 5 | from torch import nn 6 | from models import EncoderSeq, QDecoder 7 | 8 | 9 | 10 | def get_tgt_entropy(problem_type, block_size, tgt_entropy, p_options): 11 | s_tgt_entropy = block_size * tgt_entropy / p_options 12 | if problem_type=='pack2d': 13 | r_tgt_entropy = 2 * tgt_entropy / p_options 14 | target_entropy = torch.tensor([s_tgt_entropy, r_tgt_entropy, tgt_entropy]) 15 | elif problem_type=='pack3d': 16 | r_tgt_entropy = 6 * tgt_entropy / p_options 17 | target_entropy = torch.tensor([s_tgt_entropy, r_tgt_entropy, tgt_entropy, tgt_entropy]) 18 | else: 19 | raise ValueError('Invalided problem type') 20 | 21 | print('target_entropy: ', target_entropy) 22 | return target_entropy 23 | 24 | 25 | 26 | class PackDecoder(nn.Module): 27 | def __init__(self, head_hidden_size, res_size, state_size, hidden_size, decoder_layers, **kargs): 28 | nn.Module.__init__(self) 29 | 30 | self.att_decoder = QDecoder(state_size, hidden_size, decoder_nb_layers=decoder_layers, **kargs) 31 | 32 | self.head = nn.Sequential( 33 | nn.Linear(hidden_size, head_hidden_size), 34 | nn.ReLU(), 35 | nn.Linear(head_hidden_size, res_size) 36 | ) 37 | 38 | 39 | def forward(self, x, embedding): 40 | h = self.att_decoder(x, embedding) 41 | out = self.head(h) 42 | return out 43 | 44 | class Cirtic(nn.Module): 45 | def __init__(self, head_hidden_size, res_size, packed_state_size, box_state_size, hidden_size, c_encoder_layers, c_decoder_layers, **kargs): 46 | nn.Module.__init__(self) 47 | 48 | # add this parameter for entropy temp 49 | # 3D 50 | if packed_state_size==6: 51 | self.log_alpha = nn.Parameter(torch.tensor([-2.0,-2.0,-2.0, -2.0])) 52 | elif packed_state_size==4: 53 | self.log_alpha = nn.Parameter(torch.tensor([-2.0,-2.0,-2.0])) 54 | else: 55 | raise ValueError('Invalided problem type') 56 | 57 | 58 | self.att_encoder = EncoderSeq( 59 | state_size=packed_state_size, 60 | hidden_size=hidden_size, 61 | encoder_nb_layers=c_encoder_layers, 62 | **kargs) 63 | 64 | self.att_decoder = QDecoder( 65 | box_state_size, 66 | hidden_size, 67 | decoder_nb_layers=c_decoder_layers, 68 | **kargs) 69 | 70 | self.head = nn.Sequential( 71 | nn.Linear(hidden_size, head_hidden_size), 72 | nn.ReLU(), 73 | nn.Linear(head_hidden_size, res_size) 74 | ) 75 | 76 | 77 | def forward(self, q, x, h_cache): 78 | 79 | embedding, h_cache = self.att_encoder(x, h_cache) 80 | h = self.att_decoder(q, embedding) # B x Q x H 81 | h = h.mean(dim=1) # B x H 82 | out = self.head(h) 83 | return out, h_cache 84 | 85 | 86 | def set_model(model, device, parallel=True): 87 | if parallel: 88 | model = torch.nn.DataParallel(model) 89 | model = model.to(device) 90 | return model 91 | 92 | 93 | def get_ac_parameters(modules): 94 | 95 | critic_params = modules['critic'].parameters() 96 | 97 | actor_params = modules['actor'].parameters() 98 | 99 | return actor_params, critic_params 100 | 101 | 102 | 103 | 104 | 105 | def build_model( 106 | device, 107 | problem_params, 108 | model_params, 109 | adapt_span_params): 110 | 111 | 112 | if problem_params['problem_type']=='pack2d': 113 | packed_state_size = 4 114 | box_state_size = 2 115 | rotate_out_size = 2 116 | elif problem_params['problem_type']=='pack3d': 117 | packed_state_size = 6 118 | box_state_size = 3 119 | rotate_out_size = 6 120 | else: 121 | raise ValueError('Invalided problem type') 122 | 123 | 124 | encoder = EncoderSeq( 125 | state_size=packed_state_size, 126 | encoder_nb_layers=model_params['encoder_layers'], 127 | **model_params, 128 | adapt_span_params=adapt_span_params) 129 | 130 | 131 | s_decoder = PackDecoder(head_hidden_size=model_params['head_hidden'], 132 | res_size=1, 133 | state_size=box_state_size, 134 | **model_params, 135 | adapt_span_params=adapt_span_params) 136 | 137 | r_decoder = PackDecoder(head_hidden_size=model_params['head_hidden'], 138 | res_size=rotate_out_size, 139 | state_size=box_state_size, 140 | **model_params, 141 | adapt_span_params=adapt_span_params) 142 | 143 | p_decoder = PackDecoder(head_hidden_size=model_params['head_hidden_pos'], 144 | res_size=problem_params['p_options'], 145 | state_size=box_state_size, 146 | **model_params, 147 | adapt_span_params=adapt_span_params) 148 | 149 | q_decoder = PackDecoder(head_hidden_size=model_params['head_hidden_pos'], 150 | res_size=problem_params['p_options'], 151 | state_size=box_state_size, 152 | **model_params, 153 | adapt_span_params=adapt_span_params) 154 | 155 | critic = Cirtic(head_hidden_size=model_params['head_hidden'], 156 | res_size=1, 157 | packed_state_size=packed_state_size, 158 | box_state_size=box_state_size, 159 | **model_params, 160 | adapt_span_params=adapt_span_params) 161 | 162 | 163 | encoder = set_model(encoder, device) 164 | s_decoder = set_model(s_decoder, device) 165 | r_decoder = set_model(r_decoder, device) 166 | p_decoder = set_model(p_decoder, device) 167 | q_decoder = set_model(q_decoder, device) 168 | critic = set_model(critic, device) 169 | 170 | if problem_params['problem_type'] == 'pack2d': 171 | 172 | actor_modules = nn.ModuleDict({ 173 | 'encoder': encoder, 174 | 's_decoder': s_decoder, 175 | 'r_decoder': r_decoder, 176 | 'p_decoder': p_decoder} 177 | ) 178 | else: 179 | actor_modules = nn.ModuleDict({ 180 | 'encoder': encoder, 181 | 's_decoder': s_decoder, 182 | 'r_decoder': r_decoder, 183 | 'p_decoder': p_decoder, 184 | 'q_decoder': q_decoder} 185 | ) 186 | 187 | 188 | packing_modules = nn.ModuleDict({ 189 | 'actor': actor_modules, 190 | 'critic': critic 191 | }) 192 | 193 | 194 | 195 | return packing_modules 196 | 197 | 198 | 199 | 200 | 201 | 202 | -------------------------------------------------------------------------------- /pack_step.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import torch 5 | from torch import nn 6 | import copy 7 | from torch.nn import DataParallel 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | 12 | def pack_step(modules, state, h_caches, problem_params): 13 | actor_modules = modules['actor'] 14 | 15 | actor_encoder_out, h_caches[0] = actor_modules['encoder'](state.packed_state, h_caches[0]) 16 | if not state.online: 17 | # (batch, block, 1) 18 | s_out = actor_modules['s_decoder'](state.boxes, actor_encoder_out) 19 | 20 | select_mask = state.get_mask() 21 | # print(state.boxes, state.packed_state) 22 | s_log_p, selected = _select_step(s_out.squeeze(-1), select_mask) 23 | 24 | else: 25 | selected = torch.zeros(state.packed_state.size(0), device=state.packed_state.device) 26 | s_log_p = 0 27 | 28 | # select (batch) 29 | state.update_select(selected) 30 | # (batch, 2) 31 | q_rotation = state.action.get_shape().unsqueeze(1) 32 | 33 | r_out = actor_modules['r_decoder'](q_rotation, actor_encoder_out).squeeze(1) 34 | 35 | r_log_p, rotation = _rotate_step(r_out.squeeze(-1)) 36 | 37 | # rotation 38 | state.update_rotate(rotation) 39 | 40 | 41 | 42 | if problem_params['problem_type'] == 'pack2d': 43 | p_position = state.action.get_shape().unsqueeze(1) 44 | 45 | if not problem_params['no_query']: 46 | 47 | p_out = actor_modules['p_decoder'](p_position, actor_encoder_out).squeeze(1) 48 | 49 | else: 50 | 51 | p_out = actor_modules['p_decoder'](q_rotation, actor_encoder_out).squeeze(1) 52 | 53 | x_log_p, box_xs = _drop_step(p_out.squeeze(-1), state.get_boundx()) 54 | 55 | value, h_caches[1] = modules['critic'](state.boxes, state.packed_state, h_caches[1]) 56 | value = value.squeeze(-1) 57 | # update location and finish one step packing 58 | state.update_pack(box_xs) 59 | 60 | return s_log_p, r_log_p, x_log_p, value, h_caches 61 | else: 62 | 63 | p_position = state.action.get_shape().unsqueeze(1) 64 | q_position = state.action.get_shape().unsqueeze(1) 65 | 66 | if not problem_params['no_query']: 67 | 68 | p_out = actor_modules['p_decoder'](p_position, actor_encoder_out).squeeze(1) 69 | q_out = actor_modules['q_decoder'](q_position, actor_encoder_out).squeeze(1) 70 | else: 71 | 72 | p_out = actor_modules['p_decoder'](q_rotation, actor_encoder_out).squeeze(1) 73 | q_out = actor_modules['q_decoder'](q_rotation, actor_encoder_out).squeeze(1) 74 | 75 | x_log_p, box_xs = _drop_step(p_out.squeeze(-1), state.get_boundx()) 76 | y_log_p, box_ys = _drop_step(q_out.squeeze(-1), state.get_boundy()) 77 | 78 | value, h_caches[1] = modules['critic'](state.boxes, state.packed_state, h_caches[1]) 79 | value = value.squeeze(-1) 80 | 81 | state.update_pack(box_xs, box_ys) 82 | 83 | return s_log_p, r_log_p, x_log_p, y_log_p, value, h_caches 84 | 85 | 86 | 87 | 88 | def _select_step(s_logits, mask): 89 | 90 | s_logits = s_logits.masked_fill(mask, -np.inf) 91 | 92 | s_log_p = F.log_softmax(s_logits, dim=-1) 93 | 94 | # (batch) 95 | selected = _select(s_log_p.exp()).unsqueeze(-1) 96 | 97 | # do not reinforce masked and avoid entropy become nan 98 | s_log_p = s_log_p.masked_fill(mask, 0) 99 | 100 | return s_log_p, selected 101 | 102 | 103 | def _rotate_step(r_logits): 104 | 105 | r_log_p = F.log_softmax(r_logits, dim=-1) 106 | 107 | # rotate (batch, 1) 108 | rotate = _select(r_log_p.exp()).unsqueeze(-1) 109 | 110 | return r_log_p, rotate 111 | 112 | def _drop_step(p_logits, right_bound): 113 | 114 | batch_size, p_options = p_logits.size() 115 | # (-1, 1) ---->(0, DISCRETE_XNUM) (batch, 1) 116 | right_b = ((right_bound + 1.0) * (p_options/2)).floor().long() 117 | 118 | bound_range = torch.arange(p_options, device=p_logits.device).unsqueeze(0) 119 | bound_range = bound_range.repeat(batch_size, 1) 120 | 121 | # bound_mask (batch, DISCRETE_XNUM) 122 | bound_mask = bound_range.gt(right_b.unsqueeze(-1)) 123 | 124 | x_logits_masked = p_logits.masked_fill(bound_mask, -np.inf) 125 | # (batch, DISCRETE_XNUM) 126 | x_log_p = F.log_softmax(x_logits_masked, dim=-1) 127 | assert not torch.isnan(x_log_p).any() 128 | 129 | # (batch, 1) 130 | x_selects = _select(x_log_p.exp()).unsqueeze(1) 131 | 132 | # do not reinforce masked 133 | x_log_p = x_log_p.masked_fill(bound_mask, 0) 134 | 135 | box_xs = x_selects.float()/(p_options/2) - 1.0 136 | 137 | # test continuous and discrete conversion 138 | # test = ((box_xs + 1.0) * (p_options/2)).round().long() 139 | # assert test.eq(x_selects).all(), "conversion error!" 140 | 141 | return x_log_p, box_xs 142 | 143 | 144 | def _select(probs, decode_type="sampling"): 145 | assert (probs == probs).all(), "Probs should not contain any nans" 146 | 147 | if decode_type == "greedy": 148 | _, selected = probs.max(-1) 149 | elif decode_type == "sampling": 150 | selected = probs.multinomial(1).squeeze(1) 151 | 152 | else: 153 | assert False, "Unknown decode type" 154 | 155 | return selected 156 | 157 | 158 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import os 6 | import argparse 7 | import numpy as np 8 | import torch 9 | import time 10 | import glob 11 | import pandas as pd 12 | import json 13 | 14 | 15 | from torch.utils.data import DataLoader 16 | from utils import load_model 17 | from problems import Pack2D 18 | from problems.pack2d.render import render, get_render_data 19 | 20 | 21 | 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--load_path', help='Path to load model parameters and optimizer state from') 26 | # parser.add_argument('--from_val', action='store_true', help='load from training example') 27 | parser.add_argument('--index', type=int, default=0, help="less than 10 is min, 11~20 is normal") 28 | parser.add_argument('--epoch', type=int, help='choose the epoch to plot') 29 | 30 | opts = parser.parse_args() 31 | 32 | if opts.load_path is not None: 33 | load_path = opts.load_path 34 | else: 35 | load_path = max(glob.iglob('outputs/pack2d_20/*'), key=os.path.getctime) 36 | 37 | 38 | load_path = '/Users/phoenix/rcw/rl/packing/unversal/unversal_packing' 39 | 40 | 41 | epoch = max( 42 | int(os.path.splitext(filename)[0].split("-")[1]) 43 | for filename in os.listdir(load_path) 44 | if os.path.splitext(filename)[1] == '.csv' and os.path.splitext(filename)[0].split("-")[0]=='epoch' 45 | ) 46 | graph_filename = os.path.join(load_path, 'epoch-{}.csv'.format(epoch)) 47 | graph_filename = os.path.join(load_path, 'epoch-{}.csv'.format(opts.epoch)) 48 | 49 | 50 | 51 | print(' [*] Loading data from {}'.format(graph_filename)) 52 | print("draw the ", opts.index) 53 | 54 | data_frame = pd.read_csv(graph_filename) 55 | 56 | # index with step size 4 57 | indexs = np.arange(0, data_frame.shape[0], 4) 58 | 59 | # last 3 is statistic 60 | graph_size = data_frame.shape[1]-4 61 | 62 | #(batch*4, 20) the first column is csv index 63 | graphs = data_frame.iloc[:,1:graph_size+1].to_numpy() 64 | assert graphs.shape[1] == graph_size 65 | 66 | 67 | # data to cpu 68 | #(batch, 4, 20) 69 | plot_graphs = torch.from_numpy(graphs.reshape((graphs.shape[0]//4, 4, graph_size)).swapaxes(1,2)) 70 | heights = torch.from_numpy(data_frame.iloc[:,graph_size+1].take(indexs).to_numpy()) 71 | gap_sizes = torch.from_numpy(data_frame.iloc[:,graph_size+2].take(indexs).to_numpy()) 72 | gap_ratios = torch.from_numpy(data_frame.iloc[:,graph_size+3].take(indexs).to_numpy()) 73 | 74 | print("average height: ", heights.mean().data) 75 | print("average gap_ratio: ", gap_ratios.mean().data) 76 | 77 | print("rendering...") 78 | # print(graphs[0]) 79 | 80 | 81 | draw_index = opts.index 82 | 83 | print("index height: ", draw_index, heights[draw_index]) 84 | print("index gap_ratio: ", draw_index, gap_ratios[draw_index]) 85 | 86 | window = render(plot_graphs[draw_index], heights[draw_index], gap_sizes[draw_index], gap_ratios[draw_index], sleep=0.2) 87 | 88 | #for draw_index in range(20): 89 | # window = render(ranked_graph[draw_index], heights[draw_index], gap_sizes[draw_index], gap_ratios[draw_index], sleep=0.2) 90 | 91 | # Plot the results 92 | 93 | 94 | # Plot the results 95 | while True: 96 | print("rendered") 97 | time.sleep(5) 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /plot3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import os 6 | import argparse 7 | import numpy as np 8 | import torch 9 | import time 10 | import glob 11 | import pandas as pd 12 | import json 13 | 14 | 15 | from torch.utils.data import DataLoader 16 | from utils import load_model 17 | from problems import Pack3D 18 | from problems.pack3d.render import render, get_render_data 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument('--load_path', help='Path to load model parameters and optimizer state from') 24 | parser.add_argument('--from_train', action='store_true', help='load from training example') 25 | parser.add_argument('--index', type=int, default=0, help="less than 10 is min, 11~20 is normal") 26 | parser.add_argument('--epoch', type=int, help='choose the epoch to plot') 27 | 28 | opts = parser.parse_args() 29 | 30 | if opts.load_path is not None: 31 | load_path = opts.load_path 32 | else: 33 | load_path = max(glob.iglob('outputs/pack2d_20/*'), key=os.path.getctime) 34 | 35 | load_path = '/Users/phoenix/rcw/rl/packing/unversal/unversal_packing' 36 | 37 | 38 | epoch = max( 39 | int(os.path.splitext(filename)[0].split("-")[1]) 40 | for filename in os.listdir(load_path) 41 | if os.path.splitext(filename)[1] == '.csv' and os.path.splitext(filename)[0].split("-")[0] == 'epoch' 42 | ) 43 | graph_filename = os.path.join(load_path, 'epoch-{}.csv'.format(epoch)) 44 | graph_filename = os.path.join(load_path, 'epoch-{}.csv'.format(opts.epoch)) 45 | 46 | 47 | print(' [*] Loading data from {}'.format(graph_filename)) 48 | 49 | data_frame = pd.read_csv(graph_filename) 50 | 51 | indexs = np.arange(0, data_frame.shape[0], 6) 52 | 53 | graph_size = data_frame.shape[1] - 4 54 | 55 | # (batch*4, 20) 56 | graphs = data_frame.iloc[:, 1:graph_size + 1].to_numpy() 57 | assert graphs.shape[1] == graph_size 58 | 59 | # data to cpu 60 | # (batch, 4, 20) 61 | plot_graphs = torch.from_numpy(graphs.reshape((graphs.shape[0] // 6, 6, graph_size)).swapaxes(1, 2)) 62 | heights = torch.from_numpy(data_frame.iloc[:, graph_size + 1].take(indexs).to_numpy()) 63 | gap_sizes = torch.from_numpy(data_frame.iloc[:, graph_size + 2].take(indexs).to_numpy()) 64 | gap_ratios = torch.from_numpy(data_frame.iloc[:, graph_size + 3].take(indexs).to_numpy()) 65 | # orders = torch.from_numpy(data_frame.iloc[:, graph_size+4].take(indexs).to_numpy()) 66 | # print(orders) 67 | # orders = plot_graphs[:, :, 7] 68 | # plot_graphs = plot_graphs[:, :, 0:7] 69 | 70 | print("average height: ", heights.mean()) 71 | print("average gap_ratio: ", gap_ratios.mean()) 72 | 73 | print("rendering...") 74 | # print(graphs[0]) 75 | # orders = orders.unsqueeze(-1).expand_as(plot_graphs).long() 76 | 77 | # ranked_graph = plot_graphs.gather(1, orders) 78 | 79 | draw_index = opts.index 80 | print("height: ", heights[draw_index]) 81 | print('gap_sizes: ', gap_sizes[draw_index]) 82 | print('gap_ratios: ', gap_ratios[draw_index]) 83 | 84 | window = render(plot_graphs[draw_index], heights[draw_index], gap_sizes[draw_index], gap_ratios[draw_index], sleep=1) 85 | 86 | # Plot the results 87 | 88 | 89 | # Plot the results 90 | while True: 91 | print("rendered") 92 | time.sleep(5) 93 | -------------------------------------------------------------------------------- /problems/__init__.py: -------------------------------------------------------------------------------- 1 | from problems.pack2d.pack2d import Pack2D 2 | from problems.pack3d.pack3d import Pack3D -------------------------------------------------------------------------------- /problems/pack2d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongdongbh/RCQL/7d71dec03d1ac9ec12063c4f0d96ebf8e960f2e6/problems/pack2d/__init__.py -------------------------------------------------------------------------------- /problems/pack2d/pack2d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | from torch.utils.data import Dataset 6 | import torch 7 | import os 8 | import pickle 9 | from problems.pack2d.state_pack2d import StatePack2D 10 | 11 | from utils import sample_truncated_normal, generate_normal 12 | 13 | 14 | class Pack2DUpdate(Dataset): 15 | 16 | def __init__(self, block_size=20, batch_size=128, block_num=10, size_p1=0.4, size_p2=10.0, distribution='normal', online=False, **kargs): 17 | super(Pack2DUpdate, self).__init__() 18 | 19 | assert distribution is not None, "Data distribution must be specified for problem" 20 | 21 | if distribution == 'normal': 22 | self.data = [generate_normal(shape=(1, 2), mu=size_p1, sigma=size_p2, a=0.02, b=2.0) 23 | for i in range(batch_size * block_size * block_num)] 24 | else: 25 | assert distribution == 'uniform' 26 | self.data = [torch.FloatTensor(1, 2).uniform_( 27 | size_p1, size_p2) for i in range(batch_size * block_size * block_num)] 28 | if not online: 29 | self.data = sorted( 30 | self.data, key=lambda x: x[0][0].item() * x[0][1].item(), reverse=True) 31 | 32 | self.size = len(self.data) 33 | 34 | def __len__(self): 35 | return self.size 36 | 37 | def __getitem__(self, idx): 38 | return self.data[idx] 39 | 40 | 41 | class Pack2D(object): 42 | NAME = 'pack2d' 43 | 44 | @staticmethod 45 | def make_dataset(*args, **kwargs): 46 | return Pack2DUpdate(*args, **kwargs) 47 | 48 | @staticmethod 49 | def make_state(*args, **kwargs): 50 | return StatePack2D(*args, **kwargs) 51 | -------------------------------------------------------------------------------- /problems/pack2d/render.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import time 6 | import torch 7 | 8 | 9 | 10 | 11 | BIN_HEIGHT = 150.0 12 | DRAW_SCALE = 3.0 13 | 14 | def get_render_data(state): 15 | 16 | # (batch, graph_size, n_feature) 17 | graphs = state.get_graph() 18 | heights = state.get_height() 19 | gap_sizes = state.get_gap_size() 20 | gap_ratios = state.get_gap_ratio() 21 | 22 | 23 | return graphs, heights, gap_sizes, gap_ratios 24 | 25 | # graph (2*block_size, 4) 26 | 27 | def render(graph, height, gap_size, gap_ratio, sleep=0): 28 | from problems.pack2d.viewer import Viewer 29 | 30 | screen_width = 2000 31 | screen_height = 1600 32 | 33 | viewer = Viewer(screen_width, screen_height) 34 | viewer.window.clear() 35 | 36 | min_y, _ = torch.min(graph[:,3], -1) # (1) 37 | 38 | print('min y: ', min_y.data) 39 | 40 | delta_height = height - min_y 41 | 42 | graph[:,3] -= min_y 43 | 44 | print('delta height: ', delta_height.data) 45 | 46 | scale = DRAW_SCALE * (BIN_HEIGHT / delta_height) 47 | 48 | viewer.set_scale(scale) 49 | viewer.draw_background(2*scale) 50 | 51 | 52 | # viewer.draw_text(height, gap_size, gap_ratio) 53 | 54 | for i, row in enumerate(graph): 55 | # print("i, row:", i, row) 56 | if row[0] == 0 or row[1] == 0: 57 | continue 58 | viewer.add_geom(row, i) 59 | if sleep>0: 60 | viewer.render() 61 | time.sleep(sleep) 62 | 63 | 64 | viewer.draw_top_line(height) 65 | viewer.render() 66 | 67 | return viewer 68 | 69 | -------------------------------------------------------------------------------- /problems/pack2d/resources/MONOFONT.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongdongbh/RCQL/7d71dec03d1ac9ec12063c4f0d96ebf8e960f2e6/problems/pack2d/resources/MONOFONT.TTF -------------------------------------------------------------------------------- /problems/pack2d/state_pack2d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | 4 | import torch 5 | import copy 6 | import numpy as np 7 | from typing import NamedTuple 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | 12 | 13 | class PackAction(): 14 | # (batch, 1) 15 | 16 | def __init__(self, batch_size, device): 17 | self.index = torch.zeros(batch_size, 1, device=device) 18 | self.x = torch.empty(batch_size, 1, device=device).fill_(-2) # set to -2 19 | self.y = torch.empty(batch_size, 1, device=device).fill_(-2) 20 | self.rotate = torch.zeros(batch_size, device=device) 21 | self.updated_shape = torch.empty(batch_size, 2, device=device) 22 | 23 | def set_index(self, selected): 24 | self.index = selected 25 | 26 | def set_rotate(self, rotate): 27 | self.rotate = rotate 28 | 29 | def set_shape(self, width, height): 30 | # (batch, 2) 31 | self.updated_shape = torch.stack([width, height], dim=-1) 32 | 33 | def get_shape(self): 34 | return self.updated_shape 35 | 36 | 37 | def set_pos(self, x, y): 38 | self.x = x 39 | self.y = y 40 | 41 | def get_packed(self): 42 | return torch.cat((self.updated_shape, self.x, self.y), dim=-1) 43 | 44 | def reset(self): 45 | self.__init__(self.index.size(0)) 46 | 47 | def __call__(self): 48 | # (batch, 1) 49 | return {'index': self.index, 50 | 'rotate': self.rotate, 51 | 'x': self.x} 52 | 53 | def __len__(self): 54 | return self.index.size(0) 55 | 56 | 57 | 58 | def push_to_tensor_alternative(tensor, x): 59 | return torch.cat((tensor[:,1:,:], x), dim=1) 60 | 61 | 62 | 63 | class StatePack2D(): 64 | def __init__(self, batch_size, instance_size, block_size, device, position_size=128 ,online=False, cache_size=None): 65 | 66 | if online: 67 | self.boxes = torch.zeros(batch_size, 1, 2, device=device) 68 | else: 69 | self.boxes = torch.zeros(batch_size, block_size, 2, device=device) 70 | 71 | 72 | if cache_size is None: 73 | cache_size = block_size 74 | 75 | self.instance_size = instance_size 76 | self.device = device 77 | self.online = online 78 | self.i = 0 79 | 80 | # {width| height| x| y} 81 | self.packed_state = torch.zeros(batch_size, block_size, 4, dtype=torch.float, device=device) 82 | self.packed_state_cache = torch.zeros(batch_size, cache_size, 4, dtype=torch.float, device=device) 83 | self.packed_cat = torch.cat((self.packed_state_cache, self.packed_state), dim=1) 84 | 85 | self.boxes_area = torch.zeros(batch_size, dtype=torch.float, device=device) 86 | self.total_rewards = torch.zeros(batch_size, dtype=torch.float, device=device) 87 | self.skyline = torch.zeros(batch_size, position_size, dtype=float, device=device) 88 | self.action=PackAction(batch_size, device=device) 89 | 90 | def put_reward(self, reward): 91 | self.total_rewards += reward 92 | 93 | def get_rewards(self): 94 | return self.total_rewards 95 | 96 | def get_mask(self): 97 | 98 | 99 | block_size = self.packed_state.size(1) 100 | mask_array = torch.from_numpy(np.tril(np.ones(block_size), k=-1).astype('bool_')).to(self.packed_state.device) 101 | 102 | # we have one zero block first, so we have to pack one more block 103 | remain_steps = self.instance_size + block_size - self.i 104 | assert remain_steps > 0, 'over packed!!!' 105 | 106 | 107 | if remain_steps // block_size == 0: 108 | # from one to (block_size-1) 109 | mask_num = block_size - remain_steps 110 | else: 111 | mask_num = 0 112 | 113 | return mask_array[mask_num] 114 | 115 | 116 | def update_env(self, new_box): 117 | 118 | # new_box (batch, 1, 2) 119 | if self.online: 120 | self.boxes = new_box 121 | else: 122 | batch_size, block_size, box_state_size = self.boxes.size() 123 | 124 | all_index = torch.arange(block_size, device=self.boxes.device).repeat(batch_size, 1) 125 | 126 | # we initialize index with 0, so it doesn't matter with the first one 127 | mask = (all_index!=self.action.index).unsqueeze(-1).repeat(1, 1, box_state_size) 128 | # selected_box (batch, block-1, box_state_size) 129 | remaining_boxes = torch.masked_select(self.boxes, mask).view(batch_size, -1, box_state_size) 130 | 131 | self.boxes = torch.cat((new_box, remaining_boxes), dim=1) 132 | 133 | 134 | def update_select(self, selected): 135 | # select(batch,1) 136 | self.action.set_index(selected) 137 | 138 | # set raw shape 139 | box_width, box_height = self._get_action_box_shape() 140 | # print("box width, box height", self.boxes[0], selected[0], box_width[0], box_height[0]) 141 | self.action.set_shape(box_width, box_height) 142 | 143 | 144 | def _get_action_box_shape(self): 145 | 146 | if self.online: 147 | box_width = self.boxes[:, :, 0].squeeze(-1).squeeze(-1) # (batch) 148 | box_height = self.boxes[:, :, 1].squeeze(-1).squeeze(-1) 149 | else: 150 | select_index = self.action.index.long() 151 | 152 | box_raw_w = self.boxes[:, :, 0].squeeze(-1) # (batch, graph) 153 | box_raw_h = self.boxes[:, :, 1].squeeze(-1) 154 | # print("box_raw_w: ", box_raw_w) 155 | 156 | # print(box_raw_w.size(), select_index.size()) 157 | box_width = torch.gather(box_raw_w, -1, select_index).squeeze(-1) # (batch) 158 | box_height = torch.gather(box_raw_h, -1, select_index).squeeze(-1) 159 | 160 | return box_width, box_height 161 | 162 | 163 | # update roate action and set width and height according rotate 164 | def update_rotate(self, rotate): 165 | # rotate(batch, 1) 166 | 167 | self.action.set_rotate(rotate) 168 | 169 | rotate_mask = rotate.squeeze(-1).gt(0.5) # (batch) 170 | 171 | # (batch, 2) 172 | box_shape = self.action.get_shape() 173 | 174 | box_width = box_shape[:, 0] # (batch) 175 | box_height = box_shape[:, 1] 176 | 177 | box_width_r = box_height 178 | box_height_r = box_width 179 | 180 | box_width_r = torch.masked_select(box_width_r, rotate_mask) # (all rotated) 181 | box_height_r = torch.masked_select(box_height_r, rotate_mask) 182 | # print("box_width_r: ", box_width_r) 183 | 184 | inbox_width = box_width.masked_scatter(rotate_mask, box_width_r) # (batch) 185 | inbox_height = box_height.masked_scatter(rotate_mask, box_height_r) 186 | 187 | 188 | # save to action 189 | self.action.set_shape(inbox_width, inbox_height) 190 | 191 | 192 | 193 | # set x,y and update packing state 194 | def update_pack(self, x): 195 | 196 | batch_size = self.packed_state.size(0) 197 | select_index = self.action.index.squeeze(-1).long() 198 | 199 | # y = self._get_y(x) 200 | y = self._get_y_skyline(x) 201 | self.action.set_pos(x, y) 202 | 203 | # (batch, 1, 4) 204 | packed_box = self.action.get_packed().unsqueeze(-2) 205 | 206 | inbox_shape = self.action.get_shape() 207 | # add new box area 208 | self.boxes_area += (inbox_shape[:,0] * inbox_shape[:,1]).squeeze(-1) 209 | 210 | # FIFO packed!!! 211 | # print('packed_box 0 and action.get_shape 0:', packed_box[0], self.action.get_shape()[0]) 212 | # print('packed state:', self.packed_state[0]) 213 | self.packed_state_cache = push_to_tensor_alternative(self.packed_state_cache, self.packed_state[:,0:1,:]) 214 | self.packed_state = push_to_tensor_alternative(self.packed_state, packed_box) 215 | 216 | self.packed_cat = torch.cat((self.packed_state_cache, self.packed_state), dim=1) 217 | 218 | self.i += 1 219 | 220 | def _get_y_skyline(self, x): 221 | 222 | inbox_width = self.action.get_packed()[:,0] 223 | inbox_height = self.action.get_packed()[:,1].unsqueeze(-1) 224 | position_size = self.skyline.size(1) 225 | batch_size = self.skyline.size(0) 226 | 227 | in_left = torch.min(x.squeeze(-1), x.squeeze(-1) + inbox_width) 228 | in_right = torch.max(x.squeeze(-1), x.squeeze(-1) + inbox_width) 229 | 230 | # print(in_left, in_right) 231 | 232 | left_idx = ((in_left + 1.0) * (position_size/2)).floor().long().unsqueeze(-1) 233 | right_idx = ((in_right + 1.0) * (position_size/2)).floor().long().unsqueeze(-1) 234 | 235 | 236 | mask = torch.arange(0, position_size, device=self.device).repeat(batch_size,1) 237 | mask_left = mask.ge(left_idx) 238 | mask_right = mask.le(right_idx) 239 | mask = mask_left * mask_right 240 | masked_skyline = mask * self.skyline 241 | non_masked_skyline = (~mask) * self.skyline 242 | # print("mask skyline size:", masked_skyline.size(), non_masked_skyline.size(), masked_skyline[0]) 243 | max_y = torch.max(masked_skyline, 1)[0].unsqueeze(-1).float() 244 | # print("max y size:", max_y.size(), max_y[0]) 245 | # print("inbox height:", inbox_height.size(), inbox_height[0]) 246 | update_skyline = mask * (max_y + inbox_height) + non_masked_skyline 247 | # print("000skyline 0:", self.skyline[0], left_idx[0], right_idx[0]) 248 | self.skyline = update_skyline 249 | # print("111skyline 0",self.skyline[0]) 250 | 251 | return max_y 252 | 253 | 254 | 255 | def _get_y(self, x): 256 | 257 | inbox_width = self.action.get_packed()[:,0] 258 | 259 | in_left = torch.min(x.squeeze(-1), x.squeeze(-1) + inbox_width) # (batch) 260 | in_right = torch.max(x.squeeze(-1), x.squeeze(-1) + inbox_width) 261 | 262 | box_width = self.packed_cat[:, :, 0] 263 | box_height = self.packed_cat[:, :, 1] 264 | 265 | box_x = self.packed_cat[:, :, 2] # (batch, packed) 266 | box_y = self.packed_cat[:, :, 3] 267 | 268 | box_left = torch.min(box_x, box_x + box_width) # (batch, packed) 269 | box_right = torch.max(box_x, box_x + box_width) 270 | box_top = torch.max(box_y, box_y + box_height) 271 | 272 | in_left = in_left.unsqueeze(-1).repeat([1, self.packed_cat.size()[1]]) # (batch, packed) 273 | in_right = in_right.unsqueeze(-1).repeat([1, self.packed_cat.size()[1]]) 274 | 275 | # print(box_right.size(), in_left.size()) 276 | 277 | is_left = torch.gt(box_right, in_left) # box_right > in_left # (batch, packed) 278 | is_right = torch.lt(box_left, in_right) # box_left < in_right 279 | 280 | is_overlaped = is_left * is_right # element wise multiplication just logic &(and) 281 | # print("is_overlaped size", is_overlaped.size()) 282 | non_overlaped = ~is_overlaped 283 | 284 | overlap_box_top = box_top.masked_fill(non_overlaped, 0) # (batch, select) 285 | # print("overlap_box_top size", overlap_box_top.size()) 286 | 287 | max_y, _ = torch.max(overlap_box_top, -1, keepdim=True) # (batch, 1) 288 | 289 | return max_y 290 | 291 | def get_boundx(self): 292 | 293 | batch_size = self.packed_state.size()[0] 294 | 295 | right_b = torch.ones(batch_size, device=self.packed_state.device) - self.action.get_shape()[:,0] 296 | 297 | return right_b 298 | 299 | def get_height(self): 300 | 301 | box_height = self.packed_cat[:, :, 1] 302 | 303 | # (batch, packed) 304 | box_top = self.packed_cat[:, :, 3] + box_height 305 | 306 | heights, _ = torch.max(box_top, -1) 307 | 308 | # (batch) 309 | return heights 310 | 311 | def get_gap_size(self): 312 | 313 | bin_area = self.get_height() * 2.0 314 | 315 | gap_area = bin_area - self.boxes_area 316 | 317 | return gap_area 318 | 319 | def get_gap_ratio(self): 320 | 321 | # (batch) bin width is 2 322 | bin_area = self.get_height() * 2.0 323 | 324 | gap_ratio = self.get_gap_size() / bin_area 325 | 326 | return gap_ratio 327 | 328 | 329 | def get_graph(self): 330 | # we drop boxes of graph for rendering 331 | # drop_height = 6.0 332 | 333 | # graph = copy.deepcopy(self.packed_cat) 334 | 335 | # drop_graph = torch.ge(self.get_height(), drop_height) # (batch) 336 | 337 | # mask = torch.nonzero(drop_graph).squeeze(-1) # (mask) 338 | 339 | # if mask.size(0)!=0: 340 | # min_y, _ = torch.min(graph[mask, :, 3], dim=1) #(mask) 341 | # graph[mask, :, 3] -= min_y.unsqueeze(-1).repeat([1, graph.size()[1]]) 342 | 343 | # graph(width, height, x, y) 344 | return self.packed_cat 345 | 346 | -------------------------------------------------------------------------------- /problems/pack2d/viewer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | #!/usr/bin/env python 8 | # coding: utf-8 9 | 10 | 11 | import pyglet 12 | from pyglet.gl import * 13 | from pyglet.image import ImagePattern, ImageData 14 | from pyglet.window import key 15 | import torch 16 | 17 | LEFT_SPACE = 100 18 | 19 | 20 | pyglet.resource.path = ['resources'] 21 | pyglet.resource.reindex() 22 | pyglet.resource.add_font('MONOFONT.TTF') 23 | 24 | BLACK = (0.2, 0.2, 0.2, 1.0) # color 0 25 | RED = (1.0, 0.5, 0.5, 1.0) # color 1 26 | GREEN = (0.5, 1.0, 0.5, 1.0) # color 2 27 | BLUE = (0.5, 0.5, 1.0, 1.0) # color 3 28 | YELLOW = (0.8, 0.8, 0.5, 1.0) # color 4 29 | GREY = (0.7, 0.7, 0.7, 1.0) # color 5 30 | DARK_GREY = (0.5, 0.5, 0.5, 1.0) # color 6 31 | WHITE = (0.9, 0.9, 0.9, 1.0) # color 7 32 | 33 | COLORS = [BLACK, RED, GREEN, BLUE, YELLOW, GREY, DARK_GREY, WHITE] 34 | 35 | def int_color(float_color): 36 | return tuple(map(lambda c: int(c*255), float_color)) 37 | 38 | 39 | def lighten_color(float_color): 40 | return tuple(map(lambda c: min(c*1.2, 1.0), float_color)) 41 | 42 | 43 | def darken_color(float_color): 44 | return tuple(map(lambda c: c*0.8, float_color)) 45 | 46 | class BlockImagePattern(ImagePattern): 47 | def __init__(self, color): 48 | self.color = int_color(color) 49 | self.frame_color = int_color(BLACK) 50 | 51 | def create_image(self, width, height): 52 | data = b'' 53 | for i in range(width * height): 54 | pos_x = i % width 55 | pos_y = i // width 56 | if pos_x < 2 or pos_x > width - 2: 57 | data += bytes(self.frame_color) 58 | elif pos_y < 2 or pos_y > height - 2: 59 | data += bytes(self.frame_color) 60 | else: 61 | data += bytes(self.color) 62 | return ImageData(width, height, 'RGBA', data) 63 | 64 | 65 | # In[6]: 66 | 67 | 68 | class Viewer(object): 69 | def __init__(self, width, height): 70 | self.width = width 71 | self.height = height 72 | self.bin_width = 100 73 | self.window = pyglet.window.Window(width, 74 | height, 75 | resizable=True, 76 | caption='Packing problem', 77 | visible=True) 78 | self.window.on_close = self.window_closed_by_user 79 | self.isopen = True 80 | self.scale = 1 81 | 82 | self.geoms = [] 83 | 84 | self.batch = pyglet.graphics.Batch() 85 | 86 | 87 | def set_scale(self, scale): 88 | self.scale = scale 89 | self.bin_width = scale*2 90 | 91 | def draw_background(self, width): 92 | 93 | 94 | pyglet.gl.glClearColor(*WHITE) 95 | 96 | 97 | color = lighten_color(BLACK) * 2 98 | 99 | self.batch.add(2, pyglet.gl.GL_LINES, None, 100 | ('v2f', (width+LEFT_SPACE, 0, 101 | width+LEFT_SPACE, self.height)), 102 | ('c4f', color)) 103 | 104 | self.batch.add(2, pyglet.gl.GL_LINES, None, 105 | ('v2f', (LEFT_SPACE, 0, 106 | LEFT_SPACE, self.height)), 107 | ('c4f', color)) 108 | 109 | def close(self): 110 | self.window.close() 111 | 112 | def window_closed_by_user(self): 113 | self.isopen = False 114 | 115 | def draw_text(self, height, gap_size, gap_ratio): 116 | 117 | text = "height: " + str(format(height.data.cpu().numpy(), '.3f')) 118 | height_label = pyglet.text.Label(text=text, color=(0,0,0,255), \ 119 | x=self.width*4/5, y=self.height-40, batch=self.batch) 120 | 121 | text = "gap_size: " + str(format(gap_size.data.cpu().numpy(), '.3f')) 122 | height_label = pyglet.text.Label(text=text, color=(0,0,0,255), \ 123 | x=self.width*4/5, y=self.height-60, batch=self.batch) 124 | 125 | text = "gap_ratio: " + str(format(gap_ratio.data.cpu().numpy(), '.3f')) 126 | height_label = pyglet.text.Label(text=text, color=(0,0,0,255), \ 127 | x=self.width*4/5, y=self.height-80, batch=self.batch) 128 | 129 | 130 | def draw_top_line(self, height): 131 | 132 | g_height = height * self.scale 133 | 134 | black_color = lighten_color(BLACK) * 2 135 | self.batch.add(2, pyglet.gl.GL_LINES, None, 136 | ('v2f', (LEFT_SPACE, g_height, 137 | LEFT_SPACE+self.bin_width, g_height)), 138 | ('c4f', black_color)) 139 | 140 | 141 | def add_geom(self, box, i): 142 | 143 | # (graph) 144 | 145 | # width must be positive, so we need rotate back 146 | rotate = box[0].lt(0) 147 | 148 | color = darken_color(BLUE) 149 | 150 | box_width = box[0] 151 | box_height = box[1] 152 | 153 | box_x = (box[2] + 1) # (0~2) 154 | box_y = box[3] 155 | 156 | box_image = pyglet.image.create(int(box_width*self.scale), int(box_height*self.scale), 157 | BlockImagePattern(color)) 158 | box_obj = pyglet.sprite.Sprite(box_image, LEFT_SPACE+box_x*self.scale, box_y*self.scale, batch=self.batch) 159 | box_obj.rotation = 0 160 | 161 | self.geoms.append(box_obj) 162 | 163 | center_x = (box_x + box[0] /2) * self.scale + LEFT_SPACE 164 | center_y = (box_y + box[1] / 2) * self.scale 165 | 166 | box_order_label = pyglet.text.Label(text=str(i), color=(0, 0, 0, 255), x=center_x, y=center_y, batch=self.batch) 167 | 168 | 169 | def render(self): 170 | self.window.clear() 171 | self.window.switch_to() 172 | self.window.dispatch_events() 173 | self.batch.draw() 174 | self.window.flip() 175 | # self.geoms = [] 176 | 177 | def __del__(self): 178 | self.close() 179 | 180 | 181 | 182 | 183 | -------------------------------------------------------------------------------- /problems/pack3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongdongbh/RCQL/7d71dec03d1ac9ec12063c4f0d96ebf8e960f2e6/problems/pack3d/__init__.py -------------------------------------------------------------------------------- /problems/pack3d/load_br.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[269]: 5 | 6 | 7 | import os 8 | import numpy as np 9 | from itertools import cycle 10 | 11 | 12 | # In[260]: 13 | 14 | 15 | # h=(w_0*L+l_0*W)/(w_0+l_0) 16 | # w = w_0/W l=l_0/L 17 | def scale_inst_(inst, bin_size): 18 | inst[:,2] /= (inst[:,0]*bin_size[1] + inst[:,1]*bin_size[0])/(inst[:,0]+inst[:,1]) 19 | inst[:,0] /= bin_size[0] 20 | inst[:,1] /= bin_size[1] 21 | return inst 22 | 23 | 24 | # In[261]: 25 | 26 | 27 | def read_instance_(data_cycle): 28 | boxes = [] 29 | instance_num = np.asarray(list(map(int, next(data_cycle)))) 30 | bin_size = np.asarray(list(map(int, next(data_cycle)))) 31 | box_num = int(next(data_cycle)[0]) 32 | 33 | for _ in range(box_num): 34 | boxes.append(np.asarray(list(map(int, next(data_cycle))))) 35 | 36 | boxes = np.asarray(boxes) 37 | # print(boxes) 38 | 39 | box_sizes = boxes[:,[1,3,5]] 40 | box_num = boxes[:,-1] 41 | inst = np.repeat(box_sizes, box_num, axis=0) 42 | 43 | # we drop bin size! 44 | return inst, bin_size 45 | 46 | 47 | # In[270]: 48 | 49 | 50 | def read_ds(file): 51 | values = [] 52 | instances = [] 53 | bin_sizes = [] 54 | 55 | with open(file) as f: 56 | lines = f.readlines() 57 | for line in lines: 58 | values.append(line.split()) 59 | 60 | data_cycle = cycle(values) 61 | 62 | for inst in range(int(next(data_cycle)[0])): 63 | inst, bin_size = read_instance_(data_cycle) 64 | inst = inst.astype(float) 65 | bin_size = bin_size.astype(float) 66 | instances.append(inst) 67 | bin_sizes.append(bin_size) 68 | 69 | instances = [scale_inst_(inst, bin_size) for inst, bin_size in zip(instances, bin_sizes)] 70 | instances = np.vstack(instances) 71 | 72 | return instances 73 | 74 | 75 | # In[275]: 76 | 77 | 78 | def get_br_ds(path, graph_size=200, batch_size=32): 79 | 80 | insts1 = read_ds(os.path.join(path, "br1.txt")) 81 | insts2 = read_ds(os.path.join(path, "br2.txt")) 82 | insts3 = read_ds(os.path.join(path, "br3.txt")) 83 | insts4 = read_ds(os.path.join(path, "br4.txt")) 84 | insts5 = read_ds(os.path.join(path, "br5.txt")) 85 | insts6 = read_ds(os.path.join(path, "br6.txt")) 86 | 87 | insts8 = read_ds(os.path.join(path, "br8.txt")) 88 | insts9 = read_ds(os.path.join(path, "br9.txt")) 89 | 90 | training_ds = np.vstack([insts1,insts2,insts3,insts4,insts5,insts6,insts8,insts9]) 91 | test_ds = read_ds(os.path.join(path, "br7.txt")) 92 | 93 | divide_size = graph_size*batch_size 94 | 95 | b_n = training_ds.shape[0]//divide_size 96 | training_ds = training_ds[0:divide_size*b_n , :] 97 | 98 | b_n = test_ds.shape[0]//graph_size 99 | test_ds = test_ds[0:graph_size*b_n , :] 100 | 101 | return training_ds, test_ds 102 | 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /problems/pack3d/pack3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | from torch.utils.data import Dataset 5 | from problems.pack3d.load_br import get_br_ds 6 | import torch 7 | import os 8 | from random import randint 9 | from problems.pack3d.state_pack3d import StatePack3D 10 | 11 | from utils import generate_normal 12 | 13 | BIN_LENGTH = 50 14 | BIN_WIDTH = 50 15 | HALF_BIN_LENGTH = (BIN_LENGTH / 2) 16 | HALF_BIN_WIDTH = (BIN_WIDTH / 2) 17 | 18 | 19 | class Pack3DInit(Dataset): 20 | 21 | def __init__(self, block_size=20, batch_size=128, size_p1=0.4, size_p2=10.0, distribution='normal', **kargs): 22 | super(Pack3DInit, self).__init__() 23 | 24 | assert distribution is not None, "Data distribution must be specified for problem" 25 | 26 | if distribution == 'normal': 27 | self.data = [generate_normal(shape=(block_size, 3), mu=size_p1, sigma=size_p2, a=0.02, b=2.0) for i in range(batch_size)] 28 | elif distribution == 'br': 29 | training_ds, _ = get_br_ds('.') 30 | self.data = torch.from_numpy(training_ds) 31 | else: 32 | assert distribution == 'uniform' 33 | self.data = [torch.FloatTensor(block_size, 3).uniform_(size_p1, size_p2) for i in range(batch_size)] 34 | 35 | self.size = len(self.data) 36 | 37 | def __len__(self): 38 | return self.size 39 | 40 | def __getitem__(self, idx): 41 | return self.data[idx] 42 | 43 | class Pack3DUpdate(Dataset): 44 | 45 | def __init__(self, block_size=20, batch_size=128, block_num=10, size_p1=0.4, size_p2=10.0, distribution='normal', **kargs): 46 | super(Pack3DUpdate, self).__init__() 47 | 48 | assert distribution is not None, "Data distribution must be specified for problem" 49 | 50 | if distribution == 'br': 51 | training_ds, _ = get_br_ds('./problems/pack3d/br/') 52 | # n(1*3) n=batch_size*block_size*block_num 53 | training_ds = torch.split(torch.from_numpy(training_ds).float(), 1, 0) 54 | 55 | num_samples = batch_size*block_size*block_num 56 | 57 | assert num_samples < len(training_ds), "dataset size is too small for current setting!!!" 58 | 59 | test_num = len(training_ds)//num_samples 60 | test_id = randint(0, test_num-1) 61 | 62 | self.data = training_ds[test_id*num_samples: test_id*num_samples + num_samples] 63 | 64 | elif distribution == 'normal': 65 | self.data = [generate_normal(shape=(1, 3), mu=size_p1, sigma=size_p2, a=0.02, b=2.0) for i in range(batch_size*block_size*block_num)] 66 | else: 67 | assert distribution == 'uniform' 68 | self.data = [torch.FloatTensor(1, 3).uniform_(size_p1, size_p2) for i in range(batch_size*block_size*block_num)] 69 | 70 | self.size = len(self.data) 71 | 72 | def __len__(self): 73 | return self.size 74 | 75 | def __getitem__(self, idx): 76 | return self.data[idx] 77 | 78 | # one block zero data for last block feeding 79 | class ZeroDataeset(Dataset): 80 | def __init__(self, block_size=20, batch_size=128, **kargs): 81 | super(ZeroDataeset, self).__init__() 82 | 83 | self.data = [torch.zeros(1, 3) for i in range(batch_size*block_size)] 84 | 85 | self.size = len(self.data) 86 | 87 | def __len__(self): 88 | return self.size 89 | 90 | def __getitem__(self, idx): 91 | return self.data[idx] 92 | 93 | 94 | class Pack3D(object): 95 | NAME = 'pack3d' 96 | 97 | @staticmethod 98 | def make_dataset(*args, **kwargs): 99 | return Pack3DUpdate(*args, **kwargs) 100 | 101 | @staticmethod 102 | def make_state(*args, **kwargs): 103 | return StatePack3D(*args, **kwargs) 104 | 105 | # class Pack3dDataset(Dataset): 106 | # """docstring for Pack3dDataset""" 107 | 108 | # def __init__(self, size=20, num_samples=10000, offset=0, distribution=None): 109 | # super(Pack3dDataset, self).__init__() 110 | 111 | 112 | # # 8 max 1 113 | # # 20 max 0.4 114 | # # min_box_size = 0.1 * 10 / size 115 | # # max_box_size = 1.4 * 10 / size 116 | 117 | # min_box_size = 0.6 118 | # max_box_size = 1.2 119 | 120 | # self.data = [torch.FloatTensor(size, 3).uniform_(min_box_size, max_box_size) for i in range(num_samples)] 121 | 122 | # self.size = len(self.data) 123 | 124 | # def __len__(self): 125 | # return self.size 126 | 127 | # def __getitem__(self, idx): 128 | # return self.data[idx] 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /problems/pack3d/render.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import time 5 | from problems.pack3d.viewer import Viewer 6 | 7 | 8 | 9 | def get_render_data(state): 10 | graphs = state.get_graph() 11 | gap_sizes = state.get_gap_size() 12 | gap_ratios = state.get_gap_ratio() 13 | heights = state.get_height() 14 | orders = state.get_order() 15 | 16 | return graphs, heights, gap_sizes, gap_ratios, orders 17 | 18 | 19 | def render(graph, height, gap_size, gap_ratio, sleep=0): 20 | width = 2 21 | height = 2 22 | length = 2 23 | 24 | viewer = Viewer(width, height, length) 25 | 26 | graph = graph.numpy() 27 | # print(graph) 28 | # for i in range(20): 29 | # row = graph[19-i] 30 | # # print(row) 31 | # viewer.add_geom(row) 32 | for row in graph: 33 | # print("row:", row) 34 | viewer.add_geom(row) 35 | 36 | 37 | # print(1) -------------------------------------------------------------------------------- /problems/pack3d/state_pack3d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import torch 6 | from typing import NamedTuple 7 | import numpy as np 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | 12 | class PackAction(): 13 | # (batch, 1) 14 | 15 | def __init__(self, batch_size, device): 16 | self.index = torch.zeros(batch_size, 1, device=device) 17 | self.x = torch.empty( 18 | batch_size, 1, device=device).fill_(-2) # set to -2 19 | self.y = torch.empty(batch_size, 1, device=device).fill_(-2) 20 | self.z = torch.empty(batch_size, 1, device=device).fill_(-2) 21 | self.rotate = torch.zeros(batch_size, device=device) 22 | self.updated_shape = torch.empty(batch_size, 3, device=device) 23 | ''' 24 | 0: no rotate 25 | 1: (x,y,z) -> (y,x,z) 26 | 2: (x,y,z) -> (y,z,x) 27 | 3: (x,y,z) -> (z,y,x) 28 | 4: (x,y,z) -> (z,x,y) 29 | 5: (x,y,z) -> (x,z,y) 30 | ''' 31 | 32 | def set_index(self, selected): 33 | self.index = selected 34 | 35 | def set_rotate(self, rotate): 36 | self.rotate = rotate 37 | 38 | def set_shape(self, length, width, height): 39 | # (batch, 3) 40 | self.updated_shape = torch.stack([length, width, height], dim=-1) 41 | 42 | def get_shape(self): 43 | return self.updated_shape 44 | 45 | def set_pos(self, x, y, z): 46 | self.x = x 47 | self.y = y 48 | self.z = z 49 | 50 | def get_packed(self): 51 | return torch.cat((self.updated_shape, self.x, self.y, self.z), dim=-1) 52 | 53 | def reset(self): 54 | self.__init__(self.index.size(0)) 55 | 56 | def __call__(self): 57 | return {'index': self.index, 58 | 'rotate': self.rotate, 59 | 'x': self.x, 60 | 'y': self.y} 61 | 62 | def __len__(self): 63 | return self.index.size(0) 64 | 65 | 66 | def push_to_tensor_alternative(tensor, x): 67 | return torch.cat((tensor[:, 1:, :], x), dim=1) 68 | 69 | 70 | class StatePack3D(): 71 | 72 | def __init__(self, batch_size, instance_size, block_size, device, position_size=128, online=False, cache_size=None): 73 | 74 | if online: 75 | self.boxes = torch.zeros(batch_size, 1, 3, device=device) 76 | else: 77 | self.boxes = torch.zeros(batch_size, block_size, 3, device=device) 78 | 79 | if cache_size is None: 80 | cache_size = block_size 81 | 82 | self.online = online 83 | self.instance_size = instance_size 84 | self.device = device 85 | self.i = 0 86 | 87 | # {length| width| height| x| y| z} 88 | self.packed_state = torch.zeros( 89 | batch_size, block_size, 6, dtype=torch.float, device=device) 90 | self.packed_state_cache = torch.zeros( 91 | batch_size, cache_size, 6, dtype=torch.float, device=device) 92 | self.packed_cat = torch.cat( 93 | (self.packed_state_cache, self.packed_state), dim=1) 94 | 95 | self.packed_rotate = torch.zeros( 96 | batch_size, block_size, 1, dtype=torch.int64, device=device) 97 | 98 | self.boxes_volume = torch.zeros( 99 | batch_size, dtype=torch.float, device=device) 100 | self.total_rewards = torch.zeros( 101 | batch_size, dtype=torch.float, device=device) 102 | self.skyline = torch.zeros( 103 | batch_size, position_size, position_size, dtype=torch.float, device=device) 104 | self.action = PackAction(batch_size, device=device) 105 | 106 | def put_reward(self, reward): 107 | self.total_rewards += reward 108 | 109 | def get_rewards(self): 110 | return self.total_rewards 111 | 112 | def get_mask(self): 113 | 114 | mask_array = torch.from_numpy( 115 | np.tril(np.ones(20), k=-1).astype('bool_')).to(self.packed_state.device) 116 | 117 | block_size = self.packed_state.size(1) 118 | 119 | remain_steps = self.instance_size + block_size - self.i 120 | assert remain_steps > 0, 'over packed!!!' 121 | if remain_steps // block_size == 0: 122 | mask_num = block_size - remain_steps 123 | else: 124 | mask_num = 0 125 | 126 | return mask_array[mask_num] 127 | 128 | def update_env(self, new_box): 129 | 130 | if self.online: 131 | self.boxes = new_box 132 | else: 133 | batch_size, block_size, box_state_size = self.boxes.size() 134 | 135 | all_index = torch.arange( 136 | block_size, device=self.boxes.device).repeat(batch_size, 1) 137 | 138 | mask = ( 139 | all_index != self.action.index).unsqueeze(-1).repeat(1, 1, box_state_size) 140 | # selected_box (batch, block-1, box_state_size) 141 | remaining_boxes = torch.masked_select( 142 | self.boxes, mask).view(batch_size, -1, box_state_size) 143 | 144 | self.boxes = torch.cat((new_box, remaining_boxes), dim=1) 145 | 146 | def update_select(self, selected): 147 | self.action.set_index(selected) 148 | box_length, box_width, box_height = self._get_action_box_shape() 149 | 150 | self.action.set_shape(box_length, box_width, box_height) 151 | 152 | def _get_action_box_shape(self): 153 | 154 | if self.online: 155 | box_length = self.boxes[:, :, 0].squeeze(-1).squeeze(-1) 156 | box_width = self.boxes[:, :, 1].squeeze(-1).squeeze(-1) 157 | box_height = self.boxes[:, :, 2].squeeze(-1).squeeze(-1) 158 | 159 | else: 160 | 161 | select_index = self.action.index.long() 162 | 163 | box_raw_l = self.boxes[:, :, 0].squeeze(-1) 164 | box_raw_w = self.boxes[:, :, 1].squeeze(-1) 165 | box_raw_h = self.boxes[:, :, 2].squeeze(-1) 166 | 167 | box_length = torch.gather(box_raw_l, -1, select_index).squeeze(-1) 168 | box_width = torch.gather(box_raw_w, -1, select_index).squeeze(-1) 169 | box_height = torch.gather(box_raw_h, -1, select_index).squeeze(-1) 170 | 171 | return box_length, box_width, box_height 172 | 173 | def update_rotate(self, rotate): 174 | 175 | self.action.set_rotate(rotate) 176 | 177 | # there are 5 rotations except the original one 178 | rotate_types = 5 179 | batch_size = rotate.size()[0] 180 | 181 | rotate_mask = torch.empty((rotate_types, batch_size), dtype=torch.bool) 182 | 183 | select_index = self.action.index.long() 184 | 185 | box_raw_x = self.boxes[:, :, 0].squeeze(-1) 186 | box_raw_y = self.boxes[:, :, 1].squeeze(-1) 187 | box_raw_z = self.boxes[:, :, 2].squeeze(-1) 188 | 189 | # (batch) get the original box shape 190 | box_length = torch.gather(box_raw_x, -1, select_index).squeeze(-1) 191 | box_width = torch.gather(box_raw_y, -1, select_index).squeeze(-1) 192 | box_height = torch.gather(box_raw_z, -1, select_index).squeeze(-1) 193 | 194 | for i in range(rotate_types): 195 | rotate_mask[i] = rotate.squeeze(-1).eq(i + 1) 196 | 197 | # rotate in 5 directions one by one 198 | # (x,y,z)->(y,x,z) 199 | # (x,y,z)->(y,z,x) 200 | # (x,y,z)->(z,y,x) 201 | # (x,y,z)->(z,x,y) 202 | # (x,y,z)->(x,z,y) 203 | for i in range(rotate_types): 204 | box_l_rotate = box_width 205 | box_w_rotate = box_length 206 | box_h_rotate = box_height 207 | 208 | box_l_rotate = torch.masked_select( 209 | box_l_rotate, rotate_mask[i]) 210 | box_w_rotate = torch.masked_select( 211 | box_w_rotate, rotate_mask[i]) 212 | box_h_rotate = torch.masked_select( 213 | box_h_rotate, rotate_mask[i]) 214 | 215 | inbox_length = box_length.masked_scatter( 216 | rotate_mask[i], box_l_rotate) 217 | inbox_width = box_width.masked_scatter( 218 | rotate_mask[i], box_w_rotate) 219 | inbox_height = box_height.masked_scatter( 220 | rotate_mask[i], box_h_rotate) 221 | 222 | self.packed_rotate[torch.arange(0, rotate.size( 223 | 0)), select_index.squeeze(-1), 0] = rotate.squeeze(-1) 224 | 225 | self.action.set_shape(inbox_length, inbox_width, inbox_height) 226 | 227 | def update_pack(self, x, y): 228 | batch_size = self.boxes.size(0) 229 | select_index = self.action.index.squeeze(-1).long() 230 | 231 | z = self._get_z_skyline(x, y) 232 | self.action.set_pos(x, y, z) 233 | 234 | packed_box = self.action.get_packed().unsqueeze(-2) 235 | inbox_shape = self.action.get_shape() 236 | 237 | self.boxes_volume += (inbox_shape[:, 0] * 238 | inbox_shape[:, 1] * inbox_shape[:, 2]).squeeze(-1) 239 | 240 | self.packed_state_cache = push_to_tensor_alternative( 241 | self.packed_state_cache, self.packed_state[:, 0:1, :]) 242 | self.packed_state = push_to_tensor_alternative( 243 | self.packed_state, packed_box) 244 | 245 | self.packed_cat = torch.cat( 246 | (self.packed_state_cache, self.packed_state), dim=1) 247 | 248 | self.i += 1 249 | 250 | def _get_z_skyline(self, x, y): 251 | 252 | inbox_length = self.action.get_packed()[:, 0] 253 | inbox_width = self.action.get_packed()[:, 1] 254 | inbox_height = self.action.get_packed( 255 | )[:, 2].unsqueeze(-1).unsqueeze(-1) 256 | 257 | position_size = self.skyline.size(1) 258 | batch_size = self.skyline.size(0) 259 | 260 | in_back = torch.min(x.squeeze(-1), x.squeeze(-1) + inbox_length) 261 | in_front = torch.max(x.squeeze(-1), x.squeeze(-1) + inbox_length) 262 | in_left = torch.min(y.squeeze(-1), y.squeeze(-1) + inbox_width) 263 | in_right = torch.max(y.squeeze(-1), y.squeeze(-1) + inbox_width) 264 | 265 | back_idx = ((in_back + 1.0) * (position_size / 2) 266 | ).floor().long().unsqueeze(-1) 267 | front_idx = ((in_front + 1.0) * (position_size / 2) 268 | ).floor().long().unsqueeze(-1) 269 | left_idx = ((in_left + 1.0) * (position_size / 2) 270 | ).floor().long().unsqueeze(-1) 271 | right_idx = ((in_right + 1.0) * (position_size / 2) 272 | ).floor().long().unsqueeze(-1) 273 | 274 | mask_x = torch.arange( 275 | 0, position_size, device=self.device).repeat(batch_size, 1) 276 | mask_y = torch.arange( 277 | 0, position_size, device=self.device).repeat(batch_size, 1) 278 | 279 | mask_back = mask_x.ge(back_idx) 280 | mask_front = mask_x.le(front_idx) 281 | mask_left = mask_y.ge(left_idx) 282 | mask_right = mask_y.le(right_idx) 283 | 284 | mask_x = mask_back * mask_front 285 | mask_y = mask_left * mask_right 286 | 287 | mask_x = mask_x.view(batch_size, position_size, 288 | 1).float() # (batch, pos_size, 1) 289 | # (batch, 1, pos_size) 290 | mask_y = mask_y.view(batch_size, 1, position_size).float() 291 | 292 | mask = torch.matmul(mask_x, mask_y) # (batch, pos_size, pos_size) 293 | 294 | masked_skyline = mask * self.skyline 295 | non_masked_skyline = (1 - mask) * self.skyline 296 | 297 | max_z = torch.max(masked_skyline.view( 298 | batch_size, -1), 1)[0].unsqueeze(-1).float() 299 | 300 | update_skyline = mask * \ 301 | (max_z.unsqueeze(-1) + inbox_height) + non_masked_skyline 302 | 303 | self.skyline = update_skyline 304 | 305 | return max_z 306 | 307 | def _get_z(self, x, y): 308 | 309 | inbox_length = self.action.get_packed()[:, 0] 310 | inbox_width = self.action.get_packed()[:, 1] 311 | 312 | in_back = torch.min(x.squeeze(-1), x.squeeze(-1) + inbox_length) 313 | in_front = torch.max(x.squeeze(-1), x.squeeze(-1) + inbox_length) 314 | in_left = torch.min(y.squeeze(-1), y.squeeze(-1) + inbox_width) 315 | in_right = torch.max(y.squeeze(-1), y.squeeze(-1) + inbox_width) 316 | 317 | box_length = self.packed_cat[:, :, 0] 318 | box_width = self.packed_cat[:, :, 1] 319 | box_height = self.packed_cat[:, :, 2] 320 | 321 | box_x = self.packed_cat[:, :, 3] 322 | box_y = self.packed_cat[:, :, 4] 323 | box_z = self.packed_cat[:, :, 5] 324 | 325 | box_back = torch.min(box_x, box_x + box_length) 326 | box_front = torch.max(box_x, box_x + box_length) 327 | box_left = torch.min(box_y, box_y + box_width) 328 | box_right = torch.max(box_y, box_y + box_width) 329 | box_top = torch.max(box_z, box_z + box_height) 330 | 331 | in_back = in_back.unsqueeze(-1).repeat([1, self.packed_cat.size()[1]]) 332 | in_front = in_front.unsqueeze(-1).repeat( 333 | [1, self.packed_cat.size()[1]]) 334 | in_left = in_left.unsqueeze(-1).repeat([1, self.packed_cat.size()[1]]) 335 | in_right = in_right.unsqueeze(-1).repeat( 336 | [1, self.packed_cat.size()[1]]) 337 | 338 | is_back = torch.gt(box_front, in_back) 339 | is_front = torch.lt(box_back, in_front) 340 | is_left = torch.gt(box_right, in_left) 341 | is_right = torch.lt(box_left, in_right) 342 | 343 | is_overlaped = is_back * is_front * is_left * is_right 344 | non_overlaped = ~is_overlaped 345 | 346 | overlap_box_top = box_top.masked_fill(non_overlaped, 0) 347 | 348 | max_z, _ = torch.max(overlap_box_top, -1, keepdim=True) 349 | 350 | return max_z 351 | 352 | def get_boundx(self): 353 | batch_size = self.packed_state.size()[0] 354 | front_bound = torch.ones( 355 | batch_size, device=self.packed_state.device) - self.action.get_shape()[:, 0] 356 | 357 | return front_bound 358 | 359 | def get_boundy(self): 360 | 361 | batch_size = self.packed_state.size()[0] 362 | right_bound = torch.ones( 363 | batch_size, device=self.packed_state.device) - self.action.get_shape()[:, 1] 364 | 365 | return right_bound 366 | 367 | def get_height(self): 368 | box_height = self.packed_cat[:, :, 2] 369 | 370 | box_top = self.packed_cat[:, :, 5] + box_height 371 | 372 | heights, _ = torch.max(box_top, -1) 373 | 374 | return heights 375 | 376 | def get_gap_size(self): 377 | 378 | bin_volumn = self.get_height() * 4.0 379 | 380 | gap_volumn = bin_volumn - self.boxes_volume 381 | 382 | return gap_volumn 383 | 384 | def get_gap_ratio(self): 385 | 386 | bin_volumn = self.get_height() * 4.0 387 | 388 | gap_ratio = self.get_gap_size() / bin_volumn 389 | 390 | return gap_ratio 391 | 392 | def get_graph(self): 393 | return self.packed_cat 394 | -------------------------------------------------------------------------------- /problems/pack3d/viewer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | from __future__ import division 6 | from vpython import * 7 | # from visual import * 8 | import torch 9 | import random 10 | import numpy as np 11 | 12 | scene.caption = """Right button drag or Ctrl-drag to rotate "camera" to view scene. 13 | To zoom, drag with middle button or Alt/Option depressed, or use scroll wheel. 14 | On a two-button mouse, middle is left + right. 15 | Shift-drag to pan left/right and up/down. 16 | Touch screen: pinch/extend to zoom, swipe or two-finger rotate.""" 17 | scene.width = 800 18 | scene.height = 800 19 | 20 | BLACK = vector(0, 0, 0) # color 0 21 | BLUE = vector(0, 0, 1) # color 1 22 | GREEN = vector(0, 1, 0) # color 2 23 | CYAN = vector(0, 1, 1) # color 3 24 | RED = vector(1, 0, 0) # color 4 25 | MAGENTA = vector(1, 0, 1) # color 5 26 | YELLOW = vector(1, 1, 0) # color 6 27 | WHITE = vector(1, 1, 1) # color 7 28 | GRAY = vector(0.5, 0.5, 0.5) # color 8 29 | 30 | COLORS = [GRAY, BLUE, GREEN, CYAN, RED, MAGENTA, YELLOW, WHITE, BLACK] 31 | 32 | 33 | 34 | class Viewer(object): 35 | def __init__(self, width, height, length): 36 | self.width= width 37 | self.height = height 38 | self.length = length 39 | self.thick = 0.001 40 | wallR = box(pos=vector(self.width/2, self.height/2-self.height/2, 0), size=vector(self.thick, self.height, self.length), color=COLORS[7], opacity=0.2) 41 | wallL = box(pos=vector(-self.width/2, self.height/2-self.height/2, 0), size=vector(-self.thick, self.height, self.length), color=COLORS[7], opacity=0.2) 42 | wallF = box(pos=vector(0, self.height/2-self.height/2, self.length/2), size=vector(self.width, self.height, self.thick), color=COLORS[7], opacity=0.2) 43 | wallB = box(pos=vector(0, self.height/2-self.height/2, -self.length/2), size=vector(self.width, self.height, -self.thick), color=COLORS[7], opacity=0.2) 44 | wallD = box(pos=vector(0, 0-self.height/2, 0), size=vector(self.width, self.thick, self.length), color=COLORS[7], opacity=0.4) 45 | 46 | 47 | def add_geom(self, boxs): 48 | # color_index = random.randint(1,6) 49 | boxs = boxs.tolist() 50 | # print(boxs[6]) 51 | pos = vector(boxs[4] + boxs[1]/2, boxs[5] + boxs[2]/2 - self.height/2, boxs[3] + boxs[0]/2) 52 | size = vector(boxs[1], boxs[2], boxs[0]) 53 | # print(size) 54 | # color_index = int(boxs[6]) 55 | 56 | 57 | add_box = box(pos=pos, size=size, color=random.choice(COLORS), opacity=0.9) 58 | 59 | 60 | 61 | # viewer = Viewer(2, 2, 2) 62 | 63 | -------------------------------------------------------------------------------- /resources/MONOFONT.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dongdongbh/RCQL/7d71dec03d1ac9ec12063c4f0d96ebf8e960f2e6/resources/MONOFONT.TTF -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import random 4 | 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | 13 | 14 | from utils import ( 15 | move_to, 16 | load_problem, 17 | logger, 18 | explained_variance) 19 | 20 | from problems.pack2d.render import render 21 | from pack_step import pack_step 22 | 23 | 24 | def train_epoch( 25 | modules, 26 | optimizer, 27 | scheduler, 28 | problem_params, 29 | device, 30 | target_entropy, 31 | batch_size, 32 | block_size, 33 | hidden_size, 34 | **kargs): 35 | 36 | 37 | problem = load_problem(problem_params['problem_type']) 38 | 39 | dataupdate = problem.make_dataset(block_size=block_size, batch_size=batch_size, online=problem_params['on_line'], **problem_params) 40 | 41 | update_dataloader = DataLoader(dataupdate, batch_size=batch_size) 42 | 43 | # only encoder requires cache 44 | h_cache_a = [ 45 | torch.zeros( 46 | batch_size, 47 | layer.attn.attn.get_cache_size(), 48 | hidden_size).to(device) 49 | for layer in modules['actor']['encoder'].module.layers] 50 | 51 | h_cache_c = [ 52 | torch.zeros( 53 | batch_size, 54 | layer.attn.attn.get_cache_size(), 55 | hidden_size).to(device) 56 | for layer in modules['critic'].module.att_encoder.layers] 57 | 58 | h_caches = [h_cache_a, h_cache_c] 59 | 60 | updatedata_iterator = iter(update_dataloader) 61 | instance_size = len(updatedata_iterator) 62 | 63 | state = problem.make_state( 64 | batch_size, instance_size, block_size, device, online=problem_params['on_line']) 65 | 66 | 67 | state, values, returns, losses, entropy, grad_norms = train_instance( 68 | modules, 69 | optimizer, 70 | scheduler, 71 | h_caches, 72 | block_size, 73 | state, 74 | updatedata_iterator, 75 | target_entropy, 76 | problem_params, 77 | **kargs) 78 | 79 | # For TensorboardX logs 80 | alpha_tlogs = modules['critic'].module.log_alpha.clone() 81 | 82 | return state, values, returns, losses, entropy, grad_norms, alpha_tlogs 83 | 84 | 85 | 86 | 87 | 88 | def train_instance( 89 | modules, 90 | optimizer, 91 | scheduler, 92 | h_caches, 93 | block_size, 94 | state, 95 | updatedata_iterator, 96 | target_entropy, 97 | problem_params, 98 | **kargs): 99 | 100 | device = state.packed_state.device 101 | total_losses = torch.tensor([0,0,0,0], dtype=torch.float, device=device) 102 | total_entropy = torch.zeros(1, dtype=torch.float, device=device) 103 | 104 | # if box number is not inter times of nsteps, then we drop last several data 105 | update_mb_number = int(len(updatedata_iterator) // kargs['nsteps']) 106 | 107 | # pack the last block 108 | if not state.online: 109 | for mb_id in tqdm(range(update_mb_number), disable=True): 110 | state, entropy, values, returns, losses, grad_norms = train_minibatch( 111 | modules, 112 | optimizer, 113 | scheduler, 114 | h_caches, 115 | state, 116 | updatedata_iterator, 117 | target_entropy, 118 | device, 119 | problem_params, 120 | **kargs) 121 | 122 | total_entropy += entropy 123 | total_losses += losses 124 | 125 | # Here we want to pack the last block 126 | last_mb_number = block_size // kargs['nsteps'] 127 | 128 | for mb_id in tqdm(range(last_mb_number), disable=True): 129 | state, entropy, values, returns, losses, grad_norms = train_minibatch( 130 | modules, 131 | optimizer, 132 | scheduler, 133 | h_caches, 134 | state, 135 | None, 136 | target_entropy, 137 | device, 138 | problem_params, 139 | **kargs) 140 | 141 | total_entropy += entropy 142 | total_losses += losses 143 | update_mb_number += last_mb_number 144 | 145 | else: 146 | for mb_id in tqdm(range(update_mb_number), disable=True): 147 | 148 | state, entropy, values, returns, losses, grad_norms = train_minibatch( 149 | modules, 150 | optimizer, 151 | scheduler, 152 | h_caches, 153 | state, 154 | updatedata_iterator, 155 | target_entropy, 156 | device, 157 | problem_params, 158 | **kargs) 159 | 160 | total_entropy += entropy 161 | total_losses += losses 162 | 163 | 164 | average_entropy = total_entropy / update_mb_number 165 | average_losses = total_losses / update_mb_number 166 | 167 | return state, values, returns, average_losses, average_entropy, grad_norms 168 | 169 | 170 | 171 | def train_minibatch(modules, 172 | optimizer, 173 | scheduler, 174 | h_caches, 175 | state, 176 | data_iterator, 177 | target_entropy, 178 | device, 179 | problem_params, 180 | ent_coef, 181 | full_eval_mode, 182 | **kargs): 183 | 184 | h_caches, returns, advs, values, log_likelihoods, entropys = get_mb_data( 185 | modules, h_caches, state, data_iterator, device, problem_params, **kargs) 186 | 187 | alpha_loss = -1 * torch.mv((-1*entropys + target_entropy).detach(), modules['critic'].module.log_alpha).mean() 188 | 189 | # entropy_loss = -1 * ent_coef * entropys.mean() 190 | 191 | value_loss = F.mse_loss(values, returns.detach()) 192 | 193 | advs = returns - values 194 | # Normalize the advantages 195 | advs = (advs - advs.mean()) / (advs.std() + 1e-8) 196 | 197 | # do not backward critic in actor 198 | advantages = advs.detach() 199 | 200 | # Calculate loss (gradient ascent) loss=A*log 201 | actor_loss = -1 * (advantages * log_likelihoods).mean() 202 | 203 | loss = actor_loss + value_loss + alpha_loss 204 | 205 | if not full_eval_mode: 206 | # Perform backward pass and optimization step 207 | optimizer.zero_grad() 208 | loss.backward() 209 | 210 | # Clip gradient norms and get (clipped) gradient norms for logging 211 | grad_norms = clip_grad_norms(optimizer.param_groups, kargs['grad_clip']) 212 | optimizer.step() 213 | 214 | if scheduler is not None: 215 | scheduler.step() 216 | else: 217 | grad_norms = None 218 | 219 | losses = torch.tensor([actor_loss, value_loss, alpha_loss, loss], device=device) 220 | 221 | return state, entropys.mean(), values, returns, losses, grad_norms 222 | 223 | 224 | 225 | def get_mb_data(modules, h_caches, state, updatedata_iterator, device, problem_params, gamma, nsteps, lam, soft_temp, **kargs): 226 | 227 | def sf01(arr): 228 | """ 229 | swap and then flatten axes 0 and 1 230 | """ 231 | s = arr.size() 232 | return arr.transpose(0, 1).reshape(s[0] * s[1], *s[2:]) 233 | 234 | mb_rewards, mb_values, mb_log_likelihoods, mb_entropy = [],[],[],[] 235 | 236 | for _ in range(nsteps): 237 | # pack last block 238 | if updatedata_iterator is None: 239 | if problem_params['problem_type'] == 'pack2d': 240 | batch = torch.zeros(state.packed_state.size(0), 1, 2, device=device) 241 | else: 242 | batch = torch.zeros(state.packed_state.size(0), 1, 3, device=device) 243 | 244 | else: 245 | try: 246 | batch = next(updatedata_iterator) 247 | batch = move_to(batch, device) 248 | 249 | # print(batch[0]) 250 | except StopIteration: 251 | print("-------------------------------------------------------------") 252 | print("No more data in the instance!!!") 253 | print("-------------------------------------------------------------") 254 | # print('batch 0: ',batch[0]) 255 | 256 | 257 | ll, entropy, value, h_caches, reward = _run_batch(modules, h_caches, state, batch, problem_params, soft_temp) 258 | mb_values.append(value) 259 | mb_log_likelihoods.append(ll) 260 | mb_rewards.append(reward) 261 | mb_entropy.append(entropy) 262 | 263 | # batch of steps to batch of roll-outs (nstep, batch) 264 | mb_rewards = torch.stack(mb_rewards) 265 | mb_values = torch.stack(mb_values) 266 | mb_log_likelihoods = torch.stack(mb_log_likelihoods) 267 | # (nstep, batch, 3) 268 | mb_entropys = torch.stack(mb_entropy) 269 | 270 | 271 | last_values, _ = modules['critic'](state.boxes, state.packed_state, h_caches[1]) 272 | last_values = last_values.squeeze(-1) 273 | 274 | mb_returns = torch.zeros_like(mb_rewards) 275 | mb_advs = torch.zeros_like(mb_rewards) 276 | 277 | lastgaelam = 0 278 | 279 | for t in reversed(range(nsteps)): 280 | if t == nsteps - 1: 281 | nextvalues = last_values 282 | else: 283 | nextvalues = mb_values[t+1] 284 | 285 | delta = mb_rewards[t] + gamma * nextvalues - mb_values[t] 286 | mb_advs[t] = lastgaelam = delta + gamma * lam * lastgaelam 287 | 288 | # use return to supervise critic 289 | mb_returns = mb_advs + mb_values 290 | 291 | # (batch * nstep, ) 292 | returns, advs, values, log_likelihoods, entropys = map(sf01, \ 293 | (mb_returns, mb_advs, mb_values, mb_log_likelihoods, mb_entropys)) 294 | 295 | 296 | return h_caches, returns, advs, values, log_likelihoods, entropys 297 | 298 | 299 | def _run_batch(modules, h_caches, state, batch, problem_params, soft_temp): 300 | 301 | # update pack candidates for next packing step 302 | state.update_env(batch) 303 | 304 | last_gap = state.get_gap_size() 305 | 306 | if problem_params['problem_type'] == 'pack2d': 307 | 308 | s_log_p, r_log_p, x_log_p, value, h_caches = pack_step(modules, state, h_caches, problem_params) 309 | 310 | actions = state.action() 311 | # position to discrete 312 | actions['x'] = ((actions['x'] + 1.0) * (x_log_p.size(1)/2)).round().long() 313 | 314 | # ll (batch, 1), entropy (batch) 315 | ll = _calc_log_likelihood(actions, s_log_p, r_log_p, x_log_p, state.online) 316 | entropys = _calc_entropy(s_log_p, r_log_p, x_log_p, state.online) 317 | 318 | elif problem_params['problem_type'] == 'pack3d': 319 | 320 | s_log_p, r_log_p, x_log_p, y_log_p, value, h_caches = pack_step(modules, state, h_caches, problem_params) 321 | 322 | actions = state.action() 323 | 324 | actions['x'] = ((actions['x'] + 1.0) * (x_log_p.size(1)/2)).round().long() 325 | actions['y'] = ((actions['y'] + 1.0) * (y_log_p.size(1)/2)).round().long() 326 | 327 | ll = _calc_log_likelihood_3d(actions, s_log_p, r_log_p, x_log_p, y_log_p, state.online) 328 | entropys = _calc_entropy_3d(s_log_p, r_log_p, x_log_p, y_log_p, state.online) 329 | 330 | # # ll (batch, 1), entropy (batch, 3) 331 | # ll = _calc_log_likelihood(actions, s_log_p, r_log_p, x_log_p, state.online) 332 | # entropys = _calc_entropy(s_log_p, r_log_p, x_log_p, state.online) 333 | 334 | # (batch) 335 | new_gap = state.get_gap_size() 336 | reward = last_gap - new_gap 337 | 338 | alpha = torch.exp(modules['critic'].module.log_alpha) 339 | 340 | 341 | # print('entropy: ', entropys.size(), "alpha: ", alpha.size()) 342 | reward += torch.mv(entropys, alpha).detach() 343 | 344 | state.put_reward(reward) 345 | 346 | return ll, entropys, value, h_caches, reward 347 | 348 | 349 | def _calc_entropy(s_log, r_log, x_log, online): 350 | # log (batch, action_num) 351 | 352 | # S=-/sum_i (p_i \ln p_i) 353 | if online: 354 | s_entropy = torch.zeros(r_log.size(0), device=r_log.device) 355 | else: 356 | s_entropy = -1 * (s_log.exp() *s_log).sum(dim=-1) 357 | 358 | r_entropy = -1 * (r_log.exp() *r_log).sum(dim=-1) 359 | x_entropy = -1 * (x_log.exp() * x_log).sum(dim=-1) 360 | # print(r_entropy, r_entropy.size()) 361 | # (batch, 3) 362 | entropys = torch.stack([s_entropy, r_entropy, x_entropy], dim=-1) 363 | # entropy = x_entropy 364 | 365 | assert not torch.isnan(entropys).any() 366 | 367 | return entropys 368 | 369 | def _calc_entropy_3d(s_log, r_log, x_log, y_log, online): 370 | # log (batch, action_num) 371 | 372 | # S=-/sum_i (p_i \ln p_i) 373 | if online: 374 | s_entropy = torch.zeros(r_log.size(0), device=r_log.device) 375 | else: 376 | s_entropy = -1 * (s_log.exp() *s_log).sum(dim=-1) 377 | 378 | r_entropy = -1 * (r_log.exp() * r_log).sum(dim=-1) 379 | x_entropy = -1 * (x_log.exp() * x_log).sum(dim=-1) 380 | y_entropy = -1 * (y_log.exp() * y_log).sum(dim=-1) 381 | 382 | # (batch) 383 | # entropy = s_entropy + r_entropy + x_entropy + y_entropy 384 | entropys = torch.stack([s_entropy, r_entropy, x_entropy, y_entropy], dim=-1) 385 | 386 | # entropy = x_entropy 387 | 388 | assert not torch.isnan(entropys).any() 389 | 390 | return entropys 391 | 392 | def _calc_log_likelihood(actions, s_log, r_log, x_log, online): 393 | 394 | # actions (batch, 4) 395 | # log (batch, action_num) 396 | 397 | # (batch, 1) 398 | 399 | action_r = actions['rotate'] 400 | action_x = actions['x'] 401 | 402 | #(batch) 403 | if online: 404 | s_log_p = 0 405 | else: 406 | action_s = actions['index'] 407 | s_log_p = s_log.gather(1, action_s).squeeze(-1) 408 | assert (s_log_p > -1000).data.all(), "log probability should not -inf, check sampling" 409 | 410 | r_log_p = r_log.gather(1, action_r).squeeze(-1) 411 | x_log_p = x_log.gather(1, action_x).squeeze(-1) 412 | 413 | 414 | assert (r_log_p > -1000).data.all(), "log probability should not -inf, check sampling" 415 | assert (x_log_p > -1000).data.all(), "log probability should not -inf, check sampling" 416 | 417 | log_likelihood = s_log_p+ r_log_p + x_log_p 418 | 419 | # print(s_log_p.mean(), r_log_p.mean(), x_log_p.mean()) 420 | 421 | return log_likelihood 422 | 423 | def _calc_log_likelihood_3d(actions, s_log, r_log, x_log, y_log, online): 424 | 425 | # actions (batch, 4) 426 | # log (batch, action_num) 427 | 428 | # (batch, 1) 429 | 430 | action_r = actions['rotate'] 431 | action_x = actions['x'] 432 | action_y = actions['y'] 433 | #(batch) 434 | if online: 435 | s_log_p = 0 436 | else: 437 | action_s = actions['index'] 438 | s_log_p = s_log.gather(1, action_s).squeeze(-1) 439 | assert (s_log_p > -1000).data.all(), "log probability should not -inf, check sampling" 440 | 441 | r_log_p = r_log.gather(1, action_r).squeeze(-1) 442 | x_log_p = x_log.gather(1, action_x).squeeze(-1) 443 | y_log_p = y_log.gather(1, action_y).squeeze(-1) 444 | 445 | 446 | assert (r_log_p > -1000).data.all(), "log probability should not -inf, check sampling" 447 | assert (x_log_p > -1000).data.all(), "log probability should not -inf, check sampling" 448 | assert (y_log_p > -1000).data.all(), "log probability should not -inf, check sampling" 449 | 450 | 451 | log_likelihood = s_log_p+ r_log_p + x_log_p + y_log_p 452 | 453 | # print(s_log_p.mean(), r_log_p.mean(), x_log_p.mean()) 454 | 455 | return log_likelihood 456 | 457 | 458 | def clip_grad_norms(param_groups, max_norm=math.inf): 459 | """ 460 | Clips the norms for all param groups to max_norm and returns gradient norms before clipping 461 | :param optimizer: 462 | :param max_norm: 463 | :param gradient_norms_log: 464 | :return: grad_norms, clipped_grad_norms: list with (clipped) gradient norms per group 465 | """ 466 | grad_norms = [ 467 | torch.nn.utils.clip_grad_norm_( 468 | group['params'], 469 | max_norm if max_norm > 0 else math.inf, # Inf so no clipping but still call to calc 470 | norm_type=2 471 | ) 472 | for group in param_groups 473 | ] 474 | grad_norms_clipped = [min(g_norm, max_norm) for g_norm in grad_norms] if max_norm > 0 else grad_norms 475 | return grad_norms, grad_norms_clipped 476 | 477 | 478 | 479 | # do full evaluation 480 | def full_eval( 481 | modules, 482 | optimizer, 483 | scheduler, 484 | h_caches, 485 | problem, 486 | init_dataloader, 487 | update_dataloader, 488 | device, 489 | ent_coef, 490 | **kargs): 491 | 492 | modules.eval() 493 | 494 | return train_epoch(modules, optimizer, scheduler, problem_params, device, 495 | **model_params, **trainer_params, **optim_params, **rl_params) 496 | 497 | 498 | 499 | def epoch_logger(epoch, state, values, returns, losses, entropy, grad_norms, log_alpha, optimizer, tb_writer, log_interval, run_name): 500 | 501 | total_gap = state.get_gap_size() 502 | gap_ratio = state.get_gap_ratio() 503 | rewards = state.get_rewards() 504 | 505 | avg_gap = total_gap.mean().item() 506 | avg_gap_ratio = gap_ratio.mean().item() 507 | var_gap_ratio = gap_ratio.var().item() 508 | avg_rewards = rewards.mean().item() 509 | min_gap = torch.min(gap_ratio) 510 | max_gap = torch.max(gap_ratio) 511 | 512 | grad_norms, grad_norms_clipped = grad_norms 513 | 514 | ev = explained_variance(values.detach().cpu().numpy(), returns.detach().cpu().numpy()) 515 | 516 | # Log values to screen 517 | if epoch % log_interval == 0: 518 | print('\nepoch: {}, run {}, avg_rewards: {}, gap_ratio: {}, var_gap_ratio: {}, ev: {}, loss: {}'.\ 519 | format(epoch, run_name, avg_rewards, avg_gap_ratio, var_gap_ratio, ev, losses[3])) 520 | print('min gap ratio: {}, max gap ratio: {}'.format(min_gap, max_gap)) 521 | print('grad_norm: {}, clipped: {}'.format(grad_norms[0], grad_norms_clipped[0])) 522 | print('grad_norm_c: {}, clipped_c: {}'.format(grad_norms[1], grad_norms_clipped[1])) 523 | 524 | logger.logkv("epoch", epoch) 525 | logger.logkv("explained_variance", float(ev)) 526 | logger.logkv('entropy', entropy.item()) 527 | 528 | logger.logkv('actor_loss', losses[0].item()) 529 | logger.logkv('value_loss', losses[1].item()) 530 | logger.logkv('alpha_loss', losses[2].item()) 531 | logger.logkv('avg_rewards', avg_rewards) 532 | logger.logkv('gap_ratio', avg_gap_ratio) 533 | logger.logkv('var_gap_ratio', var_gap_ratio) 534 | 535 | logger.dumpkvs() 536 | 537 | # Log values to tensorboard 538 | if tb_writer is not None: 539 | tb_writer.add_scalar('avg_rewards', avg_rewards, epoch) 540 | tb_writer.add_scalar('entropy', entropy, epoch) 541 | tb_writer.add_scalar('s_log_alpha', log_alpha[0].item(), epoch) 542 | tb_writer.add_scalar('r_log_alpha', log_alpha[1].item(), epoch) 543 | tb_writer.add_scalar('p_log_alpha', log_alpha[2].item(), epoch) 544 | tb_writer.add_scalar('gap_ratio', avg_gap_ratio, epoch) 545 | tb_writer.add_scalar('var_gap_ratio', var_gap_ratio, epoch) 546 | 547 | tb_writer.add_scalar('min_gap', min_gap, epoch) 548 | tb_writer.add_scalar('explained_variance', float(ev), epoch) 549 | 550 | 551 | 552 | tb_writer.add_scalar('actor_loss', losses[0].item(), epoch) 553 | tb_writer.add_scalar('value_loss', losses[1].item(), epoch) 554 | tb_writer.add_scalar('alpha_loss', losses[2].item(), epoch) 555 | 556 | tb_writer.add_scalar('grad_norm', grad_norms[0], epoch) 557 | tb_writer.add_scalar('grad_norm_c', grad_norms[1], epoch) 558 | 559 | tb_writer.add_scalar('learnrate_pg0', optimizer.param_groups[0]['lr'], epoch) 560 | tb_writer.add_scalar('learnrate_pg1', optimizer.param_groups[1]['lr'], epoch) 561 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .utils_fb import * 3 | from .math_util import * 4 | from .truncated_normal import * -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | 6 | 7 | import warnings 8 | 9 | import torch 10 | import numpy as np 11 | import os 12 | import json 13 | from tqdm import tqdm 14 | from multiprocessing.dummy import Pool as ThreadPool 15 | from multiprocessing import Pool 16 | import torch.nn.functional as F 17 | 18 | 19 | def load_problem(name): 20 | from problems import Pack2D, Pack3D 21 | problem = { 22 | 'pack2d': Pack2D, 23 | 'pack3d': Pack3D, 24 | }.get(name, None) 25 | assert problem is not None, "Currently unsupported problem: {}!".format(name) 26 | return problem 27 | 28 | 29 | def torch_load_cpu(load_path): 30 | return torch.load(load_path, map_location=lambda storage, loc: storage) # Load on CPU 31 | 32 | 33 | def move_to(var, device): 34 | if isinstance(var, dict): 35 | return {k: move_to(v, device) for k, v in var.items()} 36 | return var.to(device) 37 | 38 | 39 | def _load_model_file(load_path, model): 40 | """Loads the model with parameters from the file and returns optimizer state dict if it is in the file""" 41 | 42 | # Load the model parameters from a saved state 43 | load_optimizer_state_dict = None 44 | print(' [*] Loading model from {}'.format(load_path)) 45 | 46 | load_data = torch.load( 47 | os.path.join( 48 | os.getcwd(), 49 | load_path 50 | ), map_location=lambda storage, loc: storage) 51 | 52 | if isinstance(load_data, dict): 53 | load_optimizer_state_dict = load_data.get('optimizer', None) 54 | load_model_state_dict = load_data.get('model', load_data) 55 | else: 56 | load_model_state_dict = load_data.state_dict() 57 | 58 | state_dict = model.state_dict() 59 | 60 | state_dict.update(load_model_state_dict) 61 | 62 | model.load_state_dict(state_dict) 63 | 64 | return model, load_optimizer_state_dict 65 | 66 | 67 | def load_args(filename): 68 | with open(filename, 'r') as f: 69 | args = json.load(f) 70 | 71 | # Backwards compatibility 72 | if 'data_distribution' not in args: 73 | args['data_distribution'] = None 74 | probl, *dist = args['problem'].split("_") 75 | if probl == "op": 76 | args['problem'] = probl 77 | args['data_distribution'] = dist[0] 78 | return args 79 | 80 | 81 | def load_model(path, epoch=None): 82 | from nets.att_model import PackingModel 83 | 84 | if os.path.isfile(path): 85 | model_filename = path 86 | path = os.path.dirname(model_filename) 87 | elif os.path.isdir(path): 88 | if epoch is None: 89 | epoch = max( 90 | int(os.path.splitext(filename)[0].split("-")[1]) 91 | for filename in os.listdir(path) 92 | if os.path.splitext(filename)[1] == '.pt' 93 | ) 94 | model_filename = os.path.join(path, 'epoch-{}.pt'.format(epoch)) 95 | else: 96 | assert False, "{} is not a valid directory or file".format(path) 97 | 98 | args = load_args(os.path.join(path, 'args.json')) 99 | 100 | problem = load_problem(args['problem']) 101 | 102 | model_class = { 103 | 'attention': PackingModel 104 | }.get(args.get('model', 'attention'), None) 105 | assert model_class is not None, "Unknown model: {}".format(model_class) 106 | 107 | model = model_class( 108 | args['graph_size'], 109 | args['embedding_dim'], 110 | problem, 111 | n_encode_layers=args['n_encode_layers'], 112 | n_decode_layers=args['n_decode_layers'], 113 | tanh_clipping=args['tanh_clipping'], 114 | ) 115 | # Overwrite model parameters by parameters to load 116 | load_data = torch_load_cpu(model_filename) 117 | model.load_state_dict({**model.state_dict(), **load_data.get('model', {})}) 118 | 119 | model, *_ = _load_model_file(model_filename, model) 120 | 121 | model.eval() # Put in eval mode 122 | 123 | return model, args 124 | 125 | 126 | def parse_softmax_temperature(raw_temp): 127 | # Load from file 128 | if os.path.isfile(raw_temp): 129 | return np.loadtxt(raw_temp)[-1, 0] 130 | return float(raw_temp) 131 | 132 | 133 | def run_all_in_pool(func, directory, dataset, opts, use_multiprocessing=True): 134 | # # Test 135 | # res = func((directory, 'test', *dataset[0])) 136 | # return [res] 137 | 138 | num_cpus = os.cpu_count() if opts.cpus is None else opts.cpus 139 | 140 | w = len(str(len(dataset) - 1)) 141 | offset = getattr(opts, 'offset', None) 142 | if offset is None: 143 | offset = 0 144 | ds = dataset[offset:(offset + opts.n if opts.n is not None else len(dataset))] 145 | pool_cls = (Pool if use_multiprocessing and num_cpus > 1 else ThreadPool) 146 | with pool_cls(num_cpus) as pool: 147 | results = list(tqdm(pool.imap( 148 | func, 149 | [ 150 | ( 151 | directory, 152 | str(i + offset).zfill(w), 153 | *problem 154 | ) 155 | for i, problem in enumerate(ds) 156 | ] 157 | ), total=len(ds), mininterval=opts.progress_bar_mininterval)) 158 | 159 | failed = [str(i + offset) for i, res in enumerate(results) if res is None] 160 | assert len(failed) == 0, "Some instances failed: {}".format(" ".join(failed)) 161 | return results, num_cpus 162 | 163 | 164 | def do_batch_rep(v, n): 165 | if isinstance(v, dict): 166 | return {k: do_batch_rep(v_, n) for k, v_ in v.items()} 167 | elif isinstance(v, list): 168 | return [do_batch_rep(v_, n) for v_ in v] 169 | elif isinstance(v, tuple): 170 | return tuple(do_batch_rep(v_, n) for v_ in v) 171 | 172 | return v[None, ...].expand(n, *v.size()).contiguous().view(-1, *v.size()[1:]) 173 | 174 | 175 | def sample_many(inner_func, get_cost_func, input, batch_rep=1, iter_rep=1): 176 | """ 177 | :param input: (batch_size, graph_size, node_dim) input node features 178 | :return: 179 | """ 180 | input = do_batch_rep(input, batch_rep) 181 | 182 | costs = [] 183 | pis = [] 184 | for i in range(iter_rep): 185 | _log_p, pi = inner_func(input) 186 | # pi.view(-1, batch_rep, pi.size(-1)) 187 | cost, mask = get_cost_func(input, pi) 188 | 189 | costs.append(cost.view(batch_rep, -1).t()) 190 | pis.append(pi.view(batch_rep, -1, pi.size(-1)).transpose(0, 1)) 191 | 192 | max_length = max(pi.size(-1) for pi in pis) 193 | # (batch_size * batch_rep, iter_rep, max_length) => (batch_size, batch_rep * iter_rep, max_length) 194 | pis = torch.cat( 195 | [F.pad(pi, (0, max_length - pi.size(-1))) for pi in pis], 196 | 1 197 | ) # .view(embeddings.size(0), batch_rep * iter_rep, max_length) 198 | costs = torch.cat(costs, 1) 199 | 200 | # (batch_size) 201 | mincosts, argmincosts = costs.min(-1) 202 | # (batch_size, minlength) 203 | minpis = pis[torch.arange(pis.size(0), out=argmincosts.new()), argmincosts] 204 | 205 | return minpis, mincosts 206 | 207 | -------------------------------------------------------------------------------- /utils/log_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import os 6 | import pandas as pd 7 | import numpy as np 8 | import torch 9 | from problems.pack2d.render import get_render_data 10 | 11 | 12 | def save_train_graph(state, epoch, save_dir): 13 | 14 | dataframe = _get_graph(state) 15 | dataframe.to_csv(os.path.join(save_dir, 'epoch-{}.csv'.format(epoch))) 16 | 17 | 18 | def save_validate_graph(state, save_dir): 19 | dataframe = _get_graph(state) 20 | dataframe.to_csv(os.path.join(save_dir, 'validate.csv')) 21 | 22 | def _get_graph(state): 23 | 24 | # (batch, graph, 4) (batch) 25 | render_data = get_render_data(state) 26 | 27 | # only save 10 examples 28 | def clip(data): return data[:10].data.cpu().numpy() 29 | # graph(n*5) others(n) 30 | clipped_data = list(map(clip, render_data)) 31 | 32 | # min gap ratio 33 | min_gap_ratio, min_index = torch.topk( 34 | render_data[3], 10, -1, largest=False) 35 | 36 | min_graphs = torch.index_select( 37 | render_data[0], 0, min_index).data.cpu().numpy() 38 | min_gap_sizes = torch.index_select( 39 | render_data[2], 0, min_index).data.cpu().numpy() 40 | min_heights = torch.index_select( 41 | render_data[1], 0, min_index).data.cpu().numpy() 42 | min_gap_ratio = min_gap_ratio.data.cpu().numpy() 43 | 44 | min_render_data = [min_graphs, min_heights, min_gap_sizes, min_gap_ratio] 45 | 46 | df_min = list_to_df(min_render_data) 47 | df_random = list_to_df(clipped_data) 48 | 49 | df_save = pd.concat([df_min, df_random], axis=0) 50 | 51 | return df_save 52 | 53 | 54 | def list_to_df(clipped_data): 55 | 56 | df = [] 57 | 58 | graphs = clipped_data[0] 59 | 60 | for i, row in enumerate(graphs): 61 | #(20, 4) + () 62 | graph_df = pd.DataFrame(row.transpose()) 63 | graph_df['heights'] = clipped_data[1][i] 64 | graph_df['gap_sizes'] = clipped_data[2][i] 65 | graph_df['gap_ratios'] = clipped_data[3][i] 66 | df.append(graph_df) 67 | 68 | df = pd.concat(df) 69 | return df 70 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import os 6 | import sys 7 | import shutil 8 | import os.path as osp 9 | import json 10 | import time 11 | import datetime 12 | import tempfile 13 | from collections import defaultdict 14 | from contextlib import contextmanager 15 | 16 | DEBUG = 10 17 | INFO = 20 18 | WARN = 30 19 | ERROR = 40 20 | 21 | DISABLED = 50 22 | 23 | 24 | class KVWriter(object): 25 | def writekvs(self, kvs): 26 | raise NotImplementedError 27 | 28 | 29 | class SeqWriter(object): 30 | def writeseq(self, seq): 31 | raise NotImplementedError 32 | 33 | 34 | class HumanOutputFormat(KVWriter, SeqWriter): 35 | def __init__(self, filename_or_file): 36 | if isinstance(filename_or_file, str): 37 | self.file = open(filename_or_file, 'wt') 38 | self.own_file = True 39 | else: 40 | assert hasattr( 41 | filename_or_file, 'read'), 'expected file or str, got %s' % filename_or_file 42 | self.file = filename_or_file 43 | self.own_file = False 44 | 45 | def writekvs(self, kvs): 46 | # Create strings for printing 47 | key2str = {} 48 | for (key, val) in sorted(kvs.items()): 49 | if hasattr(val, '__float__'): 50 | valstr = '%-8.3g' % val 51 | else: 52 | valstr = str(val) 53 | key2str[self._truncate(key)] = self._truncate(valstr) 54 | 55 | # Find max widths 56 | if len(key2str) == 0: 57 | print('WARNING: tried to write empty key-value dict') 58 | return 59 | else: 60 | keywidth = max(map(len, key2str.keys())) 61 | valwidth = max(map(len, key2str.values())) 62 | 63 | # Write out the data 64 | dashes = '-' * (keywidth + valwidth + 7) 65 | lines = [dashes] 66 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 67 | lines.append('| %s%s | %s%s |' % ( 68 | key, 69 | ' ' * (keywidth - len(key)), 70 | val, 71 | ' ' * (valwidth - len(val)), 72 | )) 73 | lines.append(dashes) 74 | self.file.write('\n'.join(lines) + '\n') 75 | 76 | # Flush the output to the file 77 | self.file.flush() 78 | 79 | def _truncate(self, s): 80 | maxlen = 30 81 | return s[:maxlen - 3] + '...' if len(s) > maxlen else s 82 | 83 | def writeseq(self, seq): 84 | seq = list(seq) 85 | for (i, elem) in enumerate(seq): 86 | self.file.write(elem) 87 | if i < len(seq) - 1: # add space unless this is the last one 88 | self.file.write(' ') 89 | self.file.write('\n') 90 | self.file.flush() 91 | 92 | def close(self): 93 | if self.own_file: 94 | self.file.close() 95 | 96 | 97 | class JSONOutputFormat(KVWriter): 98 | def __init__(self, filename): 99 | self.file = open(filename, 'wt') 100 | 101 | def writekvs(self, kvs): 102 | for k, v in sorted(kvs.items()): 103 | if hasattr(v, 'dtype'): 104 | kvs[k] = float(v) 105 | self.file.write(json.dumps(kvs) + '\n') 106 | self.file.flush() 107 | 108 | def close(self): 109 | self.file.close() 110 | 111 | 112 | class CSVOutputFormat(KVWriter): 113 | def __init__(self, filename): 114 | self.file = open(filename, 'w+t') 115 | self.keys = [] 116 | self.sep = ',' 117 | 118 | def writekvs(self, kvs): 119 | # Add our current row to the history 120 | extra_keys = list(kvs.keys() - self.keys) 121 | extra_keys.sort() 122 | if extra_keys: 123 | self.keys.extend(extra_keys) 124 | self.file.seek(0) 125 | lines = self.file.readlines() 126 | self.file.seek(0) 127 | for (i, k) in enumerate(self.keys): 128 | if i > 0: 129 | self.file.write(',') 130 | self.file.write(k) 131 | self.file.write('\n') 132 | for line in lines[1:]: 133 | self.file.write(line[:-1]) 134 | self.file.write(self.sep * len(extra_keys)) 135 | self.file.write('\n') 136 | for (i, k) in enumerate(self.keys): 137 | if i > 0: 138 | self.file.write(',') 139 | v = kvs.get(k) 140 | if v is not None: 141 | self.file.write(str(v)) 142 | self.file.write('\n') 143 | self.file.flush() 144 | 145 | def close(self): 146 | self.file.close() 147 | 148 | 149 | class TensorBoardOutputFormat(KVWriter): 150 | """ 151 | Dumps key/value pairs into TensorBoard's numeric format. 152 | """ 153 | 154 | def __init__(self, dir): 155 | os.makedirs(dir, exist_ok=True) 156 | self.dir = dir 157 | self.step = 1 158 | prefix = 'events' 159 | path = osp.join(osp.abspath(dir), prefix) 160 | import tensorflow as tf 161 | from tensorflow.python import pywrap_tensorflow 162 | from tensorflow.core.util import event_pb2 163 | from tensorflow.python.util import compat 164 | self.tf = tf 165 | self.event_pb2 = event_pb2 166 | self.pywrap_tensorflow = pywrap_tensorflow 167 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 168 | 169 | def writekvs(self, kvs): 170 | def summary_val(k, v): 171 | kwargs = {'tag': k, 'simple_value': float(v)} 172 | return self.tf.Summary.Value(**kwargs) 173 | summary = self.tf.Summary( 174 | value=[summary_val(k, v) for k, v in kvs.items()]) 175 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 176 | event.step = self.step # is there any reason why you'd want to specify the step? 177 | self.writer.WriteEvent(event) 178 | self.writer.Flush() 179 | self.step += 1 180 | 181 | def close(self): 182 | if self.writer: 183 | self.writer.Close() 184 | self.writer = None 185 | 186 | 187 | def make_output_format(format, ev_dir, log_suffix=''): 188 | os.makedirs(ev_dir, exist_ok=True) 189 | if format == 'stdout': 190 | return HumanOutputFormat(sys.stdout) 191 | elif format == 'log': 192 | return HumanOutputFormat(osp.join(ev_dir, 'log%s.txt' % log_suffix)) 193 | elif format == 'json': 194 | return JSONOutputFormat(osp.join(ev_dir, 'progress%s.json' % log_suffix)) 195 | elif format == 'csv': 196 | return CSVOutputFormat(osp.join(ev_dir, 'progress%s.csv' % log_suffix)) 197 | elif format == 'tensorboard': 198 | return TensorBoardOutputFormat(osp.join(ev_dir, 'tb%s' % log_suffix)) 199 | else: 200 | raise ValueError('Unknown format specified: %s' % (format,)) 201 | 202 | # ================================================================ 203 | # API 204 | # ================================================================ 205 | 206 | 207 | def logkv(key, val): 208 | """ 209 | Log a value of some diagnostic 210 | Call this once for each diagnostic quantity, each iteration 211 | If called many times, last value will be used. 212 | """ 213 | get_current().logkv(key, val) 214 | 215 | 216 | def logkv_mean(key, val): 217 | """ 218 | The same as logkv(), but if called many times, values averaged. 219 | """ 220 | get_current().logkv_mean(key, val) 221 | 222 | 223 | def logkvs(d): 224 | """ 225 | Log a dictionary of key-value pairs 226 | """ 227 | for (k, v) in d.items(): 228 | logkv(k, v) 229 | 230 | 231 | def dumpkvs(): 232 | """ 233 | Write all of the diagnostics from the current iteration 234 | """ 235 | return get_current().dumpkvs() 236 | 237 | 238 | def getkvs(): 239 | return get_current().name2val 240 | 241 | 242 | def log(*args, level=INFO): 243 | """ 244 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 245 | """ 246 | get_current().log(*args, level=level) 247 | 248 | 249 | def debug(*args): 250 | log(*args, level=DEBUG) 251 | 252 | 253 | def info(*args): 254 | log(*args, level=INFO) 255 | 256 | 257 | def warn(*args): 258 | log(*args, level=WARN) 259 | 260 | 261 | def error(*args): 262 | log(*args, level=ERROR) 263 | 264 | 265 | def set_level(level): 266 | """ 267 | Set logging threshold on current logger. 268 | """ 269 | get_current().set_level(level) 270 | 271 | 272 | def set_comm(comm): 273 | get_current().set_comm(comm) 274 | 275 | 276 | def get_dir(): 277 | """ 278 | Get directory that log files are being written to. 279 | will be None if there is no output directory (i.e., if you didn't call start) 280 | """ 281 | return get_current().get_dir() 282 | 283 | 284 | record_tabular = logkv 285 | dump_tabular = dumpkvs 286 | 287 | 288 | @contextmanager 289 | def profile_kv(scopename): 290 | logkey = 'wait_' + scopename 291 | tstart = time.time() 292 | try: 293 | yield 294 | finally: 295 | get_current().name2val[logkey] += time.time() - tstart 296 | 297 | 298 | def profile(n): 299 | """ 300 | Usage: 301 | @profile("my_func") 302 | def my_func(): code 303 | """ 304 | def decorator_with_name(func): 305 | def func_wrapper(*args, **kwargs): 306 | with profile_kv(n): 307 | return func(*args, **kwargs) 308 | return func_wrapper 309 | return decorator_with_name 310 | 311 | 312 | # ================================================================ 313 | # Backend 314 | # ================================================================ 315 | 316 | def get_current(): 317 | if Logger.CURRENT is None: 318 | _configure_default_logger() 319 | 320 | return Logger.CURRENT 321 | 322 | 323 | class Logger(object): 324 | # A logger with no output files. (See right below class definition) 325 | DEFAULT = None 326 | # So that you can still log to the terminal without setting up any output files 327 | CURRENT = None # Current logger being used by the free functions above 328 | 329 | def __init__(self, dir, output_formats, comm=None): 330 | self.name2val = defaultdict(float) # values this iteration 331 | self.name2cnt = defaultdict(int) 332 | self.level = INFO 333 | self.dir = dir 334 | self.output_formats = output_formats 335 | self.comm = comm 336 | 337 | # Logging API, forwarded 338 | # ---------------------------------------- 339 | def logkv(self, key, val): 340 | self.name2val[key] = val 341 | 342 | def logkv_mean(self, key, val): 343 | oldval, cnt = self.name2val[key], self.name2cnt[key] 344 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 345 | self.name2cnt[key] = cnt + 1 346 | 347 | def dumpkvs(self): 348 | if self.comm is None: 349 | d = self.name2val 350 | else: 351 | from baselines.common import mpi_util 352 | d = mpi_util.mpi_weighted_mean(self.comm, 353 | {name: (val, self.name2cnt.get(name, 1)) 354 | for (name, val) in self.name2val.items()}) 355 | if self.comm.rank != 0: 356 | d['dummy'] = 1 # so we don't get a warning about empty dict 357 | out = d.copy() # Return the dict for unit testing purposes 358 | for fmt in self.output_formats: 359 | if isinstance(fmt, KVWriter): 360 | fmt.writekvs(d) 361 | self.name2val.clear() 362 | self.name2cnt.clear() 363 | return out 364 | 365 | def log(self, *args, level=INFO): 366 | if self.level <= level: 367 | self._do_log(args) 368 | 369 | # Configuration 370 | # ---------------------------------------- 371 | def set_level(self, level): 372 | self.level = level 373 | 374 | def set_comm(self, comm): 375 | self.comm = comm 376 | 377 | def get_dir(self): 378 | return self.dir 379 | 380 | def close(self): 381 | for fmt in self.output_formats: 382 | fmt.close() 383 | 384 | # Misc 385 | # ---------------------------------------- 386 | def _do_log(self, args): 387 | for fmt in self.output_formats: 388 | if isinstance(fmt, SeqWriter): 389 | fmt.writeseq(map(str, args)) 390 | 391 | 392 | def get_rank_without_mpi_import(): 393 | # check environment variables here instead of importing mpi4py 394 | # to avoid calling MPI_Init() when this module is imported 395 | for varname in ['PMI_RANK', 'OMPI_COMM_WORLD_RANK']: 396 | if varname in os.environ: 397 | return int(os.environ[varname]) 398 | return 0 399 | 400 | 401 | def configure(dir=None, format_strs=None, comm=None, log_suffix=''): 402 | """ 403 | If comm is provided, average all numerical stats across that comm 404 | """ 405 | if dir is None: 406 | dir = os.getenv('OPENAI_LOGDIR') 407 | if dir is None: 408 | dir = osp.join(tempfile.gettempdir(), 409 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f")) 410 | assert isinstance(dir, str) 411 | dir = os.path.expanduser(dir) 412 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 413 | 414 | rank = get_rank_without_mpi_import() 415 | if rank > 0: 416 | log_suffix = log_suffix + "-rank%03i" % rank 417 | 418 | if format_strs is None: 419 | if rank == 0: 420 | format_strs = os.getenv( 421 | 'OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',') 422 | else: 423 | format_strs = os.getenv('OPENAI_LOG_FORMAT_MPI', 'log').split(',') 424 | format_strs = filter(None, format_strs) 425 | output_formats = [make_output_format( 426 | f, dir, log_suffix) for f in format_strs] 427 | 428 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 429 | if output_formats: 430 | log('Logging to %s' % dir) 431 | 432 | 433 | def _configure_default_logger(): 434 | configure() 435 | Logger.DEFAULT = Logger.CURRENT 436 | 437 | 438 | def reset(): 439 | if Logger.CURRENT is not Logger.DEFAULT: 440 | Logger.CURRENT.close() 441 | Logger.CURRENT = Logger.DEFAULT 442 | log('Reset logger') 443 | 444 | 445 | @contextmanager 446 | def scoped_configure(dir=None, format_strs=None, comm=None): 447 | prevlogger = Logger.CURRENT 448 | configure(dir=dir, format_strs=format_strs, comm=comm) 449 | try: 450 | yield 451 | finally: 452 | Logger.CURRENT.close() 453 | Logger.CURRENT = prevlogger 454 | 455 | # ================================================================ 456 | 457 | 458 | def _demo(): 459 | info("hi") 460 | debug("shouldn't appear") 461 | set_level(DEBUG) 462 | debug("should appear") 463 | dir = "/tmp/testlogging" 464 | if os.path.exists(dir): 465 | shutil.rmtree(dir) 466 | configure(dir=dir) 467 | logkv("a", 3) 468 | logkv("b", 2.5) 469 | dumpkvs() 470 | logkv("b", -2.5) 471 | logkv("a", 5.5) 472 | dumpkvs() 473 | info("^^^ should see a = 5.5") 474 | logkv_mean("b", -22.5) 475 | logkv_mean("b", -44.4) 476 | logkv("a", 5.5) 477 | dumpkvs() 478 | info("^^^ should see b = -33.3") 479 | 480 | logkv("b", -2.5) 481 | dumpkvs() 482 | 483 | logkv("a", "longasslongasslongasslongasslongasslongassvalue") 484 | dumpkvs() 485 | 486 | 487 | # ================================================================ 488 | # Readers 489 | # ================================================================ 490 | 491 | def read_json(fname): 492 | import pandas 493 | ds = [] 494 | with open(fname, 'rt') as fh: 495 | for line in fh: 496 | ds.append(json.loads(line)) 497 | return pandas.DataFrame(ds) 498 | 499 | 500 | def read_csv(fname): 501 | import pandas 502 | return pandas.read_csv(fname, index_col=None, comment='#') 503 | 504 | 505 | def read_tb(path): 506 | """ 507 | path : a tensorboard file OR a directory, where we will find all TB files 508 | of the form events.* 509 | """ 510 | import pandas 511 | import numpy as np 512 | from glob import glob 513 | import tensorflow as tf 514 | if osp.isdir(path): 515 | fnames = glob(osp.join(path, "events.*")) 516 | elif osp.basename(path).startswith("events."): 517 | fnames = [path] 518 | else: 519 | raise NotImplementedError( 520 | "Expected tensorboard file or directory containing them. Got %s" % path) 521 | tag2pairs = defaultdict(list) 522 | maxstep = 0 523 | for fname in fnames: 524 | for summary in tf.train.summary_iterator(fname): 525 | if summary.step > 0: 526 | for v in summary.summary.value: 527 | pair = (summary.step, v.simple_value) 528 | tag2pairs[v.tag].append(pair) 529 | maxstep = max(summary.step, maxstep) 530 | data = np.empty((maxstep, len(tag2pairs))) 531 | data[:] = np.nan 532 | tags = sorted(tag2pairs.keys()) 533 | for (colidx, tag) in enumerate(tags): 534 | pairs = tag2pairs[tag] 535 | for (step, value) in pairs: 536 | data[step - 1, colidx] = value 537 | return pandas.DataFrame(data, columns=tags) 538 | 539 | 540 | if __name__ == "__main__": 541 | _demo() 542 | -------------------------------------------------------------------------------- /utils/math_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | import numpy as np 8 | import scipy.signal 9 | 10 | 11 | def discount(x, gamma): 12 | """ 13 | computes discounted sums along 0th dimension of x. 14 | 15 | inputs 16 | ------ 17 | x: ndarray 18 | gamma: float 19 | 20 | outputs 21 | ------- 22 | y: ndarray with same shape as x, satisfying 23 | 24 | y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k], 25 | where k = len(x) - t - 1 26 | 27 | """ 28 | assert x.ndim >= 1 29 | return scipy.signal.lfilter([1],[1,-gamma],x[::-1], axis=0)[::-1] 30 | 31 | def explained_variance(ypred,y): 32 | """ 33 | Computes fraction of variance that ypred explains about y. 34 | Returns 1 - Var[y-ypred] / Var[y] 35 | 36 | interpretation: 37 | ev=0 => might as well have predicted zero 38 | ev=1 => perfect prediction 39 | ev<0 => worse than just predicting zero 40 | 41 | """ 42 | assert y.ndim == 1 and ypred.ndim == 1 43 | vary = np.var(y) 44 | return np.nan if vary==0 else 1 - np.var(y-ypred)/vary 45 | 46 | def explained_variance_2d(ypred, y): 47 | assert y.ndim == 2 and ypred.ndim == 2 48 | vary = np.var(y, axis=0) 49 | out = 1 - np.var(y-ypred)/vary 50 | out[vary < 1e-10] = 0 51 | return out 52 | 53 | def ncc(ypred, y): 54 | return np.corrcoef(ypred, y)[1,0] 55 | 56 | def flatten_arrays(arrs): 57 | return np.concatenate([arr.flat for arr in arrs]) 58 | 59 | def unflatten_vector(vec, shapes): 60 | i=0 61 | arrs = [] 62 | for shape in shapes: 63 | size = np.prod(shape) 64 | arr = vec[i:i+size].reshape(shape) 65 | arrs.append(arr) 66 | i += size 67 | return arrs 68 | 69 | def discount_with_boundaries(X, New, gamma): 70 | """ 71 | X: 2d array of floats, time x features 72 | New: 2d array of bools, indicating when a new episode has started 73 | """ 74 | Y = np.zeros_like(X) 75 | T = X.shape[0] 76 | Y[T-1] = X[T-1] 77 | for t in range(T-2, -1, -1): 78 | Y[t] = X[t] + gamma * Y[t+1] * (1 - New[t+1]) 79 | return Y 80 | 81 | def test_discount_with_boundaries(): 82 | gamma=0.9 83 | x = np.array([1.0, 2.0, 3.0, 4.0], 'float32') 84 | starts = [1.0, 0.0, 0.0, 1.0] 85 | y = discount_with_boundaries(x, starts, gamma) 86 | assert np.allclose(y, [ 87 | 1 + gamma * 2 + gamma**2 * 3, 88 | 2 + gamma * 3, 89 | 3, 90 | 4 91 | ]) 92 | 93 | -------------------------------------------------------------------------------- /utils/seeding.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import numpy as np 3 | import os 4 | import random as _random 5 | from six import integer_types 6 | import struct 7 | import sys 8 | 9 | from gym import error 10 | 11 | def np_random(seed=None): 12 | if seed is not None and not (isinstance(seed, integer_types) and 0 <= seed): 13 | raise error.Error('Seed must be a non-negative integer or omitted, not {}'.format(seed)) 14 | 15 | seed = create_seed(seed) 16 | 17 | rng = np.random.RandomState() 18 | rng.seed(_int_list_from_bigint(hash_seed(seed))) 19 | return rng, seed 20 | 21 | def hash_seed(seed=None, max_bytes=8): 22 | """Any given evaluation is likely to have many PRNG's active at 23 | once. (Most commonly, because the environment is running in 24 | multiple processes.) There's literature indicating that having 25 | linear correlations between seeds of multiple PRNG's can correlate 26 | the outputs: 27 | 28 | http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/ 29 | http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be 30 | http://dl.acm.org/citation.cfm?id=1276928 31 | 32 | Thus, for sanity we hash the seeds before using them. (This scheme 33 | is likely not crypto-strength, but it should be good enough to get 34 | rid of simple correlations.) 35 | 36 | Args: 37 | seed (Optional[int]): None seeds from an operating system specific randomness source. 38 | max_bytes: Maximum number of bytes to use in the hashed seed. 39 | """ 40 | if seed is None: 41 | seed = create_seed(max_bytes=max_bytes) 42 | hash = hashlib.sha512(str(seed).encode('utf8')).digest() 43 | return _bigint_from_bytes(hash[:max_bytes]) 44 | 45 | def create_seed(a=None, max_bytes=8): 46 | """Create a strong random seed. Otherwise, Python 2 would seed using 47 | the system time, which might be non-robust especially in the 48 | presence of concurrency. 49 | 50 | Args: 51 | a (Optional[int, str]): None seeds from an operating system specific randomness source. 52 | max_bytes: Maximum number of bytes to use in the seed. 53 | """ 54 | # Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py 55 | if a is None: 56 | a = _bigint_from_bytes(os.urandom(max_bytes)) 57 | elif isinstance(a, str): 58 | a = a.encode('utf8') 59 | a += hashlib.sha512(a).digest() 60 | a = _bigint_from_bytes(a[:max_bytes]) 61 | elif isinstance(a, integer_types): 62 | a = a % 2**(8 * max_bytes) 63 | else: 64 | raise error.Error('Invalid type for seed: {} ({})'.format(type(a), a)) 65 | 66 | return a 67 | 68 | # TODO: don't hardcode sizeof_int here 69 | def _bigint_from_bytes(bytes): 70 | sizeof_int = 4 71 | padding = sizeof_int - len(bytes) % sizeof_int 72 | bytes += b'\0' * padding 73 | int_count = int(len(bytes) / sizeof_int) 74 | unpacked = struct.unpack("{}I".format(int_count), bytes) 75 | accum = 0 76 | for i, val in enumerate(unpacked): 77 | accum += 2 ** (sizeof_int * 8 * i) * val 78 | return accum 79 | 80 | def _int_list_from_bigint(bigint): 81 | # Special case 0 82 | if bigint < 0: 83 | raise error.Error('Seed must be non-negative, not {}'.format(bigint)) 84 | elif bigint == 0: 85 | return [0] 86 | 87 | ints = [] 88 | while bigint > 0: 89 | bigint, mod = divmod(bigint, 2 ** 32) 90 | ints.append(mod) 91 | return ints 92 | 93 | -------------------------------------------------------------------------------- /utils/statistic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | import os 6 | import pandas as pd 7 | import json 8 | 9 | 10 | def get_row(path): 11 | arg_file = os.path.join(path, 'args.json') 12 | progress_file = os.path.join(path, 'progress.csv') 13 | 14 | with open(arg_file, 'r') as f: 15 | data = json.load(f) 16 | 17 | args_df = pd.DataFrame([data]) 18 | 19 | prog_df = pd.read_csv(progress_file, index_col=0) 20 | 21 | last_20_line = prog_df.tail(5).mean(axis=0, skipna=True).reset_index() 22 | 23 | last_20_line.set_index('index', inplace=True) 24 | 25 | last_20_line = last_20_line.T 26 | 27 | test_row = args_df.join(last_20_line) 28 | 29 | return test_row 30 | 31 | 32 | def get_test_folders(path): 33 | sub_folders = [] 34 | dirs = [] 35 | 36 | for r, d, f in os.walk(path): 37 | dirs[:] = [d for d in d if not d[0] == '.'] 38 | for folder in dirs: 39 | # has epoch-x.pt 40 | folder_path = os.path.join(path, folder) 41 | if len(os.listdir(folder_path)) > 3: 42 | # print("go to ", folder_path) 43 | # for filename in os.listdir(folder_path): 44 | # print("file", filename) 45 | epoch = max( 46 | int(os.path.splitext(filename)[0].split("-")[1]) 47 | for filename in os.listdir(folder_path) 48 | if os.path.splitext(filename)[1] == '.pt' 49 | ) 50 | # print("epoch", epoch) 51 | 52 | if epoch > 2: 53 | sub_folders.append(os.path.join(r, folder)) 54 | return sub_folders 55 | 56 | 57 | def select_coloum(df): 58 | drop_list = ['model', 'no_cuda', 'checkpoint_encoder', 59 | 'data_distribution', 'log_dir', 'output_dir', 60 | 'epoch_start', 'checkpoint_epochs', 'load_path', 61 | 'resume', 'no_tensorboard', 'no_progress_bar', 'use_cuda', 62 | 'save_dir', 'eval_only', 'normalization'] 63 | 64 | df.drop(drop_list, axis=1, inplace=True) 65 | 66 | cols = list(df.columns.values) 67 | cols.pop(cols.index('run_name')) 68 | cols.pop(cols.index('gap_ratio')) 69 | cols.pop(cols.index('misc/step')) 70 | cols.pop(cols.index('problem')) 71 | cols.pop(cols.index('graph_size')) 72 | cols.pop(cols.index('epoch')) 73 | 74 | cols.pop(cols.index('lr_model')) 75 | cols.pop(cols.index('hidden_dim')) 76 | cols.pop(cols.index('batch_size')) 77 | 78 | df = df[['gap_ratio', 'misc/step', 'epoch', 'lr_model', 79 | 'run_name', 'problem', 'graph_size', 80 | 'hidden_dim', 'batch_size'] + cols] 81 | 82 | return df 83 | 84 | 85 | def get_statistic(folders): 86 | test_list = [] 87 | for folder in folders: 88 | test_list.append(get_row(folder)) 89 | 90 | graph_stt = pd.concat(test_list, sort=False).reset_index() 91 | graph_stt = select_coloum(graph_stt) 92 | 93 | return graph_stt 94 | -------------------------------------------------------------------------------- /utils/truncated_normal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def parameterized_truncated_normal(uniform, mu, sigma, a, b): 9 | normal = torch.distributions.normal.Normal(0, 1) 10 | 11 | alpha = (a - mu) / sigma 12 | beta = (b - mu) / sigma 13 | 14 | alpha_normal_cdf = normal.cdf(alpha) 15 | p = alpha_normal_cdf + (normal.cdf(beta) - alpha_normal_cdf) * uniform 16 | 17 | p = p.numpy() 18 | one = np.array(1, dtype=p.dtype) 19 | epsilon = np.array(np.finfo(p.dtype).eps, dtype=p.dtype) 20 | v = np.clip(2 * p - 1, -one + epsilon, one - epsilon) 21 | x = mu + sigma * np.sqrt(2) * torch.erfinv(torch.from_numpy(v)) 22 | x = torch.clamp(x, a, b) 23 | 24 | return x 25 | 26 | 27 | def truncated_normal(uniform, mu, sigma, a, b): 28 | return parameterized_truncated_normal(uniform, mu, sigma, a, b) 29 | 30 | 31 | def sample_truncated_normal(shape=(), mu=0.4, sigma=1.0, a=0.02, b=2.0): 32 | return truncated_normal(torch.from_numpy(np.random.uniform(0, 1, shape)).float(), mu, sigma, a, b) 33 | 34 | def generate_normal(shape=(), mu=0.4, sigma=1.0, a=0.02, b=2.0): 35 | dataset = [] 36 | for i in range(shape[0] * shape[1]): 37 | data = torch.FloatTensor(1).normal_(mean=mu, std=sigma) 38 | # while not a <= data <= b: 39 | # data = torch.FloatTensor(1).normal_(mean=mu, std=sigma) 40 | # dataset.append(data) 41 | while not a <= data <= b: 42 | data = torch.FloatTensor(1).normal_(mean=mu, std=sigma) 43 | dataset.append(data) 44 | dataset = torch.stack(dataset, 0) 45 | dataset = torch.reshape(dataset, shape) 46 | return dataset 47 | -------------------------------------------------------------------------------- /utils/utils_fb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import math 5 | import argparse 6 | 7 | import torch 8 | 9 | 10 | def _parse_args(params_config, args): 11 | parser = argparse.ArgumentParser() 12 | for params_category in params_config: # e.g., 'model_params' 13 | for param_flag, param_config in params_config[params_category].items(): 14 | # e.g., param_flag = '--block-sz' 15 | parser.add_argument(param_flag, **param_config) 16 | return parser.parse_args(args) 17 | 18 | 19 | def get_params(params_config, args=None): 20 | namespace = _parse_args(params_config, args) 21 | return { 22 | params_category: { 23 | param_config['dest']: 24 | namespace.__getattribute__(param_config['dest']) 25 | for param_config in params_config[params_category].values() 26 | } 27 | for params_category in params_config 28 | } 29 | 30 | 31 | ############################################################################## 32 | # ENVIRONMENT 33 | ############################################################################## 34 | 35 | def _torch_distributed_init_process_group(local_rank): 36 | torch.distributed.init_process_group( 37 | backend='nccl', 38 | init_method='env://' 39 | ) 40 | rank = torch.distributed.get_rank() 41 | world_size = torch.distributed.get_world_size() 42 | print('my rank={} local_rank={}'.format(rank, local_rank)) 43 | torch.cuda.set_device(local_rank) 44 | return { 45 | 'rank': rank, 46 | 'world_size': world_size, 47 | } 48 | 49 | def set_up_env(env_params): 50 | if torch.cuda.is_available(): 51 | env_params['device'] = torch.device('cuda') 52 | else: 53 | env_params['device'] = torch.device('cpu') 54 | 55 | 56 | ############################################################################## 57 | # OPTIMIZER AND SCHEDULER 58 | ############################################################################## 59 | 60 | def get_grad_requiring_params(model): 61 | nb_parameters = 0 62 | grad_requiring_params = [] 63 | for param in model.parameters(): 64 | if param.requires_grad: 65 | nb_parameters += param.numel() 66 | grad_requiring_params.append(param) 67 | print('nb_parameters={:.2f}M'.format(nb_parameters / 1e6)) 68 | return grad_requiring_params 69 | 70 | 71 | 72 | 73 | 74 | def get_scheduler(optimizer, lr_warmup): 75 | if lr_warmup > 0: 76 | return torch.optim.lr_scheduler.LambdaLR( 77 | optimizer, lambda ep: min(1, ep / lr_warmup)) 78 | return None 79 | 80 | 81 | 82 | 83 | ############################################################################## 84 | # CHECKPOINT 85 | ############################################################################## 86 | 87 | def _load_checkpoint(checkpoint_path, model, optimizer, scheduler): 88 | print('loading from a checkpoint at {}'.format(checkpoint_path)) 89 | 90 | checkpoint_state = torch.load(checkpoint_path) 91 | iter_init = checkpoint_state['iter_no'] + 1 # next iteration 92 | model.load_state_dict(checkpoint_state['model']) 93 | optimizer.load_state_dict(checkpoint_state['optimizer']) 94 | if 'scheduler_iter' in checkpoint_state: 95 | # we only need the step count 96 | scheduler.step(checkpoint_state['scheduler_iter']) 97 | return iter_init 98 | 99 | 100 | def load_checkpoint(checkpoint_path, model, optimizer, scheduler): 101 | if checkpoint_path and os.path.exists(checkpoint_path): 102 | return _load_checkpoint(checkpoint_path=checkpoint_path, 103 | model=model, 104 | optimizer=optimizer, 105 | scheduler=scheduler) 106 | return 0 107 | 108 | 109 | def save_checkpoint(checkpoint_path, iter_no, modules, 110 | optimizer, scheduler): 111 | if checkpoint_path: 112 | checkpoint_state = { 113 | 'iter_no': iter_no, # last completed iteration 114 | 'model': modules.state_dict(), 115 | 'optimizer': optimizer.state_dict(), 116 | } 117 | if scheduler is not None: 118 | checkpoint_state['scheduler_iter'] = scheduler.last_epoch 119 | torch.save(checkpoint_state, checkpoint_path) 120 | 121 | 122 | --------------------------------------------------------------------------------