├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── SECURITY.md
├── ai_challenge
├── README.md
└── pig_chase
│ ├── README.md
│ ├── agent.py
│ ├── common.py
│ ├── environment.py
│ ├── evaluation.py
│ ├── pig-chase-overview.png
│ ├── pig_chase.xml
│ ├── pig_chase_baseline.py
│ ├── pig_chase_dqn.py
│ ├── pig_chase_dqn_top_down.py
│ ├── pig_chase_eval_sample.py
│ └── pig_chase_human_vs_agent.py
├── docker
├── README.md
├── malmo
│ ├── Dockerfile
│ ├── options.txt
│ └── run.sh
├── malmopy-ai-challenge
│ └── docker-compose.yml
├── malmopy-chainer-cpu
│ └── Dockerfile
├── malmopy-chainer-gpu
│ └── Dockerfile
├── malmopy-cntk-cpu-py27
│ └── Dockerfile
└── malmopy-cntk-gpu-py27
│ └── Dockerfile
├── malmopy
├── README.md
├── __init__.py
├── agent
│ ├── __init__.py
│ ├── agent.py
│ ├── astar.py
│ ├── explorer.py
│ ├── gui.py
│ └── qlearner.py
├── environment
│ ├── __init__.py
│ ├── environment.py
│ ├── gym
│ │ ├── __init__.py
│ │ └── gym.py
│ └── malmo
│ │ ├── __init__.py
│ │ └── malmo.py
├── model
│ ├── __init__.py
│ ├── chainer
│ │ ├── __init__.py
│ │ └── qlearning.py
│ ├── cntk
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── qlearning.py
│ └── model.py
├── util
│ ├── __init__.py
│ ├── images.py
│ └── util.py
├── version.py
└── visualization
│ ├── __init__.py
│ ├── tensorboard
│ ├── __init__.py
│ ├── cntk
│ │ ├── __init__.py
│ │ └── cntk.py
│ └── tensorboard.py
│ └── visualizer.py
├── samples
└── atari
│ └── gym_atari_dqn.py
└── setup.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Set default behavior to automatically normalize line endings.
3 | ###############################################################################
4 | * text=auto
5 |
6 | ###############################################################################
7 | # Set default behavior for command prompt diff.
8 | #
9 | # This is need for earlier builds of msysgit that does not have it on by
10 | # default for csharp files.
11 | # Note: This is only used by command line
12 | ###############################################################################
13 | #*.cs diff=csharp
14 |
15 | ###############################################################################
16 | # Set the merge driver for project and solution files
17 | #
18 | # Merging from the command prompt will add diff markers to the files if there
19 | # are conflicts (Merging from VS is not affected by the settings below, in VS
20 | # the diff markers are never inserted). Diff markers may cause the following
21 | # file extensions to fail to load in VS. An alternative would be to treat
22 | # these files as binary and thus will always conflict and require user
23 | # intervention with every merge. To do so, just uncomment the entries below
24 | ###############################################################################
25 | #*.sln merge=binary
26 | #*.csproj merge=binary
27 | #*.vbproj merge=binary
28 | #*.vcxproj merge=binary
29 | #*.vcproj merge=binary
30 | #*.dbproj merge=binary
31 | #*.fsproj merge=binary
32 | #*.lsproj merge=binary
33 | #*.wixproj merge=binary
34 | #*.modelproj merge=binary
35 | #*.sqlproj merge=binary
36 | #*.wwaproj merge=binary
37 |
38 | ###############################################################################
39 | # behavior for image files
40 | #
41 | # image files are treated as binary by default.
42 | ###############################################################################
43 | #*.jpg binary
44 | #*.png binary
45 | #*.gif binary
46 |
47 | ###############################################################################
48 | # diff behavior for common document formats
49 | #
50 | # Convert binary document formats to text before diffing them. This feature
51 | # is only available from the command line. Turn it on by uncommenting the
52 | # entries below.
53 | ###############################################################################
54 | #*.doc diff=astextplain
55 | #*.DOC diff=astextplain
56 | #*.docx diff=astextplain
57 | #*.DOCX diff=astextplain
58 | #*.dot diff=astextplain
59 | #*.DOT diff=astextplain
60 | #*.pdf diff=astextplain
61 | #*.PDF diff=astextplain
62 | #*.rtf diff=astextplain
63 | #*.RTF diff=astextplain
64 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 |
4 | # User-specific files
5 | *.suo
6 | *.user
7 | *.userosscache
8 | *.sln.docstates
9 |
10 | # User-specific files (MonoDevelop/Xamarin Studio)
11 | *.userprefs
12 |
13 | # Build results
14 | [Dd]ebug/
15 | [Dd]ebugPublic/
16 | [Rr]elease/
17 | [Rr]eleases/
18 | x64/
19 | x86/
20 | bld/
21 | [Bb]in/
22 | [Oo]bj/
23 | [Ll]og/
24 |
25 | # Visual Studio 2015 cache/options directory
26 | .vs/
27 | # Uncomment if you have tasks that create the project's static files in wwwroot
28 | #wwwroot/
29 |
30 | # MSTest test Results
31 | [Tt]est[Rr]esult*/
32 | [Bb]uild[Ll]og.*
33 |
34 | # NUNIT
35 | *.VisualState.xml
36 | TestResult.xml
37 |
38 | # Build Results of an ATL Project
39 | [Dd]ebugPS/
40 | [Rr]eleasePS/
41 | dlldata.c
42 |
43 | # DNX
44 | project.lock.json
45 | project.fragment.lock.json
46 | artifacts/
47 |
48 | *_i.c
49 | *_p.c
50 | *_i.h
51 | *.ilk
52 | *.meta
53 | *.obj
54 | *.pch
55 | *.pdb
56 | *.pgc
57 | *.pgd
58 | *.rsp
59 | *.sbr
60 | *.tlb
61 | *.tli
62 | *.tlh
63 | *.tmp
64 | *.tmp_proj
65 | *.log
66 | *.vspscc
67 | *.vssscc
68 | .builds
69 | *.pidb
70 | *.svclog
71 | *.scc
72 |
73 | # Chutzpah Test files
74 | _Chutzpah*
75 |
76 | # Visual C++ cache files
77 | ipch/
78 | *.aps
79 | *.ncb
80 | *.opendb
81 | *.opensdf
82 | *.sdf
83 | *.cachefile
84 | *.VC.db
85 | *.VC.VC.opendb
86 |
87 | # Visual Studio profiler
88 | *.psess
89 | *.vsp
90 | *.vspx
91 | *.sap
92 |
93 | # TFS 2012 Local Workspace
94 | $tf/
95 |
96 | # Guidance Automation Toolkit
97 | *.gpState
98 |
99 | # ReSharper is a .NET coding add-in
100 | _ReSharper*/
101 | *.[Rr]e[Ss]harper
102 | *.DotSettings.user
103 |
104 | # JustCode is a .NET coding add-in
105 | .JustCode
106 |
107 | # TeamCity is a build add-in
108 | _TeamCity*
109 |
110 | # DotCover is a Code Coverage Tool
111 | *.dotCover
112 |
113 | # NCrunch
114 | _NCrunch_*
115 | .*crunch*.local.xml
116 | nCrunchTemp_*
117 |
118 | # MightyMoose
119 | *.mm.*
120 | AutoTest.Net/
121 |
122 | # Web workbench (sass)
123 | .sass-cache/
124 |
125 | # Installshield output folder
126 | [Ee]xpress/
127 |
128 | # DocProject is a documentation generator add-in
129 | DocProject/buildhelp/
130 | DocProject/Help/*.HxT
131 | DocProject/Help/*.HxC
132 | DocProject/Help/*.hhc
133 | DocProject/Help/*.hhk
134 | DocProject/Help/*.hhp
135 | DocProject/Help/Html2
136 | DocProject/Help/html
137 |
138 | # Click-Once directory
139 | publish/
140 |
141 | # Publish Web Output
142 | *.[Pp]ublish.xml
143 | *.azurePubxml
144 | # TODO: Comment the next line if you want to checkin your web deploy settings
145 | # but database connection strings (with potential passwords) will be unencrypted
146 | #*.pubxml
147 | *.publishproj
148 |
149 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
150 | # checkin your Azure Web App publish settings, but sensitive information contained
151 | # in these scripts will be unencrypted
152 | PublishScripts/
153 |
154 | # NuGet Packages
155 | *.nupkg
156 | # The packages folder can be ignored because of Package Restore
157 | **/packages/*
158 | # except build/, which is used as an MSBuild target.
159 | !**/packages/build/
160 | # Uncomment if necessary however generally it will be regenerated when needed
161 | #!**/packages/repositories.config
162 | # NuGet v3's project.json files produces more ignoreable files
163 | *.nuget.props
164 | *.nuget.targets
165 |
166 | # Microsoft Azure Build Output
167 | csx/
168 | *.build.csdef
169 |
170 | # Microsoft Azure Emulator
171 | ecf/
172 | rcf/
173 |
174 | # Windows Store app package directories and files
175 | AppPackages/
176 | BundleArtifacts/
177 | Package.StoreAssociation.xml
178 | _pkginfo.txt
179 |
180 | # Visual Studio cache files
181 | # files ending in .cache can be ignored
182 | *.[Cc]ache
183 | # but keep track of directories ending in .cache
184 | !*.[Cc]ache/
185 |
186 | # Others
187 | ClientBin/
188 | ~$*
189 | *~
190 | *.dbmdl
191 | *.dbproj.schemaview
192 | *.jfm
193 | *.pfx
194 | *.publishsettings
195 | node_modules/
196 | orleans.codegen.cs
197 |
198 | # Since there are multiple workflows, uncomment next line to ignore bower_components
199 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
200 | #bower_components/
201 |
202 | # RIA/Silverlight projects
203 | Generated_Code/
204 |
205 | # Backup & report files from converting an old project file
206 | # to a newer Visual Studio version. Backup files are not needed,
207 | # because we have git ;-)
208 | _UpgradeReport_Files/
209 | Backup*/
210 | UpgradeLog*.XML
211 | UpgradeLog*.htm
212 |
213 | # SQL Server files
214 | *.mdf
215 | *.ldf
216 |
217 | # Business Intelligence projects
218 | *.rdl.data
219 | *.bim.layout
220 | *.bim_*.settings
221 |
222 | # Microsoft Fakes
223 | FakesAssemblies/
224 |
225 | # GhostDoc plugin setting file
226 | *.GhostDoc.xml
227 |
228 | # Node.js Tools for Visual Studio
229 | .ntvs_analysis.dat
230 |
231 | # Visual Studio 6 build log
232 | *.plg
233 |
234 | # Visual Studio 6 workspace options file
235 | *.opt
236 |
237 | # Visual Studio LightSwitch build output
238 | **/*.HTMLClient/GeneratedArtifacts
239 | **/*.DesktopClient/GeneratedArtifacts
240 | **/*.DesktopClient/ModelManifest.xml
241 | **/*.Server/GeneratedArtifacts
242 | **/*.Server/ModelManifest.xml
243 | _Pvt_Extensions
244 |
245 | # Paket dependency manager
246 | .paket/paket.exe
247 | paket-files/
248 |
249 | # FAKE - F# Make
250 | .fake/
251 |
252 | # JetBrains Rider
253 | .idea/
254 | *.sln.iml
255 |
256 | # CodeRush
257 | .cr/
258 |
259 | # Python Tools for Visual Studio (PTVS)
260 | __pycache__/
261 | *.pyc
262 |
263 | # Tests cache
264 | */tests/.cache/v/cache/
265 |
266 | # Library
267 | *.pyd
268 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation. All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # The Malmo Collaborative AI Challenge
2 |
3 | This repository contains the task definition and example code for the [Malmo Collaborative AI Challenge](https://www.microsoft.com/en-us/research/academic-program/collaborative-ai-challenge/).
4 | This challenge is organized to encourage research in collaborative AI - to work towards AI agents
5 | that learn to collaborate to solve problems and achieve goals.
6 | You can find additional details, including terms and conditions, prizes and information on how to participate at the [Challenge Homepage](https://www.microsoft.com/en-us/research/academic-program/collaborative-ai-challenge/).
7 |
8 | [](https://gitter.im/malmo-challenge/Lobby?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
9 | [](https://github.com/Microsoft/malmo-challenge/blob/master/LICENSE)
10 |
11 | ----
12 |
13 | **Notes for challenge participants:** Once you and your team decide to participate in the challenge, please make sure to register your team at our [Registration Page](https://www.surveygizmo.com/s3/3299773/The-Collaborative-AI-Challenge). On the registration form, you need to provide a link to the GitHub repository that will
14 | contain your solution. We recommend that you fork this repository (learn how),
15 | and provide address of the forked repo. You can then update your submission as you make progress on the challenge task.
16 | We will consider the version of the code on branch master at the time of the submission deadline as your challenge submission. Your submission needs to contain code in working order, a 1-page description of your approach, and a 1-minute video that shows off your agent. Please see the [challenge terms and conditions]() for further details.
17 |
18 | ----
19 |
20 | **Jump to:**
21 |
22 | - [Installation](#installation)
23 | - [Prerequisites](#prerequisites)
24 | - [Minimal installation](#minimal-installation)
25 | - [Optional extensions](#optional-extensions)
26 |
27 | - [Getting started](#getting-started)
28 | - [Play the challenge task](#play-the-challenge-task)
29 | - [Run your first experiment](#run-your-first-experiment)
30 |
31 | - [Next steps](#next-steps)
32 | - [Run an experiment in Docker on Azure](#run-an-experiment-in-docker-on-azure)
33 | - [Compare your results again other teams](#compare-your-results-against-other-teams)
34 | - [Resources](#resources)
35 |
36 | # Installation
37 |
38 | ## Prerequisites
39 |
40 | - [Python](https://www.python.org/) 2.7+ (recommended) or 3.5+
41 | - [Project Malmo](https://github.com/Microsoft/malmo) - we recommend downloading the [Malmo-0.21.0 release](https://github.com/Microsoft/malmo/releases) and installing dependencies for [Windows](https://github.com/Microsoft/malmo/blob/master/doc/install_windows.md), [Linux](https://github.com/Microsoft/malmo/blob/master/doc/install_linux.md) or [MacOS](https://github.com/Microsoft/malmo/blob/master/doc/install_macosx.md). Test your Malmo installation by [launching Minecraft with Malmo](https://github.com/Microsoft/malmo#launching-minecraft-with-our-mod) and [launching an agent](https://github.com/Microsoft/malmo#launch-an-agent).
42 |
43 | ## Minimal installation
44 |
45 | ```
46 | pip install -e git+https://github.com/Microsoft/malmo-challenge#egg=malmopy
47 | ```
48 |
49 | or
50 |
51 | ```
52 | git clone https://github.com/Microsoft/malmo-challenge
53 | cd malmo-challenge
54 | pip install -e .
55 | ```
56 |
57 | ## Optional extensions
58 |
59 | Some of the example code uses additional dependencies to provide 'extra' functionality. These can be installed using:
60 |
61 | ```
62 | pip install -e '.[extra1, extra2]'
63 | ```
64 | For example to install gym and chainer:
65 |
66 | ```
67 | pip install -e '.[gym]'
68 | ```
69 |
70 | Or to install all extras:
71 |
72 | ```
73 | pip install -e '.[all]'
74 | ```
75 |
76 | The following extras are available:
77 | - `gym`: [OpenAI Gym](https://gym.openai.com/) is an interface to a wide range of reinforcement learning environments. Installing this extra enables the Atari example agents in [samples/atari](samples/atari) to train on the gym environments. *Note that OpenAI gym atari environments are currently not available on Windows.*
78 | - `tensorflow`: [TensorFlow](https://www.tensorflow.org/) is a popular deep learning framework developed by Google. In our examples it enables visualizations through [TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
79 |
80 |
81 | # Getting started
82 |
83 | ## Play the challenge task
84 |
85 | The challenge task takes the form of a mini game, called Pig Chase. Learn about the game, and try playing it yourself on our [Pig Chase Challenge page](ai_challenge/pig_chase/README.md).
86 |
87 | ## Run your first experiment
88 |
89 | See how to [run your first baseline experiment](ai_challenge/pig_chase/README.md#run-your-first-experiment) on the [Pig Chase Challenge page](ai_challenge/pig_chase/README.md).
90 |
91 | # Next steps
92 |
93 | ## Run an experiment in Docker on Azure
94 |
95 | Docker is a virtualization platform that makes it easy to deploy software with all its dependencies.
96 | We use docker to run experiments locally or in the cloud. Details on how to run an example experiment using docker are in the [docker README](docker/README.md).
97 |
98 | ## Compare your results against other teams:
99 |
100 | We provide you a [leaderboard website](https://malmo-leaderboard.azurewebsites.net/) where you can compare your results against the other participants.
101 |
102 |
103 | ## Resources
104 |
105 | - [Malmo Platform Tutorial](https://github.com/Microsoft/malmo/blob/master/Malmo/samples/Python_examples/Tutorial.pdf)
106 | - [Azure Portal](portal.azure.com/)
107 | - [Docker Documentation](https://docs.docker.com/)
108 | - [Docker Machine on Azure](https://docs.microsoft.com/en-us/azure/virtual-machines/virtual-machines-linux-docker-machine)
109 | - [CNTK Tutorials](https://www.microsoft.com/en-us/research/product/cognitive-toolkit/tutorials/)
110 | - [CNTK Documentation](https://github.com/Microsoft/CNTK/wiki)
111 | - [Chainer Documentation](http://docs.chainer.org/en/stable/)
112 | - [TensorBoard Documentation](https://www.tensorflow.org/get_started/summaries_and_tensorboard)
113 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/ai_challenge/README.md:
--------------------------------------------------------------------------------
1 | # Malmo Collaborative AI Challenge
2 |
3 | This folder contains task definitions for the Malmo Collaborative
4 | AI Challenge. For installation instructions see [Installation](../README.md#installation).
5 |
6 | ## Available challenges
7 |
8 | - [Malmo Collaborative AI Challenge - Pig Chase](pig_chase/README.md) : Try to build collaborative
9 | AI agents trying to catch a pig.
10 |
11 | ## Further reading
12 |
13 | Once you have familiarized yourself with a challenge task, you may want to head back to the [Overview Page](../README.md) to learn about [Installation](../README.md#installation) and [Getting started](../README.md#getting-started). Or dive into the code examples to learn how to [Write your first agent](../malmopy/README.md#write-your-first-agent).
14 |
15 |
--------------------------------------------------------------------------------
/ai_challenge/pig_chase/README.md:
--------------------------------------------------------------------------------
1 | # Malmo Collaborative AI Challenge - Pig Chase
2 |
3 | This repository contains Malmo Collaborative AI challenge task definition. The challenge task takes the form of a collaborative mini game, called Pig Chase.
4 |
5 | 
6 |
7 | ## Overview of the game
8 |
9 | Two Minecraft agents and a pig are wandering a small meadow. The agents have two choices:
10 |
11 | - _Catch the pig_ (i.e., the agents pinch or corner the pig, and no escape path is available), and receive a high reward (25 points)
12 | - _Give up_ and leave the pig pen through the exits to the left and right of the pen, marked by blue squares, and receive a small reward (5 points)
13 |
14 | The pig chased is inspired by the variant of the _stag hunt_ presented in [Yoshida et al. 2008]. The [stag hunt](https://en.wikipedia.org/wiki/Stag_hunt) is a classical game theoretic game formulation that captures conflicts between collaboration and individual safety.
15 |
16 | [Yoshida et al. 2008] Yoshida, Wako, Ray J. Dolan, and Karl J. Friston. ["Game theory of mind."](http://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1000254) PLoS Comput Biol 4.12 (2008): e1000254.
17 |
18 |
19 | ## How to play (human players)
20 |
21 | To familiarize yourself with the game, we recommend that you play it yourself. The following instructions allow you to play the game with a "focused agent". A baseline agent that tries to move towards the pig whenever possible.
22 |
23 | ### Prerequisites
24 |
25 | * Install the [Malmo Platform](https://github.com/Microsoft/malmo) and the `malmopy` framework as described under [Installation](../../README.md#installation), and verify that you can run the Malmo platform and python example agents
26 |
27 | ### Steps
28 |
29 | * Start two instances of the Malmo Client on ports `10000` and `10001`
30 | * `cd malmo-challenge/ai_challenge/pig_chase`
31 | * `python pig_chase_human_vs_agent.py`
32 |
33 | Wait for a few seconds for the human player interface to appear.
34 |
35 | Note: the script assumes that two Malmo clients are running on the default ports on localhost. You can specify alternative clients on the command line. See the script's usage instructions (`python pig_chase_human_vs_agent.py -h`) for details.
36 |
37 | ### How to play
38 |
39 | * The game is played over 10 rounds at a time. Goal is to accumulate the highest score over these 10 rounds.
40 | * In each round a "collaborator" agent is selected to play with you. Different collaborators may have different behaviors.
41 | * Once the game has started, use the left/right arrow keys to turn, and the forward/backward keys to move. You can see your agent move in the first person view, and shown as a red arrow in the top-down rendering on the left.
42 | * You and your collaborator move in turns and try to catch the pig (25 points if caught). You can give up on catching the pig in the current round by moving to the blue "exit squares" (5 points). You have a maximum of 25 steps available, and will get -1 point for each step taken.
43 |
44 | ## Run your first experiment
45 |
46 | An example experiment is provided in `pig_chase_baseline.py`. To run it, start two instances of the Malmo Client as [above](#steps). Then run:
47 |
48 | ```
49 | python pig_chase_baseline.py
50 | ```
51 |
52 | Depending on whether `tensorboard` is available on your system, this script will output performance statistics to either tensorboard or to console. If using tensorboard, you can plot the stored data by pointing a tensorboard instance to the results folder:
53 |
54 | ```
55 | cd ai_challenge/pig_chase
56 | tensorboard --logdir=results --port=6006
57 | ```
58 |
59 | You can then navigate to http://127.0.0.1:6006 to view the results.
60 |
61 | The baseline script runs a `FocusedAgent` by default - it uses a simple planning algorithm to find a shortest path to the pig. You can also run a `RandomAgent` baseline. Switch agents using the command line arguments:
62 |
63 | ```
64 | python pig_chase_baseline.py -t random
65 | ```
66 |
67 | For additional command line options, see the usage instructions: `python pig_chase_baseline.py -h`.
68 |
69 | ## Evaluate your agent
70 |
71 | We provide a commodity evaluator PigChaseEvaluator, which allows you to quickly evaluate
72 | the performance of your agent.
73 |
74 | PigChaseEvaluator takes 2 arguments:
75 | - agent_100k : Your agent trained with 100k steps (100k train calls)
76 | - agent_500k : Your agent trained with 500k steps (500k train calls)
77 |
78 | To evaluate your agent:
79 |
80 | ``` python
81 | # Creates an agent trained with 100k train calls
82 | my_agent_100k = MyCustomAgent()
83 |
84 | # Creates an agent trained with 500k train calls
85 | my_agent_500k = MyCustomAgent()
86 |
87 | # You can pass a custom StateBuilder for your agent.
88 | # It will be used by the environment to generate state for your agent
89 | eval = PigChaseEvaluator(my_agent_100k, my_agent_500k, MyStateBuilder())
90 |
91 | # Run and save
92 | eval.run()
93 | eval.save('My experiment 1', 'path/to/save.json')
94 | ```
95 |
96 | ## Compare against other teams:
97 |
98 | Submit your evaluation results on the [leaderboard website](https://malmo-leaderboard.azurewebsites.net/) to compare your results against the other participants.
99 |
100 |
101 | ## Next steps
102 |
103 | To participate in the Collaborative AI Challenge, implement and train an agent that can effectively collaborate with any collaborator. Your agent can use either the first-person visual view, or the symbolic view (as demonstrated in the `FocusedAgent`). You can use any AI/learning approach you like - originality of the chose approach is part of the criteria for the challenge prizes. Can you come up with an agent learns to outperform the A-star baseline agent? Can an agent learn to play with a copy of itself? Can it outperform your own (human) score?
104 |
105 | For more inspiration, you can look at more [code samples](../../samples/README.md) or learn how to [run experiments on Azure using docker](../../docker/README.md).
106 |
107 |
108 |
109 |
--------------------------------------------------------------------------------
/ai_challenge/pig_chase/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | ENV_AGENT_NAMES = ['Agent_1', 'Agent_2']
19 | ENV_TARGET_NAMES = ['Pig']
20 | ENV_ENTITIES_NAME = ENV_AGENT_NAMES + ENV_TARGET_NAMES
21 | ENV_ACTIONS = ["move 1", "turn -1", "turn 1"]
22 | ENV_ENTITIES = 'entities'
23 | ENV_BOARD = 'board'
24 | ENV_BOARD_SHAPE = (9, 9)
25 | ENV_INDIVIDUAL_REWARD = 5
26 | ENV_CAUGHT_REWARD = 25
27 | class ENV_AGENT_TYPES:
28 | RANDOM, FOCUSED, TABQ, DEEPQ, HUMAN, OTHER = range(0, 6)
29 |
30 | def parse_clients_args(args_clients):
31 | """
32 | Return an array of tuples (ip, port) extracted from ip:port string
33 | :param args_clients:
34 | :return:
35 | """
36 | return [str.split(str(client), ':') for client in args_clients]
37 |
38 | def visualize_training(visualizer, step, rewards, tag='Training'):
39 | visualizer.add_entry(step, '%s/reward per episode' % tag, sum(rewards))
40 | visualizer.add_entry(step, '%s/max.reward' % tag, max(rewards))
41 | visualizer.add_entry(step, '%s/min.reward' % tag, min(rewards))
42 | visualizer.add_entry(step, '%s/actions per episode' % tag, len(rewards)-1)
43 |
44 | class Entity(object):
45 | """ Wrap entity attributes """
46 |
47 | def __init__(self, x, y, z, yaw, pitch, name=''):
48 | self._name = name
49 | self._x = int(x)
50 | self._y = int(y)
51 | self._z = int(z)
52 | self._yaw = int(yaw) % 360
53 | self._pitch = int(pitch)
54 |
55 | @property
56 | def name(self):
57 | return self._name
58 |
59 | @property
60 | def x(self):
61 | return self._x
62 |
63 | @x.setter
64 | def x(self, value):
65 | self._x = int(value)
66 |
67 | @property
68 | def y(self):
69 | return self._y
70 |
71 | @y.setter
72 | def y(self, value):
73 | self._y = int(value)
74 |
75 | @property
76 | def z(self):
77 | return self._z
78 |
79 | @z.setter
80 | def z(self, value):
81 | self._z = int(value)
82 |
83 | @property
84 | def yaw(self):
85 | return self._yaw
86 |
87 | @yaw.setter
88 | def yaw(self, value):
89 | self._yaw = int(value) % 360
90 |
91 | @property
92 | def pitch(self):
93 | return self._pitch
94 |
95 | @pitch.setter
96 | def pitch(self, value):
97 | self._pitch = int(value)
98 |
99 | @property
100 | def position(self):
101 | return self._x, self._y, self._z
102 |
103 | def __eq__(self, other):
104 | if isinstance(other, tuple):
105 | return self.position == other
106 |
107 | def __getitem__(self, item):
108 | return getattr(self, item)
109 |
110 | @classmethod
111 | def create(cls, obj):
112 | return cls(obj['x'], obj['y'], obj['z'], obj['yaw'], obj['pitch'])
113 |
--------------------------------------------------------------------------------
/ai_challenge/pig_chase/evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | import os
19 | import sys
20 | from time import sleep
21 |
22 | from common import parse_clients_args, ENV_AGENT_NAMES
23 | from agent import PigChaseChallengeAgent
24 | from common import ENV_AGENT_NAMES
25 | from environment import PigChaseEnvironment, PigChaseSymbolicStateBuilder
26 |
27 | # Enforce path
28 | sys.path.insert(0, os.getcwd())
29 | sys.path.insert(1, os.path.join(os.path.pardir, os.getcwd()))
30 |
31 |
32 | class PigChaseEvaluator(object):
33 | def __init__(self, clients, agent_100k, agent_500k, state_builder):
34 | assert len(clients) >= 2, 'Not enough clients provided'
35 |
36 | self._clients = clients
37 | self._agent_100k = agent_100k
38 | self._agent_500k = agent_500k
39 | self._state_builder = state_builder
40 | self._accumulators = {'100k': [], '500k': []}
41 |
42 | def save(self, experiment_name, filepath):
43 | """
44 | Save the evaluation results in a JSON file
45 | understandable by the leaderboard.
46 |
47 | Note: The leaderboard will not accept a submission if you already
48 | uploaded a file with the same experiment name.
49 |
50 | :param experiment_name: An identifier for the experiment
51 | :param filepath: Path where to store the results file
52 | :return:
53 | """
54 |
55 | assert experiment_name is not None, 'experiment_name cannot be None'
56 |
57 | from json import dump
58 | from os.path import exists, join, pardir, abspath
59 | from os import makedirs
60 | from numpy import mean, var
61 |
62 | # Compute metrics
63 | metrics = {key: {'mean': mean(buffer),
64 | 'var': var(buffer),
65 | 'count': len(buffer)}
66 | for key, buffer in self._accumulators.items()}
67 |
68 | metrics['experimentname'] = experiment_name
69 |
70 | try:
71 | filepath = abspath(filepath)
72 | parent = join(pardir, filepath)
73 | if not exists(parent):
74 | makedirs(parent)
75 |
76 | with open(filepath, 'w') as f_out:
77 | dump(metrics, f_out)
78 |
79 | print('==================================')
80 | print('Evaluation done, results written at %s' % filepath)
81 |
82 | except Exception as e:
83 | print('Unable to save the results: %s' % e)
84 |
85 | def run(self):
86 | from multiprocessing import Process
87 |
88 | env = PigChaseEnvironment(self._clients, self._state_builder,
89 | role=1, randomize_positions=True)
90 | print('==================================')
91 | print('Starting evaluation of Agent @100k')
92 |
93 | p = Process(target=run_challenge_agent, args=(self._clients,))
94 | p.start()
95 | sleep(5)
96 | agent_loop(self._agent_100k, env, self._accumulators['100k'])
97 | p.terminate()
98 |
99 | print('==================================')
100 | print('Starting evaluation of Agent @500k')
101 |
102 | p = Process(target=run_challenge_agent, args=(self._clients,))
103 | p.start()
104 | sleep(5)
105 | agent_loop(self._agent_500k, env, self._accumulators['500k'])
106 | p.terminate()
107 |
108 |
109 | def run_challenge_agent(clients):
110 | builder = PigChaseSymbolicStateBuilder()
111 | env = PigChaseEnvironment(clients, builder, role=0,
112 | randomize_positions=True)
113 | agent = PigChaseChallengeAgent(ENV_AGENT_NAMES[0])
114 | agent_loop(agent, env, None)
115 |
116 |
117 | def agent_loop(agent, env, metrics_acc):
118 | EVAL_EPISODES = 100
119 | agent_done = False
120 | reward = 0
121 | episode = 0
122 | obs = env.reset()
123 |
124 | while episode < EVAL_EPISODES:
125 | # check if env needs reset
126 | if env.done:
127 | print('Episode %d (%.2f)%%' % (episode, (episode / EVAL_EPISODES) * 100.))
128 |
129 | obs = env.reset()
130 | while obs is None:
131 | # this can happen if the episode ended with the first
132 | # action of the other agent
133 | print('Warning: received obs == None.')
134 | obs = env.reset()
135 |
136 | episode += 1
137 |
138 | # select an action
139 | action = agent.act(obs, reward, agent_done, is_training=True)
140 | # take a step
141 | obs, reward, agent_done = env.do(action)
142 |
143 | if metrics_acc is not None:
144 | metrics_acc.append(reward)
145 |
--------------------------------------------------------------------------------
/ai_challenge/pig_chase/pig-chase-overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/malmo-challenge/72634708638fd6fe84521894891f054933717ecf/ai_challenge/pig_chase/pig-chase-overview.png
--------------------------------------------------------------------------------
/ai_challenge/pig_chase/pig_chase.xml:
--------------------------------------------------------------------------------
1 |
2 |
13 |
14 |
15 |
16 |
17 | Catch the pig!
18 |
19 |
20 |
21 | 4
22 |
23 |
24 |
25 |
26 |
30 | clear
31 | false
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 | Agent_1
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 | attack
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 | 640
122 | 480
123 |
124 |
125 |
126 |
127 |
128 | Agent_2
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 | attack
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 | 640
170 | 480
171 |
172 |
173 |
174 |
175 |
--------------------------------------------------------------------------------
/ai_challenge/pig_chase/pig_chase_baseline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | import numpy as np
19 | import os
20 | import sys
21 |
22 | from argparse import ArgumentParser
23 | from datetime import datetime
24 |
25 | import six
26 | from os import path
27 | from threading import Thread, active_count
28 | from time import sleep
29 | from builtins import range
30 |
31 | from malmopy.agent import RandomAgent
32 | try:
33 | from malmopy.visualization.tensorboard import TensorboardVisualizer
34 | from malmopy.visualization.tensorboard.cntk import CntkConverter
35 | except ImportError:
36 | print('Cannot import tensorboard, using ConsoleVisualizer.')
37 | from malmopy.visualization import ConsoleVisualizer
38 |
39 | from common import parse_clients_args, visualize_training, ENV_AGENT_NAMES, ENV_TARGET_NAMES
40 | from agent import PigChaseChallengeAgent, FocusedAgent, TabularQLearnerAgent, get_agent_type
41 | from environment import PigChaseEnvironment, PigChaseSymbolicStateBuilder
42 |
43 | # Enforce path
44 | sys.path.insert(0, os.getcwd())
45 | sys.path.insert(1, os.path.join(os.path.pardir, os.getcwd()))
46 |
47 | BASELINES_FOLDER = 'results/baselines/pig_chase/%s/%s'
48 | EPOCH_SIZE = 100
49 |
50 |
51 | def agent_factory(name, role, baseline_agent, clients, max_epochs,
52 | logdir, visualizer):
53 |
54 | assert len(clients) >= 2, 'Not enough clients (need at least 2)'
55 | clients = parse_clients_args(clients)
56 |
57 | builder = PigChaseSymbolicStateBuilder()
58 | env = PigChaseEnvironment(clients, builder, role=role,
59 | randomize_positions=True)
60 |
61 | if role == 0:
62 | agent = PigChaseChallengeAgent(name)
63 | obs = env.reset(get_agent_type(agent))
64 |
65 | reward = 0
66 | agent_done = False
67 |
68 | while True:
69 | if env.done:
70 | while True:
71 | obs = env.reset(get_agent_type(agent))
72 | if obs:
73 | break
74 |
75 | # select an action
76 | action = agent.act(obs, reward, agent_done, is_training=True)
77 |
78 | # reset if needed
79 | if env.done:
80 | obs = env.reset(get_agent_type(agent))
81 |
82 | # take a step
83 | obs, reward, agent_done = env.do(action)
84 |
85 |
86 | else:
87 |
88 | if baseline_agent == 'tabq':
89 | agent = TabularQLearnerAgent(name, visualizer)
90 | elif baseline_agent == 'astar':
91 | agent = FocusedAgent(name, ENV_TARGET_NAMES[0])
92 | else:
93 | agent = RandomAgent(name, env.available_actions)
94 |
95 | obs = env.reset()
96 | reward = 0
97 | agent_done = False
98 | viz_rewards = []
99 |
100 | max_training_steps = EPOCH_SIZE * max_epochs
101 | for step in six.moves.range(1, max_training_steps+1):
102 |
103 | # check if env needs reset
104 | if env.done:
105 | while True:
106 | if len(viz_rewards) == 0:
107 | viz_rewards.append(0)
108 | visualize_training(visualizer, step, viz_rewards)
109 | tag = "Episode End Conditions"
110 | visualizer.add_entry(step, '%s/timeouts per episode' % tag, env.end_result == "command_quota_reached")
111 | visualizer.add_entry(step, '%s/agent_1 defaults per episode' % tag, env.end_result == "Agent_1_defaulted")
112 | visualizer.add_entry(step, '%s/agent_2 defaults per episode' % tag, env.end_result == "Agent_2_defaulted")
113 | visualizer.add_entry(step, '%s/pig caught per episode' % tag, env.end_result == "caught_the_pig")
114 | agent.inject_summaries(step)
115 | viz_rewards = []
116 | obs = env.reset()
117 | if obs:
118 | break
119 |
120 | # select an action
121 | action = agent.act(obs, reward, agent_done, is_training=True)
122 | # take a step
123 | obs, reward, agent_done = env.do(action)
124 | viz_rewards.append(reward)
125 |
126 | #agent.inject_summaries(step)
127 |
128 |
129 | def run_experiment(agents_def):
130 | assert len(agents_def) == 2, 'Not enough agents (required: 2, got: %d)'\
131 | % len(agents_def)
132 |
133 | processes = []
134 | for agent in agents_def:
135 | p = Thread(target=agent_factory, kwargs=agent)
136 | p.daemon = True
137 | p.start()
138 |
139 | # Give the server time to start
140 | if agent['role'] == 0:
141 | sleep(1)
142 |
143 | processes.append(p)
144 |
145 | try:
146 | # wait until only the challenge agent is left
147 | while active_count() > 2:
148 | sleep(0.1)
149 | except KeyboardInterrupt:
150 | print('Caught control-c - shutting down.')
151 |
152 |
153 | if __name__ == '__main__':
154 | arg_parser = ArgumentParser('Pig Chase baseline experiment')
155 | arg_parser.add_argument('-t', '--type', type=str, default='astar',
156 | choices=['tabq', 'astar', 'random'],
157 | help='The type of baseline to run.')
158 | arg_parser.add_argument('-e', '--epochs', type=int, default=5,
159 | help='Number of epochs to run.')
160 | arg_parser.add_argument('clients', nargs='*',
161 | default=['127.0.0.1:10000', '127.0.0.1:10001'],
162 | help='Minecraft clients endpoints (ip(:port)?)+')
163 | args = arg_parser.parse_args()
164 |
165 | logdir = BASELINES_FOLDER % (args.type, datetime.utcnow().isoformat())
166 | if 'malmopy.visualization.tensorboard' in sys.modules:
167 | visualizer = TensorboardVisualizer()
168 | visualizer.initialize(logdir, None)
169 | else:
170 | visualizer = ConsoleVisualizer()
171 |
172 | agents = [{'name': agent, 'role': role, 'baseline_agent': args.type,
173 | 'clients': args.clients, 'max_epochs': args.epochs,
174 | 'logdir': logdir, 'visualizer': visualizer}
175 | for role, agent in enumerate(ENV_AGENT_NAMES)]
176 |
177 | run_experiment(agents)
178 |
179 |
--------------------------------------------------------------------------------
/ai_challenge/pig_chase/pig_chase_dqn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | import os
19 | import sys
20 | from argparse import ArgumentParser
21 | from datetime import datetime
22 |
23 | import six
24 | from os import path
25 | from threading import Thread, active_count
26 | from time import sleep
27 |
28 | from malmopy.agent import LinearEpsilonGreedyExplorer
29 |
30 | from common import parse_clients_args, visualize_training, ENV_AGENT_NAMES
31 | from agent import PigChaseChallengeAgent, PigChaseQLearnerAgent
32 | from environment import PigChaseEnvironment, PigChaseSymbolicStateBuilder
33 |
34 | from malmopy.environment.malmo import MalmoALEStateBuilder
35 | from malmopy.agent import TemporalMemory, RandomAgent
36 |
37 | try:
38 | from malmopy.visualization.tensorboard import TensorboardVisualizer
39 | from malmopy.visualization.tensorboard.cntk import CntkConverter
40 | except ImportError:
41 | print('Cannot import tensorboard, using ConsoleVisualizer.')
42 | from malmopy.visualization import ConsoleVisualizer
43 |
44 | # Enforce path
45 | sys.path.insert(0, os.getcwd())
46 | sys.path.insert(1, os.path.join(os.path.pardir, os.getcwd()))
47 |
48 | DQN_FOLDER = 'results/baselines/%s/dqn/%s-%s'
49 | EPOCH_SIZE = 100000
50 |
51 |
52 | def agent_factory(name, role, clients, backend,
53 | device, max_epochs, logdir, visualizer):
54 |
55 | assert len(clients) >= 2, 'Not enough clients (need at least 2)'
56 | clients = parse_clients_args(clients)
57 |
58 | if role == 0:
59 |
60 | builder = PigChaseSymbolicStateBuilder()
61 | env = PigChaseEnvironment(clients, builder, role=role,
62 | randomize_positions=True)
63 | agent = PigChaseChallengeAgent(name)
64 | if type(agent.current_agent) == RandomAgent:
65 | agent_type = PigChaseEnvironment.AGENT_TYPE_1
66 | else:
67 | agent_type = PigChaseEnvironment.AGENT_TYPE_2
68 |
69 | obs = env.reset(agent_type)
70 | reward = 0
71 | agent_done = False
72 |
73 | while True:
74 | if env.done:
75 | if type(agent.current_agent) == RandomAgent:
76 | agent_type = PigChaseEnvironment.AGENT_TYPE_1
77 | else:
78 | agent_type = PigChaseEnvironment.AGENT_TYPE_2
79 |
80 | obs = env.reset(agent_type)
81 | while obs is None:
82 | # this can happen if the episode ended with the first
83 | # action of the other agent
84 | print('Warning: received obs == None.')
85 | obs = env.reset(agent_type)
86 |
87 | # select an action
88 | action = agent.act(obs, reward, agent_done, is_training=True)
89 | # take a step
90 | obs, reward, agent_done = env.do(action)
91 |
92 | else:
93 | env = PigChaseEnvironment(clients, MalmoALEStateBuilder(),
94 | role=role, randomize_positions=True)
95 | memory = TemporalMemory(100000, (84, 84))
96 |
97 | if backend == 'cntk':
98 | from malmopy.model.cntk import QNeuralNetwork
99 | model = QNeuralNetwork((memory.history_length, 84, 84), env.available_actions, device)
100 | else:
101 | from malmopy.model.chainer import QNeuralNetwork, DQNChain
102 | chain = DQNChain((memory.history_length, 84, 84), env.available_actions)
103 | target_chain = DQNChain((memory.history_length, 84, 84), env.available_actions)
104 | model = QNeuralNetwork(chain, target_chain, device)
105 |
106 | explorer = LinearEpsilonGreedyExplorer(1, 0.1, 1000000)
107 | agent = PigChaseQLearnerAgent(name, env.available_actions,
108 | model, memory, 0.99, 32, 50000,
109 | explorer=explorer, visualizer=visualizer)
110 |
111 | obs = env.reset()
112 | reward = 0
113 | agent_done = False
114 | viz_rewards = []
115 |
116 | max_training_steps = EPOCH_SIZE * max_epochs
117 | for step in six.moves.range(1, max_training_steps+1):
118 |
119 | # check if env needs reset
120 | if env.done:
121 |
122 | visualize_training(visualizer, step, viz_rewards)
123 | agent.inject_summaries(step)
124 | viz_rewards = []
125 |
126 | obs = env.reset()
127 | while obs is None:
128 | # this can happen if the episode ended with the first
129 | # action of the other agent
130 | print('Warning: received obs == None.')
131 | obs = env.reset()
132 |
133 | # select an action
134 | action = agent.act(obs, reward, agent_done, is_training=True)
135 | # take a step
136 | obs, reward, agent_done = env.do(action)
137 | viz_rewards.append(reward)
138 |
139 | if (step % EPOCH_SIZE) == 0:
140 | if 'model' in locals():
141 | model.save('pig_chase-dqn_%d.model' % (step / EPOCH_SIZE))
142 |
143 |
144 | def run_experiment(agents_def):
145 | assert len(agents_def) == 2, 'Not enough agents (required: 2, got: %d)' \
146 | % len(agents_def)
147 |
148 | processes = []
149 | for agent in agents_def:
150 | p = Thread(target=agent_factory, kwargs=agent)
151 | p.daemon = True
152 | p.start()
153 |
154 | # Give the server time to start
155 | if agent['role'] == 0:
156 | sleep(1)
157 |
158 | processes.append(p)
159 |
160 | try:
161 | # wait until only the challenge agent is left
162 | while active_count() > 2:
163 | sleep(0.1)
164 | except KeyboardInterrupt:
165 | print('Caught control-c - shutting down.')
166 |
167 |
168 | if __name__ == '__main__':
169 | arg_parser = ArgumentParser('Pig Chase DQN experiment')
170 | arg_parser.add_argument('-b', '--backend', type=str, choices=['cntk', 'chainer'],
171 | default='cntk', help='Neural network backend')
172 | arg_parser.add_argument('-e', '--epochs', type=int, default=5,
173 | help='Number of epochs to run.')
174 | arg_parser.add_argument('clients', nargs='*',
175 | default=['127.0.0.1:10000', '127.0.0.1:10001'],
176 | help='Minecraft clients endpoints (ip(:port)?)+')
177 | arg_parser.add_argument('-d', '--device', type=int, default=-1,
178 | help='GPU device on which to run the experiment.')
179 | args = arg_parser.parse_args()
180 |
181 | logdir = path.join('results/pig_chase/dqn', datetime.utcnow().isoformat())
182 | if 'malmopy.visualization.tensorboard' in sys.modules:
183 | visualizer = TensorboardVisualizer()
184 | visualizer.initialize(logdir, None)
185 |
186 | else:
187 | visualizer = ConsoleVisualizer()
188 |
189 | agents = [{'name': agent, 'role': role, 'clients': args.clients,
190 | 'backend': args.backend, 'device': args.device,
191 | 'max_epochs': args.epochs, 'logdir': logdir, 'visualizer': visualizer}
192 | for role, agent in enumerate(ENV_AGENT_NAMES)]
193 |
194 | run_experiment(agents)
195 |
--------------------------------------------------------------------------------
/ai_challenge/pig_chase/pig_chase_dqn_top_down.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | import os
19 | import sys
20 | from argparse import ArgumentParser
21 | from datetime import datetime
22 |
23 | import six
24 | from os import path
25 | from threading import Thread, active_count
26 | from time import sleep
27 |
28 | from malmopy.agent import LinearEpsilonGreedyExplorer, RandomAgent
29 |
30 | from common import parse_clients_args, visualize_training, ENV_AGENT_NAMES
31 | from agent import PigChaseChallengeAgent, PigChaseQLearnerAgent
32 | from environment import PigChaseEnvironment, PigChaseSymbolicStateBuilder, \
33 | PigChaseTopDownStateBuilder
34 |
35 | from malmopy.agent import TemporalMemory
36 |
37 | try:
38 | from malmopy.visualization.tensorboard import TensorboardVisualizer
39 | from malmopy.visualization.tensorboard.cntk import CntkConverter
40 | except ImportError:
41 | print('Cannot import tensorboard, using ConsoleVisualizer.')
42 | from malmopy.visualization import ConsoleVisualizer
43 |
44 | # Enforce path
45 | sys.path.insert(0, os.getcwd())
46 | sys.path.insert(1, os.path.join(os.path.pardir, os.getcwd()))
47 |
48 | DQN_FOLDER = 'results/baselines/%s/dqn/%s-%s'
49 | EPOCH_SIZE = 100000
50 |
51 |
52 | def agent_factory(name, role, clients, backend, device, max_epochs, logdir, visualizer):
53 |
54 | assert len(clients) >= 2, 'Not enough clients (need at least 2)'
55 | clients = parse_clients_args(clients)
56 |
57 | if role == 0:
58 |
59 | builder = PigChaseSymbolicStateBuilder()
60 | env = PigChaseEnvironment(clients, builder, role=role,
61 | randomize_positions=True)
62 |
63 | agent = PigChaseChallengeAgent(name)
64 | if type(agent.current_agent) == RandomAgent:
65 | agent_type = PigChaseEnvironment.AGENT_TYPE_1
66 | else:
67 | agent_type = PigChaseEnvironment.AGENT_TYPE_2
68 |
69 | obs = env.reset(agent_type)
70 | reward = 0
71 | agent_done = False
72 |
73 | while True:
74 | if env.done:
75 | if type(agent.current_agent) == RandomAgent:
76 | agent_type = PigChaseEnvironment.AGENT_TYPE_1
77 | else:
78 | agent_type = PigChaseEnvironment.AGENT_TYPE_2
79 |
80 | obs = env.reset(agent_type)
81 | while obs is None:
82 | # this can happen if the episode ended with the first
83 | # action of the other agent
84 | print('Warning: received obs == None.')
85 | obs = env.reset(agent_type)
86 |
87 | # select an action
88 | action = agent.act(obs, reward, agent_done, is_training=True)
89 | # take a step
90 | obs, reward, agent_done = env.do(action)
91 |
92 | else:
93 | env = PigChaseEnvironment(clients, PigChaseTopDownStateBuilder(True),
94 | role=role, randomize_positions=True)
95 | memory = TemporalMemory(100000, (18, 18))
96 |
97 | if backend == 'cntk':
98 | from malmopy.model.cntk import QNeuralNetwork
99 | model = QNeuralNetwork((memory.history_length, 18, 18), env.available_actions, device)
100 | else:
101 | from malmopy.model.chainer import QNeuralNetwork, ReducedDQNChain
102 | chain = ReducedDQNChain((memory.history_length, 18, 18), env.available_actions)
103 | target_chain = ReducedDQNChain((memory.history_length, 18, 18), env.available_actions)
104 | model = QNeuralNetwork(chain, target_chain, device)
105 |
106 | explorer = LinearEpsilonGreedyExplorer(1, 0.1, 1000000)
107 | agent = PigChaseQLearnerAgent(name, env.available_actions,
108 | model, memory, 0.99, 32, 50000,
109 | explorer=explorer, visualizer=visualizer)
110 |
111 | obs = env.reset()
112 | reward = 0
113 | agent_done = False
114 | viz_rewards = []
115 |
116 | max_training_steps = EPOCH_SIZE * max_epochs
117 | for step in six.moves.range(1, max_training_steps+1):
118 |
119 | # check if env needs reset
120 | if env.done:
121 |
122 | visualize_training(visualizer, step, viz_rewards)
123 | agent.inject_summaries(step)
124 | viz_rewards = []
125 |
126 | obs = env.reset()
127 | while obs is None:
128 | # this can happen if the episode ended with the first
129 | # action of the other agent
130 | print('Warning: received obs == None.')
131 | obs = env.reset()
132 |
133 | # select an action
134 | action = agent.act(obs, reward, agent_done, is_training=True)
135 | # take a step
136 | obs, reward, agent_done = env.do(action)
137 | viz_rewards.append(reward)
138 |
139 | if (step % EPOCH_SIZE) == 0:
140 | if 'model' in locals():
141 | model.save('pig_chase-dqn_%d.model' % (step / EPOCH_SIZE))
142 |
143 |
144 | def run_experiment(agents_def):
145 | assert len(agents_def) == 2, 'Not enough agents (required: 2, got: %d)' \
146 | % len(agents_def)
147 |
148 | processes = []
149 | for agent in agents_def:
150 | p = Thread(target=agent_factory, kwargs=agent)
151 | p.daemon = True
152 | p.start()
153 |
154 | # Give the server time to start
155 | if agent['role'] == 0:
156 | sleep(1)
157 |
158 | processes.append(p)
159 |
160 | try:
161 | # wait until only the challenge agent is left
162 | while active_count() > 2:
163 | sleep(0.1)
164 | except KeyboardInterrupt:
165 | print('Caught control-c - shutting down.')
166 |
167 |
168 | if __name__ == '__main__':
169 | arg_parser = ArgumentParser('Pig Chase DQN experiment')
170 | arg_parser.add_argument('-b', '--backend', type=str, choices=['cntk', 'chainer'],
171 | default='cntk', help='Neural network backend')
172 | arg_parser.add_argument('-e', '--epochs', type=int, default=5,
173 | help='Number of epochs to run.')
174 | arg_parser.add_argument('clients', nargs='*',
175 | default=['127.0.0.1:10000', '127.0.0.1:10001'],
176 | help='Minecraft clients endpoints (ip(:port)?)+')
177 | arg_parser.add_argument('-d', '--device', type=int, default=-1,
178 | help='GPU device on which to run the experiment.')
179 | args = arg_parser.parse_args()
180 |
181 | logdir = path.join('results/pig_chase/dqn', datetime.utcnow().isoformat())
182 | if 'malmopy.visualization.tensorboard' in sys.modules:
183 | visualizer = TensorboardVisualizer()
184 | visualizer.initialize(logdir, None)
185 |
186 | else:
187 | visualizer = ConsoleVisualizer()
188 |
189 | agents = [{'name': agent, 'role': role, 'clients': args.clients,
190 | 'backend':args.backend, 'device': args.device,
191 | 'max_epochs': args.epochs, 'logdir': logdir, 'visualizer': visualizer}
192 | for role, agent in enumerate(ENV_AGENT_NAMES)]
193 |
194 | run_experiment(agents)
195 |
--------------------------------------------------------------------------------
/ai_challenge/pig_chase/pig_chase_eval_sample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from common import ENV_AGENT_NAMES
19 | from evaluation import PigChaseEvaluator
20 | from environment import PigChaseTopDownStateBuilder
21 | from malmopy.agent import RandomAgent
22 |
23 |
24 | if __name__ == '__main__':
25 | # Warn for Agent name !!!
26 |
27 | clients = [('127.0.0.1', 10000), ('127.0.0.1', 10001)]
28 | agent = RandomAgent(ENV_AGENT_NAMES[1], 3)
29 |
30 | eval = PigChaseEvaluator(clients, agent, agent, PigChaseTopDownStateBuilder())
31 | eval.run()
32 |
33 | eval.save('My Exp 1', 'pig_chase_results.json')
34 |
--------------------------------------------------------------------------------
/ai_challenge/pig_chase/pig_chase_human_vs_agent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | import os
19 | import sys
20 | from argparse import ArgumentParser
21 | from datetime import datetime
22 | from multiprocessing import Process, Event
23 | from os import path
24 | from time import sleep
25 |
26 | from malmopy.agent import RandomAgent
27 | from malmopy.agent.gui import ARROW_KEYS_MAPPING
28 | from malmopy.visualization import ConsoleVisualizer
29 |
30 | # Enforce path
31 | sys.path.insert(0, os.getcwd())
32 | sys.path.insert(1, os.path.join(os.path.pardir, os.getcwd()))
33 |
34 | from common import parse_clients_args, ENV_AGENT_NAMES, ENV_ACTIONS, ENV_AGENT_TYPES
35 | from agent import PigChaseChallengeAgent, PigChaseHumanAgent, TabularQLearnerAgent, FocusedAgent, get_agent_type
36 | from environment import PigChaseEnvironment, PigChaseSymbolicStateBuilder
37 |
38 | MAX_ACTIONS = 25 # this should match the mission definition, used for display only
39 |
40 | def agent_factory(name, role, kind, clients, max_episodes, max_actions, logdir, quit, model_file):
41 | assert len(clients) >= 2, 'There are not enough Malmo clients in the pool (need at least 2)'
42 |
43 | clients = parse_clients_args(clients)
44 | visualizer = ConsoleVisualizer(prefix='Agent %d' % role)
45 |
46 | if role == 0:
47 | env = PigChaseEnvironment(clients, PigChaseSymbolicStateBuilder(),
48 | actions=ENV_ACTIONS, role=role,
49 | human_speed=True, randomize_positions=True)
50 | if kind == 'challenge':
51 | agent = PigChaseChallengeAgent(name)
52 | elif kind == 'astar':
53 | agent = FocusedAgent(name, ENV_TARGET_NAMES[0])
54 | elif kind == 'tabq':
55 | agent = TabularQLearnerAgent(name)
56 | if model_file != '':
57 | agent.load(model_file)
58 | else:
59 | agent = RandomAgent(name ,env.available_actions)
60 |
61 | obs = env.reset(get_agent_type(agent))
62 | reward = 0
63 | rewards = []
64 | done = False
65 | episode = 0
66 |
67 | while True:
68 |
69 | # select an action
70 | action = agent.act(obs, reward, done, True)
71 |
72 | if done:
73 | visualizer << (episode + 1, 'Reward', sum(rewards))
74 | rewards = []
75 | episode += 1
76 |
77 | obs = env.reset(get_agent_type(agent))
78 |
79 | # take a step
80 | obs, reward, done = env.do(action)
81 | rewards.append(reward)
82 |
83 | else:
84 | env = PigChaseEnvironment(clients, PigChaseSymbolicStateBuilder(),
85 | actions=list(ARROW_KEYS_MAPPING.values()),
86 | role=role, randomize_positions=True)
87 | env.reset(ENV_AGENT_TYPES.HUMAN)
88 |
89 | agent = PigChaseHumanAgent(name, env, list(ARROW_KEYS_MAPPING.keys()),
90 | max_episodes, max_actions, visualizer, quit)
91 | agent.show()
92 |
93 |
94 | def run_mission(agents_def):
95 | assert len(agents_def) == 2, 'Incompatible number of agents (required: 2, got: %d)' % len(agents_def)
96 | quit = Event()
97 | processes = []
98 | for agent in agents_def:
99 | agent['quit'] = quit
100 | p = Process(target=agent_factory, kwargs=agent)
101 | p.daemon = True
102 | p.start()
103 |
104 | if agent['role'] == 0:
105 | sleep(1) # Just to let time for the server to start
106 |
107 | processes.append(p)
108 | quit.wait()
109 | for process in processes:
110 | process.terminate()
111 |
112 |
113 | if __name__ == '__main__':
114 | arg_parser = ArgumentParser()
115 | arg_parser.add_argument('-e', '--episodes', type=int, default=10, help='Number of episodes to run.')
116 | arg_parser.add_argument('-k', '--kind', type=str, default='challenge', choices=['astar', 'random', 'tabq', 'challenge'],
117 | help='The kind of agent to play with (random, astar, tabq or challenge).')
118 | arg_parser.add_argument('-m', '--model_file', type=str, default='', help='Model file with which to initialise agent, if appropriate')
119 | arg_parser.add_argument('clients', nargs='*',
120 | default=['127.0.0.1:10000', '127.0.0.1:10001'],
121 | help='Malmo clients (ip(:port)?)+')
122 | args = arg_parser.parse_args()
123 |
124 | logdir = path.join('results/pig-human', datetime.utcnow().isoformat())
125 | agents = [{'name': agent, 'role': role, 'kind': args.kind, 'model_file': args.model_file,
126 | 'clients': args.clients, 'max_episodes': args.episodes,
127 | 'max_actions': MAX_ACTIONS, 'logdir': logdir}
128 | for role, agent in enumerate(ENV_AGENT_NAMES)]
129 |
130 | run_mission(agents)
131 |
--------------------------------------------------------------------------------
/docker/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Run experiments in docker
4 |
5 | [Docker](https://www.docker.com/) is a container solution that makes it easy to build and deploy
6 | software in a virtual environment. The examples in this folder use docker to easily deploy an experiment
7 | with all its dependencies, either on a local machine or on the cloud.
8 |
9 | ## Prerequisites
10 |
11 | Install docker on your local machine by following the installation instructions for
12 | [Windows](https://docs.docker.com/docker-for-windows/install/),
13 | [Linux](https://docs.docker.com/engine/installation/),
14 | [MacOS](https://docs.docker.com/docker-for-mac/install/).
15 |
16 | Prepare a docker machine on Azure, follow the local installation steps above, then run:
17 | ```
18 | docker-machine create --driver azure --azure-size Standard_D12 --azure-subscription-id
19 | ```
20 | Replace `` with your Azure subsciption id - you can find this on the Azure dashboard after
21 | logging on to https://portal.azure.com. The `` is arbitrary.
22 |
23 | Additional `docker-machine` options are listed here: https://docs.docker.com/machine/drivers/azure/
24 | Azure machine sizes are detailed on: https://docs.microsoft.com/en-us/azure/virtual-machines/virtual-machines-linux-sizes (we recommend to use at least size Standard_D12)
25 |
26 | Configure docker to deploy to ``. Run:
27 | ```
28 | docker-machine env
29 | ```
30 | This will provide a script / instructions on how to prepare your environment to work with .
31 |
32 | ## Build the docker images
33 |
34 | Build the required docker images:
35 | ```
36 | cd docker
37 | docker build malmo -t malmo:latest
38 | docker build malmopy-cntk-cpu-py27 -t malmopy-cntk-cpu-py27:latest
39 |
40 | ```
41 |
42 | Check to make sure that the images have been compiled:
43 | ```
44 | docker images
45 | ```
46 | You should see a list that includes the compiled images, e.g.,
47 | ```
48 | REPOSITORY TAG IMAGE ID CREATED SIZE
49 | malmopy-cntk-cpu-py27 latest 0161af81632d 29 minutes ago 5.62 GB
50 | malmo latest 1b67b8e2cfa8 41 minutes ago 1.04 GB
51 | ...
52 | ```
53 |
54 | ## Run the experiment
55 |
56 | Run the challenge task with an example agent:
57 | ```
58 | cd malmopy-ai-challenge
59 | docker-compose up
60 | ```
61 |
62 | The experiment is set up to start a tensorboard process alongside the experiment.
63 | You can view it by pointing your browser to http://127.0.0.1:6006.
64 |
65 | ## Write your own
66 |
67 | The provided docker files load malmopy and sample code directly from the
68 | `malmo-challenge` git repository. To include your own code, create a file
69 | called `Dockerfile` with the following content:
70 |
71 | ```
72 | FROM malmopy-cntk-cpu-py27:latest
73 |
74 | # add your own experiment code here
75 | # ADD copies content from your local machine into the docker image
76 | ADD ai_challenge/pig_chase /local/malmo-challenge/ai_challenge/pig_chase
77 | ```
78 |
79 | Build this new image using:
80 | ```
81 | docker build . -t my_malmo_experiment:latest
82 | ```
83 |
84 | Point the `agents` service in `docker-compose.py` to the new image by replacing
85 | `image: malmopy-cntk-cpu-py27:latest` with the name of the image you have just
86 | built (e.g., `image:my_malmo_experiment:latest`). Also check if the working
87 | directory or command need to be changed.
88 |
89 | Then run the new experiment:
90 | ```
91 | docker-compose up
92 | ```
93 |
94 | ## Cleaning up
95 |
96 | If you are using a docker machine on Azure, make sure to shutdown and decomission
97 | the machine when your experiments have completed, to avoid incurring costs.
98 |
99 | To shut a machine down:
100 | ```
101 | docker-machine stop
102 | ```
103 |
104 | To remove (decomission) a machine:
105 | ```
106 | docker-machine rm
107 | ```
108 |
109 | ## Further reading
110 |
111 | - [docker documentation](https://docs.docker.com/)
112 | - [docker on Azure](https://docs.docker.com/machine/drivers/azure/)
113 | - [docker compose](https://docs.docker.com/compose/overview/)
114 |
--------------------------------------------------------------------------------
/docker/malmo/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | FROM ubuntu:16.04
19 |
20 | ENV MALMO_VERSION 0.21.0
21 |
22 | # Install Malmo dependencies
23 | RUN apt-get update && apt-get install -y --no-install-recommends \
24 | openjdk-8-jdk \
25 | libxerces-c3.1 \
26 | libav-tools \
27 | wget \
28 | unzip \
29 | xvfb && \
30 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
31 |
32 | # Download and unpack Malmo
33 | WORKDIR /root
34 | RUN wget https://github.com/Microsoft/malmo/releases/download/$MALMO_VERSION/Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
35 | unzip Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
36 | rm Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
37 | mv Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost Malmo
38 | ENV MALMO_XSD_PATH /root/Malmo/Schemas
39 |
40 | # Precompile Malmo mod
41 | RUN mkdir ~/.gradle && echo 'org.gradle.daemon=true\n' > ~/.gradle/gradle.properties
42 | WORKDIR /root/Malmo/Minecraft
43 | RUN ./gradlew setupDecompWorkspace
44 | RUN ./gradlew build
45 |
46 | # Unlimited framerate settings
47 | COPY options.txt /root/Malmo/Minecraft/run
48 |
49 | COPY run.sh /root/
50 | RUN chmod +x /root/run.sh
51 |
52 | # Expose port
53 | EXPOSE 10000
54 |
55 | # Run Malmo
56 | ENTRYPOINT ["/root/run.sh", "/root/Malmo/Minecraft/launchClient.sh"]
--------------------------------------------------------------------------------
/docker/malmo/options.txt:
--------------------------------------------------------------------------------
1 | invertYMouse:false
2 | mouseSensitivity:0.5
3 | fov:0.0
4 | gamma:0.0
5 | saturation:0.0
6 | renderDistance:6
7 | guiScale:0
8 | particles:2
9 | bobView:true
10 | anaglyph3d:false
11 | maxFps:200
12 | fboEnable:true
13 | difficulty:2
14 | fancyGraphics:false
15 | ao:0
16 | renderClouds:true
17 | resourcePacks:[]
18 | lastServer:
19 | lang:en_US
20 | chatVisibility:0
21 | chatColors:true
22 | chatLinks:true
23 | chatLinksPrompt:true
24 | chatOpacity:1.0
25 | snooperEnabled:true
26 | fullscreen:false
27 | enableVsync:true
28 | useVbo:false
29 | hideServerAddress:false
30 | advancedItemTooltips:false
31 | pauseOnLostFocus:false
32 | touchscreen:false
33 | overrideWidth:0
34 | overrideHeight:0
35 | heldItemTooltips:true
36 | chatHeightFocused:1.0
37 | chatHeightUnfocused:0.44366196
38 | chatScale:1.0
39 | chatWidth:1.0
40 | showInventoryAchievementHint:false
41 | mipmapLevels:4
42 | streamBytesPerPixel:0.5
43 | streamMicVolume:1.0
44 | streamSystemVolume:1.0
45 | streamKbps:0.5412844
46 | streamFps:0.31690142
47 | streamCompression:1
48 | streamSendMetadata:true
49 | streamPreferredServer:
50 | streamChatEnabled:0
51 | streamChatUserFilter:0
52 | streamMicToggleBehavior:0
53 | forceUnicodeFont:false
54 | allowBlockAlternatives:true
55 | reducedDebugInfo:false
56 | key_key.attack:-100
57 | key_key.use:-99
58 | key_key.forward:17
59 | key_key.left:30
60 | key_key.back:31
61 | key_key.right:32
62 | key_key.jump:57
63 | key_key.sneak:42
64 | key_key.drop:16
65 | key_key.inventory:18
66 | key_key.chat:20
67 | key_key.playerlist:15
68 | key_key.pickItem:-98
69 | key_key.command:53
70 | key_key.screenshot:60
71 | key_key.togglePerspective:63
72 | key_key.smoothCamera:0
73 | key_key.sprint:29
74 | key_key.streamStartStop:64
75 | key_key.streamPauseUnpause:65
76 | key_key.streamCommercial:0
77 | key_key.streamToggleMic:0
78 | key_key.fullscreen:87
79 | key_key.spectatorOutlines:0
80 | key_key.hotbar.1:2
81 | key_key.hotbar.2:3
82 | key_key.hotbar.3:4
83 | key_key.hotbar.4:5
84 | key_key.hotbar.5:6
85 | key_key.hotbar.6:7
86 | key_key.hotbar.7:8
87 | key_key.hotbar.8:9
88 | key_key.hotbar.9:10
89 | key_key.toggleMalmo:28
90 | key_key.handyTestHook:22
91 | soundCategory_master:0.0
92 | soundCategory_music:1.0
93 | soundCategory_record:1.0
94 | soundCategory_weather:1.0
95 | soundCategory_block:1.0
96 | soundCategory_hostile:1.0
97 | soundCategory_neutral:1.0
98 | soundCategory_player:1.0
99 | soundCategory_ambient:1.0
100 | modelPart_cape:true
101 | modelPart_jacket:true
102 | modelPart_left_sleeve:true
103 | modelPart_right_sleeve:true
104 | modelPart_left_pants_leg:true
105 | modelPart_right_pants_leg:true
106 | modelPart_hat:true
107 |
--------------------------------------------------------------------------------
/docker/malmo/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | xvfb-run -a -e /dev/stdout -s '-screen 0 1400x900x24' $*
--------------------------------------------------------------------------------
/docker/malmopy-ai-challenge/docker-compose.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | version: '3'
19 | services:
20 | malmo1:
21 | image: malmo:latest
22 | expose:
23 | - "10000"
24 | malmo2:
25 | image: malmo:latest
26 | expose:
27 | - "10000"
28 | agents:
29 | image: malmopy-cntk-cpu-py27:latest
30 | working_dir: /root/malmo-challenge/ai_challenge/pig_chase
31 | command: bash -c "python pig_chase_baseline.py malmo1:10000 malmo2:10000 & tensorboard --logdir 'results' --port 6006"
32 | ports:
33 | - "6006:6006"
34 | links:
35 | - malmo1
36 | - malmo2
37 |
--------------------------------------------------------------------------------
/docker/malmopy-chainer-cpu/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | FROM ubuntu:16.04
19 |
20 | # Version variables
21 | ENV MALMO_VERSION 0.21.0
22 | ENV MALMOPY_VERSION 0.1.0
23 |
24 | RUN apt-get update -y && \
25 | apt-get install -y --no-install-recommends \
26 | build-essential \
27 | python-dev \
28 | python-pip \
29 | python-setuptools \
30 | cmake \
31 | ssh \
32 | git-all \
33 | zlib1g-dev \
34 |
35 | # install Malmo dependencies
36 | libpython2.7 \
37 | lua5.1 \
38 | libxerces-c3.1 \
39 | liblua5.1-0-dev \
40 | libav-tools \
41 | python-tk \
42 | python-imaging-tk \
43 | wget \
44 | unzip && \
45 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
46 |
47 | RUN pip install -U pip setuptools && pip install wheel && pip install chainer==1.21.0
48 |
49 | # download and unpack Malmo
50 | WORKDIR /root
51 | RUN wget https://github.com/Microsoft/malmo/releases/download/$MALMO_VERSION/Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
52 | unzip Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
53 | rm Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
54 | mv Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost Malmo
55 |
56 | ENV MALMO_XSD_PATH /root/Malmo/Schemas
57 | ENV PYTHONPATH /root/Malmo/Python_Examples
58 |
59 | # add and install malmopy, malmo challenge task and samples
60 | WORKDIR /root
61 | RUN git clone https://github.com/Microsoft/malmo-challenge.git && \
62 | cd malmo-challenge && \
63 | git checkout tags/$MALMOPY_VERSION -b latest
64 | WORKDIR /root/malmo-challenge
65 | RUN pip install -e '.[all]'
66 |
--------------------------------------------------------------------------------
/docker/malmopy-chainer-gpu/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | FROM nvidia/cuda:8.0-cudnn5-devel-ubuntu16.04
19 |
20 |
21 | # Version variables
22 | ENV MALMO_VERSION 0.21.0
23 | ENV MALMOPY_VERSION 0.1.0
24 |
25 | RUN apt-get update -y && \
26 | apt-get install -y --no-install-recommends \
27 | build-essential \
28 | python-dev \
29 | python-pip \
30 | python-setuptools \
31 | cmake \
32 | ssh \
33 | git-all \
34 | zlib1g-dev \
35 |
36 | # install Malmo dependencies
37 | libpython2.7 \
38 | lua5.1 \
39 | libxerces-c3.1 \
40 | liblua5.1-0-dev \
41 | libav-tools \
42 | python-tk \
43 | python-imaging-tk \
44 | wget \
45 | unzip && \
46 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
47 |
48 | RUN pip install -U pip setuptools && pip install wheel && pip install chainer==1.21.0
49 |
50 | # download and unpack Malmo
51 | WORKDIR /root
52 | RUN wget https://github.com/Microsoft/malmo/releases/download/$MALMO_VERSION/Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
53 | unzip Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
54 | rm Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
55 | mv Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost Malmo
56 |
57 | ENV MALMO_XSD_PATH /root/Malmo/Schemas
58 | ENV PYTHONPATH /root/Malmo/Python_Examples
59 |
60 | # add and install malmopy, malmo challenge task and samples
61 | WORKDIR /root
62 | RUN git clone https://github.com/Microsoft/malmo-challenge.git && \
63 | cd malmo-challenge && \
64 | git checkout tags/$MALMOPY_VERSION -b latest
65 | WORKDIR /root/malmo-challenge
66 | RUN pip install -e '.[all]'
67 |
--------------------------------------------------------------------------------
/docker/malmopy-cntk-cpu-py27/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | FROM microsoft/cntk:2.0.beta15.0-cpu-python2.7
19 |
20 | # Version variables
21 | ENV MALMO_VERSION 0.21.0
22 | ENV MALMOPY_VERSION 0.1.0
23 |
24 | RUN apt-get update -y && \
25 | apt-get install -y --no-install-recommends \
26 | build-essential \
27 | cmake \
28 | ssh \
29 | git-all \
30 | zlib1g-dev \
31 | python-dev \
32 | python-pip \
33 |
34 | # install Malmo dependencies
35 | libpython2.7 \
36 | openjdk-7-jdk \
37 | lua5.1 \
38 | libxerces-c3.1 \
39 | liblua5.1-0-dev \
40 | libav-tools \
41 | python-tk \
42 | python-imaging-tk \
43 | wget \
44 | unzip && \
45 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
46 |
47 | # Set CNTK Python PATH at first position to be picked automatically
48 | ENV PATH=/root/anaconda3/envs/cntk-py27/bin:$PATH
49 |
50 | # Update pip
51 | RUN /root/anaconda3/envs/cntk-py27/bin/pip install --upgrade pip
52 |
53 | # download and unpack Malmo
54 | WORKDIR /root
55 | RUN wget https://github.com/Microsoft/malmo/releases/download/$MALMO_VERSION/Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
56 | unzip Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
57 | rm Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
58 | mv Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost Malmo
59 |
60 | ENV MALMO_XSD_PATH /root/Malmo/Schemas
61 | ENV PYTHONPATH /root/Malmo/Python_Examples
62 |
63 | # add and install malmopy, malmo challenge task and samples
64 | WORKDIR /root
65 | RUN git clone https://github.com/Microsoft/malmo-challenge.git && \
66 | cd malmo-challenge && \
67 | git checkout tags/$MALMOPY_VERSION -b latest
68 | WORKDIR /root/malmo-challenge
69 | RUN pip install -e '.[all]'
70 |
--------------------------------------------------------------------------------
/docker/malmopy-cntk-gpu-py27/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | FROM microsoft/cntk:2.0.beta15.0-gpu-python2.7-cuda8.0-cudnn5.1
19 |
20 | # Version variables
21 | ENV MALMO_VERSION 0.21.0
22 | ENV MALMOPY_VERSION 0.1.0
23 |
24 | RUN apt-get update -y && \
25 | apt-get install -y --no-install-recommends \
26 | build-essential \
27 | cmake \
28 | ssh \
29 | git-all \
30 | zlib1g-dev \
31 | python-dev \
32 | python-pip \
33 |
34 | # install Malmo dependencies
35 | libpython2.7 \
36 | openjdk-7-jdk \
37 | lua5.1 \
38 | libxerces-c3.1 \
39 | liblua5.1-0-dev \
40 | libav-tools \
41 | python-tk \
42 | python-imaging-tk \
43 | wget \
44 | unzip && \
45 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
46 |
47 | # Set CNTK Python PATH at first position to be picked automatically
48 | ENV PATH=/root/anaconda3/envs/cntk-py27/bin:$PATH
49 |
50 | # Update pip
51 | RUN /root/anaconda3/envs/cntk-py27/bin/pip install --upgrade pip
52 |
53 | # download and unpack Malmo
54 | WORKDIR /root
55 | RUN wget https://github.com/Microsoft/malmo/releases/download/$MALMO_VERSION/Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
56 | unzip Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
57 | rm Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost.zip && \
58 | mv Malmo-$MALMO_VERSION-Linux-Ubuntu-16.04-64bit_withBoost Malmo
59 |
60 | ENV MALMO_XSD_PATH /root/Malmo/Schemas
61 | ENV PYTHONPATH /root/Malmo/Python_Examples
62 |
63 | # add and install malmopy, malmo challenge task and samples
64 | WORKDIR /root
65 | RUN git clone https://github.com/Microsoft/malmo-challenge.git && \
66 | cd malmo-challenge && \
67 | git checkout tags/$MALMOPY_VERSION -b latest
68 | WORKDIR /root/malmo-challenge
69 | RUN pip install -e '.[all]'
70 |
--------------------------------------------------------------------------------
/malmopy/README.md:
--------------------------------------------------------------------------------
1 | ## Writing your first experiment
2 |
3 | The framework is designed to give you the flexibility you need to design and run your experiment.
4 | In this section you will see how easy it is to write a simple Atari/DQN experiment based on CNTK backend.
5 |
6 |
7 | ### Using with Microsoft Cognitive Network ToolKit (CNTK)
8 | To be able use CNTK from the framework, you will need first to install CNTK from the
9 | official repository [release page](https://github.com/Microsoft/CNTK/releases). Pick the
10 | right distribution according to your OS / Hardware configuration and plans to use distributed
11 | training sessions.
12 |
13 | The CNTK Python binding can be installed by running the installation script
14 | ([more information here](https://github.com/Microsoft/CNTK/wiki/Setup-CNTK-on-your-machine)).
15 | After following the installation process you should be able to import CNTK.
16 |
17 | ___Note that every time you will want to run experiment with CNTK you will need to activate the cntk-pyXX environment.___
18 |
19 | ### Getting started
20 |
21 | First of all, you need to import all the dependencies :
22 | ```python
23 | from malmopy.agent.qlearner import QLearnerAgent, TemporalMemory
24 | from malmopy.model.cntk import QNeuralNetwork
25 | from malmopy.environment.gym import GymEnvironment
26 |
27 | # In this example we will use the Breakout-v3 environment.
28 | env = GymEnvironment('Breakout-v3', monitoring_path='/directory/where/to/put/records')
29 |
30 | # Q Neural Network needs a Replay Memory to randomly sample minibatch.
31 | memory = TemporalMemory(1000000, (84, 84), 4)
32 |
33 | #Here a simple Deep Q Neural Network backed by CNTK runtime
34 | model = QNeuralNetwork((4, 84, 84), env.available_actions, device_id=-1)
35 |
36 | # We provide the number of action available, our model and the memory
37 | agent = QLearnerAgent("DQN Agent", env.available_actions, model, memory, 0.99, 32)
38 |
39 | reward = 0
40 | done = False
41 |
42 | # Remplace range by xrange if running Python 2
43 | while True:
44 |
45 | # Reset environment if needed
46 | if env.done:
47 | current_state = env.reset()
48 |
49 | action = agent.act(current_state, reward, done, True)
50 | new_state, reward, done = env.do(action)
51 | ```
52 |
53 | ## Some comments:
54 | - The GymEnvironment monitoring_path is used to record short epsiode videos of the agent
55 | - Temporal Memory generates a sample w.r.t to the history_length previous state
56 | - For example with history_length = 4 a sample is [s(t-3), s(t-2), s(t-1), s(t)]
57 | - QNeuralNetwork input_shape is the shape of a sample from the TemporalMemory (history_length, width, height)
58 | - QNeuralNetwork output_shape is the number of actions available for the environment (one neuron per action)
59 | - QNeuralNetwork device_id == -1 indicate 'Run on CPU', anything >=0 refers to a GPU device ID
60 |
--------------------------------------------------------------------------------
/malmopy/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
--------------------------------------------------------------------------------
/malmopy/agent/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from .agent import BaseAgent, RandomAgent, ConsoleAgent, ReplayMemory
21 | from .astar import AStarAgent
22 | from .explorer import BaseExplorer, LinearEpsilonGreedyExplorer
23 | from .qlearner import QLearnerAgent, History, ReplayMemory, TemporalMemory
24 |
25 | __all__ = ['agent', 'astar', 'qlearner', 'explorer']
26 |
--------------------------------------------------------------------------------
/malmopy/agent/agent.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | import os
21 | import sys
22 | from collections import Iterable
23 |
24 | import numpy as np
25 |
26 | from ..visualization import Visualizable
27 |
28 |
29 | class BaseAgent(Visualizable):
30 | """
31 | Represents an agent that interacts with an environment
32 | """
33 |
34 | def __init__(self, name, nb_actions, visualizer=None):
35 | assert nb_actions > 0, 'Agent should at least have 1 action (got %d)' % nb_actions
36 |
37 | super(BaseAgent, self).__init__(visualizer)
38 |
39 | self.name = name
40 | self.nb_actions = nb_actions
41 |
42 | def act(self, new_state, reward, done, is_training=False):
43 | raise NotImplementedError()
44 |
45 | def save(self, out_dir):
46 | pass
47 |
48 | def load(self, out_dir):
49 | pass
50 |
51 | def inject_summaries(self, idx):
52 | pass
53 |
54 |
55 | class RandomAgent(BaseAgent):
56 | """
57 | An agent that selects actions uniformly at random
58 | """
59 |
60 | def __init__(self, name, nb_actions, delay_between_action=0, visualizer=None):
61 | super(RandomAgent, self).__init__(name, nb_actions, visualizer)
62 |
63 | self._delay = delay_between_action
64 |
65 | def act(self, new_state, reward, done, is_training=False):
66 | if self._delay > 0:
67 | from time import sleep
68 | sleep(self._delay)
69 |
70 | return np.random.randint(0, self.nb_actions)
71 |
72 |
73 | class ConsoleAgent(BaseAgent):
74 | """ Provide a console interface for mediating human interaction with
75 | an environment
76 |
77 | Users are prompted for input when an action is required:
78 |
79 | Agent-1, what do you want to do?
80 | 1: action1
81 | 2: action2
82 | 3: action3
83 | ...
84 | N: actionN
85 | Agent-1: 2
86 | ...
87 | """
88 |
89 | def __init__(self, name, actions, stdin=None):
90 | assert isinstance(actions, Iterable), 'actions need to be iterable (e.g., list, tuple)'
91 | assert len(actions) > 0, 'actions need at least one element'
92 |
93 | super(ConsoleAgent, self).__init__(name, len(actions))
94 |
95 | self._actions = actions
96 |
97 | if stdin is not None:
98 | sys.stdin = os.fdopen(stdin)
99 |
100 | def act(self, new_state, reward, done, is_training=False):
101 | user_action = None
102 |
103 | while user_action is None:
104 | self._print_choices()
105 | try:
106 | user_input = input("%s: " % self.name)
107 | user_action = int(user_input)
108 | if user_action < 0 or user_action > len(self._actions) - 1:
109 | user_action = None
110 | print("Provided input is not valid should be [0, %d]" % (len(self._actions) - 1))
111 | except ValueError:
112 | user_action = None
113 | print("Provided input is not valid should be [0, %d]" % (len(self._actions) - 1))
114 |
115 | return user_action
116 |
117 | def _print_choices(self):
118 | print("\n%s What do you want to do?" % self.name)
119 |
120 | for idx, action in enumerate(self._actions):
121 | print("\t%d : %s" % (idx, action))
122 |
123 |
124 | class ReplayMemory(object):
125 | """
126 | Simple representation of agent memory
127 | """
128 |
129 | def __init__(self, max_size, state_shape):
130 | assert max_size > 0, 'size should be > 0 (got %d)' % max_size
131 |
132 | self._pos = 0
133 | self._count = 0
134 | self._max_size = max_size
135 | self._state_shape = state_shape
136 | self._states = np.empty((max_size,) + state_shape, dtype=np.float32)
137 | self._actions = np.empty(max_size, dtype=np.uint8)
138 | self._rewards = np.empty(max_size, dtype=np.float32)
139 | self._terminals = np.empty(max_size, dtype=np.bool)
140 |
141 | def append(self, state, action, reward, is_terminal):
142 | """
143 | Appends the specified memory to the history.
144 | :param state: The state to append (should have the same shape as defined at initialization time)
145 | :param action: An integer representing the action done
146 | :param reward: An integer reprensenting the reward received for doing this action
147 | :param is_terminal: A boolean specifying if this state is a terminal (episode has finished)
148 | :return:
149 | """
150 | assert state.shape == self._state_shape, \
151 | 'Invalid state shape (required: %s, got: %s)' % (self._state_shape, state.shape)
152 |
153 | self._states[self._pos, ...] = state
154 | self._actions[self._pos] = action
155 | self._rewards[self._pos] = reward
156 | self._terminals[self._pos] = is_terminal
157 |
158 | self._count = max(self._count, self._pos + 1)
159 | self._pos = (self._pos + 1) % self._max_size
160 |
161 | def __len__(self):
162 | """
163 | Number of elements currently stored in the memory (same as #size())
164 | See #size()
165 | :return: Integer : max_size >= size() >= 0
166 | """
167 | return self.size
168 |
169 | @property
170 | def last(self):
171 | """
172 | Return the last observation from the memory
173 | :return: Tuple (state, action, reward, terminal)
174 | """
175 | idx = self._pos
176 | return self._states[idx], self._actions[idx], self._rewards[idx], self._terminals[idx]
177 |
178 | @property
179 | def size(self):
180 | """
181 | Number of elements currently stored in the memory
182 | :return: Integer : max_size >= size >= 0
183 | """
184 | return self._count
185 |
186 | @property
187 | def max_size(self):
188 | """
189 | Maximum number of elements that can fit in the memory
190 | :return: Integer > 0
191 | """
192 | return self._max_size
193 |
194 | @property
195 | def history_length(self):
196 | """
197 | Number of states stacked along the first axis
198 | :return: int >= 1
199 | """
200 | return 1
201 |
202 | def sample(self, size, replace=False):
203 | """
204 | Generate a random sample of desired size (if available) from the current memory
205 | :param size: Number of samples
206 | :param replace: True if sampling with replacement
207 | :return: Integer[size] representing the sampled indices
208 | """
209 | return np.random.choice(self._count, size, replace=replace)
210 |
211 | def get_state(self, index):
212 | """
213 | Return the specified state
214 | :param index: State's index
215 | :return: state : (input_shape)
216 | """
217 | index %= self.size
218 | return self._states[index]
219 |
220 | def get_action(self, index):
221 | """
222 | Return the specified action
223 | :param index: Action's index
224 | :return: Integer
225 | """
226 | index %= self.size
227 | return self._actions[index]
228 |
229 | def get_reward(self, index):
230 | """
231 | Return the specified reward
232 | :param index: Reward's index
233 | :return: Integer
234 | """
235 | index %= self.size
236 | return self._rewards[index]
237 |
238 | def minibatch(self, size):
239 | """
240 | Generate a minibatch with the number of samples specified by the size parameter.
241 | :param size: Minibatch size
242 | :return: Tensor[minibatch_size, input_shape...)
243 | """
244 | indexes = self.sample(size)
245 |
246 | pre_states = np.array([self.get_state(index) for index in indexes], dtype=np.float32)
247 | post_states = np.array([self.get_state(index + 1) for index in indexes], dtype=np.float32)
248 | actions = self._actions[indexes]
249 | rewards = self._rewards[indexes]
250 | terminals = self._terminals[indexes]
251 |
252 | return pre_states, actions, post_states, rewards, terminals
253 |
254 | def save(self, out_file):
255 | """
256 | Save the current memory into a file in Numpy format
257 | :param out_file: File storage path
258 | :return:
259 | """
260 | np.savez_compressed(out_file, states=self._states, actions=self._actions,
261 | rewards=self._rewards, terminals=self._terminals)
262 |
263 | def load(self, in_dir):
264 | pass
265 |
--------------------------------------------------------------------------------
/malmopy/agent/astar.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from heapq import heapify, heappop, heappush
21 | from collections import deque
22 |
23 | from . import BaseAgent
24 |
25 |
26 | class AStarAgent(BaseAgent):
27 | def __init__(self, name, nb_actions, visualizer=None):
28 | super(AStarAgent, self).__init__(name, nb_actions, visualizer)
29 |
30 | def _find_shortest_path(self, start, end, **kwargs):
31 | came_from, cost_so_far = {}, {}
32 | explorer = []
33 | heapify(explorer)
34 |
35 | heappush(explorer, (0, start))
36 | came_from[start] = None
37 | cost_so_far[start] = 0
38 | current = None
39 |
40 | while len(explorer) > 0:
41 | _, current = heappop(explorer)
42 |
43 | if self.matches(current, end):
44 | break
45 |
46 | for nb in self.neighbors(current, **kwargs):
47 | cost = nb.cost if hasattr(nb, "cost") else 1
48 | new_cost = cost_so_far[current] + cost
49 |
50 | if nb not in cost_so_far or new_cost < cost_so_far[nb]:
51 | cost_so_far[nb] = new_cost
52 | priority = new_cost + self.heuristic(end, nb, **kwargs)
53 | heappush(explorer, (priority, nb))
54 | came_from[nb] = current
55 |
56 | # build path:
57 | path = deque()
58 | while current is not start:
59 | path.appendleft(current)
60 | current = came_from[current]
61 | return path, cost_so_far
62 |
63 | def neighbors(self, pos, **kwargs):
64 | raise NotImplementedError()
65 |
66 | def heuristic(self, a, b, **kwargs):
67 | raise NotImplementedError()
68 |
69 | def matches(self, a, b):
70 | return a == b
71 |
--------------------------------------------------------------------------------
/malmopy/agent/explorer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | import numpy as np
21 |
22 |
23 | class BaseExplorer:
24 | """ Explore/exploit logic wrapper"""
25 |
26 | def __call__(self, step, nb_actions):
27 | return self.explore(step, nb_actions)
28 |
29 | def is_exploring(self, step):
30 | """ Returns True when exploring, False when exploiting """
31 | raise NotImplementedError()
32 |
33 | def explore(self, step, nb_actions):
34 | """ Generate an exploratory action """
35 | raise NotImplementedError()
36 |
37 |
38 | class LinearEpsilonGreedyExplorer(BaseExplorer):
39 | """ Explore/exploit logic wrapper
40 |
41 |
42 | This implementation uses linear interpolation between
43 | epsilon_max and epsilon_min to linearly anneal epsilon as a function of the current episode.
44 |
45 | 3 cases exists:
46 | - If 0 <= episode < eps_min_time then epsilon = interpolator(episode)
47 | - If episode >= eps_min_time then epsilon then epsilon = eps_min
48 | - Otherwise epsilon = eps_max
49 | """
50 |
51 | def __init__(self, eps_max, eps_min, eps_min_time):
52 | assert eps_max > eps_min
53 | assert eps_min_time > 0
54 |
55 | self._eps_min_time = eps_min_time
56 | self._eps_min = eps_min
57 | self._eps_max = eps_max
58 |
59 | self._a = -(eps_max - eps_min) / eps_min_time
60 |
61 | def _epsilon(self, step):
62 | if step < 0:
63 | return self._eps_max
64 | elif step > self._eps_min_time:
65 | return self._eps_min
66 | else:
67 | return self._a * step + self._eps_max
68 |
69 | def is_exploring(self, step):
70 | return np.random.rand() < self._epsilon(step)
71 |
72 | def explore(self, step, nb_actions):
73 | return np.random.randint(0, nb_actions)
74 |
--------------------------------------------------------------------------------
/malmopy/agent/gui.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | import six
21 | if six.PY2:
22 | from Tkinter import Tk
23 | else:
24 | from tkinter import Tk
25 |
26 | from . import BaseAgent
27 | from ..environment import VideoCapableEnvironment
28 |
29 | FPS_KEYS_MAPPING = {'w': 'move 1', 'a': 'strafe -1', 's': 'move -1', 'd': 'strafe 1', ' ': 'jump 1',
30 | 'q': 'strafe -1', 'z': 'move 1'}
31 |
32 | ARROW_KEYS_MAPPING = {'Left': 'turn -1', 'Right': 'turn 1', 'Up': 'move 1', 'Down': 'move -1'}
33 |
34 | CONTINUOUS_KEYS_MAPPING = {'Shift_L': 'crouch 1', 'Shift_R': 'crouch 1',
35 | '1': 'hotbar.1 1', '2': 'hotbar.2 1', '3': 'hotbar.3 1', '4': 'hotbar.4 1',
36 | '5': 'hotbar.5 1',
37 | '6': 'hotbar.6 1', '7': 'hotbar.7 1', '8': 'hotbar.8 1', '9': 'hotbar.9 1'} \
38 | .update(ARROW_KEYS_MAPPING)
39 |
40 | DISCRETE_KEYS_MAPPING = {'Left': 'turn -1', 'Right': 'turn 1', 'Up': 'move 1', 'Down': 'move -1',
41 | '1': 'hotbar.1 1', '2': 'hotbar.2 1', '3': 'hotbar.3 1', '4': 'hotbar.4 1', '5': 'hotbar.5 1',
42 | '6': 'hotbar.6 1', '7': 'hotbar.7 1', '8': 'hotbar.8 1', '9': 'hotbar.9 1'}
43 |
44 |
45 | class GuiAgent(BaseAgent):
46 | def __init__(self, name, environment, keymap, win_name="Gui Agent", size=(640, 480), visualizer=None):
47 | assert isinstance(keymap, list), 'keymap should be a list[character]'
48 | assert isinstance(environment, VideoCapableEnvironment), 'environment should inherit from BaseEnvironment'
49 |
50 | super(GuiAgent, self).__init__(name, environment.available_actions, visualizer)
51 |
52 | if not environment.recording:
53 | environment.recording = True
54 |
55 | self._env = environment
56 | self._keymap = keymap
57 | self._tick = 20
58 |
59 | self._root = Tk()
60 | self._root.wm_title = win_name
61 | self._root.resizable(width=False, height=False)
62 | self._root.geometry = "%dx%d" % size
63 |
64 | self._build_layout(self._root)
65 |
66 | def act(self, new_state, reward, done, is_training=False):
67 | pass
68 |
69 | def show(self):
70 | self._root.mainloop()
71 |
72 | def _build_layout(self, root):
73 | """
74 | Build the window layout
75 | :param root:
76 | :return:
77 | """
78 | raise NotImplementedError()
79 |
80 | def _get_keymapping_help(self):
81 | return self._keymap
82 |
--------------------------------------------------------------------------------
/malmopy/agent/qlearner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from collections import namedtuple
21 |
22 | import numpy as np
23 |
24 | from . import BaseAgent, ReplayMemory, BaseExplorer, LinearEpsilonGreedyExplorer
25 | from ..model import QModel
26 | from ..util import get_rank
27 |
28 |
29 | class TemporalMemory(ReplayMemory):
30 | """
31 | Temporal memory adds a new dimension to store N previous samples (t, t-1, t-2, ..., t-N)
32 | when sampling from memory
33 | """
34 |
35 | def __init__(self, max_size, sample_shape, history_length=4,
36 | unflicker=False):
37 | """
38 | :param max_size: Maximum number of elements in the memory
39 | :param sample_shape: Shape of each sample
40 | :param history_length: Length of the visual memory (n previous frames) included with each state
41 | :param unflicker: Indicate if we need to compute the difference between consecutive frames
42 | """
43 | super(TemporalMemory, self).__init__(max_size, sample_shape)
44 |
45 | self._unflicker = unflicker
46 | self._history_length = max(1, history_length)
47 | self._last = np.zeros(sample_shape)
48 |
49 | def append(self, state, action, reward, is_terminal):
50 | if self._unflicker:
51 | max_diff_buffer = np.maximum(self._last, state)
52 | self._last = state
53 | state = max_diff_buffer
54 |
55 | super(TemporalMemory, self).append(state, action, reward, is_terminal)
56 |
57 | if is_terminal:
58 | if self._unflicker:
59 | self._last.fill(0)
60 |
61 | def sample(self, size, replace=True):
62 | """
63 | Generate a random minibatch. The returned indices can be retrieved using #get_state().
64 | See the method #minibatch() if you want to retrieve samples directly
65 | :param size: The minibatch size
66 | :param replace: Indicate if one index can appear multiple times (True), only once (False)
67 | :return: Indexes of the sampled states
68 | """
69 |
70 | if not replace:
71 | assert (self._count - 1) - self._history_length >= size, \
72 | 'Cannot sample %d from %d elements' % (
73 | size, (self._count - 1) - self._history_length)
74 |
75 | # Local variable access are faster in loops
76 | count, pos, history_len, terminals = self._count - 1, self._pos, \
77 | self._history_length, self._terminals
78 | indexes = []
79 |
80 | while len(indexes) < size:
81 | index = np.random.randint(history_len, count)
82 |
83 | # Check if replace=False to not include same index multiple times
84 | if replace or index not in indexes:
85 |
86 | # if not wrapping over current pointer,
87 | # then check if there is terminal state wrapped inside
88 | if not (index >= pos > index - history_len):
89 | if not terminals[(index - history_len):index].any():
90 | indexes.append(index)
91 |
92 | assert len(indexes) == size
93 | return indexes
94 |
95 | def get_state(self, index):
96 | """
97 | Return the specified state with the visual memory
98 | :param index: State's index
99 | :return: Tensor[history_length, input_shape...]
100 | """
101 | index %= self._count
102 | history_length = self._history_length
103 |
104 | # If index > history_length, take from a slice
105 | if index >= history_length:
106 | return self._states[(index - (history_length - 1)):index + 1, ...]
107 | else:
108 | indexes = np.arange(index - self._history_length + 1, index + 1)
109 | return self._states.take(indexes, mode='wrap', axis=0)
110 |
111 | @property
112 | def unflicker(self):
113 | """
114 | Indicate if samples added to the replay memory are preprocessed
115 | by taking the maximum between current frame and previous one
116 | :return: True if preprocessed, False otherwise
117 | """
118 | return self._unflicker
119 |
120 | @property
121 | def history_length(self):
122 | """
123 | Visual memory length
124 | (ie. the number of previous frames included for each sample)
125 | :return: Integer >= 0
126 | """
127 | return self._history_length
128 |
129 |
130 | class History(object):
131 | """
132 | Accumulator keeping track of the N previous frames to be used by the agent
133 | for evaluation
134 | """
135 |
136 | def __init__(self, shape):
137 | self._buffer = np.zeros(shape, dtype=np.float32)
138 |
139 | @property
140 | def value(self):
141 | return self._buffer
142 |
143 | def append(self, state):
144 | self._buffer[:-1] = self._buffer[1:]
145 | self._buffer[-1, ...] = state
146 |
147 | def reset(self):
148 | self._buffer.fill(0)
149 |
150 |
151 | # Track previous state and action for observation
152 | Tracker = namedtuple('Tracker', ['state', 'action'])
153 |
154 |
155 | class QLearnerAgent(BaseAgent):
156 | def __init__(self, name, nb_actions, model, memory, gamma=.99,
157 | minibatch_size=32, train_after=50000, train_frequency=4,
158 | explorer=None, reward_clipping=None, visualizer=None):
159 |
160 | assert isinstance(model, QModel), 'model should inherit from QModel'
161 | assert get_rank(model.input_shape) > 1, 'input_shape rank should be > 1'
162 | assert isinstance(memory, ReplayMemory), 'memory should inherit from ' \
163 | 'ReplayMemory'
164 | assert 0 < gamma < 1, 'gamma should be 0 < gamma < 1 (got: %d)' % gamma
165 | assert minibatch_size > 0, 'minibatch_size should be > 0 (got: %d)' % minibatch_size
166 | assert train_after >= 0, 'train_after should be >= 0 (got %d)' % train_after
167 | assert train_frequency > 0, 'train_frequency should be > 0'
168 |
169 | super(QLearnerAgent, self).__init__(name, nb_actions, visualizer)
170 |
171 | self._model = model
172 | self._memory = memory
173 | self._gamma = gamma
174 | self._minibatch_size = minibatch_size
175 | self._train_after = train_after
176 | self._train_frequency = train_frequency
177 | self._history = History(model.input_shape)
178 | self._actions_taken = 0
179 | self._tracker = None
180 |
181 | # Rewards clipping related
182 | reward_clipping = reward_clipping or (-2 ** 31 - 1, 2 ** 31 - 1)
183 | assert isinstance(reward_clipping, tuple) and len(reward_clipping) == 2, \
184 | 'clip_reward should be None or (min_reward, max_reward)'
185 |
186 | assert reward_clipping[0] <= reward_clipping[1], \
187 | 'max reward_clipping should be >= min (got %d < %d)' % (
188 | reward_clipping[1], reward_clipping[0])
189 |
190 | self._reward_clipping = reward_clipping
191 |
192 | # Explorer related
193 | explorer = explorer or LinearEpsilonGreedyExplorer(1, 0.1, 1e6)
194 | assert isinstance(explorer, BaseExplorer), \
195 | 'explorer should inherit from BaseExplorer'
196 |
197 | self._explorer = explorer
198 |
199 | # Stats related
200 | self._stats_rewards = []
201 | self._stats_mean_qvalues = []
202 | self._stats_stddev_qvalues = []
203 | self._stats_loss = []
204 |
205 | def act(self, new_state, reward, done, is_training=False):
206 |
207 | if self._tracker is not None:
208 | self.observe(self._tracker.state, self._tracker.action,
209 | reward, new_state, done)
210 |
211 | if is_training:
212 | if self._actions_taken > self._train_after:
213 | self.learn()
214 |
215 | # Append the new state to the history
216 | self._history.append(new_state)
217 |
218 | # select the next action
219 | if self._explorer.is_exploring(self._actions_taken):
220 | new_action = self._explorer(self._actions_taken, self.nb_actions)
221 | else:
222 | q_values = self._model.evaluate(self._history.value)
223 | new_action = q_values.argmax()
224 |
225 | self._stats_mean_qvalues.append(q_values.max())
226 | self._stats_stddev_qvalues.append(np.std(q_values))
227 |
228 | self._tracker = Tracker(new_state, new_action)
229 | self._actions_taken += 1
230 |
231 | return new_action
232 |
233 | def observe(self, old_state, action, reward, new_state, is_terminal):
234 | if is_terminal:
235 | self._history.reset()
236 |
237 | min_val, max_val = self._reward_clipping
238 | reward = max(min_val, min(max_val, reward))
239 | self._memory.append(old_state, int(action), reward, is_terminal)
240 |
241 | def learn(self):
242 | if (self._actions_taken % self._train_frequency) == 0:
243 | minibatch = self._memory.minibatch(self._minibatch_size)
244 | q_t_target = self._compute_q(*minibatch)
245 |
246 | self._model.train(minibatch[0], q_t_target, minibatch[1])
247 | self._stats_loss.append(self._model.loss_val)
248 |
249 | def inject_summaries(self, idx):
250 | if len(self._stats_mean_qvalues) > 0:
251 | self.visualize(idx, "%s/episode mean q" % self.name,
252 | np.asscalar(np.mean(self._stats_mean_qvalues)))
253 | self.visualize(idx, "%s/episode mean stddev.q" % self.name,
254 | np.asscalar(np.mean(self._stats_stddev_qvalues)))
255 |
256 | if len(self._stats_loss) > 0:
257 | self.visualize(idx, "%s/episode mean loss" % self.name,
258 | np.asscalar(np.mean(self._stats_loss)))
259 |
260 | if len(self._stats_rewards) > 0:
261 | self.visualize(idx, "%s/episode mean reward" % self.name,
262 | np.asscalar(np.mean(self._stats_rewards)))
263 |
264 | # Reset
265 | self._stats_mean_qvalues = []
266 | self._stats_stddev_qvalues = []
267 | self._stats_loss = []
268 | self._stats_rewards = []
269 |
270 | def _compute_q(self, pres, actions, posts, rewards, terminals):
271 | """ Compute the Q Values from input states """
272 |
273 | q_hat = self._model.evaluate(posts, model=QModel.TARGET_NETWORK)
274 | q_hat_eval = q_hat[np.arange(len(actions)), q_hat.argmax(axis=1)]
275 |
276 | q_targets = (1 - terminals) * (self._gamma * q_hat_eval) + rewards
277 | return np.array(q_targets, dtype=np.float32)
278 |
--------------------------------------------------------------------------------
/malmopy/environment/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from .environment import BaseEnvironment, VideoCapableEnvironment
21 |
--------------------------------------------------------------------------------
/malmopy/environment/environment.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | import numpy as np
21 |
22 | from ..util import check_rank, get_rank, resize, rgb2gray
23 |
24 |
25 | class StateBuilder(object):
26 | """
27 | StateBuilder are object that map environment state into another representation.
28 |
29 | Subclasses should override the build() method which can map specific environment behavior.
30 | For concrete examples, malmo package has some predefined state builder specific to Malmo
31 | """
32 |
33 | def build(self, environment):
34 | raise NotImplementedError()
35 |
36 | def __call__(self, *args, **kwargs):
37 | return self.build(*args)
38 |
39 |
40 | class ALEStateBuilder(StateBuilder):
41 | """
42 | Atari Environment state builder interface.
43 |
44 | This class assumes the environment.state() returns a numpy array.
45 | """
46 |
47 | SCALE_FACTOR = 1. / 255.
48 |
49 | def __init__(self, shape=(84, 84), normalize=True):
50 | self._shape = shape
51 | self._normalize = bool(normalize)
52 |
53 | def build(self, environment):
54 | if not isinstance(environment, np.ndarray):
55 | raise ValueError(
56 | 'environment type is not a numpy.ndarray (got %s)' % str(
57 | type(environment)))
58 |
59 | state = environment
60 |
61 | # Convert to gray
62 | if check_rank(environment.shape, 3):
63 | state = rgb2gray(environment)
64 | elif get_rank(state) > 3:
65 | raise ValueError('Cannot handle data with more than 3 dimensions')
66 |
67 | # Resize
68 | if state.shape != self._shape:
69 | state = resize(state, self._shape)
70 |
71 | return (state * ALEStateBuilder.SCALE_FACTOR).astype(np.float32)
72 |
73 |
74 | class BaseEnvironment(object):
75 | """
76 | Abstract representation of an interactive environment
77 | """
78 |
79 | def __init__(self):
80 | self._score = 0.
81 | self._reward = 0.
82 | self._done = False
83 | self._state = None
84 |
85 | def do(self, action):
86 | """
87 | Do the specified action in the environment
88 | :param action: The action to be executed
89 | :return Tuple holding the new state, the reward and a flag indicating if the environment is done
90 | """
91 | raise NotImplementedError()
92 |
93 | def reset(self):
94 | """
95 | Reset the current environment's internal state.
96 | :return:
97 | """
98 | self._score = 0.
99 | self._reward = 0.
100 | self._done = False
101 | self._state = None
102 |
103 | @property
104 | def available_actions(self):
105 | """
106 | Returns the number of actions available in this environment
107 | :return: Integer > 0
108 | """
109 | raise NotImplementedError()
110 |
111 | @property
112 | def done(self):
113 | """
114 | Indicate if the current environment is in a terminal state
115 | :return: Boolean True if environment is in a terminal state, False otherwise
116 | """
117 | return self._done
118 |
119 | @property
120 | def state(self):
121 | """
122 | Return the current environment state
123 | :return:
124 | """
125 | return self._state
126 |
127 | @property
128 | def reward(self):
129 | """
130 | Return accumulated rewards
131 | :return: Float as the current accumulated rewards since last state
132 | """
133 | return self._reward
134 |
135 | @property
136 | def score(self):
137 | """
138 | Return the environment's current score.
139 | It is common that the score will the sum of observed rewards, but subclasses can change this behavior.
140 | :return: Number
141 | """
142 | return self._score
143 |
144 | @property
145 | def is_turn_based(self):
146 | """
147 | Indicate if this environment is running on a turn-based scenario (i.e.,
148 | agents take turns and wait for other agents' turns to complete before taking the next action).
149 | All subclasses should override this accordingly to the running scenario.
150 | As currently turn based is not the default behavior, the value returned is False
151 | :return: False
152 | """
153 | return False
154 |
155 |
156 | class VideoCapableEnvironment(BaseEnvironment):
157 | """
158 | Represent the capacity of an environment to stream it's current state.
159 | Streaming relies on 2 properties :
160 | - fps : Number of frame this environment is able to generate each second
161 | - frame : The latest frame generated by this environment
162 | The display adapter should ask for a new frame with a 1/fps millisecond delay.
163 | If there is no updated frame, the frame property can return None.
164 | """
165 |
166 | def __init__(self):
167 | super(VideoCapableEnvironment, self).__init__()
168 | self._recording = False
169 |
170 | @property
171 | def recording(self):
172 | """
173 | Indicate if the current environment is dispatching the video stream
174 | :return: True if streaming, False otherwise
175 | """
176 | return self._recording
177 |
178 | @recording.setter
179 | def recording(self, val):
180 | """
181 | Change the internal recording state.
182 | :param val: True to activate video streaming, False otherwise
183 | :return:
184 | """
185 | self._recording = bool(val)
186 |
187 | @property
188 | def frame(self):
189 | """
190 | Return the most recent frame from the environment
191 | :return: PIL Image representing the current environment
192 | """
193 | raise NotImplementedError()
194 |
--------------------------------------------------------------------------------
/malmopy/environment/gym/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from .gym import GymEnvironment
21 |
--------------------------------------------------------------------------------
/malmopy/environment/gym/gym.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | import gym
21 | import numpy as np
22 | import six
23 | from PIL import Image
24 | from gym.wrappers import Monitor
25 |
26 | from ..environment import VideoCapableEnvironment, StateBuilder, ALEStateBuilder
27 |
28 |
29 | def need_record(episode_id):
30 | return episode_id % 1000 == 0
31 |
32 |
33 | class GymEnvironment(VideoCapableEnvironment):
34 | """
35 | Wraps an Open AI Gym environment
36 | """
37 |
38 | def __init__(self, env_name, state_builder=ALEStateBuilder(), repeat_action=4, no_op=30, monitoring_path=None):
39 | assert isinstance(state_builder, StateBuilder), 'state_builder should inherit from StateBuilder'
40 | assert isinstance(repeat_action, (int, tuple)), 'repeat_action should be int or tuple'
41 | if isinstance(repeat_action, int):
42 | assert repeat_action >= 1, "repeat_action should be >= 1"
43 | elif isinstance(repeat_action, tuple):
44 | assert len(repeat_action) == 2, 'repeat_action should be a length-2 tuple: (min frameskip, max frameskip)'
45 | assert repeat_action[0] < repeat_action[1], 'repeat_action[0] should be < repeat_action[1]'
46 |
47 | super(GymEnvironment, self).__init__()
48 |
49 | self._state_builder = state_builder
50 | self._env = gym.make(env_name)
51 | self._env.env.frameskip = repeat_action
52 | self._no_op = max(0, no_op)
53 | self._done = True
54 |
55 | if monitoring_path is not None:
56 | self._env = Monitor(self._env, monitoring_path, video_callable=need_record)
57 |
58 | @property
59 | def available_actions(self):
60 | return self._env.action_space.n
61 |
62 | @property
63 | def state(self):
64 | return None if self._state is None else self._state_builder(self._state)
65 |
66 | @property
67 | def lives(self):
68 | return self._env.env.ale.lives()
69 |
70 | @property
71 | def frame(self):
72 | return Image.fromarray(self._state)
73 |
74 | def do(self, action):
75 | self._state, self._reward, self._done, _ = self._env.step(action)
76 | self._score += self._reward
77 | return self.state, self._reward, self._done
78 |
79 | def reset(self):
80 | super(GymEnvironment, self).reset()
81 |
82 | self._state = self._env.reset()
83 |
84 | # Random number of initial no-op to introduce stochasticity
85 | if self._no_op > 0:
86 | for _ in six.moves.range(np.random.randint(1, self._no_op)):
87 | self._state, _, _, _ = self._env.step(0)
88 |
89 | return self.state
90 |
--------------------------------------------------------------------------------
/malmopy/environment/malmo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from .malmo import MalmoEnvironment, allocate_remotes
21 | from .malmo import MalmoStateBuilder, MalmoRGBStateBuilder, MalmoALEStateBuilder
22 |
--------------------------------------------------------------------------------
/malmopy/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from .model import *
21 |
22 |
--------------------------------------------------------------------------------
/malmopy/model/chainer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from .qlearning import *
21 |
--------------------------------------------------------------------------------
/malmopy/model/chainer/qlearning.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | import chainer.cuda as cuda
21 | import chainer.functions as F
22 | import chainer.links as L
23 | import numpy as np
24 | from chainer import ChainList
25 | from chainer.initializers import HeUniform
26 | from chainer.optimizers import Adam
27 | from chainer.serializers import save_npz, load_npz
28 |
29 | from ..model import QModel
30 | from ...util import check_rank, get_rank
31 |
32 |
33 | class ChainerModel(ChainList):
34 | """
35 | Wraps a Chainer Chain and enforces the model to be callable.
36 | Every model should override the __call__ method as a forward call.
37 | """
38 |
39 | def __init__(self, input_shape, output_shape):
40 | self.input_shape = input_shape
41 | self.output_shape = output_shape
42 |
43 | super(ChainerModel, self).__init__(*self._build_model())
44 |
45 | def __call__(self, *args, **kwargs):
46 | raise NotImplementedError()
47 |
48 | def _build_model(self):
49 | raise NotImplementedError()
50 |
51 |
52 | class MLPChain(ChainerModel):
53 | """
54 | Create a Multi Layer Perceptron neural network.
55 | The number of layers and units for each layer can be specified using hidden_layer_sizes.
56 | For example for a 128 units on the first hidden layer, then 256 on the second and 512 on the third:
57 |
58 | >>> MLPChain(input_shape=(28, 28), output_shape=10, hidden_layer_sizes=(128, 256, 512))
59 |
60 | Note : The network will contain len(hidden_layer_sizes) + 2 layers because
61 | of the input layer and the output layer.
62 | """
63 |
64 | def __init__(self, in_shape, output_shape,
65 | hidden_layer_sizes=(512, 512, 512), activation=F.relu):
66 | self._activation = activation
67 | self._hidden_layer_sizes = hidden_layer_sizes
68 |
69 | super(MLPChain, self).__init__(in_shape, output_shape)
70 |
71 | @property
72 | def hidden_layer_sizes(self):
73 | return self._hidden_layer_sizes
74 |
75 | def __call__(self, x):
76 | f = self._activation
77 |
78 | for layer in self[:-1]:
79 | x = f(layer(x))
80 | return self[-1](x)
81 |
82 | def _build_model(self):
83 | hidden_layers = [L.Linear(None, units) for units in
84 | self._hidden_layer_sizes]
85 | hidden_layers += [L.Linear(None, self.output_shape)]
86 |
87 | return hidden_layers
88 |
89 |
90 | class ReducedDQNChain(ChainerModel):
91 | """
92 | Simplified DQN topology:
93 |
94 | Convolution(64, kernel=(4, 4), strides=(2, 2)
95 | Convolution(64, kernel=(3, 3), strides=(1, 1)
96 | Dense(512)
97 | Dense(output_shape)
98 | """
99 | def __init__(self, in_shape, output_shape):
100 | super(ReducedDQNChain, self).__init__(in_shape, output_shape)
101 |
102 | def __call__(self, x):
103 | for layer in self[:-1]:
104 | x = F.relu(layer(x))
105 | return self[-1](x)
106 |
107 | def _build_model(self):
108 | initializer = HeUniform()
109 | in_shape = self.input_shape[0]
110 |
111 | return [L.Convolution2D(in_shape, 64, ksize=4, stride=2,
112 | initialW=initializer),
113 | L.Convolution2D(64, 64, ksize=3, stride=1,
114 | initialW=initializer),
115 | L.Linear(None, 512, initialW=HeUniform(0.1)),
116 | L.Linear(512, self.output_shape, initialW=HeUniform(0.1))]
117 |
118 |
119 | class DQNChain(ChainerModel):
120 | """
121 | DQN topology as in
122 | (Mnih & al. 2015): Human-level control through deep reinforcement learning"
123 | Nature 518.7540 (2015): 529-533.
124 |
125 | Convolution(32, kernel=(8, 8), strides=(4, 4)
126 | Convolution(64, kernel=(4, 4), strides=(2, 2)
127 | Convolution(64, kernel=(3, 3), strides=(1, 1)
128 | Dense(512)
129 | Dense(output_shape)
130 | """
131 |
132 | def __init__(self, in_shape, output_shape):
133 | super(DQNChain, self).__init__(in_shape, output_shape)
134 |
135 | def __call__(self, x):
136 | for layer in self[:-1]:
137 | x = F.relu(layer(x))
138 | return self[-1](x)
139 |
140 | def _build_model(self):
141 | initializer = HeUniform()
142 | in_shape = self.input_shape[0]
143 |
144 | return [L.Convolution2D(in_shape, 32, ksize=8, stride=4,
145 | initialW=initializer),
146 | L.Convolution2D(32, 64, ksize=4, stride=2,
147 | initialW=initializer),
148 | L.Convolution2D(64, 64, ksize=3, stride=1,
149 | initialW=initializer),
150 | L.Linear(7 * 7 * 64, 512, initialW=HeUniform(0.01)),
151 | L.Linear(512, self.output_shape, initialW=HeUniform(0.01))]
152 |
153 |
154 | class QNeuralNetwork(QModel):
155 | def __init__(self, model, target, device_id=-1,
156 | learning_rate=0.00025, momentum=.9,
157 | minibatch_size=32, update_interval=10000):
158 |
159 | assert isinstance(model, ChainerModel), \
160 | 'model should inherit from ChainerModel'
161 |
162 | super(QNeuralNetwork, self).__init__(model.input_shape,
163 | model.output_shape)
164 |
165 | self._gpu_device = None
166 | self._loss_val = 0
167 |
168 | # Target model update method
169 | self._steps = 0
170 | self._target_update_interval = update_interval
171 |
172 | # Setup model and target network
173 | self._minibatch_size = minibatch_size
174 | self._model = model
175 | self._target = target
176 | self._target.copyparams(self._model)
177 |
178 | # If GPU move to GPU memory
179 | if device_id >= 0:
180 | with cuda.get_device(device_id) as device:
181 | self._gpu_device = device
182 | self._model.to_gpu(device)
183 | self._target.to_gpu(device)
184 |
185 | # Setup optimizer
186 | self._optimizer = Adam(learning_rate, momentum, 0.999)
187 | self._optimizer.setup(self._model)
188 |
189 | def evaluate(self, environment, model=QModel.ACTION_VALUE_NETWORK):
190 | if check_rank(environment.shape, get_rank(self._input_shape)):
191 | environment = environment.reshape((1,) + environment.shape)
192 |
193 | # Move data if necessary
194 | if self._gpu_device is not None:
195 | environment = cuda.to_gpu(environment, self._gpu_device)
196 |
197 | if model == QModel.ACTION_VALUE_NETWORK:
198 | output = self._model(environment)
199 | else:
200 | output = self._target(environment)
201 |
202 | return cuda.to_cpu(output.data)
203 |
204 | def train(self, x, y, actions=None):
205 | actions = actions.astype(np.int32)
206 | batch_size = len(actions)
207 |
208 | if self._gpu_device:
209 | x = cuda.to_gpu(x, self._gpu_device)
210 | y = cuda.to_gpu(y, self._gpu_device)
211 | actions = cuda.to_gpu(actions, self._gpu_device)
212 |
213 | q = self._model(x)
214 | q_subset = F.reshape(F.select_item(q, actions), (batch_size, 1))
215 | y = y.reshape(batch_size, 1)
216 |
217 | loss = F.sum(F.huber_loss(q_subset, y, 1.0))
218 |
219 | self._model.cleargrads()
220 | loss.backward()
221 | self._optimizer.update()
222 |
223 | self._loss_val = np.asscalar(cuda.to_cpu(loss.data))
224 |
225 | # Keeps track of the number of train() calls
226 | self._steps += 1
227 | if self._steps % self._target_update_interval == 0:
228 | # copy weights
229 | self._target.copyparams(self._model)
230 |
231 | @property
232 | def loss_val(self):
233 | return self._loss_val # / self._minibatch_size
234 |
235 | def save(self, output_file):
236 | save_npz(output_file, self._model)
237 |
238 | def load(self, input_file):
239 | load_npz(input_file, self._model)
240 |
241 | # Copy parameter from model to target
242 | self._target.copyparams(self._model)
243 |
--------------------------------------------------------------------------------
/malmopy/model/cntk/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from .base import *
21 | from .qlearning import *
22 |
--------------------------------------------------------------------------------
/malmopy/model/cntk/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | import numpy as np
19 | from cntk.device import cpu, gpu, try_set_default_device
20 | from cntk.train.distributed import Communicator
21 | from cntk.learners import set_default_unit_gain_value
22 | from cntk.ops import abs, element_select, less, square, sqrt, reduce_sum, reduce_mean
23 |
24 | from ...visualization import Visualizable
25 |
26 |
27 | def rmse(y, y_hat, axis=0):
28 | """
29 | Compute the Root Mean Squared error as part of the model graph
30 |
31 | :param y: CNTK Variable holding the true value of Y
32 | :param y_hat: CNTK variable holding the estimated value of Y
33 | :param axis: The axis over which to compute the mean, 0 by default
34 | :return: Root Mean Squared error
35 | """
36 | return sqrt(reduce_mean(square(y_hat - y), axis=axis))
37 |
38 |
39 | def huber_loss(y_hat, y, delta):
40 | """
41 | Compute the Huber Loss as part of the model graph
42 |
43 | Huber Loss is more robust to outliers. It is defined as:
44 | if |y - h_hat| < delta :
45 | 0.5 * (y - y_hat)**2
46 | else :
47 | delta * |y - y_hat| - 0.5 * delta**2
48 |
49 | :param y: Target value
50 | :param y_hat: Estimated value
51 | :param delta: Outliers threshold
52 | :return: float
53 | """
54 | half_delta_squared = 0.5 * delta * delta
55 | error = y - y_hat
56 | abs_error = abs(error)
57 |
58 | less_than = 0.5 * square(error)
59 | more_than = (delta * abs_error) - half_delta_squared
60 |
61 | loss_per_sample = element_select(less(abs_error, delta), less_than, more_than)
62 |
63 | return reduce_sum(loss_per_sample, name='loss')
64 |
65 |
66 | def as_learning_rate_by_sample(learning_rate_per_minibatch, minibatch_size, momentum=0, momentum_as_unit_gain=False):
67 | """
68 | Compute the scale parameter for the learning rate to match the learning rate
69 | definition used in other deep learning frameworks.
70 | In CNTK, gradients are calculated as follows:
71 | g(t + 1) = momentum * v(t) + (1-momentum) * gradient(t)
72 |
73 | Whereas in other frameworks they are computed this way :
74 | g(t + 1) = momentum * v(t)
75 |
76 | According to the above equations we need to scale the learning rate with regard to the momentum by a
77 | factor of 1/(1 - momentum)
78 | :param learning_rate_per_minibatch: The current learning rate
79 | :param minibatch_size: Size of the minibatch
80 | :param momentum: The current momentum (0 by default, used only when momentum_as_unit_gain is True)
81 | :param momentum_as_unit_gain: Indicate whetherf the momentum is a unit gain factor (CNTK) or not (TensorFlow, etc.)
82 | :return: Scaled learning rate according to momentum and minibatch size
83 | """
84 | assert learning_rate_per_minibatch > 0, "learning_rate_per_minibatch cannot be < 0"
85 | assert minibatch_size > 0, "minibatch_size cannot be < 1"
86 |
87 | learning_rate_per_sample = learning_rate_per_minibatch / minibatch_size
88 |
89 | if momentum_as_unit_gain:
90 | learning_rate_per_sample /= (1. - momentum)
91 |
92 | return learning_rate_per_sample
93 |
94 |
95 | def as_momentum_as_time_constant(momentum, minibatch_size):
96 | """ Convert a momentum provided a global for the a full minibatch
97 | to the momentum as number of sample seen rate by sample
98 |
99 | momentum_as_time_constant = -minibatch_size / (np.log(momentum))
100 | """
101 | return np.ceil(-minibatch_size / (np.log(momentum)))
102 |
103 |
104 | def prepend_batch_seq_axis(tensor):
105 | """
106 | CNTK uses 2 dynamic axes (batch, sequence, input_shape...).
107 | To have a single sample with length 1 you need to pass (1, 1, input_shape...)
108 | This method reshapes a tensor to add to the batch and sequence axis equal to 1.
109 | :param tensor: The tensor to be reshaped
110 | :return: Reshaped tensor with batch and sequence axis = 1
111 | """
112 | return tensor.reshape((1, 1,) + tensor.shape)
113 |
114 |
115 | def prepend_batch_axis(tensor):
116 | """
117 | CNTK uses 2 dynamic axes (batch, sequence, input_shape...).
118 | If you define variables with dynamic_axes=[Axis.default_batch_axis()] you can rid of sequence axis
119 |
120 | To have a single sample with length 1 you need to pass (1, input_shape...)
121 | This method reshapes a tensor to add to the batch and sequence axis equal to 1.
122 | :param tensor: The tensor to be reshaped
123 | :return: Reshaped tensor with batch and sequence axis = 1
124 | """
125 | return tensor.reshape((1,) + tensor.shape)
126 |
127 |
128 | class CntkModel(Visualizable):
129 | """ Base class for CNTK based neural networks.
130 |
131 | It handles the management of the CPU/GPU device and provides commodity methods for exporting the model
132 | """
133 |
134 | def __init__(self, device_id=None, unit_gain=False, n_workers=1, visualizer=None):
135 | """
136 | Abstract constructor of CNTK model.
137 | This constructor wraps CNTK intialization and tuning
138 | :param device_id: Use None if you want CNTK to use the best available device, -1 for CPU, >= 0 for GPU
139 | :param n_workers: Number of concurrent workers for distributed training. Keep set to 1 for non distributed mode
140 | :param visualizer: Optional visualizer allowing model to save summary data
141 | """
142 | assert n_workers >= 1, 'n_workers should be at least 1 (not distributed) or > 1 if distributed'
143 |
144 | Visualizable.__init__(self, visualizer)
145 |
146 | self._model = None
147 | self._learner = None
148 | self._loss = None
149 | self._distributed = n_workers > 1
150 |
151 | if isinstance(device_id, int):
152 | try_set_default_device(cpu() if device_id == -1 else gpu(device_id))
153 |
154 | set_default_unit_gain_value(unit_gain)
155 |
156 | def _build_model(self):
157 | raise NotImplementedError()
158 |
159 | @property
160 | def loss_val(self):
161 | raise NotImplementedError()
162 |
163 | @property
164 | def model(self):
165 | return self._model
166 |
167 | @property
168 | def distributed_training(self):
169 | return self._distributed
170 |
171 | @property
172 | def distributed_rank(self):
173 | if self._distributed:
174 | if self._learner and hasattr(self._learner, 'communicator'):
175 | return self._learner.communicator().rank()
176 | else:
177 | return 0
178 |
179 | def load(self, input_file):
180 | if self._model is None:
181 | raise ValueError("cannot load to a model that equals None")
182 |
183 | self._model.restore(input_file)
184 |
185 | def save(self, output_file):
186 | if self._model is None:
187 | raise ValueError("cannot save a model that equals None")
188 |
189 | self._model.save(output_file)
190 |
191 | def finalize(self):
192 | if self._distributed:
193 | Communicator.finalize()
194 |
--------------------------------------------------------------------------------
/malmopy/model/cntk/qlearning.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from cntk import Value
21 | from cntk.axis import Axis
22 | from cntk.initializer import he_uniform, he_normal
23 | from cntk.layers import Convolution, Dense, default_options
24 | from cntk.layers.higher_order_layers import Sequential
25 | from cntk.learners import adam, momentum_schedule, learning_rate_schedule, \
26 | UnitType
27 | from cntk.ops import input, relu, reduce_sum
28 | from cntk.ops.functions import CloneMethod
29 | from cntk.train.trainer import Trainer
30 |
31 | from . import CntkModel, prepend_batch_axis, huber_loss
32 | from ..model import QModel
33 | from ...util import check_rank
34 |
35 |
36 | class QNeuralNetwork(CntkModel, QModel):
37 | """
38 | Represents a learning capable entity using CNTK
39 | """
40 |
41 | def __init__(self, in_shape, output_shape, device_id=None,
42 | learning_rate=0.00025, momentum=0.9,
43 | minibatch_size=32, update_interval=10000,
44 | n_workers=1, visualizer=None):
45 |
46 | """
47 | Q Neural Network following Mnih and al. implementation and default options.
48 |
49 | The network has the following topology:
50 | Convolution(32, (8, 8))
51 | Convolution(64, (4, 4))
52 | Convolution(64, (2, 2))
53 | Dense(512)
54 |
55 | :param in_shape: Shape of the observations perceived by the learner (the neural net input)
56 | :param output_shape: Size of the action space (mapped to the number of output neurons)
57 |
58 | :param device_id: Use None to let CNTK select the best available device,
59 | -1 for CPU, >= 0 for GPU
60 | (default: None)
61 |
62 | :param learning_rate: Learning rate
63 | (default: 0.00025, as per Mnih et al.)
64 |
65 | :param momentum: Momentum, provided as momentum value for
66 | averaging gradients without unit gain filter
67 | Note that CNTK does not currently provide an implementation
68 | of Graves' RmsProp with momentum.
69 | It uses AdamSGD optimizer instead.
70 | (default: 0, no momentum with RProp optimizer)
71 |
72 | :param minibatch_size: Minibatch size
73 | (default: 32, as per Mnih et al.)
74 |
75 | :param n_workers: Number of concurrent worker for distributed training.
76 | (default: 1, not distributed)
77 |
78 | :param visualizer: Optional visualizer allowing the model to save summary data
79 | (default: None, no visualization)
80 |
81 | Ref: Mnih et al.: "Human-level control through deep reinforcement learning."
82 | Nature 518.7540 (2015): 529-533.
83 | """
84 |
85 | assert learning_rate > 0, 'learning_rate should be > 0'
86 | assert 0. <= momentum < 1, 'momentum should be 0 <= momentum < 1'
87 |
88 | QModel.__init__(self, in_shape, output_shape)
89 | CntkModel.__init__(self, device_id, False, n_workers, visualizer)
90 |
91 | self._nb_actions = output_shape
92 | self._steps = 0
93 | self._target_update_interval = update_interval
94 | self._target = None
95 |
96 | # Input vars
97 | self._environment = input(in_shape, name='env',
98 | dynamic_axes=(Axis.default_batch_axis()))
99 | self._q_targets = input(1, name='q_targets',
100 | dynamic_axes=(Axis.default_batch_axis()))
101 | self._actions = input(output_shape, name='actions',
102 | dynamic_axes=(Axis.default_batch_axis()))
103 |
104 | # Define the neural network graph
105 | self._model = self._build_model()(self._environment)
106 | self._target = self._model.clone(
107 | CloneMethod.freeze, {self._environment: self._environment}
108 | )
109 |
110 | # Define the learning rate
111 | lr_schedule = learning_rate_schedule(learning_rate, UnitType.minibatch)
112 |
113 | # AdamSGD optimizer
114 | m_schedule = momentum_schedule(momentum)
115 | vm_schedule = momentum_schedule(0.999)
116 | l_sgd = adam(self._model.parameters, lr_schedule,
117 | momentum=m_schedule,
118 | unit_gain=True,
119 | variance_momentum=vm_schedule)
120 |
121 | if self.distributed_training:
122 | raise NotImplementedError('ASGD not implemented yet.')
123 |
124 | # _actions is a sparse 1-hot encoding of the actions done by the agent
125 | q_acted = reduce_sum(self._model * self._actions, axis=0)
126 |
127 | # Define the trainer with Huber Loss function
128 | criterion = huber_loss(q_acted, self._q_targets, 1.0)
129 |
130 | self._learner = l_sgd
131 | self._trainer = Trainer(self._model, (criterion, None), l_sgd)
132 |
133 | @property
134 | def loss_val(self):
135 | return self._trainer.previous_minibatch_loss_average
136 |
137 | def _build_model(self):
138 | with default_options(init=he_uniform(), activation=relu, bias=True):
139 | model = Sequential([
140 | Convolution((8, 8), 32, strides=(4, 4)),
141 | Convolution((4, 4), 64, strides=(2, 2)),
142 | Convolution((3, 3), 64, strides=(1, 1)),
143 | Dense(512, init=he_normal(0.01)),
144 | Dense(self._nb_actions, activation=None, init=he_normal(0.01))
145 | ])
146 | return model
147 |
148 | def train(self, x, q_value_targets, actions=None):
149 | assert actions is not None, 'actions cannot be None'
150 |
151 | # We need to add extra dimensions to shape [N, 1] => [N, 1]
152 | if check_rank(q_value_targets.shape, 1):
153 | q_value_targets = q_value_targets.reshape((-1, 1))
154 |
155 | # Add extra dimensions to match shape [N, 1] required by one_hot
156 | if check_rank(actions.shape, 1):
157 | actions = actions.reshape((-1, 1))
158 |
159 | # We need batch axis
160 | if check_rank(x.shape, len(self._environment.shape)):
161 | x = prepend_batch_axis(x)
162 |
163 | self._trainer.train_minibatch({
164 | self._environment: x,
165 | self._actions: Value.one_hot(actions, self._nb_actions),
166 | self._q_targets: q_value_targets
167 | })
168 |
169 | # Counter number of train calls
170 | self._steps += 1
171 |
172 | # Update the model with the target one
173 | if (self._steps % self._target_update_interval) == 0:
174 | self._target = self._model.clone(
175 | CloneMethod.freeze, {self._environment: self._environment}
176 | )
177 |
178 | def evaluate(self, data, model=QModel.ACTION_VALUE_NETWORK):
179 | # If evaluating a single sample, expand the minibatch axis
180 | # (minibatch = 1, input_shape...)
181 | if len(data.shape) == len(self.input_shape):
182 | data = prepend_batch_axis(data) # Append minibatch dim
183 |
184 | if model == QModel.TARGET_NETWORK:
185 | predictions = self._target.eval({self._environment: data})
186 | else:
187 | predictions = self._model.eval({self._environment: data})
188 | return predictions.squeeze()
189 |
190 |
191 | class ReducedQNeuralNetwork(QNeuralNetwork):
192 | """
193 | Represents a learning capable entity using CNTK, reduced model
194 | """
195 |
196 | def __init__(self, in_shape, output_shape, device_id=None,
197 | learning_rate=0.00025, momentum=0.9,
198 | minibatch_size=32, update_interval=10000,
199 | n_workers=1, visualizer=None):
200 |
201 | QNeuralNetwork.__init__(self, in_shape, output_shape, device_id,
202 | learning_rate, momentum, minibatch_size, update_interval,
203 | n_workers, visualizer)
204 |
205 | def _build_model(self):
206 | with default_options(init=he_uniform(), activation=relu, bias=True):
207 | model = Sequential([
208 | Convolution((4, 4), 64, strides=(2, 2), name='conv1'),
209 | Convolution((3, 3), 64, strides=(1, 1), name='conv2'),
210 | Dense(512, name='dense1', init=he_normal(0.01)),
211 | Dense(self._nb_actions, activation=None, init=he_normal(0.01), name='qvalues')
212 | ])
213 | return model
--------------------------------------------------------------------------------
/malmopy/model/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 |
21 | class BaseModel(object):
22 | """Represents a learning capable entity"""
23 |
24 | def __init__(self, in_shape, output_shape):
25 | self._input_shape = in_shape
26 | self._output_shape = output_shape
27 |
28 | @property
29 | def input_shape(self):
30 | return self._input_shape
31 |
32 | @property
33 | def output_shape(self):
34 | return self._output_shape
35 |
36 | @property
37 | def loss_val(self):
38 | raise NotImplementedError()
39 |
40 | def evaluate(self, environment):
41 | raise NotImplementedError()
42 |
43 | def train(self, x, y):
44 | raise NotImplementedError()
45 |
46 | def load(self, input_file):
47 | raise NotImplementedError()
48 |
49 | def save(self, output_file):
50 | raise NotImplementedError()
51 |
52 |
53 | class QModel(BaseModel):
54 | ACTION_VALUE_NETWORK = 1 << 0
55 | TARGET_NETWORK = 1 << 1
56 |
57 | def evaluate(self, environment, model=ACTION_VALUE_NETWORK):
58 | raise NotImplementedError()
59 |
60 | def train(self, x, y, actions=None):
61 | raise NotImplementedError()
62 |
63 |
64 |
--------------------------------------------------------------------------------
/malmopy/util/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from .images import resize, rgb2gray
21 | from .util import *
22 |
23 |
24 |
--------------------------------------------------------------------------------
/malmopy/util/images.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | import sys
19 | import numpy as np
20 |
21 | OPENCV_AVAILABLE = False
22 | PILLOW_AVAILABLE = False
23 |
24 | try:
25 | import cv2
26 |
27 | OPENCV_AVAILABLE = True
28 | print('OpenCV found, setting as default backend.')
29 | except ImportError:
30 | pass
31 |
32 | try:
33 | import PIL
34 |
35 | PILLOW_AVAILABLE = True
36 |
37 | if not OPENCV_AVAILABLE:
38 | print('Pillow found, setting as default backend.')
39 | except ImportError:
40 | pass
41 |
42 |
43 | if not (OPENCV_AVAILABLE or PILLOW_AVAILABLE):
44 | raise ValueError('No image library backend found.'' Install either '
45 | 'OpenCV or Pillow to support image processing.')
46 |
47 |
48 | def resize(img, shape):
49 | """
50 | Resize the specified image
51 | :param img: Image to reshape
52 | :param shape: New image shape
53 | :return:
54 | """
55 | if OPENCV_AVAILABLE:
56 | from cv2 import resize
57 | return resize(img, shape)
58 | elif PILLOW_AVAILABLE:
59 | from PIL import Image
60 | return np.array(Image.fromarray(img).resize(shape))
61 |
62 |
63 | def rgb2gray(img):
64 | """
65 | Convert an RGB image to grayscale
66 | :param img: image to convert
67 | :return:
68 | """
69 | if OPENCV_AVAILABLE:
70 | from cv2 import cvtColor, COLOR_RGB2GRAY
71 | return cvtColor(img, COLOR_RGB2GRAY)
72 | elif PILLOW_AVAILABLE:
73 | from PIL import Image
74 | return np.array(Image.fromarray(img).convert('L'))
75 |
--------------------------------------------------------------------------------
/malmopy/util/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | import os
21 | from math import sqrt
22 |
23 | import numpy as np
24 |
25 |
26 | def euclidean(a, b):
27 | assert len(a) == len(b), 'cannot compute distance when a and b have different shapes'
28 | return sqrt(sum([(a - b) ** 2 for a, b in zip(a, b)]))
29 |
30 |
31 | def get_rank(x):
32 | """ Get a shape's rank """
33 | if isinstance(x, np.ndarray):
34 | return len(x.shape)
35 | elif isinstance(x, tuple):
36 | return len(x)
37 | else:
38 | return ValueError('Unable to determine rank of type: %s' % str(type(x)))
39 |
40 |
41 | def check_rank(shape, required_rank):
42 | """ Check if the shape's rank equals the expected rank """
43 | if isinstance(shape, tuple):
44 | return len(shape) == required_rank
45 | else:
46 | return False
47 |
48 |
49 | def isclose(a, b, atol=1e-01):
50 | """ Check if a and b are closer than tolerance level atol
51 |
52 | return abs(a - b) < atol
53 | """
54 | return abs(a - b) < atol
55 |
56 |
57 | def ensure_path_exists(path):
58 | """ Ensure that the specified path exists on the filesystem """
59 | if not os.path.isabs(path):
60 | path = os.path.abspath(path)
61 | if not os.path.exists(path):
62 | os.makedirs(path)
63 |
--------------------------------------------------------------------------------
/malmopy/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | VERSION = '0.1.0'
19 |
--------------------------------------------------------------------------------
/malmopy/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from .visualizer import BaseVisualizer, ConsoleVisualizer, EmptyVisualizer, Visualizable
21 |
--------------------------------------------------------------------------------
/malmopy/visualization/tensorboard/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from .tensorboard import TensorboardVisualizer, TensorflowConverter
21 |
--------------------------------------------------------------------------------
/malmopy/visualization/tensorboard/cntk/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from .cntk import CntkConverter
21 |
--------------------------------------------------------------------------------
/malmopy/visualization/tensorboard/cntk/cntk.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | import six
21 | import tensorflow as tf
22 | from tensorflow.core.framework import attr_value_pb2, tensor_shape_pb2
23 |
24 | from ..tensorboard import TensorflowConverter
25 |
26 |
27 | class CntkConverter(TensorflowConverter):
28 | def convert(self, network, graph):
29 | """
30 | Converts a function from CNTK to the Tensorflow graph format
31 |
32 | Args:
33 | network: CNTK function that defines the network structure
34 | graph: destination Tensorflow graph
35 | """
36 | # Walk every node of the network iteratively
37 | stack = [network.model]
38 | visited = set()
39 |
40 | while stack:
41 | node = stack.pop()
42 |
43 | if node in visited:
44 | continue
45 |
46 | try:
47 |
48 | # Function node
49 | node = node.root_function
50 | stack.extend(node.inputs)
51 | try:
52 | # TF graph already has the current node
53 | graph.get_operation_by_name(node.uid.split('_')[0])
54 | continue
55 |
56 | except KeyError:
57 | # New network node that has to be converted to TF format
58 | # define TF operation attributes based on CNTK network node
59 | try:
60 | dim_x = tensor_shape_pb2.TensorShapeProto.Dim(size=node.outputs[0].shape[0])
61 | except IndexError:
62 | dim_x = tensor_shape_pb2.TensorShapeProto.Dim(size=1)
63 | try:
64 | dim_y = tensor_shape_pb2.TensorShapeProto.Dim(size=node.outputs[0].shape[1])
65 | except IndexError:
66 | dim_y = tensor_shape_pb2.TensorShapeProto.Dim(size=1)
67 | shape = tensor_shape_pb2.TensorShapeProto(dim=(dim_x, dim_y))
68 | shape_attr = attr_value_pb2.AttrValue(shape=shape)
69 | attrs = {"shape": shape_attr}
70 |
71 | # Use name scope based on the node's name (e.g. Plus1) to
72 | # group the operation and its inputs
73 | with graph.name_scope(node.uid) as _:
74 |
75 | # Create a TF placeholder operation with type, name and shape of the current node
76 | op = graph.create_op("Placeholder", inputs=[],
77 | dtypes=[node.outputs[0].dtype], attrs=attrs,
78 | name=node.uid)
79 |
80 | # Add inputs to the created TF operation
81 | for i in six.moves.range(len(node.inputs)):
82 | child = node.inputs[i]
83 | name = child.uid
84 | try:
85 | # The input tensor already exists in the graph
86 | tf_input = graph.get_tensor_by_name(name + ":0")
87 | except KeyError:
88 | # A new tensor that needs to be converted from CNTK to TF
89 | shape = self.convert_shape(child.shape)
90 | dtype = child.dtype
91 | # Create a new placeholder tensor with the corresponding attributes
92 | tf_input = tf.placeholder(shape=shape, dtype=dtype, name=name)
93 |
94 | # Update TF operator's inputs
95 | op._add_input(tf_input)
96 |
97 | # Update TF operation's outputs
98 | output = node.outputs[0]
99 | for o in graph.get_operations():
100 | if output.uid in o.name:
101 | o._add_input(op.outputs[0])
102 |
103 | except AttributeError:
104 | # OutputVariable node
105 | try:
106 | if node.is_output:
107 | try:
108 | # Owner of the node is already added to the TF graph
109 | owner_name = node.owner.uid + '/' + node.owner.uid
110 | graph.get_operation_by_name(owner_name)
111 | except KeyError:
112 | # Unknown network node
113 | stack.append(node.owner)
114 |
115 | except AttributeError:
116 | pass
117 |
118 | # Add missing connections in the graph
119 | CntkConverter.update_outputs(graph.get_operations())
120 | graph.finalize()
121 |
122 | @staticmethod
123 | def convert_shape(shape):
124 | if len(shape) == 0:
125 | shape = (1, 1)
126 | else:
127 | if len(shape) == 1:
128 | shape += (1,)
129 | return shape
130 |
131 | @staticmethod
132 | def update_outputs(ops):
133 | """Updates the inputs/outputs of the Tensorflow operations
134 | by adding missing connections
135 |
136 | Args:
137 | ops: a list of Tensorflow operations
138 | """
139 | for i in six.moves.range(len(ops)):
140 | for j in six.moves.range(i + 1, len(ops)):
141 | if ops[i].name.split('/')[1] in ops[j].name.split('/')[1]:
142 | ops[i]._add_input(ops[j].outputs[0])
143 |
--------------------------------------------------------------------------------
/malmopy/visualization/tensorboard/tensorboard.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | import six
21 | import tensorflow as tf
22 | from tensorflow.core.framework.summary_pb2 import Summary
23 |
24 | from ..visualizer import BaseVisualizer
25 |
26 |
27 | class TensorboardVisualizer(BaseVisualizer):
28 | """
29 | Visualize the generated results in Tensorboard
30 | """
31 |
32 | def __init__(self):
33 | super(TensorboardVisualizer, self).__init__()
34 |
35 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.01)
36 | self._session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
37 | self._train_writer = None
38 |
39 | def initialize(self, logdir, model, converter=None):
40 | assert logdir is not None, "logdir cannot be None"
41 | assert isinstance(logdir, six.string_types), "logdir should be a string"
42 |
43 | if converter is not None:
44 | assert isinstance(converter, TensorflowConverter), \
45 | "converter should derive from TensorflowConverter"
46 | converter.convert(model, self._session.graph)
47 |
48 | self._train_writer = tf.summary.FileWriter(logdir=logdir,
49 | graph=self._session.graph,
50 | flush_secs=30)
51 |
52 | def add_entry(self, index, tag, value, **kwargs):
53 | if "image" in kwargs and value is not None:
54 | image_string = tf.image.encode_jpeg(value, optimize_size=True, quality=80)
55 | summary_value = Summary.Image(width=value.shape[1],
56 | height=value.shape[0],
57 | colorspace=value.shape[2],
58 | encoded_image_string=image_string)
59 | else:
60 | summary_value = Summary.Value(tag=tag, simple_value=value)
61 |
62 | if summary_value is not None:
63 | entry = Summary(value=[summary_value])
64 | self._train_writer.add_summary(entry, index)
65 |
66 | def close(self):
67 | if self._train_writer is not None:
68 | self._train_writer.close()
69 |
70 | def __enter__(self):
71 | return self
72 |
73 | def __exit__(self, exc_type, exc_val, exc_tb):
74 | self.close()
75 |
76 |
77 | class TensorflowConverter(object):
78 | def convert(self, network, graph):
79 | raise NotImplementedError()
80 |
--------------------------------------------------------------------------------
/malmopy/visualization/visualizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from __future__ import absolute_import
19 |
20 | from os import path
21 |
22 |
23 | class Visualizable(object):
24 | def __init__(self, visualizer=None):
25 | if visualizer is not None:
26 | assert isinstance(visualizer, BaseVisualizer), "visualizer should derive from BaseVisualizer"
27 |
28 | self._visualizer = visualizer
29 |
30 | def visualize(self, index, tag, value, **kwargs):
31 | if self._visualizer is not None:
32 | self._visualizer << (index, tag, value, kwargs)
33 |
34 | @property
35 | def can_visualize(self):
36 | return self._visualizer is not None
37 |
38 |
39 | class BaseVisualizer(object):
40 | """ Provide a unified interface for observing the training progress """
41 |
42 | def add_entry(self, index, key, result, **kwargs):
43 | raise NotImplementedError()
44 |
45 | def __lshift__(self, other):
46 | if isinstance(other, tuple):
47 | if len(other) >= 3:
48 | self.add_entry(other[0], str(other[1]), other[2])
49 | else:
50 | raise ValueError("Provided tuple should be of the form (key, value)")
51 | else:
52 | raise ValueError("Trying to use stream operator without a tuple (key, value)")
53 |
54 |
55 | class EmptyVisualizer(BaseVisualizer):
56 | """ A boilerplate visualizer that does nothing """
57 |
58 | def add_entry(self, index, key, result, **kwargs):
59 | pass
60 |
61 |
62 | class ConsoleVisualizer(BaseVisualizer):
63 | """ Print visualization to stdout as:
64 | key -> value
65 | """
66 | CONSOLE_DEFAULT_FORMAT = "[%s] %d : %s -> %.3f"
67 |
68 | def __init__(self, format=None, prefix=None):
69 | self._format = format or ConsoleVisualizer.CONSOLE_DEFAULT_FORMAT
70 | self._prefix = prefix or '-'
71 |
72 | def add_entry(self, index, key, result, **kwargs):
73 | print(self._format % (self._prefix, index, key, result))
74 |
75 |
76 | class CsvVisualizer(BaseVisualizer):
77 | """ Write data to file. The following formats are supported: CSV, JSON, Excel. """
78 | def __init__(self, output_file, override=False):
79 | if path.exists(output_file) and not override:
80 | raise Exception('%s already exists and override is False' % output_file)
81 |
82 | super(CsvVisualizer, self).__init__()
83 | self._file = output_file
84 | self._data = {}
85 |
86 | def add_entry(self, index, key, result, **kwargs):
87 | if key in self._data[index]:
88 | print('Warning: Found previous value for %s in visualizer' % key)
89 |
90 | self._data[index].update({key: result})
91 |
92 | def close(self, format='csv'):
93 | import pandas as pd
94 |
95 | if format == 'csv':
96 | pd.DataFrame.from_dict(self._data, orient='index').to_csv(self._file)
97 | elif format == 'json':
98 | pd.DataFrame.from_dict(self._data, orient='index').to_json(self._file)
99 | else:
100 | writer = pd.ExcelWriter(self._file)
101 | pd.DataFrame.from_dict(self._data, orient='index').to_excel(writer)
102 | writer.save()
103 |
104 | def __enter__(self):
105 | return self
106 |
107 | def __exit__(self, exc_type, exc_val, exc_tb):
108 | self.close()
109 | return self
110 |
--------------------------------------------------------------------------------
/samples/atari/gym_atari_dqn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017 Microsoft Corporation.
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
4 | # documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
5 | # rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
6 | # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7 | #
8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
9 | # the Software.
10 | #
11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
12 | # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
13 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
14 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
15 | # SOFTWARE.
16 | # ===================================================================================================================
17 |
18 | from argparse import ArgumentParser
19 | from datetime import datetime
20 | from subprocess import Popen
21 |
22 | from malmopy.agent import QLearnerAgent, TemporalMemory
23 | from malmopy.environment.gym import GymEnvironment
24 |
25 | try:
26 | from malmopy.visualization.tensorboard import TensorboardVisualizer
27 | from malmopy.visualization.tensorboard.cntk import CntkConverter
28 |
29 | TENSORBOARD_AVAILABLE = True
30 | except ImportError:
31 | print('Cannot import tensorboard, using ConsoleVisualizer.')
32 | from malmopy.visualization import ConsoleVisualizer
33 |
34 | TENSORBOARD_AVAILABLE = False
35 |
36 |
37 | ROOT_FOLDER = 'results/baselines/%s/dqn/%s-%s'
38 | EPOCH_SIZE = 250000
39 |
40 |
41 | def visualize_training(visualizer, step, rewards, tag='Training'):
42 | visualizer.add_entry(step, '%s/reward per episode' % tag, sum(rewards))
43 | visualizer.add_entry(step, '%s/max.reward' % tag, max(rewards))
44 | visualizer.add_entry(step, '%s/min.reward' % tag, min(rewards))
45 | visualizer.add_entry(step, '%s/actions per episode' % tag, len(rewards)-1)
46 |
47 |
48 | def run_experiment(environment, backend, device_id, max_epoch, record, logdir,
49 | visualizer):
50 |
51 | env = GymEnvironment(environment,
52 | monitoring_path=logdir if record else None)
53 |
54 | if backend == 'cntk':
55 | from malmopy.model.cntk import QNeuralNetwork as CntkDQN
56 | model = CntkDQN((4, 84, 84), env.available_actions, momentum=0.95,
57 | device_id=device_id, visualizer=visualizer)
58 | else:
59 | from malmopy.model.chainer import DQNChain, QNeuralNetwork as ChainerDQN
60 | chain = DQNChain((4, 84, 84), env.available_actions)
61 | target_chain = DQNChain((4, 84, 84), env.available_actions)
62 | model = ChainerDQN(chain, target_chain,
63 | momentum=0.95, device_id=device_id)
64 |
65 | memory = TemporalMemory(1000000, model.input_shape[1:])
66 | agent = QLearnerAgent("DQN Agent", env.available_actions, model, memory,
67 | 0.99, 32, train_after=10000, reward_clipping=(-1, 1),
68 | visualizer=visualizer)
69 |
70 | state = env.reset()
71 | reward = 0
72 | agent_done = False
73 | viz_rewards = []
74 |
75 | max_training_steps = max_epoch * EPOCH_SIZE
76 | for step in range(1, max_training_steps + 1):
77 |
78 | # check if env needs reset
79 | if env.done:
80 | visualize_training(visualizer, step, viz_rewards)
81 | agent.inject_summaries(step)
82 | viz_rewards = []
83 | state = env.reset()
84 |
85 | # select an action
86 | action = agent.act(state, reward, agent_done, is_training=True)
87 |
88 | # take a step
89 | state, reward, agent_done = env.do(action)
90 | viz_rewards.append(reward)
91 |
92 | if (step % EPOCH_SIZE) == 0:
93 | model.save('%s-%s-dqn_%d.model' %
94 | (backend, environment, step / EPOCH_SIZE))
95 |
96 |
97 | if __name__ == '__main__':
98 | arg_parser = ArgumentParser(description='OpenAI Gym DQN example')
99 | arg_parser.add_argument('-b', '--backend', type=str, default='cntk',
100 | choices=['cntk', 'chainer'],
101 | help='Neural network backend to use.')
102 | arg_parser.add_argument('-d', '--device', type=int, default=-1,
103 | help='GPU device on which to run the experiment.')
104 | arg_parser.add_argument('-r', '--record', action='store_true',
105 | help='Setting this will record runs')
106 | arg_parser.add_argument('-e', '--epochs', type=int, default=50,
107 | help='Number of epochs. One epoch is 250k actions.')
108 | arg_parser.add_argument('-p', '--port', type=int, default=6006,
109 | help='Port for running tensorboard.')
110 | arg_parser.add_argument('env', type=str, metavar='environment',
111 | nargs='?', default='Breakout-v3',
112 | help='Gym environment to run')
113 |
114 | args = arg_parser.parse_args()
115 |
116 | logdir = ROOT_FOLDER % (args.env, args.backend, datetime.utcnow().isoformat())
117 | if TENSORBOARD_AVAILABLE:
118 | visualizer = TensorboardVisualizer()
119 | visualizer.initialize(logdir, None)
120 | print('Starting tensorboard ...')
121 | p = Popen(['tensorboard', '--logdir=results', '--port=%d' % args.port])
122 |
123 | else:
124 | visualizer = ConsoleVisualizer()
125 |
126 | print('Starting experiment')
127 | run_experiment(args.env, args.backend, int(args.device), args.epochs,
128 | args.record, logdir, visualizer)
129 |
130 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from distutils.core import setup
4 |
5 | from setuptools import find_packages
6 |
7 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'malmopy'))
8 | from version import VERSION
9 |
10 | extras = {
11 | 'chainer': ['chainer>=1.21.0'],
12 | 'gym': ['gym[atari]>=0.7.0'],
13 | 'tensorflow': ['tensorflow'],
14 | }
15 |
16 | # Meta dependency groups.
17 | all_deps = []
18 | for group_name in extras:
19 | all_deps += extras[group_name]
20 | extras['all'] = all_deps
21 |
22 | setup(
23 | name='malmopy',
24 | version=VERSION,
25 |
26 | packages=[package for package in find_packages()
27 | if package.startswith('malmopy')],
28 |
29 | url='https://github.com/Microsoft/malmo-challenge',
30 | license='MIT',
31 | author='Microsoft Research Cambridge',
32 | author_email='',
33 | description='Malmo Collaborative AI Challenge task and example code',
34 | install_requires=['future', 'numpy>=1.11.0', 'six>=0.10.0', 'pandas', 'Pillow'],
35 | extras_require=extras
36 | )
37 |
--------------------------------------------------------------------------------