├── .gitignore ├── .gitlab-ci.yml ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── assets ├── deck │ ├── BestowedDragon.ydk │ ├── BestowedDragon2.ydk │ ├── BestowedDragonF.ydk │ ├── Blackwing.ydk │ ├── BlueEyes.ydk │ ├── Branded.ydk │ ├── Branded60.ydk │ ├── CenturIon.ydk │ ├── CenturIon2.ydk │ ├── Chimera.ydk │ ├── CyberDragon.ydk │ ├── Eld.ydk │ ├── Floowandereeze.ydk │ ├── Floowandereeze2.ydk │ ├── Hero.ydk │ ├── Labrynth.ydk │ ├── MaxDragon.ydk │ ├── NatRunick.ydk │ ├── Pachycephalo.ydk │ ├── Shaddoll.ydk │ ├── Shiranui.ydk │ ├── SinfulSnake.ydk │ ├── SinfulSnake2.ydk │ ├── SinfulSnakeKash.ydk │ ├── SkyStrikerAce.ydk │ ├── SnakeEyeAlter.ydk │ ├── SnakeEyeFire.ydk │ ├── SnakeEyeFire2.ydk │ ├── SnakeEyeTear.ydk │ ├── TenyiSword.ydk │ ├── Voiceless.ydk │ ├── unsupported │ │ ├── Magician.ydk │ │ └── _tokens.ydk │ └── unused │ │ └── OldSchool.ydk └── log_conf.yaml ├── docs ├── action.md ├── feature_engineering.md ├── network_design.md └── support.md ├── mcts ├── mcts │ ├── __init__.py │ ├── alphazero │ │ ├── __init__.py │ │ ├── alphazero.py │ │ ├── alphazero_mcts.pyi │ │ ├── cnode.cpp │ │ ├── cnode.h │ │ └── tree.cpp │ └── core │ │ ├── __init__.py │ │ ├── array.h │ │ ├── common.h │ │ ├── minimax.h │ │ ├── spec.h │ │ └── state.py └── setup.py ├── repo └── packages │ ├── e │ └── edopro-core │ │ └── xmake.lua │ └── y │ └── ygopro-core │ └── xmake.lua ├── scripts ├── battle.py ├── card │ ├── code_list.py │ └── embedding.py ├── cleanba.py ├── cleanba_g.py ├── cleanba_l.py ├── cleanba_nnx.py ├── cleanba_rnd.py ├── code_list.txt ├── eval.py ├── impala.py └── torch │ ├── ppo.py │ ├── ppo_c.py │ ├── ppo_osfp.py │ └── ppo_xla.py ├── setup.py ├── xmake.lua ├── ygoai ├── __init__.py ├── _version.py ├── constants.py ├── embed.py ├── rl │ ├── __init__.py │ ├── agent.py │ ├── agent2.py │ ├── buffer.py │ ├── ckpt.py │ ├── dist.py │ ├── env.py │ ├── eval.py │ ├── jax │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── agent2.py │ │ ├── eval.py │ │ ├── modules.py │ │ ├── nnx │ │ │ ├── agent.py │ │ │ ├── modules.py │ │ │ ├── rnn.py │ │ │ └── transformer.py │ │ ├── rwkv.py │ │ ├── switch.py │ │ ├── transformer.py │ │ └── utils.py │ ├── ppo.py │ └── utils.py └── utils.py ├── ygoenv ├── MANIFEST.in ├── setup.py └── ygoenv │ ├── __init__.py │ ├── core │ ├── BS_thread_pool.h │ ├── ThreadPool.h │ ├── action_buffer_queue.h │ ├── array.h │ ├── async_envpool.h │ ├── circular_buffer.h │ ├── dict.h │ ├── env.h │ ├── env_spec.h │ ├── envpool.h │ ├── py_envpool.h │ ├── spec.h │ ├── state_buffer.h │ ├── state_buffer_queue.h │ ├── tuple_utils.h │ └── type_utils.h │ ├── dummy │ ├── __init__.py │ ├── dummy_envpool.cpp │ ├── dummy_envpool.h │ └── registration.py │ ├── edopro │ ├── __init__.py │ ├── edopro.cpp │ ├── edopro.h │ └── registration.py │ ├── entry.py │ ├── python │ ├── __init__.py │ ├── api.py │ ├── data.py │ ├── dm_envpool.py │ ├── env_spec.py │ ├── envpool.py │ ├── gym_envpool.py │ ├── gymnasium_envpool.py │ ├── protocol.py │ └── utils.py │ ├── registration.py │ ├── ygopro │ ├── __init__.py │ ├── registration.py │ ├── ygopro.cpp │ └── ygopro.h │ └── ygopro0 │ ├── __init__.py │ ├── registration.py │ ├── ygopro.cpp │ └── ygopro.h └── ygoinf ├── setup.py └── ygoinf ├── __init__.py ├── features.py ├── jax_inf.py ├── models.py ├── server.py └── tflite_inf.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.ptj 3 | *.pkl 4 | 5 | # Xmake cache 6 | .xmake/ 7 | 8 | # MacOS Cache 9 | .DS_Store 10 | 11 | 12 | *.out 13 | *.npy 14 | 15 | .vscode/ 16 | checkpoints 17 | runs 18 | logs 19 | k8s_job/ 20 | script 21 | assets/locale/*/*.cdb 22 | assets/locale/*/strings.conf 23 | 24 | 25 | # Byte-compiled / optimized / DLL files 26 | __pycache__/ 27 | *.py[cod] 28 | *$py.class 29 | 30 | # C extensions 31 | *.so 32 | *.o 33 | 34 | # Distribution / packaging 35 | .Python 36 | build/ 37 | develop-eggs/ 38 | dist/ 39 | downloads/ 40 | eggs/ 41 | .eggs/ 42 | lib/ 43 | lib64/ 44 | parts/ 45 | sdist/ 46 | var/ 47 | wheels/ 48 | share/python-wheels/ 49 | *.egg-info/ 50 | .installed.cfg 51 | *.egg 52 | MANIFEST 53 | 54 | # PyInstaller 55 | # Usually these files are written by a python script from a template 56 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 57 | *.manifest 58 | *.spec 59 | 60 | # Installer logs 61 | pip-log.txt 62 | pip-delete-this-directory.txt 63 | 64 | # Unit test / coverage reports 65 | htmlcov/ 66 | .tox/ 67 | .nox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | *.py,cover 75 | .hypothesis/ 76 | .pytest_cache/ 77 | cover/ 78 | 79 | # Translations 80 | *.mo 81 | *.pot 82 | 83 | # Django stuff: 84 | *.log 85 | local_settings.py 86 | db.sqlite3 87 | db.sqlite3-journal 88 | 89 | # Flask stuff: 90 | instance/ 91 | .webassets-cache 92 | 93 | # Scrapy stuff: 94 | .scrapy 95 | 96 | # Sphinx documentation 97 | docs/_build/ 98 | 99 | # PyBuilder 100 | .pybuilder/ 101 | target/ 102 | 103 | # Jupyter Notebook 104 | .ipynb_checkpoints 105 | 106 | # IPython 107 | profile_default/ 108 | ipython_config.py 109 | 110 | # pyenv 111 | # For a library or package, you might want to ignore these files since the code is 112 | # intended to run in multiple environments; otherwise, check them in: 113 | # .python-version 114 | 115 | # pipenv 116 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 117 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 118 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 119 | # install all needed dependencies. 120 | #Pipfile.lock 121 | 122 | # poetry 123 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 124 | # This is especially recommended for binary packages to ensure reproducibility, and is more 125 | # commonly ignored for libraries. 126 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 127 | #poetry.lock 128 | 129 | # pdm 130 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 131 | #pdm.lock 132 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 133 | # in version control. 134 | # https://pdm.fming.dev/#use-with-ide 135 | .pdm.toml 136 | 137 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 138 | __pypackages__/ 139 | 140 | # Celery stuff 141 | celerybeat-schedule 142 | celerybeat.pid 143 | 144 | # SageMath parsed files 145 | *.sage.py 146 | 147 | # Environments 148 | .env 149 | .venv 150 | env/ 151 | venv/ 152 | ENV/ 153 | env.bak/ 154 | venv.bak/ 155 | 156 | # Spyder project settings 157 | .spyderproject 158 | .spyproject 159 | 160 | # Rope project settings 161 | .ropeproject 162 | 163 | # mkdocs documentation 164 | /site 165 | 166 | # mypy 167 | .mypy_cache/ 168 | .dmypy.json 169 | dmypy.json 170 | 171 | # Pyre type checker 172 | .pyre/ 173 | 174 | # pytype static type analyzer 175 | .pytype/ 176 | 177 | # Cython debug symbols 178 | cython_debug/ 179 | 180 | # PyCharm 181 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 182 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 183 | # and can be added to the global gitignore or merged into this file. For a more nuclear 184 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 185 | #.idea/ -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | stages: 2 | - build 3 | - deploy 4 | variables: 5 | GIT_DEPTH: "1" 6 | 7 | before_script: 8 | - docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY 9 | 10 | .build-image: 11 | stage: build 12 | script: 13 | - docker build --pull -t $TARGET_IMAGE . 14 | - docker push $TARGET_IMAGE 15 | 16 | build-x86: 17 | extends: .build-image 18 | tags: 19 | - docker 20 | variables: 21 | TARGET_IMAGE: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG 22 | 23 | .deploy: 24 | stage: deploy 25 | tags: 26 | - docker 27 | script: 28 | - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG 29 | - docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG $TARGET_IMAGE 30 | - docker push $TARGET_IMAGE 31 | 32 | deploy_latest: 33 | extends: .deploy 34 | variables: 35 | TARGET_IMAGE: $CI_REGISTRY_IMAGE:latest 36 | only: 37 | - main 38 | 39 | deploy_branch: 40 | extends: .deploy 41 | variables: 42 | TARGET_IMAGE: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG 43 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10.14-bookworm as base 2 | LABEL Author="Hastur " 3 | 4 | WORKDIR /usr/src/app 5 | COPY ./ygoinf ./ 6 | COPY ./assets/log_conf.yaml ./ 7 | COPY ./scripts/code_list.txt ./ 8 | RUN pip install -e . 9 | RUN wget https://github.com/sbl1996/ygo-agent/releases/download/v0.1/0546_26550M.tflite 10 | ENV CHECKPOINT 0546_26550M.tflite 11 | 12 | EXPOSE 3000 13 | CMD [ "uvicorn", "ygoinf.server:app", "--host", "127.0.0.1", "--port", "3000", "--log-config=log_conf.yaml" ] 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | ygo-agent 2 | --------- 3 | MIT License 4 | 5 | Copyright (c) 2024 Hastur 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | 26 | Modified from sail-sg's envpool 27 | ------------------------------------------------------------------------------ 28 | Copyright 2021 Garena Online Private Limited 29 | 30 | Licensed under the Apache License, Version 2.0 (the "License"); 31 | you may not use this file except in compliance with the License. 32 | You may obtain a copy of the License at 33 | 34 | http://www.apache.org/licenses/LICENSE-2.0 35 | 36 | Unless required by applicable law or agreed to in writing, software 37 | distributed under the License is distributed on an "AS IS" BASIS, 38 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 39 | See the License for the specific language governing permissions and 40 | limitations under the License. 41 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SCRIPTS_REPO := "https://github.com/mycard/ygopro-scripts.git" 2 | SCRIPTS_DIR := "../ygopro-scripts" 3 | DATABASE_REPO := "https://github.com/mycard/ygopro-database/raw/7b1874301fc1aa52bd60585589f771e372ff52cc/locales" 4 | LOCALES := en zh 5 | 6 | .PHONY: all assets script py_install ygoenv_so clean dev 7 | 8 | all: assets script py_install 9 | 10 | dev: assets script py_install ygoenv_so 11 | 12 | py_install: 13 | pip install -e ygoenv 14 | pip install -e ygoinf 15 | pip install -e . 16 | 17 | ygoenv_so: ygoenv/ygoenv/ygopro/ygopro_ygoenv.so 18 | 19 | ygoenv/ygoenv/ygopro/ygopro_ygoenv.so: 20 | xmake b ygopro_ygoenv 21 | 22 | script : scripts/script 23 | 24 | scripts/script: 25 | if [ ! -d $(SCRIPTS_DIR) ] ; then git clone $(SCRIPTS_REPO) $(SCRIPTS_DIR); fi 26 | cd $(SCRIPTS_DIR) && git checkout 8e7fde9 27 | ln -sf "../$(SCRIPTS_DIR)" scripts/script 28 | 29 | assets: $(LOCALES) 30 | 31 | $(LOCALES): % : assets/locale/%/cards.cdb assets/locale/%/strings.conf 32 | 33 | assets/locale/en assets/locale/zh: 34 | mkdir -p $@ 35 | 36 | assets/locale/en/cards.cdb: assets/locale/en 37 | wget -nv $(DATABASE_REPO)/en-US/cards.cdb -O $@ 38 | 39 | assets/locale/en/strings.conf: assets/locale/en 40 | wget -nv $(DATABASE_REPO)/en-US/strings.conf -O $@ 41 | 42 | assets/locale/zh/cards.cdb: assets/locale/zh 43 | wget -nv $(DATABASE_REPO)/zh-CN/cards.cdb -O $@ 44 | 45 | assets/locale/zh/strings.conf: assets/locale/zh 46 | wget -nv $(DATABASE_REPO)/zh-CN/strings.conf -O $@ 47 | 48 | clean: 49 | rm -rf scripts/script 50 | rm -rf assets/locale/en assets/locale/zh -------------------------------------------------------------------------------- /assets/deck/BestowedDragon.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 46565218 5 | 46565218 6 | 91800273 7 | 91800273 8 | 33854624 9 | 91810826 10 | 91810826 11 | 91810826 12 | 39931513 13 | 39931513 14 | 39931513 15 | 65326118 16 | 65326118 17 | 14558127 18 | 14558127 19 | 14558127 20 | 23434538 21 | 23434538 22 | 23434538 23 | 18144506 24 | 14532163 25 | 14532163 26 | 12580477 27 | 12580477 28 | 35269904 29 | 91880660 30 | 73628505 31 | 24299458 32 | 24299458 33 | 66730191 34 | 66730191 35 | 66730191 36 | 47355498 37 | 30336082 38 | 30336082 39 | 30336082 40 | 10045474 41 | 10045474 42 | 10045474 43 | #extra 44 | 84815190 45 | 39402797 46 | 73218989 47 | 18969888 48 | 50954680 49 | 9012916 50 | 73580471 51 | 25862681 52 | 82570174 53 | 82570174 54 | 1686814 55 | 93039339 56 | 29301450 57 | 24361622 58 | 73539069 59 | !side 60 | 36668118 61 | 27204312 62 | -------------------------------------------------------------------------------- /assets/deck/BestowedDragon2.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 53804307 5 | 91810826 6 | 91810826 7 | 91810826 8 | 39931513 9 | 39931513 10 | 39931513 11 | 65326118 12 | 65326118 13 | 14558127 14 | 14558127 15 | 14558127 16 | 59438930 17 | 59438930 18 | 23434538 19 | 23434538 20 | 23434538 21 | 14532163 22 | 14532163 23 | 25311006 24 | 73628505 25 | 75500286 26 | 24224830 27 | 24224830 28 | 24299458 29 | 24299458 30 | 18144506 31 | 66730191 32 | 66730191 33 | 66730191 34 | 30336082 35 | 30336082 36 | 30336082 37 | 47355498 38 | 91880660 39 | 10045474 40 | 10045474 41 | 10045474 42 | 23002292 43 | #extra 44 | 73218989 45 | 40139997 46 | 18969888 47 | 39402797 48 | 84815190 49 | 50954680 50 | 9012916 51 | 82570174 52 | 82570174 53 | 33698022 54 | 25862681 55 | 68431965 56 | 1686814 57 | 24361622 58 | !side 59 | 73580471 60 | 27204312 61 | -------------------------------------------------------------------------------- /assets/deck/BestowedDragonF.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 55063751 5 | 78661338 6 | 33854624 7 | 46565218 8 | 46565218 9 | 46565218 10 | 91800273 11 | 91800273 12 | 91810826 13 | 91810826 14 | 91810826 15 | 97682931 16 | 39931513 17 | 39931513 18 | 39931513 19 | 65326118 20 | 65326118 21 | 14558127 22 | 14558127 23 | 14558127 24 | 23434538 25 | 23434538 26 | 23434538 27 | 94145021 28 | 4031928 29 | 4031928 30 | 25311006 31 | 73628505 32 | 8267140 33 | 48130397 34 | 66730191 35 | 66730191 36 | 66730191 37 | 73468603 38 | 30336082 39 | 30336082 40 | 30336082 41 | 56111151 42 | 10045474 43 | 10045474 44 | 10045474 45 | #extra 46 | 11765832 47 | 54757758 48 | 27572350 49 | 18969888 50 | 18969888 51 | 39402797 52 | 84815190 53 | 33698022 54 | 82570174 55 | 82570174 56 | 25862681 57 | 74997493 58 | 29301450 59 | 58699500 60 | 24361622 61 | !side 62 | 27204312 63 | -------------------------------------------------------------------------------- /assets/deck/Blackwing.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 75498415 4 | 81105204 5 | 81105204 6 | 81105204 7 | 58820853 8 | 58820853 9 | 58820853 10 | 49003716 11 | 49003716 12 | 49003716 13 | 14785765 14 | 85215458 15 | 85215458 16 | 85215458 17 | 2009101 18 | 2009101 19 | 2009101 20 | 22835145 21 | 22835145 22 | 22835145 23 | 73652465 24 | 1475311 25 | 1475311 26 | 53129443 27 | 5318639 28 | 5318639 29 | 14087893 30 | 27243130 31 | 27243130 32 | 91351370 33 | 91351370 34 | 91351370 35 | 53567095 36 | 53567095 37 | 53567095 38 | 53582587 39 | 59839761 40 | 72930878 41 | 72930878 42 | 84749824 43 | #extra 44 | 52687916 45 | 33236860 46 | 16051717 47 | 23338098 48 | 81983656 49 | 69031175 50 | 73580471 51 | 95040215 52 | 76913983 53 | 17377751 54 | 16195942 55 | 86848580 56 | 82633039 57 | 73347079 58 | 78156759 59 | !side 60 | -------------------------------------------------------------------------------- /assets/deck/BlueEyes.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 89631139 4 | 89631139 5 | 89631139 6 | 38517737 7 | 38517737 8 | 38517737 9 | 45467446 10 | 65681983 11 | 65681983 12 | 65681983 13 | 23434538 14 | 23434538 15 | 23434538 16 | 71039903 17 | 71039903 18 | 71039903 19 | 45644898 20 | 45644898 21 | 79814787 22 | 8240199 23 | 8240199 24 | 8240199 25 | 97268402 26 | 97268402 27 | 2295440 28 | 6853254 29 | 6853254 30 | 6853254 31 | 38120068 32 | 38120068 33 | 38120068 34 | 39701395 35 | 39701395 36 | 41620959 37 | 41620959 38 | 41620959 39 | 48800175 40 | 48800175 41 | 48800175 42 | 43898403 43 | 43898403 44 | 63356631 45 | 63356631 46 | #extra 47 | 40908371 48 | 40908371 49 | 59822133 50 | 59822133 51 | 50954680 52 | 83994433 53 | 33698022 54 | 39030163 55 | 31801517 56 | 02978414 57 | 63767246 58 | 64332231 59 | 10443957 60 | 41999284 61 | !side 62 | 8233522 63 | 8233522 64 | 14558127 65 | 14558127 66 | 14558127 67 | 56399890 68 | 53129443 69 | 25789292 70 | 25789292 71 | 43455065 72 | 43455065 73 | 43898403 74 | 11109820 75 | 11109820 76 | 11109820 77 | -------------------------------------------------------------------------------- /assets/deck/Branded.ydk: -------------------------------------------------------------------------------- 1 | #created by wxapp_ygo 2 | #main 3 | 25451383 4 | 25451383 5 | 32731036 6 | 35984222 7 | 60242223 8 | 60242223 9 | 48048590 10 | 62962630 11 | 62962630 12 | 62962630 13 | 68468459 14 | 68468459 15 | 68468459 16 | 45484331 17 | 19096726 18 | 95515789 19 | 45883110 20 | 14558127 21 | 14558127 22 | 14558127 23 | 23434538 24 | 23434538 25 | 23434538 26 | 36577931 27 | 25311006 28 | 35269904 29 | 44362883 30 | 06498706 31 | 06498706 32 | 06498706 33 | 01984618 34 | 01984618 35 | 75500286 36 | 81439173 37 | 82738008 38 | 29948294 39 | 36637374 40 | 24224830 41 | 24224830 42 | 18973184 43 | 01041278 44 | 10045474 45 | 10045474 46 | 10045474 47 | 17751597 48 | 17751597 49 | #extra 50 | 11321089 51 | 03410461 52 | 72272462 53 | 87746184 54 | 87746184 55 | 41373230 56 | 51409648 57 | 01906812 58 | 24915933 59 | 70534340 60 | 44146295 61 | 44146295 62 | 92892239 63 | 38811586 64 | 53971455 65 | !side 66 | 27204312 -------------------------------------------------------------------------------- /assets/deck/Branded60.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 55204071 5 | 69680031 6 | 25451383 7 | 25451383 8 | 32731036 9 | 33854624 10 | 60242223 11 | 60242223 12 | 19096726 13 | 82489470 14 | 62962630 15 | 62962630 16 | 60303688 17 | 55273560 18 | 95515789 19 | 95515789 20 | 45883110 21 | 45484331 22 | 68468459 23 | 68468459 24 | 68468459 25 | 14558127 26 | 14558127 27 | 14558127 28 | 23434538 29 | 23434538 30 | 23434538 31 | 36577931 32 | 94145021 33 | 94145021 34 | 35269904 35 | 75500286 36 | 81439173 37 | 6498706 38 | 6498706 39 | 6498706 40 | 1984618 41 | 1984618 42 | 1984618 43 | 11110587 44 | 11110587 45 | 34995106 46 | 44362883 47 | 24224830 48 | 24224830 49 | 27204311 50 | 36637374 51 | 29948294 52 | 29948294 53 | 82738008 54 | 18973184 55 | 10045474 56 | 81767888 57 | 6763530 58 | 1041278 59 | 19271881 60 | 32756828 61 | 17751597 62 | 17751597 63 | #extra 64 | 11321089 65 | 92892239 66 | 41373230 67 | 51409648 68 | 1906812 69 | 24915933 70 | 24915933 71 | 70534340 72 | 3410461 73 | 38811586 74 | 44146295 75 | 44146295 76 | 87746184 77 | 87746184 78 | 53971455 79 | !side 80 | 27204312 81 | 81767889 82 | -------------------------------------------------------------------------------- /assets/deck/CenturIon.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 32731036 5 | 78888899 6 | 97698279 7 | 6637331 8 | 33854624 9 | 15005145 10 | 15005145 11 | 15005145 12 | 42493140 13 | 42493140 14 | 42493140 15 | 14558127 16 | 14558127 17 | 14558127 18 | 77202120 19 | 23434538 20 | 23434538 21 | 23434538 22 | 94145021 23 | 94145021 24 | 97268402 25 | 97268402 26 | 97268402 27 | 25311006 28 | 4160316 29 | 24224830 30 | 24224830 31 | 77765207 32 | 77765207 33 | 77765207 34 | 34090915 35 | 41371602 36 | 41371602 37 | 41371602 38 | 10045474 39 | 10045474 40 | 10045474 41 | 40155014 42 | 77543769 43 | #extra 44 | 21123811 45 | 26268488 46 | 99585850 47 | 15982593 48 | 71858682 49 | 71858682 50 | 63436931 51 | 27572350 52 | 22850702 53 | 84815190 54 | 72444406 55 | 30983281 56 | 93039339 57 | 93854893 58 | 29301450 59 | !side 60 | 27204312 61 | -------------------------------------------------------------------------------- /assets/deck/CenturIon2.ydk: -------------------------------------------------------------------------------- 1 | #created by wxapp_ygo 2 | #main 3 | 42493140 4 | 42493140 5 | 42493140 6 | 15005145 7 | 15005145 8 | 15005145 9 | 97698279 10 | 23434538 11 | 23434538 12 | 23434538 13 | 14558127 14 | 14558127 15 | 14558127 16 | 97268402 17 | 97268402 18 | 27204311 19 | 27204311 20 | 41371602 21 | 41371602 22 | 41371602 23 | 77765207 24 | 77765207 25 | 77765207 26 | 92907248 27 | 81674782 28 | 81674782 29 | 81674782 30 | 35059553 31 | 35059553 32 | 35059553 33 | 24224830 34 | 24224830 35 | 84211599 36 | 73628505 37 | 40155014 38 | 40155014 39 | 10045474 40 | 10045474 41 | 10045474 42 | 82732705 43 | #extra 44 | 72444406 45 | 63436931 46 | 63436931 47 | 71858682 48 | 71858682 49 | 71858682 50 | 15982593 51 | 15982593 52 | 21123811 53 | 21123811 54 | 93039339 55 | 34755994 56 | 65741786 57 | 29301450 58 | 02857636 59 | !side 60 | 27204312 61 | 92907249 62 | -------------------------------------------------------------------------------- /assets/deck/Chimera.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 27204311 5 | 27204311 6 | 58143852 7 | 58143852 8 | 58143852 9 | 55461744 10 | 55461744 11 | 28954097 12 | 28954097 13 | 28954097 14 | 23076639 15 | 23076639 16 | 23076639 17 | 61173621 18 | 61173621 19 | 84478195 20 | 92565383 21 | 92565383 22 | 92565383 23 | 34541543 24 | 14558127 25 | 14558127 26 | 14558127 27 | 23434538 28 | 23434538 29 | 23434538 30 | 94145021 31 | 94145021 32 | 97268402 33 | 97268402 34 | 24094653 35 | 24094653 36 | 25311006 37 | 25311006 38 | 34773082 39 | 34773082 40 | 34773082 41 | 24224830 42 | 24224830 43 | 63136489 44 | 63136489 45 | 63136489 46 | 53329234 47 | #extra 48 | 11321089 49 | 11321089 50 | 38264974 51 | 43227 52 | 69946549 53 | 69601012 54 | 69601012 55 | 1769875 56 | 1769875 57 | 11765832 58 | 22850702 59 | 93039339 60 | 27552504 61 | 29301450 62 | 71607202 63 | !side 64 | 5818798 65 | 77207191 66 | 4796100 67 | 27204312 68 | -------------------------------------------------------------------------------- /assets/deck/CyberDragon.ydk: -------------------------------------------------------------------------------- 1 | #created by wxapp_ygo 2 | #main 3 | 63941210 4 | 63941210 5 | 10604644 6 | 46659709 7 | 46659709 8 | 70095154 9 | 70095154 10 | 70095154 11 | 05370235 12 | 23434538 13 | 23434538 14 | 23434538 15 | 23893227 16 | 23893227 17 | 23893227 18 | 01142880 19 | 01142880 20 | 56364287 21 | 56364287 22 | 56364287 23 | 14532163 24 | 14532163 25 | 86686671 26 | 60600126 27 | 60600126 28 | 60600126 29 | 18144506 30 | 12580477 31 | 63995093 32 | 63995093 33 | 03659803 34 | 37630732 35 | 39973386 36 | 84797028 37 | 84797028 38 | 84797028 39 | 64753988 40 | 10045474 41 | 10045474 42 | 10045474 43 | #extra 44 | 01546123 45 | 87116928 46 | 74157028 47 | 79229522 48 | 79229522 49 | 84058253 50 | 84058253 51 | 22850702 52 | 90448279 53 | 10443957 54 | 73964868 55 | 58069384 56 | 70369116 57 | 46724542 58 | 60303245 59 | !side -------------------------------------------------------------------------------- /assets/deck/Eld.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 10000080 5 | 10000080 6 | 10000080 7 | 95440946 8 | 95440946 9 | 95440946 10 | 14558127 11 | 14558127 12 | 14558127 13 | 23434538 14 | 23434538 15 | 23434538 16 | 68829754 17 | 68829754 18 | 94224458 19 | 31434645 20 | 31434645 21 | 31434645 22 | 10045474 23 | 10045474 24 | 20612097 25 | 20612097 26 | 20612097 27 | 58921041 28 | 58921041 29 | 53334471 30 | 53334471 31 | 90846359 32 | 90846359 33 | 23516703 34 | 23516703 35 | 82732705 36 | 67007102 37 | 93191801 38 | 93191801 39 | 20590515 40 | 20590515 41 | 20590515 42 | 56984514 43 | #extra 44 | 74889525 45 | 62541668 46 | 90448279 47 | 26556950 48 | 26096328 49 | 49032236 50 | 56910167 51 | 56910167 52 | 56910167 53 | 73082255 54 | 37129797 55 | 03814632 56 | 72860663 57 | !side 58 | 27204312 59 | -------------------------------------------------------------------------------- /assets/deck/Floowandereeze.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 24508238 4 | 54334420 5 | 54334420 6 | 54334420 7 | 18940725 8 | 18940725 9 | 18940725 10 | 80433039 11 | 17827173 12 | 17827173 13 | 23434538 14 | 23434538 15 | 23434538 16 | 14558127 17 | 14558127 18 | 14558127 19 | 91800273 20 | 91800273 21 | 29587993 22 | 69327790 23 | 80611581 24 | 53212882 25 | 75500286 26 | 98645731 27 | 49238328 28 | 51697825 29 | 51697825 30 | 51697825 31 | 84211599 32 | 28126717 33 | 55521751 34 | 24224830 35 | 24224830 36 | 25311006 37 | 69087397 38 | 69087397 39 | 69087397 40 | 87639778 41 | 87639778 42 | 41215808 43 | #extra 44 | 35809262 45 | 35809262 46 | 81003500 47 | 81003500 48 | 81003500 49 | 76913983 50 | 76913983 51 | 76913983 52 | 48608796 53 | 48608796 54 | 48608796 55 | 90448279 56 | 94259633 57 | 4280258 58 | 98127546 59 | !side 60 | 21844576 61 | 58932615 62 | 86188410 63 | 89252153 64 | -------------------------------------------------------------------------------- /assets/deck/Floowandereeze2.ydk: -------------------------------------------------------------------------------- 1 | #created by wxapp_ygo 2 | #main 3 | 53212882 4 | 80611581 5 | 69327790 6 | 29587993 7 | 91800273 8 | 91800273 9 | 14558127 10 | 14558127 11 | 14558127 12 | 18940725 13 | 18940725 14 | 18940725 15 | 54334420 16 | 54334420 17 | 54334420 18 | 80433039 19 | 80433039 20 | 17827173 21 | 24508238 22 | 84211599 23 | 75500286 24 | 98645731 25 | 49238328 26 | 51697825 27 | 69087397 28 | 69087397 29 | 69087397 30 | 25311006 31 | 24224830 32 | 24224830 33 | 81674782 34 | 55521751 35 | 28126717 36 | 41215808 37 | 15693423 38 | 15693423 39 | 15693423 40 | 10045474 41 | 10045474 42 | 10045474 43 | #extra 44 | 90448279 45 | 72167543 46 | 72971064 47 | 48608796 48 | 98127546 49 | 86066372 50 | 04280258 51 | 21887175 52 | 38342335 53 | 08264361 54 | 48815792 55 | 65741786 56 | 02857636 57 | 75452921 58 | 94259633 59 | !side -------------------------------------------------------------------------------- /assets/deck/Hero.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 18094166 4 | 18094166 5 | 18094166 6 | 22865492 7 | 22865492 8 | 27780618 9 | 27780618 10 | 40044918 11 | 40044918 12 | 9411399 13 | 9411399 14 | 89943723 15 | 16605586 16 | 50720316 17 | 59392529 18 | 14124483 19 | 83965310 20 | 23434538 21 | 23434538 22 | 23434538 23 | 14558127 24 | 14558127 25 | 14558127 26 | 94145021 27 | 94145021 28 | 8949584 29 | 8949584 30 | 8949584 31 | 21143940 32 | 21143940 33 | 21143940 34 | 45906428 35 | 52947044 36 | 24094653 37 | 24094653 38 | 24224830 39 | 24224830 40 | 81439173 41 | 32807846 42 | 75047173 43 | #extra 44 | 30757127 45 | 89870349 46 | 58481572 47 | 58481572 48 | 22908820 49 | 93347961 50 | 46759931 51 | 40854197 52 | 60461804 53 | 56733747 54 | 32828466 55 | 90590303 56 | 58004362 57 | 19324993 58 | 1948619 59 | !side 60 | -------------------------------------------------------------------------------- /assets/deck/Labrynth.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 2347656 4 | 81497285 5 | 81497285 6 | 75730490 7 | 1225009 8 | 1225009 9 | 1225009 10 | 73642296 11 | 62015408 12 | 62015408 13 | 14558127 14 | 14558127 15 | 14558127 16 | 37629703 17 | 37629703 18 | 23434538 19 | 23434538 20 | 23434538 21 | 74018812 22 | 74018812 23 | 97268402 24 | 2511 25 | 2511 26 | 84211599 27 | 24224830 28 | 24224830 29 | 33407125 30 | 4931121 31 | 83326048 32 | 6351147 33 | 30748475 34 | 10045474 35 | 10045474 36 | 10045474 37 | 5380979 38 | 5380979 39 | 92714517 40 | 92714517 41 | 92714517 42 | 82732705 43 | #extra 44 | 87746184 45 | 84815190 46 | 22850702 47 | 22850702 48 | 90590303 49 | 54498517 50 | 98127546 51 | 4280258 52 | 2772337 53 | 65741786 54 | 8264361 55 | 27381364 56 | 71607202 57 | 71607202 58 | 94259633 59 | !side 60 | 68468459 61 | -------------------------------------------------------------------------------- /assets/deck/MaxDragon.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 89631139 4 | 55410871 5 | 89631139 6 | 80701178 7 | 31036355 8 | 38517737 9 | 80701178 10 | 80701178 11 | 95492061 12 | 95492061 13 | 95492061 14 | 53303460 15 | 53303460 16 | 53303460 17 | 14558127 18 | 14558127 19 | 23434538 20 | 55410871 21 | 55410871 22 | 31036355 23 | 31036355 24 | 48800175 25 | 48800175 26 | 48800175 27 | 70368879 28 | 70368879 29 | 70368879 30 | 21082832 31 | 46052429 32 | 46052429 33 | 46052429 34 | 24224830 35 | 24224830 36 | 24224830 37 | 73915051 38 | 10045474 39 | 10045474 40 | 37576645 41 | 37576645 42 | 37576645 43 | #extra 44 | 31833038 45 | 85289965 46 | 74997493 47 | 5043010 48 | 65330383 49 | 38342335 50 | 2857636 51 | 75452921 52 | 3987233 53 | 3987233 54 | 99111753 55 | 98978921 56 | 41999284 57 | 41999284 58 | !side 59 | 73915052 60 | 73915053 61 | 73915054 62 | 73915055 63 | -------------------------------------------------------------------------------- /assets/deck/NatRunick.ydk: -------------------------------------------------------------------------------- 1 | #created by wxapp_ygo 2 | #main 3 | 93454062 4 | 93454062 5 | 93454062 6 | 07478431 7 | 23434538 8 | 23434538 9 | 23434538 10 | 14558127 11 | 14558127 12 | 14558127 13 | 29942771 14 | 29942771 15 | 29942771 16 | 35726888 17 | 92107604 18 | 24224830 19 | 24299458 20 | 94445733 21 | 66712905 22 | 66712905 23 | 68957034 24 | 68957034 25 | 68957034 26 | 30430448 27 | 30430448 28 | 20618850 29 | 20618850 30 | 67835547 31 | 67835547 32 | 93229151 33 | 93229151 34 | 93229151 35 | 31562086 36 | 31562086 37 | 34813545 38 | 34813545 39 | 34813545 40 | 03734202 41 | 03734202 42 | 03734202 43 | #extra 44 | 55990317 45 | 55990317 46 | 55990317 47 | 28373620 48 | 28373620 49 | 33198837 50 | 42566602 51 | 52445243 52 | 80666118 53 | 87188910 54 | 84815190 55 | 96633955 56 | 66011101 57 | 90590303 58 | 08728498 59 | !side 60 | 70902743 61 | -------------------------------------------------------------------------------- /assets/deck/Pachycephalo.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 27204311 5 | 32909498 6 | 68304193 7 | 68304193 8 | 68304193 9 | 91800273 10 | 91800273 11 | 60303688 12 | 47961808 13 | 10963799 14 | 10963799 15 | 10963799 16 | 42009836 17 | 42009836 18 | 42009836 19 | 14558127 20 | 14558127 21 | 23434538 22 | 23434538 23 | 23434538 24 | 1984618 25 | 1984618 26 | 35261759 27 | 84211599 28 | 49238328 29 | 84797028 30 | 84797028 31 | 69540484 32 | 71832012 33 | 82956214 34 | 82956214 35 | 36975314 36 | 36975314 37 | 36975314 38 | 41420027 39 | 41420027 40 | 41420027 41 | 84749824 42 | 84749824 43 | #extra 44 | 11765832 45 | 11765832 46 | 80532587 47 | 80532587 48 | 96633955 49 | 84815190 50 | 98506199 51 | 90448279 52 | 48626373 53 | 48626373 54 | 80117527 55 | 80117527 56 | 10019086 57 | 34755994 58 | 34755994 59 | !side 60 | 73542331 61 | 27204311 62 | -------------------------------------------------------------------------------- /assets/deck/Shaddoll.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 3717252 4 | 3717252 5 | 3717252 6 | 77723643 7 | 77723643 8 | 30328508 9 | 30328508 10 | 59546797 11 | 97518132 12 | 34710660 13 | 51023024 14 | 51023024 15 | 4939890 16 | 4939890 17 | 4939890 18 | 59438930 19 | 59438930 20 | 69764158 21 | 24635329 22 | 37445295 23 | 37445295 24 | 23434538 25 | 23434538 26 | 1475311 27 | 11827244 28 | 44394295 29 | 44394295 30 | 44394295 31 | 53129443 32 | 81439173 33 | 6417578 34 | 6417578 35 | 48130397 36 | 23912837 37 | 23912837 38 | 77505534 39 | 77505534 40 | 77505534 41 | 4904633 42 | 40605147 43 | 40605147 44 | 84749824 45 | #extra 46 | 84433295 47 | 74822425 48 | 74822425 49 | 19261966 50 | 20366274 51 | 48424886 52 | 50907446 53 | 50907446 54 | 94977269 55 | 52687916 56 | 73580471 57 | 56832966 58 | 84013237 59 | 82633039 60 | !side 61 | -------------------------------------------------------------------------------- /assets/deck/Shiranui.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 55623480 4 | 52467217 5 | 52467217 6 | 52467217 7 | 92826944 8 | 92826944 9 | 92826944 10 | 41562624 11 | 41562624 12 | 99423156 13 | 99423156 14 | 94801854 15 | 94801854 16 | 94801854 17 | 49959355 18 | 49959355 19 | 49959355 20 | 79783880 21 | 14558127 22 | 14558127 23 | 14558127 24 | 36630403 25 | 36630403 26 | 23434538 27 | 23434538 28 | 23434538 29 | 97268402 30 | 12580477 31 | 18144506 32 | 75500286 33 | 81439173 34 | 13965201 35 | 13965201 36 | 24224830 37 | 24224830 38 | 40364916 39 | 40364916 40 | 4333086 41 | 4333086 42 | 10045474 43 | 10045474 44 | 40605147 45 | 40605147 46 | 41420027 47 | #extra 48 | 59843383 49 | 27548199 50 | 50954680 51 | 83283063 52 | 74586817 53 | 52711246 54 | 57288064 55 | 26326541 56 | 98558751 57 | 86066372 58 | 72860663 59 | 86926989 60 | 37129797 61 | 91420202 62 | 41999284 63 | !side 64 | -------------------------------------------------------------------------------- /assets/deck/SinfulSnake.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 48452496 5 | 48452496 6 | 72270339 7 | 72270339 8 | 72270339 9 | 14558127 10 | 14558127 11 | 14558127 12 | 23434538 13 | 23434538 14 | 23434538 15 | 35405755 16 | 45663742 17 | 9674034 18 | 9674034 19 | 9674034 20 | 90241276 21 | 90241276 22 | 9742784 23 | 12058741 24 | 94145021 25 | 94145021 26 | 97268402 27 | 97268402 28 | 97268402 29 | 2295440 30 | 24081957 31 | 89023486 32 | 89023486 33 | 24224830 34 | 24224830 35 | 26700718 36 | 80845034 37 | 80845034 38 | 80845034 39 | 53639887 40 | 10045474 41 | 10045474 42 | 38511382 43 | #extra 44 | 84815190 45 | 27548199 46 | 79606837 47 | 50091196 48 | 98127546 49 | 20665527 50 | 45112597 51 | 4280258 52 | 2772337 53 | 61245672 54 | 48815792 55 | 87871125 56 | 27381364 57 | 65741786 58 | 41999284 59 | !side 60 | 27204312 61 | -------------------------------------------------------------------------------- /assets/deck/SinfulSnake2.ydk: -------------------------------------------------------------------------------- 1 | #created by wxapp_ygo 2 | #main 3 | 97268402 4 | 97268402 5 | 97268402 6 | 09742784 7 | 09674034 8 | 09674034 9 | 09674034 10 | 45663742 11 | 90241276 12 | 90241276 13 | 90241276 14 | 23434538 15 | 23434538 16 | 23434538 17 | 14558127 18 | 14558127 19 | 14558127 20 | 52038441 21 | 52038441 22 | 52038441 23 | 06637331 24 | 33854624 25 | 72270339 26 | 72270339 27 | 72270339 28 | 48452496 29 | 48452496 30 | 27204311 31 | 27204311 32 | 27204311 33 | 89023486 34 | 89023486 35 | 02295440 36 | 53639887 37 | 80845034 38 | 80845034 39 | 80845034 40 | 10045474 41 | 10045474 42 | 10045474 43 | #extra 44 | 27548199 45 | 02772337 46 | 41999284 47 | 02857636 48 | 48815792 49 | 87871125 50 | 65741786 51 | 08264361 52 | 38342335 53 | 61245672 54 | 20665527 55 | 04280258 56 | 86066372 57 | 45112597 58 | 98127546 59 | !side 60 | 27204312 61 | -------------------------------------------------------------------------------- /assets/deck/SinfulSnakeKash.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 97268402 4 | 97268402 5 | 9674034 6 | 9674034 7 | 9674034 8 | 45663742 9 | 90241276 10 | 90241276 11 | 90241276 12 | 23434538 13 | 23434538 14 | 23434538 15 | 14558127 16 | 14558127 17 | 14558127 18 | 6637331 19 | 32909498 20 | 68304193 21 | 68304193 22 | 72270339 23 | 72270339 24 | 72270339 25 | 48452496 26 | 48452496 27 | 27204311 28 | 89023486 29 | 89023486 30 | 25311006 31 | 24081957 32 | 53639887 33 | 69540484 34 | 80845034 35 | 80845034 36 | 80845034 37 | 24224830 38 | 24224830 39 | 24299458 40 | 10045474 41 | 10045474 42 | 10045474 43 | #extra 44 | 2772337 45 | 41999284 46 | 94259633 47 | 2857636 48 | 48815792 49 | 87871125 50 | 65741786 51 | 8264361 52 | 38342335 53 | 45819647 54 | 20665527 55 | 4280258 56 | 86066372 57 | 45112597 58 | 98127546 59 | !side 60 | 27204312 61 | -------------------------------------------------------------------------------- /assets/deck/SkyStrikerAce.ydk: -------------------------------------------------------------------------------- 1 | #created by wxapp_ygo 2 | #main 3 | 27204311 4 | 27204311 5 | 26077387 6 | 26077387 7 | 26077387 8 | 37351133 9 | 14558127 10 | 14558127 11 | 14558127 12 | 23434538 13 | 23434538 14 | 23434538 15 | 20357457 16 | 20357457 17 | 94145021 18 | 94145021 19 | 73594093 20 | 35726888 21 | 35726888 22 | 99550630 23 | 35261759 24 | 70368879 25 | 70368879 26 | 25311006 27 | 32807846 28 | 63166095 29 | 63166095 30 | 63166095 31 | 24224830 32 | 24224830 33 | 52340444 34 | 98338152 35 | 98338152 36 | 98338152 37 | 51227866 38 | 09726840 39 | 09726840 40 | 24010609 41 | 24010609 42 | 50005218 43 | #extra 44 | 86066372 45 | 75147529 46 | 29301450 47 | 98462037 48 | 63013339 49 | 63288573 50 | 63288573 51 | 63288573 52 | 90673288 53 | 90673288 54 | 90673288 55 | 08491308 56 | 08491308 57 | 08491308 58 | 12421694 59 | !side 60 | 52340445 61 | 27204312 62 | -------------------------------------------------------------------------------- /assets/deck/SnakeEyeAlter.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 48452496 4 | 72270339 5 | 53143898 6 | 27132400 7 | 27132400 8 | 42790071 9 | 14558127 10 | 14558127 11 | 14558127 12 | 89538537 13 | 23434538 14 | 23434538 15 | 23434538 16 | 9674034 17 | 90241276 18 | 25533642 19 | 25533642 20 | 25533642 21 | 54126514 22 | 59185998 23 | 59185998 24 | 24508238 25 | 2295440 26 | 51405049 27 | 84211599 28 | 89023486 29 | 24224830 30 | 24224830 31 | 26700718 32 | 52340444 33 | 80845034 34 | 80845034 35 | 80845034 36 | 10045474 37 | 10045474 38 | 10045474 39 | 22024279 40 | 35146019 41 | 53329234 42 | 27541563 43 | #extra 44 | 93039339 45 | 61470213 46 | 20665527 47 | 86066372 48 | 4280258 49 | 2772337 50 | 93503294 51 | 45819647 52 | 48815792 53 | 29301450 54 | 1508649 55 | 1508649 56 | 63013339 57 | 41999284 58 | 94259633 59 | !side 60 | 52340445 61 | 27204312 62 | -------------------------------------------------------------------------------- /assets/deck/SnakeEyeFire.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 48452496 5 | 66431519 6 | 2526224 7 | 2526224 8 | 2526224 9 | 72270339 10 | 18621798 11 | 14558127 12 | 14558127 13 | 14558127 14 | 23434538 15 | 23434538 16 | 23434538 17 | 9674034 18 | 9674034 19 | 9674034 20 | 90241276 21 | 90241276 22 | 90241276 23 | 90681088 24 | 94145021 25 | 97268402 26 | 97268402 27 | 24081957 28 | 85106525 29 | 85106525 30 | 85106525 31 | 89023486 32 | 24224830 33 | 24224830 34 | 80845034 35 | 80845034 36 | 80845034 37 | 91703676 38 | 65305978 39 | 57554544 40 | 10045474 41 | 10045474 42 | 10045474 43 | #extra 44 | 93039339 45 | 64182380 46 | 57134592 47 | 20665527 48 | 45112597 49 | 4280258 50 | 2772337 51 | 2772337 52 | 61245672 53 | 8264361 54 | 48815792 55 | 87871125 56 | 29301450 57 | 65741786 58 | 41999284 59 | !side 60 | 27204312 61 | -------------------------------------------------------------------------------- /assets/deck/SnakeEyeFire2.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 66431519 4 | 66431519 5 | 66431519 6 | 2526224 7 | 2526224 8 | 2526224 9 | 72270339 10 | 18621798 11 | 18621798 12 | 18621798 13 | 14558127 14 | 14558127 15 | 14558127 16 | 23434538 17 | 23434538 18 | 23434538 19 | 9674034 20 | 9674034 21 | 90241276 22 | 90241276 23 | 90681088 24 | 90681088 25 | 90681088 26 | 12058741 27 | 12058741 28 | 22993208 29 | 22993208 30 | 73628505 31 | 85106525 32 | 85106525 33 | 85106525 34 | 89023486 35 | 89023486 36 | 59388357 37 | 59388357 38 | 91703676 39 | 65305978 40 | 65305978 41 | 57554544 42 | 57554544 43 | 57554544 44 | 38798785 45 | 38798785 46 | #extra 47 | 63767246 48 | 63767246 49 | 64182380 50 | 64182380 51 | 93854893 52 | 98127546 53 | 86066372 54 | 86066372 55 | 4280258 56 | 2772337 57 | 2772337 58 | 2772337 59 | 38342335 60 | 48815792 61 | 48815792 62 | !side 63 | -------------------------------------------------------------------------------- /assets/deck/SnakeEyeTear.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 48452496 5 | 72270339 6 | 72270339 7 | 32909498 8 | 4928565 9 | 5560911 10 | 572850 11 | 572850 12 | 73956664 13 | 73956664 14 | 73956664 15 | 42009836 16 | 37961969 17 | 97682931 18 | 14558127 19 | 14558127 20 | 14558127 21 | 23434538 22 | 23434538 23 | 23434538 24 | 45663742 25 | 9674034 26 | 9674034 27 | 90241276 28 | 90241276 29 | 24081957 30 | 33878367 31 | 89023486 32 | 24224830 33 | 24224830 34 | 60362066 35 | 80845034 36 | 80845034 37 | 80845034 38 | 6767771 39 | 53639887 40 | 71832012 41 | 77103950 42 | 7436169 43 | 38436986 44 | 74920585 45 | #extra 46 | 28226490 47 | 84330567 48 | 92731385 49 | 84815190 50 | 27548199 51 | 63533837 52 | 73082255 53 | 20665527 54 | 4280258 55 | 2772337 56 | 98095162 57 | 84271823 58 | 65741786 59 | 50277355 60 | 41999284 61 | !side 62 | 56099748 63 | 27204312 64 | -------------------------------------------------------------------------------- /assets/deck/TenyiSword.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 06728559 5 | 87052196 6 | 87052196 7 | 23431858 8 | 93490856 9 | 93490856 10 | 93490856 11 | 56495147 12 | 56495147 13 | 56495147 14 | 20001443 15 | 20001443 16 | 20001443 17 | 55273560 18 | 55273560 19 | 55273560 20 | 14558127 21 | 14558127 22 | 14558127 23 | 23434538 24 | 23434538 25 | 23434538 26 | 97268402 27 | 97268402 28 | 97268402 29 | 98159737 30 | 35261759 31 | 35261759 32 | 56465981 33 | 56465981 34 | 56465981 35 | 93850690 36 | 24224830 37 | 24224830 38 | 10045474 39 | 10045474 40 | 10045474 41 | 14821890 42 | 14821890 43 | #extra 44 | 42632209 45 | 60465049 46 | 96633955 47 | 84815190 48 | 47710198 49 | 9464441 50 | 5041348 51 | 69248256 52 | 69248256 53 | 83755611 54 | 43202238 55 | 32519092 56 | 32519092 57 | 32519092 58 | 78917791 59 | !side 60 | 20001444 61 | 27204312 62 | 93490857 63 | 56495148 64 | 14821891 65 | -------------------------------------------------------------------------------- /assets/deck/Voiceless.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 27204311 4 | 27204311 5 | 27204311 6 | 26866984 7 | 88284599 8 | 51296484 9 | 51296484 10 | 51296484 11 | 14558127 12 | 14558127 13 | 14558127 14 | 92919429 15 | 92919429 16 | 92919429 17 | 23434538 18 | 23434538 19 | 23434538 20 | 25801745 21 | 25801745 22 | 25801745 23 | 4810828 24 | 10804018 25 | 10774240 26 | 10774240 27 | 13048472 28 | 13048472 29 | 13048472 30 | 25311006 31 | 49238328 32 | 49238328 33 | 52472775 34 | 52472775 35 | 24224830 36 | 24224830 37 | 39114494 38 | 98477480 39 | 98477480 40 | 98477480 41 | 10045474 42 | 10045474 43 | 86310763 44 | #extra 45 | 80532587 46 | 22850702 47 | 22850702 48 | 79606837 49 | 93039339 50 | 98127546 51 | 73898890 52 | 73898890 53 | 9839945 54 | 29301450 55 | 29301450 56 | 29301450 57 | 71818935 58 | 41999284 59 | 94259633 60 | !side 61 | 27204312 62 | -------------------------------------------------------------------------------- /assets/deck/unsupported/Magician.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 14513273 4 | 14513273 5 | 14513273 6 | 76794549 7 | 22211622 8 | 96227613 9 | 96227613 10 | 14920218 11 | 69610326 12 | 69610326 13 | 40318957 14 | 40318957 15 | 40318957 16 | 72714461 17 | 72714461 18 | 72714461 19 | 49684352 20 | 49684352 21 | 49684352 22 | 73941492 23 | 27204311 24 | 14558127 25 | 14558127 26 | 14558127 27 | 23434538 28 | 23434538 29 | 23434538 30 | 01845204 31 | 25311006 32 | 41620959 33 | 41620959 34 | 81439173 35 | 24224830 36 | 24224830 37 | 65681983 38 | 82190203 39 | 82190203 40 | 55795155 41 | 74850403 42 | 10045474 43 | 01344018 44 | #extra 45 | 43387895 46 | 43387895 47 | 53262004 48 | 76815942 49 | 58074177 50 | 84815190 51 | 30095833 52 | 16691074 53 | 20665527 54 | 04280258 55 | 02772337 56 | 92812851 57 | 45819647 58 | 24094258 59 | 22125101 60 | !side 61 | 27204312 62 | -------------------------------------------------------------------------------- /assets/deck/unsupported/_tokens.ydk: -------------------------------------------------------------------------------- 1 | #main 2 | 176393 3 | 645088 4 | 904186 5 | 1426715 6 | 1799465 7 | 2625940 8 | 2819436 9 | 3285552 10 | 4417408 11 | 7392746 12 | 7610395 13 | 8025951 14 | 8198621 15 | 9047461 16 | 9396663 17 | 9925983 18 | 9929399 19 | 10389143 20 | 11050416 21 | 11050417 22 | 11050418 23 | 11050419 24 | 11654068 25 | 11738490 26 | 12958920 27 | 12965762 28 | 13536607 29 | 13764603 30 | 13935002 31 | 14089429 32 | 14470846 33 | 14470847 34 | 14821891 35 | 14957441 36 | 15341822 37 | 15341823 38 | 15394084 39 | 15590356 40 | 15629802 41 | 16943771 42 | 16946850 43 | 17000166 44 | 17228909 45 | 17418745 46 | 18027139 47 | 18027140 48 | 18027141 49 | 18494512 50 | 19280590 51 | 20001444 52 | 20368764 53 | 21179144 54 | 21770261 55 | 21830680 56 | 22110648 57 | 22404676 58 | 22411610 59 | 22493812 60 | 22953212 61 | 23116809 62 | 23331401 63 | 23837055 64 | 24874631 65 | 25415053 66 | 25419324 67 | 26326542 68 | 27198002 69 | 27204312 70 | 27450401 71 | 27882994 72 | 28053764 73 | 28062326 74 | 28355719 75 | 28674153 76 | 29491335 77 | 29843092 78 | 29843093 79 | 29843094 80 | 30069399 81 | 30327675 82 | 30327676 83 | 30650148 84 | 30765616 85 | 30811117 86 | 31480216 87 | 31533705 88 | 31600514 89 | 31986289 90 | 32056071 91 | 32335698 92 | 32446631 93 | 33676147 94 | 34479659 95 | 34690954 96 | 34767866 97 | 34822851 98 | 35263181 99 | 35268888 100 | 35514097 101 | 35834120 102 | 36629636 103 | 38030233 104 | 38041941 105 | 38053382 106 | 39972130 107 | 40551411 108 | 40633085 109 | 40703223 110 | 40844553 111 | 41329459 112 | 41456842 113 | 42427231 114 | 42671152 115 | 42956964 116 | 43140792 117 | 43664495 118 | 44026394 119 | 44052075 120 | 44092305 121 | 44097051 122 | 44308318 123 | 44330099 124 | 44586427 125 | 44689689 126 | 46173680 127 | 46173681 128 | 46647145 129 | 47658965 130 | 48068379 131 | 48115278 132 | 48411997 133 | 49752796 134 | 49808197 135 | 51208047 136 | 51611042 137 | 51987572 138 | 52340445 139 | 52900001 140 | 53855410 141 | 53855411 142 | 54537490 143 | 55326323 144 | 56051649 145 | 56495148 146 | 56597273 147 | 58371672 148 | 59160189 149 | 59900656 150 | 60025884 151 | 60406592 152 | 60514626 153 | 60764582 154 | 60764583 155 | 62125439 156 | 62481204 157 | 62543394 158 | 63184228 159 | 63442605 160 | 64213018 161 | 64382840 162 | 64583601 163 | 65500516 164 | 65810490 165 | 66200211 166 | 66661679 167 | 67284108 168 | 67489920 169 | 67922703 170 | 67949764 171 | 68815402 172 | 69550260 173 | 69811711 174 | 69868556 175 | 69890968 176 | 70391589 177 | 70465811 178 | 70875956 179 | 70950699 180 | 71645243 181 | 72291079 182 | 73915052 183 | 73915053 184 | 73915054 185 | 73915055 186 | 74440056 187 | 74627017 188 | 74659583 189 | 74983882 190 | 75119041 191 | 75524093 192 | 75622825 193 | 75732623 194 | 76524507 195 | 76589547 196 | 77672445 197 | 78394033 198 | 78789357 199 | 78836196 200 | 79387393 201 | 81767889 202 | 82255873 203 | 82324106 204 | 82340057 205 | 82556059 206 | 82994510 207 | 83239740 208 | 84816245 209 | 85243785 210 | 85771020 211 | 85771021 212 | 85969518 213 | 86801872 214 | 86871615 215 | 87240372 216 | 87669905 217 | 88923964 218 | 89907228 219 | 90884404 220 | 91512836 221 | 93104633 222 | 93130022 223 | 93224849 224 | 93490857 225 | 93912846 226 | 94703022 227 | 94973029 228 | 97452818 229 | 98596597 230 | 98875864 231 | 99092625 232 | 99137267 233 | #extra 234 | !side 235 | -------------------------------------------------------------------------------- /assets/deck/unused/OldSchool.ydk: -------------------------------------------------------------------------------- 1 | #created by ... 2 | #main 3 | 6631034 4 | 6631034 5 | 6631034 6 | 43096270 7 | 43096270 8 | 43096270 9 | 69247929 10 | 69247929 11 | 69247929 12 | 77542832 13 | 77542832 14 | 77542832 15 | 11091375 16 | 11091375 17 | 11091375 18 | 35052053 19 | 35052053 20 | 35052053 21 | 49881766 22 | 83104731 23 | 83104731 24 | 30190809 25 | 30190809 26 | 26412047 27 | 26412047 28 | 26412047 29 | 43422537 30 | 43422537 31 | 43422537 32 | 53129443 33 | 66788016 34 | 66788016 35 | 66788016 36 | 72302403 37 | 72302403 38 | 44095762 39 | 44095762 40 | 44095762 41 | 70342110 42 | 70342110 43 | #extra 44 | !side 45 | -------------------------------------------------------------------------------- /assets/log_conf.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | disable_existing_loggers: False 3 | formatters: 4 | default: 5 | # "()": uvicorn.logging.DefaultFormatter 6 | format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 7 | access: 8 | # "()": uvicorn.logging.AccessFormatter 9 | format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 10 | handlers: 11 | default: 12 | formatter: default 13 | class: logging.StreamHandler 14 | stream: ext://sys.stderr 15 | access: 16 | formatter: access 17 | class: logging.StreamHandler 18 | stream: ext://sys.stdout 19 | loggers: 20 | uvicorn.error: 21 | level: INFO 22 | handlers: 23 | - default 24 | propagate: no 25 | uvicorn.access: 26 | level: INFO 27 | handlers: 28 | - access 29 | propagate: no 30 | root: 31 | level: INFO 32 | handlers: 33 | - default 34 | propagate: no -------------------------------------------------------------------------------- /docs/action.md: -------------------------------------------------------------------------------- 1 | # Action 2 | 3 | ## Types 4 | - Set + card 5 | - Reposition + card 6 | - Special summon + card 7 | - Summon Face-up Attack + card 8 | - Summon Face-down Defense + card 9 | - Attack + card 10 | - DirectAttack + card 11 | - Activate + card + effect 12 | - Cancel 13 | - Switch + phase 14 | - SelectPosition + card + position 15 | - AnnounceNumber + card + effect + number 16 | - SelectPlace + card + place 17 | - AnnounceAttrib + card + effect + attrib 18 | 19 | ## Effect 20 | 21 | ### MSG_SELECT_BATTLECMD | MSG_SELECT_IDLECMD | MSG_SELECT_CHAIN | MSG_SELECT_EFFECTYN 22 | - desc == 0: default effect of card 23 | - desc < LIMIT: system string 24 | - desc > LIMIT: card + effect 25 | 26 | ### MSG_SELECT_OPTION | MSG_SELECT_YESNO 27 | - desc == 0: error 28 | - desc < LIMIT: system string 29 | - desc > LIMIT: card + effect 30 | -------------------------------------------------------------------------------- /docs/feature_engineering.md: -------------------------------------------------------------------------------- 1 | # Features 2 | 3 | ## Definitions 4 | 5 | ### Float transform 6 | - float transform: max 65535 -> 2 bytes 7 | 8 | ### Card ID 9 | The card id is the index of the card code in `code_list.txt`. 10 | 11 | ## Card 12 | - 0,1: card id, uint16 -> 2 uint8, name+desc 13 | - 2: location, discrete, 0: N/A, 1+: same as location2str (9) 14 | - 3: seq, discrete, 0: N/A, 1+: seq in location 15 | - 4: owner, discrete, 0: me, 1: oppo (2) 16 | - 5: position, discrete, 0: N/A, 1+: same as position2str 17 | - 6: overlay, discrete, 0: not, 1: xyz material 18 | - 7: attribute, discrete, 0: N/A, 1+: same as attribute2str 19 | - 8: race, discrete, 0: N/A, 1+: same as race2str 20 | - 9: level, discrete, 0: N/A 21 | - 10: counter, discrete, 0: N/A 22 | - 11: negated, discrete, 0: False, 1: True 23 | - 12,13: atk, float transform 24 | - 14,15: def: float transform 25 | - 16-40: type, multi-hot, same as type2str (25) 26 | 27 | ## Global 28 | - 0,1: my_lp, float transform 29 | - 2,3: op_lp, float transform 30 | - 4: turn, discrete, trunc to 16 31 | - 5: phase, discrete (10) 32 | - 6: is_first, discrete, 0: False, 1: True 33 | - 7: is_my_turn, discrete, 0: False, 1: True 34 | - 8: n_my_decks, count 35 | - 9: n_my_hands, count 36 | - 10: n_my_monsters, count 37 | - 11: n_my_spell_traps, count 38 | - 12: n_my_graves, count 39 | - 13: n_my_removes, count 40 | - 14: n_my_extras, count 41 | - 15: n_op_decks, count 42 | - 16: n_op_hands, count 43 | - 17: n_op_monsters, count 44 | - 18: n_op_spell_traps, count 45 | - 19: n_op_graves, count 46 | - 20: n_op_removes, count 47 | - 21: n_op_extras, count 48 | - 22: is_end, discrete, 0: False, 1: True 49 | 50 | 51 | ## Legal Actions 52 | - 0: spec index 53 | - 1,2: code, uint16 -> 2 uint8 54 | - 3: msg, discrete, 0: N/A, 1+: same as msg2str (15) 55 | - 4: act, discrete (11) 56 | - N/A 57 | - Set 58 | - Reposition 59 | - Special Summon 60 | - Summon Face-up Attack 61 | - Summon Face-down Defense 62 | - Attack 63 | - DirectAttack 64 | - Activate 65 | - Cancel 66 | - 5: finish, discrete (2) 67 | - N/A 68 | - Finish 69 | - 6: effect, discrete, 0: N/A 70 | - 7: phase, discrete (4) 71 | - N/A 72 | - Battle (b) 73 | - Main Phase 2 (m) 74 | - End Phase (e) 75 | - 8: position, discrete, 0: N/A, same as position2str 76 | - 9: number, discrete, 0: N/A 77 | - 10: place, discrete 78 | - 0: N/A 79 | - 1-7: m 80 | - 8-15: s 81 | - 16-22: om 82 | - 23-30: os 83 | - 11: attribute, discrete, 0: N/A, same as attribute2id 84 | 85 | 86 | ## History Actions 87 | - 0,1: card id, uint16 -> 2 uint8 88 | - 2-11 same as legal actions 89 | - 12: turn, discrete, trunc to 3 90 | - 13: phase, discrete (10) 91 | -------------------------------------------------------------------------------- /docs/network_design.md: -------------------------------------------------------------------------------- 1 | # Dimensions 2 | B: batch size 3 | C: number of channels 4 | H: number of history actions 5 | 6 | # Features 7 | f_cards: (B, n_cards, C), features of cards 8 | f_global: (B, C), global features 9 | f_h_actions: (B, H, C), features of history actions 10 | f_actions: (B, max_actions, C), features of current legal actions 11 | 12 | output: (B, max_actions, 1), value of each action 13 | 14 | # Fusion 15 | 16 | ## Method 1 17 | ``` 18 | f_cards -> n encoder layers -> f_cards 19 | f_global -> ResMLP -> f_global 20 | f_cards = f_cards + f_global 21 | f_actions -> n encoder layers -> f_actions 22 | 23 | f_cards[id] -> f_a_cards -> ResMLP -> f_a_cards 24 | f_actions = f_a_cards + f_a_feats 25 | 26 | f_actions, f_cards -> n decoder layers -> f_actions 27 | 28 | f_h_actions -> n encoder layers -> f_h_actions 29 | f_actions, f_h_actions -> n decoder layers -> f_actions 30 | 31 | f_actions -> MLP -> output 32 | ``` -------------------------------------------------------------------------------- /docs/support.md: -------------------------------------------------------------------------------- 1 | # Deck 2 | 3 | ## Unsupported 4 | - Magician (pendulum) 5 | 6 | # Messgae 7 | 8 | ## announce_attrib 9 | Only 1 attribute is announced at a time. 10 | Not supported: 11 | - DNA Checkup 12 | 13 | ## announce_number 14 | Only 1-12 is supported. -------------------------------------------------------------------------------- /mcts/mcts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbl1996/ygo-agent/dbf5142d49aab2e6beb4150788d4fffec39ae3e5/mcts/mcts/__init__.py -------------------------------------------------------------------------------- /mcts/mcts/alphazero/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbl1996/ygo-agent/dbf5142d49aab2e6beb4150788d4fffec39ae3e5/mcts/mcts/alphazero/__init__.py -------------------------------------------------------------------------------- /mcts/mcts/alphazero/alphazero_mcts.pyi: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import numpy 3 | __all__ = ['MinMaxStatsList', 'Roots', 'SearchResults', 'batch_backpropagate', 'batch_expand', 'batch_traverse', 'init_module'] 4 | class MinMaxStatsList: 5 | def __init__(self, arg0: int) -> None: 6 | ... 7 | def set_delta(self, arg0: float) -> None: 8 | ... 9 | class Roots: 10 | def __init__(self, arg0: int) -> None: 11 | ... 12 | def get_distributions(self) -> list[list[int]]: 13 | ... 14 | def get_values(self) -> list[float]: 15 | ... 16 | def prepare(self, arg0: numpy.ndarray[numpy.float32], arg1: numpy.ndarray[numpy.float32], arg2: numpy.ndarray[numpy.int32], arg3: numpy.ndarray[numpy.int32], arg4: float, arg5: float) -> None: 17 | ... 18 | @property 19 | def num(self) -> int: 20 | ... 21 | class SearchResults: 22 | def __init__(self, arg0: int) -> None: 23 | ... 24 | def get_search_len(self) -> list[int]: 25 | ... 26 | def batch_backpropagate(arg0: float, arg1: numpy.ndarray[numpy.float32], arg2: MinMaxStatsList, arg3: SearchResults) -> None: 27 | ... 28 | def batch_expand(arg0: int, arg1: numpy.ndarray[bool], arg2: numpy.ndarray[numpy.float32], arg3: numpy.ndarray[numpy.float32], arg4: numpy.ndarray[numpy.int32], arg5: numpy.ndarray[numpy.int32], arg6: SearchResults) -> None: 29 | ... 30 | def batch_traverse(arg0: Roots, arg1: int, arg2: float, arg3: float, arg4: MinMaxStatsList, arg5: SearchResults) -> tuple: 31 | ... 32 | def init_module(seed: int) -> None: 33 | ... 34 | -------------------------------------------------------------------------------- /mcts/mcts/alphazero/cnode.h: -------------------------------------------------------------------------------- 1 | #ifndef AZ_CNODE_H 2 | #define AZ_CNODE_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "mcts/core/minimax.h" 16 | #include "mcts/core/array.h" 17 | 18 | const int DEBUG_MODE = 0; 19 | 20 | namespace tree { 21 | 22 | void init_module(int seed); 23 | 24 | using Action = int; 25 | 26 | class Node { 27 | public: 28 | int visit_count, state_index, batch_index, best_action; 29 | float reward, prior, value_sum; 30 | std::map children; 31 | 32 | Node(); 33 | Node(float prior); 34 | ~Node(); 35 | 36 | void expand( 37 | int state_index, int batch_index, float reward, const Array &logits, const Array &legal_actions); 38 | void add_exploration_noise(float exploration_fraction, float dirichlet_alpha); 39 | float compute_mean_q(int isRoot, float parent_q, float discount_factor); 40 | 41 | int expanded() const; 42 | float value() const; 43 | std::vector get_trajectory(); 44 | std::vector get_children_distribution(); 45 | Node* get_child(int action); 46 | }; 47 | 48 | class Roots{ 49 | public: 50 | int root_num; 51 | std::vector roots; 52 | 53 | Roots(); 54 | Roots(int root_num); 55 | ~Roots(); 56 | 57 | void prepare( 58 | const Array &rewards, const Array &logits, 59 | const Array &all_legal_actions, const Array &n_legal_actions, 60 | float exploration_fraction, float dirichlet_alpha); 61 | void clear(); 62 | std::vector > get_trajectories(); 63 | std::vector > get_distributions(); 64 | std::vector get_values(); 65 | 66 | }; 67 | 68 | class SearchResults{ 69 | public: 70 | int num; 71 | std::vector state_index_in_search_path, state_index_in_batch, last_actions, search_lens; 72 | std::vector nodes; 73 | std::vector > search_paths; 74 | 75 | SearchResults(); 76 | SearchResults(int num); 77 | ~SearchResults(); 78 | 79 | }; 80 | 81 | 82 | void update_tree_q(Node* root, MinMaxStats &min_max_stats, float discount_factor); 83 | void backpropagate(std::vector &search_path, MinMaxStats &min_max_stats, float value, float discount_factor); 84 | void batch_expand( 85 | int state_index, const Array &game_over, const Array &rewards, const Array &logits /* 2D array */, 86 | const Array &all_legal_actions, const Array &n_legal_actions, SearchResults &results); 87 | void batch_backpropagate(float discount_factor, const Array &values, MinMaxStatsList &min_max_stats_lst, SearchResults &results); 88 | int select_child(Node* root, const MinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q); 89 | float ucb_score(const Node &child, const MinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor); 90 | void batch_traverse(Roots &roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList &min_max_stats_lst, SearchResults &results); 91 | } 92 | 93 | #endif // AZ_CNODE_H -------------------------------------------------------------------------------- /mcts/mcts/alphazero/tree.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "mcts/core/common.h" 7 | #include "mcts/core/minimax.h" 8 | #include "mcts/core/array.h" 9 | 10 | #include "mcts/alphazero/cnode.h" 11 | 12 | namespace py = pybind11; 13 | 14 | PYBIND11_MODULE(alphazero_mcts, m) { 15 | using namespace pybind11::literals; 16 | 17 | py::class_(m, "MinMaxStatsList") 18 | .def(py::init()) 19 | .def("set_delta", &MinMaxStatsList::set_delta); 20 | 21 | py::class_(m, "SearchResults") 22 | .def(py::init()) 23 | .def("get_search_len", [](tree::SearchResults &results) { 24 | return results.search_lens; 25 | }); 26 | 27 | py::class_(m, "Roots") 28 | .def(py::init()) 29 | .def_readonly("num", &tree::Roots::root_num) 30 | .def("prepare", []( 31 | tree::Roots &roots, const py::array_t &rewards, 32 | const py::array_t &logits, const py::array_t &all_legal_actions, 33 | const py::array_t &n_legal_actions, float exploration_fraction, 34 | float dirichlet_alpha) { 35 | Array rewards_ = NumpyToArray(rewards); 36 | Array logits_ = NumpyToArray(logits); 37 | Array all_legal_actions_ = NumpyToArray(all_legal_actions); 38 | Array n_legal_actions_ = NumpyToArray(n_legal_actions); 39 | roots.prepare(rewards_, logits_, all_legal_actions_, n_legal_actions_, exploration_fraction, dirichlet_alpha); 40 | }) 41 | .def("get_distributions", &tree::Roots::get_distributions) 42 | .def("get_values", &tree::Roots::get_values); 43 | 44 | m.def("batch_expand", []( 45 | int state_index, const py::array_t &game_over, const py::array_t &rewards, const py::array_t &logits, 46 | const py::array_t &all_legal_actions, const py::array_t &n_legal_actions, tree::SearchResults &results) { 47 | Array game_over_ = NumpyToArray(game_over); 48 | Array rewards_ = NumpyToArray(rewards); 49 | Array logits_ = NumpyToArray(logits); 50 | Array all_legal_actions_ = NumpyToArray(all_legal_actions); 51 | Array n_legal_actions_ = NumpyToArray(n_legal_actions); 52 | tree::batch_expand(state_index, game_over_, rewards_, logits_, all_legal_actions_, n_legal_actions_, results); 53 | }); 54 | 55 | m.def("batch_backpropagate", []( 56 | float discount_factor, const py::array_t &values, 57 | MinMaxStatsList &min_max_stats_lst, tree::SearchResults &results) { 58 | Array values_ = NumpyToArray(values); 59 | tree::batch_backpropagate(discount_factor, values_, min_max_stats_lst, results); 60 | }); 61 | 62 | m.def("batch_traverse", []( 63 | tree::Roots &roots, int pb_c_base, float pb_c_init, float discount_factor, 64 | MinMaxStatsList &min_max_stats_lst, tree::SearchResults &results) { 65 | tree::batch_traverse(roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results); 66 | return py::make_tuple(results.state_index_in_search_path, results.state_index_in_batch, results.last_actions); 67 | }); 68 | 69 | m.def("init_module", &tree::init_module, "", "seed"_a); 70 | } -------------------------------------------------------------------------------- /mcts/mcts/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbl1996/ygo-agent/dbf5142d49aab2e6beb4150788d4fffec39ae3e5/mcts/mcts/core/__init__.py -------------------------------------------------------------------------------- /mcts/mcts/core/array.h: -------------------------------------------------------------------------------- 1 | #ifndef MCTS_CORE_ARRAY_H_ 2 | #define MCTS_CORE_ARRAY_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "mcts/core/spec.h" 11 | 12 | class Array { 13 | public: 14 | std::size_t size; 15 | std::size_t ndim; 16 | std::size_t element_size; 17 | 18 | protected: 19 | std::vector shape_; 20 | std::shared_ptr ptr_; 21 | 22 | template 23 | Array(char *ptr, Shape &&shape, std::size_t element_size,// NOLINT 24 | Deleter &&deleter) 25 | : size(Prod(shape.data(), shape.size())), 26 | ndim(shape.size()), 27 | element_size(element_size), 28 | shape_(std::forward(shape)), 29 | ptr_(ptr, std::forward(deleter)) {} 30 | 31 | template 32 | Array(std::shared_ptr ptr, Shape &&shape, std::size_t element_size) 33 | : size(Prod(shape.data(), shape.size())), 34 | ndim(shape.size()), 35 | element_size(element_size), 36 | shape_(std::forward(shape)), 37 | ptr_(std::move(ptr)) {} 38 | 39 | public: 40 | Array() = default; 41 | 42 | /** 43 | * Constructor an `Array` of shape defined by `spec`, with `data` as pointer 44 | * to its raw memory. With an empty deleter, which means Array does not own 45 | * the memory. 46 | */ 47 | template 48 | Array(const ShapeSpec &spec, char *data, Deleter &&deleter)// NOLINT 49 | : Array(data, spec.Shape(), spec.element_size, 50 | std::forward(deleter)) {} 51 | 52 | Array(const ShapeSpec &spec, char *data) 53 | : Array(data, spec.Shape(), spec.element_size, [](char * /*unused*/) {}) {} 54 | 55 | /** 56 | * Constructor an `Array` of shape defined by `spec`. This constructor 57 | * allocates and owns the memory. 58 | */ 59 | explicit Array(const ShapeSpec &spec) 60 | : Array(spec, nullptr, [](char * /*unused*/) {}) { 61 | ptr_.reset(new char[size * element_size](), 62 | [](const char *p) { delete[] p; }); 63 | } 64 | 65 | /** 66 | * Take multidimensional index into the Array. 67 | */ 68 | template 69 | inline Array operator()(Index... index) const { 70 | constexpr std::size_t num_index = sizeof...(Index); 71 | std::size_t offset = 0; 72 | std::size_t i = 0; 73 | for (((offset = offset * shape_[i++] + index), ...); i < ndim; ++i) { 74 | offset *= shape_[i]; 75 | } 76 | return Array( 77 | ptr_.get() + offset * element_size, 78 | std::vector(shape_.begin() + num_index, shape_.end()), 79 | element_size, [](char * /*unused*/) {}); 80 | } 81 | 82 | /** 83 | * Index operator of array, takes the index along the first axis. 84 | */ 85 | inline Array operator[](int index) const { return this->operator()(index); } 86 | 87 | /** 88 | * Take a slice at the first axis of the Array. 89 | */ 90 | [[nodiscard]] Array Slice(std::size_t start, std::size_t end) const { 91 | std::vector new_shape(shape_); 92 | new_shape[0] = end - start; 93 | std::size_t offset = 0; 94 | if (shape_[0] > 0) { 95 | offset = start * size / shape_[0]; 96 | } 97 | return {ptr_.get() + offset * element_size, std::move(new_shape), 98 | element_size, [](char *p) {}}; 99 | } 100 | 101 | /** 102 | * Copy the content of another Array to this Array. 103 | */ 104 | void Assign(const Array &value) const { 105 | std::memcpy(ptr_.get(), value.ptr_.get(), size * element_size); 106 | } 107 | 108 | /** 109 | * Return a clone of this array. 110 | */ 111 | Array Clone() const { 112 | std::vector shape; 113 | for (int i = 0; i < ndim; i++) { 114 | shape.push_back(shape_[i]); 115 | } 116 | auto spec = ShapeSpec(element_size, shape); 117 | Array ret(spec); 118 | ret.Assign(*this); 119 | return ret; 120 | } 121 | 122 | /** 123 | * Assign to this Array a scalar value. This Array needs to have a scalar 124 | * shape. 125 | */ 126 | template 127 | void operator=(const T &value) const { 128 | *reinterpret_cast(ptr_.get()) = value; 129 | } 130 | 131 | /** 132 | * Fills this array with a scalar value of type T. 133 | */ 134 | template 135 | void Fill(const T &value) const { 136 | auto *data = reinterpret_cast(ptr_.get()); 137 | std::fill(data, data + size, value); 138 | } 139 | 140 | /** 141 | * Copy the memory starting at `raw.first`, to `raw.first + raw.second` to the 142 | * memory of this Array. 143 | */ 144 | template 145 | void Assign(const T *buff, std::size_t sz) const { 146 | std::memcpy(ptr_.get(), buff, sz * sizeof(T)); 147 | } 148 | 149 | template 150 | void Assign(const T *buff, std::size_t sz, ptrdiff_t offset) const { 151 | offset = offset * (element_size / sizeof(char)); 152 | std::memcpy(ptr_.get() + offset, buff, sz * sizeof(T)); 153 | } 154 | 155 | /** 156 | * Cast the Array to a scalar value of type `T`. This Array needs to have a 157 | * scalar shape. 158 | */ 159 | template 160 | operator const T &() const {// NOLINT 161 | return *reinterpret_cast(ptr_.get()); 162 | } 163 | 164 | /** 165 | * Cast the Array to a scalar value of type `T`. This Array needs to have a 166 | * scalar shape. 167 | */ 168 | template 169 | operator T &() {// NOLINT 170 | return *reinterpret_cast(ptr_.get()); 171 | } 172 | 173 | /** 174 | * Size of axis `dim`. 175 | */ 176 | [[nodiscard]] inline std::size_t Shape(std::size_t dim) const { 177 | return shape_[dim]; 178 | } 179 | 180 | /** 181 | * Shape 182 | */ 183 | [[nodiscard]] inline const std::vector &Shape() const { 184 | return shape_; 185 | } 186 | 187 | /** 188 | * Pointer to the raw memory. 189 | */ 190 | [[nodiscard]] inline void *Data() const { return ptr_.get(); } 191 | 192 | /** 193 | * Truncate the Array. Return a new Array that shares the same memory 194 | * location but with a truncated shape. 195 | */ 196 | [[nodiscard]] Array Truncate(std::size_t end) const { 197 | auto new_shape = std::vector(shape_); 198 | new_shape[0] = end; 199 | Array ret(ptr_, std::move(new_shape), element_size); 200 | return ret; 201 | } 202 | 203 | void Zero() const { std::memset(ptr_.get(), 0, size * element_size); } 204 | [[nodiscard]] std::shared_ptr SharedPtr() const { return ptr_; } 205 | }; 206 | 207 | template 208 | class TArray : public Array { 209 | public: 210 | explicit TArray(const Spec &spec) : Array(spec) {} 211 | }; 212 | 213 | #endif // MCTS_CORE_ARRAY_H_ -------------------------------------------------------------------------------- /mcts/mcts/core/common.h: -------------------------------------------------------------------------------- 1 | #ifndef MCTS_CORE_COMMON_H_ 2 | #define MCTS_CORE_COMMON_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "mcts/core/array.h" 9 | 10 | namespace py = pybind11; 11 | 12 | template 13 | py::array_t ArrayToNumpy(const Array &a) { 14 | auto *ptr = new std::shared_ptr(a.SharedPtr()); 15 | auto capsule = py::capsule(ptr, [](void *ptr) { 16 | delete reinterpret_cast *>(ptr); 17 | }); 18 | return py::array(a.Shape(), reinterpret_cast(a.Data()), capsule); 19 | } 20 | 21 | template 22 | Array NumpyToArray(const py::array_t &arr) { 23 | using ArrayT = py::array_t; 24 | ArrayT arr_t(arr); 25 | ShapeSpec spec(arr_t.itemsize(), 26 | std::vector(arr_t.shape(), arr_t.shape() + arr_t.ndim())); 27 | return {spec, reinterpret_cast(arr_t.mutable_data())}; 28 | } 29 | 30 | 31 | template 32 | inline py::array_t as_pyarray(Sequence &&seq) { 33 | using T = typename Sequence::value_type; 34 | std::unique_ptr seq_ptr = std::make_unique(std::forward(seq)); 35 | return py::array_t({seq_ptr->size()}, {sizeof(T)}, seq_ptr->data()); 36 | } 37 | 38 | 39 | #endif // MCTS_CORE_COMMON_H_ -------------------------------------------------------------------------------- /mcts/mcts/core/minimax.h: -------------------------------------------------------------------------------- 1 | #ifndef MCTS_CORE_MINIMAX_H_ 2 | #define MCTS_CORE_MINIMAX_H_ 3 | 4 | #include 5 | #include 6 | 7 | const float FLOAT_MAX = 1000000.0; 8 | const float FLOAT_MIN = -FLOAT_MAX; 9 | 10 | class MinMaxStats { 11 | public: 12 | float maximum, minimum, value_delta_max; 13 | 14 | MinMaxStats() { 15 | this->maximum = FLOAT_MIN; 16 | this->minimum = FLOAT_MAX; 17 | this->value_delta_max = 0.; 18 | } 19 | ~MinMaxStats() {} 20 | 21 | void set_delta(float value_delta_max) { 22 | this->value_delta_max = value_delta_max; 23 | } 24 | void update(float value) { 25 | if(value > this->maximum){ 26 | this->maximum = value; 27 | } 28 | if(value < this->minimum){ 29 | this->minimum = value; 30 | } 31 | } 32 | void clear() { 33 | this->maximum = FLOAT_MIN; 34 | this->minimum = FLOAT_MAX; 35 | } 36 | float normalize(float value) const { 37 | float norm_value = value; 38 | float delta = this->maximum - this->minimum; 39 | if(delta > 0){ 40 | if(delta < this->value_delta_max){ 41 | norm_value = (norm_value - this->minimum) / this->value_delta_max; 42 | } 43 | else{ 44 | norm_value = (norm_value - this->minimum) / delta; 45 | } 46 | } 47 | return norm_value; 48 | } 49 | }; 50 | 51 | class MinMaxStatsList { 52 | public: 53 | int num; 54 | std::vector stats_lst; 55 | 56 | MinMaxStatsList() { 57 | this->num = 0; 58 | } 59 | MinMaxStatsList(int num) { 60 | this->num = num; 61 | for(int i = 0; i < num; ++i){ 62 | this->stats_lst.push_back(MinMaxStats()); 63 | } 64 | } 65 | ~MinMaxStatsList() {} 66 | 67 | void set_delta(float value_delta_max) { 68 | for(int i = 0; i < this->num; ++i){ 69 | this->stats_lst[i].set_delta(value_delta_max); 70 | } 71 | } 72 | }; 73 | 74 | #endif // MCTS_CORE_MINIMAX_H_ -------------------------------------------------------------------------------- /mcts/mcts/core/spec.h: -------------------------------------------------------------------------------- 1 | #ifndef MCTS_CORE_SPEC_H_ 2 | #define MCTS_CORE_SPEC_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | static std::size_t Prod(const std::size_t *shape, std::size_t ndim) { 15 | return std::accumulate(shape, shape + ndim, static_cast(1), 16 | std::multiplies<>()); 17 | } 18 | 19 | class ShapeSpec { 20 | public: 21 | int element_size; 22 | std::vector shape; 23 | ShapeSpec() = default; 24 | ShapeSpec(int element_size, std::vector shape_vec) 25 | : element_size(element_size), shape(std::move(shape_vec)) {} 26 | [[nodiscard]] ShapeSpec Batch(int batch_size) const { 27 | std::vector new_shape = {batch_size}; 28 | new_shape.insert(new_shape.end(), shape.begin(), shape.end()); 29 | return {element_size, std::move(new_shape)}; 30 | } 31 | [[nodiscard]] std::vector Shape() const { 32 | auto s = std::vector(shape.size()); 33 | for (std::size_t i = 0; i < shape.size(); ++i) { 34 | s[i] = shape[i]; 35 | } 36 | return s; 37 | } 38 | }; 39 | 40 | 41 | template 42 | class Spec : public ShapeSpec { 43 | public: 44 | using dtype = D;// NOLINT 45 | std::tuple bounds = {std::numeric_limits::min(), 46 | std::numeric_limits::max()}; 47 | std::tuple, std::vector> elementwise_bounds; 48 | explicit Spec(std::vector &&shape) 49 | : ShapeSpec(sizeof(dtype), std::move(shape)) {} 50 | explicit Spec(const std::vector &shape) 51 | : ShapeSpec(sizeof(dtype), shape) {} 52 | 53 | /* init with constant bounds */ 54 | Spec(std::vector &&shape, std::tuple &&bounds) 55 | : ShapeSpec(sizeof(dtype), std::move(shape)), bounds(std::move(bounds)) {} 56 | Spec(const std::vector &shape, const std::tuple &bounds) 57 | : ShapeSpec(sizeof(dtype), shape), bounds(bounds) {} 58 | 59 | /* init with elementwise bounds */ 60 | Spec(std::vector &&shape, 61 | std::tuple, std::vector> &&elementwise_bounds) 62 | : ShapeSpec(sizeof(dtype), std::move(shape)), 63 | elementwise_bounds(std::move(elementwise_bounds)) {} 64 | Spec(const std::vector &shape, 65 | const std::tuple, std::vector> & 66 | elementwise_bounds) 67 | : ShapeSpec(sizeof(dtype), shape), 68 | elementwise_bounds(elementwise_bounds) {} 69 | 70 | [[nodiscard]] Spec Batch(int batch_size) const { 71 | std::vector new_shape = {batch_size}; 72 | new_shape.insert(new_shape.end(), shape.begin(), shape.end()); 73 | return Spec(std::move(new_shape)); 74 | } 75 | }; 76 | 77 | template 78 | class TArray; 79 | 80 | template 81 | using Container = std::unique_ptr>; 82 | 83 | template 84 | class Spec> : public ShapeSpec { 85 | public: 86 | using dtype = Container;// NOLINT 87 | Spec inner_spec; 88 | explicit Spec(const std::vector &shape, const Spec &inner_spec) 89 | : ShapeSpec(sizeof(Container), shape), inner_spec(inner_spec) {} 90 | explicit Spec(std::vector &&shape, Spec &&inner_spec) 91 | : ShapeSpec(sizeof(Container), std::move(shape)), 92 | inner_spec(std::move(inner_spec)) {} 93 | }; 94 | 95 | #endif // MCTS_CORE_SPEC_H_ -------------------------------------------------------------------------------- /mcts/mcts/core/state.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Sequence 2 | 3 | import numpy as np 4 | 5 | class State: 6 | 7 | def __init__(self, batch_shape: Tuple[int, ...], store): 8 | assert isinstance(store, np.ndarray) 9 | self.store = store 10 | 11 | self.batch_shape = batch_shape 12 | self.ndim = len(batch_shape) 13 | 14 | def get_state_keys(self): 15 | return self.store 16 | 17 | @classmethod 18 | def from_item(cls, item): 19 | return cls((1,), np.array([item], dtype=np.int32)) 20 | 21 | def reshape(self, batch_shape: Tuple[int, ...]): 22 | self.batch_shape = batch_shape 23 | self.ndim = len(batch_shape) 24 | return self 25 | 26 | def item(self): 27 | assert self.ndim == 1 and self.batch_shape[0] == 1 28 | return self.store[0] 29 | 30 | @classmethod 31 | def from_state_list(cls, state_list, batch_shape=None): 32 | if isinstance(state_list[0], State): 33 | batch_shape_ = (len(state_list),) 34 | elif isinstance(state_list[0], Sequence): 35 | batch_shape_ = (len(state_list), len(state_list[0])) 36 | assert isinstance(state_list[0][0], State) 37 | else: 38 | raise ValueError("Invalid dim of states") 39 | if batch_shape is None: 40 | batch_shape = batch_shape_ 41 | else: 42 | assert len(batch_shape) == 2 and len(batch_shape_) == 1 43 | if len(batch_shape) == 2: 44 | states = [s for ss in state_list for s in ss] 45 | else: 46 | states = state_list 47 | state_keys = np.concatenate([s.store for s in states], dtype=np.int32, axis=0) 48 | return State(batch_shape, state_keys) 49 | 50 | def _get_by_index(self, batch_shape, indices): 51 | state_keys = self.store[indices] 52 | return State(batch_shape, state_keys) 53 | 54 | def __getitem__(self, item): 55 | return self.get(item) 56 | 57 | def get(self, i): 58 | if self.ndim == 2: 59 | assert isinstance(i, tuple) 60 | i = i[0] * self.batch_shape[1] + i[1] 61 | i = np.array([i], dtype=np.int32) 62 | return self._get_by_index((1,), i) 63 | 64 | def __len__(self) -> int: 65 | return len(self.store) 66 | 67 | def __repr__(self) -> str: 68 | return f'State(batch_shape={self.batch_shape}, ndim={self.ndim})' 69 | 70 | def __str__(self) -> str: 71 | return f'State(batch_shape={self.batch_shape}, ndim={self.ndim})' 72 | -------------------------------------------------------------------------------- /mcts/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | __version__ = "0.0.1" 4 | 5 | INSTALL_REQUIRES = [ 6 | "setuptools", 7 | "wheel", 8 | "pybind11-stubgen", 9 | "numpy", 10 | ] 11 | 12 | setup( 13 | name="mcts", 14 | version=__version__, 15 | packages=find_packages(include='mcts*'), 16 | long_description="", 17 | install_requires=INSTALL_REQUIRES, 18 | python_requires=">=3.7", 19 | include_package_data=True, 20 | ) -------------------------------------------------------------------------------- /repo/packages/e/edopro-core/xmake.lua: -------------------------------------------------------------------------------- 1 | package("edopro-core") 2 | 3 | set_homepage("https://github.com/edo9300/ygopro-core") 4 | 5 | set_urls("https://github.com/edo9300/ygopro-core.git") 6 | 7 | -- set_sourcedir(path.join(os.scriptdir(), "edopro-core")) 8 | -- set_policy("package.install_always", true) 9 | 10 | add_deps("lua") 11 | 12 | on_install("linux", function (package) 13 | io.writefile("xmake.lua", [[ 14 | add_rules("mode.debug", "mode.release") 15 | add_requires("lua") 16 | target("edopro-core") 17 | set_kind("static") 18 | set_languages("c++17") 19 | add_files("*.cpp") 20 | add_headerfiles("*.h") 21 | add_headerfiles("RNG/*.hpp") 22 | add_packages("lua") 23 | ]]) 24 | 25 | local check_and_insert = function(file, line, insert) 26 | local lines = table.to_array(io.lines(file)) 27 | if lines[line] ~= insert then 28 | table.insert(lines, line, insert) 29 | io.writefile(file, table.concat(lines, "\n")) 30 | end 31 | end 32 | 33 | check_and_insert("interpreter.h", 12, "extern \"C\" {") 34 | check_and_insert("interpreter.h", 14, "}") 35 | 36 | check_and_insert("interpreter.h", 16, "extern \"C\" {") 37 | check_and_insert("interpreter.h", 19, "}") 38 | 39 | local configs = {} 40 | if package:config("shared") then 41 | configs.kind = "shared" 42 | end 43 | import("package.tools.xmake").install(package) 44 | os.cp("*.h", package:installdir("include", "edopro-core")) 45 | os.cp("RNG", package:installdir("include", "edopro-core")) 46 | end) 47 | package_end() -------------------------------------------------------------------------------- /repo/packages/y/ygopro-core/xmake.lua: -------------------------------------------------------------------------------- 1 | package("ygopro-core") 2 | 3 | set_homepage("https://github.com/Fluorohydride/ygopro-core") 4 | 5 | add_urls("https://github.com/Fluorohydride/ygopro-core.git") 6 | add_versions("0.0.1", "6ed45241ab9360fd832dbc5fe913aa0017f577fc") 7 | add_versions("0.0.2", "f96929650ff8685b82fd48670126eae406366734") 8 | 9 | add_deps("lua") 10 | 11 | on_install("linux", function (package) 12 | io.writefile("xmake.lua", [[ 13 | add_rules("mode.debug", "mode.release") 14 | add_requires("lua") 15 | target("ygopro-core") 16 | set_kind("static") 17 | add_files("*.cpp") 18 | add_headerfiles("*.h") 19 | add_packages("lua") 20 | ]]) 21 | 22 | local check_and_insert = function(file, line, insert) 23 | local lines = table.to_array(io.lines(file)) 24 | if lines[line] ~= insert then 25 | table.insert(lines, line, insert) 26 | io.writefile(file, table.concat(lines, "\n")) 27 | end 28 | end 29 | 30 | check_and_insert("field.h", 14, "#include ") 31 | check_and_insert("interpreter.h", 11, "extern \"C\" {") 32 | check_and_insert("interpreter.h", 15, "}") 33 | local configs = {} 34 | if package:config("shared") then 35 | configs.kind = "shared" 36 | end 37 | import("package.tools.xmake").install(package) 38 | os.cp("*.h", package:installdir("include", "ygopro-core")) 39 | end) 40 | package_end() -------------------------------------------------------------------------------- /scripts/card/code_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | from dataclasses import dataclass 5 | import tyro 6 | 7 | from ygoai.embed import read_cards 8 | 9 | @dataclass 10 | class Args: 11 | output: str = "code_list.txt" 12 | """the file containing the list of card codes""" 13 | cdb: str = "../assets/locale/en/cards.cdb" 14 | """the cards database file""" 15 | script_dir: str = "script" 16 | """path to the scripts directory""" 17 | 18 | if __name__ == "__main__": 19 | args = tyro.cli(Args) 20 | cards = read_cards(args.cdb)[1] 21 | 22 | pattern = os.path.join(args.script_dir, "c*.lua") 23 | # list all c*.lua files 24 | script_files = glob(pattern) 25 | 26 | codes = sorted([os.path.basename(f).split(".")[0][1:] for f in script_files]) 27 | # exclude constant.lua 28 | codes_s = set([int(c) for c in codes[:-1]]) 29 | codes_c = sorted([ c.code for c in cards ]) 30 | 31 | difference = codes_s.difference(codes_c) 32 | if len(difference) > 0: 33 | raise ValueError("Missing in cards.cdb: {difference}") 34 | 35 | print(f"Total {len(codes_c)} cards, {len(codes_s)} scripts") 36 | 37 | lines = [] 38 | for c in codes_c: 39 | line = f"{c} {1 if c in codes_s else 0}" 40 | lines.append(line) 41 | with open(args.output, "w") as f: 42 | f.write("\n".join(lines)) -------------------------------------------------------------------------------- /scripts/card/embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | import tyro 6 | 7 | import pickle 8 | import numpy as np 9 | 10 | import voyageai 11 | 12 | from ygoai.embed import read_cards 13 | from ygoai.utils import load_deck 14 | 15 | 16 | @dataclass 17 | class Args: 18 | deck_dir: str = "../assets/deck" 19 | """the directory of ydk files""" 20 | code_list_file: str = "code_list.txt" 21 | """the file containing the list of card codes""" 22 | embeddings_file: Optional[str] = None 23 | """the pickle file containing the embeddings of the cards""" 24 | cards_db: str = "../assets/locale/en/cards.cdb" 25 | """the cards database file""" 26 | batch_size: int = 64 27 | """the batch size for embedding generation""" 28 | wait_time: float = 0.1 29 | """the time to wait between each batch""" 30 | 31 | 32 | def get_embeddings(texts, batch_size=64, wait_time=0.1, verbose=False): 33 | vo = voyageai.Client() 34 | 35 | embeddings = [] 36 | for i in range(0, len(texts), batch_size): 37 | if verbose: 38 | print(f"Embedding {i} / {len(texts)}") 39 | embeddings += vo.embed( 40 | texts[i : i + batch_size], model="voyage-2", truncation=False).embeddings 41 | time.sleep(wait_time) 42 | embeddings = np.array(embeddings, dtype=np.float32) 43 | return embeddings 44 | 45 | 46 | def read_decks(d): 47 | # iterate over ydk files 48 | codes = [] 49 | for file in os.listdir(d): 50 | if file.endswith(".ydk"): 51 | file = os.path.join(d, file) 52 | codes += load_deck(file) 53 | return set(codes) 54 | 55 | 56 | def read_texts(cards_db, codes): 57 | df, cards = read_cards(cards_db) 58 | code2card = {c.code: c for c in cards} 59 | texts = [] 60 | for code in codes: 61 | texts.append(code2card[code].format()) 62 | return texts 63 | 64 | 65 | if __name__ == "__main__": 66 | args = tyro.cli(Args) 67 | 68 | deck_dir = args.deck_dir 69 | code_list_file = args.code_list_file 70 | embeddings_file = args.embeddings_file 71 | cards_db = args.cards_db 72 | 73 | # read code_list file 74 | if not os.path.exists(code_list_file): 75 | with open(code_list_file, "w") as f: 76 | f.write("") 77 | with open(code_list_file, "r") as f: 78 | code_list = f.readlines() 79 | code_list = [int(code.strip()) for code in code_list] 80 | print(f"The code list contains {len(code_list)} cards.") 81 | 82 | all_codes = set(code_list) 83 | 84 | new_codes = [] 85 | for code in read_decks(deck_dir): 86 | if code not in all_codes: 87 | new_codes.append(code) 88 | 89 | if new_codes == []: 90 | print("No new cards have been added to the code list.") 91 | else: 92 | # update code_list 93 | code_list += new_codes 94 | 95 | with open(code_list_file, "w") as f: 96 | f.write("\n".join(map(str, code_list)) + "\n") 97 | 98 | print(f"{len(new_codes)} new cards have been added to the code list.") 99 | 100 | if embeddings_file is not None: 101 | if not os.path.exists(embeddings_file): 102 | all_embeddings = {} 103 | else: 104 | all_embeddings = pickle.load(open(embeddings_file, "rb")) 105 | 106 | codes_not_in_embeddings = [code for code in code_list if code not in all_embeddings] 107 | if codes_not_in_embeddings == []: 108 | print("All cards have embeddings.") 109 | exit() 110 | print(f"{len(codes_not_in_embeddings)} cards do not have embeddings.") 111 | new_texts = read_texts(cards_db, codes_not_in_embeddings) 112 | print(new_texts) 113 | embeddings = get_embeddings(new_texts, args.batch_size, args.wait_time, verbose=True) 114 | embeddings = np.array(embeddings, dtype=np.float32) 115 | for code, embedding in zip(codes_not_in_embeddings, embeddings): 116 | all_embeddings[code] = embedding 117 | print(f"Embeddings of {len(codes_not_in_embeddings)} cards have been added.") 118 | pickle.dump(all_embeddings, open(embeddings_file, "wb")) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | from setuptools import find_packages, setup 5 | 6 | NAME = 'ygoai' 7 | IMPORT_NAME = 'ygoai' 8 | DESCRIPTION = "A Yu-Gi-Oh! AI." 9 | URL = 'https://github.com/sbl1996/ygo-agent' 10 | EMAIL = 'sbl1996@gmail.com' 11 | AUTHOR = 'Hastur' 12 | REQUIRES_PYTHON = '>=3.10.0' 13 | VERSION = None 14 | 15 | REQUIRED = [ 16 | "tyro", 17 | "pandas", 18 | "tensorboardX", 19 | "tqdm", 20 | ] 21 | 22 | here = os.path.dirname(os.path.abspath(__file__)) 23 | 24 | try: 25 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 26 | long_description = '\n' + f.read() 27 | except FileNotFoundError: 28 | long_description = DESCRIPTION 29 | 30 | about = {} 31 | if not VERSION: 32 | with open(os.path.join(here, IMPORT_NAME, '_version.py')) as f: 33 | exec(f.read(), about) 34 | else: 35 | about['__version__'] = VERSION 36 | 37 | 38 | setup( 39 | name=NAME, 40 | version=about['__version__'], 41 | description=DESCRIPTION, 42 | long_description=long_description, 43 | long_description_content_type='text/markdown', 44 | author=AUTHOR, 45 | author_email=EMAIL, 46 | python_requires=REQUIRES_PYTHON, 47 | url=URL, 48 | packages=find_packages(include='ygoai*'), 49 | install_requires=REQUIRED, 50 | dependency_links=[], 51 | license='MIT', 52 | ) -------------------------------------------------------------------------------- /xmake.lua: -------------------------------------------------------------------------------- 1 | add_rules("mode.debug", "mode.release") 2 | 3 | add_repositories("my-repo repo") 4 | 5 | add_requires( 6 | "ygopro-core 0.0.2", "edopro-core", "pybind11 2.13.*", "fmt 10.2.*", "glog 0.6.0", 7 | "sqlite3 3.43.0+200", "concurrentqueue 1.0.4", "unordered_dense 4.4.*", 8 | "sqlitecpp 3.2.1") 9 | 10 | 11 | target("ygopro0_ygoenv") 12 | add_rules("python.library") 13 | add_files("ygoenv/ygoenv/ygopro0/*.cpp") 14 | add_packages("pybind11", "fmt", "glog", "concurrentqueue", "sqlitecpp", "unordered_dense", "ygopro-core") 15 | set_languages("c++17") 16 | if is_mode("release") then 17 | set_policy("build.optimization.lto", true) 18 | add_cxxflags("-march=native") 19 | end 20 | add_includedirs("ygoenv") 21 | 22 | after_build(function (target) 23 | local install_target = "$(projectdir)/ygoenv/ygoenv/ygopro0" 24 | os.cp(target:targetfile(), install_target) 25 | print("Copy target to " .. install_target) 26 | end) 27 | 28 | 29 | target("ygopro_ygoenv") 30 | add_rules("python.library") 31 | add_files("ygoenv/ygoenv/ygopro/*.cpp") 32 | add_packages("pybind11", "fmt", "glog", "concurrentqueue", "sqlitecpp", "unordered_dense", "ygopro-core") 33 | set_languages("c++17") 34 | if is_mode("release") then 35 | set_policy("build.optimization.lto", true) 36 | add_cxxflags("-march=native") 37 | end 38 | add_includedirs("ygoenv") 39 | 40 | after_build(function (target) 41 | local install_target = "$(projectdir)/ygoenv/ygoenv/ygopro" 42 | os.cp(target:targetfile(), install_target) 43 | print("Copy target to " .. install_target) 44 | end) 45 | 46 | target("edopro_ygoenv") 47 | add_rules("python.library") 48 | add_files("ygoenv/ygoenv/edopro/*.cpp") 49 | add_packages("pybind11", "fmt", "glog", "concurrentqueue", "sqlitecpp", "unordered_dense", "edopro-core") 50 | set_languages("c++17") 51 | if is_mode("release") then 52 | set_policy("build.optimization.lto", true) 53 | add_cxxflags("-march=native") 54 | end 55 | add_includedirs("ygoenv") 56 | 57 | after_build(function (target) 58 | local install_target = "$(projectdir)/ygoenv/ygoenv/edopro" 59 | os.cp(target:targetfile(), install_target) 60 | print("Copy target to " .. install_target) 61 | end) 62 | 63 | 64 | target("alphazero_mcts") 65 | add_rules("python.library") 66 | add_files("mcts/mcts/alphazero/*.cpp") 67 | add_packages("pybind11") 68 | set_languages("c++17") 69 | if is_mode("release") then 70 | set_policy("build.optimization.lto", true) 71 | add_cxxflags("-march=native") 72 | end 73 | add_includedirs("mcts") 74 | 75 | after_build(function (target) 76 | local install_target = "$(projectdir)/mcts/mcts/alphazero" 77 | os.cp(target:targetfile(), install_target) 78 | print("Copy target to " .. install_target) 79 | os.run("pybind11-stubgen mcts.alphazero.alphazero_mcts -o %s", "$(projectdir)/mcts") 80 | end) 81 | -------------------------------------------------------------------------------- /ygoai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbl1996/ygo-agent/dbf5142d49aab2e6beb4150788d4fffec39ae3e5/ygoai/__init__.py -------------------------------------------------------------------------------- /ygoai/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" -------------------------------------------------------------------------------- /ygoai/embed.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from dataclasses import dataclass 3 | 4 | import sqlite3 5 | 6 | import pandas as pd 7 | 8 | from ygoai.constants import TYPE, type2str, attribute2str, race2str 9 | 10 | 11 | def parse_types(value): 12 | types = [] 13 | all_types = list(type2str.keys()) 14 | for t in all_types: 15 | if value & t: 16 | types.append(type2str[t]) 17 | return types 18 | 19 | 20 | def parse_attribute(value): 21 | attribute = attribute2str.get(value, None) 22 | assert attribute, "Invalid attribute, value: " + str(value) 23 | return attribute 24 | 25 | 26 | def parse_race(value): 27 | race = race2str.get(value, None) 28 | assert race, "Invalid race, value: " + str(value) 29 | return race 30 | 31 | 32 | @dataclass 33 | class Card: 34 | code: int 35 | name: str 36 | desc: str 37 | types: List[str] 38 | 39 | def format(self): 40 | return format_card(self) 41 | 42 | 43 | @dataclass 44 | class MonsterCard(Card): 45 | atk: int 46 | def_: int 47 | level: int 48 | race: str 49 | attribute: str 50 | 51 | 52 | @dataclass 53 | class SpellCard(Card): 54 | pass 55 | 56 | 57 | @dataclass 58 | class TrapCard(Card): 59 | pass 60 | 61 | 62 | def format_monster_card(card: MonsterCard): 63 | name = card.name 64 | typ = "/".join(card.types) 65 | 66 | attribute = card.attribute 67 | race = card.race 68 | 69 | level = str(card.level) 70 | 71 | atk = str(card.atk) 72 | if atk == '-2': 73 | atk = '?' 74 | 75 | def_ = str(card.def_) 76 | if def_ == '-2': 77 | def_ = '?' 78 | 79 | if typ == 'Monster/Normal': 80 | desc = "-" 81 | else: 82 | desc = card.desc 83 | 84 | columns = [name, typ, attribute, race, level, atk, def_, desc] 85 | return " | ".join(columns) 86 | 87 | 88 | def format_spell_trap_card(card: Union[SpellCard, TrapCard]): 89 | name = card.name 90 | typ = "/".join(card.types) 91 | desc = card.desc 92 | 93 | columns = [name, typ, desc] 94 | return " | ".join(columns) 95 | 96 | 97 | def format_card(card: Card): 98 | if isinstance(card, MonsterCard): 99 | return format_monster_card(card) 100 | elif isinstance(card, (SpellCard, TrapCard)): 101 | return format_spell_trap_card(card) 102 | else: 103 | raise ValueError("Invalid card type: " + str(card)) 104 | 105 | 106 | ## For analyzing cards.db 107 | 108 | def parse_monster_card(data) -> MonsterCard: 109 | code = int(data['id']) 110 | name = data['name'] 111 | desc = data['desc'] 112 | 113 | types = parse_types(int(data['type'])) 114 | 115 | atk = int(data['atk']) 116 | def_ = int(data['def']) 117 | level = int(data['level']) 118 | 119 | if level >= 16: 120 | # pendulum monster 121 | level = level % 16 122 | 123 | race = parse_race(int(data['race'])) 124 | attribute = parse_attribute(int(data['attribute'])) 125 | return MonsterCard(code, name, desc, types, atk, def_, level, race, attribute) 126 | 127 | 128 | def parse_spell_card(data) -> SpellCard: 129 | code = int(data['id']) 130 | name = data['name'] 131 | desc = data['desc'] 132 | 133 | types = parse_types(int(data['type'])) 134 | return SpellCard(code, name, desc, types) 135 | 136 | 137 | def parse_trap_card(data) -> TrapCard: 138 | code = int(data['id']) 139 | name = data['name'] 140 | desc = data['desc'] 141 | 142 | types = parse_types(int(data['type'])) 143 | return TrapCard(code, name, desc, types) 144 | 145 | 146 | def parse_card(data) -> Card: 147 | type_ = data['type'] 148 | if type_ & TYPE.MONSTER: 149 | return parse_monster_card(data) 150 | elif type_ & TYPE.SPELL: 151 | return parse_spell_card(data) 152 | elif type_ & TYPE.TRAP: 153 | return parse_trap_card(data) 154 | else: 155 | raise ValueError("Invalid card type: " + str(type_)) 156 | 157 | 158 | def read_cards(cards_path): 159 | conn = sqlite3.connect(cards_path) 160 | cursor = conn.cursor() 161 | 162 | cursor.execute("SELECT * FROM datas") 163 | datas_rows = cursor.fetchall() 164 | datas_columns = [description[0] for description in cursor.description] 165 | datas_df = pd.DataFrame(datas_rows, columns=datas_columns) 166 | 167 | cursor.execute("SELECT * FROM texts") 168 | texts_rows = cursor.fetchall() 169 | texts_columns = [description[0] for description in cursor.description] 170 | texts_df = pd.DataFrame(texts_rows, columns=texts_columns) 171 | 172 | cursor.close() 173 | conn.close() 174 | 175 | texts_df = texts_df.loc[:, ['id', 'name', 'desc']] 176 | merged_df = pd.merge(texts_df, datas_df, on='id') 177 | 178 | cards_data = merged_df.to_dict('records') 179 | cards = [parse_card(data) for data in cards_data] 180 | return merged_df, cards 181 | -------------------------------------------------------------------------------- /ygoai/rl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbl1996/ygo-agent/dbf5142d49aab2e6beb4150788d4fffec39ae3e5/ygoai/rl/__init__.py -------------------------------------------------------------------------------- /ygoai/rl/ckpt.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import os 4 | import shutil 5 | from pathlib import Path 6 | import zipfile 7 | 8 | 9 | class ModelCheckpoint(object): 10 | """ ModelCheckpoint handler can be used to periodically save objects to disk. 11 | 12 | Args: 13 | dirname (str): 14 | Directory path where objects will be saved. 15 | save_fn (callable): 16 | Function that will be called to save the object. It should have the signature `save_fn(obj, path)`. 17 | n_saved (int, optional): 18 | Number of objects that should be kept on disk. Older files will be removed. 19 | """ 20 | 21 | def __init__(self, dirname, save_fn, n_saved=1): 22 | self._dirname = Path(dirname).expanduser().absolute() 23 | self._n_saved = n_saved 24 | self._save_fn = save_fn 25 | self._saved: List[Path] = [] 26 | 27 | def _check_dir(self): 28 | self._dirname.mkdir(parents=True, exist_ok=True) 29 | 30 | # Ensure that dirname exists 31 | if not self._dirname.exists(): 32 | raise ValueError( 33 | "Directory path '{}' is not found".format(self._dirname)) 34 | 35 | def save(self, obj, name): 36 | self._check_dir() 37 | path = self._dirname / name 38 | self._save_fn(obj, str(path)) 39 | self._saved.append(path) 40 | print(f"Saved model to {path}") 41 | 42 | if len(self._saved) > self._n_saved: 43 | to_remove = self._saved.pop(0) 44 | if to_remove != path: 45 | if to_remove.is_dir(): 46 | shutil.rmtree(to_remove) 47 | else: 48 | if to_remove.exists(): 49 | os.remove(to_remove) 50 | 51 | def get_latest(self): 52 | path = self._saved[-1] 53 | return path 54 | 55 | 56 | def sync_to_gcs(bucket, source, dest=None): 57 | if bucket.startswith("gs://"): 58 | bucket = bucket[5:] 59 | if dest is None: 60 | dest = Path(source).name 61 | gcs_url = Path(bucket) / dest 62 | gcs_url = f"gs://{gcs_url}" 63 | os.system(f"gsutil cp {source} {gcs_url} > /dev/null 2>&1 &") 64 | print(f"Sync to GCS: {gcs_url}") 65 | 66 | 67 | def zip_files(zip_file_path, files_to_zip): 68 | """ 69 | Creates a zip file at the specified path, containing the files and directories 70 | specified in files_to_zip. 71 | 72 | Args: 73 | zip_file_path (str): The path to the zip file to be created. 74 | files_to_zip (list): A list of paths to files and directories to be zipped. 75 | """ 76 | with zipfile.ZipFile(zip_file_path, mode='w') as zip_file: 77 | for file_path in files_to_zip: 78 | # Check if the path is a file or a directory 79 | if os.path.isfile(file_path): 80 | # If it's a file, add it to the zip file 81 | zip_file.write(file_path) 82 | elif os.path.isdir(file_path): 83 | # If it's a directory, add all its files and subdirectories to the zip file 84 | for root, dirs, files in os.walk(file_path): 85 | for file in files: 86 | file_path = os.path.join(root, file) 87 | zip_file.write(file_path) 88 | -------------------------------------------------------------------------------- /ygoai/rl/dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import datetime 4 | import torch 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | 8 | 9 | def reduce_gradidents(params, world_size): 10 | if world_size == 1: 11 | return 12 | all_grads_list = [] 13 | for param in params: 14 | if param.grad is not None: 15 | all_grads_list.append(param.grad.view(-1)) 16 | all_grads = torch.cat(all_grads_list) 17 | dist.all_reduce(all_grads, op=dist.ReduceOp.SUM) 18 | offset = 0 19 | for param in params: 20 | if param.grad is not None: 21 | param.grad.data.copy_( 22 | all_grads[offset : offset + param.numel()].view_as(param.grad.data) / world_size 23 | ) 24 | offset += param.numel() 25 | 26 | 27 | def test_nccl(local_rank): 28 | # manual init nccl 29 | x = torch.rand(4, device=f'cuda:{local_rank}') 30 | dist.all_reduce(x, op=dist.ReduceOp.SUM) 31 | x.mean().item() 32 | dist.barrier() 33 | 34 | 35 | def torchrun_setup(backend, local_rank): 36 | dist.init_process_group( 37 | backend, timeout=datetime.timedelta(seconds=60 * 30)) 38 | test_nccl(local_rank) 39 | 40 | 41 | def setup(backend, rank, world_size, port): 42 | os.environ['MASTER_ADDR'] = '127.0.0.1' 43 | os.environ['MASTER_PORT'] = str(port) 44 | dist.init_process_group( 45 | backend, rank=rank, world_size=world_size, 46 | timeout=datetime.timedelta(seconds=60 * 30)) 47 | 48 | test_nccl(rank) 49 | 50 | 51 | def mp_start(run): 52 | world_size = int(os.getenv("WORLD_SIZE", "1")) 53 | if world_size == 1: 54 | run(local_rank=0, world_size=world_size) 55 | else: 56 | # mp.set_start_method('spawn') 57 | children = [] 58 | for i in range(world_size): 59 | subproc = mp.Process(target=run, args=(i, world_size)) 60 | children.append(subproc) 61 | subproc.start() 62 | 63 | for i in range(world_size): 64 | children[i].join() 65 | 66 | 67 | def fprint(msg): 68 | sys.stdout.flush() 69 | sys.stdout.write(msg + os.linesep) 70 | sys.stdout.flush() 71 | -------------------------------------------------------------------------------- /ygoai/rl/env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gymnasium as gym 3 | 4 | 5 | class RecordEpisodeStatistics(gym.Wrapper): 6 | def __init__(self, env): 7 | super().__init__(env) 8 | self.num_envs = getattr(env, "num_envs", 1) 9 | self.episode_returns = None 10 | self.episode_lengths = None 11 | 12 | def reset(self, **kwargs): 13 | observations, infos = self.env.reset(**kwargs) 14 | self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) 15 | self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 16 | self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32) 17 | self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 18 | return observations, infos 19 | 20 | def step(self, action): 21 | return self.update_stats_and_infos(*super().step(action)) 22 | 23 | def update_stats_and_infos(self, *args): 24 | observations, rewards, terminated, truncated, infos = args 25 | dones = np.logical_or(terminated, truncated) 26 | self.episode_returns += infos.get("reward", rewards) 27 | self.episode_lengths += 1 28 | self.returned_episode_returns = np.where( 29 | dones, self.episode_returns, self.returned_episode_returns 30 | ) 31 | self.returned_episode_lengths = np.where( 32 | dones, self.episode_lengths, self.returned_episode_lengths 33 | ) 34 | self.episode_returns *= 1 - dones 35 | self.episode_lengths *= 1 - dones 36 | infos["r"] = self.returned_episode_returns 37 | infos["l"] = self.returned_episode_lengths 38 | 39 | return ( 40 | observations, 41 | rewards, 42 | dones, 43 | infos, 44 | ) 45 | 46 | def async_reset(self): 47 | self.env.async_reset() 48 | self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) 49 | self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 50 | self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32) 51 | self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32) 52 | 53 | def recv(self): 54 | return self.update_stats_and_infos(*self.env.recv()) 55 | 56 | def send(self, action): 57 | return self.env.send(action) 58 | 59 | 60 | class CompatEnv(gym.Wrapper): 61 | 62 | def reset(self, **kwargs): 63 | observations, infos = self.env.reset(**kwargs) 64 | return observations, infos 65 | 66 | def step(self, action): 67 | observations, rewards, terminated, truncated, infos = super().step(action) 68 | dones = np.logical_or(terminated, truncated) 69 | return ( 70 | observations, 71 | rewards, 72 | dones, 73 | infos, 74 | ) 75 | 76 | 77 | class EnvPreprocess(gym.Wrapper): 78 | 79 | def __init__(self, env, skip_mask): 80 | super().__init__(env) 81 | self.num_envs = env.num_envs 82 | self.skip_mask = skip_mask 83 | 84 | def reset(self, **kwargs): 85 | observations, infos = self.env.reset(**kwargs) 86 | if self.skip_mask: 87 | observations['mask_'] = None 88 | return observations, infos 89 | 90 | def step(self, action): 91 | observations, rewards, terminated, truncated, infos = super().step(action) 92 | if self.skip_mask: 93 | observations['mask_'] = None 94 | return ( 95 | observations, 96 | rewards, 97 | terminated, 98 | truncated, 99 | infos, 100 | ) -------------------------------------------------------------------------------- /ygoai/rl/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.cuda.amp import autocast 4 | 5 | from ygoai.rl.utils import to_tensor 6 | 7 | 8 | def evaluate(envs, model, num_episodes, device, fp16_eval=False): 9 | episode_lengths = [] 10 | episode_rewards = [] 11 | eval_win_rates = [] 12 | obs = envs.reset()[0] 13 | while True: 14 | obs = to_tensor(obs, device, dtype=torch.uint8) 15 | with torch.no_grad(): 16 | with autocast(enabled=fp16_eval): 17 | logits = model(obs)[0] 18 | probs = torch.softmax(logits, dim=-1) 19 | probs = probs.cpu().numpy() 20 | actions = probs.argmax(axis=1) 21 | 22 | obs, rewards, dones, info = envs.step(actions) 23 | 24 | for idx, d in enumerate(dones): 25 | if d: 26 | episode_length = info['l'][idx] 27 | episode_reward = info['r'][idx] 28 | win = 1 if episode_reward > 0 else 0 29 | 30 | episode_lengths.append(episode_length) 31 | episode_rewards.append(episode_reward) 32 | eval_win_rates.append(win) 33 | if len(episode_lengths) >= num_episodes: 34 | break 35 | 36 | eval_return = np.mean(episode_rewards[:num_episodes]) 37 | eval_ep_len = np.mean(episode_lengths[:num_episodes]) 38 | eval_win_rate = np.mean(eval_win_rates[:num_episodes]) 39 | return eval_return, eval_ep_len, eval_win_rate -------------------------------------------------------------------------------- /ygoai/rl/jax/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def evaluate(envs, num_episodes, predict_fn, rnn_state=None): 5 | episode_lengths = [] 6 | episode_rewards = [] 7 | win_rates = [] 8 | obs = envs.reset()[0] 9 | collected = np.zeros((num_episodes,), dtype=np.bool_) 10 | while True: 11 | if rnn_state is None: 12 | actions = predict_fn(obs) 13 | else: 14 | rnn_state, actions = predict_fn(obs, rnn_state) 15 | actions = np.array(actions) 16 | 17 | obs, rewards, dones, info = envs.step(actions) 18 | 19 | for idx, d in enumerate(dones): 20 | if not d or collected[idx]: 21 | continue 22 | collected[idx] = True 23 | episode_length = info['l'][idx] 24 | episode_reward = info['r'][idx] 25 | win = 1 if episode_reward > 0 else 0 26 | 27 | episode_lengths.append(episode_length) 28 | episode_rewards.append(episode_reward) 29 | win_rates.append(win) 30 | if len(episode_lengths) >= num_episodes: 31 | break 32 | 33 | eval_return = np.mean(episode_rewards[:num_episodes]) 34 | eval_ep_len = np.mean(episode_lengths[:num_episodes]) 35 | eval_win_rate = np.mean(win_rates[:num_episodes]) 36 | return eval_return, eval_ep_len, eval_win_rate 37 | 38 | 39 | def battle(envs, num_episodes, predict_fn, rstate1=None, rstate2=None): 40 | assert num_episodes == envs.num_envs 41 | num_envs = envs.num_envs 42 | episode_rewards = [] 43 | episode_lengths = [] 44 | win_rates = [] 45 | 46 | obs, infos = envs.reset() 47 | next_to_play = infos['to_play'] 48 | dones = np.zeros(num_envs, dtype=np.bool_) 49 | collected = np.zeros((num_episodes,), dtype=np.bool_) 50 | 51 | main_player = np.concatenate([ 52 | np.zeros(num_envs // 2, dtype=np.int64), 53 | np.ones(num_envs - num_envs // 2, dtype=np.int64) 54 | ]) 55 | 56 | while True: 57 | main = next_to_play == main_player 58 | rstate1, rstate2, actions = predict_fn(obs, rstate1, rstate2, main, dones) 59 | actions = np.array(actions) 60 | 61 | obs, rewards, dones, infos = envs.step(actions) 62 | next_to_play = infos['to_play'] 63 | 64 | for idx, d in enumerate(dones): 65 | if not d or collected[idx]: 66 | continue 67 | collected[idx] = True 68 | episode_length = infos['l'][idx] 69 | episode_reward = infos['r'][idx] * (1 if main[idx] else -1) 70 | win = 1 if episode_reward > 0 else 0 71 | 72 | episode_lengths.append(episode_length) 73 | episode_rewards.append(episode_reward) 74 | win_rates.append(win) 75 | if len(episode_lengths) >= num_episodes: 76 | break 77 | 78 | eval_return = np.mean(episode_rewards[:num_episodes]) 79 | eval_ep_len = np.mean(episode_lengths[:num_episodes]) 80 | eval_win_rate = np.mean(win_rates[:num_episodes]) 81 | return eval_return, eval_ep_len, eval_win_rate 82 | -------------------------------------------------------------------------------- /ygoai/rl/jax/nnx/rnn.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import nnx 4 | 5 | 6 | default_kernel_init = nnx.initializers.lecun_normal() 7 | default_bias_init = nnx.initializers.zeros_init() 8 | 9 | 10 | class OptimizedLSTMCell(nnx.Module): 11 | 12 | def __init__( 13 | self, in_features, features: int, *, 14 | gate_fn=nnx.sigmoid, activation_fn=nnx.tanh, 15 | kernel_init=default_kernel_init, bias_init=default_bias_init, 16 | recurrent_kernel_init=nnx.initializers.orthogonal(), 17 | dtype=None, param_dtype=jnp.float32, rngs, 18 | ): 19 | self.features = features 20 | self.gate_fn = gate_fn 21 | self.activation_fn = activation_fn 22 | 23 | self.fc_i = nnx.Linear( 24 | in_features, 4 * features, 25 | use_bias=False, kernel_init=kernel_init, 26 | bias_init=bias_init, dtype=dtype, 27 | param_dtype=param_dtype, rngs=rngs, 28 | ) 29 | self.fc_h = nnx.Linear( 30 | features, 4 * features, 31 | use_bias=True, kernel_init=recurrent_kernel_init, 32 | bias_init=bias_init, dtype=dtype, 33 | param_dtype=param_dtype, rngs=rngs, 34 | ) 35 | 36 | def __call__(self, carry, inputs): 37 | c, h = carry 38 | 39 | dense_i = self.fc_i(inputs) 40 | dense_h = self.fc_h(h) 41 | 42 | i, f, g, o = jnp.split(dense_i + dense_h, indices_or_sections=4, axis=-1) 43 | i, f, g, o = self.gate_fn(i), self.gate_fn(f), self.activation_fn(g), self.gate_fn(o) 44 | 45 | new_c = f * c + i * g 46 | new_h = o * self.activation_fn(new_c) 47 | return (new_c, new_h), new_h 48 | 49 | 50 | class GRUCell(nnx.Module): 51 | 52 | def __init__( 53 | self, in_features: int, features: int, *, 54 | gate_fn=nnx.sigmoid, activation_fn=nnx.tanh, 55 | kernel_init=default_kernel_init, bias_init=default_bias_init, 56 | recurrent_kernel_init=nnx.initializers.orthogonal(), 57 | dtype=None, param_dtype=jnp.float32, rngs, 58 | ): 59 | self.features = features 60 | self.gate_fn = gate_fn 61 | self.activation_fn = activation_fn 62 | 63 | self.fc_i = nnx.Linear( 64 | in_features, 3 * features, 65 | use_bias=True, kernel_init=kernel_init, 66 | bias_init=bias_init, dtype=dtype, 67 | param_dtype=param_dtype, rngs=rngs, 68 | ) 69 | self.fc_h = nnx.Linear( 70 | features, 3 * features, 71 | use_bias=True, kernel_init=recurrent_kernel_init, 72 | bias_init=bias_init, dtype=dtype, 73 | param_dtype=param_dtype, rngs=rngs, 74 | ) 75 | 76 | def __call__(self, carry, inputs): 77 | h = carry 78 | 79 | dense_i = self.fc_i(inputs) 80 | dense_h = self.fc_h(h) 81 | 82 | ir, iz, in_ = jnp.split(dense_i, indices_or_sections=3, axis=-1) 83 | hr, hz, hn = jnp.split(dense_h, indices_or_sections=3, axis=-1) 84 | 85 | r = self.gate_fn(ir + hr) 86 | z = self.gate_fn(iz + hz) 87 | n = self.activation_fn(in_ + r * hn) 88 | new_h = (1.0 - z) * n + z * h 89 | return new_h, new_h 90 | -------------------------------------------------------------------------------- /ygoai/rl/jax/switch.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | def truncated_gae_sep( 6 | next_value, values, rewards, next_dones, switch, gamma, gae_lambda, upgo 7 | ): 8 | def body_fn(carry, inp): 9 | boot_value, boot_done, next_value, lastgaelam, next_q, last_return = carry 10 | next_done, cur_value, reward, switch = inp 11 | 12 | next_done = jnp.where(switch, boot_done, next_done) 13 | next_value = jnp.where(switch, -boot_value, next_value) 14 | lastgaelam = jnp.where(switch, 0, lastgaelam) 15 | next_q = jnp.where(switch, -boot_value * gamma, next_q) 16 | last_return = jnp.where(switch, -boot_value, last_return) 17 | 18 | discount = gamma * (1.0 - next_done) 19 | last_return = reward + discount * jnp.where( 20 | next_q >= next_value, last_return, next_value) 21 | next_q = reward + discount * next_value 22 | delta = next_q - cur_value 23 | lastgaelam = delta + gae_lambda * discount * lastgaelam 24 | carry = boot_value, boot_done, cur_value, lastgaelam, next_q, last_return 25 | return carry, (lastgaelam, last_return) 26 | 27 | next_done = next_dones[-1] 28 | lastgaelam = jnp.zeros_like(next_value) 29 | next_q = last_return = next_value 30 | carry = next_value, next_done, next_value, lastgaelam, next_q, last_return 31 | 32 | _, (advantages, returns) = jax.lax.scan( 33 | body_fn, carry, (next_dones, values, rewards, switch), reverse=True 34 | ) 35 | targets = values + advantages 36 | if upgo: 37 | advantages += returns - values 38 | targets = jax.lax.stop_gradient(targets) 39 | return targets, advantages 40 | -------------------------------------------------------------------------------- /ygoai/rl/jax/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from flax import core, struct 7 | from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT 8 | 9 | import optax 10 | 11 | import numpy as np 12 | 13 | 14 | def masked_mean(x, valid): 15 | x = jnp.where(valid, x, jnp.zeros_like(x)) 16 | return x.sum() / valid.sum() 17 | 18 | 19 | def masked_normalize(x, valid, eps=1e-8): 20 | x = jnp.where(valid, x, jnp.zeros_like(x)) 21 | n = valid.sum() 22 | mean = x.sum() / n 23 | variance = jnp.square(x - mean).sum() / n 24 | return (x - mean) / jnp.sqrt(variance + eps) 25 | 26 | 27 | def categorical_sample(logits, key): 28 | # sample action: Gumbel-softmax trick 29 | # see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution 30 | key, subkey = jax.random.split(key) 31 | u = jax.random.uniform(subkey, shape=logits.shape) 32 | action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=-1) 33 | return action, key 34 | 35 | 36 | class RunningMeanStd: 37 | """Tracks the mean, variance and count of values.""" 38 | 39 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 40 | def __init__(self, epsilon=1e-4, shape=()): 41 | """Tracks the mean, variance and count of values.""" 42 | self.mean = np.zeros(shape, "float64") 43 | self.var = np.ones(shape, "float64") 44 | self.count = epsilon 45 | 46 | def update(self, x): 47 | """Updates the mean, var and count from a batch of samples.""" 48 | batch_mean = np.mean(x, axis=0) 49 | batch_var = np.var(x, axis=0) 50 | batch_count = x.shape[0] 51 | self.update_from_moments(batch_mean, batch_var, batch_count) 52 | 53 | def update_from_moments(self, batch_mean, batch_var, batch_count): 54 | """Updates from batch mean, variance and count moments.""" 55 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 56 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count 57 | ) 58 | 59 | 60 | def update_mean_var_count_from_moments( 61 | mean, var, count, batch_mean, batch_var, batch_count 62 | ): 63 | """Updates the mean, var and count using the previous mean, var, count and batch values.""" 64 | delta = batch_mean - mean 65 | tot_count = count + batch_count 66 | 67 | new_mean = mean + delta * batch_count / tot_count 68 | m_a = var * count 69 | m_b = batch_var * batch_count 70 | M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count 71 | new_var = M2 / tot_count 72 | new_count = tot_count 73 | 74 | return new_mean, new_var, new_count 75 | 76 | 77 | class TrainState(struct.PyTreeNode): 78 | step: int 79 | apply_fn: Callable = struct.field(pytree_node=False) 80 | params: core.FrozenDict[str, Any] = struct.field(pytree_node=True) 81 | tx: optax.GradientTransformation = struct.field(pytree_node=False) 82 | opt_state: optax.OptState = struct.field(pytree_node=True) 83 | batch_stats: core.FrozenDict[str, Any] = struct.field(pytree_node=True) 84 | 85 | def apply_gradients(self, *, grads, **kwargs): 86 | """Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value. 87 | 88 | Note that internally this function calls ``.tx.update()`` followed by a call 89 | to ``optax.apply_updates()`` to update ``params`` and ``opt_state``. 90 | 91 | Args: 92 | grads: Gradients that have the same pytree structure as ``.params``. 93 | **kwargs: Additional dataclass attributes that should be ``.replace()``-ed. 94 | 95 | Returns: 96 | An updated instance of ``self`` with ``step`` incremented by one, ``params`` 97 | and ``opt_state`` updated by applying ``grads``, and additional attributes 98 | replaced as specified by ``kwargs``. 99 | """ 100 | if OVERWRITE_WITH_GRADIENT in grads: 101 | grads_with_opt = grads['params'] 102 | params_with_opt = self.params['params'] 103 | else: 104 | grads_with_opt = grads 105 | params_with_opt = self.params 106 | 107 | updates, new_opt_state = self.tx.update( 108 | grads_with_opt, self.opt_state, params_with_opt 109 | ) 110 | new_params_with_opt = optax.apply_updates(params_with_opt, updates) 111 | 112 | # As implied by the OWG name, the gradients are used directly to update the 113 | # parameters. 114 | if OVERWRITE_WITH_GRADIENT in grads: 115 | new_params = { 116 | 'params': new_params_with_opt, 117 | OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT], 118 | } 119 | else: 120 | new_params = new_params_with_opt 121 | return self.replace( 122 | step=self.step + 1, 123 | params=new_params, 124 | opt_state=new_opt_state, 125 | **kwargs, 126 | ) 127 | 128 | @classmethod 129 | def create(cls, *, apply_fn, params, tx, **kwargs): 130 | """Creates a new instance with ``step=0`` and initialized ``opt_state``.""" 131 | # We exclude OWG params when present because they do not need opt states. 132 | params_with_opt = ( 133 | params['params'] if OVERWRITE_WITH_GRADIENT in params else params 134 | ) 135 | opt_state = tx.init(params_with_opt) 136 | return cls( 137 | step=0, 138 | apply_fn=apply_fn, 139 | params=params, 140 | tx=tx, 141 | opt_state=opt_state, 142 | **kwargs, 143 | ) -------------------------------------------------------------------------------- /ygoai/rl/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import optree 4 | import torch 5 | 6 | from ygoai.rl.env import RecordEpisodeStatistics, EnvPreprocess 7 | 8 | 9 | def split_param_groups(model, regex): 10 | embed_params = [] 11 | other_params = [] 12 | for name, param in model.named_parameters(): 13 | if re.search(regex, name): 14 | embed_params.append(param) 15 | else: 16 | other_params.append(param) 17 | return [ 18 | {'params': embed_params}, {'params': other_params} 19 | ] 20 | 21 | 22 | class Elo: 23 | 24 | def __init__(self, k = 10, r0 = 1500, r1 = 1500): 25 | self.r0 = r0 26 | self.r1 = r1 27 | self.k = k 28 | 29 | def update(self, winner): 30 | diff = self.k * (1 - self.expect_result(self.r0, self.r1)) 31 | if winner == 1: 32 | diff = -diff 33 | self.r0 += diff 34 | self.r1 -= diff 35 | 36 | def expect_result(self, p0, p1): 37 | exp = (p0 - p1) / 400.0 38 | return 1 / ((10.0 ** (exp)) + 1) 39 | 40 | 41 | def masked_mean(x, valid): 42 | x = x.masked_fill(~valid, 0) 43 | return x.sum() / valid.float().sum() 44 | 45 | 46 | def masked_normalize(x, valid, eps=1e-8): 47 | x = x.masked_fill(~valid, 0) 48 | n = valid.float().sum() 49 | mean = x.sum() / n 50 | var = ((x - mean) ** 2).sum() / n 51 | std = (var + eps).sqrt() 52 | return (x - mean) / std 53 | 54 | 55 | def to_tensor(x, device, dtype=None): 56 | return optree.tree_map(lambda x: torch.from_numpy(x).to(device=device, dtype=dtype, non_blocking=True), x) 57 | 58 | -------------------------------------------------------------------------------- /ygoai/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from pathlib import Path 4 | 5 | 6 | def load_deck(fn): 7 | with open(fn) as f: 8 | lines = f.readlines() 9 | deck = [int(line) for line in lines if line[:-1].isdigit()] 10 | return deck 11 | 12 | 13 | def get_root_directory(): 14 | cur = Path(__file__).resolve() 15 | return str(cur.parent.parent) 16 | 17 | 18 | def extract_deck_name(path): 19 | return Path(path).stem 20 | 21 | _languages = { 22 | "english": "en", 23 | "chinese": "zh", 24 | } 25 | 26 | def init_ygopro(env_id, lang, deck, code_list_file, preload_tokens=False, return_deck_names=False): 27 | short = _languages[lang] 28 | db_path = Path(get_root_directory(), 'assets', 'locale', short, 'cards.cdb') 29 | deck_fp = Path(deck) 30 | if deck_fp.is_dir(): 31 | decks = {f.stem: str(f) for f in deck_fp.glob("*.ydk")} 32 | deck_dir = deck_fp 33 | deck_name = 'random' 34 | else: 35 | deck_name = deck_fp.stem 36 | decks = {deck_name: deck} 37 | deck_dir = deck_fp.parent 38 | if preload_tokens: 39 | token_deck = deck_dir / "_tokens.ydk" 40 | if not token_deck.exists(): 41 | raise FileNotFoundError(f"Token deck not found: {token_deck}") 42 | decks["_tokens"] = str(token_deck) 43 | if 'YGOPro' in env_id: 44 | if env_id == 'YGOPro-v1': 45 | from ygoenv.ygopro import init_module 46 | elif env_id == 'YGOPro-v0': 47 | from ygoenv.ygopro0 import init_module 48 | else: 49 | raise ValueError(f"Unknown YGOPro environment: {env_id}") 50 | elif 'EDOPro' in env_id: 51 | from ygoenv.edopro import init_module 52 | init_module(str(db_path), code_list_file, decks) 53 | if return_deck_names: 54 | if "_tokens" in decks: 55 | del decks["_tokens"] 56 | return deck_name, list(decks.keys()) 57 | return deck_name 58 | 59 | 60 | def load_embeddings(embedding_file, code_list_file, pad_to=999): 61 | with open(embedding_file, "rb") as f: 62 | embeddings = pickle.load(f) 63 | with open(code_list_file, "r") as f: 64 | code_list = f.readlines() 65 | code_list = [int(code.strip()) for code in code_list] 66 | assert len(embeddings) == len(code_list), f"len(embeddings)={len(embeddings)}, len(code_list)={len(code_list)}" 67 | embeddings = np.array([embeddings[code] for code in code_list], dtype=np.float32) 68 | if pad_to is not None: 69 | assert pad_to >= len(embeddings), f"pad_to={pad_to} < len(embeddings)={len(embeddings)}" 70 | pad = np.zeros((pad_to - len(embeddings), embeddings.shape[1]), dtype=np.float32) 71 | embeddings = np.concatenate([embeddings, pad], axis=0) 72 | return embeddings -------------------------------------------------------------------------------- /ygoenv/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include ygoenv/*/*.so -------------------------------------------------------------------------------- /ygoenv/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | __version__ = "0.0.1" 4 | 5 | INSTALL_REQUIRES = [ 6 | "setuptools", 7 | "wheel", 8 | "numpy", 9 | "dm-env", 10 | "gym>=0.26", 11 | "gymnasium>=0.26,!=0.27.0", 12 | "optree>=0.6.0", 13 | "packaging", 14 | ] 15 | 16 | setup( 17 | name="ygoenv", 18 | version=__version__, 19 | packages=find_packages(include='ygoenv*'), 20 | long_description="", 21 | install_requires=INSTALL_REQUIRES, 22 | python_requires=">=3.10", 23 | include_package_data=True, 24 | ) -------------------------------------------------------------------------------- /ygoenv/ygoenv/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """EnvPool package for efficient RL environment simulation.""" 15 | 16 | import ygoenv.entry # noqa: F401 17 | from ygoenv.registration import ( 18 | list_all_envs, 19 | make, 20 | make_dm, 21 | make_gym, 22 | make_gymnasium, 23 | make_spec, 24 | register, 25 | ) 26 | 27 | __version__ = "0.8.4" 28 | __all__ = [ 29 | "register", 30 | "make", 31 | "make_dm", 32 | "make_gym", 33 | "make_gymnasium", 34 | "make_spec", 35 | "list_all_envs", 36 | ] 37 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/core/ThreadPool.h: -------------------------------------------------------------------------------- 1 | #ifndef THREAD_POOL_H 2 | #define THREAD_POOL_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | class ThreadPool { 15 | public: 16 | ThreadPool(size_t); 17 | template 18 | auto enqueue(F&& f, Args&&... args) 19 | -> std::future::type>; 20 | ~ThreadPool(); 21 | private: 22 | // need to keep track of threads so we can join them 23 | std::vector< std::thread > workers; 24 | // the task queue 25 | std::queue< std::function > tasks; 26 | 27 | // synchronization 28 | std::mutex queue_mutex; 29 | std::condition_variable condition; 30 | bool stop; 31 | }; 32 | 33 | // the constructor just launches some amount of workers 34 | inline ThreadPool::ThreadPool(size_t threads) 35 | : stop(false) 36 | { 37 | for(size_t i = 0;i task; 44 | 45 | { 46 | std::unique_lock lock(this->queue_mutex); 47 | this->condition.wait(lock, 48 | [this]{ return this->stop || !this->tasks.empty(); }); 49 | if(this->stop && this->tasks.empty()) 50 | return; 51 | task = std::move(this->tasks.front()); 52 | this->tasks.pop(); 53 | } 54 | 55 | task(); 56 | } 57 | } 58 | ); 59 | } 60 | 61 | // add new work item to the pool 62 | template 63 | auto ThreadPool::enqueue(F&& f, Args&&... args) 64 | -> std::future::type> 65 | { 66 | using return_type = typename std::result_of::type; 67 | 68 | auto task = std::make_shared< std::packaged_task >( 69 | std::bind(std::forward(f), std::forward(args)...) 70 | ); 71 | 72 | std::future res = task->get_future(); 73 | { 74 | std::unique_lock lock(queue_mutex); 75 | 76 | // don't allow enqueueing after stopping the pool 77 | if(stop) 78 | throw std::runtime_error("enqueue on stopped ThreadPool"); 79 | 80 | tasks.emplace([task](){ (*task)(); }); 81 | } 82 | condition.notify_one(); 83 | return res; 84 | } 85 | 86 | // the destructor joins all threads 87 | inline ThreadPool::~ThreadPool() 88 | { 89 | { 90 | std::unique_lock lock(queue_mutex); 91 | stop = true; 92 | } 93 | condition.notify_all(); 94 | for(std::thread &worker: workers) 95 | worker.join(); 96 | } 97 | 98 | #endif -------------------------------------------------------------------------------- /ygoenv/ygoenv/core/action_buffer_queue.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Garena Online Private Limited 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef YGOENV_CORE_ACTION_BUFFER_QUEUE_H_ 18 | #define YGOENV_CORE_ACTION_BUFFER_QUEUE_H_ 19 | 20 | #ifndef MOODYCAMEL_DELETE_FUNCTION 21 | #define MOODYCAMEL_DELETE_FUNCTION = delete 22 | #endif 23 | 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | #include "ygoenv/core/array.h" 30 | #include "concurrentqueue/moodycamel/lightweightsemaphore.h" 31 | 32 | /** 33 | * Lock-free action buffer queue. 34 | */ 35 | class ActionBufferQueue { 36 | public: 37 | struct ActionSlice { 38 | int env_id; 39 | int order; 40 | bool force_reset; 41 | }; 42 | 43 | protected: 44 | std::atomic alloc_ptr_, done_ptr_; 45 | std::size_t queue_size_; 46 | std::vector queue_; 47 | moodycamel::LightweightSemaphore sem_, sem_enqueue_, sem_dequeue_; 48 | 49 | public: 50 | explicit ActionBufferQueue(std::size_t num_envs) 51 | : alloc_ptr_(0), 52 | done_ptr_(0), 53 | queue_size_(num_envs * 2), 54 | queue_(queue_size_), 55 | sem_(0), 56 | sem_enqueue_(1), 57 | sem_dequeue_(1) {} 58 | 59 | void EnqueueBulk(const std::vector& action) { 60 | // ensure only one enqueue_bulk happens at any time 61 | while (!sem_enqueue_.wait()) { 62 | } 63 | uint64_t pos = alloc_ptr_.fetch_add(action.size()); 64 | for (std::size_t i = 0; i < action.size(); ++i) { 65 | queue_[(pos + i) % queue_size_] = action[i]; 66 | } 67 | sem_.signal(action.size()); 68 | sem_enqueue_.signal(1); 69 | } 70 | 71 | ActionSlice Dequeue() { 72 | while (!sem_.wait()) { 73 | } 74 | while (!sem_dequeue_.wait()) { 75 | } 76 | auto ptr = done_ptr_.fetch_add(1); 77 | auto ret = queue_[ptr % queue_size_]; 78 | sem_dequeue_.signal(1); 79 | return ret; 80 | } 81 | 82 | std::size_t SizeApprox() { 83 | return static_cast(alloc_ptr_ - done_ptr_); 84 | } 85 | }; 86 | 87 | #endif // YGOENV_CORE_ACTION_BUFFER_QUEUE_H_ 88 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/core/async_envpool.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Garena Online Private Limited 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef YGOENV_CORE_ASYNC_ENVPOOL_H_ 18 | #define YGOENV_CORE_ASYNC_ENVPOOL_H_ 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #include "ThreadPool.h" 28 | #include "ygoenv/core/action_buffer_queue.h" 29 | #include "ygoenv/core/array.h" 30 | #include "ygoenv/core/envpool.h" 31 | #include "ygoenv/core/state_buffer_queue.h" 32 | /** 33 | * Async EnvPool 34 | * 35 | * batch-action -> action buffer queue -> threadpool -> state buffer queue 36 | * 37 | * ThreadPool is tailored with EnvPool, so here we don't use the existing 38 | * third_party ThreadPool (which is really slow). 39 | */ 40 | template 41 | class AsyncEnvPool : public EnvPool { 42 | protected: 43 | std::size_t num_envs_; 44 | std::size_t batch_; 45 | std::size_t max_num_players_; 46 | std::size_t num_threads_; 47 | bool is_sync_; 48 | std::atomic stop_; 49 | std::atomic stepping_env_num_; 50 | std::vector workers_; 51 | std::unique_ptr action_buffer_queue_; 52 | std::unique_ptr state_buffer_queue_; 53 | std::vector> envs_; 54 | std::vector> stepping_env_; 55 | std::chrono::duration dur_send_, dur_recv_, dur_send_all_; 56 | 57 | template 58 | void SendImpl(V&& action) { 59 | int* env_id = static_cast(action[0].Data()); 60 | int shared_offset = action[0].Shape(0); 61 | std::vector actions; 62 | std::shared_ptr> action_batch = 63 | std::make_shared>(std::forward(action)); 64 | for (int i = 0; i < shared_offset; ++i) { 65 | int eid = env_id[i]; 66 | envs_[eid]->SetAction(action_batch, i); 67 | actions.emplace_back(ActionSlice{ 68 | .env_id = eid, 69 | .order = is_sync_ ? i : -1, 70 | .force_reset = false, 71 | }); 72 | } 73 | if (is_sync_) { 74 | stepping_env_num_ += shared_offset; 75 | } 76 | // add to abq 77 | auto start = std::chrono::system_clock::now(); 78 | action_buffer_queue_->EnqueueBulk(actions); 79 | dur_send_ += std::chrono::system_clock::now() - start; 80 | } 81 | 82 | public: 83 | using Spec = typename Env::Spec; 84 | using Action = typename Env::Action; 85 | using State = typename Env::State; 86 | using ActionSlice = typename ActionBufferQueue::ActionSlice; 87 | 88 | explicit AsyncEnvPool(const Spec& spec) 89 | : EnvPool(spec), 90 | num_envs_(spec.config["num_envs"_]), 91 | batch_(spec.config["batch_size"_] <= 0 ? num_envs_ 92 | : spec.config["batch_size"_]), 93 | max_num_players_(spec.config["max_num_players"_]), 94 | num_threads_(spec.config["num_threads"_]), 95 | is_sync_(batch_ == num_envs_ && max_num_players_ == 1), 96 | stop_(0), 97 | stepping_env_num_(0), 98 | action_buffer_queue_(new ActionBufferQueue(num_envs_)), 99 | state_buffer_queue_(new StateBufferQueue( 100 | batch_, num_envs_, max_num_players_, 101 | spec.state_spec.template AllValues())), 102 | envs_(num_envs_) { 103 | std::size_t processor_count = std::thread::hardware_concurrency(); 104 | ThreadPool init_pool(std::min(processor_count, num_envs_)); 105 | std::vector> result; 106 | for (std::size_t i = 0; i < num_envs_; ++i) { 107 | result.emplace_back(init_pool.enqueue( 108 | [i, spec, this] { envs_[i].reset(new Env(spec, i)); })); 109 | } 110 | for (auto& f : result) { 111 | f.get(); 112 | } 113 | if (num_threads_ == 0) { 114 | num_threads_ = std::min(batch_, processor_count); 115 | } 116 | for (std::size_t i = 0; i < num_threads_; ++i) { 117 | workers_.emplace_back([this] { 118 | for (;;) { 119 | ActionSlice raw_action = action_buffer_queue_->Dequeue(); 120 | if (stop_ == 1) { 121 | break; 122 | } 123 | int env_id = raw_action.env_id; 124 | int order = raw_action.order; 125 | bool reset = raw_action.force_reset || envs_[env_id]->IsDone(); 126 | envs_[env_id]->EnvStep(state_buffer_queue_.get(), order, reset); 127 | } 128 | }); 129 | } 130 | if (spec.config["thread_affinity_offset"_] >= 0) { 131 | std::size_t thread_affinity_offset = 132 | spec.config["thread_affinity_offset"_]; 133 | for (std::size_t tid = 0; tid < num_threads_; ++tid) { 134 | cpu_set_t cpuset; 135 | CPU_ZERO(&cpuset); 136 | std::size_t cid = (thread_affinity_offset + tid) % processor_count; 137 | CPU_SET(cid, &cpuset); 138 | pthread_setaffinity_np(workers_[tid].native_handle(), sizeof(cpu_set_t), 139 | &cpuset); 140 | } 141 | } 142 | } 143 | 144 | ~AsyncEnvPool() override { 145 | stop_ = 1; 146 | // LOG(INFO) << "envpool send: " << dur_send_.count(); 147 | // LOG(INFO) << "envpool recv: " << dur_recv_.count(); 148 | // send n actions to clear threadpool 149 | std::vector empty_actions(workers_.size()); 150 | action_buffer_queue_->EnqueueBulk(empty_actions); 151 | for (auto& worker : workers_) { 152 | worker.join(); 153 | } 154 | } 155 | 156 | void Send(const Action& action) { 157 | SendImpl(action.template AllValues()); 158 | } 159 | void Send(const std::vector& action) override { SendImpl(action); } 160 | void Send(std::vector&& action) override { SendImpl(action); } 161 | 162 | std::vector Recv() override { 163 | int additional_wait = 0; 164 | if (is_sync_ && stepping_env_num_ < batch_) { 165 | additional_wait = batch_ - stepping_env_num_; 166 | } 167 | auto start = std::chrono::system_clock::now(); 168 | auto ret = state_buffer_queue_->Wait(additional_wait); 169 | dur_recv_ += std::chrono::system_clock::now() - start; 170 | if (is_sync_) { 171 | stepping_env_num_ -= ret[0].Shape(0); 172 | } 173 | return ret; 174 | } 175 | 176 | void Reset(const Array& env_ids) override { 177 | TArray tenv_ids(env_ids); 178 | int shared_offset = tenv_ids.Shape(0); 179 | std::vector actions(shared_offset); 180 | for (int i = 0; i < shared_offset; ++i) { 181 | actions[i].force_reset = true; 182 | actions[i].env_id = tenv_ids[i]; 183 | actions[i].order = is_sync_ ? i : -1; 184 | } 185 | if (is_sync_) { 186 | stepping_env_num_ += shared_offset; 187 | } 188 | action_buffer_queue_->EnqueueBulk(actions); 189 | } 190 | }; 191 | 192 | #endif // YGOENV_CORE_ASYNC_ENVPOOL_H_ 193 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/core/circular_buffer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Garena Online Private Limited 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef YGOENV_CORE_CIRCULAR_BUFFER_H_ 18 | #define YGOENV_CORE_CIRCULAR_BUFFER_H_ 19 | 20 | #ifndef MOODYCAMEL_DELETE_FUNCTION 21 | #define MOODYCAMEL_DELETE_FUNCTION = delete 22 | #endif 23 | 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | #include "concurrentqueue/moodycamel/lightweightsemaphore.h" 33 | 34 | template 35 | class CircularBuffer { 36 | protected: 37 | std::size_t size_; 38 | moodycamel::LightweightSemaphore sem_get_; 39 | moodycamel::LightweightSemaphore sem_put_; 40 | std::vector buffer_; 41 | std::atomic head_; 42 | std::atomic tail_; 43 | 44 | public: 45 | explicit CircularBuffer(std::size_t size) 46 | : size_(size), sem_put_(size), buffer_(size), head_(0), tail_(0) {} 47 | 48 | template 49 | void Put(T&& v) { 50 | while (!sem_put_.wait()) { 51 | } 52 | uint64_t tail = tail_.fetch_add(1); 53 | auto offset = tail % size_; 54 | buffer_[offset] = std::forward(v); 55 | sem_get_.signal(); 56 | } 57 | 58 | V Get() { 59 | while (!sem_get_.wait()) { 60 | } 61 | uint64_t head = head_.fetch_add(1); 62 | auto offset = head % size_; 63 | V v = std::move(buffer_[offset]); 64 | sem_put_.signal(); 65 | return v; 66 | } 67 | }; 68 | 69 | #endif // YGOENV_CORE_CIRCULAR_BUFFER_H_ 70 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/core/env_spec.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Garena Online Private Limited 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef YGOENV_CORE_ENV_SPEC_H_ 18 | #define YGOENV_CORE_ENV_SPEC_H_ 19 | 20 | #include 21 | #include 22 | 23 | #include "ygoenv/core/dict.h" 24 | 25 | auto common_config = 26 | MakeDict("num_envs"_.Bind(1), "batch_size"_.Bind(0), "num_threads"_.Bind(0), 27 | "max_num_players"_.Bind(1), "thread_affinity_offset"_.Bind(-1), 28 | "base_path"_.Bind(std::string("ygoenv")), "seed"_.Bind(42), 29 | "gym_reset_return_info"_.Bind(false), 30 | "max_episode_steps"_.Bind(std::numeric_limits::max())); 31 | // Note: this action order is hardcoded in async_envpool Send function 32 | // and env ParseAction function for performance 33 | auto common_action_spec = MakeDict("env_id"_.Bind(Spec({})), 34 | "players.env_id"_.Bind(Spec({-1}))); 35 | // Note: this state order is hardcoded in async_envpool Recv function 36 | auto common_state_spec = 37 | MakeDict("info:env_id"_.Bind(Spec({})), 38 | "info:players.env_id"_.Bind(Spec({-1})), 39 | "elapsed_step"_.Bind(Spec({})), "done"_.Bind(Spec({})), 40 | "reward"_.Bind(Spec({-1})), 41 | "discount"_.Bind(Spec({-1}, {0.0, 1.0})), 42 | "step_type"_.Bind(Spec({})), "trunc"_.Bind(Spec({}))); 43 | 44 | /** 45 | * EnvSpec funciton, it constructs the env spec when a Config is passed. 46 | */ 47 | template 48 | class EnvSpec { 49 | public: 50 | using EnvFnsType = EnvFns; 51 | using Config = decltype(ConcatDict(common_config, EnvFns::DefaultConfig())); 52 | using ConfigKeys = typename Config::Keys; 53 | using ConfigValues = typename Config::Values; 54 | using StateSpec = decltype(ConcatDict( 55 | common_state_spec, EnvFns::StateSpec(std::declval()))); 56 | using ActionSpec = decltype(ConcatDict( 57 | common_action_spec, EnvFns::ActionSpec(std::declval()))); 58 | using StateKeys = typename StateSpec::Keys; 59 | using ActionKeys = typename ActionSpec::Keys; 60 | 61 | // For C++ 62 | Config config; 63 | StateSpec state_spec; 64 | ActionSpec action_spec; 65 | static inline const Config kDefaultConfig = 66 | ConcatDict(common_config, EnvFns::DefaultConfig()); 67 | 68 | EnvSpec() : EnvSpec(kDefaultConfig) {} 69 | explicit EnvSpec(const ConfigValues& conf) 70 | : config(conf), 71 | state_spec(ConcatDict(common_state_spec, EnvFns::StateSpec(config))), 72 | action_spec( 73 | ConcatDict(common_action_spec, EnvFns::ActionSpec(config))) { 74 | if (config["batch_size"_] > config["num_envs"_]) { 75 | throw std::invalid_argument( 76 | "It is required that batch_size <= num_envs, got num_envs = " + 77 | std::to_string(config["num_envs"_]) + 78 | ", batch_size = " + std::to_string(config["batch_size"_])); 79 | } 80 | if (config["batch_size"_] == 0) { 81 | config["batch_size"_] = config["num_envs"_]; 82 | } 83 | } 84 | }; 85 | 86 | #endif // YGOENV_CORE_ENV_SPEC_H_ 87 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/core/envpool.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Garena Online Private Limited 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef YGOENV_CORE_ENVPOOL_H_ 18 | #define YGOENV_CORE_ENVPOOL_H_ 19 | 20 | #include 21 | #include 22 | 23 | #include "ygoenv/core/env_spec.h" 24 | 25 | /** 26 | * Templated subclass of EnvPool, to be overrided by the real EnvPool. 27 | */ 28 | template 29 | class EnvPool { 30 | public: 31 | EnvSpec spec; 32 | using Spec = EnvSpec; 33 | using State = NamedVector>; 34 | using Action = NamedVector>; 35 | explicit EnvPool(EnvSpec spec) : spec(std::move(spec)) {} 36 | virtual ~EnvPool() = default; 37 | 38 | protected: 39 | virtual void Send(const std::vector& action) { 40 | throw std::runtime_error("send not implemented"); 41 | } 42 | virtual void Send(std::vector&& action) { 43 | throw std::runtime_error("send not implemented"); 44 | } 45 | virtual std::vector Recv() { 46 | throw std::runtime_error("recv not implemented"); 47 | } 48 | virtual void Reset(const Array& env_ids) { 49 | throw std::runtime_error("reset not implemented"); 50 | } 51 | }; 52 | 53 | #endif // YGOENV_CORE_ENVPOOL_H_ 54 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/core/spec.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Garena Online Private Limited 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef YGOENV_CORE_SPEC_H_ 18 | #define YGOENV_CORE_SPEC_H_ 19 | 20 | #include 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | static std::size_t Prod(const std::size_t* shape, std::size_t ndim) { 32 | return std::accumulate(shape, shape + ndim, static_cast(1), 33 | std::multiplies<>()); 34 | } 35 | 36 | class ShapeSpec { 37 | public: 38 | int element_size; 39 | std::vector shape; 40 | ShapeSpec() = default; 41 | ShapeSpec(int element_size, std::vector shape_vec) 42 | : element_size(element_size), shape(std::move(shape_vec)) {} 43 | [[nodiscard]] ShapeSpec Batch(int batch_size) const { 44 | std::vector new_shape = {batch_size}; 45 | new_shape.insert(new_shape.end(), shape.begin(), shape.end()); 46 | return {element_size, std::move(new_shape)}; 47 | } 48 | [[nodiscard]] std::vector Shape() const { 49 | auto s = std::vector(shape.size()); 50 | for (std::size_t i = 0; i < shape.size(); ++i) { 51 | s[i] = shape[i]; 52 | } 53 | return s; 54 | } 55 | }; 56 | 57 | template 58 | class Spec : public ShapeSpec { 59 | public: 60 | using dtype = D; // NOLINT 61 | std::tuple bounds = {std::numeric_limits::min(), 62 | std::numeric_limits::max()}; 63 | std::tuple, std::vector> elementwise_bounds; 64 | explicit Spec(std::vector&& shape) 65 | : ShapeSpec(sizeof(dtype), std::move(shape)) {} 66 | explicit Spec(const std::vector& shape) 67 | : ShapeSpec(sizeof(dtype), shape) {} 68 | 69 | /* init with constant bounds */ 70 | Spec(std::vector&& shape, std::tuple&& bounds) 71 | : ShapeSpec(sizeof(dtype), std::move(shape)), bounds(std::move(bounds)) {} 72 | Spec(const std::vector& shape, const std::tuple& bounds) 73 | : ShapeSpec(sizeof(dtype), shape), bounds(bounds) {} 74 | 75 | /* init with elementwise bounds */ 76 | Spec(std::vector&& shape, 77 | std::tuple, std::vector>&& elementwise_bounds) 78 | : ShapeSpec(sizeof(dtype), std::move(shape)), 79 | elementwise_bounds(std::move(elementwise_bounds)) {} 80 | Spec(const std::vector& shape, 81 | const std::tuple, std::vector>& 82 | elementwise_bounds) 83 | : ShapeSpec(sizeof(dtype), shape), 84 | elementwise_bounds(elementwise_bounds) {} 85 | 86 | [[nodiscard]] Spec Batch(int batch_size) const { 87 | std::vector new_shape = {batch_size}; 88 | new_shape.insert(new_shape.end(), shape.begin(), shape.end()); 89 | return Spec(std::move(new_shape)); 90 | } 91 | }; 92 | 93 | template 94 | class TArray; 95 | 96 | template 97 | using Container = std::unique_ptr>; 98 | 99 | template 100 | class Spec> : public ShapeSpec { 101 | public: 102 | using dtype = Container; // NOLINT 103 | Spec inner_spec; 104 | explicit Spec(const std::vector& shape, const Spec& inner_spec) 105 | : ShapeSpec(sizeof(Container), shape), inner_spec(inner_spec) {} 106 | explicit Spec(std::vector&& shape, Spec&& inner_spec) 107 | : ShapeSpec(sizeof(Container), std::move(shape)), 108 | inner_spec(std::move(inner_spec)) {} 109 | }; 110 | 111 | #endif // YGOENV_CORE_SPEC_H_ 112 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/core/state_buffer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Garena Online Private Limited 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef YGOENV_CORE_STATE_BUFFER_H_ 18 | #define YGOENV_CORE_STATE_BUFFER_H_ 19 | 20 | #ifndef MOODYCAMEL_DELETE_FUNCTION 21 | #define MOODYCAMEL_DELETE_FUNCTION = delete 22 | #endif 23 | 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | #include "ygoenv/core/array.h" 30 | #include "ygoenv/core/dict.h" 31 | #include "ygoenv/core/spec.h" 32 | #include "concurrentqueue/moodycamel/lightweightsemaphore.h" 33 | 34 | /** 35 | * Buffer of a batch of states, which is used as an intermediate storage device 36 | * for the environments to write their state outputs of each step. 37 | * There's a quota for how many envs' results are stored in this buffer, 38 | * which is controlled by the batch argments in the constructor. 39 | */ 40 | class StateBuffer { 41 | protected: 42 | std::size_t batch_; 43 | std::size_t max_num_players_; 44 | std::vector arrays_; 45 | std::vector is_player_state_; 46 | std::atomic offsets_{0}; 47 | std::atomic alloc_count_{0}; 48 | std::atomic done_count_{0}; 49 | moodycamel::LightweightSemaphore sem_; 50 | 51 | public: 52 | /** 53 | * Return type of StateBuffer.Allocate is a slice of each state arrays that 54 | * can be written by the caller. When writing is done, the caller should 55 | * invoke done write. 56 | */ 57 | struct WritableSlice { 58 | std::vector arr; 59 | std::function done_write; 60 | }; 61 | 62 | /** 63 | * Create a StateBuffer instance with the player_specs and shared_specs 64 | * provided. 65 | */ 66 | StateBuffer(std::size_t batch, std::size_t max_num_players, 67 | const std::vector& specs, 68 | std::vector is_player_state) 69 | : batch_(batch), 70 | max_num_players_(max_num_players), 71 | arrays_(MakeArray(specs)), 72 | is_player_state_(std::move(is_player_state)) {} 73 | 74 | /** 75 | * Tries to allocate a piece of memory without lock. 76 | * If this buffer runs out of quota, an out_of_range exception is thrown. 77 | * Externally, caller has to catch the exception and handle accordingly. 78 | */ 79 | WritableSlice Allocate(std::size_t num_players, int order = -1) { 80 | DCHECK_LE(num_players, max_num_players_); 81 | std::size_t alloc_count = alloc_count_.fetch_add(1); 82 | if (alloc_count < batch_) { 83 | // Make a increment atomically on two uint32_t simultaneously 84 | // This avoids lock 85 | uint64_t increment = static_cast(num_players) << 32 | 1; 86 | uint64_t offsets = offsets_.fetch_add(increment); 87 | uint32_t player_offset = offsets >> 32; 88 | uint32_t shared_offset = offsets; 89 | DCHECK_LE((std::size_t)shared_offset + 1, batch_); 90 | DCHECK_LE((std::size_t)(player_offset + num_players), 91 | batch_ * max_num_players_); 92 | if (order != -1 && max_num_players_ == 1) { 93 | // single player with sync setting: return ordered data 94 | player_offset = shared_offset = order; 95 | } 96 | std::vector state; 97 | state.reserve(arrays_.size()); 98 | for (std::size_t i = 0; i < arrays_.size(); ++i) { 99 | const Array& a = arrays_[i]; 100 | if (is_player_state_[i]) { 101 | state.emplace_back( 102 | a.Slice(player_offset, player_offset + num_players)); 103 | } else { 104 | state.emplace_back(a[shared_offset]); 105 | } 106 | } 107 | return WritableSlice{.arr = std::move(state), 108 | .done_write = [this]() { Done(); }}; 109 | } 110 | DLOG(INFO) << "Allocation failed, continue to the next block of memory"; 111 | throw std::out_of_range("StateBuffer out of storage"); 112 | } 113 | 114 | [[nodiscard]] std::pair Offsets() const { 115 | uint32_t player_offset = offsets_ >> 32; 116 | uint32_t shared_offset = offsets_; 117 | return {player_offset, shared_offset}; 118 | } 119 | 120 | /** 121 | * When the allocated memory has been filled, the user of the memory will 122 | * call this callback to notify StateBuffer that its part has been written. 123 | */ 124 | void Done(std::size_t num = 1) { 125 | std::size_t done_count = done_count_.fetch_add(num); 126 | if (done_count + num == batch_) { 127 | sem_.signal(); 128 | } 129 | } 130 | 131 | /** 132 | * Blocks until the entire buffer is ready, aka, all quota has been 133 | * distributed out, and all user has called done. 134 | */ 135 | std::vector Wait(std::size_t additional_done_count = 0) { 136 | if (additional_done_count > 0) { 137 | Done(additional_done_count); 138 | } 139 | while (!sem_.wait()) { 140 | } 141 | // when things are all done, compact the buffer. 142 | uint64_t offsets = offsets_; 143 | uint32_t player_offset = (offsets >> 32); 144 | uint32_t shared_offset = offsets; 145 | DCHECK_EQ((std::size_t)shared_offset, batch_ - additional_done_count); 146 | std::vector ret; 147 | ret.reserve(arrays_.size()); 148 | for (std::size_t i = 0; i < arrays_.size(); ++i) { 149 | const Array& a = arrays_[i]; 150 | if (is_player_state_[i]) { 151 | ret.emplace_back(a.Truncate(player_offset)); 152 | } else { 153 | ret.emplace_back(a.Truncate(shared_offset)); 154 | } 155 | } 156 | return ret; 157 | } 158 | }; 159 | 160 | #endif // YGOENV_CORE_STATE_BUFFER_H_ 161 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/core/state_buffer_queue.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Garena Online Private Limited 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef YGOENV_CORE_STATE_BUFFER_QUEUE_H_ 18 | #define YGOENV_CORE_STATE_BUFFER_QUEUE_H_ 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #include "ygoenv/core/array.h" 28 | #include "ygoenv/core/circular_buffer.h" 29 | #include "ygoenv/core/spec.h" 30 | #include "ygoenv/core/state_buffer.h" 31 | 32 | class StateBufferQueue { 33 | protected: 34 | std::size_t batch_; 35 | std::size_t max_num_players_; 36 | std::vector is_player_state_; 37 | std::vector specs_; 38 | std::size_t queue_size_; 39 | std::vector> queue_; 40 | std::atomic alloc_count_, done_ptr_, alloc_tail_; 41 | 42 | // Create stock statebuffers in a background thread 43 | CircularBuffer> stock_buffer_; 44 | std::vector create_buffer_thread_; 45 | std::atomic quit_; 46 | 47 | public: 48 | StateBufferQueue(std::size_t batch_env, std::size_t num_envs, 49 | std::size_t max_num_players, 50 | const std::vector& specs) 51 | : batch_(batch_env), 52 | max_num_players_(max_num_players), 53 | is_player_state_(Transform(specs, 54 | [](const ShapeSpec& s) { 55 | return (!s.shape.empty() && 56 | s.shape[0] == -1); 57 | })), 58 | specs_(Transform(specs, 59 | [=](ShapeSpec s) { 60 | if (!s.shape.empty() && s.shape[0] == -1) { 61 | // If first dim is num_players 62 | s.shape[0] = batch_ * max_num_players_; 63 | return s; 64 | } 65 | return s.Batch(batch_); 66 | })), 67 | // two times enough buffer for all the envs 68 | queue_size_((num_envs / batch_env + 2) * 2), 69 | queue_(queue_size_), // circular buffer 70 | alloc_count_(0), 71 | done_ptr_(0), 72 | stock_buffer_((num_envs / batch_env + 2) * 2), 73 | quit_(false) { 74 | // Only initialize first half of the buffer 75 | // At the consumption of each block, the first consumping thread 76 | // will allocate a new state buffer and append to the tail. 77 | // alloc_tail_ = num_envs / batch_env + 2; 78 | for (auto& q : queue_) { 79 | q = std::make_unique(batch_, max_num_players_, specs_, 80 | is_player_state_); 81 | } 82 | std::size_t processor_count = std::thread::hardware_concurrency(); 83 | // hardcode here :( 84 | std::size_t create_buffer_thread_num = std::max(1UL, processor_count / 64); 85 | for (std::size_t i = 0; i < create_buffer_thread_num; ++i) { 86 | create_buffer_thread_.emplace_back(std::thread([&]() { 87 | while (true) { 88 | stock_buffer_.Put(std::make_unique( 89 | batch_, max_num_players_, specs_, is_player_state_)); 90 | if (quit_) { 91 | break; 92 | } 93 | } 94 | })); 95 | } 96 | } 97 | 98 | ~StateBufferQueue() { 99 | // stop the thread 100 | quit_ = true; 101 | for (std::size_t i = 0; i < create_buffer_thread_.size(); ++i) { 102 | stock_buffer_.Get(); 103 | } 104 | for (auto& t : create_buffer_thread_) { 105 | t.join(); 106 | } 107 | } 108 | 109 | /** 110 | * Allocate slice of memory for the current env to write. 111 | * This function is used from the producer side. 112 | * It is safe to access from multiple threads. 113 | */ 114 | StateBuffer::WritableSlice Allocate(std::size_t num_players, int order = -1) { 115 | std::size_t pos = alloc_count_.fetch_add(1); 116 | std::size_t offset = (pos / batch_) % queue_size_; 117 | // if (pos % batch_ == 0) { 118 | // // At the time a new statebuffer is accessed, the first visitor 119 | // allocate 120 | // // a new state buffer and put it at the back of the queue. 121 | // std::size_t insert_pos = alloc_tail_.fetch_add(1); 122 | // std::size_t insert_offset = insert_pos % queue_size_; 123 | // queue_[insert_offset].reset( 124 | // new StateBuffer(batch_, max_num_players_, specs_, 125 | // is_player_state_)); 126 | // } 127 | return queue_[offset]->Allocate(num_players, order); 128 | } 129 | 130 | /** 131 | * Wait for the state buffer at the head to be ready. 132 | * This function can only be accessed from one thread. 133 | * 134 | * BIG CAVEATE: 135 | * Wait should be accessed from only one thread. 136 | * If Wait is accessed from multiple threads, it is only safe if the finish 137 | * time of each state buffer is in the same order as the allocation time. 138 | */ 139 | std::vector Wait(std::size_t additional_done_count = 0) { 140 | std::unique_ptr newbuf = stock_buffer_.Get(); 141 | std::size_t pos = done_ptr_.fetch_add(1); 142 | std::size_t offset = pos % queue_size_; 143 | auto arr = queue_[offset]->Wait(additional_done_count); 144 | if (additional_done_count > 0) { 145 | // move pointer to the next block 146 | alloc_count_.fetch_add(additional_done_count); 147 | } 148 | std::swap(queue_[offset], newbuf); 149 | return arr; 150 | } 151 | }; 152 | 153 | #endif // YGOENV_CORE_STATE_BUFFER_QUEUE_H_ 154 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/core/tuple_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Garena Online Private Limited 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef YGOENV_CORE_TUPLE_UTILS_H_ 18 | #define YGOENV_CORE_TUPLE_UTILS_H_ 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | template 27 | struct Index; 28 | 29 | template 30 | struct Index> { 31 | static constexpr std::size_t kValue = 0; 32 | }; 33 | 34 | template 35 | struct Index> { 36 | static constexpr std::size_t kValue = 37 | 1 + Index>::kValue; 38 | }; 39 | 40 | template 41 | decltype(auto) ApplyZip(F&& f, K&& k, V&& v, 42 | std::index_sequence /*unused*/) { 43 | return std::invoke(std::forward(f), 44 | std::make_tuple(I, std::get(std::forward(k)), 45 | std::get(std::forward(v)))...); 46 | } 47 | 48 | template 49 | using tuple_cat_t = decltype(std::tuple_cat(std::declval()...)); // NOLINT 50 | 51 | template 52 | decltype(auto) TupleFromVectorImpl(std::index_sequence /*unused*/, 53 | const std::vector& arguments) { 54 | return TupleType(arguments[Is]...); 55 | } 56 | 57 | template 58 | decltype(auto) TupleFromVectorImpl(std::index_sequence /*unused*/, 59 | std::vector&& arguments) { 60 | return TupleType(std::move(arguments[Is])...); 61 | } 62 | 63 | template 64 | decltype(auto) TupleFromVector(V&& arguments) { 65 | return TupleFromVectorImpl( 66 | std::make_index_sequence>{}, 67 | std::forward(arguments)); 68 | } 69 | 70 | #endif // YGOENV_CORE_TUPLE_UTILS_H_ 71 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/core/type_utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Garena Online Private Limited 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef YGOENV_CORE_TYPE_UTILS_H_ 18 | #define YGOENV_CORE_TYPE_UTILS_H_ 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | template 26 | struct any_match; 27 | 28 | template class Tuple, typename... Ts> 29 | struct any_match> : std::disjunction...> {}; 30 | 31 | template 32 | struct all_match; 33 | 34 | template class Tuple, typename... Ts> 35 | struct all_match> : std::conjunction...> {}; 36 | 37 | template 38 | struct all_convertible; 39 | 40 | template class Tuple, typename... Fs> 41 | struct all_convertible> 42 | : std::conjunction...> {}; 43 | 44 | template 45 | constexpr bool is_tuple_v = false; // NOLINT 46 | template 47 | constexpr bool is_tuple_v> = true; // NOLINT 48 | 49 | template 50 | constexpr bool is_vector_v = false; // NOLINT 51 | template 52 | constexpr bool is_vector_v> = true; // NOLINT 53 | 54 | #endif // YGOENV_CORE_TYPE_UTILS_H_ 55 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/dummy/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2021 Garena Online Private Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Dummy Env in EnvPool.""" 16 | 17 | from ygoenv.python.api import py_env 18 | 19 | from .dummy_envpool import _DummyEnvPool, _DummyEnvSpec 20 | 21 | DummyEnvSpec, DummyDMEnvPool, DummyGymEnvPool, DummyGymnasiumEnvPool = py_env( 22 | _DummyEnvSpec, _DummyEnvPool 23 | ) 24 | 25 | __all__ = [ 26 | "DummyEnvSpec", 27 | "DummyDMEnvPool", 28 | "DummyGymEnvPool", 29 | "DummyGymnasiumEnvPool", 30 | ] 31 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/dummy/dummy_envpool.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 2021 Garena Online Private Limited 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "ygoenv/dummy/dummy_envpool.h" 18 | 19 | #include "ygoenv/core/py_envpool.h" 20 | 21 | /** 22 | * Wrap the `DummyEnvSpec` and `DummyEnvPool` with the corresponding `PyEnvSpec` 23 | * and `PyEnvPool` template. 24 | */ 25 | using DummyEnvSpec = PyEnvSpec; 26 | using DummyEnvPool = PyEnvPool; 27 | 28 | /** 29 | * Finally, call the REGISTER macro to expose them to python 30 | */ 31 | PYBIND11_MODULE(dummy_ygoenv, m) { REGISTER(m, DummyEnvSpec, DummyEnvPool) } 32 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/dummy/registration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Classic control env registration.""" 15 | 16 | from ygoenv.registration import register 17 | 18 | register( 19 | task_id="Dummy-v0", 20 | import_path="ygoenv.dummy", 21 | spec_cls="DummyEnvSpec", 22 | dm_cls="DummyDMEnvPool", 23 | gym_cls="DummyGymEnvPool", 24 | gymnasium_cls="DummyGymnasiumEnvPool", 25 | ) 26 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/edopro/__init__.py: -------------------------------------------------------------------------------- 1 | from ygoenv.python.api import py_env 2 | 3 | from .edopro_ygoenv import ( 4 | _EDOProEnvPool, 5 | _EDOProEnvSpec, 6 | init_module, 7 | ) 8 | 9 | ( 10 | EDOProEnvSpec, 11 | EDOProDMEnvPool, 12 | EDOProGymEnvPool, 13 | EDOProGymnasiumEnvPool, 14 | ) = py_env(_EDOProEnvSpec, _EDOProEnvPool) 15 | 16 | 17 | __all__ = [ 18 | "EDOProEnvSpec", 19 | "EDOProDMEnvPool", 20 | "EDOProGymEnvPool", 21 | "EDOProGymnasiumEnvPool", 22 | ] 23 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/edopro/edopro.cpp: -------------------------------------------------------------------------------- 1 | #include "ygoenv/edopro/edopro.h" 2 | #include "ygoenv/core/py_envpool.h" 3 | 4 | using EDOProEnvSpec = PyEnvSpec; 5 | using EDOProEnvPool = PyEnvPool; 6 | 7 | PYBIND11_MODULE(edopro_ygoenv, m) { 8 | REGISTER(m, EDOProEnvSpec, EDOProEnvPool) 9 | 10 | m.def("init_module", &edopro::init_module); 11 | } 12 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/edopro/registration.py: -------------------------------------------------------------------------------- 1 | from ygoenv.registration import register 2 | 3 | register( 4 | task_id="EDOPro-v0", 5 | import_path="ygoenv.edopro", 6 | spec_cls="EDOProEnvSpec", 7 | dm_cls="EDOProDMEnvPool", 8 | gym_cls="EDOProGymEnvPool", 9 | gymnasium_cls="EDOProGymnasiumEnvPool", 10 | ) 11 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/entry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Entry point for all envs' registration.""" 15 | 16 | try: 17 | import ygoenv.ygopro.registration # noqa: F401 18 | except ImportError: 19 | pass 20 | 21 | try: 22 | import ygoenv.ygopro0.registration # noqa: F401 23 | except ImportError: 24 | pass 25 | 26 | try: 27 | import ygoenv.edopro.registration # noqa: F401 28 | except ImportError: 29 | pass 30 | 31 | try: 32 | import ygoenv.dummy.registration # noqa: F401 33 | except ImportError: 34 | pass 35 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/python/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Python interface for EnvPool.""" 15 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/python/api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Api wrapper layer for EnvPool.""" 15 | 16 | from typing import Tuple, Type 17 | 18 | from .dm_envpool import DMEnvPoolMeta 19 | from .env_spec import EnvSpecMeta 20 | from .gym_envpool import GymEnvPoolMeta 21 | from .gymnasium_envpool import GymnasiumEnvPoolMeta 22 | from .protocol import EnvPool, EnvSpec 23 | 24 | 25 | def py_env( 26 | envspec: Type[EnvSpec], envpool: Type[EnvPool] 27 | ) -> Tuple[Type[EnvSpec], Type[EnvPool], Type[EnvPool], Type[EnvPool]]: 28 | """Initialize EnvPool for users.""" 29 | # remove the _ prefix added when registering cpp class via pybind 30 | spec_name = envspec.__name__[1:] 31 | pool_name = envpool.__name__[1:] 32 | return ( 33 | EnvSpecMeta(spec_name, (envspec,), {}), # type: ignore[return-value] 34 | DMEnvPoolMeta(pool_name.replace("EnvPool", "DMEnvPool"), (envpool,), {}), 35 | GymEnvPoolMeta(pool_name.replace("EnvPool", "GymEnvPool"), (envpool,), {}), 36 | GymnasiumEnvPoolMeta( 37 | pool_name.replace("EnvPool", "GymnasiumEnvPool"), (envpool,), {} 38 | ), 39 | ) 40 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/python/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Helper function for data convertion.""" 15 | 16 | from collections import namedtuple 17 | from typing import Any, Dict, List, Tuple, Type 18 | 19 | import dm_env 20 | import gym 21 | import gymnasium 22 | import numpy as np 23 | import optree 24 | from optree import PyTreeSpec 25 | 26 | from .protocol import ArraySpec 27 | 28 | ACTION_THRESHOLD = 2**20 29 | 30 | 31 | def to_nested_dict(flatten_dict: Dict[str, Any], 32 | generator: Type = dict) -> Dict[str, Any]: 33 | """Convert a flat dict to a hierarchical dict. 34 | 35 | The input dict's hierarchy is denoted by ``.``. 36 | 37 | Example: 38 | :: 39 | 40 | >>> to_nested_dict({"a.b": 2333, "a.c": 666}) 41 | {"a": {"b": 2333, "c": 666}} 42 | 43 | Args: 44 | flatten_dict: a dict whose keys list contains ``.`` for hierarchical 45 | representation. 46 | generator: a type of mapping. Default to ``dict``. 47 | """ 48 | ret: Dict[str, Any] = generator() 49 | for k, v in flatten_dict.items(): 50 | segments = k.split(".") 51 | ptr = ret 52 | for s in segments[:-1]: 53 | if s not in ptr: 54 | ptr[s] = generator() 55 | ptr = ptr[s] 56 | ptr[segments[-1]] = v 57 | return ret 58 | 59 | 60 | def to_namedtuple(name: str, hdict: Dict) -> Tuple: 61 | """Convert a hierarchical dict to namedtuple.""" 62 | return namedtuple(name, hdict.keys())( 63 | *[ 64 | to_namedtuple(k, v) if isinstance(v, Dict) else v 65 | for k, v in hdict.items() 66 | ] 67 | ) 68 | 69 | 70 | def dm_spec_transform( 71 | name: str, spec: ArraySpec, spec_type: str 72 | ) -> dm_env.specs.Array: 73 | """Transform ArraySpec to dm_env compatible specs.""" 74 | if np.prod(np.abs(spec.shape)) == 1 and \ 75 | np.isclose(spec.minimum, 0) and spec.maximum < ACTION_THRESHOLD: 76 | # special treatment for discrete action space 77 | return dm_env.specs.DiscreteArray( 78 | name=name, 79 | dtype=spec.dtype, 80 | num_values=int(spec.maximum - spec.minimum + 1), 81 | ) 82 | return dm_env.specs.BoundedArray( 83 | name=name, 84 | shape=[s for s in spec.shape if s != -1], 85 | dtype=spec.dtype, 86 | minimum=spec.minimum, 87 | maximum=spec.maximum, 88 | ) 89 | 90 | 91 | def gym_spec_transform(name: str, spec: ArraySpec, spec_type: str) -> gym.Space: 92 | """Transform ArraySpec to gym.Env compatible spaces.""" 93 | if np.prod(np.abs(spec.shape)) == 1 and \ 94 | np.isclose(spec.minimum, 0) and spec.maximum < ACTION_THRESHOLD: 95 | # special treatment for discrete action space 96 | discrete_range = int(spec.maximum - spec.minimum + 1) 97 | try: 98 | return gym.spaces.Discrete(n=discrete_range, start=int(spec.minimum)) 99 | except TypeError: # old gym version doesn't have `start` 100 | return gym.spaces.Discrete(n=discrete_range) 101 | return gym.spaces.Box( 102 | shape=[s for s in spec.shape if s != -1], 103 | dtype=spec.dtype, 104 | low=spec.minimum, 105 | high=spec.maximum, 106 | ) 107 | 108 | 109 | def gymnasium_spec_transform( 110 | name: str, spec: ArraySpec, spec_type: str 111 | ) -> gymnasium.Space: 112 | """Transform ArraySpec to gymnasium.Env compatible spaces.""" 113 | if np.prod(np.abs(spec.shape)) == 1 and \ 114 | np.isclose(spec.minimum, 0) and spec.maximum < ACTION_THRESHOLD: 115 | # special treatment for discrete action space 116 | discrete_range = int(spec.maximum - spec.minimum + 1) 117 | return gymnasium.spaces.Discrete(n=discrete_range, start=int(spec.minimum)) 118 | return gymnasium.spaces.Box( 119 | shape=[s for s in spec.shape if s != -1], 120 | dtype=spec.dtype, 121 | low=spec.minimum, 122 | high=spec.maximum, 123 | ) 124 | 125 | 126 | def dm_structure( 127 | root_name: str, 128 | keys: List[str], 129 | ) -> Tuple[List[Tuple[int, ...]], List[int], PyTreeSpec]: 130 | """Convert flat keys into tree structure for namedtuple construction.""" 131 | new_keys = [] 132 | for key in keys: 133 | if key in ["obs", "info"]: # special treatment for single-node obs/info 134 | key = f"obs:{key}" 135 | key = key.replace("info:", "obs:") # merge obs and info together 136 | key = key.replace("obs:", f"{root_name}:") # compatible with to_namedtuple 137 | new_keys.append(key.replace(":", ".")) 138 | dict_tree = to_nested_dict(dict(zip(new_keys, list(range(len(new_keys)))))) 139 | structure = to_namedtuple(root_name, dict_tree) 140 | paths, indices, treespec = optree.tree_flatten_with_path(structure) 141 | return paths, indices, treespec 142 | 143 | 144 | def gym_structure( 145 | keys: List[str] 146 | ) -> Tuple[List[Tuple[str, ...]], List[int], PyTreeSpec]: 147 | """Convert flat keys into tree structure for dict construction.""" 148 | keys = [k.replace(":", ".") for k in keys] 149 | dict_tree = to_nested_dict(dict(zip(keys, list(range(len(keys)))))) 150 | paths, indices, treespec = optree.tree_flatten_with_path(dict_tree) 151 | return paths, indices, treespec 152 | 153 | 154 | gymnasium_structure = gym_structure 155 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/python/dm_envpool.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """EnvPool meta class for dm_env API.""" 15 | 16 | from abc import ABC, ABCMeta 17 | from typing import Any, Dict, List, Tuple, Union 18 | 19 | import dm_env 20 | import numpy as np 21 | import optree 22 | from dm_env import TimeStep 23 | 24 | from .data import dm_structure 25 | from .envpool import EnvPoolMixin 26 | from .utils import check_key_duplication 27 | 28 | 29 | class DMEnvPoolMixin(ABC): 30 | """Special treatment for dm_env API.""" 31 | 32 | def observation_spec(self: Any) -> Tuple: 33 | """Observation spec from EnvSpec.""" 34 | if not hasattr(self, "_dm_observation_spec"): 35 | self._dm_observation_spec = self.spec.observation_spec() 36 | return self._dm_observation_spec 37 | 38 | def action_spec(self: Any) -> Union[dm_env.specs.Array, Tuple]: 39 | """Action spec from EnvSpec.""" 40 | if not hasattr(self, "_dm_action_spec"): 41 | self._dm_action_spec = self.spec.action_spec() 42 | return self._dm_action_spec 43 | 44 | 45 | class DMEnvPoolMeta(ABCMeta): 46 | """Additional wrapper for EnvPool dm_env API.""" 47 | 48 | def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any: 49 | """Check internal config and initialize data format convertion.""" 50 | base = parents[0] 51 | try: 52 | from .lax import XlaMixin 53 | 54 | parents = ( 55 | base, DMEnvPoolMixin, EnvPoolMixin, XlaMixin, dm_env.Environment 56 | ) 57 | except ImportError: 58 | 59 | def _xla(self: Any) -> None: 60 | raise RuntimeError("XLA is disabled. To enable XLA please install jax.") 61 | 62 | attrs["xla"] = _xla 63 | parents = (base, DMEnvPoolMixin, EnvPoolMixin, dm_env.Environment) 64 | 65 | state_keys = base._state_keys 66 | action_keys = base._action_keys 67 | check_key_duplication(name, "state", state_keys) 68 | check_key_duplication(name, "action", action_keys) 69 | 70 | state_paths, state_idx, treepsec = dm_structure("State", state_keys) 71 | 72 | def _to_dm( 73 | self: Any, 74 | state_values: List[np.ndarray], 75 | reset: bool, 76 | return_info: bool, 77 | ) -> TimeStep: 78 | values = (state_values[i] for i in state_idx) 79 | state = optree.tree_unflatten(treepsec, values) 80 | timestep = TimeStep( 81 | step_type=state.step_type, 82 | observation=state.State, 83 | reward=state.reward, 84 | discount=state.discount, 85 | ) 86 | return timestep 87 | 88 | attrs["_to"] = _to_dm 89 | subcls = super().__new__(cls, name, parents, attrs) 90 | 91 | def init(self: Any, spec: Any) -> None: 92 | """Set self.spec to EnvSpecMeta.""" 93 | super(subcls, self).__init__(spec) 94 | self.spec = spec 95 | 96 | setattr(subcls, "__init__", init) # noqa: B010 97 | return subcls 98 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/python/envpool.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """EnvPool Mixin class for meta class definition.""" 15 | 16 | import pprint 17 | import warnings 18 | from abc import ABC 19 | from typing import Any, Dict, List, Optional, Tuple, Union 20 | 21 | import numpy as np 22 | import optree 23 | from dm_env import TimeStep 24 | 25 | from .protocol import EnvPool, EnvSpec 26 | 27 | 28 | class EnvPoolMixin(ABC): 29 | """Mixin class for EnvPool, exposed to EnvPoolMeta.""" 30 | 31 | _spec: EnvSpec 32 | 33 | def _check_action(self: EnvPool, actions: List[np.ndarray]) -> None: 34 | if hasattr(self, "_check_action_finished"): # only check once 35 | return 36 | self._check_action_finished = True 37 | for a, (k, v) in zip(actions, self.spec.action_array_spec.items()): 38 | if v.dtype != a.dtype: 39 | raise RuntimeError( 40 | f"Expected dtype {v.dtype} with action \"{k}\", got {a.dtype}" 41 | ) 42 | shape = tuple(v.shape) 43 | if len(shape) > 0 and shape[0] == -1: 44 | if a.shape[1:] != shape[1:]: 45 | raise RuntimeError( 46 | f"Expected shape {shape} with action \"{k}\", got {a.shape}" 47 | ) 48 | else: 49 | if len(a.shape) == 0 or a.shape[1:] != shape: 50 | raise RuntimeError( 51 | f"Expected shape {('num_env', *shape)} with action \"{k}\", " 52 | f"got {a.shape}" 53 | ) 54 | 55 | def _from( 56 | self: EnvPool, 57 | action: Union[Dict[str, Any], np.ndarray], 58 | env_id: Optional[np.ndarray] = None, 59 | ) -> List[np.ndarray]: 60 | """Convert action to C++-acceptable format.""" 61 | if isinstance(action, dict): 62 | paths, values, _ = optree.tree_flatten_with_path(action) 63 | adict = {'.'.join(p): v for p, v in zip(paths, values)} 64 | else: # only 3 keys in action_keys 65 | if not hasattr(self, "_last_action_type"): 66 | self._last_action_type = self._spec._action_spec[-1][0] 67 | if not hasattr(self, "_last_action_name"): 68 | self._last_action_name = self._spec._action_keys[-1] 69 | if isinstance(action, np.ndarray): 70 | # else it could be a jax array, when using xla 71 | action = action.astype( 72 | self._last_action_type, # type: ignore 73 | order='C', 74 | ) 75 | adict = {self._last_action_name: action} # type: ignore 76 | if env_id is None: 77 | if "env_id" not in adict: 78 | adict["env_id"] = self.all_env_ids 79 | else: 80 | adict["env_id"] = env_id.astype(np.int32) 81 | if "players.env_id" not in adict: 82 | adict["players.env_id"] = adict["env_id"] 83 | if not hasattr(self, "_action_names"): 84 | self._action_names = self._spec._action_keys 85 | return list(map(lambda k: adict[k], self._action_names)) # type: ignore 86 | 87 | def __len__(self: EnvPool) -> int: 88 | """Return the number of environments.""" 89 | return self.config["num_envs"] 90 | 91 | @property 92 | def all_env_ids(self: EnvPool) -> np.ndarray: 93 | """All env_id in numpy ndarray with dtype=np.int32.""" 94 | if not hasattr(self, "_all_env_ids"): 95 | self._all_env_ids = np.arange(self.config["num_envs"], dtype=np.int32) 96 | return self._all_env_ids # type: ignore 97 | 98 | @property 99 | def is_async(self: EnvPool) -> bool: 100 | """Return if this env is in sync mode or async mode.""" 101 | return self.config["batch_size"] > 0 and self.config[ 102 | "num_envs"] != self.config["batch_size"] 103 | 104 | def seed(self: EnvPool, seed: Optional[Union[int, List[int]]] = None) -> None: 105 | """Set the seed for all environments (abandoned).""" 106 | warnings.warn( 107 | "The `seed` function in envpool is abandoned. " 108 | "You can set seed by envpool.make(..., seed=seed) instead.", 109 | stacklevel=2 110 | ) 111 | 112 | def send( 113 | self: EnvPool, 114 | action: Union[Dict[str, Any], np.ndarray], 115 | env_id: Optional[np.ndarray] = None, 116 | ) -> None: 117 | """Send actions into EnvPool.""" 118 | action = self._from(action, env_id) 119 | self._check_action(action) 120 | self._send(action) 121 | 122 | def recv( 123 | self: EnvPool, 124 | reset: bool = False, 125 | return_info: bool = True, 126 | ) -> Union[TimeStep, Tuple]: 127 | """Recv a batch state from EnvPool.""" 128 | state_list = self._recv() 129 | return self._to(state_list, reset, return_info) 130 | 131 | def async_reset(self: EnvPool) -> None: 132 | """Follows the async semantics, reset the envs in env_ids.""" 133 | self._reset(self.all_env_ids) 134 | 135 | def step( 136 | self: EnvPool, 137 | action: Union[Dict[str, Any], np.ndarray], 138 | env_id: Optional[np.ndarray] = None, 139 | ) -> Union[TimeStep, Tuple]: 140 | """Perform one step with multiple environments in EnvPool.""" 141 | self.send(action, env_id) 142 | return self.recv(reset=False, return_info=True) 143 | 144 | def reset( 145 | self: EnvPool, 146 | env_id: Optional[np.ndarray] = None, 147 | ) -> Union[TimeStep, Tuple]: 148 | """Reset envs in env_id. 149 | 150 | This behavior is not defined in async mode. 151 | """ 152 | if env_id is None: 153 | env_id = self.all_env_ids 154 | self._reset(env_id) 155 | return self.recv( 156 | reset=True, return_info=self.config["gym_reset_return_info"] 157 | ) 158 | 159 | @property 160 | def config(self: EnvPool) -> Dict[str, Any]: 161 | """Config dict of this class.""" 162 | return dict(zip(self._spec._config_keys, self._spec._config_values)) 163 | 164 | def __repr__(self: EnvPool) -> str: 165 | """Prettify the debug information.""" 166 | config = self.config 167 | config_str = ", ".join( 168 | [f"{k}={pprint.pformat(v)}" for k, v in config.items()] 169 | ) 170 | return f"{self.__class__.__name__}({config_str})" 171 | 172 | def __str__(self: EnvPool) -> str: 173 | """Prettify the debug information.""" 174 | return self.__repr__() 175 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/python/gym_envpool.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """EnvPool meta class for gym.Env API.""" 15 | 16 | from abc import ABC, ABCMeta 17 | from typing import Any, Dict, List, Tuple, Union 18 | 19 | import gym 20 | import numpy as np 21 | import optree 22 | from packaging import version 23 | 24 | from .data import gym_structure 25 | from .envpool import EnvPoolMixin 26 | from .utils import check_key_duplication 27 | 28 | 29 | class GymEnvPoolMixin(ABC): 30 | """Special treatment for gym API.""" 31 | 32 | @property 33 | def observation_space(self: Any) -> Union[gym.Space, Dict[str, Any]]: 34 | """Observation space from EnvSpec.""" 35 | if not hasattr(self, "_gym_observation_space"): 36 | self._gym_observation_space = self.spec.observation_space 37 | return self._gym_observation_space 38 | 39 | @property 40 | def action_space(self: Any) -> Union[gym.Space, Dict[str, Any]]: 41 | """Action space from EnvSpec.""" 42 | if not hasattr(self, "_gym_action_space"): 43 | self._gym_action_space = self.spec.action_space 44 | return self._gym_action_space 45 | 46 | 47 | class GymEnvPoolMeta(ABCMeta, gym.Env.__class__): 48 | """Additional wrapper for EnvPool gym.Env API.""" 49 | 50 | def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any: 51 | """Check internal config and initialize data format convertion.""" 52 | base = parents[0] 53 | try: 54 | from .lax import XlaMixin 55 | 56 | parents = (base, GymEnvPoolMixin, EnvPoolMixin, XlaMixin, gym.Env) 57 | except ImportError: 58 | 59 | def _xla(self: Any) -> None: 60 | raise RuntimeError("XLA is disabled. To enable XLA please install jax.") 61 | 62 | attrs["xla"] = _xla 63 | parents = (base, GymEnvPoolMixin, EnvPoolMixin, gym.Env) 64 | 65 | state_keys = base._state_keys 66 | action_keys = base._action_keys 67 | check_key_duplication(name, "state", state_keys) 68 | check_key_duplication(name, "action", action_keys) 69 | 70 | state_paths, state_idx, treepsec = gym_structure(state_keys) 71 | 72 | new_gym_api = version.parse(gym.__version__) >= version.parse("0.26.0") 73 | 74 | def _to_gym( 75 | self: Any, state_values: List[np.ndarray], reset: bool, return_info: bool 76 | ) -> Union[ 77 | Any, 78 | Tuple[Any, Any], 79 | Tuple[Any, np.ndarray, np.ndarray, Any], 80 | Tuple[Any, np.ndarray, np.ndarray, np.ndarray, Any], 81 | ]: 82 | values = (state_values[i] for i in state_idx) 83 | state = optree.tree_unflatten(treepsec, values) 84 | if reset and not (return_info or new_gym_api): 85 | return state["obs"] 86 | info = state["info"] 87 | if not new_gym_api: 88 | info["TimeLimit.truncated"] = state["trunc"] 89 | info["elapsed_step"] = state["elapsed_step"] 90 | if reset: 91 | return state["obs"], info 92 | if new_gym_api: 93 | terminated = state["done"] & ~state["trunc"] 94 | return state["obs"], state["reward"], terminated, state["trunc"], info 95 | return state["obs"], state["reward"], state["done"], info 96 | 97 | attrs["_to"] = _to_gym 98 | subcls = super().__new__(cls, name, parents, attrs) 99 | 100 | def init(self: Any, spec: Any) -> None: 101 | """Set self.spec to EnvSpecMeta.""" 102 | super(subcls, self).__init__(spec) 103 | self.spec = spec 104 | 105 | setattr(subcls, "__init__", init) # noqa: B010 106 | return subcls 107 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/python/gymnasium_envpool.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """EnvPool meta class for gymnasium.Env API.""" 15 | 16 | from abc import ABC, ABCMeta 17 | from typing import Any, Dict, List, Tuple, Union 18 | 19 | import gymnasium 20 | import numpy as np 21 | import optree 22 | 23 | from .data import gymnasium_structure 24 | from .envpool import EnvPoolMixin 25 | from .utils import check_key_duplication 26 | 27 | 28 | class GymnasiumEnvPoolMixin(ABC): 29 | """Special treatment for gymnasim API.""" 30 | 31 | @property 32 | def observation_space(self: Any) -> Union[gymnasium.Space, Dict[str, Any]]: 33 | """Observation space from EnvSpec.""" 34 | if not hasattr(self, "_gym_observation_space"): 35 | self._gym_observation_space = self.spec.gymnasium_observation_space 36 | return self._gym_observation_space 37 | 38 | @property 39 | def action_space(self: Any) -> Union[gymnasium.Space, Dict[str, Any]]: 40 | """Action space from EnvSpec.""" 41 | if not hasattr(self, "_gym_action_space"): 42 | self._gym_action_space = self.spec.gymnasium_action_space 43 | return self._gym_action_space 44 | 45 | 46 | class GymnasiumEnvPoolMeta(ABCMeta, gymnasium.Env.__class__): 47 | """Additional wrapper for EnvPool gymnasium.Env API.""" 48 | 49 | def __new__(cls: Any, name: str, parents: Tuple, attrs: Dict) -> Any: 50 | """Check internal config and initialize data format convertion.""" 51 | base = parents[0] 52 | try: 53 | from .lax import XlaMixin 54 | 55 | parents = ( 56 | base, GymnasiumEnvPoolMixin, EnvPoolMixin, XlaMixin, gymnasium.Env 57 | ) 58 | except ImportError: 59 | 60 | def _xla(self: Any) -> None: 61 | raise RuntimeError("XLA is disabled. To enable XLA please install jax.") 62 | 63 | attrs["xla"] = _xla 64 | parents = (base, GymnasiumEnvPoolMixin, EnvPoolMixin, gymnasium.Env) 65 | 66 | state_keys = base._state_keys 67 | action_keys = base._action_keys 68 | check_key_duplication(name, "state", state_keys) 69 | check_key_duplication(name, "action", action_keys) 70 | 71 | state_paths, state_idx, treepsec = gymnasium_structure(state_keys) 72 | 73 | def _to_gymnasium( 74 | self: Any, state_values: List[np.ndarray], reset: bool, return_info: bool 75 | ) -> Union[ 76 | Any, 77 | Tuple[Any, Any], 78 | Tuple[Any, np.ndarray, np.ndarray, Any], 79 | Tuple[Any, np.ndarray, np.ndarray, np.ndarray, Any], 80 | ]: 81 | values = (state_values[i] for i in state_idx) 82 | state = optree.tree_unflatten(treepsec, values) 83 | info = state["info"] 84 | info["elapsed_step"] = state["elapsed_step"] 85 | if reset: 86 | return state["obs"], info 87 | terminated = state["done"] & ~state["trunc"] 88 | return state["obs"], state["reward"], terminated, state["trunc"], info 89 | 90 | attrs["_to"] = _to_gymnasium 91 | subcls = super().__new__(cls, name, parents, attrs) 92 | 93 | def init(self: Any, spec: Any) -> None: 94 | """Set self.spec to EnvSpecMeta.""" 95 | super(subcls, self).__init__(spec) 96 | self.spec = spec 97 | 98 | setattr(subcls, "__init__", init) # noqa: B010 99 | return subcls 100 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/python/protocol.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Protocol of C++ EnvPool.""" 15 | 16 | from typing import ( 17 | Any, 18 | Callable, 19 | Dict, 20 | List, 21 | NamedTuple, 22 | Optional, 23 | Tuple, 24 | Type, 25 | Union, 26 | ) 27 | 28 | import dm_env 29 | import gym 30 | import numpy as np 31 | from dm_env import TimeStep 32 | 33 | try: 34 | from typing import Protocol 35 | except ImportError: 36 | from typing_extensions import Protocol # type: ignore 37 | 38 | 39 | class EnvSpec(Protocol): 40 | """Cpp EnvSpec class.""" 41 | 42 | _config_keys: List[str] 43 | _default_config_values: Tuple 44 | gen_config: Type 45 | 46 | def __init__(self, config: Tuple): 47 | """Protocol for constructor of EnvSpec.""" 48 | 49 | @property 50 | def _state_spec(self) -> Tuple: 51 | """Cpp private _state_spec.""" 52 | 53 | @property 54 | def _action_spec(self) -> Tuple: 55 | """Cpp private _action_spec.""" 56 | 57 | @property 58 | def _state_keys(self) -> List: 59 | """Cpp private _state_keys.""" 60 | 61 | @property 62 | def _action_keys(self) -> List: 63 | """Cpp private _action_keys.""" 64 | 65 | @property 66 | def _config_values(self) -> Tuple: 67 | """Cpp private _config_values.""" 68 | 69 | @property 70 | def config(self) -> NamedTuple: 71 | """Configuration used to create the current EnvSpec.""" 72 | 73 | @property 74 | def state_array_spec(self) -> Dict[str, Any]: 75 | """Specs of the states of the environment in ArraySpec format.""" 76 | 77 | @property 78 | def action_array_spec(self) -> Dict[str, Any]: 79 | """Specs of the actions of the environment in ArraySpec format.""" 80 | 81 | def observation_spec(self) -> Dict[str, Any]: 82 | """Specs of the observations of the environment in dm_env format.""" 83 | 84 | def action_spec(self) -> Union[dm_env.specs.Array, Dict[str, Any]]: 85 | """Specs of the actions of the environment in dm_env format.""" 86 | 87 | @property 88 | def observation_space(self) -> Dict[str, Any]: 89 | """Specs of the observations of the environment in gym.Env format.""" 90 | 91 | @property 92 | def action_space(self) -> Union[gym.Space, Dict[str, Any]]: 93 | """Specs of the actions of the environment in gym.Env format.""" 94 | 95 | @property 96 | def reward_threshold(self) -> Optional[float]: 97 | """Reward threshold, None for no threshold.""" 98 | 99 | 100 | class ArraySpec(object): 101 | """Spec of numpy array.""" 102 | 103 | def __init__( 104 | self, dtype: Type, shape: List[int], bounds: Tuple[Any, Any], 105 | element_wise_bounds: Tuple[Any, Any] 106 | ): 107 | """Constructor of ArraySpec.""" 108 | self.dtype = dtype 109 | self.shape = shape 110 | if element_wise_bounds[0]: 111 | self.minimum = np.array(element_wise_bounds[0]) 112 | else: 113 | self.minimum = bounds[0] 114 | if element_wise_bounds[1]: 115 | self.maximum = np.array(element_wise_bounds[1]) 116 | else: 117 | self.maximum = bounds[1] 118 | 119 | def __repr__(self) -> str: 120 | """Beautify debug info.""" 121 | return ( 122 | f"ArraySpec(shape={self.shape}, dtype={self.dtype}, " 123 | f"minimum={self.minimum}, maximum={self.maximum})" 124 | ) 125 | 126 | 127 | class EnvPool(Protocol): 128 | """Cpp PyEnvpool class interface.""" 129 | 130 | _state_keys: List[str] 131 | _action_keys: List[str] 132 | spec: Any 133 | 134 | def __init__(self, spec: EnvSpec): 135 | """Constructor of EnvPool.""" 136 | 137 | def __len__(self) -> int: 138 | """Return the number of environments.""" 139 | 140 | @property 141 | def _spec(self) -> EnvSpec: 142 | """Cpp env spec.""" 143 | 144 | @property 145 | def _action_spec(self) -> List: 146 | """Cpp action spec.""" 147 | 148 | def _check_action(self, actions: List) -> None: 149 | """Check action shapes.""" 150 | 151 | def _recv(self) -> List[np.ndarray]: 152 | """Cpp private _recv method.""" 153 | 154 | def _send(self, action: List[np.ndarray]) -> None: 155 | """Cpp private _send method.""" 156 | 157 | def _reset(self, env_id: np.ndarray) -> None: 158 | """Cpp private _reset method.""" 159 | 160 | def _from( 161 | self, 162 | action: Union[Dict[str, Any], np.ndarray], 163 | env_id: Optional[np.ndarray] = None, 164 | ) -> List[np.ndarray]: 165 | """Convertion for input action.""" 166 | 167 | def _to( 168 | self, 169 | state: List[np.ndarray], 170 | reset: bool, 171 | return_info: bool, 172 | ) -> Union[TimeStep, Tuple]: 173 | """A switch of to_dm and to_gym for output state.""" 174 | 175 | @property 176 | def all_env_ids(self) -> np.ndarray: 177 | """All env_id in numpy ndarray with dtype=np.int32.""" 178 | 179 | @property 180 | def is_async(self) -> bool: 181 | """Return if this env is in sync mode or async mode.""" 182 | 183 | @property 184 | def observation_space(self) -> Union[gym.Space, Dict[str, Any]]: 185 | """Gym observation space.""" 186 | 187 | @property 188 | def action_space(self) -> Union[gym.Space, Dict[str, Any]]: 189 | """Gym action space.""" 190 | 191 | def observation_spec(self) -> Tuple: 192 | """Dm observation spec.""" 193 | 194 | def action_spec(self) -> Union[dm_env.specs.Array, Tuple]: 195 | """Dm action spec.""" 196 | 197 | def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: 198 | """Set the seed for all environments.""" 199 | 200 | @property 201 | def config(self) -> Dict[str, Any]: 202 | """Envpool config.""" 203 | 204 | def send( 205 | self, 206 | action: Union[Dict[str, Any], np.ndarray], 207 | env_id: Optional[np.ndarray] = None, 208 | ) -> None: 209 | """Envpool send wrapper.""" 210 | 211 | def recv( 212 | self, 213 | reset: bool = False, 214 | return_info: bool = True, 215 | ) -> Union[TimeStep, Tuple]: 216 | """Envpool recv wrapper.""" 217 | 218 | def async_reset(self) -> None: 219 | """Envpool async reset interface.""" 220 | 221 | def step( 222 | self, 223 | action: Union[Dict[str, Any], np.ndarray], 224 | env_id: Optional[np.ndarray] = None, 225 | ) -> Union[TimeStep, Tuple]: 226 | """Envpool step interface that performs send/recv.""" 227 | 228 | def reset( 229 | self, 230 | env_id: Optional[np.ndarray] = None, 231 | ) -> Union[TimeStep, Tuple]: 232 | """Envpool reset interface.""" 233 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/python/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Helper function for Python API.""" 15 | 16 | from typing import Any, List 17 | 18 | import numpy as np 19 | 20 | 21 | def check_key_duplication(cls: Any, keytype: str, keys: List[str]) -> None: 22 | """Check if there's any duplicated keys in ``keys``.""" 23 | ukeys, counts = np.unique(keys, return_counts=True) 24 | if not np.all(counts == 1): 25 | dup_keys = ukeys[counts > 1] 26 | raise SystemError( 27 | f"{cls} c++ code error. {keytype} keys {list(dup_keys)} are duplicated. " 28 | f"Please report to the author of {cls}." 29 | ) 30 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/registration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Garena Online Private Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Global env registry.""" 15 | 16 | import importlib 17 | import os 18 | from typing import Any, Dict, List, Tuple 19 | 20 | import gym 21 | from packaging import version 22 | 23 | base_path = os.path.abspath(os.path.dirname(__file__)) 24 | 25 | 26 | class EnvRegistry: 27 | """A collection of available envs.""" 28 | 29 | def __init__(self) -> None: 30 | """Constructor of EnvRegistry.""" 31 | self.specs: Dict[str, Tuple[str, str, Dict[str, Any]]] = {} 32 | self.envpools: Dict[str, Dict[str, Tuple[str, str]]] = {} 33 | 34 | def register( 35 | self, task_id: str, import_path: str, spec_cls: str, dm_cls: str, 36 | gym_cls: str, gymnasium_cls: str, **kwargs: Any 37 | ) -> None: 38 | """Register EnvSpec and EnvPool in global EnvRegistry.""" 39 | assert task_id not in self.specs 40 | if "base_path" not in kwargs: 41 | kwargs["base_path"] = base_path 42 | self.specs[task_id] = (import_path, spec_cls, kwargs) 43 | self.envpools[task_id] = { 44 | "dm": (import_path, dm_cls), 45 | "gym": (import_path, gym_cls), 46 | "gymnasium": (import_path, gymnasium_cls) 47 | } 48 | 49 | def make(self, task_id: str, env_type: str, **kwargs: Any) -> Any: 50 | """Make envpool.""" 51 | new_gym_api = version.parse(gym.__version__) >= version.parse("0.26.0") 52 | if "gym_reset_return_info" not in kwargs: 53 | kwargs["gym_reset_return_info"] = new_gym_api 54 | if new_gym_api and not kwargs["gym_reset_return_info"]: 55 | raise ValueError( 56 | "You are using gym>=0.26.0 but passed `gym_reset_return_info=False`. " 57 | "The new gym API requires environments to return an info dictionary " 58 | "after resets." 59 | ) 60 | 61 | assert task_id in self.specs, \ 62 | f"{task_id} is not supported, `envpool.list_all_envs()` may help." 63 | assert env_type in ["dm", "gym", "gymnasium"] 64 | 65 | spec = self.make_spec(task_id, **kwargs) 66 | import_path, envpool_cls = self.envpools[task_id][env_type] 67 | return getattr(importlib.import_module(import_path), envpool_cls)(spec) 68 | 69 | def make_dm(self, task_id: str, **kwargs: Any) -> Any: 70 | """Make dm_env compatible envpool.""" 71 | return self.make(task_id, "dm", **kwargs) 72 | 73 | def make_gym(self, task_id: str, **kwargs: Any) -> Any: 74 | """Make gym.Env compatible envpool.""" 75 | return self.make(task_id, "gym", **kwargs) 76 | 77 | def make_gymnasium(self, task_id: str, **kwargs: Any) -> Any: 78 | """Make gymnasium.Env compatible envpool.""" 79 | return self.make(task_id, "gymnasium", **kwargs) 80 | 81 | def make_spec(self, task_id: str, **make_kwargs: Any) -> Any: 82 | """Make EnvSpec.""" 83 | import_path, spec_cls, kwargs = self.specs[task_id] 84 | kwargs = {**kwargs, **make_kwargs} 85 | 86 | # check arguments 87 | if "seed" in kwargs: # Issue 214 88 | INT_MAX = 2**31 89 | assert -INT_MAX <= kwargs["seed"] < INT_MAX, \ 90 | f"Seed should be in range of int32, got {kwargs['seed']}" 91 | if "num_envs" in kwargs: 92 | assert kwargs["num_envs"] >= 1 93 | if "batch_size" in kwargs: 94 | assert 0 <= kwargs["batch_size"] <= kwargs["num_envs"] 95 | if "max_num_players" in kwargs: 96 | assert 1 <= kwargs["max_num_players"] 97 | 98 | spec_cls = getattr(importlib.import_module(import_path), spec_cls) 99 | config = spec_cls.gen_config(**kwargs) 100 | return spec_cls(config) 101 | 102 | def list_all_envs(self) -> List[str]: 103 | """Return all available task_id.""" 104 | return list(self.specs.keys()) 105 | 106 | 107 | # use a global EnvRegistry 108 | registry = EnvRegistry() 109 | register = registry.register 110 | make = registry.make 111 | make_dm = registry.make_dm 112 | make_gym = registry.make_gym 113 | make_gymnasium = registry.make_gymnasium 114 | make_spec = registry.make_spec 115 | list_all_envs = registry.list_all_envs 116 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/ygopro/__init__.py: -------------------------------------------------------------------------------- 1 | from ygoenv.python.api import py_env 2 | 3 | from .ygopro_ygoenv import ( 4 | _YGOProEnvPool, 5 | _YGOProEnvSpec, 6 | init_module, 7 | ) 8 | 9 | ( 10 | YGOProEnvSpec, 11 | YGOProDMEnvPool, 12 | YGOProGymEnvPool, 13 | YGOProGymnasiumEnvPool, 14 | ) = py_env(_YGOProEnvSpec, _YGOProEnvPool) 15 | 16 | 17 | __all__ = [ 18 | "YGOProEnvSpec", 19 | "YGOProDMEnvPool", 20 | "YGOProGymEnvPool", 21 | "YGOProGymnasiumEnvPool", 22 | ] 23 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/ygopro/registration.py: -------------------------------------------------------------------------------- 1 | from ygoenv.registration import register 2 | 3 | register( 4 | task_id="YGOPro-v1", 5 | import_path="ygoenv.ygopro", 6 | spec_cls="YGOProEnvSpec", 7 | dm_cls="YGOProDMEnvPool", 8 | gym_cls="YGOProGymEnvPool", 9 | gymnasium_cls="YGOProGymnasiumEnvPool", 10 | ) 11 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/ygopro/ygopro.cpp: -------------------------------------------------------------------------------- 1 | #include "ygoenv/ygopro/ygopro.h" 2 | #include "ygoenv/core/py_envpool.h" 3 | 4 | using YGOProEnvSpec = PyEnvSpec; 5 | using YGOProEnvPool = PyEnvPool; 6 | 7 | PYBIND11_MODULE(ygopro_ygoenv, m) { 8 | REGISTER(m, YGOProEnvSpec, YGOProEnvPool) 9 | 10 | m.def("init_module", &ygopro::init_module); 11 | } 12 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/ygopro0/__init__.py: -------------------------------------------------------------------------------- 1 | from ygoenv.python.api import py_env 2 | 3 | from .ygopro0_ygoenv import ( 4 | _YGOPro0EnvPool, 5 | _YGOPro0EnvSpec, 6 | init_module, 7 | ) 8 | 9 | ( 10 | YGOPro0EnvSpec, 11 | YGOPro0DMEnvPool, 12 | YGOPro0GymEnvPool, 13 | YGOPro0GymnasiumEnvPool, 14 | ) = py_env(_YGOPro0EnvSpec, _YGOPro0EnvPool) 15 | 16 | 17 | __all__ = [ 18 | "YGOPro0EnvSpec", 19 | "YGOPro0DMEnvPool", 20 | "YGOPro0GymEnvPool", 21 | "YGOPro0GymnasiumEnvPool", 22 | ] 23 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/ygopro0/registration.py: -------------------------------------------------------------------------------- 1 | from ygoenv.registration import register 2 | 3 | register( 4 | task_id="YGOPro-v0", 5 | import_path="ygoenv.ygopro0", 6 | spec_cls="YGOPro0EnvSpec", 7 | dm_cls="YGOPro0DMEnvPool", 8 | gym_cls="YGOPro0GymEnvPool", 9 | gymnasium_cls="YGOPro0GymnasiumEnvPool", 10 | ) 11 | -------------------------------------------------------------------------------- /ygoenv/ygoenv/ygopro0/ygopro.cpp: -------------------------------------------------------------------------------- 1 | #include "ygoenv/ygopro0/ygopro.h" 2 | #include "ygoenv/core/py_envpool.h" 3 | 4 | using YGOPro0EnvSpec = PyEnvSpec; 5 | using YGOPro0EnvPool = PyEnvPool; 6 | 7 | PYBIND11_MODULE(ygopro0_ygoenv, m) { 8 | REGISTER(m, YGOPro0EnvSpec, YGOPro0EnvPool) 9 | 10 | m.def("init_module", &ygopro0::init_module); 11 | } 12 | -------------------------------------------------------------------------------- /ygoinf/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | __version__ = "0.0.1" 4 | 5 | INSTALL_REQUIRES = [ 6 | "numpy==1.26.4", 7 | "optree", 8 | "fastapi", 9 | "uvicorn[standard]", 10 | "pydantic_settings", 11 | "tflite-runtime", 12 | ] 13 | 14 | setup( 15 | name="ygoinf", 16 | version=__version__, 17 | packages=find_packages(include='ygoinf*'), 18 | long_description="", 19 | install_requires=INSTALL_REQUIRES, 20 | python_requires=">=3.10", 21 | ) -------------------------------------------------------------------------------- /ygoinf/ygoinf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sbl1996/ygo-agent/dbf5142d49aab2e6beb4150788d4fffec39ae3e5/ygoinf/ygoinf/__init__.py -------------------------------------------------------------------------------- /ygoinf/ygoinf/jax_inf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import jax 5 | from jax.experimental.compilation_cache import compilation_cache as cc 6 | cc.set_cache_dir(os.path.expanduser("~/.cache/jax")) 7 | 8 | import jax.numpy as jnp 9 | import flax 10 | from ygoai.rl.jax.agent import RNNAgent 11 | 12 | def create_agent(): 13 | return RNNAgent( 14 | num_layers=2, 15 | rnn_channels=512, 16 | use_history=True, 17 | rnn_type='lstm', 18 | num_channels=128, 19 | film=True, 20 | noam=True, 21 | version=2, 22 | ) 23 | 24 | 25 | @jax.jit 26 | def get_probs_and_value(params, rstate, obs): 27 | agent = create_agent() 28 | next_rstate, logits, value = agent.apply(params, obs, rstate)[:3] 29 | probs = jax.nn.softmax(logits, axis=-1) 30 | return next_rstate, probs, value 31 | 32 | 33 | def predict_fn(params, rstate, obs): 34 | obs = jax.tree.map(lambda x: jnp.array([x]), obs) 35 | rstate, probs, value = get_probs_and_value(params, rstate, obs) 36 | return rstate, np.array(probs)[0].tolist(), float(np.array(value)[0]) 37 | 38 | def load_model(checkpoint, rstate, sample_obs, **kwargs): 39 | agent = create_agent() 40 | key = jax.random.PRNGKey(0) 41 | key, agent_key = jax.random.split(key, 2) 42 | sample_obs_ = jax.tree.map(lambda x: jnp.array([x]), sample_obs) 43 | params = jax.jit(agent.init)(agent_key, sample_obs_, rstate) 44 | with open(checkpoint, "rb") as f: 45 | params = flax.serialization.from_bytes(params, f.read()) 46 | 47 | params = jax.device_put(params) 48 | return params 49 | -------------------------------------------------------------------------------- /ygoinf/ygoinf/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ.setdefault("JAX_PLATFORMS", "cpu") 3 | from typing import Union, Dict 4 | 5 | import time 6 | import threading 7 | import uuid 8 | from contextlib import asynccontextmanager 9 | 10 | from fastapi import FastAPI, Path 11 | from fastapi.responses import PlainTextResponse 12 | from fastapi.middleware.cors import CORSMiddleware 13 | from pydantic import Field 14 | from pydantic_settings import BaseSettings 15 | 16 | 17 | from .models import ( 18 | DuelCreateResponse, 19 | DuelPredictRequest, 20 | DuelPredictResponse, 21 | DuelPredictErrorResponse, 22 | ) 23 | from .features import predict, init_code_list, PredictState, Predictor 24 | 25 | 26 | class Settings(BaseSettings): 27 | code_list: str = "code_list.txt" 28 | checkpoint: str = "latest.flax_model" 29 | enable_cors: bool = Field(default=True, description="Enable CORS") 30 | state_expire: int = Field(default=3600, description="Duel state expire time in seconds") 31 | test_duel_id: str = Field(default="9654823a-23fd-4850-bb-6fec241740b0", description="Test duel id") 32 | ygo_num_threads: int = Field(default=1, description="Number of threads to use for YGO prediction") 33 | 34 | settings = Settings() 35 | 36 | all_models = {} 37 | duel_states: Dict[str, PredictState] = {} 38 | 39 | def delete_outdated_states(): 40 | while True: 41 | current_time = time.time() 42 | for k, v in list(duel_states.items()): 43 | if k == settings.test_duel_id: 44 | continue 45 | if current_time - v._timestamp > settings.state_expire: 46 | del duel_states[k] 47 | time.sleep(600) 48 | 49 | # Start the thread to delete outdated states 50 | thread = threading.Thread(target=delete_outdated_states) 51 | thread.daemon = True 52 | thread.start() 53 | 54 | @asynccontextmanager 55 | async def lifespan(app: FastAPI): 56 | init_code_list(settings.code_list) 57 | 58 | checkpoint = settings.checkpoint 59 | predictor = Predictor.load(checkpoint, settings.ygo_num_threads) 60 | all_models["default"] = predictor 61 | print(f"loaded checkpoint from {checkpoint}") 62 | 63 | state = new_state() 64 | test_duel_id = settings.test_duel_id 65 | duel_states[test_duel_id] = state 66 | 67 | yield 68 | # Clean up the ML models and release the resources 69 | all_models.clear() 70 | 71 | 72 | app = FastAPI( 73 | lifespan=lifespan, 74 | ) 75 | 76 | if settings.enable_cors: 77 | app.add_middleware( 78 | CORSMiddleware, 79 | allow_origins=["*"], 80 | allow_credentials=True, 81 | allow_methods=["*"], 82 | allow_headers=["*"], 83 | ) 84 | 85 | def new_state(): 86 | return PredictState() 87 | 88 | @app.get('/', status_code=200, response_class=PlainTextResponse) 89 | async def root(): 90 | return "OK" 91 | 92 | 93 | @app.post('/v0/duels', response_model=DuelCreateResponse) 94 | async def create_duel() -> DuelCreateResponse: 95 | """ 96 | Create duel 97 | """ 98 | duel_id = str(uuid.uuid4()) 99 | state = new_state() 100 | duel_states[duel_id] = state 101 | return DuelCreateResponse(duelId=duel_id, index=state.index) 102 | 103 | 104 | @app.delete('/v0/duels/{duelId}', status_code=204) 105 | async def delete_duel( 106 | duel_id: str = Path(..., alias='duelId') 107 | ) -> None: 108 | """ 109 | Delete duel 110 | """ 111 | if duel_id in duel_states: 112 | duel_states.pop(duel_id) 113 | 114 | 115 | @app.post( 116 | '/v0/duels/{duelId}/predict', 117 | ) 118 | async def duel_predict( 119 | duel_id: str = Path(..., alias='duelId'), body: DuelPredictRequest = None 120 | ) -> Union[DuelPredictResponse, DuelPredictErrorResponse]: 121 | index = body.index 122 | if duel_id not in duel_states: 123 | return DuelPredictErrorResponse( 124 | error=f"duel {duel_id} not found" 125 | ) 126 | duel_state = duel_states[duel_id] 127 | if index != duel_state.index: 128 | return DuelPredictErrorResponse( 129 | error=f"index mismatch: expected {duel_state.index}, got {index}" 130 | ) 131 | 132 | predictor = all_models["default"] 133 | model_fn = predictor.predict 134 | 135 | _start = time.time() 136 | try: 137 | predict_results = predict(model_fn, body.input, body.prev_action_idx, duel_state) 138 | except (KeyError, NotImplementedError) as e: 139 | return DuelPredictErrorResponse( 140 | error=f"{e}" 141 | ) 142 | predict_time = time.time() - _start 143 | 144 | print(f"predict time: {predict_time:.3f}") 145 | return DuelPredictResponse( 146 | index=duel_state.index, 147 | predict_results=predict_results, 148 | ) 149 | -------------------------------------------------------------------------------- /ygoinf/ygoinf/tflite_inf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import optree 3 | import tflite_runtime.interpreter as tf_lite 4 | 5 | def tflite_predict(interpreter, rstate, obs): 6 | input_details = interpreter.get_input_details() 7 | output_details = interpreter.get_output_details() 8 | 9 | inputs = rstate, obs 10 | for i, x in enumerate(optree.tree_leaves(inputs)): 11 | interpreter.set_tensor(input_details[i]["index"], x) 12 | interpreter.invoke() 13 | results = [ 14 | interpreter.get_tensor(o["index"]) for o in output_details] 15 | rstate1, rstate2, probs, value = results 16 | rstate = (rstate1, rstate2) 17 | return rstate, probs, value 18 | 19 | def predict_fn(interpreter, rstate, obs): 20 | obs = optree.tree_map(lambda x: np.array([x]), obs) 21 | rstate, probs, value = tflite_predict(interpreter, rstate, obs) 22 | prob = probs[0].tolist() 23 | value = float(value[0]) 24 | return rstate, prob, value 25 | 26 | def load_model(checkpoint, *args, **kwargs): 27 | with open(checkpoint, "rb") as f: 28 | tflite_model = f.read() 29 | interpreter = tf_lite.Interpreter( 30 | model_content=tflite_model, num_threads=kwargs.get("num_threads", 1)) 31 | interpreter.allocate_tensors() 32 | return interpreter 33 | --------------------------------------------------------------------------------