├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── ai_challenge ├── README.md └── pig_chase │ ├── README.md │ ├── agent.py │ ├── common.py │ ├── environment.py │ ├── evaluation.py │ ├── pig-chase-overview.png │ ├── pig_chase.xml │ ├── pig_chase_baseline.py │ ├── pig_chase_dqn.py │ ├── pig_chase_dqn_top_down.py │ ├── pig_chase_eval_sample.py │ └── pig_chase_human_vs_agent.py ├── docker ├── README.md ├── malmo │ ├── Dockerfile │ ├── options.txt │ └── run.sh ├── malmopy-ai-challenge │ └── docker-compose.yml ├── malmopy-chainer-cpu │ └── Dockerfile ├── malmopy-chainer-gpu │ └── Dockerfile ├── malmopy-cntk-cpu-py27 │ └── Dockerfile └── malmopy-cntk-gpu-py27 │ └── Dockerfile ├── malmopy ├── README.md ├── __init__.py ├── agent │ ├── __init__.py │ ├── agent.py │ ├── astar.py │ ├── explorer.py │ ├── gui.py │ └── qlearner.py ├── environment │ ├── __init__.py │ ├── environment.py │ ├── gym │ │ ├── __init__.py │ │ └── gym.py │ └── malmo │ │ ├── __init__.py │ │ └── malmo.py ├── model │ ├── __init__.py │ ├── chainer │ │ ├── __init__.py │ │ └── qlearning.py │ ├── cntk │ │ ├── __init__.py │ │ ├── base.py │ │ └── qlearning.py │ └── model.py ├── util │ ├── __init__.py │ ├── images.py │ └── util.py ├── version.py └── visualization │ ├── __init__.py │ ├── tensorboard │ ├── __init__.py │ ├── cntk │ │ ├── __init__.py │ │ └── cntk.py │ └── tensorboard.py │ └── visualizer.py ├── samples └── atari │ └── gym_atari_dqn.py └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Set default behavior to automatically normalize line endings. 3 | ############################################################################### 4 | * text=auto 5 | 6 | ############################################################################### 7 | # Set default behavior for command prompt diff. 8 | # 9 | # This is need for earlier builds of msysgit that does not have it on by 10 | # default for csharp files. 11 | # Note: This is only used by command line 12 | ############################################################################### 13 | #*.cs diff=csharp 14 | 15 | ############################################################################### 16 | # Set the merge driver for project and solution files 17 | # 18 | # Merging from the command prompt will add diff markers to the files if there 19 | # are conflicts (Merging from VS is not affected by the settings below, in VS 20 | # the diff markers are never inserted). Diff markers may cause the following 21 | # file extensions to fail to load in VS. An alternative would be to treat 22 | # these files as binary and thus will always conflict and require user 23 | # intervention with every merge. To do so, just uncomment the entries below 24 | ############################################################################### 25 | #*.sln merge=binary 26 | #*.csproj merge=binary 27 | #*.vbproj merge=binary 28 | #*.vcxproj merge=binary 29 | #*.vcproj merge=binary 30 | #*.dbproj merge=binary 31 | #*.fsproj merge=binary 32 | #*.lsproj merge=binary 33 | #*.wixproj merge=binary 34 | #*.modelproj merge=binary 35 | #*.sqlproj merge=binary 36 | #*.wwaproj merge=binary 37 | 38 | ############################################################################### 39 | # behavior for image files 40 | # 41 | # image files are treated as binary by default. 42 | ############################################################################### 43 | #*.jpg binary 44 | #*.png binary 45 | #*.gif binary 46 | 47 | ############################################################################### 48 | # diff behavior for common document formats 49 | # 50 | # Convert binary document formats to text before diffing them. This feature 51 | # is only available from the command line. Turn it on by uncommenting the 52 | # entries below. 53 | ############################################################################### 54 | #*.doc diff=astextplain 55 | #*.DOC diff=astextplain 56 | #*.docx diff=astextplain 57 | #*.DOCX diff=astextplain 58 | #*.dot diff=astextplain 59 | #*.DOT diff=astextplain 60 | #*.pdf diff=astextplain 61 | #*.PDF diff=astextplain 62 | #*.rtf diff=astextplain 63 | #*.RTF diff=astextplain 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | 4 | # User-specific files 5 | *.suo 6 | *.user 7 | *.userosscache 8 | *.sln.docstates 9 | 10 | # User-specific files (MonoDevelop/Xamarin Studio) 11 | *.userprefs 12 | 13 | # Build results 14 | [Dd]ebug/ 15 | [Dd]ebugPublic/ 16 | [Rr]elease/ 17 | [Rr]eleases/ 18 | x64/ 19 | x86/ 20 | bld/ 21 | [Bb]in/ 22 | [Oo]bj/ 23 | [Ll]og/ 24 | 25 | # Visual Studio 2015 cache/options directory 26 | .vs/ 27 | # Uncomment if you have tasks that create the project's static files in wwwroot 28 | #wwwroot/ 29 | 30 | # MSTest test Results 31 | [Tt]est[Rr]esult*/ 32 | [Bb]uild[Ll]og.* 33 | 34 | # NUNIT 35 | *.VisualState.xml 36 | TestResult.xml 37 | 38 | # Build Results of an ATL Project 39 | [Dd]ebugPS/ 40 | [Rr]eleasePS/ 41 | dlldata.c 42 | 43 | # DNX 44 | project.lock.json 45 | project.fragment.lock.json 46 | artifacts/ 47 | 48 | *_i.c 49 | *_p.c 50 | *_i.h 51 | *.ilk 52 | *.meta 53 | *.obj 54 | *.pch 55 | *.pdb 56 | *.pgc 57 | *.pgd 58 | *.rsp 59 | *.sbr 60 | *.tlb 61 | *.tli 62 | *.tlh 63 | *.tmp 64 | *.tmp_proj 65 | *.log 66 | *.vspscc 67 | *.vssscc 68 | .builds 69 | *.pidb 70 | *.svclog 71 | *.scc 72 | 73 | # Chutzpah Test files 74 | _Chutzpah* 75 | 76 | # Visual C++ cache files 77 | ipch/ 78 | *.aps 79 | *.ncb 80 | *.opendb 81 | *.opensdf 82 | *.sdf 83 | *.cachefile 84 | *.VC.db 85 | *.VC.VC.opendb 86 | 87 | # Visual Studio profiler 88 | *.psess 89 | *.vsp 90 | *.vspx 91 | *.sap 92 | 93 | # TFS 2012 Local Workspace 94 | $tf/ 95 | 96 | # Guidance Automation Toolkit 97 | *.gpState 98 | 99 | # ReSharper is a .NET coding add-in 100 | _ReSharper*/ 101 | *.[Rr]e[Ss]harper 102 | *.DotSettings.user 103 | 104 | # JustCode is a .NET coding add-in 105 | .JustCode 106 | 107 | # TeamCity is a build add-in 108 | _TeamCity* 109 | 110 | # DotCover is a Code Coverage Tool 111 | *.dotCover 112 | 113 | # NCrunch 114 | _NCrunch_* 115 | .*crunch*.local.xml 116 | nCrunchTemp_* 117 | 118 | # MightyMoose 119 | *.mm.* 120 | AutoTest.Net/ 121 | 122 | # Web workbench (sass) 123 | .sass-cache/ 124 | 125 | # Installshield output folder 126 | [Ee]xpress/ 127 | 128 | # DocProject is a documentation generator add-in 129 | DocProject/buildhelp/ 130 | DocProject/Help/*.HxT 131 | DocProject/Help/*.HxC 132 | DocProject/Help/*.hhc 133 | DocProject/Help/*.hhk 134 | DocProject/Help/*.hhp 135 | DocProject/Help/Html2 136 | DocProject/Help/html 137 | 138 | # Click-Once directory 139 | publish/ 140 | 141 | # Publish Web Output 142 | *.[Pp]ublish.xml 143 | *.azurePubxml 144 | # TODO: Comment the next line if you want to checkin your web deploy settings 145 | # but database connection strings (with potential passwords) will be unencrypted 146 | #*.pubxml 147 | *.publishproj 148 | 149 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 150 | # checkin your Azure Web App publish settings, but sensitive information contained 151 | # in these scripts will be unencrypted 152 | PublishScripts/ 153 | 154 | # NuGet Packages 155 | *.nupkg 156 | # The packages folder can be ignored because of Package Restore 157 | **/packages/* 158 | # except build/, which is used as an MSBuild target. 159 | !**/packages/build/ 160 | # Uncomment if necessary however generally it will be regenerated when needed 161 | #!**/packages/repositories.config 162 | # NuGet v3's project.json files produces more ignoreable files 163 | *.nuget.props 164 | *.nuget.targets 165 | 166 | # Microsoft Azure Build Output 167 | csx/ 168 | *.build.csdef 169 | 170 | # Microsoft Azure Emulator 171 | ecf/ 172 | rcf/ 173 | 174 | # Windows Store app package directories and files 175 | AppPackages/ 176 | BundleArtifacts/ 177 | Package.StoreAssociation.xml 178 | _pkginfo.txt 179 | 180 | # Visual Studio cache files 181 | # files ending in .cache can be ignored 182 | *.[Cc]ache 183 | # but keep track of directories ending in .cache 184 | !*.[Cc]ache/ 185 | 186 | # Others 187 | ClientBin/ 188 | ~$* 189 | *~ 190 | *.dbmdl 191 | *.dbproj.schemaview 192 | *.jfm 193 | *.pfx 194 | *.publishsettings 195 | node_modules/ 196 | orleans.codegen.cs 197 | 198 | # Since there are multiple workflows, uncomment next line to ignore bower_components 199 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 200 | #bower_components/ 201 | 202 | # RIA/Silverlight projects 203 | Generated_Code/ 204 | 205 | # Backup & report files from converting an old project file 206 | # to a newer Visual Studio version. Backup files are not needed, 207 | # because we have git ;-) 208 | _UpgradeReport_Files/ 209 | Backup*/ 210 | UpgradeLog*.XML 211 | UpgradeLog*.htm 212 | 213 | # SQL Server files 214 | *.mdf 215 | *.ldf 216 | 217 | # Business Intelligence projects 218 | *.rdl.data 219 | *.bim.layout 220 | *.bim_*.settings 221 | 222 | # Microsoft Fakes 223 | FakesAssemblies/ 224 | 225 | # GhostDoc plugin setting file 226 | *.GhostDoc.xml 227 | 228 | # Node.js Tools for Visual Studio 229 | .ntvs_analysis.dat 230 | 231 | # Visual Studio 6 build log 232 | *.plg 233 | 234 | # Visual Studio 6 workspace options file 235 | *.opt 236 | 237 | # Visual Studio LightSwitch build output 238 | **/*.HTMLClient/GeneratedArtifacts 239 | **/*.DesktopClient/GeneratedArtifacts 240 | **/*.DesktopClient/ModelManifest.xml 241 | **/*.Server/GeneratedArtifacts 242 | **/*.Server/ModelManifest.xml 243 | _Pvt_Extensions 244 | 245 | # Paket dependency manager 246 | .paket/paket.exe 247 | paket-files/ 248 | 249 | # FAKE - F# Make 250 | .fake/ 251 | 252 | # JetBrains Rider 253 | .idea/ 254 | *.sln.iml 255 | 256 | # CodeRush 257 | .cr/ 258 | 259 | # Python Tools for Visual Studio (PTVS) 260 | __pycache__/ 261 | *.pyc 262 | 263 | # Tests cache 264 | */tests/.cache/v/cache/ 265 | 266 | # Library 267 | *.pyd 268 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Malmo Collaborative AI Challenge 2 | 3 | This repository contains the task definition and example code for the [Malmo Collaborative AI Challenge](https://www.microsoft.com/en-us/research/academic-program/collaborative-ai-challenge/). 4 | This challenge is organized to encourage research in collaborative AI - to work towards AI agents 5 | that learn to collaborate to solve problems and achieve goals. 6 | You can find additional details, including terms and conditions, prizes and information on how to participate at the [Challenge Homepage](https://www.microsoft.com/en-us/research/academic-program/collaborative-ai-challenge/). 7 | 8 | [![Join the chat at https://gitter.im/malmo-challenge/Lobby](https://badges.gitter.im/malmo-challenge/Lobby.svg)](https://gitter.im/malmo-challenge/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 9 | [![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/Microsoft/malmo-challenge/blob/master/LICENSE) 10 | 11 | ---- 12 | 13 | **Notes for challenge participants:** Once you and your team decide to participate in the challenge, please make sure to register your team at our [Registration Page](https://www.surveygizmo.com/s3/3299773/The-Collaborative-AI-Challenge). On the registration form, you need to provide a link to the GitHub repository that will 14 | contain your solution. We recommend that you fork this repository (learn how), 15 | and provide address of the forked repo. You can then update your submission as you make progress on the challenge task. 16 | We will consider the version of the code on branch master at the time of the submission deadline as your challenge submission. Your submission needs to contain code in working order, a 1-page description of your approach, and a 1-minute video that shows off your agent. Please see the [challenge terms and conditions]() for further details. 17 | 18 | ---- 19 | 20 | **Jump to:** 21 | 22 | - [Installation](#installation) 23 | - [Prerequisites](#prerequisites) 24 | - [Minimal installation](#minimal-installation) 25 | - [Optional extensions](#optional-extensions) 26 | 27 | - [Getting started](#getting-started) 28 | - [Play the challenge task](#play-the-challenge-task) 29 | - [Run your first experiment](#run-your-first-experiment) 30 | 31 | - [Next steps](#next-steps) 32 | - [Run an experiment in Docker on Azure](#run-an-experiment-in-docker-on-azure) 33 | - [Compare your results again other teams](#compare-your-results-against-other-teams) 34 | - [Resources](#resources) 35 | 36 | # Installation 37 | 38 | ## Prerequisites 39 | 40 | - [Python](https://www.python.org/) 2.7+ (recommended) or 3.5+ 41 | - [Project Malmo](https://github.com/Microsoft/malmo) - we recommend downloading the [Malmo-0.21.0 release](https://github.com/Microsoft/malmo/releases) and installing dependencies for [Windows](https://github.com/Microsoft/malmo/blob/master/doc/install_windows.md), [Linux](https://github.com/Microsoft/malmo/blob/master/doc/install_linux.md) or [MacOS](https://github.com/Microsoft/malmo/blob/master/doc/install_macosx.md). Test your Malmo installation by [launching Minecraft with Malmo](https://github.com/Microsoft/malmo#launching-minecraft-with-our-mod) and [launching an agent](https://github.com/Microsoft/malmo#launch-an-agent). 42 | 43 | ## Minimal installation 44 | 45 | ``` 46 | pip install -e git+https://github.com/Microsoft/malmo-challenge#egg=malmopy 47 | ``` 48 | 49 | or 50 | 51 | ``` 52 | git clone https://github.com/Microsoft/malmo-challenge 53 | cd malmo-challenge 54 | pip install -e . 55 | ``` 56 | 57 | ## Optional extensions 58 | 59 | Some of the example code uses additional dependencies to provide 'extra' functionality. These can be installed using: 60 | 61 | ``` 62 | pip install -e '.[extra1, extra2]' 63 | ``` 64 | For example to install gym and chainer: 65 | 66 | ``` 67 | pip install -e '.[gym]' 68 | ``` 69 | 70 | Or to install all extras: 71 | 72 | ``` 73 | pip install -e '.[all]' 74 | ``` 75 | 76 | The following extras are available: 77 | - `gym`: [OpenAI Gym](https://gym.openai.com/) is an interface to a wide range of reinforcement learning environments. Installing this extra enables the Atari example agents in [samples/atari](samples/atari) to train on the gym environments. *Note that OpenAI gym atari environments are currently not available on Windows.* 78 | - `tensorflow`: [TensorFlow](https://www.tensorflow.org/) is a popular deep learning framework developed by Google. In our examples it enables visualizations through [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard). 79 | 80 | 81 | # Getting started 82 | 83 | ## Play the challenge task 84 | 85 | The challenge task takes the form of a mini game, called Pig Chase. Learn about the game, and try playing it yourself on our [Pig Chase Challenge page](ai_challenge/pig_chase/README.md). 86 | 87 | ## Run your first experiment 88 | 89 | See how to [run your first baseline experiment](ai_challenge/pig_chase/README.md#run-your-first-experiment) on the [Pig Chase Challenge page](ai_challenge/pig_chase/README.md). 90 | 91 | # Next steps 92 | 93 | ## Run an experiment in Docker on Azure 94 | 95 | Docker is a virtualization platform that makes it easy to deploy software with all its dependencies. 96 | We use docker to run experiments locally or in the cloud. Details on how to run an example experiment using docker are in the [docker README](docker/README.md). 97 | 98 | ## Compare your results against other teams: 99 | 100 | We provide you a [leaderboard website](https://malmo-leaderboard.azurewebsites.net/) where you can compare your results against the other participants. 101 | 102 | 103 | ## Resources 104 | 105 | - [Malmo Platform Tutorial](https://github.com/Microsoft/malmo/blob/master/Malmo/samples/Python_examples/Tutorial.pdf) 106 | - [Azure Portal](portal.azure.com/) 107 | - [Docker Documentation](https://docs.docker.com/) 108 | - [Docker Machine on Azure](https://docs.microsoft.com/en-us/azure/virtual-machines/virtual-machines-linux-docker-machine) 109 | - [CNTK Tutorials](https://www.microsoft.com/en-us/research/product/cognitive-toolkit/tutorials/) 110 | - [CNTK Documentation](https://github.com/Microsoft/CNTK/wiki) 111 | - [Chainer Documentation](http://docs.chainer.org/en/stable/) 112 | - [TensorBoard Documentation](https://www.tensorflow.org/get_started/summaries_and_tensorboard) 113 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /ai_challenge/README.md: -------------------------------------------------------------------------------- 1 | # Malmo Collaborative AI Challenge 2 | 3 | This folder contains task definitions for the Malmo Collaborative 4 | AI Challenge. For installation instructions see [Installation](../README.md#installation). 5 | 6 | ## Available challenges 7 | 8 | - [Malmo Collaborative AI Challenge - Pig Chase](pig_chase/README.md) : Try to build collaborative 9 | AI agents trying to catch a pig. 10 | 11 | ## Further reading 12 | 13 | Once you have familiarized yourself with a challenge task, you may want to head back to the [Overview Page](../README.md) to learn about [Installation](../README.md#installation) and [Getting started](../README.md#getting-started). Or dive into the code examples to learn how to [Write your first agent](../malmopy/README.md#write-your-first-agent). 14 | 15 | -------------------------------------------------------------------------------- /ai_challenge/pig_chase/README.md: -------------------------------------------------------------------------------- 1 | # Malmo Collaborative AI Challenge - Pig Chase 2 | 3 | This repository contains Malmo Collaborative AI challenge task definition. The challenge task takes the form of a collaborative mini game, called Pig Chase. 4 | 5 | ![Screenshot of the pig chase game](pig-chase-overview.png?raw=true "Screenshot of the Pig Chase game") 6 | 7 | ## Overview of the game 8 | 9 | Two Minecraft agents and a pig are wandering a small meadow. The agents have two choices: 10 | 11 | - _Catch the pig_ (i.e., the agents pinch or corner the pig, and no escape path is available), and receive a high reward (25 points) 12 | - _Give up_ and leave the pig pen through the exits to the left and right of the pen, marked by blue squares, and receive a small reward (5 points) 13 | 14 | The pig chased is inspired by the variant of the _stag hunt_ presented in [Yoshida et al. 2008]. The [stag hunt](https://en.wikipedia.org/wiki/Stag_hunt) is a classical game theoretic game formulation that captures conflicts between collaboration and individual safety. 15 | 16 | [Yoshida et al. 2008] Yoshida, Wako, Ray J. Dolan, and Karl J. Friston. ["Game theory of mind."](http://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1000254) PLoS Comput Biol 4.12 (2008): e1000254. 17 | 18 | 19 | ## How to play (human players) 20 | 21 | To familiarize yourself with the game, we recommend that you play it yourself. The following instructions allow you to play the game with a "focused agent". A baseline agent that tries to move towards the pig whenever possible. 22 | 23 | ### Prerequisites 24 | 25 | * Install the [Malmo Platform](https://github.com/Microsoft/malmo) and the `malmopy` framework as described under [Installation](../../README.md#installation), and verify that you can run the Malmo platform and python example agents 26 | 27 | ### Steps 28 | 29 | * Start two instances of the Malmo Client on ports `10000` and `10001` 30 | * `cd malmo-challenge/ai_challenge/pig_chase` 31 | * `python pig_chase_human_vs_agent.py` 32 | 33 | Wait for a few seconds for the human player interface to appear. 34 | 35 | Note: the script assumes that two Malmo clients are running on the default ports on localhost. You can specify alternative clients on the command line. See the script's usage instructions (`python pig_chase_human_vs_agent.py -h`) for details. 36 | 37 | ### How to play 38 | 39 | * The game is played over 10 rounds at a time. Goal is to accumulate the highest score over these 10 rounds. 40 | * In each round a "collaborator" agent is selected to play with you. Different collaborators may have different behaviors. 41 | * Once the game has started, use the left/right arrow keys to turn, and the forward/backward keys to move. You can see your agent move in the first person view, and shown as a red arrow in the top-down rendering on the left. 42 | * You and your collaborator move in turns and try to catch the pig (25 points if caught). You can give up on catching the pig in the current round by moving to the blue "exit squares" (5 points). You have a maximum of 25 steps available, and will get -1 point for each step taken. 43 | 44 | ## Run your first experiment 45 | 46 | An example experiment is provided in `pig_chase_baseline.py`. To run it, start two instances of the Malmo Client as [above](#steps). Then run: 47 | 48 | ``` 49 | python pig_chase_baseline.py 50 | ``` 51 | 52 | Depending on whether `tensorboard` is available on your system, this script will output performance statistics to either tensorboard or to console. If using tensorboard, you can plot the stored data by pointing a tensorboard instance to the results folder: 53 | 54 | ``` 55 | cd ai_challenge/pig_chase 56 | tensorboard --logdir=results --port=6006 57 | ``` 58 | 59 | You can then navigate to http://127.0.0.1:6006 to view the results. 60 | 61 | The baseline script runs a `FocusedAgent` by default - it uses a simple planning algorithm to find a shortest path to the pig. You can also run a `RandomAgent` baseline. Switch agents using the command line arguments: 62 | 63 | ``` 64 | python pig_chase_baseline.py -t random 65 | ``` 66 | 67 | For additional command line options, see the usage instructions: `python pig_chase_baseline.py -h`. 68 | 69 | ## Evaluate your agent 70 | 71 | We provide a commodity evaluator PigChaseEvaluator, which allows you to quickly evaluate 72 | the performance of your agent. 73 | 74 | PigChaseEvaluator takes 2 arguments: 75 | - agent_100k : Your agent trained with 100k steps (100k train calls) 76 | - agent_500k : Your agent trained with 500k steps (500k train calls) 77 | 78 | To evaluate your agent: 79 | 80 | ``` python 81 | # Creates an agent trained with 100k train calls 82 | my_agent_100k = MyCustomAgent() 83 | 84 | # Creates an agent trained with 500k train calls 85 | my_agent_500k = MyCustomAgent() 86 | 87 | # You can pass a custom StateBuilder for your agent. 88 | # It will be used by the environment to generate state for your agent 89 | eval = PigChaseEvaluator(my_agent_100k, my_agent_500k, MyStateBuilder()) 90 | 91 | # Run and save 92 | eval.run() 93 | eval.save('My experiment 1', 'path/to/save.json') 94 | ``` 95 | 96 | ## Compare against other teams: 97 | 98 | Submit your evaluation results on the [leaderboard website](https://malmo-leaderboard.azurewebsites.net/) to compare your results against the other participants. 99 | 100 | 101 | ## Next steps 102 | 103 | To participate in the Collaborative AI Challenge, implement and train an agent that can effectively collaborate with any collaborator. Your agent can use either the first-person visual view, or the symbolic view (as demonstrated in the `FocusedAgent`). You can use any AI/learning approach you like - originality of the chose approach is part of the criteria for the challenge prizes. Can you come up with an agent learns to outperform the A-star baseline agent? Can an agent learn to play with a copy of itself? Can it outperform your own (human) score? 104 | 105 | For more inspiration, you can look at more [code samples](../../samples/README.md) or learn how to [run experiments on Azure using docker](../../docker/README.md). 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /ai_challenge/pig_chase/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | ENV_AGENT_NAMES = ['Agent_1', 'Agent_2'] 19 | ENV_TARGET_NAMES = ['Pig'] 20 | ENV_ENTITIES_NAME = ENV_AGENT_NAMES + ENV_TARGET_NAMES 21 | ENV_ACTIONS = ["move 1", "turn -1", "turn 1"] 22 | ENV_ENTITIES = 'entities' 23 | ENV_BOARD = 'board' 24 | ENV_BOARD_SHAPE = (9, 9) 25 | ENV_INDIVIDUAL_REWARD = 5 26 | ENV_CAUGHT_REWARD = 25 27 | class ENV_AGENT_TYPES: 28 | RANDOM, FOCUSED, TABQ, DEEPQ, HUMAN, OTHER = range(0, 6) 29 | 30 | def parse_clients_args(args_clients): 31 | """ 32 | Return an array of tuples (ip, port) extracted from ip:port string 33 | :param args_clients: 34 | :return: 35 | """ 36 | return [str.split(str(client), ':') for client in args_clients] 37 | 38 | def visualize_training(visualizer, step, rewards, tag='Training'): 39 | visualizer.add_entry(step, '%s/reward per episode' % tag, sum(rewards)) 40 | visualizer.add_entry(step, '%s/max.reward' % tag, max(rewards)) 41 | visualizer.add_entry(step, '%s/min.reward' % tag, min(rewards)) 42 | visualizer.add_entry(step, '%s/actions per episode' % tag, len(rewards)-1) 43 | 44 | class Entity(object): 45 | """ Wrap entity attributes """ 46 | 47 | def __init__(self, x, y, z, yaw, pitch, name=''): 48 | self._name = name 49 | self._x = int(x) 50 | self._y = int(y) 51 | self._z = int(z) 52 | self._yaw = int(yaw) % 360 53 | self._pitch = int(pitch) 54 | 55 | @property 56 | def name(self): 57 | return self._name 58 | 59 | @property 60 | def x(self): 61 | return self._x 62 | 63 | @x.setter 64 | def x(self, value): 65 | self._x = int(value) 66 | 67 | @property 68 | def y(self): 69 | return self._y 70 | 71 | @y.setter 72 | def y(self, value): 73 | self._y = int(value) 74 | 75 | @property 76 | def z(self): 77 | return self._z 78 | 79 | @z.setter 80 | def z(self, value): 81 | self._z = int(value) 82 | 83 | @property 84 | def yaw(self): 85 | return self._yaw 86 | 87 | @yaw.setter 88 | def yaw(self, value): 89 | self._yaw = int(value) % 360 90 | 91 | @property 92 | def pitch(self): 93 | return self._pitch 94 | 95 | @pitch.setter 96 | def pitch(self, value): 97 | self._pitch = int(value) 98 | 99 | @property 100 | def position(self): 101 | return self._x, self._y, self._z 102 | 103 | def __eq__(self, other): 104 | if isinstance(other, tuple): 105 | return self.position == other 106 | 107 | def __getitem__(self, item): 108 | return getattr(self, item) 109 | 110 | @classmethod 111 | def create(cls, obj): 112 | return cls(obj['x'], obj['y'], obj['z'], obj['yaw'], obj['pitch']) 113 | -------------------------------------------------------------------------------- /ai_challenge/pig_chase/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | import os 19 | import sys 20 | from time import sleep 21 | 22 | from common import parse_clients_args, ENV_AGENT_NAMES 23 | from agent import PigChaseChallengeAgent 24 | from common import ENV_AGENT_NAMES 25 | from environment import PigChaseEnvironment, PigChaseSymbolicStateBuilder 26 | 27 | # Enforce path 28 | sys.path.insert(0, os.getcwd()) 29 | sys.path.insert(1, os.path.join(os.path.pardir, os.getcwd())) 30 | 31 | 32 | class PigChaseEvaluator(object): 33 | def __init__(self, clients, agent_100k, agent_500k, state_builder): 34 | assert len(clients) >= 2, 'Not enough clients provided' 35 | 36 | self._clients = clients 37 | self._agent_100k = agent_100k 38 | self._agent_500k = agent_500k 39 | self._state_builder = state_builder 40 | self._accumulators = {'100k': [], '500k': []} 41 | 42 | def save(self, experiment_name, filepath): 43 | """ 44 | Save the evaluation results in a JSON file 45 | understandable by the leaderboard. 46 | 47 | Note: The leaderboard will not accept a submission if you already 48 | uploaded a file with the same experiment name. 49 | 50 | :param experiment_name: An identifier for the experiment 51 | :param filepath: Path where to store the results file 52 | :return: 53 | """ 54 | 55 | assert experiment_name is not None, 'experiment_name cannot be None' 56 | 57 | from json import dump 58 | from os.path import exists, join, pardir, abspath 59 | from os import makedirs 60 | from numpy import mean, var 61 | 62 | # Compute metrics 63 | metrics = {key: {'mean': mean(buffer), 64 | 'var': var(buffer), 65 | 'count': len(buffer)} 66 | for key, buffer in self._accumulators.items()} 67 | 68 | metrics['experimentname'] = experiment_name 69 | 70 | try: 71 | filepath = abspath(filepath) 72 | parent = join(pardir, filepath) 73 | if not exists(parent): 74 | makedirs(parent) 75 | 76 | with open(filepath, 'w') as f_out: 77 | dump(metrics, f_out) 78 | 79 | print('==================================') 80 | print('Evaluation done, results written at %s' % filepath) 81 | 82 | except Exception as e: 83 | print('Unable to save the results: %s' % e) 84 | 85 | def run(self): 86 | from multiprocessing import Process 87 | 88 | env = PigChaseEnvironment(self._clients, self._state_builder, 89 | role=1, randomize_positions=True) 90 | print('==================================') 91 | print('Starting evaluation of Agent @100k') 92 | 93 | p = Process(target=run_challenge_agent, args=(self._clients,)) 94 | p.start() 95 | sleep(5) 96 | agent_loop(self._agent_100k, env, self._accumulators['100k']) 97 | p.terminate() 98 | 99 | print('==================================') 100 | print('Starting evaluation of Agent @500k') 101 | 102 | p = Process(target=run_challenge_agent, args=(self._clients,)) 103 | p.start() 104 | sleep(5) 105 | agent_loop(self._agent_500k, env, self._accumulators['500k']) 106 | p.terminate() 107 | 108 | 109 | def run_challenge_agent(clients): 110 | builder = PigChaseSymbolicStateBuilder() 111 | env = PigChaseEnvironment(clients, builder, role=0, 112 | randomize_positions=True) 113 | agent = PigChaseChallengeAgent(ENV_AGENT_NAMES[0]) 114 | agent_loop(agent, env, None) 115 | 116 | 117 | def agent_loop(agent, env, metrics_acc): 118 | EVAL_EPISODES = 100 119 | agent_done = False 120 | reward = 0 121 | episode = 0 122 | obs = env.reset() 123 | 124 | while episode < EVAL_EPISODES: 125 | # check if env needs reset 126 | if env.done: 127 | print('Episode %d (%.2f)%%' % (episode, (episode / EVAL_EPISODES) * 100.)) 128 | 129 | obs = env.reset() 130 | while obs is None: 131 | # this can happen if the episode ended with the first 132 | # action of the other agent 133 | print('Warning: received obs == None.') 134 | obs = env.reset() 135 | 136 | episode += 1 137 | 138 | # select an action 139 | action = agent.act(obs, reward, agent_done, is_training=True) 140 | # take a step 141 | obs, reward, agent_done = env.do(action) 142 | 143 | if metrics_acc is not None: 144 | metrics_acc.append(reward) 145 | -------------------------------------------------------------------------------- /ai_challenge/pig_chase/pig-chase-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/malmo-challenge/72634708638fd6fe84521894891f054933717ecf/ai_challenge/pig_chase/pig-chase-overview.png -------------------------------------------------------------------------------- /ai_challenge/pig_chase/pig_chase.xml: -------------------------------------------------------------------------------- 1 | 2 | 13 | 14 | 15 | 16 | 17 | Catch the pig! 18 | 19 | 20 | 21 | 4 22 | 23 | 24 | 25 | 26 | 30 | clear 31 | false 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | Agent_1 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | attack 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 640 122 | 480 123 | 124 | 125 | 126 | 127 | 128 | Agent_2 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | attack 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 640 170 | 480 171 | 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /ai_challenge/pig_chase/pig_chase_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | import numpy as np 19 | import os 20 | import sys 21 | 22 | from argparse import ArgumentParser 23 | from datetime import datetime 24 | 25 | import six 26 | from os import path 27 | from threading import Thread, active_count 28 | from time import sleep 29 | from builtins import range 30 | 31 | from malmopy.agent import RandomAgent 32 | try: 33 | from malmopy.visualization.tensorboard import TensorboardVisualizer 34 | from malmopy.visualization.tensorboard.cntk import CntkConverter 35 | except ImportError: 36 | print('Cannot import tensorboard, using ConsoleVisualizer.') 37 | from malmopy.visualization import ConsoleVisualizer 38 | 39 | from common import parse_clients_args, visualize_training, ENV_AGENT_NAMES, ENV_TARGET_NAMES 40 | from agent import PigChaseChallengeAgent, FocusedAgent, TabularQLearnerAgent, get_agent_type 41 | from environment import PigChaseEnvironment, PigChaseSymbolicStateBuilder 42 | 43 | # Enforce path 44 | sys.path.insert(0, os.getcwd()) 45 | sys.path.insert(1, os.path.join(os.path.pardir, os.getcwd())) 46 | 47 | BASELINES_FOLDER = 'results/baselines/pig_chase/%s/%s' 48 | EPOCH_SIZE = 100 49 | 50 | 51 | def agent_factory(name, role, baseline_agent, clients, max_epochs, 52 | logdir, visualizer): 53 | 54 | assert len(clients) >= 2, 'Not enough clients (need at least 2)' 55 | clients = parse_clients_args(clients) 56 | 57 | builder = PigChaseSymbolicStateBuilder() 58 | env = PigChaseEnvironment(clients, builder, role=role, 59 | randomize_positions=True) 60 | 61 | if role == 0: 62 | agent = PigChaseChallengeAgent(name) 63 | obs = env.reset(get_agent_type(agent)) 64 | 65 | reward = 0 66 | agent_done = False 67 | 68 | while True: 69 | if env.done: 70 | while True: 71 | obs = env.reset(get_agent_type(agent)) 72 | if obs: 73 | break 74 | 75 | # select an action 76 | action = agent.act(obs, reward, agent_done, is_training=True) 77 | 78 | # reset if needed 79 | if env.done: 80 | obs = env.reset(get_agent_type(agent)) 81 | 82 | # take a step 83 | obs, reward, agent_done = env.do(action) 84 | 85 | 86 | else: 87 | 88 | if baseline_agent == 'tabq': 89 | agent = TabularQLearnerAgent(name, visualizer) 90 | elif baseline_agent == 'astar': 91 | agent = FocusedAgent(name, ENV_TARGET_NAMES[0]) 92 | else: 93 | agent = RandomAgent(name, env.available_actions) 94 | 95 | obs = env.reset() 96 | reward = 0 97 | agent_done = False 98 | viz_rewards = [] 99 | 100 | max_training_steps = EPOCH_SIZE * max_epochs 101 | for step in six.moves.range(1, max_training_steps+1): 102 | 103 | # check if env needs reset 104 | if env.done: 105 | while True: 106 | if len(viz_rewards) == 0: 107 | viz_rewards.append(0) 108 | visualize_training(visualizer, step, viz_rewards) 109 | tag = "Episode End Conditions" 110 | visualizer.add_entry(step, '%s/timeouts per episode' % tag, env.end_result == "command_quota_reached") 111 | visualizer.add_entry(step, '%s/agent_1 defaults per episode' % tag, env.end_result == "Agent_1_defaulted") 112 | visualizer.add_entry(step, '%s/agent_2 defaults per episode' % tag, env.end_result == "Agent_2_defaulted") 113 | visualizer.add_entry(step, '%s/pig caught per episode' % tag, env.end_result == "caught_the_pig") 114 | agent.inject_summaries(step) 115 | viz_rewards = [] 116 | obs = env.reset() 117 | if obs: 118 | break 119 | 120 | # select an action 121 | action = agent.act(obs, reward, agent_done, is_training=True) 122 | # take a step 123 | obs, reward, agent_done = env.do(action) 124 | viz_rewards.append(reward) 125 | 126 | #agent.inject_summaries(step) 127 | 128 | 129 | def run_experiment(agents_def): 130 | assert len(agents_def) == 2, 'Not enough agents (required: 2, got: %d)'\ 131 | % len(agents_def) 132 | 133 | processes = [] 134 | for agent in agents_def: 135 | p = Thread(target=agent_factory, kwargs=agent) 136 | p.daemon = True 137 | p.start() 138 | 139 | # Give the server time to start 140 | if agent['role'] == 0: 141 | sleep(1) 142 | 143 | processes.append(p) 144 | 145 | try: 146 | # wait until only the challenge agent is left 147 | while active_count() > 2: 148 | sleep(0.1) 149 | except KeyboardInterrupt: 150 | print('Caught control-c - shutting down.') 151 | 152 | 153 | if __name__ == '__main__': 154 | arg_parser = ArgumentParser('Pig Chase baseline experiment') 155 | arg_parser.add_argument('-t', '--type', type=str, default='astar', 156 | choices=['tabq', 'astar', 'random'], 157 | help='The type of baseline to run.') 158 | arg_parser.add_argument('-e', '--epochs', type=int, default=5, 159 | help='Number of epochs to run.') 160 | arg_parser.add_argument('clients', nargs='*', 161 | default=['127.0.0.1:10000', '127.0.0.1:10001'], 162 | help='Minecraft clients endpoints (ip(:port)?)+') 163 | args = arg_parser.parse_args() 164 | 165 | logdir = BASELINES_FOLDER % (args.type, datetime.utcnow().isoformat()) 166 | if 'malmopy.visualization.tensorboard' in sys.modules: 167 | visualizer = TensorboardVisualizer() 168 | visualizer.initialize(logdir, None) 169 | else: 170 | visualizer = ConsoleVisualizer() 171 | 172 | agents = [{'name': agent, 'role': role, 'baseline_agent': args.type, 173 | 'clients': args.clients, 'max_epochs': args.epochs, 174 | 'logdir': logdir, 'visualizer': visualizer} 175 | for role, agent in enumerate(ENV_AGENT_NAMES)] 176 | 177 | run_experiment(agents) 178 | 179 | -------------------------------------------------------------------------------- /ai_challenge/pig_chase/pig_chase_dqn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | import os 19 | import sys 20 | from argparse import ArgumentParser 21 | from datetime import datetime 22 | 23 | import six 24 | from os import path 25 | from threading import Thread, active_count 26 | from time import sleep 27 | 28 | from malmopy.agent import LinearEpsilonGreedyExplorer 29 | 30 | from common import parse_clients_args, visualize_training, ENV_AGENT_NAMES 31 | from agent import PigChaseChallengeAgent, PigChaseQLearnerAgent 32 | from environment import PigChaseEnvironment, PigChaseSymbolicStateBuilder 33 | 34 | from malmopy.environment.malmo import MalmoALEStateBuilder 35 | from malmopy.agent import TemporalMemory, RandomAgent 36 | 37 | try: 38 | from malmopy.visualization.tensorboard import TensorboardVisualizer 39 | from malmopy.visualization.tensorboard.cntk import CntkConverter 40 | except ImportError: 41 | print('Cannot import tensorboard, using ConsoleVisualizer.') 42 | from malmopy.visualization import ConsoleVisualizer 43 | 44 | # Enforce path 45 | sys.path.insert(0, os.getcwd()) 46 | sys.path.insert(1, os.path.join(os.path.pardir, os.getcwd())) 47 | 48 | DQN_FOLDER = 'results/baselines/%s/dqn/%s-%s' 49 | EPOCH_SIZE = 100000 50 | 51 | 52 | def agent_factory(name, role, clients, backend, 53 | device, max_epochs, logdir, visualizer): 54 | 55 | assert len(clients) >= 2, 'Not enough clients (need at least 2)' 56 | clients = parse_clients_args(clients) 57 | 58 | if role == 0: 59 | 60 | builder = PigChaseSymbolicStateBuilder() 61 | env = PigChaseEnvironment(clients, builder, role=role, 62 | randomize_positions=True) 63 | agent = PigChaseChallengeAgent(name) 64 | if type(agent.current_agent) == RandomAgent: 65 | agent_type = PigChaseEnvironment.AGENT_TYPE_1 66 | else: 67 | agent_type = PigChaseEnvironment.AGENT_TYPE_2 68 | 69 | obs = env.reset(agent_type) 70 | reward = 0 71 | agent_done = False 72 | 73 | while True: 74 | if env.done: 75 | if type(agent.current_agent) == RandomAgent: 76 | agent_type = PigChaseEnvironment.AGENT_TYPE_1 77 | else: 78 | agent_type = PigChaseEnvironment.AGENT_TYPE_2 79 | 80 | obs = env.reset(agent_type) 81 | while obs is None: 82 | # this can happen if the episode ended with the first 83 | # action of the other agent 84 | print('Warning: received obs == None.') 85 | obs = env.reset(agent_type) 86 | 87 | # select an action 88 | action = agent.act(obs, reward, agent_done, is_training=True) 89 | # take a step 90 | obs, reward, agent_done = env.do(action) 91 | 92 | else: 93 | env = PigChaseEnvironment(clients, MalmoALEStateBuilder(), 94 | role=role, randomize_positions=True) 95 | memory = TemporalMemory(100000, (84, 84)) 96 | 97 | if backend == 'cntk': 98 | from malmopy.model.cntk import QNeuralNetwork 99 | model = QNeuralNetwork((memory.history_length, 84, 84), env.available_actions, device) 100 | else: 101 | from malmopy.model.chainer import QNeuralNetwork, DQNChain 102 | chain = DQNChain((memory.history_length, 84, 84), env.available_actions) 103 | target_chain = DQNChain((memory.history_length, 84, 84), env.available_actions) 104 | model = QNeuralNetwork(chain, target_chain, device) 105 | 106 | explorer = LinearEpsilonGreedyExplorer(1, 0.1, 1000000) 107 | agent = PigChaseQLearnerAgent(name, env.available_actions, 108 | model, memory, 0.99, 32, 50000, 109 | explorer=explorer, visualizer=visualizer) 110 | 111 | obs = env.reset() 112 | reward = 0 113 | agent_done = False 114 | viz_rewards = [] 115 | 116 | max_training_steps = EPOCH_SIZE * max_epochs 117 | for step in six.moves.range(1, max_training_steps+1): 118 | 119 | # check if env needs reset 120 | if env.done: 121 | 122 | visualize_training(visualizer, step, viz_rewards) 123 | agent.inject_summaries(step) 124 | viz_rewards = [] 125 | 126 | obs = env.reset() 127 | while obs is None: 128 | # this can happen if the episode ended with the first 129 | # action of the other agent 130 | print('Warning: received obs == None.') 131 | obs = env.reset() 132 | 133 | # select an action 134 | action = agent.act(obs, reward, agent_done, is_training=True) 135 | # take a step 136 | obs, reward, agent_done = env.do(action) 137 | viz_rewards.append(reward) 138 | 139 | if (step % EPOCH_SIZE) == 0: 140 | if 'model' in locals(): 141 | model.save('pig_chase-dqn_%d.model' % (step / EPOCH_SIZE)) 142 | 143 | 144 | def run_experiment(agents_def): 145 | assert len(agents_def) == 2, 'Not enough agents (required: 2, got: %d)' \ 146 | % len(agents_def) 147 | 148 | processes = [] 149 | for agent in agents_def: 150 | p = Thread(target=agent_factory, kwargs=agent) 151 | p.daemon = True 152 | p.start() 153 | 154 | # Give the server time to start 155 | if agent['role'] == 0: 156 | sleep(1) 157 | 158 | processes.append(p) 159 | 160 | try: 161 | # wait until only the challenge agent is left 162 | while active_count() > 2: 163 | sleep(0.1) 164 | except KeyboardInterrupt: 165 | print('Caught control-c - shutting down.') 166 | 167 | 168 | if __name__ == '__main__': 169 | arg_parser = ArgumentParser('Pig Chase DQN experiment') 170 | arg_parser.add_argument('-b', '--backend', type=str, choices=['cntk', 'chainer'], 171 | default='cntk', help='Neural network backend') 172 | arg_parser.add_argument('-e', '--epochs', type=int, default=5, 173 | help='Number of epochs to run.') 174 | arg_parser.add_argument('clients', nargs='*', 175 | default=['127.0.0.1:10000', '127.0.0.1:10001'], 176 | help='Minecraft clients endpoints (ip(:port)?)+') 177 | arg_parser.add_argument('-d', '--device', type=int, default=-1, 178 | help='GPU device on which to run the experiment.') 179 | args = arg_parser.parse_args() 180 | 181 | logdir = path.join('results/pig_chase/dqn', datetime.utcnow().isoformat()) 182 | if 'malmopy.visualization.tensorboard' in sys.modules: 183 | visualizer = TensorboardVisualizer() 184 | visualizer.initialize(logdir, None) 185 | 186 | else: 187 | visualizer = ConsoleVisualizer() 188 | 189 | agents = [{'name': agent, 'role': role, 'clients': args.clients, 190 | 'backend': args.backend, 'device': args.device, 191 | 'max_epochs': args.epochs, 'logdir': logdir, 'visualizer': visualizer} 192 | for role, agent in enumerate(ENV_AGENT_NAMES)] 193 | 194 | run_experiment(agents) 195 | -------------------------------------------------------------------------------- /ai_challenge/pig_chase/pig_chase_dqn_top_down.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | import os 19 | import sys 20 | from argparse import ArgumentParser 21 | from datetime import datetime 22 | 23 | import six 24 | from os import path 25 | from threading import Thread, active_count 26 | from time import sleep 27 | 28 | from malmopy.agent import LinearEpsilonGreedyExplorer, RandomAgent 29 | 30 | from common import parse_clients_args, visualize_training, ENV_AGENT_NAMES 31 | from agent import PigChaseChallengeAgent, PigChaseQLearnerAgent 32 | from environment import PigChaseEnvironment, PigChaseSymbolicStateBuilder, \ 33 | PigChaseTopDownStateBuilder 34 | 35 | from malmopy.agent import TemporalMemory 36 | 37 | try: 38 | from malmopy.visualization.tensorboard import TensorboardVisualizer 39 | from malmopy.visualization.tensorboard.cntk import CntkConverter 40 | except ImportError: 41 | print('Cannot import tensorboard, using ConsoleVisualizer.') 42 | from malmopy.visualization import ConsoleVisualizer 43 | 44 | # Enforce path 45 | sys.path.insert(0, os.getcwd()) 46 | sys.path.insert(1, os.path.join(os.path.pardir, os.getcwd())) 47 | 48 | DQN_FOLDER = 'results/baselines/%s/dqn/%s-%s' 49 | EPOCH_SIZE = 100000 50 | 51 | 52 | def agent_factory(name, role, clients, backend, device, max_epochs, logdir, visualizer): 53 | 54 | assert len(clients) >= 2, 'Not enough clients (need at least 2)' 55 | clients = parse_clients_args(clients) 56 | 57 | if role == 0: 58 | 59 | builder = PigChaseSymbolicStateBuilder() 60 | env = PigChaseEnvironment(clients, builder, role=role, 61 | randomize_positions=True) 62 | 63 | agent = PigChaseChallengeAgent(name) 64 | if type(agent.current_agent) == RandomAgent: 65 | agent_type = PigChaseEnvironment.AGENT_TYPE_1 66 | else: 67 | agent_type = PigChaseEnvironment.AGENT_TYPE_2 68 | 69 | obs = env.reset(agent_type) 70 | reward = 0 71 | agent_done = False 72 | 73 | while True: 74 | if env.done: 75 | if type(agent.current_agent) == RandomAgent: 76 | agent_type = PigChaseEnvironment.AGENT_TYPE_1 77 | else: 78 | agent_type = PigChaseEnvironment.AGENT_TYPE_2 79 | 80 | obs = env.reset(agent_type) 81 | while obs is None: 82 | # this can happen if the episode ended with the first 83 | # action of the other agent 84 | print('Warning: received obs == None.') 85 | obs = env.reset(agent_type) 86 | 87 | # select an action 88 | action = agent.act(obs, reward, agent_done, is_training=True) 89 | # take a step 90 | obs, reward, agent_done = env.do(action) 91 | 92 | else: 93 | env = PigChaseEnvironment(clients, PigChaseTopDownStateBuilder(True), 94 | role=role, randomize_positions=True) 95 | memory = TemporalMemory(100000, (18, 18)) 96 | 97 | if backend == 'cntk': 98 | from malmopy.model.cntk import QNeuralNetwork 99 | model = QNeuralNetwork((memory.history_length, 18, 18), env.available_actions, device) 100 | else: 101 | from malmopy.model.chainer import QNeuralNetwork, ReducedDQNChain 102 | chain = ReducedDQNChain((memory.history_length, 18, 18), env.available_actions) 103 | target_chain = ReducedDQNChain((memory.history_length, 18, 18), env.available_actions) 104 | model = QNeuralNetwork(chain, target_chain, device) 105 | 106 | explorer = LinearEpsilonGreedyExplorer(1, 0.1, 1000000) 107 | agent = PigChaseQLearnerAgent(name, env.available_actions, 108 | model, memory, 0.99, 32, 50000, 109 | explorer=explorer, visualizer=visualizer) 110 | 111 | obs = env.reset() 112 | reward = 0 113 | agent_done = False 114 | viz_rewards = [] 115 | 116 | max_training_steps = EPOCH_SIZE * max_epochs 117 | for step in six.moves.range(1, max_training_steps+1): 118 | 119 | # check if env needs reset 120 | if env.done: 121 | 122 | visualize_training(visualizer, step, viz_rewards) 123 | agent.inject_summaries(step) 124 | viz_rewards = [] 125 | 126 | obs = env.reset() 127 | while obs is None: 128 | # this can happen if the episode ended with the first 129 | # action of the other agent 130 | print('Warning: received obs == None.') 131 | obs = env.reset() 132 | 133 | # select an action 134 | action = agent.act(obs, reward, agent_done, is_training=True) 135 | # take a step 136 | obs, reward, agent_done = env.do(action) 137 | viz_rewards.append(reward) 138 | 139 | if (step % EPOCH_SIZE) == 0: 140 | if 'model' in locals(): 141 | model.save('pig_chase-dqn_%d.model' % (step / EPOCH_SIZE)) 142 | 143 | 144 | def run_experiment(agents_def): 145 | assert len(agents_def) == 2, 'Not enough agents (required: 2, got: %d)' \ 146 | % len(agents_def) 147 | 148 | processes = [] 149 | for agent in agents_def: 150 | p = Thread(target=agent_factory, kwargs=agent) 151 | p.daemon = True 152 | p.start() 153 | 154 | # Give the server time to start 155 | if agent['role'] == 0: 156 | sleep(1) 157 | 158 | processes.append(p) 159 | 160 | try: 161 | # wait until only the challenge agent is left 162 | while active_count() > 2: 163 | sleep(0.1) 164 | except KeyboardInterrupt: 165 | print('Caught control-c - shutting down.') 166 | 167 | 168 | if __name__ == '__main__': 169 | arg_parser = ArgumentParser('Pig Chase DQN experiment') 170 | arg_parser.add_argument('-b', '--backend', type=str, choices=['cntk', 'chainer'], 171 | default='cntk', help='Neural network backend') 172 | arg_parser.add_argument('-e', '--epochs', type=int, default=5, 173 | help='Number of epochs to run.') 174 | arg_parser.add_argument('clients', nargs='*', 175 | default=['127.0.0.1:10000', '127.0.0.1:10001'], 176 | help='Minecraft clients endpoints (ip(:port)?)+') 177 | arg_parser.add_argument('-d', '--device', type=int, default=-1, 178 | help='GPU device on which to run the experiment.') 179 | args = arg_parser.parse_args() 180 | 181 | logdir = path.join('results/pig_chase/dqn', datetime.utcnow().isoformat()) 182 | if 'malmopy.visualization.tensorboard' in sys.modules: 183 | visualizer = TensorboardVisualizer() 184 | visualizer.initialize(logdir, None) 185 | 186 | else: 187 | visualizer = ConsoleVisualizer() 188 | 189 | agents = [{'name': agent, 'role': role, 'clients': args.clients, 190 | 'backend':args.backend, 'device': args.device, 191 | 'max_epochs': args.epochs, 'logdir': logdir, 'visualizer': visualizer} 192 | for role, agent in enumerate(ENV_AGENT_NAMES)] 193 | 194 | run_experiment(agents) 195 | -------------------------------------------------------------------------------- /ai_challenge/pig_chase/pig_chase_eval_sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from common import ENV_AGENT_NAMES 19 | from evaluation import PigChaseEvaluator 20 | from environment import PigChaseTopDownStateBuilder 21 | from malmopy.agent import RandomAgent 22 | 23 | 24 | if __name__ == '__main__': 25 | # Warn for Agent name !!! 26 | 27 | clients = [('127.0.0.1', 10000), ('127.0.0.1', 10001)] 28 | agent = RandomAgent(ENV_AGENT_NAMES[1], 3) 29 | 30 | eval = PigChaseEvaluator(clients, agent, agent, PigChaseTopDownStateBuilder()) 31 | eval.run() 32 | 33 | eval.save('My Exp 1', 'pig_chase_results.json') 34 | -------------------------------------------------------------------------------- /ai_challenge/pig_chase/pig_chase_human_vs_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | import os 19 | import sys 20 | from argparse import ArgumentParser 21 | from datetime import datetime 22 | from multiprocessing import Process, Event 23 | from os import path 24 | from time import sleep 25 | 26 | from malmopy.agent import RandomAgent 27 | from malmopy.agent.gui import ARROW_KEYS_MAPPING 28 | from malmopy.visualization import ConsoleVisualizer 29 | 30 | # Enforce path 31 | sys.path.insert(0, os.getcwd()) 32 | sys.path.insert(1, os.path.join(os.path.pardir, os.getcwd())) 33 | 34 | from common import parse_clients_args, ENV_AGENT_NAMES, ENV_ACTIONS, ENV_AGENT_TYPES 35 | from agent import PigChaseChallengeAgent, PigChaseHumanAgent, TabularQLearnerAgent, FocusedAgent, get_agent_type 36 | from environment import PigChaseEnvironment, PigChaseSymbolicStateBuilder 37 | 38 | MAX_ACTIONS = 25 # this should match the mission definition, used for display only 39 | 40 | def agent_factory(name, role, kind, clients, max_episodes, max_actions, logdir, quit, model_file): 41 | assert len(clients) >= 2, 'There are not enough Malmo clients in the pool (need at least 2)' 42 | 43 | clients = parse_clients_args(clients) 44 | visualizer = ConsoleVisualizer(prefix='Agent %d' % role) 45 | 46 | if role == 0: 47 | env = PigChaseEnvironment(clients, PigChaseSymbolicStateBuilder(), 48 | actions=ENV_ACTIONS, role=role, 49 | human_speed=True, randomize_positions=True) 50 | if kind == 'challenge': 51 | agent = PigChaseChallengeAgent(name) 52 | elif kind == 'astar': 53 | agent = FocusedAgent(name, ENV_TARGET_NAMES[0]) 54 | elif kind == 'tabq': 55 | agent = TabularQLearnerAgent(name) 56 | if model_file != '': 57 | agent.load(model_file) 58 | else: 59 | agent = RandomAgent(name ,env.available_actions) 60 | 61 | obs = env.reset(get_agent_type(agent)) 62 | reward = 0 63 | rewards = [] 64 | done = False 65 | episode = 0 66 | 67 | while True: 68 | 69 | # select an action 70 | action = agent.act(obs, reward, done, True) 71 | 72 | if done: 73 | visualizer << (episode + 1, 'Reward', sum(rewards)) 74 | rewards = [] 75 | episode += 1 76 | 77 | obs = env.reset(get_agent_type(agent)) 78 | 79 | # take a step 80 | obs, reward, done = env.do(action) 81 | rewards.append(reward) 82 | 83 | else: 84 | env = PigChaseEnvironment(clients, PigChaseSymbolicStateBuilder(), 85 | actions=list(ARROW_KEYS_MAPPING.values()), 86 | role=role, randomize_positions=True) 87 | env.reset(ENV_AGENT_TYPES.HUMAN) 88 | 89 | agent = PigChaseHumanAgent(name, env, list(ARROW_KEYS_MAPPING.keys()), 90 | max_episodes, max_actions, visualizer, quit) 91 | agent.show() 92 | 93 | 94 | def run_mission(agents_def): 95 | assert len(agents_def) == 2, 'Incompatible number of agents (required: 2, got: %d)' % len(agents_def) 96 | quit = Event() 97 | processes = [] 98 | for agent in agents_def: 99 | agent['quit'] = quit 100 | p = Process(target=agent_factory, kwargs=agent) 101 | p.daemon = True 102 | p.start() 103 | 104 | if agent['role'] == 0: 105 | sleep(1) # Just to let time for the server to start 106 | 107 | processes.append(p) 108 | quit.wait() 109 | for process in processes: 110 | process.terminate() 111 | 112 | 113 | if __name__ == '__main__': 114 | arg_parser = ArgumentParser() 115 | arg_parser.add_argument('-e', '--episodes', type=int, default=10, help='Number of episodes to run.') 116 | arg_parser.add_argument('-k', '--kind', type=str, default='challenge', choices=['astar', 'random', 'tabq', 'challenge'], 117 | help='The kind of agent to play with (random, astar, tabq or challenge).') 118 | arg_parser.add_argument('-m', '--model_file', type=str, default='', help='Model file with which to initialise agent, if appropriate') 119 | arg_parser.add_argument('clients', nargs='*', 120 | default=['127.0.0.1:10000', '127.0.0.1:10001'], 121 | help='Malmo clients (ip(:port)?)+') 122 | args = arg_parser.parse_args() 123 | 124 | logdir = path.join('results/pig-human', datetime.utcnow().isoformat()) 125 | agents = [{'name': agent, 'role': role, 'kind': args.kind, 'model_file': args.model_file, 126 | 'clients': args.clients, 'max_episodes': args.episodes, 127 | 'max_actions': MAX_ACTIONS, 'logdir': logdir} 128 | for role, agent in enumerate(ENV_AGENT_NAMES)] 129 | 130 | run_mission(agents) 131 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Run experiments in docker 4 | 5 | [Docker](https://www.docker.com/) is a container solution that makes it easy to build and deploy 6 | software in a virtual environment. The examples in this folder use docker to easily deploy an experiment 7 | with all its dependencies, either on a local machine or on the cloud. 8 | 9 | ## Prerequisites 10 | 11 | Install docker on your local machine by following the installation instructions for 12 | [Windows](https://docs.docker.com/docker-for-windows/install/), 13 | [Linux](https://docs.docker.com/engine/installation/), 14 | [MacOS](https://docs.docker.com/docker-for-mac/install/). 15 | 16 | Prepare a docker machine on Azure, follow the local installation steps above, then run: 17 | ``` 18 | docker-machine create --driver azure --azure-size Standard_D12 --azure-subscription-id 19 | ``` 20 | Replace `` with your Azure subsciption id - you can find this on the Azure dashboard after 21 | logging on to https://portal.azure.com. The `` is arbitrary. 22 | 23 | Additional `docker-machine` options are listed here: https://docs.docker.com/machine/drivers/azure/ 24 | Azure machine sizes are detailed on: https://docs.microsoft.com/en-us/azure/virtual-machines/virtual-machines-linux-sizes (we recommend to use at least size Standard_D12) 25 | 26 | Configure docker to deploy to ``. Run: 27 | ``` 28 | docker-machine env 29 | ``` 30 | This will provide a script / instructions on how to prepare your environment to work with . 31 | 32 | ## Build the docker images 33 | 34 | Build the required docker images: 35 | ``` 36 | cd docker 37 | docker build malmo -t malmo:latest 38 | docker build malmopy-cntk-cpu-py27 -t malmopy-cntk-cpu-py27:latest 39 | 40 | ``` 41 | 42 | Check to make sure that the images have been compiled: 43 | ``` 44 | docker images 45 | ``` 46 | You should see a list that includes the compiled images, e.g., 47 | ``` 48 | REPOSITORY TAG IMAGE ID CREATED SIZE 49 | malmopy-cntk-cpu-py27 latest 0161af81632d 29 minutes ago 5.62 GB 50 | malmo latest 1b67b8e2cfa8 41 minutes ago 1.04 GB 51 | ... 52 | ``` 53 | 54 | ## Run the experiment 55 | 56 | Run the challenge task with an example agent: 57 | ``` 58 | cd malmopy-ai-challenge 59 | docker-compose up 60 | ``` 61 | 62 | The experiment is set up to start a tensorboard process alongside the experiment. 63 | You can view it by pointing your browser to http://127.0.0.1:6006. 64 | 65 | ## Write your own 66 | 67 | The provided docker files load malmopy and sample code directly from the 68 | `malmo-challenge` git repository. To include your own code, create a file 69 | called `Dockerfile` with the following content: 70 | 71 | ``` 72 | FROM malmopy-cntk-cpu-py27:latest 73 | 74 | # add your own experiment code here 75 | # ADD copies content from your local machine into the docker image 76 | ADD ai_challenge/pig_chase /local/malmo-challenge/ai_challenge/pig_chase 77 | ``` 78 | 79 | Build this new image using: 80 | ``` 81 | docker build . -t my_malmo_experiment:latest 82 | ``` 83 | 84 | Point the `agents` service in `docker-compose.py` to the new image by replacing 85 | `image: malmopy-cntk-cpu-py27:latest` with the name of the image you have just 86 | built (e.g., `image:my_malmo_experiment:latest`). Also check if the working 87 | directory or command need to be changed. 88 | 89 | Then run the new experiment: 90 | ``` 91 | docker-compose up 92 | ``` 93 | 94 | ## Cleaning up 95 | 96 | If you are using a docker machine on Azure, make sure to shutdown and decomission 97 | the machine when your experiments have completed, to avoid incurring costs. 98 | 99 | To shut a machine down: 100 | ``` 101 | docker-machine stop 102 | ``` 103 | 104 | To remove (decomission) a machine: 105 | ``` 106 | docker-machine rm 107 | ``` 108 | 109 | ## Further reading 110 | 111 | - [docker documentation](https://docs.docker.com/) 112 | - [docker on Azure](https://docs.docker.com/machine/drivers/azure/) 113 | - [docker compose](https://docs.docker.com/compose/overview/) 114 | -------------------------------------------------------------------------------- /docker/malmo/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | FROM ubuntu:16.04 19 | 20 | ENV MALMO_VERSION 0.21.0 21 | 22 | # Install Malmo dependencies 23 | RUN apt-get update && apt-get install -y --no-install-recommends \ 24 | openjdk-8-jdk \ 25 | libxerces-c3.1 \ 26 | libav-tools \ 27 | wget \ 28 | unzip \ 29 | xvfb && \ 30 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 31 | 32 | # Download and unpack Malmo 33 | WORKDIR /root 34 | RUN wget https://github.com/Microsoft/malmo/releases/download/$MALMO_VERSION/Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 35 | unzip Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 36 | rm Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 37 | mv Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost Malmo 38 | ENV MALMO_XSD_PATH /root/Malmo/Schemas 39 | 40 | # Precompile Malmo mod 41 | RUN mkdir ~/.gradle && echo 'org.gradle.daemon=true\n' > ~/.gradle/gradle.properties 42 | WORKDIR /root/Malmo/Minecraft 43 | RUN ./gradlew setupDecompWorkspace 44 | RUN ./gradlew build 45 | 46 | # Unlimited framerate settings 47 | COPY options.txt /root/Malmo/Minecraft/run 48 | 49 | COPY run.sh /root/ 50 | RUN chmod +x /root/run.sh 51 | 52 | # Expose port 53 | EXPOSE 10000 54 | 55 | # Run Malmo 56 | ENTRYPOINT ["/root/run.sh", "/root/Malmo/Minecraft/launchClient.sh"] -------------------------------------------------------------------------------- /docker/malmo/options.txt: -------------------------------------------------------------------------------- 1 | invertYMouse:false 2 | mouseSensitivity:0.5 3 | fov:0.0 4 | gamma:0.0 5 | saturation:0.0 6 | renderDistance:6 7 | guiScale:0 8 | particles:2 9 | bobView:true 10 | anaglyph3d:false 11 | maxFps:200 12 | fboEnable:true 13 | difficulty:2 14 | fancyGraphics:false 15 | ao:0 16 | renderClouds:true 17 | resourcePacks:[] 18 | lastServer: 19 | lang:en_US 20 | chatVisibility:0 21 | chatColors:true 22 | chatLinks:true 23 | chatLinksPrompt:true 24 | chatOpacity:1.0 25 | snooperEnabled:true 26 | fullscreen:false 27 | enableVsync:true 28 | useVbo:false 29 | hideServerAddress:false 30 | advancedItemTooltips:false 31 | pauseOnLostFocus:false 32 | touchscreen:false 33 | overrideWidth:0 34 | overrideHeight:0 35 | heldItemTooltips:true 36 | chatHeightFocused:1.0 37 | chatHeightUnfocused:0.44366196 38 | chatScale:1.0 39 | chatWidth:1.0 40 | showInventoryAchievementHint:false 41 | mipmapLevels:4 42 | streamBytesPerPixel:0.5 43 | streamMicVolume:1.0 44 | streamSystemVolume:1.0 45 | streamKbps:0.5412844 46 | streamFps:0.31690142 47 | streamCompression:1 48 | streamSendMetadata:true 49 | streamPreferredServer: 50 | streamChatEnabled:0 51 | streamChatUserFilter:0 52 | streamMicToggleBehavior:0 53 | forceUnicodeFont:false 54 | allowBlockAlternatives:true 55 | reducedDebugInfo:false 56 | key_key.attack:-100 57 | key_key.use:-99 58 | key_key.forward:17 59 | key_key.left:30 60 | key_key.back:31 61 | key_key.right:32 62 | key_key.jump:57 63 | key_key.sneak:42 64 | key_key.drop:16 65 | key_key.inventory:18 66 | key_key.chat:20 67 | key_key.playerlist:15 68 | key_key.pickItem:-98 69 | key_key.command:53 70 | key_key.screenshot:60 71 | key_key.togglePerspective:63 72 | key_key.smoothCamera:0 73 | key_key.sprint:29 74 | key_key.streamStartStop:64 75 | key_key.streamPauseUnpause:65 76 | key_key.streamCommercial:0 77 | key_key.streamToggleMic:0 78 | key_key.fullscreen:87 79 | key_key.spectatorOutlines:0 80 | key_key.hotbar.1:2 81 | key_key.hotbar.2:3 82 | key_key.hotbar.3:4 83 | key_key.hotbar.4:5 84 | key_key.hotbar.5:6 85 | key_key.hotbar.6:7 86 | key_key.hotbar.7:8 87 | key_key.hotbar.8:9 88 | key_key.hotbar.9:10 89 | key_key.toggleMalmo:28 90 | key_key.handyTestHook:22 91 | soundCategory_master:0.0 92 | soundCategory_music:1.0 93 | soundCategory_record:1.0 94 | soundCategory_weather:1.0 95 | soundCategory_block:1.0 96 | soundCategory_hostile:1.0 97 | soundCategory_neutral:1.0 98 | soundCategory_player:1.0 99 | soundCategory_ambient:1.0 100 | modelPart_cape:true 101 | modelPart_jacket:true 102 | modelPart_left_sleeve:true 103 | modelPart_right_sleeve:true 104 | modelPart_left_pants_leg:true 105 | modelPart_right_pants_leg:true 106 | modelPart_hat:true 107 | -------------------------------------------------------------------------------- /docker/malmo/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | xvfb-run -a -e /dev/stdout -s '-screen 0 1400x900x24' $* -------------------------------------------------------------------------------- /docker/malmopy-ai-challenge/docker-compose.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | version: '3' 19 | services: 20 | malmo1: 21 | image: malmo:latest 22 | expose: 23 | - "10000" 24 | malmo2: 25 | image: malmo:latest 26 | expose: 27 | - "10000" 28 | agents: 29 | image: malmopy-cntk-cpu-py27:latest 30 | working_dir: /root/malmo-challenge/ai_challenge/pig_chase 31 | command: bash -c "python pig_chase_baseline.py malmo1:10000 malmo2:10000 & tensorboard --logdir 'results' --port 6006" 32 | ports: 33 | - "6006:6006" 34 | links: 35 | - malmo1 36 | - malmo2 37 | -------------------------------------------------------------------------------- /docker/malmopy-chainer-cpu/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | FROM ubuntu:16.04 19 | 20 | # Version variables 21 | ENV MALMO_VERSION 0.21.0 22 | ENV MALMOPY_VERSION 0.1.0 23 | 24 | RUN apt-get update -y && \ 25 | apt-get install -y --no-install-recommends \ 26 | build-essential \ 27 | python-dev \ 28 | python-pip \ 29 | python-setuptools \ 30 | cmake \ 31 | ssh \ 32 | git-all \ 33 | zlib1g-dev \ 34 | 35 | # install Malmo dependencies 36 | libpython2.7 \ 37 | lua5.1 \ 38 | libxerces-c3.1 \ 39 | liblua5.1-0-dev \ 40 | libav-tools \ 41 | python-tk \ 42 | python-imaging-tk \ 43 | wget \ 44 | unzip && \ 45 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 46 | 47 | RUN pip install -U pip setuptools && pip install wheel && pip install chainer==1.21.0 48 | 49 | # download and unpack Malmo 50 | WORKDIR /root 51 | RUN wget https://github.com/Microsoft/malmo/releases/download/$MALMO_VERSION/Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 52 | unzip Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 53 | rm Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 54 | mv Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost Malmo 55 | 56 | ENV MALMO_XSD_PATH /root/Malmo/Schemas 57 | ENV PYTHONPATH /root/Malmo/Python_Examples 58 | 59 | # add and install malmopy, malmo challenge task and samples 60 | WORKDIR /root 61 | RUN git clone https://github.com/Microsoft/malmo-challenge.git && \ 62 | cd malmo-challenge && \ 63 | git checkout tags/$MALMOPY_VERSION -b latest 64 | WORKDIR /root/malmo-challenge 65 | RUN pip install -e '.[all]' 66 | -------------------------------------------------------------------------------- /docker/malmopy-chainer-gpu/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | FROM nvidia/cuda:8.0-cudnn5-devel-ubuntu16.04 19 | 20 | 21 | # Version variables 22 | ENV MALMO_VERSION 0.21.0 23 | ENV MALMOPY_VERSION 0.1.0 24 | 25 | RUN apt-get update -y && \ 26 | apt-get install -y --no-install-recommends \ 27 | build-essential \ 28 | python-dev \ 29 | python-pip \ 30 | python-setuptools \ 31 | cmake \ 32 | ssh \ 33 | git-all \ 34 | zlib1g-dev \ 35 | 36 | # install Malmo dependencies 37 | libpython2.7 \ 38 | lua5.1 \ 39 | libxerces-c3.1 \ 40 | liblua5.1-0-dev \ 41 | libav-tools \ 42 | python-tk \ 43 | python-imaging-tk \ 44 | wget \ 45 | unzip && \ 46 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 47 | 48 | RUN pip install -U pip setuptools && pip install wheel && pip install chainer==1.21.0 49 | 50 | # download and unpack Malmo 51 | WORKDIR /root 52 | RUN wget https://github.com/Microsoft/malmo/releases/download/$MALMO_VERSION/Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 53 | unzip Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 54 | rm Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 55 | mv Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost Malmo 56 | 57 | ENV MALMO_XSD_PATH /root/Malmo/Schemas 58 | ENV PYTHONPATH /root/Malmo/Python_Examples 59 | 60 | # add and install malmopy, malmo challenge task and samples 61 | WORKDIR /root 62 | RUN git clone https://github.com/Microsoft/malmo-challenge.git && \ 63 | cd malmo-challenge && \ 64 | git checkout tags/$MALMOPY_VERSION -b latest 65 | WORKDIR /root/malmo-challenge 66 | RUN pip install -e '.[all]' 67 | -------------------------------------------------------------------------------- /docker/malmopy-cntk-cpu-py27/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | FROM microsoft/cntk:2.0.beta15.0-cpu-python2.7 19 | 20 | # Version variables 21 | ENV MALMO_VERSION 0.21.0 22 | ENV MALMOPY_VERSION 0.1.0 23 | 24 | RUN apt-get update -y && \ 25 | apt-get install -y --no-install-recommends \ 26 | build-essential \ 27 | cmake \ 28 | ssh \ 29 | git-all \ 30 | zlib1g-dev \ 31 | python-dev \ 32 | python-pip \ 33 | 34 | # install Malmo dependencies 35 | libpython2.7 \ 36 | openjdk-7-jdk \ 37 | lua5.1 \ 38 | libxerces-c3.1 \ 39 | liblua5.1-0-dev \ 40 | libav-tools \ 41 | python-tk \ 42 | python-imaging-tk \ 43 | wget \ 44 | unzip && \ 45 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 46 | 47 | # Set CNTK Python PATH at first position to be picked automatically 48 | ENV PATH=/root/anaconda3/envs/cntk-py27/bin:$PATH 49 | 50 | # Update pip 51 | RUN /root/anaconda3/envs/cntk-py27/bin/pip install --upgrade pip 52 | 53 | # download and unpack Malmo 54 | WORKDIR /root 55 | RUN wget https://github.com/Microsoft/malmo/releases/download/$MALMO_VERSION/Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 56 | unzip Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 57 | rm Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 58 | mv Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost Malmo 59 | 60 | ENV MALMO_XSD_PATH /root/Malmo/Schemas 61 | ENV PYTHONPATH /root/Malmo/Python_Examples 62 | 63 | # add and install malmopy, malmo challenge task and samples 64 | WORKDIR /root 65 | RUN git clone https://github.com/Microsoft/malmo-challenge.git && \ 66 | cd malmo-challenge && \ 67 | git checkout tags/$MALMOPY_VERSION -b latest 68 | WORKDIR /root/malmo-challenge 69 | RUN pip install -e '.[all]' 70 | -------------------------------------------------------------------------------- /docker/malmopy-cntk-gpu-py27/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | FROM microsoft/cntk:2.0.beta15.0-gpu-python2.7-cuda8.0-cudnn5.1 19 | 20 | # Version variables 21 | ENV MALMO_VERSION 0.21.0 22 | ENV MALMOPY_VERSION 0.1.0 23 | 24 | RUN apt-get update -y && \ 25 | apt-get install -y --no-install-recommends \ 26 | build-essential \ 27 | cmake \ 28 | ssh \ 29 | git-all \ 30 | zlib1g-dev \ 31 | python-dev \ 32 | python-pip \ 33 | 34 | # install Malmo dependencies 35 | libpython2.7 \ 36 | openjdk-7-jdk \ 37 | lua5.1 \ 38 | libxerces-c3.1 \ 39 | liblua5.1-0-dev \ 40 | libav-tools \ 41 | python-tk \ 42 | python-imaging-tk \ 43 | wget \ 44 | unzip && \ 45 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 46 | 47 | # Set CNTK Python PATH at first position to be picked automatically 48 | ENV PATH=/root/anaconda3/envs/cntk-py27/bin:$PATH 49 | 50 | # Update pip 51 | RUN /root/anaconda3/envs/cntk-py27/bin/pip install --upgrade pip 52 | 53 | # download and unpack Malmo 54 | WORKDIR /root 55 | RUN wget https://github.com/Microsoft/malmo/releases/download/$MALMO_VERSION/Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 56 | unzip Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 57 | rm Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \ 58 | mv Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost Malmo 59 | 60 | ENV MALMO_XSD_PATH /root/Malmo/Schemas 61 | ENV PYTHONPATH /root/Malmo/Python_Examples 62 | 63 | # add and install malmopy, malmo challenge task and samples 64 | WORKDIR /root 65 | RUN git clone https://github.com/Microsoft/malmo-challenge.git && \ 66 | cd malmo-challenge && \ 67 | git checkout tags/$MALMOPY_VERSION -b latest 68 | WORKDIR /root/malmo-challenge 69 | RUN pip install -e '.[all]' 70 | -------------------------------------------------------------------------------- /malmopy/README.md: -------------------------------------------------------------------------------- 1 | ## Writing your first experiment 2 | 3 | The framework is designed to give you the flexibility you need to design and run your experiment. 4 | In this section you will see how easy it is to write a simple Atari/DQN experiment based on CNTK backend. 5 | 6 | 7 | ### Using with Microsoft Cognitive Network ToolKit (CNTK) 8 | To be able use CNTK from the framework, you will need first to install CNTK from the 9 | official repository [release page](https://github.com/Microsoft/CNTK/releases). Pick the 10 | right distribution according to your OS / Hardware configuration and plans to use distributed 11 | training sessions. 12 | 13 | The CNTK Python binding can be installed by running the installation script 14 | ([more information here](https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-your-machine)). 15 | After following the installation process you should be able to import CNTK. 16 | 17 | ___Note that every time you will want to run experiment with CNTK you will need to activate the cntk-pyXX environment.___ 18 | 19 | ### Getting started 20 | 21 | First of all, you need to import all the dependencies : 22 | ```python 23 | from malmopy.agent.qlearner import QLearnerAgent, TemporalMemory 24 | from malmopy.model.cntk import QNeuralNetwork 25 | from malmopy.environment.gym import GymEnvironment 26 | 27 | # In this example we will use the Breakout-v3 environment. 28 | env = GymEnvironment('Breakout-v3', monitoring_path='/directory/where/to/put/records') 29 | 30 | # Q Neural Network needs a Replay Memory to randomly sample minibatch. 31 | memory = TemporalMemory(1000000, (84, 84), 4) 32 | 33 | #Here a simple Deep Q Neural Network backed by CNTK runtime 34 | model = QNeuralNetwork((4, 84, 84), env.available_actions, device_id=-1) 35 | 36 | # We provide the number of action available, our model and the memory 37 | agent = QLearnerAgent("DQN Agent", env.available_actions, model, memory, 0.99, 32) 38 | 39 | reward = 0 40 | done = False 41 | 42 | # Remplace range by xrange if running Python 2 43 | while True: 44 | 45 | # Reset environment if needed 46 | if env.done: 47 | current_state = env.reset() 48 | 49 | action = agent.act(current_state, reward, done, True) 50 | new_state, reward, done = env.do(action) 51 | ``` 52 | 53 | ## Some comments: 54 | - The GymEnvironment monitoring_path is used to record short epsiode videos of the agent 55 | - Temporal Memory generates a sample w.r.t to the history_length previous state 56 | - For example with history_length = 4 a sample is [s(t-3), s(t-2), s(t-1), s(t)] 57 | - QNeuralNetwork input_shape is the shape of a sample from the TemporalMemory (history_length, width, height) 58 | - QNeuralNetwork output_shape is the number of actions available for the environment (one neuron per action) 59 | - QNeuralNetwork device_id == -1 indicate 'Run on CPU', anything >=0 refers to a GPU device ID 60 | -------------------------------------------------------------------------------- /malmopy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | -------------------------------------------------------------------------------- /malmopy/agent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from .agent import BaseAgent, RandomAgent, ConsoleAgent, ReplayMemory 21 | from .astar import AStarAgent 22 | from .explorer import BaseExplorer, LinearEpsilonGreedyExplorer 23 | from .qlearner import QLearnerAgent, History, ReplayMemory, TemporalMemory 24 | 25 | __all__ = ['agent', 'astar', 'qlearner', 'explorer'] 26 | -------------------------------------------------------------------------------- /malmopy/agent/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | import os 21 | import sys 22 | from collections import Iterable 23 | 24 | import numpy as np 25 | 26 | from ..visualization import Visualizable 27 | 28 | 29 | class BaseAgent(Visualizable): 30 | """ 31 | Represents an agent that interacts with an environment 32 | """ 33 | 34 | def __init__(self, name, nb_actions, visualizer=None): 35 | assert nb_actions > 0, 'Agent should at least have 1 action (got %d)' % nb_actions 36 | 37 | super(BaseAgent, self).__init__(visualizer) 38 | 39 | self.name = name 40 | self.nb_actions = nb_actions 41 | 42 | def act(self, new_state, reward, done, is_training=False): 43 | raise NotImplementedError() 44 | 45 | def save(self, out_dir): 46 | pass 47 | 48 | def load(self, out_dir): 49 | pass 50 | 51 | def inject_summaries(self, idx): 52 | pass 53 | 54 | 55 | class RandomAgent(BaseAgent): 56 | """ 57 | An agent that selects actions uniformly at random 58 | """ 59 | 60 | def __init__(self, name, nb_actions, delay_between_action=0, visualizer=None): 61 | super(RandomAgent, self).__init__(name, nb_actions, visualizer) 62 | 63 | self._delay = delay_between_action 64 | 65 | def act(self, new_state, reward, done, is_training=False): 66 | if self._delay > 0: 67 | from time import sleep 68 | sleep(self._delay) 69 | 70 | return np.random.randint(0, self.nb_actions) 71 | 72 | 73 | class ConsoleAgent(BaseAgent): 74 | """ Provide a console interface for mediating human interaction with 75 | an environment 76 | 77 | Users are prompted for input when an action is required: 78 | 79 | Agent-1, what do you want to do? 80 | 1: action1 81 | 2: action2 82 | 3: action3 83 | ... 84 | N: actionN 85 | Agent-1: 2 86 | ... 87 | """ 88 | 89 | def __init__(self, name, actions, stdin=None): 90 | assert isinstance(actions, Iterable), 'actions need to be iterable (e.g., list, tuple)' 91 | assert len(actions) > 0, 'actions need at least one element' 92 | 93 | super(ConsoleAgent, self).__init__(name, len(actions)) 94 | 95 | self._actions = actions 96 | 97 | if stdin is not None: 98 | sys.stdin = os.fdopen(stdin) 99 | 100 | def act(self, new_state, reward, done, is_training=False): 101 | user_action = None 102 | 103 | while user_action is None: 104 | self._print_choices() 105 | try: 106 | user_input = input("%s: " % self.name) 107 | user_action = int(user_input) 108 | if user_action < 0 or user_action > len(self._actions) - 1: 109 | user_action = None 110 | print("Provided input is not valid should be [0, %d]" % (len(self._actions) - 1)) 111 | except ValueError: 112 | user_action = None 113 | print("Provided input is not valid should be [0, %d]" % (len(self._actions) - 1)) 114 | 115 | return user_action 116 | 117 | def _print_choices(self): 118 | print("\n%s What do you want to do?" % self.name) 119 | 120 | for idx, action in enumerate(self._actions): 121 | print("\t%d : %s" % (idx, action)) 122 | 123 | 124 | class ReplayMemory(object): 125 | """ 126 | Simple representation of agent memory 127 | """ 128 | 129 | def __init__(self, max_size, state_shape): 130 | assert max_size > 0, 'size should be > 0 (got %d)' % max_size 131 | 132 | self._pos = 0 133 | self._count = 0 134 | self._max_size = max_size 135 | self._state_shape = state_shape 136 | self._states = np.empty((max_size,) + state_shape, dtype=np.float32) 137 | self._actions = np.empty(max_size, dtype=np.uint8) 138 | self._rewards = np.empty(max_size, dtype=np.float32) 139 | self._terminals = np.empty(max_size, dtype=np.bool) 140 | 141 | def append(self, state, action, reward, is_terminal): 142 | """ 143 | Appends the specified memory to the history. 144 | :param state: The state to append (should have the same shape as defined at initialization time) 145 | :param action: An integer representing the action done 146 | :param reward: An integer reprensenting the reward received for doing this action 147 | :param is_terminal: A boolean specifying if this state is a terminal (episode has finished) 148 | :return: 149 | """ 150 | assert state.shape == self._state_shape, \ 151 | 'Invalid state shape (required: %s, got: %s)' % (self._state_shape, state.shape) 152 | 153 | self._states[self._pos, ...] = state 154 | self._actions[self._pos] = action 155 | self._rewards[self._pos] = reward 156 | self._terminals[self._pos] = is_terminal 157 | 158 | self._count = max(self._count, self._pos + 1) 159 | self._pos = (self._pos + 1) % self._max_size 160 | 161 | def __len__(self): 162 | """ 163 | Number of elements currently stored in the memory (same as #size()) 164 | See #size() 165 | :return: Integer : max_size >= size() >= 0 166 | """ 167 | return self.size 168 | 169 | @property 170 | def last(self): 171 | """ 172 | Return the last observation from the memory 173 | :return: Tuple (state, action, reward, terminal) 174 | """ 175 | idx = self._pos 176 | return self._states[idx], self._actions[idx], self._rewards[idx], self._terminals[idx] 177 | 178 | @property 179 | def size(self): 180 | """ 181 | Number of elements currently stored in the memory 182 | :return: Integer : max_size >= size >= 0 183 | """ 184 | return self._count 185 | 186 | @property 187 | def max_size(self): 188 | """ 189 | Maximum number of elements that can fit in the memory 190 | :return: Integer > 0 191 | """ 192 | return self._max_size 193 | 194 | @property 195 | def history_length(self): 196 | """ 197 | Number of states stacked along the first axis 198 | :return: int >= 1 199 | """ 200 | return 1 201 | 202 | def sample(self, size, replace=False): 203 | """ 204 | Generate a random sample of desired size (if available) from the current memory 205 | :param size: Number of samples 206 | :param replace: True if sampling with replacement 207 | :return: Integer[size] representing the sampled indices 208 | """ 209 | return np.random.choice(self._count, size, replace=replace) 210 | 211 | def get_state(self, index): 212 | """ 213 | Return the specified state 214 | :param index: State's index 215 | :return: state : (input_shape) 216 | """ 217 | index %= self.size 218 | return self._states[index] 219 | 220 | def get_action(self, index): 221 | """ 222 | Return the specified action 223 | :param index: Action's index 224 | :return: Integer 225 | """ 226 | index %= self.size 227 | return self._actions[index] 228 | 229 | def get_reward(self, index): 230 | """ 231 | Return the specified reward 232 | :param index: Reward's index 233 | :return: Integer 234 | """ 235 | index %= self.size 236 | return self._rewards[index] 237 | 238 | def minibatch(self, size): 239 | """ 240 | Generate a minibatch with the number of samples specified by the size parameter. 241 | :param size: Minibatch size 242 | :return: Tensor[minibatch_size, input_shape...) 243 | """ 244 | indexes = self.sample(size) 245 | 246 | pre_states = np.array([self.get_state(index) for index in indexes], dtype=np.float32) 247 | post_states = np.array([self.get_state(index + 1) for index in indexes], dtype=np.float32) 248 | actions = self._actions[indexes] 249 | rewards = self._rewards[indexes] 250 | terminals = self._terminals[indexes] 251 | 252 | return pre_states, actions, post_states, rewards, terminals 253 | 254 | def save(self, out_file): 255 | """ 256 | Save the current memory into a file in Numpy format 257 | :param out_file: File storage path 258 | :return: 259 | """ 260 | np.savez_compressed(out_file, states=self._states, actions=self._actions, 261 | rewards=self._rewards, terminals=self._terminals) 262 | 263 | def load(self, in_dir): 264 | pass 265 | -------------------------------------------------------------------------------- /malmopy/agent/astar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from heapq import heapify, heappop, heappush 21 | from collections import deque 22 | 23 | from . import BaseAgent 24 | 25 | 26 | class AStarAgent(BaseAgent): 27 | def __init__(self, name, nb_actions, visualizer=None): 28 | super(AStarAgent, self).__init__(name, nb_actions, visualizer) 29 | 30 | def _find_shortest_path(self, start, end, **kwargs): 31 | came_from, cost_so_far = {}, {} 32 | explorer = [] 33 | heapify(explorer) 34 | 35 | heappush(explorer, (0, start)) 36 | came_from[start] = None 37 | cost_so_far[start] = 0 38 | current = None 39 | 40 | while len(explorer) > 0: 41 | _, current = heappop(explorer) 42 | 43 | if self.matches(current, end): 44 | break 45 | 46 | for nb in self.neighbors(current, **kwargs): 47 | cost = nb.cost if hasattr(nb, "cost") else 1 48 | new_cost = cost_so_far[current] + cost 49 | 50 | if nb not in cost_so_far or new_cost < cost_so_far[nb]: 51 | cost_so_far[nb] = new_cost 52 | priority = new_cost + self.heuristic(end, nb, **kwargs) 53 | heappush(explorer, (priority, nb)) 54 | came_from[nb] = current 55 | 56 | # build path: 57 | path = deque() 58 | while current is not start: 59 | path.appendleft(current) 60 | current = came_from[current] 61 | return path, cost_so_far 62 | 63 | def neighbors(self, pos, **kwargs): 64 | raise NotImplementedError() 65 | 66 | def heuristic(self, a, b, **kwargs): 67 | raise NotImplementedError() 68 | 69 | def matches(self, a, b): 70 | return a == b 71 | -------------------------------------------------------------------------------- /malmopy/agent/explorer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | import numpy as np 21 | 22 | 23 | class BaseExplorer: 24 | """ Explore/exploit logic wrapper""" 25 | 26 | def __call__(self, step, nb_actions): 27 | return self.explore(step, nb_actions) 28 | 29 | def is_exploring(self, step): 30 | """ Returns True when exploring, False when exploiting """ 31 | raise NotImplementedError() 32 | 33 | def explore(self, step, nb_actions): 34 | """ Generate an exploratory action """ 35 | raise NotImplementedError() 36 | 37 | 38 | class LinearEpsilonGreedyExplorer(BaseExplorer): 39 | """ Explore/exploit logic wrapper 40 | 41 | 42 | This implementation uses linear interpolation between 43 | epsilon_max and epsilon_min to linearly anneal epsilon as a function of the current episode. 44 | 45 | 3 cases exists: 46 | - If 0 <= episode < eps_min_time then epsilon = interpolator(episode) 47 | - If episode >= eps_min_time then epsilon then epsilon = eps_min 48 | - Otherwise epsilon = eps_max 49 | """ 50 | 51 | def __init__(self, eps_max, eps_min, eps_min_time): 52 | assert eps_max > eps_min 53 | assert eps_min_time > 0 54 | 55 | self._eps_min_time = eps_min_time 56 | self._eps_min = eps_min 57 | self._eps_max = eps_max 58 | 59 | self._a = -(eps_max - eps_min) / eps_min_time 60 | 61 | def _epsilon(self, step): 62 | if step < 0: 63 | return self._eps_max 64 | elif step > self._eps_min_time: 65 | return self._eps_min 66 | else: 67 | return self._a * step + self._eps_max 68 | 69 | def is_exploring(self, step): 70 | return np.random.rand() < self._epsilon(step) 71 | 72 | def explore(self, step, nb_actions): 73 | return np.random.randint(0, nb_actions) 74 | -------------------------------------------------------------------------------- /malmopy/agent/gui.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | import six 21 | if six.PY2: 22 | from Tkinter import Tk 23 | else: 24 | from tkinter import Tk 25 | 26 | from . import BaseAgent 27 | from ..environment import VideoCapableEnvironment 28 | 29 | FPS_KEYS_MAPPING = {'w': 'move 1', 'a': 'strafe -1', 's': 'move -1', 'd': 'strafe 1', ' ': 'jump 1', 30 | 'q': 'strafe -1', 'z': 'move 1'} 31 | 32 | ARROW_KEYS_MAPPING = {'Left': 'turn -1', 'Right': 'turn 1', 'Up': 'move 1', 'Down': 'move -1'} 33 | 34 | CONTINUOUS_KEYS_MAPPING = {'Shift_L': 'crouch 1', 'Shift_R': 'crouch 1', 35 | '1': 'hotbar.1 1', '2': 'hotbar.2 1', '3': 'hotbar.3 1', '4': 'hotbar.4 1', 36 | '5': 'hotbar.5 1', 37 | '6': 'hotbar.6 1', '7': 'hotbar.7 1', '8': 'hotbar.8 1', '9': 'hotbar.9 1'} \ 38 | .update(ARROW_KEYS_MAPPING) 39 | 40 | DISCRETE_KEYS_MAPPING = {'Left': 'turn -1', 'Right': 'turn 1', 'Up': 'move 1', 'Down': 'move -1', 41 | '1': 'hotbar.1 1', '2': 'hotbar.2 1', '3': 'hotbar.3 1', '4': 'hotbar.4 1', '5': 'hotbar.5 1', 42 | '6': 'hotbar.6 1', '7': 'hotbar.7 1', '8': 'hotbar.8 1', '9': 'hotbar.9 1'} 43 | 44 | 45 | class GuiAgent(BaseAgent): 46 | def __init__(self, name, environment, keymap, win_name="Gui Agent", size=(640, 480), visualizer=None): 47 | assert isinstance(keymap, list), 'keymap should be a list[character]' 48 | assert isinstance(environment, VideoCapableEnvironment), 'environment should inherit from BaseEnvironment' 49 | 50 | super(GuiAgent, self).__init__(name, environment.available_actions, visualizer) 51 | 52 | if not environment.recording: 53 | environment.recording = True 54 | 55 | self._env = environment 56 | self._keymap = keymap 57 | self._tick = 20 58 | 59 | self._root = Tk() 60 | self._root.wm_title = win_name 61 | self._root.resizable(width=False, height=False) 62 | self._root.geometry = "%dx%d" % size 63 | 64 | self._build_layout(self._root) 65 | 66 | def act(self, new_state, reward, done, is_training=False): 67 | pass 68 | 69 | def show(self): 70 | self._root.mainloop() 71 | 72 | def _build_layout(self, root): 73 | """ 74 | Build the window layout 75 | :param root: 76 | :return: 77 | """ 78 | raise NotImplementedError() 79 | 80 | def _get_keymapping_help(self): 81 | return self._keymap 82 | -------------------------------------------------------------------------------- /malmopy/agent/qlearner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from collections import namedtuple 21 | 22 | import numpy as np 23 | 24 | from . import BaseAgent, ReplayMemory, BaseExplorer, LinearEpsilonGreedyExplorer 25 | from ..model import QModel 26 | from ..util import get_rank 27 | 28 | 29 | class TemporalMemory(ReplayMemory): 30 | """ 31 | Temporal memory adds a new dimension to store N previous samples (t, t-1, t-2, ..., t-N) 32 | when sampling from memory 33 | """ 34 | 35 | def __init__(self, max_size, sample_shape, history_length=4, 36 | unflicker=False): 37 | """ 38 | :param max_size: Maximum number of elements in the memory 39 | :param sample_shape: Shape of each sample 40 | :param history_length: Length of the visual memory (n previous frames) included with each state 41 | :param unflicker: Indicate if we need to compute the difference between consecutive frames 42 | """ 43 | super(TemporalMemory, self).__init__(max_size, sample_shape) 44 | 45 | self._unflicker = unflicker 46 | self._history_length = max(1, history_length) 47 | self._last = np.zeros(sample_shape) 48 | 49 | def append(self, state, action, reward, is_terminal): 50 | if self._unflicker: 51 | max_diff_buffer = np.maximum(self._last, state) 52 | self._last = state 53 | state = max_diff_buffer 54 | 55 | super(TemporalMemory, self).append(state, action, reward, is_terminal) 56 | 57 | if is_terminal: 58 | if self._unflicker: 59 | self._last.fill(0) 60 | 61 | def sample(self, size, replace=True): 62 | """ 63 | Generate a random minibatch. The returned indices can be retrieved using #get_state(). 64 | See the method #minibatch() if you want to retrieve samples directly 65 | :param size: The minibatch size 66 | :param replace: Indicate if one index can appear multiple times (True), only once (False) 67 | :return: Indexes of the sampled states 68 | """ 69 | 70 | if not replace: 71 | assert (self._count - 1) - self._history_length >= size, \ 72 | 'Cannot sample %d from %d elements' % ( 73 | size, (self._count - 1) - self._history_length) 74 | 75 | # Local variable access are faster in loops 76 | count, pos, history_len, terminals = self._count - 1, self._pos, \ 77 | self._history_length, self._terminals 78 | indexes = [] 79 | 80 | while len(indexes) < size: 81 | index = np.random.randint(history_len, count) 82 | 83 | # Check if replace=False to not include same index multiple times 84 | if replace or index not in indexes: 85 | 86 | # if not wrapping over current pointer, 87 | # then check if there is terminal state wrapped inside 88 | if not (index >= pos > index - history_len): 89 | if not terminals[(index - history_len):index].any(): 90 | indexes.append(index) 91 | 92 | assert len(indexes) == size 93 | return indexes 94 | 95 | def get_state(self, index): 96 | """ 97 | Return the specified state with the visual memory 98 | :param index: State's index 99 | :return: Tensor[history_length, input_shape...] 100 | """ 101 | index %= self._count 102 | history_length = self._history_length 103 | 104 | # If index > history_length, take from a slice 105 | if index >= history_length: 106 | return self._states[(index - (history_length - 1)):index + 1, ...] 107 | else: 108 | indexes = np.arange(index - self._history_length + 1, index + 1) 109 | return self._states.take(indexes, mode='wrap', axis=0) 110 | 111 | @property 112 | def unflicker(self): 113 | """ 114 | Indicate if samples added to the replay memory are preprocessed 115 | by taking the maximum between current frame and previous one 116 | :return: True if preprocessed, False otherwise 117 | """ 118 | return self._unflicker 119 | 120 | @property 121 | def history_length(self): 122 | """ 123 | Visual memory length 124 | (ie. the number of previous frames included for each sample) 125 | :return: Integer >= 0 126 | """ 127 | return self._history_length 128 | 129 | 130 | class History(object): 131 | """ 132 | Accumulator keeping track of the N previous frames to be used by the agent 133 | for evaluation 134 | """ 135 | 136 | def __init__(self, shape): 137 | self._buffer = np.zeros(shape, dtype=np.float32) 138 | 139 | @property 140 | def value(self): 141 | return self._buffer 142 | 143 | def append(self, state): 144 | self._buffer[:-1] = self._buffer[1:] 145 | self._buffer[-1, ...] = state 146 | 147 | def reset(self): 148 | self._buffer.fill(0) 149 | 150 | 151 | # Track previous state and action for observation 152 | Tracker = namedtuple('Tracker', ['state', 'action']) 153 | 154 | 155 | class QLearnerAgent(BaseAgent): 156 | def __init__(self, name, nb_actions, model, memory, gamma=.99, 157 | minibatch_size=32, train_after=50000, train_frequency=4, 158 | explorer=None, reward_clipping=None, visualizer=None): 159 | 160 | assert isinstance(model, QModel), 'model should inherit from QModel' 161 | assert get_rank(model.input_shape) > 1, 'input_shape rank should be > 1' 162 | assert isinstance(memory, ReplayMemory), 'memory should inherit from ' \ 163 | 'ReplayMemory' 164 | assert 0 < gamma < 1, 'gamma should be 0 < gamma < 1 (got: %d)' % gamma 165 | assert minibatch_size > 0, 'minibatch_size should be > 0 (got: %d)' % minibatch_size 166 | assert train_after >= 0, 'train_after should be >= 0 (got %d)' % train_after 167 | assert train_frequency > 0, 'train_frequency should be > 0' 168 | 169 | super(QLearnerAgent, self).__init__(name, nb_actions, visualizer) 170 | 171 | self._model = model 172 | self._memory = memory 173 | self._gamma = gamma 174 | self._minibatch_size = minibatch_size 175 | self._train_after = train_after 176 | self._train_frequency = train_frequency 177 | self._history = History(model.input_shape) 178 | self._actions_taken = 0 179 | self._tracker = None 180 | 181 | # Rewards clipping related 182 | reward_clipping = reward_clipping or (-2 ** 31 - 1, 2 ** 31 - 1) 183 | assert isinstance(reward_clipping, tuple) and len(reward_clipping) == 2, \ 184 | 'clip_reward should be None or (min_reward, max_reward)' 185 | 186 | assert reward_clipping[0] <= reward_clipping[1], \ 187 | 'max reward_clipping should be >= min (got %d < %d)' % ( 188 | reward_clipping[1], reward_clipping[0]) 189 | 190 | self._reward_clipping = reward_clipping 191 | 192 | # Explorer related 193 | explorer = explorer or LinearEpsilonGreedyExplorer(1, 0.1, 1e6) 194 | assert isinstance(explorer, BaseExplorer), \ 195 | 'explorer should inherit from BaseExplorer' 196 | 197 | self._explorer = explorer 198 | 199 | # Stats related 200 | self._stats_rewards = [] 201 | self._stats_mean_qvalues = [] 202 | self._stats_stddev_qvalues = [] 203 | self._stats_loss = [] 204 | 205 | def act(self, new_state, reward, done, is_training=False): 206 | 207 | if self._tracker is not None: 208 | self.observe(self._tracker.state, self._tracker.action, 209 | reward, new_state, done) 210 | 211 | if is_training: 212 | if self._actions_taken > self._train_after: 213 | self.learn() 214 | 215 | # Append the new state to the history 216 | self._history.append(new_state) 217 | 218 | # select the next action 219 | if self._explorer.is_exploring(self._actions_taken): 220 | new_action = self._explorer(self._actions_taken, self.nb_actions) 221 | else: 222 | q_values = self._model.evaluate(self._history.value) 223 | new_action = q_values.argmax() 224 | 225 | self._stats_mean_qvalues.append(q_values.max()) 226 | self._stats_stddev_qvalues.append(np.std(q_values)) 227 | 228 | self._tracker = Tracker(new_state, new_action) 229 | self._actions_taken += 1 230 | 231 | return new_action 232 | 233 | def observe(self, old_state, action, reward, new_state, is_terminal): 234 | if is_terminal: 235 | self._history.reset() 236 | 237 | min_val, max_val = self._reward_clipping 238 | reward = max(min_val, min(max_val, reward)) 239 | self._memory.append(old_state, int(action), reward, is_terminal) 240 | 241 | def learn(self): 242 | if (self._actions_taken % self._train_frequency) == 0: 243 | minibatch = self._memory.minibatch(self._minibatch_size) 244 | q_t_target = self._compute_q(*minibatch) 245 | 246 | self._model.train(minibatch[0], q_t_target, minibatch[1]) 247 | self._stats_loss.append(self._model.loss_val) 248 | 249 | def inject_summaries(self, idx): 250 | if len(self._stats_mean_qvalues) > 0: 251 | self.visualize(idx, "%s/episode mean q" % self.name, 252 | np.asscalar(np.mean(self._stats_mean_qvalues))) 253 | self.visualize(idx, "%s/episode mean stddev.q" % self.name, 254 | np.asscalar(np.mean(self._stats_stddev_qvalues))) 255 | 256 | if len(self._stats_loss) > 0: 257 | self.visualize(idx, "%s/episode mean loss" % self.name, 258 | np.asscalar(np.mean(self._stats_loss))) 259 | 260 | if len(self._stats_rewards) > 0: 261 | self.visualize(idx, "%s/episode mean reward" % self.name, 262 | np.asscalar(np.mean(self._stats_rewards))) 263 | 264 | # Reset 265 | self._stats_mean_qvalues = [] 266 | self._stats_stddev_qvalues = [] 267 | self._stats_loss = [] 268 | self._stats_rewards = [] 269 | 270 | def _compute_q(self, pres, actions, posts, rewards, terminals): 271 | """ Compute the Q Values from input states """ 272 | 273 | q_hat = self._model.evaluate(posts, model=QModel.TARGET_NETWORK) 274 | q_hat_eval = q_hat[np.arange(len(actions)), q_hat.argmax(axis=1)] 275 | 276 | q_targets = (1 - terminals) * (self._gamma * q_hat_eval) + rewards 277 | return np.array(q_targets, dtype=np.float32) 278 | -------------------------------------------------------------------------------- /malmopy/environment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from .environment import BaseEnvironment, VideoCapableEnvironment 21 | -------------------------------------------------------------------------------- /malmopy/environment/environment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | import numpy as np 21 | 22 | from ..util import check_rank, get_rank, resize, rgb2gray 23 | 24 | 25 | class StateBuilder(object): 26 | """ 27 | StateBuilder are object that map environment state into another representation. 28 | 29 | Subclasses should override the build() method which can map specific environment behavior. 30 | For concrete examples, malmo package has some predefined state builder specific to Malmo 31 | """ 32 | 33 | def build(self, environment): 34 | raise NotImplementedError() 35 | 36 | def __call__(self, *args, **kwargs): 37 | return self.build(*args) 38 | 39 | 40 | class ALEStateBuilder(StateBuilder): 41 | """ 42 | Atari Environment state builder interface. 43 | 44 | This class assumes the environment.state() returns a numpy array. 45 | """ 46 | 47 | SCALE_FACTOR = 1. / 255. 48 | 49 | def __init__(self, shape=(84, 84), normalize=True): 50 | self._shape = shape 51 | self._normalize = bool(normalize) 52 | 53 | def build(self, environment): 54 | if not isinstance(environment, np.ndarray): 55 | raise ValueError( 56 | 'environment type is not a numpy.ndarray (got %s)' % str( 57 | type(environment))) 58 | 59 | state = environment 60 | 61 | # Convert to gray 62 | if check_rank(environment.shape, 3): 63 | state = rgb2gray(environment) 64 | elif get_rank(state) > 3: 65 | raise ValueError('Cannot handle data with more than 3 dimensions') 66 | 67 | # Resize 68 | if state.shape != self._shape: 69 | state = resize(state, self._shape) 70 | 71 | return (state * ALEStateBuilder.SCALE_FACTOR).astype(np.float32) 72 | 73 | 74 | class BaseEnvironment(object): 75 | """ 76 | Abstract representation of an interactive environment 77 | """ 78 | 79 | def __init__(self): 80 | self._score = 0. 81 | self._reward = 0. 82 | self._done = False 83 | self._state = None 84 | 85 | def do(self, action): 86 | """ 87 | Do the specified action in the environment 88 | :param action: The action to be executed 89 | :return Tuple holding the new state, the reward and a flag indicating if the environment is done 90 | """ 91 | raise NotImplementedError() 92 | 93 | def reset(self): 94 | """ 95 | Reset the current environment's internal state. 96 | :return: 97 | """ 98 | self._score = 0. 99 | self._reward = 0. 100 | self._done = False 101 | self._state = None 102 | 103 | @property 104 | def available_actions(self): 105 | """ 106 | Returns the number of actions available in this environment 107 | :return: Integer > 0 108 | """ 109 | raise NotImplementedError() 110 | 111 | @property 112 | def done(self): 113 | """ 114 | Indicate if the current environment is in a terminal state 115 | :return: Boolean True if environment is in a terminal state, False otherwise 116 | """ 117 | return self._done 118 | 119 | @property 120 | def state(self): 121 | """ 122 | Return the current environment state 123 | :return: 124 | """ 125 | return self._state 126 | 127 | @property 128 | def reward(self): 129 | """ 130 | Return accumulated rewards 131 | :return: Float as the current accumulated rewards since last state 132 | """ 133 | return self._reward 134 | 135 | @property 136 | def score(self): 137 | """ 138 | Return the environment's current score. 139 | It is common that the score will the sum of observed rewards, but subclasses can change this behavior. 140 | :return: Number 141 | """ 142 | return self._score 143 | 144 | @property 145 | def is_turn_based(self): 146 | """ 147 | Indicate if this environment is running on a turn-based scenario (i.e., 148 | agents take turns and wait for other agents' turns to complete before taking the next action). 149 | All subclasses should override this accordingly to the running scenario. 150 | As currently turn based is not the default behavior, the value returned is False 151 | :return: False 152 | """ 153 | return False 154 | 155 | 156 | class VideoCapableEnvironment(BaseEnvironment): 157 | """ 158 | Represent the capacity of an environment to stream it's current state. 159 | Streaming relies on 2 properties : 160 | - fps : Number of frame this environment is able to generate each second 161 | - frame : The latest frame generated by this environment 162 | The display adapter should ask for a new frame with a 1/fps millisecond delay. 163 | If there is no updated frame, the frame property can return None. 164 | """ 165 | 166 | def __init__(self): 167 | super(VideoCapableEnvironment, self).__init__() 168 | self._recording = False 169 | 170 | @property 171 | def recording(self): 172 | """ 173 | Indicate if the current environment is dispatching the video stream 174 | :return: True if streaming, False otherwise 175 | """ 176 | return self._recording 177 | 178 | @recording.setter 179 | def recording(self, val): 180 | """ 181 | Change the internal recording state. 182 | :param val: True to activate video streaming, False otherwise 183 | :return: 184 | """ 185 | self._recording = bool(val) 186 | 187 | @property 188 | def frame(self): 189 | """ 190 | Return the most recent frame from the environment 191 | :return: PIL Image representing the current environment 192 | """ 193 | raise NotImplementedError() 194 | -------------------------------------------------------------------------------- /malmopy/environment/gym/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from .gym import GymEnvironment 21 | -------------------------------------------------------------------------------- /malmopy/environment/gym/gym.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | import gym 21 | import numpy as np 22 | import six 23 | from PIL import Image 24 | from gym.wrappers import Monitor 25 | 26 | from ..environment import VideoCapableEnvironment, StateBuilder, ALEStateBuilder 27 | 28 | 29 | def need_record(episode_id): 30 | return episode_id % 1000 == 0 31 | 32 | 33 | class GymEnvironment(VideoCapableEnvironment): 34 | """ 35 | Wraps an Open AI Gym environment 36 | """ 37 | 38 | def __init__(self, env_name, state_builder=ALEStateBuilder(), repeat_action=4, no_op=30, monitoring_path=None): 39 | assert isinstance(state_builder, StateBuilder), 'state_builder should inherit from StateBuilder' 40 | assert isinstance(repeat_action, (int, tuple)), 'repeat_action should be int or tuple' 41 | if isinstance(repeat_action, int): 42 | assert repeat_action >= 1, "repeat_action should be >= 1" 43 | elif isinstance(repeat_action, tuple): 44 | assert len(repeat_action) == 2, 'repeat_action should be a length-2 tuple: (min frameskip, max frameskip)' 45 | assert repeat_action[0] < repeat_action[1], 'repeat_action[0] should be < repeat_action[1]' 46 | 47 | super(GymEnvironment, self).__init__() 48 | 49 | self._state_builder = state_builder 50 | self._env = gym.make(env_name) 51 | self._env.env.frameskip = repeat_action 52 | self._no_op = max(0, no_op) 53 | self._done = True 54 | 55 | if monitoring_path is not None: 56 | self._env = Monitor(self._env, monitoring_path, video_callable=need_record) 57 | 58 | @property 59 | def available_actions(self): 60 | return self._env.action_space.n 61 | 62 | @property 63 | def state(self): 64 | return None if self._state is None else self._state_builder(self._state) 65 | 66 | @property 67 | def lives(self): 68 | return self._env.env.ale.lives() 69 | 70 | @property 71 | def frame(self): 72 | return Image.fromarray(self._state) 73 | 74 | def do(self, action): 75 | self._state, self._reward, self._done, _ = self._env.step(action) 76 | self._score += self._reward 77 | return self.state, self._reward, self._done 78 | 79 | def reset(self): 80 | super(GymEnvironment, self).reset() 81 | 82 | self._state = self._env.reset() 83 | 84 | # Random number of initial no-op to introduce stochasticity 85 | if self._no_op > 0: 86 | for _ in six.moves.range(np.random.randint(1, self._no_op)): 87 | self._state, _, _, _ = self._env.step(0) 88 | 89 | return self.state 90 | -------------------------------------------------------------------------------- /malmopy/environment/malmo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from .malmo import MalmoEnvironment, allocate_remotes 21 | from .malmo import MalmoStateBuilder, MalmoRGBStateBuilder, MalmoALEStateBuilder 22 | -------------------------------------------------------------------------------- /malmopy/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from .model import * 21 | 22 | -------------------------------------------------------------------------------- /malmopy/model/chainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from .qlearning import * 21 | -------------------------------------------------------------------------------- /malmopy/model/chainer/qlearning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | import chainer.cuda as cuda 21 | import chainer.functions as F 22 | import chainer.links as L 23 | import numpy as np 24 | from chainer import ChainList 25 | from chainer.initializers import HeUniform 26 | from chainer.optimizers import Adam 27 | from chainer.serializers import save_npz, load_npz 28 | 29 | from ..model import QModel 30 | from ...util import check_rank, get_rank 31 | 32 | 33 | class ChainerModel(ChainList): 34 | """ 35 | Wraps a Chainer Chain and enforces the model to be callable. 36 | Every model should override the __call__ method as a forward call. 37 | """ 38 | 39 | def __init__(self, input_shape, output_shape): 40 | self.input_shape = input_shape 41 | self.output_shape = output_shape 42 | 43 | super(ChainerModel, self).__init__(*self._build_model()) 44 | 45 | def __call__(self, *args, **kwargs): 46 | raise NotImplementedError() 47 | 48 | def _build_model(self): 49 | raise NotImplementedError() 50 | 51 | 52 | class MLPChain(ChainerModel): 53 | """ 54 | Create a Multi Layer Perceptron neural network. 55 | The number of layers and units for each layer can be specified using hidden_layer_sizes. 56 | For example for a 128 units on the first hidden layer, then 256 on the second and 512 on the third: 57 | 58 | >>> MLPChain(input_shape=(28, 28), output_shape=10, hidden_layer_sizes=(128, 256, 512)) 59 | 60 | Note : The network will contain len(hidden_layer_sizes) + 2 layers because 61 | of the input layer and the output layer. 62 | """ 63 | 64 | def __init__(self, in_shape, output_shape, 65 | hidden_layer_sizes=(512, 512, 512), activation=F.relu): 66 | self._activation = activation 67 | self._hidden_layer_sizes = hidden_layer_sizes 68 | 69 | super(MLPChain, self).__init__(in_shape, output_shape) 70 | 71 | @property 72 | def hidden_layer_sizes(self): 73 | return self._hidden_layer_sizes 74 | 75 | def __call__(self, x): 76 | f = self._activation 77 | 78 | for layer in self[:-1]: 79 | x = f(layer(x)) 80 | return self[-1](x) 81 | 82 | def _build_model(self): 83 | hidden_layers = [L.Linear(None, units) for units in 84 | self._hidden_layer_sizes] 85 | hidden_layers += [L.Linear(None, self.output_shape)] 86 | 87 | return hidden_layers 88 | 89 | 90 | class ReducedDQNChain(ChainerModel): 91 | """ 92 | Simplified DQN topology: 93 | 94 | Convolution(64, kernel=(4, 4), strides=(2, 2) 95 | Convolution(64, kernel=(3, 3), strides=(1, 1) 96 | Dense(512) 97 | Dense(output_shape) 98 | """ 99 | def __init__(self, in_shape, output_shape): 100 | super(ReducedDQNChain, self).__init__(in_shape, output_shape) 101 | 102 | def __call__(self, x): 103 | for layer in self[:-1]: 104 | x = F.relu(layer(x)) 105 | return self[-1](x) 106 | 107 | def _build_model(self): 108 | initializer = HeUniform() 109 | in_shape = self.input_shape[0] 110 | 111 | return [L.Convolution2D(in_shape, 64, ksize=4, stride=2, 112 | initialW=initializer), 113 | L.Convolution2D(64, 64, ksize=3, stride=1, 114 | initialW=initializer), 115 | L.Linear(None, 512, initialW=HeUniform(0.1)), 116 | L.Linear(512, self.output_shape, initialW=HeUniform(0.1))] 117 | 118 | 119 | class DQNChain(ChainerModel): 120 | """ 121 | DQN topology as in 122 | (Mnih & al. 2015): Human-level control through deep reinforcement learning" 123 | Nature 518.7540 (2015): 529-533. 124 | 125 | Convolution(32, kernel=(8, 8), strides=(4, 4) 126 | Convolution(64, kernel=(4, 4), strides=(2, 2) 127 | Convolution(64, kernel=(3, 3), strides=(1, 1) 128 | Dense(512) 129 | Dense(output_shape) 130 | """ 131 | 132 | def __init__(self, in_shape, output_shape): 133 | super(DQNChain, self).__init__(in_shape, output_shape) 134 | 135 | def __call__(self, x): 136 | for layer in self[:-1]: 137 | x = F.relu(layer(x)) 138 | return self[-1](x) 139 | 140 | def _build_model(self): 141 | initializer = HeUniform() 142 | in_shape = self.input_shape[0] 143 | 144 | return [L.Convolution2D(in_shape, 32, ksize=8, stride=4, 145 | initialW=initializer), 146 | L.Convolution2D(32, 64, ksize=4, stride=2, 147 | initialW=initializer), 148 | L.Convolution2D(64, 64, ksize=3, stride=1, 149 | initialW=initializer), 150 | L.Linear(7 * 7 * 64, 512, initialW=HeUniform(0.01)), 151 | L.Linear(512, self.output_shape, initialW=HeUniform(0.01))] 152 | 153 | 154 | class QNeuralNetwork(QModel): 155 | def __init__(self, model, target, device_id=-1, 156 | learning_rate=0.00025, momentum=.9, 157 | minibatch_size=32, update_interval=10000): 158 | 159 | assert isinstance(model, ChainerModel), \ 160 | 'model should inherit from ChainerModel' 161 | 162 | super(QNeuralNetwork, self).__init__(model.input_shape, 163 | model.output_shape) 164 | 165 | self._gpu_device = None 166 | self._loss_val = 0 167 | 168 | # Target model update method 169 | self._steps = 0 170 | self._target_update_interval = update_interval 171 | 172 | # Setup model and target network 173 | self._minibatch_size = minibatch_size 174 | self._model = model 175 | self._target = target 176 | self._target.copyparams(self._model) 177 | 178 | # If GPU move to GPU memory 179 | if device_id >= 0: 180 | with cuda.get_device(device_id) as device: 181 | self._gpu_device = device 182 | self._model.to_gpu(device) 183 | self._target.to_gpu(device) 184 | 185 | # Setup optimizer 186 | self._optimizer = Adam(learning_rate, momentum, 0.999) 187 | self._optimizer.setup(self._model) 188 | 189 | def evaluate(self, environment, model=QModel.ACTION_VALUE_NETWORK): 190 | if check_rank(environment.shape, get_rank(self._input_shape)): 191 | environment = environment.reshape((1,) + environment.shape) 192 | 193 | # Move data if necessary 194 | if self._gpu_device is not None: 195 | environment = cuda.to_gpu(environment, self._gpu_device) 196 | 197 | if model == QModel.ACTION_VALUE_NETWORK: 198 | output = self._model(environment) 199 | else: 200 | output = self._target(environment) 201 | 202 | return cuda.to_cpu(output.data) 203 | 204 | def train(self, x, y, actions=None): 205 | actions = actions.astype(np.int32) 206 | batch_size = len(actions) 207 | 208 | if self._gpu_device: 209 | x = cuda.to_gpu(x, self._gpu_device) 210 | y = cuda.to_gpu(y, self._gpu_device) 211 | actions = cuda.to_gpu(actions, self._gpu_device) 212 | 213 | q = self._model(x) 214 | q_subset = F.reshape(F.select_item(q, actions), (batch_size, 1)) 215 | y = y.reshape(batch_size, 1) 216 | 217 | loss = F.sum(F.huber_loss(q_subset, y, 1.0)) 218 | 219 | self._model.cleargrads() 220 | loss.backward() 221 | self._optimizer.update() 222 | 223 | self._loss_val = np.asscalar(cuda.to_cpu(loss.data)) 224 | 225 | # Keeps track of the number of train() calls 226 | self._steps += 1 227 | if self._steps % self._target_update_interval == 0: 228 | # copy weights 229 | self._target.copyparams(self._model) 230 | 231 | @property 232 | def loss_val(self): 233 | return self._loss_val # / self._minibatch_size 234 | 235 | def save(self, output_file): 236 | save_npz(output_file, self._model) 237 | 238 | def load(self, input_file): 239 | load_npz(input_file, self._model) 240 | 241 | # Copy parameter from model to target 242 | self._target.copyparams(self._model) 243 | -------------------------------------------------------------------------------- /malmopy/model/cntk/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from .base import * 21 | from .qlearning import * 22 | -------------------------------------------------------------------------------- /malmopy/model/cntk/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | import numpy as np 19 | from cntk.device import cpu, gpu, try_set_default_device 20 | from cntk.train.distributed import Communicator 21 | from cntk.learners import set_default_unit_gain_value 22 | from cntk.ops import abs, element_select, less, square, sqrt, reduce_sum, reduce_mean 23 | 24 | from ...visualization import Visualizable 25 | 26 | 27 | def rmse(y, y_hat, axis=0): 28 | """ 29 | Compute the Root Mean Squared error as part of the model graph 30 | 31 | :param y: CNTK Variable holding the true value of Y 32 | :param y_hat: CNTK variable holding the estimated value of Y 33 | :param axis: The axis over which to compute the mean, 0 by default 34 | :return: Root Mean Squared error 35 | """ 36 | return sqrt(reduce_mean(square(y_hat - y), axis=axis)) 37 | 38 | 39 | def huber_loss(y_hat, y, delta): 40 | """ 41 | Compute the Huber Loss as part of the model graph 42 | 43 | Huber Loss is more robust to outliers. It is defined as: 44 | if |y - h_hat| < delta : 45 | 0.5 * (y - y_hat)**2 46 | else : 47 | delta * |y - y_hat| - 0.5 * delta**2 48 | 49 | :param y: Target value 50 | :param y_hat: Estimated value 51 | :param delta: Outliers threshold 52 | :return: float 53 | """ 54 | half_delta_squared = 0.5 * delta * delta 55 | error = y - y_hat 56 | abs_error = abs(error) 57 | 58 | less_than = 0.5 * square(error) 59 | more_than = (delta * abs_error) - half_delta_squared 60 | 61 | loss_per_sample = element_select(less(abs_error, delta), less_than, more_than) 62 | 63 | return reduce_sum(loss_per_sample, name='loss') 64 | 65 | 66 | def as_learning_rate_by_sample(learning_rate_per_minibatch, minibatch_size, momentum=0, momentum_as_unit_gain=False): 67 | """ 68 | Compute the scale parameter for the learning rate to match the learning rate 69 | definition used in other deep learning frameworks. 70 | In CNTK, gradients are calculated as follows: 71 | g(t + 1) = momentum * v(t) + (1-momentum) * gradient(t) 72 | 73 | Whereas in other frameworks they are computed this way : 74 | g(t + 1) = momentum * v(t) 75 | 76 | According to the above equations we need to scale the learning rate with regard to the momentum by a 77 | factor of 1/(1 - momentum) 78 | :param learning_rate_per_minibatch: The current learning rate 79 | :param minibatch_size: Size of the minibatch 80 | :param momentum: The current momentum (0 by default, used only when momentum_as_unit_gain is True) 81 | :param momentum_as_unit_gain: Indicate whetherf the momentum is a unit gain factor (CNTK) or not (TensorFlow, etc.) 82 | :return: Scaled learning rate according to momentum and minibatch size 83 | """ 84 | assert learning_rate_per_minibatch > 0, "learning_rate_per_minibatch cannot be < 0" 85 | assert minibatch_size > 0, "minibatch_size cannot be < 1" 86 | 87 | learning_rate_per_sample = learning_rate_per_minibatch / minibatch_size 88 | 89 | if momentum_as_unit_gain: 90 | learning_rate_per_sample /= (1. - momentum) 91 | 92 | return learning_rate_per_sample 93 | 94 | 95 | def as_momentum_as_time_constant(momentum, minibatch_size): 96 | """ Convert a momentum provided a global for the a full minibatch 97 | to the momentum as number of sample seen rate by sample 98 | 99 | momentum_as_time_constant = -minibatch_size / (np.log(momentum)) 100 | """ 101 | return np.ceil(-minibatch_size / (np.log(momentum))) 102 | 103 | 104 | def prepend_batch_seq_axis(tensor): 105 | """ 106 | CNTK uses 2 dynamic axes (batch, sequence, input_shape...). 107 | To have a single sample with length 1 you need to pass (1, 1, input_shape...) 108 | This method reshapes a tensor to add to the batch and sequence axis equal to 1. 109 | :param tensor: The tensor to be reshaped 110 | :return: Reshaped tensor with batch and sequence axis = 1 111 | """ 112 | return tensor.reshape((1, 1,) + tensor.shape) 113 | 114 | 115 | def prepend_batch_axis(tensor): 116 | """ 117 | CNTK uses 2 dynamic axes (batch, sequence, input_shape...). 118 | If you define variables with dynamic_axes=[Axis.default_batch_axis()] you can rid of sequence axis 119 | 120 | To have a single sample with length 1 you need to pass (1, input_shape...) 121 | This method reshapes a tensor to add to the batch and sequence axis equal to 1. 122 | :param tensor: The tensor to be reshaped 123 | :return: Reshaped tensor with batch and sequence axis = 1 124 | """ 125 | return tensor.reshape((1,) + tensor.shape) 126 | 127 | 128 | class CntkModel(Visualizable): 129 | """ Base class for CNTK based neural networks. 130 | 131 | It handles the management of the CPU/GPU device and provides commodity methods for exporting the model 132 | """ 133 | 134 | def __init__(self, device_id=None, unit_gain=False, n_workers=1, visualizer=None): 135 | """ 136 | Abstract constructor of CNTK model. 137 | This constructor wraps CNTK intialization and tuning 138 | :param device_id: Use None if you want CNTK to use the best available device, -1 for CPU, >= 0 for GPU 139 | :param n_workers: Number of concurrent workers for distributed training. Keep set to 1 for non distributed mode 140 | :param visualizer: Optional visualizer allowing model to save summary data 141 | """ 142 | assert n_workers >= 1, 'n_workers should be at least 1 (not distributed) or > 1 if distributed' 143 | 144 | Visualizable.__init__(self, visualizer) 145 | 146 | self._model = None 147 | self._learner = None 148 | self._loss = None 149 | self._distributed = n_workers > 1 150 | 151 | if isinstance(device_id, int): 152 | try_set_default_device(cpu() if device_id == -1 else gpu(device_id)) 153 | 154 | set_default_unit_gain_value(unit_gain) 155 | 156 | def _build_model(self): 157 | raise NotImplementedError() 158 | 159 | @property 160 | def loss_val(self): 161 | raise NotImplementedError() 162 | 163 | @property 164 | def model(self): 165 | return self._model 166 | 167 | @property 168 | def distributed_training(self): 169 | return self._distributed 170 | 171 | @property 172 | def distributed_rank(self): 173 | if self._distributed: 174 | if self._learner and hasattr(self._learner, 'communicator'): 175 | return self._learner.communicator().rank() 176 | else: 177 | return 0 178 | 179 | def load(self, input_file): 180 | if self._model is None: 181 | raise ValueError("cannot load to a model that equals None") 182 | 183 | self._model.restore(input_file) 184 | 185 | def save(self, output_file): 186 | if self._model is None: 187 | raise ValueError("cannot save a model that equals None") 188 | 189 | self._model.save(output_file) 190 | 191 | def finalize(self): 192 | if self._distributed: 193 | Communicator.finalize() 194 | -------------------------------------------------------------------------------- /malmopy/model/cntk/qlearning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from cntk import Value 21 | from cntk.axis import Axis 22 | from cntk.initializer import he_uniform, he_normal 23 | from cntk.layers import Convolution, Dense, default_options 24 | from cntk.layers.higher_order_layers import Sequential 25 | from cntk.learners import adam, momentum_schedule, learning_rate_schedule, \ 26 | UnitType 27 | from cntk.ops import input, relu, reduce_sum 28 | from cntk.ops.functions import CloneMethod 29 | from cntk.train.trainer import Trainer 30 | 31 | from . import CntkModel, prepend_batch_axis, huber_loss 32 | from ..model import QModel 33 | from ...util import check_rank 34 | 35 | 36 | class QNeuralNetwork(CntkModel, QModel): 37 | """ 38 | Represents a learning capable entity using CNTK 39 | """ 40 | 41 | def __init__(self, in_shape, output_shape, device_id=None, 42 | learning_rate=0.00025, momentum=0.9, 43 | minibatch_size=32, update_interval=10000, 44 | n_workers=1, visualizer=None): 45 | 46 | """ 47 | Q Neural Network following Mnih and al. implementation and default options. 48 | 49 | The network has the following topology: 50 | Convolution(32, (8, 8)) 51 | Convolution(64, (4, 4)) 52 | Convolution(64, (2, 2)) 53 | Dense(512) 54 | 55 | :param in_shape: Shape of the observations perceived by the learner (the neural net input) 56 | :param output_shape: Size of the action space (mapped to the number of output neurons) 57 | 58 | :param device_id: Use None to let CNTK select the best available device, 59 | -1 for CPU, >= 0 for GPU 60 | (default: None) 61 | 62 | :param learning_rate: Learning rate 63 | (default: 0.00025, as per Mnih et al.) 64 | 65 | :param momentum: Momentum, provided as momentum value for 66 | averaging gradients without unit gain filter 67 | Note that CNTK does not currently provide an implementation 68 | of Graves' RmsProp with momentum. 69 | It uses AdamSGD optimizer instead. 70 | (default: 0, no momentum with RProp optimizer) 71 | 72 | :param minibatch_size: Minibatch size 73 | (default: 32, as per Mnih et al.) 74 | 75 | :param n_workers: Number of concurrent worker for distributed training. 76 | (default: 1, not distributed) 77 | 78 | :param visualizer: Optional visualizer allowing the model to save summary data 79 | (default: None, no visualization) 80 | 81 | Ref: Mnih et al.: "Human-level control through deep reinforcement learning." 82 | Nature 518.7540 (2015): 529-533. 83 | """ 84 | 85 | assert learning_rate > 0, 'learning_rate should be > 0' 86 | assert 0. <= momentum < 1, 'momentum should be 0 <= momentum < 1' 87 | 88 | QModel.__init__(self, in_shape, output_shape) 89 | CntkModel.__init__(self, device_id, False, n_workers, visualizer) 90 | 91 | self._nb_actions = output_shape 92 | self._steps = 0 93 | self._target_update_interval = update_interval 94 | self._target = None 95 | 96 | # Input vars 97 | self._environment = input(in_shape, name='env', 98 | dynamic_axes=(Axis.default_batch_axis())) 99 | self._q_targets = input(1, name='q_targets', 100 | dynamic_axes=(Axis.default_batch_axis())) 101 | self._actions = input(output_shape, name='actions', 102 | dynamic_axes=(Axis.default_batch_axis())) 103 | 104 | # Define the neural network graph 105 | self._model = self._build_model()(self._environment) 106 | self._target = self._model.clone( 107 | CloneMethod.freeze, {self._environment: self._environment} 108 | ) 109 | 110 | # Define the learning rate 111 | lr_schedule = learning_rate_schedule(learning_rate, UnitType.minibatch) 112 | 113 | # AdamSGD optimizer 114 | m_schedule = momentum_schedule(momentum) 115 | vm_schedule = momentum_schedule(0.999) 116 | l_sgd = adam(self._model.parameters, lr_schedule, 117 | momentum=m_schedule, 118 | unit_gain=True, 119 | variance_momentum=vm_schedule) 120 | 121 | if self.distributed_training: 122 | raise NotImplementedError('ASGD not implemented yet.') 123 | 124 | # _actions is a sparse 1-hot encoding of the actions done by the agent 125 | q_acted = reduce_sum(self._model * self._actions, axis=0) 126 | 127 | # Define the trainer with Huber Loss function 128 | criterion = huber_loss(q_acted, self._q_targets, 1.0) 129 | 130 | self._learner = l_sgd 131 | self._trainer = Trainer(self._model, (criterion, None), l_sgd) 132 | 133 | @property 134 | def loss_val(self): 135 | return self._trainer.previous_minibatch_loss_average 136 | 137 | def _build_model(self): 138 | with default_options(init=he_uniform(), activation=relu, bias=True): 139 | model = Sequential([ 140 | Convolution((8, 8), 32, strides=(4, 4)), 141 | Convolution((4, 4), 64, strides=(2, 2)), 142 | Convolution((3, 3), 64, strides=(1, 1)), 143 | Dense(512, init=he_normal(0.01)), 144 | Dense(self._nb_actions, activation=None, init=he_normal(0.01)) 145 | ]) 146 | return model 147 | 148 | def train(self, x, q_value_targets, actions=None): 149 | assert actions is not None, 'actions cannot be None' 150 | 151 | # We need to add extra dimensions to shape [N, 1] => [N, 1] 152 | if check_rank(q_value_targets.shape, 1): 153 | q_value_targets = q_value_targets.reshape((-1, 1)) 154 | 155 | # Add extra dimensions to match shape [N, 1] required by one_hot 156 | if check_rank(actions.shape, 1): 157 | actions = actions.reshape((-1, 1)) 158 | 159 | # We need batch axis 160 | if check_rank(x.shape, len(self._environment.shape)): 161 | x = prepend_batch_axis(x) 162 | 163 | self._trainer.train_minibatch({ 164 | self._environment: x, 165 | self._actions: Value.one_hot(actions, self._nb_actions), 166 | self._q_targets: q_value_targets 167 | }) 168 | 169 | # Counter number of train calls 170 | self._steps += 1 171 | 172 | # Update the model with the target one 173 | if (self._steps % self._target_update_interval) == 0: 174 | self._target = self._model.clone( 175 | CloneMethod.freeze, {self._environment: self._environment} 176 | ) 177 | 178 | def evaluate(self, data, model=QModel.ACTION_VALUE_NETWORK): 179 | # If evaluating a single sample, expand the minibatch axis 180 | # (minibatch = 1, input_shape...) 181 | if len(data.shape) == len(self.input_shape): 182 | data = prepend_batch_axis(data) # Append minibatch dim 183 | 184 | if model == QModel.TARGET_NETWORK: 185 | predictions = self._target.eval({self._environment: data}) 186 | else: 187 | predictions = self._model.eval({self._environment: data}) 188 | return predictions.squeeze() 189 | 190 | 191 | class ReducedQNeuralNetwork(QNeuralNetwork): 192 | """ 193 | Represents a learning capable entity using CNTK, reduced model 194 | """ 195 | 196 | def __init__(self, in_shape, output_shape, device_id=None, 197 | learning_rate=0.00025, momentum=0.9, 198 | minibatch_size=32, update_interval=10000, 199 | n_workers=1, visualizer=None): 200 | 201 | QNeuralNetwork.__init__(self, in_shape, output_shape, device_id, 202 | learning_rate, momentum, minibatch_size, update_interval, 203 | n_workers, visualizer) 204 | 205 | def _build_model(self): 206 | with default_options(init=he_uniform(), activation=relu, bias=True): 207 | model = Sequential([ 208 | Convolution((4, 4), 64, strides=(2, 2), name='conv1'), 209 | Convolution((3, 3), 64, strides=(1, 1), name='conv2'), 210 | Dense(512, name='dense1', init=he_normal(0.01)), 211 | Dense(self._nb_actions, activation=None, init=he_normal(0.01), name='qvalues') 212 | ]) 213 | return model -------------------------------------------------------------------------------- /malmopy/model/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | 21 | class BaseModel(object): 22 | """Represents a learning capable entity""" 23 | 24 | def __init__(self, in_shape, output_shape): 25 | self._input_shape = in_shape 26 | self._output_shape = output_shape 27 | 28 | @property 29 | def input_shape(self): 30 | return self._input_shape 31 | 32 | @property 33 | def output_shape(self): 34 | return self._output_shape 35 | 36 | @property 37 | def loss_val(self): 38 | raise NotImplementedError() 39 | 40 | def evaluate(self, environment): 41 | raise NotImplementedError() 42 | 43 | def train(self, x, y): 44 | raise NotImplementedError() 45 | 46 | def load(self, input_file): 47 | raise NotImplementedError() 48 | 49 | def save(self, output_file): 50 | raise NotImplementedError() 51 | 52 | 53 | class QModel(BaseModel): 54 | ACTION_VALUE_NETWORK = 1 << 0 55 | TARGET_NETWORK = 1 << 1 56 | 57 | def evaluate(self, environment, model=ACTION_VALUE_NETWORK): 58 | raise NotImplementedError() 59 | 60 | def train(self, x, y, actions=None): 61 | raise NotImplementedError() 62 | 63 | 64 | -------------------------------------------------------------------------------- /malmopy/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from .images import resize, rgb2gray 21 | from .util import * 22 | 23 | 24 | -------------------------------------------------------------------------------- /malmopy/util/images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | import sys 19 | import numpy as np 20 | 21 | OPENCV_AVAILABLE = False 22 | PILLOW_AVAILABLE = False 23 | 24 | try: 25 | import cv2 26 | 27 | OPENCV_AVAILABLE = True 28 | print('OpenCV found, setting as default backend.') 29 | except ImportError: 30 | pass 31 | 32 | try: 33 | import PIL 34 | 35 | PILLOW_AVAILABLE = True 36 | 37 | if not OPENCV_AVAILABLE: 38 | print('Pillow found, setting as default backend.') 39 | except ImportError: 40 | pass 41 | 42 | 43 | if not (OPENCV_AVAILABLE or PILLOW_AVAILABLE): 44 | raise ValueError('No image library backend found.'' Install either ' 45 | 'OpenCV or Pillow to support image processing.') 46 | 47 | 48 | def resize(img, shape): 49 | """ 50 | Resize the specified image 51 | :param img: Image to reshape 52 | :param shape: New image shape 53 | :return: 54 | """ 55 | if OPENCV_AVAILABLE: 56 | from cv2 import resize 57 | return resize(img, shape) 58 | elif PILLOW_AVAILABLE: 59 | from PIL import Image 60 | return np.array(Image.fromarray(img).resize(shape)) 61 | 62 | 63 | def rgb2gray(img): 64 | """ 65 | Convert an RGB image to grayscale 66 | :param img: image to convert 67 | :return: 68 | """ 69 | if OPENCV_AVAILABLE: 70 | from cv2 import cvtColor, COLOR_RGB2GRAY 71 | return cvtColor(img, COLOR_RGB2GRAY) 72 | elif PILLOW_AVAILABLE: 73 | from PIL import Image 74 | return np.array(Image.fromarray(img).convert('L')) 75 | -------------------------------------------------------------------------------- /malmopy/util/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | import os 21 | from math import sqrt 22 | 23 | import numpy as np 24 | 25 | 26 | def euclidean(a, b): 27 | assert len(a) == len(b), 'cannot compute distance when a and b have different shapes' 28 | return sqrt(sum([(a - b) ** 2 for a, b in zip(a, b)])) 29 | 30 | 31 | def get_rank(x): 32 | """ Get a shape's rank """ 33 | if isinstance(x, np.ndarray): 34 | return len(x.shape) 35 | elif isinstance(x, tuple): 36 | return len(x) 37 | else: 38 | return ValueError('Unable to determine rank of type: %s' % str(type(x))) 39 | 40 | 41 | def check_rank(shape, required_rank): 42 | """ Check if the shape's rank equals the expected rank """ 43 | if isinstance(shape, tuple): 44 | return len(shape) == required_rank 45 | else: 46 | return False 47 | 48 | 49 | def isclose(a, b, atol=1e-01): 50 | """ Check if a and b are closer than tolerance level atol 51 | 52 | return abs(a - b) < atol 53 | """ 54 | return abs(a - b) < atol 55 | 56 | 57 | def ensure_path_exists(path): 58 | """ Ensure that the specified path exists on the filesystem """ 59 | if not os.path.isabs(path): 60 | path = os.path.abspath(path) 61 | if not os.path.exists(path): 62 | os.makedirs(path) 63 | -------------------------------------------------------------------------------- /malmopy/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | VERSION = '0.1.0' 19 | -------------------------------------------------------------------------------- /malmopy/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from .visualizer import BaseVisualizer, ConsoleVisualizer, EmptyVisualizer, Visualizable 21 | -------------------------------------------------------------------------------- /malmopy/visualization/tensorboard/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from .tensorboard import TensorboardVisualizer, TensorflowConverter 21 | -------------------------------------------------------------------------------- /malmopy/visualization/tensorboard/cntk/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from .cntk import CntkConverter 21 | -------------------------------------------------------------------------------- /malmopy/visualization/tensorboard/cntk/cntk.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | import six 21 | import tensorflow as tf 22 | from tensorflow.core.framework import attr_value_pb2, tensor_shape_pb2 23 | 24 | from ..tensorboard import TensorflowConverter 25 | 26 | 27 | class CntkConverter(TensorflowConverter): 28 | def convert(self, network, graph): 29 | """ 30 | Converts a function from CNTK to the Tensorflow graph format 31 | 32 | Args: 33 | network: CNTK function that defines the network structure 34 | graph: destination Tensorflow graph 35 | """ 36 | # Walk every node of the network iteratively 37 | stack = [network.model] 38 | visited = set() 39 | 40 | while stack: 41 | node = stack.pop() 42 | 43 | if node in visited: 44 | continue 45 | 46 | try: 47 | 48 | # Function node 49 | node = node.root_function 50 | stack.extend(node.inputs) 51 | try: 52 | # TF graph already has the current node 53 | graph.get_operation_by_name(node.uid.split('_')[0]) 54 | continue 55 | 56 | except KeyError: 57 | # New network node that has to be converted to TF format 58 | # define TF operation attributes based on CNTK network node 59 | try: 60 | dim_x = tensor_shape_pb2.TensorShapeProto.Dim(size=node.outputs[0].shape[0]) 61 | except IndexError: 62 | dim_x = tensor_shape_pb2.TensorShapeProto.Dim(size=1) 63 | try: 64 | dim_y = tensor_shape_pb2.TensorShapeProto.Dim(size=node.outputs[0].shape[1]) 65 | except IndexError: 66 | dim_y = tensor_shape_pb2.TensorShapeProto.Dim(size=1) 67 | shape = tensor_shape_pb2.TensorShapeProto(dim=(dim_x, dim_y)) 68 | shape_attr = attr_value_pb2.AttrValue(shape=shape) 69 | attrs = {"shape": shape_attr} 70 | 71 | # Use name scope based on the node's name (e.g. Plus1) to 72 | # group the operation and its inputs 73 | with graph.name_scope(node.uid) as _: 74 | 75 | # Create a TF placeholder operation with type, name and shape of the current node 76 | op = graph.create_op("Placeholder", inputs=[], 77 | dtypes=[node.outputs[0].dtype], attrs=attrs, 78 | name=node.uid) 79 | 80 | # Add inputs to the created TF operation 81 | for i in six.moves.range(len(node.inputs)): 82 | child = node.inputs[i] 83 | name = child.uid 84 | try: 85 | # The input tensor already exists in the graph 86 | tf_input = graph.get_tensor_by_name(name + ":0") 87 | except KeyError: 88 | # A new tensor that needs to be converted from CNTK to TF 89 | shape = self.convert_shape(child.shape) 90 | dtype = child.dtype 91 | # Create a new placeholder tensor with the corresponding attributes 92 | tf_input = tf.placeholder(shape=shape, dtype=dtype, name=name) 93 | 94 | # Update TF operator's inputs 95 | op._add_input(tf_input) 96 | 97 | # Update TF operation's outputs 98 | output = node.outputs[0] 99 | for o in graph.get_operations(): 100 | if output.uid in o.name: 101 | o._add_input(op.outputs[0]) 102 | 103 | except AttributeError: 104 | # OutputVariable node 105 | try: 106 | if node.is_output: 107 | try: 108 | # Owner of the node is already added to the TF graph 109 | owner_name = node.owner.uid + '/' + node.owner.uid 110 | graph.get_operation_by_name(owner_name) 111 | except KeyError: 112 | # Unknown network node 113 | stack.append(node.owner) 114 | 115 | except AttributeError: 116 | pass 117 | 118 | # Add missing connections in the graph 119 | CntkConverter.update_outputs(graph.get_operations()) 120 | graph.finalize() 121 | 122 | @staticmethod 123 | def convert_shape(shape): 124 | if len(shape) == 0: 125 | shape = (1, 1) 126 | else: 127 | if len(shape) == 1: 128 | shape += (1,) 129 | return shape 130 | 131 | @staticmethod 132 | def update_outputs(ops): 133 | """Updates the inputs/outputs of the Tensorflow operations 134 | by adding missing connections 135 | 136 | Args: 137 | ops: a list of Tensorflow operations 138 | """ 139 | for i in six.moves.range(len(ops)): 140 | for j in six.moves.range(i + 1, len(ops)): 141 | if ops[i].name.split('/')[1] in ops[j].name.split('/')[1]: 142 | ops[i]._add_input(ops[j].outputs[0]) 143 | -------------------------------------------------------------------------------- /malmopy/visualization/tensorboard/tensorboard.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | import six 21 | import tensorflow as tf 22 | from tensorflow.core.framework.summary_pb2 import Summary 23 | 24 | from ..visualizer import BaseVisualizer 25 | 26 | 27 | class TensorboardVisualizer(BaseVisualizer): 28 | """ 29 | Visualize the generated results in Tensorboard 30 | """ 31 | 32 | def __init__(self): 33 | super(TensorboardVisualizer, self).__init__() 34 | 35 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.01) 36 | self._session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) 37 | self._train_writer = None 38 | 39 | def initialize(self, logdir, model, converter=None): 40 | assert logdir is not None, "logdir cannot be None" 41 | assert isinstance(logdir, six.string_types), "logdir should be a string" 42 | 43 | if converter is not None: 44 | assert isinstance(converter, TensorflowConverter), \ 45 | "converter should derive from TensorflowConverter" 46 | converter.convert(model, self._session.graph) 47 | 48 | self._train_writer = tf.summary.FileWriter(logdir=logdir, 49 | graph=self._session.graph, 50 | flush_secs=30) 51 | 52 | def add_entry(self, index, tag, value, **kwargs): 53 | if "image" in kwargs and value is not None: 54 | image_string = tf.image.encode_jpeg(value, optimize_size=True, quality=80) 55 | summary_value = Summary.Image(width=value.shape[1], 56 | height=value.shape[0], 57 | colorspace=value.shape[2], 58 | encoded_image_string=image_string) 59 | else: 60 | summary_value = Summary.Value(tag=tag, simple_value=value) 61 | 62 | if summary_value is not None: 63 | entry = Summary(value=[summary_value]) 64 | self._train_writer.add_summary(entry, index) 65 | 66 | def close(self): 67 | if self._train_writer is not None: 68 | self._train_writer.close() 69 | 70 | def __enter__(self): 71 | return self 72 | 73 | def __exit__(self, exc_type, exc_val, exc_tb): 74 | self.close() 75 | 76 | 77 | class TensorflowConverter(object): 78 | def convert(self, network, graph): 79 | raise NotImplementedError() 80 | -------------------------------------------------------------------------------- /malmopy/visualization/visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from __future__ import absolute_import 19 | 20 | from os import path 21 | 22 | 23 | class Visualizable(object): 24 | def __init__(self, visualizer=None): 25 | if visualizer is not None: 26 | assert isinstance(visualizer, BaseVisualizer), "visualizer should derive from BaseVisualizer" 27 | 28 | self._visualizer = visualizer 29 | 30 | def visualize(self, index, tag, value, **kwargs): 31 | if self._visualizer is not None: 32 | self._visualizer << (index, tag, value, kwargs) 33 | 34 | @property 35 | def can_visualize(self): 36 | return self._visualizer is not None 37 | 38 | 39 | class BaseVisualizer(object): 40 | """ Provide a unified interface for observing the training progress """ 41 | 42 | def add_entry(self, index, key, result, **kwargs): 43 | raise NotImplementedError() 44 | 45 | def __lshift__(self, other): 46 | if isinstance(other, tuple): 47 | if len(other) >= 3: 48 | self.add_entry(other[0], str(other[1]), other[2]) 49 | else: 50 | raise ValueError("Provided tuple should be of the form (key, value)") 51 | else: 52 | raise ValueError("Trying to use stream operator without a tuple (key, value)") 53 | 54 | 55 | class EmptyVisualizer(BaseVisualizer): 56 | """ A boilerplate visualizer that does nothing """ 57 | 58 | def add_entry(self, index, key, result, **kwargs): 59 | pass 60 | 61 | 62 | class ConsoleVisualizer(BaseVisualizer): 63 | """ Print visualization to stdout as: 64 | key -> value 65 | """ 66 | CONSOLE_DEFAULT_FORMAT = "[%s] %d : %s -> %.3f" 67 | 68 | def __init__(self, format=None, prefix=None): 69 | self._format = format or ConsoleVisualizer.CONSOLE_DEFAULT_FORMAT 70 | self._prefix = prefix or '-' 71 | 72 | def add_entry(self, index, key, result, **kwargs): 73 | print(self._format % (self._prefix, index, key, result)) 74 | 75 | 76 | class CsvVisualizer(BaseVisualizer): 77 | """ Write data to file. The following formats are supported: CSV, JSON, Excel. """ 78 | def __init__(self, output_file, override=False): 79 | if path.exists(output_file) and not override: 80 | raise Exception('%s already exists and override is False' % output_file) 81 | 82 | super(CsvVisualizer, self).__init__() 83 | self._file = output_file 84 | self._data = {} 85 | 86 | def add_entry(self, index, key, result, **kwargs): 87 | if key in self._data[index]: 88 | print('Warning: Found previous value for %s in visualizer' % key) 89 | 90 | self._data[index].update({key: result}) 91 | 92 | def close(self, format='csv'): 93 | import pandas as pd 94 | 95 | if format == 'csv': 96 | pd.DataFrame.from_dict(self._data, orient='index').to_csv(self._file) 97 | elif format == 'json': 98 | pd.DataFrame.from_dict(self._data, orient='index').to_json(self._file) 99 | else: 100 | writer = pd.ExcelWriter(self._file) 101 | pd.DataFrame.from_dict(self._data, orient='index').to_excel(writer) 102 | writer.save() 103 | 104 | def __enter__(self): 105 | return self 106 | 107 | def __exit__(self, exc_type, exc_val, exc_tb): 108 | self.close() 109 | return self 110 | -------------------------------------------------------------------------------- /samples/atari/gym_atari_dqn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017 Microsoft Corporation. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated 4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the 5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, 6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO 12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 15 | # SOFTWARE. 16 | # =================================================================================================================== 17 | 18 | from argparse import ArgumentParser 19 | from datetime import datetime 20 | from subprocess import Popen 21 | 22 | from malmopy.agent import QLearnerAgent, TemporalMemory 23 | from malmopy.environment.gym import GymEnvironment 24 | 25 | try: 26 | from malmopy.visualization.tensorboard import TensorboardVisualizer 27 | from malmopy.visualization.tensorboard.cntk import CntkConverter 28 | 29 | TENSORBOARD_AVAILABLE = True 30 | except ImportError: 31 | print('Cannot import tensorboard, using ConsoleVisualizer.') 32 | from malmopy.visualization import ConsoleVisualizer 33 | 34 | TENSORBOARD_AVAILABLE = False 35 | 36 | 37 | ROOT_FOLDER = 'results/baselines/%s/dqn/%s-%s' 38 | EPOCH_SIZE = 250000 39 | 40 | 41 | def visualize_training(visualizer, step, rewards, tag='Training'): 42 | visualizer.add_entry(step, '%s/reward per episode' % tag, sum(rewards)) 43 | visualizer.add_entry(step, '%s/max.reward' % tag, max(rewards)) 44 | visualizer.add_entry(step, '%s/min.reward' % tag, min(rewards)) 45 | visualizer.add_entry(step, '%s/actions per episode' % tag, len(rewards)-1) 46 | 47 | 48 | def run_experiment(environment, backend, device_id, max_epoch, record, logdir, 49 | visualizer): 50 | 51 | env = GymEnvironment(environment, 52 | monitoring_path=logdir if record else None) 53 | 54 | if backend == 'cntk': 55 | from malmopy.model.cntk import QNeuralNetwork as CntkDQN 56 | model = CntkDQN((4, 84, 84), env.available_actions, momentum=0.95, 57 | device_id=device_id, visualizer=visualizer) 58 | else: 59 | from malmopy.model.chainer import DQNChain, QNeuralNetwork as ChainerDQN 60 | chain = DQNChain((4, 84, 84), env.available_actions) 61 | target_chain = DQNChain((4, 84, 84), env.available_actions) 62 | model = ChainerDQN(chain, target_chain, 63 | momentum=0.95, device_id=device_id) 64 | 65 | memory = TemporalMemory(1000000, model.input_shape[1:]) 66 | agent = QLearnerAgent("DQN Agent", env.available_actions, model, memory, 67 | 0.99, 32, train_after=10000, reward_clipping=(-1, 1), 68 | visualizer=visualizer) 69 | 70 | state = env.reset() 71 | reward = 0 72 | agent_done = False 73 | viz_rewards = [] 74 | 75 | max_training_steps = max_epoch * EPOCH_SIZE 76 | for step in range(1, max_training_steps + 1): 77 | 78 | # check if env needs reset 79 | if env.done: 80 | visualize_training(visualizer, step, viz_rewards) 81 | agent.inject_summaries(step) 82 | viz_rewards = [] 83 | state = env.reset() 84 | 85 | # select an action 86 | action = agent.act(state, reward, agent_done, is_training=True) 87 | 88 | # take a step 89 | state, reward, agent_done = env.do(action) 90 | viz_rewards.append(reward) 91 | 92 | if (step % EPOCH_SIZE) == 0: 93 | model.save('%s-%s-dqn_%d.model' % 94 | (backend, environment, step / EPOCH_SIZE)) 95 | 96 | 97 | if __name__ == '__main__': 98 | arg_parser = ArgumentParser(description='OpenAI Gym DQN example') 99 | arg_parser.add_argument('-b', '--backend', type=str, default='cntk', 100 | choices=['cntk', 'chainer'], 101 | help='Neural network backend to use.') 102 | arg_parser.add_argument('-d', '--device', type=int, default=-1, 103 | help='GPU device on which to run the experiment.') 104 | arg_parser.add_argument('-r', '--record', action='store_true', 105 | help='Setting this will record runs') 106 | arg_parser.add_argument('-e', '--epochs', type=int, default=50, 107 | help='Number of epochs. One epoch is 250k actions.') 108 | arg_parser.add_argument('-p', '--port', type=int, default=6006, 109 | help='Port for running tensorboard.') 110 | arg_parser.add_argument('env', type=str, metavar='environment', 111 | nargs='?', default='Breakout-v3', 112 | help='Gym environment to run') 113 | 114 | args = arg_parser.parse_args() 115 | 116 | logdir = ROOT_FOLDER % (args.env, args.backend, datetime.utcnow().isoformat()) 117 | if TENSORBOARD_AVAILABLE: 118 | visualizer = TensorboardVisualizer() 119 | visualizer.initialize(logdir, None) 120 | print('Starting tensorboard ...') 121 | p = Popen(['tensorboard', '--logdir=results', '--port=%d' % args.port]) 122 | 123 | else: 124 | visualizer = ConsoleVisualizer() 125 | 126 | print('Starting experiment') 127 | run_experiment(args.env, args.backend, int(args.device), args.epochs, 128 | args.record, logdir, visualizer) 129 | 130 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from distutils.core import setup 4 | 5 | from setuptools import find_packages 6 | 7 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'malmopy')) 8 | from version import VERSION 9 | 10 | extras = { 11 | 'chainer': ['chainer>=1.21.0'], 12 | 'gym': ['gym[atari]>=0.7.0'], 13 | 'tensorflow': ['tensorflow'], 14 | } 15 | 16 | # Meta dependency groups. 17 | all_deps = [] 18 | for group_name in extras: 19 | all_deps += extras[group_name] 20 | extras['all'] = all_deps 21 | 22 | setup( 23 | name='malmopy', 24 | version=VERSION, 25 | 26 | packages=[package for package in find_packages() 27 | if package.startswith('malmopy')], 28 | 29 | url='https://github.com/Microsoft/malmo-challenge', 30 | license='MIT', 31 | author='Microsoft Research Cambridge', 32 | author_email='', 33 | description='Malmo Collaborative AI Challenge task and example code', 34 | install_requires=['future', 'numpy>=1.11.0', 'six>=0.10.0', 'pandas', 'Pillow'], 35 | extras_require=extras 36 | ) 37 | --------------------------------------------------------------------------------