├── .github └── workflows │ ├── codeql-analysis.yml │ └── stale.yml ├── .gitignore ├── LICENSE ├── README.md ├── configs ├── eval-diamond.yaml ├── eval-treechop.yaml ├── train-diamond.yaml └── train-treechop.yaml ├── demonstrations └── .gitkeep ├── docker ├── build.sh ├── dockerfile └── requirements.txt ├── hierarchy ├── __init__.py ├── subtask_agent.py └── subtasks_extraction.py ├── main.py ├── policy ├── __init__.py ├── agent.py ├── models.py ├── replay_buffer.py ├── sum_tree.py └── tf1_models.py ├── static └── forging.png ├── train └── .gitkeep └── utils ├── __init__.py ├── config_validation.py ├── discretization.py ├── fake_env.py ├── load_demonstrations.py ├── load_weights.py ├── tf_util.py └── wrappers.py /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ master ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ master ] 20 | schedule: 21 | - cron: '20 19 * * 0' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://git.io/codeql-language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 52 | 53 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 54 | # If this step fails, then you should remove it and run the build manually (see below) 55 | - name: Autobuild 56 | uses: github/codeql-action/autobuild@v2 57 | 58 | # ℹ️ Command-line programs to run using the OS shell. 59 | # 📚 https://git.io/JvXDl 60 | 61 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 62 | # and modify them (or add more) to build your code if your project 63 | # uses a compiled language 64 | 65 | #- run: | 66 | # make bootstrap 67 | # make release 68 | 69 | - name: Perform CodeQL Analysis 70 | uses: github/codeql-action/analyze@v2 71 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time. 2 | # 3 | # You can adjust the behavior by modifying this file. 4 | # For more information, see: 5 | # https://github.com/actions/stale 6 | name: Mark stale issues and pull requests 7 | 8 | on: 9 | schedule: 10 | - cron: '31 15 * * *' 11 | 12 | jobs: 13 | stale: 14 | 15 | runs-on: ubuntu-latest 16 | permissions: 17 | issues: write 18 | pull-requests: write 19 | 20 | steps: 21 | - uses: actions/stale@v5 22 | with: 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | stale-issue-message: 'Stale issue message' 25 | stale-pr-message: 'Stale pull request message' 26 | stale-issue-label: 'no-issue-activity' 27 | stale-pr-label: 'no-pr-activity' 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.idea 3 | *.ipynb_checkpoints 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Tviskaron 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ForgER: Forgetful Expirience Replay for Reinforcement Learning from Demonstrations 2 | 3 | This repository is the TF2.0 implementation of Forgetful Replay Buffer for Reinforcement Learning from Demonstrations by [Alexey Skrynnik](https://github.com/Tviskaron), [Aleksey Staroverov](https://github.com/alstar8), [Ermek Aitygulov](https://github.com/ermekaitygulov), [Kirill Aksenov](https://github.com/axellkir), [Vasilii Davydov](https://github.com/dexfrost89), [Aleksandr I. Panov](https://github.com/grafft). 4 | 5 | 6 | [[Paper]](https://www.sciencedirect.com/science/article/pii/S0950705121001076) [[arxiv]](https://arxiv.org/abs/2006.09939) [[Webpage]](https://sites.google.com/view/forgetful-experience-replay) 7 | 8 | ![Forging phase](static/forging.png) 9 | 10 | 11 | 12 | ## Installation 13 | 14 | To install requirements: 15 | 16 | ```setup 17 | pip install -r docker/requirements.txt 18 | ``` 19 | 20 | ## Evaluating ForgER++ 21 | 22 | To download pretrained weights: 23 | ```shell 24 | python utils/load_weights.py 25 | ``` 26 | 27 | To run evaluation in *ObtainDiamond* task: 28 | ```shell 29 | python main.py --config configs/eval-diamond.yaml 30 | ``` 31 | 32 | To run evaluation in *Treechop* task: 33 | ```shell 34 | python main.py --config configs/eval-treechop.yaml 35 | ``` 36 | 37 | 38 | ## Training 39 | 40 | Downloading MineRL dataset: 41 | 42 | ```train 43 | python utils/load_demonstrations.py 44 | ``` 45 | 46 | Training ForgER on *Treechop* task: 47 | 48 | ```train 49 | python main.py --config configs/train-treechop.yaml 50 | ``` 51 | 52 | Training ForgER on *ObtainDiamondDense* task: 53 | 54 | ```train 55 | python main.py --config configs/train-diamond.yaml 56 | ``` 57 | **Caution:** We didn't test reproducibility of results after moving to TF2 version and updating code for MineRL version 0.4. 58 | 59 | 60 | ## Results on MineRLObtainDiamond-v0 (1000 seeds) 61 | 62 | | Item | MineRL2019 | ForgER | ForgER++| 63 | | --- | --- | --- | --- | 64 | | log | 859 | **882** | 867 | 65 | | planks | 805 | **806** | 792 | 66 | | stick | 718 | 747 | **790** | 67 | | crafting table | 716 | 744 | **790** | 68 | | wooden pickaxe | 713 | 744 | **789** | 69 | | cobblestone | 687 | 730 | **779** | 70 | | stone pickaxe | 642 | 698 | **751** | 71 | | furnace | 19 | 48 | **98** | 72 | | iron ore | 96 | 109 | **231** | 73 | | iron ingot | 19 | 48 | **98** | 74 | | iron pickaxe | 12 | 43 | **83** | 75 | | diamond | 0 | 0 | **1** | 76 | | mean reward | 57.701 | 74.09 | **104.315** | 77 | 78 | ## Citation 79 | If you use this repo in your research, please consider citing the paper as follows: 80 | ``` 81 | @article{skrynnik2021forgetful, 82 | title={Forgetful experience replay in hierarchical reinforcement learning from expert demonstrations}, 83 | author={Skrynnik, Alexey and Staroverov, Aleksey and Aitygulov, Ermek and Aksenov, Kirill and Davydov, Vasilii and Panov, Aleksandr I}, 84 | journal={Knowledge-Based Systems}, 85 | volume={218}, 86 | pages={106844}, 87 | year={2021}, 88 | publisher={Elsevier} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /configs/eval-diamond.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | - environment: MineRLObtainDiamondDense-v0 3 | from_scratch: False 4 | agent_type: hierarchical 5 | max_train_steps: 1000000 6 | source: agent 7 | cfg: 8 | wrappers: 9 | render: True 10 | agent: 11 | save_dir: 'train/forger++/' 12 | initial_epsilon: 0.01 13 | final_epsilon: 0.01 14 | subtasks: 15 | - item_count: 7 16 | item_name: log 17 | 18 | - actions: 19 | - name: craft 20 | target: planks 21 | 22 | item_count: 28 23 | item_name: planks 24 | 25 | - actions: 26 | - name: craft 27 | target: crafting_table 28 | 29 | item_count: 1 30 | item_name: crafting_table 31 | 32 | - actions: 33 | - name: craft 34 | target: stick 35 | item_count: 8 36 | item_name: stick 37 | 38 | - actions: 39 | - name: place 40 | target: crafting_table 41 | - name: nearbyCraft 42 | target: wooden_pickaxe 43 | 44 | item_count: 1 45 | item_name: wooden_pickaxe 46 | 47 | - actions: 48 | - name: craft 49 | target: crafting_table 50 | item_count: 1 51 | item_name: crafting_table 52 | 53 | - actions: 54 | - name: equip 55 | target: wooden_pickaxe 56 | 57 | item_count: 14 58 | item_name: cobblestone 59 | 60 | - actions: 61 | - name: craft 62 | target: crafting_table 63 | item_count: 1 64 | item_name: crafting_table 65 | 66 | - actions: 67 | - name: place 68 | target: crafting_table 69 | - name: nearbyCraft 70 | target: stone_pickaxe 71 | 72 | item_count: 1 73 | item_name: stone_pickaxe 74 | 75 | - actions: 76 | - name: equip 77 | target: stone_pickaxe 78 | 79 | item_count: 3 80 | item_name: iron_ore 81 | 82 | 83 | - actions: 84 | - name: craft 85 | target: crafting_table 86 | item_count: 1 87 | item_name: crafting_table 88 | 89 | - actions: 90 | - name: place 91 | target: crafting_table 92 | - name: nearbyCraft 93 | target: furnace 94 | 95 | item_count: 1 96 | item_name: furnace 97 | 98 | - actions: 99 | - name: place 100 | target: furnace 101 | 102 | item_count: 3 103 | item_name: iron_ore 104 | 105 | - actions: 106 | - name: nearbySmelt 107 | target: iron_ingot 108 | 109 | item_count: 3 110 | item_name: iron_ingot 111 | 112 | - actions: 113 | - name: nearbyCraft 114 | target: iron_pickaxe 115 | 116 | item_count: 1 117 | item_name: iron_pickaxe 118 | 119 | - action: 120 | - name: equip 121 | target: iron_pickaxe 122 | item_count: 1 123 | item_name: diamond 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /configs/eval-treechop.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | 3 | - environment: MineRLTreechop-v0 4 | max_train_steps: 1000000 5 | evaluation: True 6 | source: agent 7 | agent_type: flat 8 | cfg: 9 | wrappers: 10 | render: True 11 | agent: 12 | initial_epsilon: 0.01 13 | save_dir: train/forger++/log/ 14 | -------------------------------------------------------------------------------- /configs/train-diamond.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | - environment: MineRLTreechop-v0 3 | from_scratch: True 4 | pretrain_num_updates: 100000 5 | 6 | source: expert 7 | agent_type: flat 8 | cfg: 9 | agent: 10 | save_dir: "train/diamond/log" 11 | 12 | - environment: MineRLTreechop-v0 13 | from_scratch: False 14 | max_train_steps: 1000000 15 | 16 | source: agent 17 | agent_type: flat 18 | cfg: 19 | agent: 20 | save_dir: "train/diamond/log" 21 | 22 | 23 | - environment: MineRLObtainDiamondDense-v0 24 | from_scratch: False 25 | agent_type: hierarchical 26 | max_train_steps: 1000000 27 | source: agent 28 | cfg: 29 | wrappers: 30 | render: True 31 | agent: 32 | save_dir: 'train/diamond/' 33 | initial_epsilon: 0.05 34 | final_epsilon: 0.01 35 | subtasks: 36 | - item_count: 7 37 | item_name: log 38 | 39 | - actions: 40 | - name: craft 41 | target: planks 42 | 43 | item_count: 28 44 | item_name: planks 45 | 46 | - actions: 47 | - name: craft 48 | target: crafting_table 49 | 50 | item_count: 1 51 | item_name: crafting_table 52 | 53 | - actions: 54 | - name: craft 55 | target: stick 56 | item_count: 8 57 | item_name: stick 58 | 59 | - actions: 60 | - name: place 61 | target: crafting_table 62 | - name: nearbyCraft 63 | target: wooden_pickaxe 64 | 65 | item_count: 1 66 | item_name: wooden_pickaxe 67 | 68 | - actions: 69 | - name: craft 70 | target: crafting_table 71 | item_count: 1 72 | item_name: crafting_table 73 | 74 | - actions: 75 | - name: equip 76 | target: wooden_pickaxe 77 | 78 | item_count: 14 79 | item_name: cobblestone 80 | 81 | - actions: 82 | - name: craft 83 | target: crafting_table 84 | item_count: 1 85 | item_name: crafting_table 86 | 87 | - actions: 88 | - name: place 89 | target: crafting_table 90 | - name: nearbyCraft 91 | target: stone_pickaxe 92 | 93 | item_count: 1 94 | item_name: stone_pickaxe 95 | 96 | - actions: 97 | - name: equip 98 | target: stone_pickaxe 99 | 100 | item_count: 3 101 | item_name: iron_ore 102 | 103 | 104 | - actions: 105 | - name: craft 106 | target: crafting_table 107 | item_count: 1 108 | item_name: crafting_table 109 | 110 | - actions: 111 | - name: place 112 | target: crafting_table 113 | - name: nearbyCraft 114 | target: furnace 115 | 116 | item_count: 1 117 | item_name: furnace 118 | 119 | - actions: 120 | - name: place 121 | target: furnace 122 | 123 | item_count: 3 124 | item_name: iron_ore 125 | 126 | - actions: 127 | - name: nearbySmelt 128 | target: iron_ingot 129 | 130 | item_count: 3 131 | item_name: iron_ingot 132 | 133 | - actions: 134 | - name: nearbyCraft 135 | target: iron_pickaxe 136 | 137 | item_count: 1 138 | item_name: iron_pickaxe 139 | 140 | - action: 141 | - name: equip 142 | target: iron_pickaxe 143 | item_count: 1 144 | item_name: diamond 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /configs/train-treechop.yaml: -------------------------------------------------------------------------------- 1 | pipeline: 2 | - environment: MineRLTreechop-v0 3 | from_scratch: True 4 | pretrain_num_updates: 100000 5 | 6 | source: expert 7 | agent_type: flat 8 | cfg: 9 | agent: 10 | save_dir: train/treechop/ 11 | 12 | - environment: MineRLTreechop-v0 13 | max_train_steps: 1000000 14 | 15 | source: agent 16 | agent_type: flat 17 | cfg: 18 | agent: 19 | save_dir: train/treechop/ 20 | -------------------------------------------------------------------------------- /demonstrations/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cog-model/forger/4e4258f358094b358db15b20f04c8197eb3bf63e/demonstrations/.gitkeep -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | docker build -t forger-v2 . 2 | -------------------------------------------------------------------------------- /docker/dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:latest-gpu 2 | 3 | RUN apt-get update && apt-get install software-properties-common -y 4 | RUN add-apt-repository -y ppa:openjdk-r/ppa && apt-get update && apt-get install -y openjdk-8-jdk 5 | 6 | ADD requirements.txt /tmp/ 7 | RUN pip install -r /tmp/requirements.txt 8 | -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | minerl>=0.4 2 | opencv-python 3 | PyYAML 4 | pydantic 5 | scipy 6 | pyglet 7 | tensorflow>=2.8.0 8 | wandb -------------------------------------------------------------------------------- /hierarchy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cog-model/forger/4e4258f358094b358db15b20f04c8197eb3bf63e/hierarchy/__init__.py -------------------------------------------------------------------------------- /hierarchy/subtask_agent.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from copy import deepcopy 3 | import gym 4 | import pathlib 5 | 6 | from typing import List, Dict 7 | 8 | from policy.agent import create_flat_agent, Agent 9 | from utils.config_validation import Task, Subtask, Action 10 | from utils.wrappers import wrap_env 11 | 12 | 13 | class LoopCraftingAgent: 14 | def __init__(self, crafting_actions): 15 | self.crafting_actions = crafting_actions 16 | self.current_action_index = 0 17 | 18 | def get_crafting_action(self): 19 | if len(self.crafting_actions) == 0: 20 | return {} 21 | action: Action = self.crafting_actions[self.current_action_index] 22 | self.current_action_index = (self.current_action_index + 1) % len(self.crafting_actions) 23 | return {action.name: action.target} 24 | 25 | 26 | class CraftInnerWrapper(gym.Wrapper): 27 | def __init__(self, env, crafting_agent): 28 | super().__init__(env) 29 | self.crafting_agent = crafting_agent 30 | 31 | def step(self, action): 32 | return self.env.step({**action, **self.crafting_agent.get_crafting_action()}) 33 | 34 | 35 | class InnerEnvWrapper(gym.Wrapper): 36 | def __init__(self, env, item, count, last_observation): 37 | super().__init__(env) 38 | self.item = item 39 | self.count = count 40 | self.previous_count = None 41 | self.is_core_env_done = None 42 | self.last_observation = last_observation 43 | 44 | def reset(self, **kwargs): 45 | return self.last_observation 46 | 47 | def step(self, action): 48 | state, reward, done, _ = self.env.step(action) 49 | reward = 0 50 | if done: 51 | self.is_core_env_done = True 52 | if self.item not in state['inventory']: 53 | state['inventory'][self.item] = 0 54 | if state['inventory'][self.item] >= self.count: 55 | done = True 56 | 57 | if self.previous_count is not None and self.previous_count < state['inventory'][self.item]: 58 | reward += 1.0 * (state['inventory'][self.item] - self.previous_count) 59 | self.last_observation = state 60 | self.previous_count = state['inventory'][self.item] 61 | return state, reward, done, _ 62 | 63 | 64 | class InventoryPrintWrapper(gym.Wrapper): 65 | def __init__(self, env, items=("log", "cobblestone", "planks", "stick", "wooden_pickaxe", "crafting_table", 66 | "stone_pickaxe", "iron_ore", "furnace", "iron_pickaxe",)): 67 | super().__init__(env) 68 | self.items = items 69 | self.inventory = None 70 | 71 | def get_inventory_info(self): 72 | inventory_info = ":" 73 | for item in self.items: 74 | inventory_info += f"[{item}:{self.inventory[item]}] " 75 | return inventory_info 76 | 77 | def step(self, action): 78 | state, reward, done, _ = self.env.step(action) 79 | inventory = state['inventory'] 80 | if inventory != self.inventory: 81 | self.inventory = inventory 82 | sys.stdout.write('\r' + self.get_inventory_info()) 83 | sys.stdout.flush() 84 | return state, reward, done, _ 85 | 86 | 87 | class ItemAgent: 88 | 89 | def __init__(self, task: Task, nodes_dict=None): 90 | self.nodes_dict = nodes_dict 91 | self.subtasks: List[Subtask] = task.subtasks 92 | self.cfg = task.cfg 93 | self.task = task 94 | self.pov_agents: Dict[str:Agent] = {} 95 | self.agent_tasks = {} 96 | 97 | def train(self, core_env, task: Task, agents_to_train=("log", "cobblestone",)): 98 | 99 | t_env = wrap_env(core_env, task.cfg.wrappers) 100 | for subtask in self.subtasks: 101 | if subtask.item_name in self.pov_agents: 102 | continue 103 | agent_task = deepcopy(task) 104 | agent_task.max_train_episodes = 1 105 | agent_task.cfg.agent.save_dir = str(pathlib.Path(agent_task.cfg.agent.save_dir) / subtask.item_name) 106 | agent_task.evaluation = subtask.item_name not in agents_to_train 107 | 108 | self.agent_tasks[subtask.item_name] = agent_task 109 | self.pov_agents[subtask.item_name] = create_flat_agent(agent_task, t_env) 110 | 111 | for episode in range(1000): 112 | subtask_idx = 0 113 | obs = core_env.reset() 114 | while True: 115 | current_subtask: Subtask = self.subtasks[subtask_idx] 116 | inner_env = InnerEnvWrapper(env=core_env, item=current_subtask.item_name, count=current_subtask.item_count, last_observation=obs) 117 | inner_env = CraftInnerWrapper(inner_env, crafting_agent=LoopCraftingAgent(current_subtask.actions)) 118 | 119 | t_env = wrap_env(inner_env, self.agent_tasks[current_subtask.item_name].cfg.wrappers) 120 | print('\n', current_subtask.item_name, "agent started") 121 | pov_agent = self.pov_agents[current_subtask.item_name] 122 | pov_agent.train(t_env, self.agent_tasks[current_subtask.item_name]) 123 | 124 | subtask_idx += 1 125 | if inner_env.is_core_env_done or subtask_idx >= len(self.subtasks): 126 | break 127 | obs = inner_env.last_observation 128 | -------------------------------------------------------------------------------- /hierarchy/subtasks_extraction.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from time import sleep 3 | 4 | import gym 5 | import minerl 6 | from typing import List 7 | 8 | from utils.fake_env import FakeEnv 9 | from utils.config_validation import Action, Subtask 10 | 11 | 12 | class TrajectoryInformation: 13 | def __init__(self, env_name='MineRLObtainIronPickaxeDense-v0', data_dir='demonstrations/', 14 | trajectory_name='v3_rigid_mustard_greens_monster-11_878-4825'): 15 | data = minerl.data.make(env_name, data_dir) 16 | trajectory = data.load_data(stream_name=trajectory_name) 17 | 18 | self.name = trajectory_name 19 | 20 | current_states, actions, rewards, next_states, dones = [], [], [], [], [] 21 | for current_state, action, reward, next_state, done in trajectory: 22 | current_states.append(current_state) 23 | actions.append(action) 24 | rewards.append(reward) 25 | next_states.append(next_state) 26 | dones.append(done) 27 | 28 | self.trajectory = (current_states, actions, rewards, next_states, dones) 29 | 30 | self.reward = int(sum(rewards)) 31 | self.length = len(rewards) 32 | 33 | if 'Treechop' in env_name: 34 | self.subtasks = [Subtask(item_name='log', item_count=self.reward, start_idx=0, end_idx=len(current_states))] 35 | self.trajectory_by_subtask = {'log': self.trajectory} 36 | else: 37 | self.subtasks = self.extract_subtasks(self.trajectory) 38 | self.trajectory_by_subtask = self.slice_trajectory_by_item(self.extract_subtasks(self.trajectory, compress_items=False)) 39 | 40 | def slice_trajectory_by_item(self, subtasks, minimal_length=4, fix_done=True, scale_rewards=True): 41 | 42 | states, actions, rewards, next_states, dones = self.trajectory 43 | 44 | sliced_states = defaultdict(list) 45 | sliced_actions = defaultdict(list) 46 | sliced_rewards = defaultdict(list) 47 | sliced_next_states = defaultdict(list) 48 | sliced_dones = defaultdict(list) 49 | 50 | result = defaultdict(list) 51 | 52 | for subtask in subtasks: 53 | # skip short ones 54 | if subtask.end_idx - subtask.start_idx < minimal_length: 55 | continue 56 | 57 | if fix_done: 58 | true_dones = [0 for _ in range(subtask.start_idx, subtask.end_idx)] 59 | true_dones[-1] = 1 60 | else: 61 | true_dones = dones[subtask.start_idx:subtask.end_idx] 62 | 63 | if scale_rewards: 64 | true_rewards = [0 for _ in range(subtask.start_idx, subtask.end_idx)] 65 | true_rewards[-1] = next_states[subtask.end_idx]['inventory'][subtask.item_name] - \ 66 | states[subtask.end_idx]['inventory'][subtask.item_name] 67 | else: 68 | true_rewards = rewards[subtask.start_idx:subtask.end_idx] 69 | 70 | sliced_states[subtask.item_name] += states[subtask.start_idx:subtask.end_idx] 71 | sliced_actions[subtask.item_name] += actions[subtask.start_idx:subtask.end_idx] 72 | sliced_rewards[subtask.item_name] += true_rewards 73 | sliced_next_states[subtask.item_name] += next_states[subtask.start_idx:subtask.end_idx] 74 | sliced_dones[subtask.item_name] += true_dones 75 | 76 | unique_item_names = set([item.item_name for item in self.subtasks]) 77 | for item_name in unique_item_names: 78 | result[item_name] = [sliced_states[item_name], 79 | sliced_actions[item_name], 80 | sliced_rewards[item_name], 81 | sliced_next_states[item_name], 82 | sliced_dones[item_name]] 83 | 84 | return result 85 | 86 | @classmethod 87 | def extract_subtasks(cls, trajectory, 88 | excluded_actions=("attack", "back", "camera", 89 | "forward", "jump", "left", 90 | "right", "sneak", "sprint"), 91 | compress_items=True) -> List[Subtask]: 92 | 93 | states, actions, rewards, next_states, _ = trajectory 94 | items = states[0].get('inventory', {}).keys() 95 | 96 | # add fake items to deal with crafting actions 97 | result: List[Subtask] = [Subtask(start_idx=0, end_idx=0)] 98 | 99 | for index in range(len(rewards)): 100 | 101 | for action in actions[index]: 102 | if action not in excluded_actions: 103 | target = str(actions[index][action]) 104 | if target == 'none': 105 | continue 106 | 107 | a = Action(name=action, target=target) 108 | last_subtask = result[-1] 109 | if a.target: 110 | if not last_subtask.actions or last_subtask.actions[-1] != a: 111 | last_subtask.actions.append(a) 112 | for item in items: 113 | if next_states[index]['inventory'][item] > states[index]['inventory'][item]: 114 | s = Subtask(item_name=item, start_idx=result[-1].end_idx, end_idx=index, item_count=next_states[index]['inventory'][item]) 115 | last_subtask = result[-1] 116 | if s.item_name == last_subtask.item_name and compress_items: 117 | # update the required number of items 118 | last_subtask.item_count = s.item_count 119 | last_subtask.end_idx = index 120 | else: 121 | # add new subtask 122 | result.append(s) 123 | 124 | result.append(Subtask()) 125 | for item, next_item in zip(reversed(result[:-1]), reversed(result[1:])): 126 | item.actions, next_item.actions = next_item.actions, item.actions 127 | 128 | # remove empty items 129 | result = [subtask for subtask in result if subtask.item_name is not None] 130 | 131 | return result 132 | 133 | 134 | class FrameSkip(gym.Wrapper): 135 | """Return every `skip`-th frame and repeat given action during skip. 136 | Note that this wrapper does not "maximize" over the skipped frames. 137 | """ 138 | 139 | def __init__(self, env, skip=4): 140 | super().__init__(env) 141 | 142 | self._skip = skip 143 | 144 | def step(self, action): 145 | total_reward = 0.0 146 | infos = [] 147 | info = {} 148 | obs = None 149 | done = None 150 | for _ in range(self._skip): 151 | obs, reward, done, info = self.env.step(action) 152 | infos.append(info) 153 | total_reward += reward 154 | if done: 155 | break 156 | if 'expert_action' in infos[0]: 157 | info['expert_action'] = self.env.preprocess_action([info_['expert_action'] for info_ in infos]) 158 | return obs, total_reward, done, info 159 | 160 | 161 | def main(): 162 | trj_info = TrajectoryInformation(env_name='MineRLTreechop-v0', trajectory_name='v3_absolute_grape_changeling-7_14600-16079') 163 | # trj_info = TrajectoryInformation() 164 | env = FakeEnv(data=trj_info.trajectory_by_subtask['log']) 165 | env = FrameSkip(env, 2) 166 | # env = Monitor(env, 'monitor/') 167 | env.reset() 168 | done = False 169 | 170 | while True: 171 | done = False 172 | while not done: 173 | obs, rew, done, info = env.step(None) 174 | env.render(), sleep(0.01) 175 | 176 | if env.reset() is None: 177 | break 178 | 179 | 180 | if __name__ == '__main__': 181 | main() 182 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import minerl 4 | import gym 5 | import os 6 | import tensorflow as tf 7 | import yaml 8 | 9 | from policy.agent import create_flat_agent 10 | from hierarchy.subtask_agent import ItemAgent 11 | from hierarchy.subtasks_extraction import TrajectoryInformation 12 | from utils.fake_env import FakeEnv 13 | 14 | import argparse 15 | 16 | from utils.config_validation import Pipeline, Task 17 | from utils.wrappers import wrap_env 18 | from utils.tf_util import config_gpu 19 | 20 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 22 | config_gpu() 23 | 24 | 25 | def load_trajectories(task, max_trj=300): 26 | data = minerl.data.make(task.environment, task.data_dir) 27 | 28 | trajectories = [] 29 | for trj_name in data.get_trajectory_names()[:max_trj]: 30 | trajectories.append(TrajectoryInformation(env_name=data.environment, trajectory_name=trj_name)) 31 | 32 | return trajectories 33 | 34 | 35 | def run_task(task: Task): 36 | if task.agent_type == 'flat': 37 | env = wrap_env(gym.make(task.environment), task.cfg.wrappers) 38 | agent = create_flat_agent(task, env) 39 | 40 | if task.source == 'expert': 41 | for trajectory in load_trajectories(task): 42 | # todo replace 'log' with parameter name 43 | agent.add_demo(wrap_env(FakeEnv(data=trajectory.trajectory_by_subtask['log']), task.cfg.wrappers)) 44 | 45 | agent.pre_train(task) 46 | agent.save(task.cfg.agent.save_dir) 47 | 48 | elif task.source == 'agent': 49 | summary_writer = tf.summary.create_file_writer('train/') 50 | with summary_writer.as_default(): 51 | scores_, _ = agent.train(env, task) 52 | env.close() 53 | agent.save(task.cfg.agent.save_dir) 54 | 55 | elif task.agent_type == 'hierarchical': 56 | 57 | if task.source == 'expert': 58 | env = wrap_env(gym.make(task.environment), task.cfg.wrappers) 59 | trajectories = load_trajectories(task) 60 | unique_subtasks = reduce(lambda x, y: x.union(y), 61 | [set(q) for q in [t.trajectory_by_subtask.keys() for t in trajectories]]) 62 | for subtask in unique_subtasks: 63 | if subtask not in ["cobblestone", "iron_ore"]: 64 | continue 65 | agent = create_flat_agent(task, env) 66 | for trj in trajectories: 67 | if not trj.trajectory_by_subtask.get(subtask, None): 68 | continue 69 | 70 | agent.add_demo(wrap_env(FakeEnv(data=trj.trajectory_by_subtask[subtask]), task.cfg.wrappers)) 71 | agent.pre_train(task) 72 | agent.save(task.cfg.agent.save_dir + subtask + '/') 73 | 74 | elif task.source == 'agent': 75 | env = gym.make(task.environment) 76 | item_agent = ItemAgent(task) 77 | item_agent.train(env, task) 78 | 79 | 80 | def run_pipeline(pipeline: Pipeline): 81 | for task in pipeline.pipeline: 82 | run_task(task) 83 | 84 | 85 | def main(): 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument('--config', type=str, action="store", help='yaml file with settings', required=False, default='configs/eval-diamond.yaml') 88 | params = parser.parse_args() 89 | 90 | with open(params.config, "r") as f: 91 | config = yaml.safe_load(f) 92 | with tf.device('/gpu'): 93 | # noinspection Pydantic 94 | run_pipeline(Pipeline(**config)) 95 | 96 | 97 | if __name__ == '__main__': 98 | main() 99 | -------------------------------------------------------------------------------- /policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cog-model/forger/4e4258f358094b358db15b20f04c8197eb3bf63e/policy/__init__.py -------------------------------------------------------------------------------- /policy/agent.py: -------------------------------------------------------------------------------- 1 | import random 2 | import timeit 3 | from collections import deque 4 | import pathlib 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from policy.models import get_network_builder 9 | from policy.replay_buffer import AggregatedBuff 10 | from utils.config_validation import AgentCfg, Task 11 | from utils.discretization import get_dtype_dict 12 | from utils.tf_util import huber_loss, take_vector_elements 13 | 14 | 15 | def create_flat_agent(task: Task, env): 16 | make_model = get_network_builder(task.model_name) 17 | env_dict, dtype_dict = get_dtype_dict(env) 18 | replay_buffer = AggregatedBuff(env_dict, task.cfg.buffer) 19 | agent = Agent(task.cfg.agent, replay_buffer, make_model, env.observation_space, env.action_space, dtype_dict) 20 | if not task.from_scratch: 21 | agent.load(task.cfg.agent.save_dir) 22 | return agent 23 | 24 | 25 | class Agent: 26 | def __init__(self, cfg: AgentCfg, replay_buffer, build_model, obs_space, act_space, 27 | dtype_dict=None, log_freq=100): 28 | 29 | self.cfg = cfg 30 | self.n_deque = deque([], maxlen=cfg.n_step) 31 | 32 | self.replay_buff = replay_buffer 33 | self.priorities_store = list() 34 | if dtype_dict is not None: 35 | ds = tf.data.Dataset.from_generator(self.sample_generator, output_types=dtype_dict) 36 | ds = ds.prefetch(tf.data.experimental.AUTOTUNE) 37 | self.sampler = ds.take 38 | else: 39 | self.sampler = self.sample_generator 40 | self.online_model = build_model('Online_Model', obs_space, act_space, self.cfg.l2) 41 | self.target_model = build_model('Target_Model', obs_space, act_space, self.cfg.l2) 42 | self.optimizer = tf.keras.optimizers.Adam(self.cfg.learning_rate) 43 | self._run_time_deque = deque(maxlen=log_freq) 44 | self._schedule_dict = dict() 45 | self._schedule_dict[self.target_update] = self.cfg.update_target_net_mod 46 | self._schedule_dict[self.update_log] = log_freq 47 | self.avg_metrics = dict() 48 | self.action_dim = act_space.n 49 | self.global_step = 0 50 | 51 | def train(self, env, task: Task): 52 | print('starting from step:', self.global_step) 53 | scores = [] 54 | 55 | epsilon = self.cfg.initial_epsilon 56 | current_episode = 0 57 | while self.global_step < task.max_train_steps and current_episode < task.max_train_episodes: 58 | score = self.train_episode(env, task, epsilon) 59 | print(f'Steps: {self.global_step}, Episode: {current_episode}, Reward: {score}, Eps Greedy: {round(epsilon, 3)}') 60 | current_episode += 1 61 | if self.global_step >= self.cfg.epsilon_time_steps: 62 | epsilon = self.cfg.final_epsilon 63 | else: 64 | epsilon = (self.cfg.initial_epsilon - self.cfg.final_epsilon) * \ 65 | (self.cfg.epsilon_time_steps - self.global_step) / self.cfg.epsilon_time_steps 66 | 67 | scores.append(score) 68 | tf.summary.scalar("reward", score, step=self.global_step) 69 | tf.summary.flush() 70 | 71 | return scores 72 | 73 | def train_episode(self, env, task: Task, epsilon=0.0): 74 | if self.global_step == 0: 75 | self.target_update() 76 | done, score, state = False, 0, env.reset() 77 | while not done: 78 | action = self.choose_act(state, epsilon) 79 | next_state, reward, done, _ = env.step(action) 80 | if task.cfg.wrappers.render: 81 | env.render() 82 | score += reward 83 | 84 | self.global_step += 1 85 | if not task.evaluation: 86 | # print(f'saving to {task.cfg.agent.save_dir}') 87 | self.perceive(to_demo=0, state=state, action=action, reward=reward, next_state=next_state, 88 | done=done, demo=False) 89 | if self.replay_buff.get_stored_size() > self.cfg.replay_start_size: 90 | if self.global_step % self.cfg.frames_to_update == 0: 91 | self.update(task.cfg.agent.update_quantity) 92 | self.save(task.cfg.agent.save_dir) 93 | print(f'saving to {task.cfg.agent.save_dir}') 94 | 95 | state = next_state 96 | return score 97 | 98 | def pre_train(self, task): 99 | """ 100 | pre_train phase in policy alg. 101 | :return: 102 | """ 103 | print('Pre-training ...') 104 | self.target_update() 105 | self.update(task.pretrain_num_updates) 106 | # self.save(os.path.join(self.cfg.save_dir, "pre_trained_model.ckpt")) 107 | print('All pre-train finish.') 108 | 109 | def update(self, num_updates): 110 | start_time = timeit.default_timer() 111 | for batch in self.sampler(num_updates): 112 | indexes = batch.pop('indexes') 113 | priorities = self.q_network_update(gamma=self.cfg.gamma, **batch) 114 | self.schedule() 115 | self.priorities_store.append({'indexes': indexes.numpy(), 'priorities': priorities.numpy()}) 116 | stop_time = timeit.default_timer() 117 | self._run_time_deque.append(stop_time - start_time) 118 | start_time = timeit.default_timer() 119 | while len(self.priorities_store) > 0: 120 | priorities = self.priorities_store.pop(0) 121 | self.replay_buff.update_priorities(**priorities) 122 | 123 | def sample_generator(self, steps=None): 124 | steps_done = 0 125 | finite_loop = bool(steps) 126 | steps = steps if finite_loop else 1 127 | while steps_done < steps: 128 | yield self.replay_buff.sample(self.cfg.batch_size) 129 | if len(self.priorities_store) > 0: 130 | priorities = self.priorities_store.pop(0) 131 | self.replay_buff.update_priorities(**priorities) 132 | steps += int(finite_loop) 133 | 134 | @tf.function 135 | def q_network_update(self, state, action, next_state, done, reward, demo, 136 | n_state, n_done, n_reward, actual_n, weights, 137 | gamma): 138 | print("Q-nn_update tracing") 139 | online_variables = self.online_model.trainable_variables 140 | with tf.GradientTape() as tape: 141 | tape.watch(online_variables) 142 | q_value = self.online_model(state, training=True) 143 | margin = self.margin_loss(q_value, action, demo, weights) 144 | self.update_metrics('margin', margin) 145 | 146 | q_value = take_vector_elements(q_value, action) 147 | 148 | td_loss = self.td_loss(q_value, next_state, done, reward, 1, gamma) 149 | huber_td = huber_loss(td_loss, delta=0.4) 150 | mean_td = tf.reduce_mean(huber_td * weights) 151 | self.update_metrics('TD', mean_td) 152 | 153 | ntd_loss = self.td_loss(q_value, n_state, n_done, n_reward, actual_n, gamma) 154 | huber_ntd = huber_loss(ntd_loss, delta=0.4) 155 | mean_ntd = tf.reduce_mean(huber_ntd * weights) 156 | self.update_metrics('nTD', mean_ntd) 157 | 158 | l2 = tf.add_n(self.online_model.losses) 159 | self.update_metrics('l2', l2) 160 | 161 | all_losses = mean_td + mean_ntd + l2 + margin 162 | self.update_metrics('all_losses', all_losses) 163 | 164 | gradients = tape.gradient(all_losses, online_variables) 165 | self.optimizer.apply_gradients(zip(gradients, online_variables)) 166 | priorities = tf.abs(td_loss) 167 | return priorities 168 | 169 | def td_loss(self, q_value, n_state, n_done, n_reward, actual_n, gamma): 170 | n_target = self.compute_target(n_state, n_done, n_reward, actual_n, gamma) 171 | n_target = tf.stop_gradient(n_target) 172 | ntd_loss = q_value - n_target 173 | return ntd_loss 174 | 175 | def compute_target(self, next_state, done, reward, actual_n, gamma): 176 | print("Compute_target tracing") 177 | q_network = self.online_model(next_state, training=True) 178 | argmax_actions = tf.argmax(q_network, axis=1, output_type='int32') 179 | q_target = self.target_model(next_state, training=True) 180 | target = take_vector_elements(q_target, argmax_actions) 181 | target = tf.where(done, tf.zeros_like(target), target) 182 | target = target * gamma ** actual_n 183 | target = target + reward 184 | return target 185 | 186 | def margin_loss(self, q_value, action, demo, weights): 187 | ae = tf.one_hot(action, self.action_dim, on_value=0.0, 188 | off_value=self.cfg.margin) 189 | ae = tf.cast(ae, 'float32') 190 | max_value = tf.reduce_max(q_value + ae, axis=1) 191 | ae = tf.one_hot(action, self.action_dim) 192 | j_e = tf.abs(tf.reduce_sum(q_value * ae, axis=1) - max_value) 193 | j_e = tf.reduce_mean(j_e * weights * demo) 194 | return j_e 195 | 196 | def add_demo(self, expert_env, expert_data=1): 197 | while not expert_env.are_all_frames_used(): 198 | done = False 199 | obs = expert_env.reset() 200 | 201 | while not done: 202 | next_obs, reward, done, info = expert_env.step(0) 203 | action = info['expert_action'] 204 | self.perceive(to_demo=1, state=obs, action=action, reward=reward, next_state=next_obs, done=done, 205 | demo=expert_data) 206 | obs = next_obs 207 | 208 | def perceive(self, **kwargs): 209 | self.n_deque.append(kwargs) 210 | 211 | if len(self.n_deque) == self.n_deque.maxlen or kwargs['done']: 212 | while len(self.n_deque) != 0: 213 | n_state = self.n_deque[-1]['next_state'] 214 | n_done = self.n_deque[-1]['done'] 215 | n_reward = sum([t['reward'] * self.cfg.gamma ** i for i, t in enumerate(self.n_deque)]) 216 | self.n_deque[0]['n_state'] = n_state 217 | self.n_deque[0]['n_reward'] = n_reward 218 | self.n_deque[0]['n_done'] = n_done 219 | self.n_deque[0]['actual_n'] = len(self.n_deque) 220 | self.replay_buff.add(**self.n_deque.popleft()) 221 | if not n_done: 222 | break 223 | 224 | def choose_act(self, state, epsilon=0.01): 225 | nn_input = np.array(state)[None] 226 | q_value = self.online_model(nn_input, training=False) 227 | if random.random() <= epsilon: 228 | return random.randint(0, self.action_dim - 1) 229 | return np.argmax(q_value) 230 | 231 | def schedule(self): 232 | for key, value in self._schedule_dict.items(): 233 | if tf.equal(self.optimizer.iterations % value, 0): 234 | key() 235 | 236 | def target_update(self): 237 | self.target_model.set_weights(self.online_model.get_weights()) 238 | 239 | def save(self, out_dir=None): 240 | self.online_model.save_weights(pathlib.Path(out_dir) / 'model.ckpt') 241 | 242 | def load(self, out_dir=None): 243 | 244 | if pathlib.Path(out_dir).exists(): 245 | self.online_model.load_weights(pathlib.Path(out_dir) / 'model.ckpt') 246 | else: 247 | raise KeyError(f"Can not import weights from {pathlib.Path(out_dir)}") 248 | 249 | def update_log(self): 250 | update_frequency = len(self._run_time_deque) / sum(self._run_time_deque) 251 | print("LearnerEpoch({:.2f}it/sec): ".format(update_frequency), self.optimizer.iterations.numpy()) 252 | for key, metric in self.avg_metrics.items(): 253 | tf.summary.scalar(key, metric.result(), step=self.optimizer.iterations) 254 | print(' {}: {:.5f}'.format(key, metric.result())) 255 | metric.reset_states() 256 | tf.summary.flush() 257 | 258 | def update_metrics(self, key, value): 259 | if key not in self.avg_metrics: 260 | self.avg_metrics[key] = tf.keras.metrics.Mean(name=key, dtype=tf.float32) 261 | self.avg_metrics[key].update_state(value) 262 | -------------------------------------------------------------------------------- /policy/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Sequential 3 | from tensorflow.keras.regularizers import l2 4 | from tensorflow.keras.layers import Dense, Conv2D, Flatten 5 | import tensorflow.keras.backend as K 6 | from tensorflow.keras.layers import InputSpec 7 | 8 | from policy.tf1_models import OldDuelingModel, OldClassicCnn 9 | 10 | #from tensorflow.python.keras.engine.base_layer import InputSpec 11 | 12 | 13 | mapping = dict() 14 | 15 | 16 | def register(name): 17 | def _thunk(func): 18 | mapping[name] = func 19 | return func 20 | return _thunk 21 | 22 | 23 | def get_network_builder(name): 24 | """ 25 | If you want to register your own network outside models.py, you just need: 26 | 27 | Usage Example: 28 | ------------- 29 | from policy.model import register 30 | @register("your_network_name") 31 | def your_network_define(**net_kwargs): 32 | ... 33 | return network_fn 34 | 35 | """ 36 | if callable(name): 37 | return name 38 | elif name in mapping: 39 | return mapping[name] 40 | else: 41 | raise ValueError('Registered networks:', ', '.join(mapping.keys())) 42 | 43 | 44 | class DuelingModel(tf.keras.Model): 45 | def __init__(self, units, action_dim, reg=1e-6, noisy=True): 46 | super(DuelingModel, self).__init__() 47 | reg = {'kernel_regularizer': l2(reg), 'bias_regularizer': l2(reg)} 48 | if noisy: 49 | layer = NoisyDense 50 | else: 51 | layer = Dense 52 | kernel_init = tf.keras.initializers.VarianceScaling(scale=2.) 53 | self.h_layers = Sequential([layer(num, 'relu', use_bias=True, kernel_initializer=kernel_init, 54 | **reg) for num in units[:-1]]) 55 | self.a_head = layer(units[-1]/2, 'relu', use_bias=True, kernel_initializer=kernel_init, **reg) 56 | self.v_head = layer(units[-1]/2, 'relu', use_bias=True, kernel_initializer=kernel_init, **reg) 57 | self.a_head1 = layer(action_dim, use_bias=True, kernel_initializer=kernel_init, **reg) 58 | self.v_head1 = layer(1, use_bias=True, kernel_initializer=kernel_init, **reg) 59 | 60 | @tf.function 61 | def call(self, inputs): 62 | print('Building model') 63 | features = self.h_layers(inputs) 64 | advantage, value = self.a_head(features), self.v_head(features) 65 | advantage, value = self.a_head1(advantage), self.v_head1(value) 66 | advantage = advantage - tf.reduce_mean(advantage, axis=-1, keepdims=True) 67 | out = value + advantage 68 | return out 69 | 70 | 71 | class ClassicCnn(tf.keras.Model): 72 | def __init__(self, filters, kernels, strides, activation='relu', reg=1e-6): 73 | super(ClassicCnn, self).__init__() 74 | reg = l2(reg) 75 | kernel_init = tf.keras.initializers.VarianceScaling(scale=2.) 76 | self.cnn = Sequential(Conv2D(filters[0], kernels[0], strides[0], activation=activation, 77 | kernel_regularizer=reg, kernel_initializer=kernel_init), name='CNN') 78 | for f, k, s in zip(filters[1:], kernels[1:], strides[1:]): 79 | self.cnn.add(Conv2D(f, k, s, activation=activation, kernel_regularizer=reg, 80 | kernel_initializer=kernel_init)) 81 | self.cnn.add(Flatten()) 82 | 83 | @tf.function 84 | def call(self, inputs): 85 | return self.cnn(inputs) 86 | 87 | 88 | class MLP(tf.keras.Model): 89 | def __init__(self, units, activation='relu', reg=1e-6): 90 | super(MLP, self).__init__() 91 | reg = l2(reg) 92 | self.model = Sequential([Dense(num, activation, kernel_regularizer=reg, bias_regularizer=reg) 93 | for num in units]) 94 | 95 | @tf.function 96 | def call(self, inputs): 97 | return self.model(inputs) 98 | 99 | 100 | class NoisyDense(Dense): 101 | 102 | # factorized noise 103 | def __init__(self, units, *args, **kwargs): 104 | self.output_dim = units 105 | self.f = lambda x: tf.multiply(tf.sign(x), tf.pow(tf.abs(x), 0.5)) 106 | super(NoisyDense, self).__init__(units, *args, **kwargs) 107 | 108 | def build(self, input_shape): 109 | assert len(input_shape) >= 2 110 | self.input_dim = input_shape[-1] 111 | 112 | self.kernel = self.add_weight(shape=(self.input_dim, self.units), 113 | initializer=self.kernel_initializer, 114 | name='kernel', 115 | regularizer=self.kernel_regularizer, 116 | constraint=None) 117 | 118 | self.kernel_sigma = self.add_weight(shape=(self.input_dim, self.units), 119 | initializer=self.kernel_initializer, 120 | name='sigma_kernel', 121 | regularizer=self.kernel_regularizer, 122 | constraint=None) 123 | 124 | if self.use_bias: 125 | self.bias = self.add_weight(shape=(1, self.units), 126 | initializer=self.bias_initializer, 127 | name='bias', 128 | regularizer=self.bias_regularizer, 129 | constraint=None) 130 | 131 | self.bias_sigma = self.add_weight(shape=(1, self.units,), 132 | initializer=self.bias_initializer, 133 | name='bias_sigma', 134 | regularizer=self.bias_regularizer, 135 | constraint=None) 136 | else: 137 | self.bias = None 138 | 139 | self.input_spec = InputSpec(min_ndim=2, axes={-1: self.input_dim}) 140 | self.built = True 141 | 142 | def call(self, inputs): 143 | if inputs.shape[0]: 144 | kernel_input = self.f(tf.random.normal(shape=(inputs.shape[0], self.input_dim, 1))) 145 | kernel_output = self.f(tf.random.normal(shape=(inputs.shape[0], 1, self.units))) 146 | else: 147 | kernel_input = self.f(tf.random.normal(shape=(self.input_dim, 1))) 148 | kernel_output = self.f(tf.random.normal(shape=(1, self.units))) 149 | kernel_epsilon = tf.matmul(kernel_input, kernel_output) 150 | 151 | w = self.kernel + self.kernel_sigma * kernel_epsilon 152 | 153 | output = tf.matmul(tf.expand_dims(inputs, axis=1), w) 154 | 155 | if self.use_bias: 156 | b = self.bias + self.bias_sigma * kernel_output 157 | output = output + b 158 | if self.activation is not None: 159 | output = self.activation(output) 160 | output = tf.squeeze(output, axis=1) 161 | return output 162 | 163 | 164 | @register("minerl_dqfd") 165 | def make_model(name, obs_space, action_space, reg=1e-5): 166 | pov = tf.keras.Input(shape=obs_space.shape) 167 | normalized_pov = pov / 255 168 | pov_base = ClassicCnn([32, 64, 64], [8, 4, 3], [4, 2, 1], reg=reg)(normalized_pov) 169 | head = DuelingModel([1024], action_space.n, reg=reg)(pov_base) 170 | model = tf.keras.Model(inputs={'pov': pov}, outputs=head, name=name) 171 | return model 172 | 173 | 174 | @register("flat_dqfd") 175 | def make_model(name, obs_space, action_space, reg=1e-5): 176 | features = tf.keras.Input(shape=obs_space.shape) 177 | feat_base = MLP([64,64], activation='tanh', reg=reg)(features) 178 | head = DuelingModel([512], action_space.n, reg=reg, noisy=False)(feat_base) 179 | model = tf.keras.Model(inputs={'features': features}, outputs=head, name=name) 180 | return model 181 | 182 | 183 | @register("tf1_minerl_dqfd") 184 | def make_model(name, obs_space, action_space, reg=1e-5): 185 | pov = tf.keras.Input(shape=obs_space.shape) 186 | normalized_pov = pov / 255 187 | pov_base = OldClassicCnn([32, 64, 64], [8, 4, 3], [4, 2, 1], reg=reg)(normalized_pov) 188 | head = OldDuelingModel(action_space.n, reg=reg)(pov_base) 189 | model = tf.keras.Model(inputs={'pov': pov}, outputs=head, name=name) 190 | return model 191 | 192 | -------------------------------------------------------------------------------- /policy/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from policy.sum_tree import SumSegmentTree, MinSegmentTree 4 | 5 | from utils.config_validation import BufferCfg 6 | 7 | 8 | class ReplayBuffer(object): 9 | def __init__(self, size, env_dict): 10 | """Create Replay buffer. 11 | Parameters 12 | ---------- 13 | size: int 14 | Max number of transitions to store in the buffer. When the buffer 15 | overflows the old memories are dropped. 16 | """ 17 | self._storage = [] 18 | self._maxsize = size 19 | self._next_idx = 0 20 | self.env_dict = env_dict 21 | 22 | def get_stored_size(self): 23 | return len(self._storage) 24 | 25 | def get_buffer_size(self): 26 | return self._maxsize 27 | 28 | def add(self, **kwargs): 29 | data = kwargs 30 | if self._next_idx >= len(self._storage): 31 | self._storage.append(data) 32 | else: 33 | self._storage[self._next_idx] = data 34 | self._next_idx = (self._next_idx + 1) % self._maxsize 35 | 36 | @property 37 | def first_transition(self): 38 | return self._storage[0] 39 | 40 | def _encode_sample(self, idxes): 41 | batch = {key: list() for key in self.first_transition.keys()} 42 | for i in idxes: 43 | data = self._storage[i] 44 | for key, value in data.items(): 45 | batch[key].append(np.array(value)) 46 | for key, value in batch.items(): 47 | batch[key] = np.array(value, dtype=self.env_dict[key]['dtype']) 48 | return batch 49 | 50 | def sample(self, batch_size): 51 | """Sample a batch of experiences. 52 | Parameters 53 | ---------- 54 | batch_size: int 55 | How many transitions to sample. 56 | Returns 57 | ------- 58 | obs_batch: np.array 59 | batch of observations 60 | act_batch: np.array 61 | batch of actions executed given obs_batch 62 | rew_batch: np.array 63 | rewards received as results of executing act_batch 64 | next_obs_batch: np.array 65 | next set of observations seen after executing act_batch 66 | done_mask: np.array 67 | done_mask[i] = 1 if executing act_batch[i] resulted in 68 | the end of an episode and 0 otherwise. 69 | """ 70 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 71 | return self._encode_sample(idxes) 72 | 73 | def update_priorities(self, *args, **kwargs): 74 | pass 75 | 76 | 77 | class PrioritizedReplayBuffer(ReplayBuffer): 78 | def __init__(self, size, env_dict, alpha=0.6, eps=1e-6): 79 | """Create Prioritized Replay buffer. 80 | Parameters 81 | ---------- 82 | size: int 83 | Max number of transitions to store in the buffer. When the buffer 84 | overflows the old memories are dropped. 85 | alpha: float 86 | how much prioritization is used 87 | (0 - no prioritization, 1 - full prioritization) 88 | See Also 89 | -------- 90 | ReplayBuffer.__init__ 91 | """ 92 | super(PrioritizedReplayBuffer, self).__init__(size, env_dict) 93 | assert alpha >= 0 94 | self._alpha = alpha 95 | self._eps = eps 96 | 97 | it_capacity = 1 98 | while it_capacity < size: 99 | it_capacity *= 2 100 | 101 | self._it_sum = SumSegmentTree(it_capacity) 102 | self._it_min = MinSegmentTree(it_capacity) 103 | self._max_priority = 1.0 104 | 105 | def add(self, **kwargs): 106 | """See ReplayBuffer.store_effect""" 107 | idx = self._next_idx 108 | super().add(**kwargs) 109 | self._it_sum[idx] = self._max_priority ** self._alpha 110 | self._it_min[idx] = self._max_priority ** self._alpha 111 | 112 | def _sample_proportional(self, batch_size): 113 | res = [] 114 | p_total = self._it_sum.sum(0, len(self._storage) - 1) 115 | every_range_len = p_total / batch_size 116 | for i in range(batch_size): 117 | mass = random.random() * every_range_len + i * every_range_len 118 | idx = self._it_sum.find_prefixsum_idx(mass) 119 | res.append(idx) 120 | return res 121 | 122 | def sample(self, batch_size, beta=0.4): 123 | """Sample a batch of experiences. 124 | compared to ReplayBuffer.sample 125 | it also returns importance weights and idxes 126 | of sampled experiences. 127 | Parameters 128 | ---------- 129 | batch_size: int 130 | How many transitions to sample. 131 | beta: float 132 | Priority level 133 | Returns 134 | ------- 135 | encoded_sample: dict of np.array 136 | Array of shape(batch_size, ...) and dtype np.*32 137 | weights: np.array 138 | Array of shape (batch_size,) and dtype np.float32 139 | denoting importance weight of each sampled transition 140 | idxes: np.array 141 | Array of shape (batch_size,) and dtype np.int32 142 | idexes in buffer of sampled experiences 143 | """ 144 | idxes = self._sample_proportional(batch_size) 145 | it_sum = self._it_sum.sum() 146 | it_min = self._it_min.min() 147 | p_min = it_min / it_sum 148 | max_weight = (p_min * len(self._storage)) ** (- beta) 149 | p_sample = np.array([self._it_sum[idx] / it_sum for idx in idxes]) 150 | weights = (p_sample * len(self._storage)) ** (- beta) 151 | weights = weights / max_weight 152 | encoded_sample = self._encode_sample(idxes) 153 | encoded_sample['weights'] = weights 154 | encoded_sample['indexes'] = idxes 155 | return encoded_sample 156 | 157 | def update_priorities(self, idxes, priorities): 158 | """Update priorities of sampled transitions. 159 | sets priority of transition at index idxes[i] in buffer 160 | to priorities[i]. 161 | Parameters 162 | ---------- 163 | idxes: [int] 164 | List of idxes of sampled transitions 165 | priorities: [float] 166 | List of updated priorities corresponding to 167 | transitions at the sampled idxes denoted by 168 | variable `idxes`. 169 | """ 170 | assert len(idxes) == len(priorities) 171 | for idx, priority in zip(idxes, priorities): 172 | assert priority >= 0 173 | assert 0 <= idx < len(self._storage) 174 | self._it_sum[idx] = (priority + self._eps) ** self._alpha 175 | self._it_min[idx] = (priority + self._eps) ** self._alpha 176 | 177 | self._max_priority = max(self._max_priority, priority) 178 | 179 | 180 | class AggregatedBuff: 181 | def __init__(self, env_dict, cfg: BufferCfg): 182 | buffer_base = PrioritizedReplayBuffer 183 | self._maxsize = cfg.size 184 | self.demo_kwargs = {'size': cfg.size, 'env_dict': env_dict, 'eps': 1.0, } 185 | self.demo_buff = buffer_base(size=cfg.size, env_dict=env_dict, eps=1.0, ) 186 | self.replay_buff = buffer_base(size=cfg.size, env_dict=env_dict, eps=1e-3, ) 187 | self.episodes_to_decay = cfg.episodes_to_decay 188 | self.episodes_done = 0 189 | self.min_demo_proportion = cfg.min_demo_proportion 190 | 191 | def add(self, to_demo=0, **kwargs): 192 | if to_demo: 193 | self.demo_buff.add(**kwargs) 194 | else: 195 | self.replay_buff.add(**kwargs) 196 | if kwargs['done']: 197 | self.episodes_done += 1 198 | if self.demo_buff.get_stored_size() > 0 \ 199 | and self.episodes_done > self.episodes_to_decay \ 200 | and self.min_demo_proportion == 0.: 201 | self.demo_buff = PrioritizedReplayBuffer(**self.demo_kwargs) 202 | 203 | def free_demo(self): 204 | self.demo_buff = PrioritizedReplayBuffer(**self.demo_kwargs) 205 | 206 | @property 207 | def proportion(self): 208 | if self.episodes_to_decay == 0: 209 | proportion = 1. - self.min_demo_proportion 210 | else: 211 | proportion = min(1. - self.min_demo_proportion, self.episodes_done / self.episodes_to_decay) 212 | proportion = max(proportion, float(self.demo_buff.get_stored_size() == 0)) 213 | return proportion 214 | 215 | def sample(self, n=32, beta=0.4): 216 | agent_n = int(n * self.proportion) 217 | demo_n = n - agent_n 218 | if demo_n > 0 and agent_n > 0: 219 | demo_samples = self.demo_buff.sample(demo_n, beta) 220 | replay_samples = self.replay_buff.sample(agent_n, beta) 221 | samples = {key: np.concatenate((replay_samples[key], demo_samples[key])) 222 | for key in replay_samples.keys()} 223 | elif agent_n == 0: 224 | samples = self.demo_buff.sample(demo_n, beta) 225 | else: 226 | samples = self.replay_buff.sample(agent_n, beta) 227 | samples = {key: np.squeeze(value) for key, value in samples.items()} 228 | return samples 229 | 230 | def update_priorities(self, indexes, priorities): 231 | n = len(indexes) 232 | agent_n = int(n * self.proportion) 233 | demo_n = n - agent_n 234 | if demo_n != 0: 235 | self.demo_buff.update_priorities(indexes[agent_n:], priorities[agent_n:]) 236 | if agent_n != 0: 237 | self.replay_buff.update_priorities(indexes[:agent_n], priorities[:agent_n]) 238 | 239 | def get_stored_size(self): 240 | return self.replay_buff.get_stored_size() 241 | 242 | def get_buffer_size(self): 243 | return self._maxsize 244 | -------------------------------------------------------------------------------- /policy/sum_tree.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | 4 | class SegmentTree(object): 5 | def __init__(self, capacity, operation, neutral_element): 6 | """Build a Segment Tree data structure. 7 | https://en.wikipedia.org/wiki/Segment_tree 8 | Can be used as regular array, but with two 9 | important differences: 10 | a) setting item's value is slightly slower. 11 | It is O(lg capacity) instead of O(1). 12 | b) user has access to an efficient ( O(log segment size) ) 13 | `reduce` operation which reduces `operation` over 14 | a contiguous subsequence of items in the array. 15 | Paramters 16 | --------- 17 | capacity: int 18 | Total size of the array - must be a power of two. 19 | operation: lambda obj, obj -> obj 20 | and operation for combining elements (eg. sum, max) 21 | must form a mathematical group together with the set of 22 | possible values for array elements (i.e. be associative) 23 | neutral_element: obj 24 | neutral element for the operation above. eg. float('-inf') 25 | for max and 0 for sum. 26 | """ 27 | assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." 28 | self._capacity = capacity 29 | self._value = [neutral_element for _ in range(2 * capacity)] 30 | self._operation = operation 31 | 32 | def _reduce_helper(self, start, end, node, node_start, node_end): 33 | if start == node_start and end == node_end: 34 | return self._value[node] 35 | mid = (node_start + node_end) // 2 36 | if end <= mid: 37 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 38 | else: 39 | if mid + 1 <= start: 40 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 41 | else: 42 | return self._operation( 43 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 44 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) 45 | ) 46 | 47 | def reduce(self, start=0, end=None): 48 | """Returns result of applying `self.operation` 49 | to a contiguous subsequence of the array. 50 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 51 | Parameters 52 | ---------- 53 | start: int 54 | beginning of the subsequence 55 | end: int 56 | end of the subsequences 57 | Returns 58 | ------- 59 | reduced: obj 60 | result of reducing self.operation over the specified range of array elements. 61 | """ 62 | if end is None: 63 | end = self._capacity 64 | if end < 0: 65 | end += self._capacity 66 | end -= 1 67 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 68 | 69 | def __setitem__(self, idx, val): 70 | # index of the leaf 71 | idx += self._capacity 72 | self._value[idx] = val 73 | idx //= 2 74 | while idx >= 1: 75 | self._value[idx] = self._operation( 76 | self._value[2 * idx], 77 | self._value[2 * idx + 1] 78 | ) 79 | idx //= 2 80 | 81 | def __getitem__(self, idx): 82 | assert 0 <= idx < self._capacity 83 | return self._value[self._capacity + idx] 84 | 85 | 86 | class SumSegmentTree(SegmentTree): 87 | def __init__(self, capacity): 88 | super(SumSegmentTree, self).__init__( 89 | capacity=capacity, 90 | operation=operator.add, 91 | neutral_element=0.0 92 | ) 93 | 94 | def sum(self, start=0, end=None): 95 | """Returns arr[start] + ... + arr[end]""" 96 | return super(SumSegmentTree, self).reduce(start, end) 97 | 98 | def find_prefixsum_idx(self, prefixsum): 99 | """Find the highest index `i` in the array such that 100 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum 101 | if array values are probabilities, this function 102 | allows to sample indexes according to the discrete 103 | probability efficiently. 104 | Parameters 105 | ---------- 106 | prefixsum: float 107 | upperbound on the sum of array prefix 108 | Returns 109 | ------- 110 | idx: int 111 | highest index satisfying the prefixsum constraint 112 | """ 113 | assert 0 <= prefixsum <= self.sum() + 1e-5 114 | idx = 1 115 | while idx < self._capacity: # while non-leaf 116 | if self._value[2 * idx] > prefixsum: 117 | idx = 2 * idx 118 | else: 119 | prefixsum -= self._value[2 * idx] 120 | idx = 2 * idx + 1 121 | return idx - self._capacity 122 | 123 | 124 | class MinSegmentTree(SegmentTree): 125 | def __init__(self, capacity): 126 | super(MinSegmentTree, self).__init__( 127 | capacity=capacity, 128 | operation=min, 129 | neutral_element=float('inf') 130 | ) 131 | 132 | def min(self, start=0, end=None): 133 | """Returns min(arr[start], ..., arr[end])""" 134 | 135 | return super(MinSegmentTree, self).reduce(start, end) 136 | -------------------------------------------------------------------------------- /policy/tf1_models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import Sequential 3 | from tensorflow.keras.regularizers import l2 4 | from tensorflow.keras.layers import Dense, Conv2D, Flatten 5 | import tensorflow.keras.backend as K 6 | from tensorflow.keras.layers import InputSpec 7 | 8 | 9 | class OldDuelingModel(tf.keras.Model): 10 | def __init__(self, action_dim, reg=1e-6): 11 | super(OldDuelingModel, self).__init__() 12 | reg = {'kernel_regularizer': l2(reg), 'bias_regularizer': l2(reg)} 13 | 14 | kernel_init = tf.keras.initializers.VarianceScaling(scale=2.) 15 | self.h_layer = OldNoisyDense( 16 | 1024, 'relu', 17 | name='Q_network/dense_1', use_bias=True, 18 | kernel_initializer=kernel_init, 19 | **reg 20 | ) 21 | 22 | self.a_head1 = OldNoisyDense(action_dim, name='Q_network/A', use_bias=True, kernel_initializer=kernel_init, **reg) 23 | self.v_head1 = OldNoisyDense(1, name='Q_network/V', use_bias=True, kernel_initializer=kernel_init, **reg) 24 | 25 | @tf.function 26 | def call(self, inputs): 27 | print('Building model') 28 | features = self.h_layer(inputs) 29 | advantage, value = tf.split(features, num_or_size_splits=2, axis=-1) 30 | advantage, value = self.a_head1(advantage), self.v_head1(value) 31 | advantage = advantage - tf.reduce_mean(advantage, axis=-1, keepdims=True) 32 | out = value + advantage 33 | return out 34 | 35 | 36 | class OldConv2D(Conv2D): 37 | def add_weight(self, name, *args, **kwargs): 38 | if name == 'kernel': 39 | name = 'k' 40 | return super().add_weight(name, *args, **kwargs) 41 | 42 | 43 | class OldClassicCnn(tf.keras.Model): 44 | def __init__(self, filters, kernels, strides, activation='relu', reg=1e-6): 45 | super(OldClassicCnn, self).__init__() 46 | reg = l2(reg) 47 | kernel_init = tf.keras.initializers.VarianceScaling(scale=2.) 48 | self.cnn = Sequential(OldConv2D( 49 | filters[0], 50 | kernels[0], 51 | strides[0], 52 | activation=activation, 53 | kernel_regularizer=reg, 54 | kernel_initializer=kernel_init, 55 | use_bias=False, 56 | name='Q_network/conv0/conv/', 57 | padding='same' 58 | )) 59 | 60 | for i, (f, k, s) in enumerate(zip(filters[1:], kernels[1:], strides[1:])): 61 | name = f'Q_network/conv{i + 1}_0/conv/' 62 | self.cnn.add(OldConv2D( 63 | f, k, s, 64 | activation=activation, 65 | kernel_regularizer=reg, 66 | kernel_initializer=kernel_init, 67 | use_bias=False, 68 | name=name, 69 | padding='same' 70 | )) 71 | 72 | 73 | self.cnn.add(Flatten()) 74 | 75 | @tf.function 76 | def call(self, inputs): 77 | return self.cnn(inputs) 78 | 79 | 80 | class OldNoisyDense(Dense): 81 | 82 | # factorized noise 83 | def __init__(self, units, *args, **kwargs): 84 | self.output_dim = units 85 | self.f = lambda x: tf.multiply(tf.sign(x), tf.pow(tf.abs(x), 0.5)) 86 | super(OldNoisyDense, self).__init__(units, *args, **kwargs) 87 | 88 | def build(self, input_shape): 89 | assert len(input_shape) >= 2 90 | self.input_dim = input_shape[-1] 91 | 92 | self.kernel = self.add_weight(shape=(self.input_dim, self.units), 93 | initializer=self.kernel_initializer, 94 | name='w_' + self.name[-1], 95 | regularizer=self.kernel_regularizer, 96 | constraint=None) 97 | 98 | self.kernel_sigma = self.add_weight(shape=(self.input_dim, self.units), 99 | initializer=self.kernel_initializer, 100 | name='w_noise_' + self.name[-1], 101 | regularizer=self.kernel_regularizer, 102 | constraint=None) 103 | 104 | if self.use_bias: 105 | self.bias = self.add_weight(shape=(1, self.units), 106 | initializer=self.bias_initializer, 107 | name='b_' + self.name[-1], 108 | regularizer=self.bias_regularizer, 109 | constraint=None) 110 | 111 | self.bias_sigma = self.add_weight(shape=(1, self.units,), 112 | initializer=self.bias_initializer, 113 | name='b_noise_' + self.name[-1], 114 | regularizer=self.bias_regularizer, 115 | constraint=None) 116 | else: 117 | self.bias = None 118 | 119 | self.input_spec = InputSpec(min_ndim=2, axes={-1: self.input_dim}) 120 | self.built = True 121 | 122 | def call(self, inputs): 123 | if inputs.shape[0]: 124 | kernel_input = self.f(tf.random.normal(shape=(inputs.shape[0], self.input_dim, 1))) 125 | kernel_output = self.f(tf.random.normal(shape=(inputs.shape[0], 1, self.units))) 126 | else: 127 | kernel_input = self.f(tf.random.normal(shape=(self.input_dim, 1))) 128 | kernel_output = self.f(tf.random.normal(shape=(1, self.units))) 129 | kernel_epsilon = tf.matmul(kernel_input, kernel_output) 130 | 131 | w = self.kernel + self.kernel_sigma * kernel_epsilon 132 | 133 | output = tf.matmul(tf.expand_dims(inputs, axis=1), w) 134 | 135 | if self.use_bias: 136 | b = self.bias + self.bias_sigma * kernel_output 137 | output = output + b 138 | if self.activation is not None: 139 | output = self.activation(output) 140 | output = tf.squeeze(output, axis=1) 141 | return output 142 | 143 | -------------------------------------------------------------------------------- /static/forging.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cog-model/forger/4e4258f358094b358db15b20f04c8197eb3bf63e/static/forging.png -------------------------------------------------------------------------------- /train/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cog-model/forger/4e4258f358094b358db15b20f04c8197eb3bf63e/train/.gitkeep -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cog-model/forger/4e4258f358094b358db15b20f04c8197eb3bf63e/utils/__init__.py -------------------------------------------------------------------------------- /utils/config_validation.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Extra 2 | from typing import List 3 | 4 | 5 | class BufferCfg(BaseModel, extra=Extra.forbid): 6 | size: int = 450000 7 | episodes_to_decay: int = 50 8 | min_demo_proportion: float = 0.0 9 | cpp: bool = False 10 | 11 | 12 | class WrappersCfg(BaseModel, extra=Extra.forbid): 13 | frame_stack: int = 2 14 | frame_skip: int = 4 15 | render: bool = False 16 | 17 | 18 | class AgentCfg(BaseModel, extra=Extra.forbid): 19 | episodes: int = 250 20 | save_dir: str = None 21 | 22 | frames_to_update: int = 2000 23 | update_quantity: int = 600 24 | update_target_net_mod: int = 3000 25 | replay_start_size: int = 20000 26 | 27 | batch_size: int = 32 28 | gamma: float = 0.99 29 | n_step: int = 10 30 | l2: float = 1e-5 31 | margin: float = 0.4 32 | learning_rate: float = 0.0001 33 | 34 | # eps-greedy 35 | initial_epsilon: float = 0.1 36 | final_epsilon: float = 0.01 37 | epsilon_time_steps: int = 100000 38 | 39 | 40 | class Action(BaseModel): 41 | name: str = None 42 | target: str = None 43 | 44 | 45 | class Subtask(BaseModel): 46 | item_name: str = None 47 | item_count: int = None 48 | start_idx: int = None 49 | end_idx: int = None 50 | actions: List[Action] = [] 51 | 52 | 53 | class GlobalCfg(BaseModel, extra=Extra.forbid): 54 | buffer: BufferCfg = BufferCfg() 55 | wrappers: WrappersCfg = WrappersCfg() 56 | agent: AgentCfg = AgentCfg() 57 | 58 | 59 | class Task(BaseModel, extra=Extra.forbid): 60 | evaluation: bool = False 61 | environment: str = "MineRLTreechop-v0" 62 | max_train_steps: int = 1000000 63 | max_train_episodes: int = 1000000000 64 | pretrain_num_updates: int = 1000000 65 | source: str = None 66 | from_scratch: bool = False 67 | agent_type: str = None 68 | cfg: GlobalCfg = GlobalCfg() 69 | subtasks: List[Subtask] = None 70 | data_dir: str = 'demonstrations' 71 | model_name: str = 'tf1_minerl_dqfd' 72 | 73 | 74 | class Pipeline(BaseModel, extra=Extra.forbid): 75 | pipeline: List[Task] = None 76 | -------------------------------------------------------------------------------- /utils/discretization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | from scipy import stats 4 | 5 | 6 | class ExpertActionPreprocessing: 7 | def __init__(self, always_attack=1): 8 | self.angle = 5 9 | self.always_attack = always_attack 10 | self.ignore_keys = ["place", "nearbySmelt", "nearbyCraft", 11 | "equip", "craft"] 12 | 13 | attack = self.always_attack 14 | self.all_actions_dict = { 15 | "[('attack', {}), ('back', 0), ('camera', [0, 0]), ('forward', 1), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(attack): 0, 16 | "[('attack', {}), ('back', 0), ('camera', [0, {}]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(attack, self.angle): 1, 17 | "[('attack', 1), ('back', 0), ('camera', [0, 0]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]": 2, 18 | "[('attack', {}), ('back', 0), ('camera', [{}, 0]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(attack, self.angle): 3, 19 | "[('attack', {}), ('back', 0), ('camera', [-{}, 0]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(attack, 20 | self.angle): 4, 21 | "[('attack', {}), ('back', 0), ('camera', [0, -{}]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(attack, 22 | self.angle): 5, 23 | "[('attack', {}), ('back', 0), ('camera', [0, 0]), ('forward', 1), ('jump', 1), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(attack): 6, 24 | "[('attack', {}), ('back', 0), ('camera', [0, 0]), ('forward', 0), ('jump', 0), ('left', 1), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(attack): 7, 25 | "[('attack', {}), ('back', 0), ('camera', [0, 0]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 1), ('sneak', 0), ('sprint', 0)]".format(attack): 8, 26 | "[('attack', {}), ('back', 1), ('camera', [0, 0]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(attack): 9} 27 | self.key_to_dict = { 28 | 0: {'attack': attack, 'back': 0, 'camera': [0, 0], 'forward': 1, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 29 | 1: {'attack': attack, 'back': 0, 'camera': [0, self.angle], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 30 | 2: {'attack': 1, 'back': 0, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 31 | 3: {'attack': attack, 'back': 0, 'camera': [self.angle, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 32 | 4: {'attack': attack, 'back': 0, 'camera': [-self.angle, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 33 | 5: {'attack': attack, 'back': 0, 'camera': [0, -self.angle], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 34 | 6: {'attack': attack, 'back': 0, 'camera': [0, 0], 'forward': 1, 'jump': 1, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 35 | 7: {'attack': attack, 'back': 0, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 1, 'right': 0, 'sneak': 0, 'sprint': 0}, 36 | 8: {'attack': attack, 'back': 0, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 1, 'sneak': 0, 'sprint': 0}, 37 | 9: {'attack': attack, 'back': 1, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}} 38 | 39 | def to_joint_action(self, action): 40 | joint_action = {} 41 | action_keys = action[0].keys() 42 | for key in action_keys: 43 | value = np.array([action[idx][key] for idx in range(len(action))]) 44 | if key == 'camera': 45 | # print(value) 46 | mean = np.mean(value, axis=0) 47 | mask = np.abs(mean) > 1.2 48 | sign = np.sign(mean) 49 | argmax = np.argmax(np.abs(mean), ) 50 | one_hot = np.eye(2)[argmax] 51 | joint_action[key] = (one_hot * sign * mask * 5 + 0).astype('int').tolist() 52 | else: 53 | most_freq, _ = stats.mode(value, ) 54 | joint_action[key] = most_freq[0] 55 | return joint_action 56 | 57 | def to_compressed_action(self, action, ): 58 | no_action_part = ["sneak", "sprint"] 59 | action_part = ["attack"] if self.always_attack else [] 60 | moving_actions = ["forward", "back", "right", "left"] 61 | if action["camera"] != [0, 0]: 62 | no_action_part.append("attack") 63 | no_action_part.append("jump") 64 | no_action_part += moving_actions 65 | elif action["jump"] == 1: 66 | action["forward"] = 1 67 | no_action_part += filter(lambda x: x != "forward", moving_actions) 68 | else: 69 | for a in moving_actions: 70 | if action[a] == 1: 71 | no_action_part += filter(lambda x: x != a, moving_actions) 72 | no_action_part.append("attack") 73 | no_action_part.append("jump") 74 | break 75 | if "attack" not in no_action_part: 76 | action["attack"] = 1 77 | for a in no_action_part: 78 | action[a] = 0 79 | for a in action_part: 80 | action[a] = 1 81 | 82 | return action 83 | 84 | @staticmethod 85 | def dict_to_sorted_str(dict_): 86 | return str(sorted(dict_.items())) 87 | 88 | def to_discrete_action(self, action): 89 | for ignored_key in self.ignore_keys: 90 | action.pop(ignored_key, None) 91 | str_dict = self.dict_to_sorted_str(action) 92 | return self.all_actions_dict[str_dict] 93 | 94 | 95 | class SmartDiscrete: 96 | def __init__(self, ignore_keys=None, always_attack=0): 97 | if ignore_keys is None: 98 | ignore_keys = [] 99 | self.always_attack = always_attack 100 | self.angle = 5 101 | self.all_actions_dict = { 102 | "[('attack', {}), ('back', 0), ('camera', [0, 0]), ('forward', 1), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(always_attack): 0, 103 | "[('attack', {}), ('back', 0), ('camera', [0, {}]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(always_attack, 104 | self.angle): 1, 105 | "[('attack', 1), ('back', 0), ('camera', [0, 0]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]": 2, 106 | "[('attack', {}), ('back', 0), ('camera', [{}, 0]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(always_attack, 107 | self.angle): 3, 108 | "[('attack', {}), ('back', 0), ('camera', [-{}, 0]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(always_attack, 109 | self.angle): 4, 110 | "[('attack', {}), ('back', 0), ('camera', [0, -{}]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(always_attack, 111 | self.angle): 5, 112 | "[('attack', {}), ('back', 0), ('camera', [0, 0]), ('forward', 1), ('jump', 1), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(always_attack): 6, 113 | "[('attack', {}), ('back', 0), ('camera', [0, 0]), ('forward', 0), ('jump', 0), ('left', 1), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(always_attack): 7, 114 | "[('attack', {}), ('back', 0), ('camera', [0, 0]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 1), ('sneak', 0), ('sprint', 0)]".format(always_attack): 8, 115 | "[('attack', {}), ('back', 1), ('camera', [0, 0]), ('forward', 0), ('jump', 0), ('left', 0), ('right', 0), ('sneak', 0), ('sprint', 0)]".format(always_attack): 9} 116 | self.ignore_keys = ignore_keys 117 | self.key_to_dict = { 118 | 0: {'attack': always_attack, 'back': 0, 'camera': [0, 0], 'forward': 1, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 119 | 1: {'attack': always_attack, 'back': 0, 'camera': [0, self.angle], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 120 | 2: {'attack': 1, 'back': 0, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 121 | 3: {'attack': always_attack, 'back': 0, 'camera': [self.angle, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 122 | 4: {'attack': always_attack, 'back': 0, 'camera': [-self.angle, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 123 | 5: {'attack': always_attack, 'back': 0, 'camera': [0, -self.angle], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 124 | 6: {'attack': always_attack, 'back': 0, 'camera': [0, 0], 'forward': 1, 'jump': 1, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}, 125 | 7: {'attack': always_attack, 'back': 0, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 1, 'right': 0, 'sneak': 0, 'sprint': 0}, 126 | 8: {'attack': always_attack, 'back': 0, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 1, 'sneak': 0, 'sprint': 0}, 127 | 9: {'attack': always_attack, 'back': 1, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0}} 128 | 129 | @staticmethod 130 | def discrete_camera(camera): 131 | result = list(camera) 132 | if abs(result[1]) >= abs(result[0]): 133 | result[0] = 0 134 | else: 135 | result[1] = 0 136 | 137 | def cut(value, max_value=1.2): 138 | sign = -1 if value < 0 else 1 139 | if abs(value) >= max_value: 140 | return 5 * sign 141 | else: 142 | return 0 143 | 144 | cutten = list(map(cut, result)) 145 | return cutten 146 | 147 | def preprocess_action_dict(self, action_dict): 148 | no_action_part = ["sneak", "sprint"] 149 | action_part = ["attack"] if self.always_attack else [] 150 | moving_actions = ["forward", "back", "right", "left"] 151 | if action_dict["camera"] != [0, 0]: 152 | no_action_part.append("attack") 153 | no_action_part.append("jump") 154 | no_action_part += moving_actions 155 | elif action_dict["jump"] == 1: 156 | action_dict["forward"] = 1 157 | no_action_part += filter(lambda x: x != "forward", moving_actions) 158 | else: 159 | for a in moving_actions: 160 | if action_dict[a] == 1: 161 | no_action_part += filter(lambda x: x != a, moving_actions) 162 | no_action_part.append("attack") 163 | no_action_part.append("jump") 164 | break 165 | if "attack" not in no_action_part: 166 | action_dict["attack"] = 1 167 | for a in no_action_part: 168 | action_dict[a] = 0 169 | for a in action_part: 170 | action_dict[a] = 1 171 | return action_dict 172 | 173 | @staticmethod 174 | def dict_to_sorted_str(dict_): 175 | return str(sorted(dict_.items())) 176 | 177 | def get_key_by_action_dict(self, action_dict): 178 | for ignored_key in self.ignore_keys: 179 | action_dict.pop(ignored_key, None) 180 | str_dict = self.dict_to_sorted_str(action_dict) 181 | return self.all_actions_dict[str_dict] 182 | 183 | def get_action_dict_by_key(self, key): 184 | return self.key_to_dict[key] 185 | 186 | def get_actions_dim(self): 187 | return len(self.key_to_dict) 188 | 189 | 190 | def get_dtype_dict(env): 191 | action_shape = env.action_space.shape 192 | action_shape = action_shape if len(action_shape) > 0 else 1 193 | action_dtype = env.action_space.dtype 194 | action_dtype = 'int32' if np.issubdtype(action_dtype, int) else action_dtype 195 | action_dtype = 'float32' if np.issubdtype(action_dtype, float) else action_dtype 196 | env_dict = {'action': {'shape': action_shape, 197 | 'dtype': action_dtype}, 198 | 'reward': {'dtype': 'float32'}, 199 | 'done': {'dtype': 'bool'}, 200 | 'n_reward': {'dtype': 'float32'}, 201 | 'n_done': {'dtype': 'bool'}, 202 | 'actual_n': {'dtype': 'float32'}, 203 | 'demo': {'dtype': 'float32'} 204 | } 205 | for prefix in ('', 'next_', 'n_'): 206 | if isinstance(env.observation_space, gym.spaces.Dict): 207 | for name, space in env.observation_space.spaces.items(): 208 | env_dict[prefix + name] = {'shape': space.shape, 209 | 'dtype': space.dtype} 210 | else: 211 | env_dict[prefix + 'state'] = {'shape': env.observation_space.shape, 212 | 'dtype': env.observation_space.dtype} 213 | dtype_dict = {key: value['dtype'] for key, value in env_dict.items()} 214 | dtype_dict.update(weights='float32', indexes='int32') 215 | return env_dict, dtype_dict 216 | -------------------------------------------------------------------------------- /utils/fake_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym.spaces import Dict, Box 3 | 4 | from utils.discretization import ExpertActionPreprocessing 5 | 6 | 7 | class FakeEnv(gym.Env): 8 | 9 | def __init__(self, data): 10 | self.index = 0 11 | self.states, self.actions, self.rewards, self.next_states, self.dones = data 12 | self.viewer = None 13 | self.action_preprocessor = ExpertActionPreprocessing() 14 | self.observation_space = Dict(pov=Box(0, 256, (64, 64, 3))) 15 | super().__init__() 16 | 17 | def preprocess_action(self, action): 18 | action = self.action_preprocessor.to_joint_action(action) 19 | action = self.action_preprocessor.to_compressed_action(action) 20 | action = self.action_preprocessor.to_discrete_action(action) 21 | return action 22 | 23 | def step(self, action): 24 | if self.are_all_frames_used(): 25 | raise KeyError("No available data for sampling") 26 | # action = self.preprocess_action(self.actions[self.index]) 27 | 28 | info = {'expert_action': self.actions[self.index]} 29 | result = self.next_states[self.index], self.rewards[self.index], self.dones[self.index], info 30 | self.index += 1 31 | return result 32 | 33 | def reset(self): 34 | if self.are_all_frames_used(): 35 | raise KeyError("No available data for sampling") 36 | return self.states[self.index] 37 | 38 | def _get_image(self): 39 | return self.next_states[min(len(self.states) - 1, self.index)]['pov'] 40 | 41 | def are_all_frames_used(self): 42 | return self.index >= len(self.states) 43 | 44 | def render(self, mode="human"): 45 | img = self._get_image() 46 | if mode == "rgb_array": 47 | return img 48 | elif mode == "human": 49 | from gym.envs.classic_control import rendering 50 | 51 | if self.viewer is None: 52 | self.viewer = rendering.SimpleImageViewer() 53 | self.viewer.imshow(img) 54 | return self.viewer.isopen 55 | -------------------------------------------------------------------------------- /utils/load_demonstrations.py: -------------------------------------------------------------------------------- 1 | import minerl 2 | 3 | 4 | def load_demonstrations(): 5 | minerl.data.download('demonstrations', environment='MineRLTreechop-v0') 6 | minerl.data.download('demonstrations', environment='MineRLObtainDiamondDense-v0') 7 | minerl.data.download('demonstrations', environment='MineRLObtainIronPickaxeDense-v0') 8 | 9 | 10 | if __name__ == "__main__": 11 | load_demonstrations() 12 | -------------------------------------------------------------------------------- /utils/load_weights.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | load_to = 'train/forger++' 4 | name = 'tviskaron/ForgER/forger:v0' 5 | 6 | run = wandb.init(anonymous='allow') 7 | artifact = run.use_artifact(name, type='weights') 8 | artifact_dir = artifact.download(root=load_to) 9 | wandb.finish() 10 | -------------------------------------------------------------------------------- /utils/tf_util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def loss_l(ae, a, margin_value): 5 | margin = tf.where(tf.equal(ae, a), tf.constant(0.0), tf.constant(margin_value)) 6 | return tf.cast(margin, tf.float32) 7 | 8 | 9 | def huber_loss(x, delta=1.0): 10 | """Reference: https://en.wikipedia.org/wiki/Huber_loss""" 11 | return tf.where( 12 | tf.abs(x) < delta, 13 | tf.square(x) * 0.5, 14 | delta * (tf.abs(x) - 0.5 * delta) 15 | ) 16 | 17 | 18 | def saliency_map(f, x): 19 | with tf.GradientTape(watch_accessed_variables=False) as tape: 20 | tape.watch(x) 21 | output = f(x) 22 | max_outp = tf.reduce_max(output, 1) 23 | saliency = tape.gradient(max_outp, x) 24 | return tf.reduce_max(tf.abs(saliency), axis=-1) 25 | 26 | 27 | def take_vector_elements(vectors, indices): 28 | """ 29 | For a batch of vectors, take a single vector component 30 | out of each vector. 31 | Args: 32 | vectors: a [batch x dims] Tensor. 33 | indices: an int32 Tensor with `batch` entries. 34 | Returns: 35 | A Tensor with `batch` entries, one for each vector. 36 | """ 37 | return tf.gather_nd(vectors, tf.stack([tf.range(tf.shape(vectors)[0]), indices], axis=1)) 38 | 39 | 40 | def config_gpu(): 41 | gpus = tf.config.experimental.list_physical_devices('GPU') 42 | if gpus: 43 | try: 44 | # Currently, memory growth needs to be the same across GPUs 45 | for gpu in gpus: 46 | tf.config.experimental.set_memory_growth(gpu, True) 47 | logical_gpus = tf.config.experimental.list_logical_devices('GPU') 48 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 49 | except RuntimeError as e: 50 | # Memory growth must be set before GPUs have been initialized 51 | print(e) 52 | -------------------------------------------------------------------------------- /utils/wrappers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import deque 3 | 4 | import cv2 5 | import numpy as np 6 | import gym 7 | 8 | from utils.config_validation import WrappersCfg 9 | 10 | mapping = dict() 11 | 12 | 13 | class LazyFrames: 14 | 15 | def __init__(self, frames, stack_axis=2): 16 | self.stack_axis = stack_axis 17 | self._frames = frames 18 | 19 | def __array__(self, dtype=None): 20 | out = np.concatenate(self._frames, axis=self.stack_axis) 21 | if dtype is not None: 22 | out = out.astype(dtype) 23 | return out 24 | 25 | 26 | class DiscreteBase(gym.Wrapper): 27 | def __init__(self, env): 28 | super().__init__(env) 29 | self.action_dict = {} 30 | self.action_space = gym.spaces.Discrete(len(self.action_dict)) 31 | 32 | def step(self, action): 33 | s, r, done, info = self.env.step(self.action_dict[action]) 34 | return s, r, done, info 35 | 36 | def sample_action(self): 37 | return self.action_space.sample() 38 | 39 | 40 | class DiscreteWrapper(DiscreteBase): 41 | def __init__(self, env, always_attack=True, angle=5): 42 | super().__init__(env) 43 | self.action_dict = { 44 | 0: {'attack': always_attack, 'back': 0, 'camera': [0, 0], 'forward': 1, 'jump': 0, 'left': 0, 'right': 0, 45 | 'sneak': 0, 'sprint': 0}, 46 | 1: {'attack': always_attack, 'back': 0, 'camera': [0, angle], 'forward': 0, 'jump': 0, 'left': 0, 47 | 'right': 0, 'sneak': 0, 'sprint': 0}, 48 | 2: {'attack': 1, 'back': 0, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 49 | 'sprint': 0}, 50 | 3: {'attack': always_attack, 'back': 0, 'camera': [angle, 0], 'forward': 0, 'jump': 0, 'left': 0, 51 | 'right': 0, 'sneak': 0, 'sprint': 0}, 52 | 4: {'attack': always_attack, 'back': 0, 'camera': [-angle, 0], 'forward': 0, 'jump': 0, 'left': 0, 53 | 'right': 0, 'sneak': 0, 'sprint': 0}, 54 | 5: {'attack': always_attack, 'back': 0, 'camera': [0, -angle], 'forward': 0, 'jump': 0, 'left': 0, 55 | 'right': 0, 'sneak': 0, 'sprint': 0}, 56 | 6: {'attack': always_attack, 'back': 0, 'camera': [0, 0], 'forward': 1, 'jump': 1, 'left': 0, 'right': 0, 57 | 'sneak': 0, 'sprint': 0}, 58 | 7: {'attack': always_attack, 'back': 0, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 1, 'right': 0, 59 | 'sneak': 0, 'sprint': 0}, 60 | 8: {'attack': always_attack, 'back': 0, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 1, 61 | 'sneak': 0, 'sprint': 0}, 62 | 9: {'attack': always_attack, 'back': 1, 'camera': [0, 0], 'forward': 0, 'jump': 0, 'left': 0, 'right': 0, 63 | 'sneak': 0, 'sprint': 0}} 64 | self.action_space = gym.spaces.Discrete(len(self.action_dict)) 65 | 66 | 67 | class ActionNoiseWrapper(gym.Wrapper): 68 | 69 | def step(self, action: dict): 70 | return self.env.step(self.apply_noise(action)) 71 | 72 | @staticmethod 73 | def apply_noise(action): 74 | if 'camera' in action: 75 | x, y = list(np.random.normal(0, 0.02, 2)) 76 | action['camera'] = np.add([x, y], action['camera']) 77 | return action 78 | 79 | 80 | def wrap_env(env, settings: WrappersCfg): 81 | env = ActionNoiseWrapper(env) 82 | env = ObtainPoVWrapper(env) 83 | env = FrameSkip(env, settings.frame_skip) 84 | env = FrameStack(env, settings.frame_stack) 85 | env = DiscreteWrapper(env) 86 | return env 87 | 88 | 89 | def register(name): 90 | def _thunk(func): 91 | mapping[name] = func 92 | return func 93 | 94 | return _thunk 95 | 96 | 97 | class FrameSkip(gym.Wrapper): 98 | def __init__(self, env, skip=4): 99 | super().__init__(env) 100 | 101 | self._skip = skip 102 | 103 | def step(self, action): 104 | total_reward = 0.0 105 | infos = [] 106 | info = {} 107 | obs = None 108 | done = None 109 | for _ in range(self._skip): 110 | obs, reward, done, info = self.env.step(action) 111 | infos.append(info) 112 | total_reward += reward 113 | if done: 114 | break 115 | if 'expert_action' in infos[0]: 116 | info['expert_action'] = self.env.preprocess_action([info_['expert_action'] for info_ in infos]) 117 | return obs, total_reward, done, info 118 | 119 | 120 | class FrameStack(gym.Wrapper): 121 | def __init__(self, env, k, channel_order='hwc', use_tuple=False): 122 | 123 | gym.Wrapper.__init__(self, env) 124 | self.k = k 125 | self.observations = deque([], maxlen=k) 126 | self.stack_axis = {'hwc': 2, 'chw': 0}[channel_order] 127 | self.use_tuple = use_tuple 128 | 129 | if self.use_tuple: 130 | pov_space = env.observation_space[0] 131 | inv_space = env.observation_space[1] 132 | else: 133 | inv_space = None 134 | pov_space = env.observation_space 135 | 136 | low_pov = np.repeat(pov_space.low, k, axis=self.stack_axis) 137 | high_pov = np.repeat(pov_space.high, k, axis=self.stack_axis) 138 | pov_space = gym.spaces.Box(low=low_pov, high=high_pov, dtype=pov_space.dtype) 139 | 140 | if self.use_tuple: 141 | low_inv = np.repeat(inv_space.low, k, axis=0) 142 | high_inv = np.repeat(inv_space.high, k, axis=0) 143 | inv_space = gym.spaces.Box(low=low_inv, high=high_inv, dtype=inv_space.dtype) 144 | self.observation_space = gym.spaces.Tuple( 145 | (pov_space, inv_space)) 146 | else: 147 | self.observation_space = pov_space 148 | 149 | def reset(self): 150 | ob = self.env.reset() 151 | for _ in range(self.k): 152 | self.observations.append(ob) 153 | return self._get_ob() 154 | 155 | def step(self, action): 156 | ob, reward, done, info = self.env.step(action) 157 | self.observations.append(ob) 158 | return self._get_ob(), reward, done, info 159 | 160 | def _get_ob(self): 161 | assert len(self.observations) == self.k 162 | if self.use_tuple: 163 | frames = [x[0] for x in self.observations] 164 | inventory = [x[1] for x in self.observations] 165 | return (LazyFrames(list(frames), stack_axis=self.stack_axis), 166 | LazyFrames(list(inventory), stack_axis=0)) 167 | else: 168 | return LazyFrames(list(self.observations), stack_axis=self.stack_axis) 169 | 170 | 171 | class ObtainPoVWrapper(gym.ObservationWrapper): 172 | """Obtain 'pov' value (current game display) of the original observation.""" 173 | 174 | def __init__(self, env): 175 | super().__init__(env) 176 | 177 | self.observation_space = self.env.observation_space.spaces['pov'] 178 | 179 | def observation(self, observation): 180 | return observation['pov'] 181 | 182 | 183 | class DiscreteBase(gym.Wrapper): 184 | def __init__(self, env): 185 | super().__init__(env) 186 | self.action_dict = {} 187 | self.action_space = gym.spaces.Discrete(len(self.action_dict)) 188 | 189 | def step(self, action): 190 | s, r, done, info = self.env.step(self.action_dict[action]) 191 | return s, r, done, info 192 | 193 | def sample_action(self): 194 | return self.action_space.sample() 195 | 196 | 197 | class SaveVideoWrapper(gym.Wrapper): 198 | current_episode = 0 199 | 200 | def __init__(self, env, path='train/', resize=4, reward_threshold=0): 201 | """ 202 | :param env: wrapped environment 203 | :param path: path to save videos 204 | :param resize: resize factor 205 | """ 206 | super().__init__(env) 207 | self.path = path 208 | self.recording = [] 209 | self.rewards = [0] 210 | self.resize = resize 211 | self.reward_threshold = reward_threshold 212 | self.previous_reward = 0 213 | 214 | def step(self, action): 215 | """ 216 | make a step in environment 217 | :param action: agent's action 218 | :return: observation, reward, done, info 219 | """ 220 | observation, reward, done, info = self.env.step(action) 221 | self.rewards.append(reward) 222 | self.recording.append(self.bgr_to_rgb(observation['pov'])) 223 | return observation, reward, done, info 224 | 225 | def get_reward(self): 226 | return sum(map(int, self.rewards)) 227 | 228 | def reset(self, **kwargs): 229 | """ 230 | reset environment and save game video if its not empty 231 | :param kwargs: 232 | :return: current observation 233 | """ 234 | 235 | reward = self.get_reward() 236 | if self.current_episode > 0 and reward >= self.reward_threshold: 237 | name = str(self.current_episode).zfill(4) + "r" + str(reward).zfill(4) + ".mp4" 238 | full_path = os.path.join(self.path, name) 239 | upscaled_video = [self.upscale_image(image, self.resize) for image in self.recording] 240 | self.save_video(full_path, video=upscaled_video) 241 | self.current_episode += 1 242 | self.rewards = [0] 243 | self.recording = [] 244 | self.env.seed(self.current_episode) 245 | observation = self.env.reset(**kwargs) 246 | self.recording.append(self.bgr_to_rgb(observation['pov'])) 247 | return observation 248 | 249 | @staticmethod 250 | def upscale_image(image, resize): 251 | """ 252 | increase image size (for better video quality) 253 | :param image: original image 254 | :param resize: 255 | :return: 256 | """ 257 | size_x, size_y, size_z = image.shape 258 | return cv2.resize(image, dsize=(size_x * resize, size_y * resize)) 259 | 260 | @staticmethod 261 | def save_video(filename, video): 262 | """ 263 | saves video from list of np.array images 264 | :param filename: filename or path to file 265 | :param video: [image, ..., image] 266 | :return: 267 | """ 268 | size_x, size_y, size_z = video[0].shape 269 | out = cv2.VideoWriter(filename, cv2.VideoWriter_fourcc(*'mp4v'), 60.0, (size_y, size_x)) 270 | for image in video: 271 | out.write(image) 272 | out.release() 273 | cv2.destroyAllWindows() 274 | 275 | @staticmethod 276 | def bgr_to_rgb(image): 277 | """ 278 | converts BGR image to RGB 279 | :param image: bgr image 280 | :return: rgb image 281 | """ 282 | return image[..., ::-1] 283 | --------------------------------------------------------------------------------