├── .gitattributes ├── .gitignore ├── .gitmodules ├── LICENSE.TXT ├── README.md ├── cule ├── atari.hpp ├── atari │ ├── accessors.hpp │ ├── actions.hpp │ ├── ale.hpp │ ├── common.hpp │ ├── controller.hpp │ ├── cuda │ │ ├── dispatch.hpp │ │ ├── frame_state.hpp │ │ ├── kernels.hpp │ │ ├── state.hpp │ │ ├── tables.hpp │ │ └── timer.hpp │ ├── debug.hpp │ ├── dispatch.hpp │ ├── environment.hpp │ ├── flags.hpp │ ├── frame_state.hpp │ ├── functors.hpp │ ├── games.hpp │ ├── games │ │ ├── adventure.hpp │ │ ├── airraid.hpp │ │ ├── alien.hpp │ │ ├── amidar.hpp │ │ ├── assault.hpp │ │ ├── asterix.hpp │ │ ├── asteroids.hpp │ │ ├── atlantis.hpp │ │ ├── bankheist.hpp │ │ ├── battlezone.hpp │ │ ├── beamrider.hpp │ │ ├── berzerk.hpp │ │ ├── bowling.hpp │ │ ├── boxing.hpp │ │ ├── breakout.hpp │ │ ├── carnival.hpp │ │ ├── centipede.hpp │ │ ├── chopper.hpp │ │ ├── crazyclimber.hpp │ │ ├── defender.hpp │ │ ├── demonattack.hpp │ │ ├── detail │ │ │ ├── attributes.hpp │ │ │ ├── types.hpp │ │ │ └── utils.hpp │ │ ├── doubledunk.hpp │ │ ├── elevatoraction.hpp │ │ ├── enduro.hpp │ │ ├── fishingderby.hpp │ │ ├── freeway.hpp │ │ ├── frostbite.hpp │ │ ├── gopher.hpp │ │ ├── gravitar.hpp │ │ ├── hero.hpp │ │ ├── icehockey.hpp │ │ ├── jamesbond.hpp │ │ ├── journeyescape.hpp │ │ ├── kaboom.hpp │ │ ├── kangaroo.hpp │ │ ├── krull.hpp │ │ ├── kungfumaster.hpp │ │ ├── montezumarevenge.hpp │ │ ├── mspacman.hpp │ │ ├── namethisgame.hpp │ │ ├── phoenix.hpp │ │ ├── pinball.hpp │ │ ├── pitfall.hpp │ │ ├── pong.hpp │ │ ├── pooyan.hpp │ │ ├── privateeye.hpp │ │ ├── qbert.hpp │ │ ├── riverraid.hpp │ │ ├── roadrunner.hpp │ │ ├── robotank.hpp │ │ ├── seaquest.hpp │ │ ├── skiing.hpp │ │ ├── solaris.hpp │ │ ├── spaceinvaders.hpp │ │ ├── stargunner.hpp │ │ ├── tennis.hpp │ │ ├── timepilot.hpp │ │ ├── tutankham.hpp │ │ ├── upndown.hpp │ │ ├── venture.hpp │ │ ├── wizard.hpp │ │ ├── yarsrevenge.hpp │ │ └── zaxxon.hpp │ ├── internals.hpp │ ├── interrupt.hpp │ ├── joystick.hpp │ ├── m6502.hpp │ ├── m6532.hpp │ ├── mmc.hpp │ ├── opcodes.cpp │ ├── opcodes.hpp │ ├── paddles.hpp │ ├── palettes.hpp │ ├── png.hpp │ ├── preprocess.hpp │ ├── prng.hpp │ ├── ram.hpp │ ├── rom.cpp │ ├── rom.hpp │ ├── stack.hpp │ ├── state.hpp │ ├── tables.cpp │ ├── tables.hpp │ ├── tia.hpp │ ├── types │ │ ├── bitfield.hpp │ │ ├── flagset.hpp │ │ ├── types.hpp │ │ └── valueobj.hpp │ ├── wrapper.cpp │ └── wrapper.hpp ├── config.hpp ├── cuda.hpp ├── cuda │ ├── errchk.hpp │ └── parallel_execution_policy.hpp ├── cule.hpp ├── macros.hpp └── md5.hpp ├── envs ├── Dockerfile └── environment.yml ├── examples ├── __init__.py ├── a2c │ ├── __init__.py │ ├── a2c_main.py │ ├── helper.py │ ├── model.py │ ├── test.py │ └── train.py ├── dqn │ ├── LICENSE.md │ ├── README.md │ ├── agent.py │ ├── benchmark.config │ ├── dqn_main.py │ ├── memory.py │ ├── model.py │ ├── test.py │ └── train.py ├── ppo │ ├── benchmark.config │ ├── ppo_main.py │ └── train.py ├── utils │ ├── __init__.py │ ├── initializers.py │ ├── launcher.py │ ├── openai │ │ ├── LICENSE.md │ │ ├── README.md │ │ ├── atari_wrappers.py │ │ ├── envs.py │ │ ├── subproc_vec_env.py │ │ └── vec_normalize.py │ └── runtime.py ├── visualize │ ├── animate.py │ └── play.py └── vtrace │ ├── benchmark_vtrace.py │ ├── test_vtrace.py │ ├── train.py │ └── vtrace_main.py ├── media └── images │ └── System.png ├── setup.py └── torchcule ├── .pylintrc ├── __init__.py ├── atari ├── __init__.py ├── env.py ├── rom.py └── state.py ├── atari_env.hpp ├── atari_state.cpp ├── atari_state.hpp ├── backend.cu └── frontend.cpp /.gitattributes: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Set default behavior to automatically normalize line endings. 3 | ############################################################################### 4 | * text=auto 5 | 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # build artifacts 2 | *.o 3 | *.os 4 | *.pyc 5 | *.so 6 | __pycache__ 7 | 8 | # media files 9 | *.png 10 | *.mov 11 | 12 | # build file 13 | build 14 | dist 15 | *.egg-info 16 | 17 | # model files 18 | *.pth 19 | 20 | # data files 21 | *.csv 22 | examples/*/runs 23 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/agency"] 2 | path = third_party/agency 3 | url = https://github.com/agency-library/agency.git 4 | [submodule "third_party/pybind11"] 5 | path = third_party/pybind11 6 | url = https://github.com/pybind/pybind11.git 7 | -------------------------------------------------------------------------------- /LICENSE.TXT: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions 6 | * are met: 7 | * * Redistributions of source code must retain the above copyright 8 | * notice, this list of conditions and the following disclaimer. 9 | * * Redistributions in binary form must reproduce the above copyright 10 | * notice, this list of conditions and the following disclaimer in the 11 | * documentation and/or other materials provided with the distribution. 12 | * * Neither the name of NVIDIA CORPORATION nor the names of its 13 | * contributors may be used to endorse or promote products derived 14 | * from this software without specific prior written permission. 15 | * 16 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 17 | * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 19 | * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 20 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 24 | * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | -------------------------------------------------------------------------------- /cule/atari.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | -------------------------------------------------------------------------------- /cule/atari/actions.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | namespace cule 11 | { 12 | namespace atari 13 | { 14 | 15 | static const Action allActions[_ACTION_MAX] = 16 | { 17 | ACTION_NOOP, 18 | ACTION_FIRE, 19 | ACTION_UP, 20 | ACTION_RIGHT, 21 | ACTION_LEFT, 22 | ACTION_DOWN, 23 | 24 | ACTION_UPRIGHT, 25 | ACTION_UPLEFT, 26 | ACTION_DOWNRIGHT, 27 | ACTION_DOWNLEFT, 28 | ACTION_UPFIRE, 29 | ACTION_RIGHTFIRE, 30 | ACTION_LEFTFIRE, 31 | ACTION_DOWNFIRE, 32 | ACTION_UPRIGHTFIRE, 33 | ACTION_UPLEFTFIRE, 34 | ACTION_DOWNRIGHTFIRE, 35 | ACTION_DOWNLEFTFIRE, 36 | 37 | ACTION_RESET, 38 | }; 39 | 40 | static std::map action_to_string_map = 41 | { 42 | {ACTION_NOOP, "ACTION_NOOP"}, 43 | {ACTION_RIGHT, "ACTION_RIGHT"}, 44 | {ACTION_LEFT, "ACTION_LEFT"}, 45 | {ACTION_DOWN, "ACTION_DOWN"}, 46 | {ACTION_UP, "ACTION_UP"}, 47 | {ACTION_FIRE, "ACTION_FIRE"}, 48 | 49 | {ACTION_UPRIGHT, "ACTION_UPRIGHT"}, 50 | {ACTION_UPLEFT, "ACTION_UPLEFT"}, 51 | {ACTION_DOWNRIGHT, "ACTION_DOWNRIGHT"}, 52 | {ACTION_DOWNLEFT, "ACTION_DOWNLEFT"}, 53 | {ACTION_UPFIRE, "ACTION_UPFIRE"}, 54 | {ACTION_RIGHTFIRE, "ACTION_RIGHTFIRE"}, 55 | {ACTION_LEFTFIRE, "ACTION_LEFTFIRE"}, 56 | {ACTION_DOWNFIRE, "ACTION_DOWNFIRE"}, 57 | {ACTION_UPRIGHTFIRE, "ACTION_UPRIGHTFIRE"}, 58 | {ACTION_UPLEFTFIRE, "ACTION_UPLEFTFIRE"}, 59 | {ACTION_DOWNRIGHTFIRE, "ACTION_DOWNRIGHTFIRE"}, 60 | {ACTION_DOWNLEFTFIRE, "ACTION_DOWNLEFTFIRE"}, 61 | 62 | {ACTION_RESET, "ACTION_RESET"}, 63 | }; 64 | 65 | } // end namespace atari 66 | } // end namespace cule 67 | 68 | -------------------------------------------------------------------------------- /cule/atari/common.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | 12 | // utils 13 | CULE_ANNOTATION 14 | word_t makeWord(const uint8_t lo, const uint8_t hi) 15 | { 16 | return word_t(lo) | (word_t(hi) << 8); 17 | } 18 | 19 | CULE_ANNOTATION 20 | int16_t clamp(int16_t value) 21 | { 22 | if(value >= 160) 23 | value -= 160; 24 | else if(value < 0) 25 | value += 160; 26 | return value; 27 | } 28 | 29 | std::string get_frame_name(const size_t proc_id, const size_t frame_index) 30 | { 31 | std::ostringstream png_filename; 32 | png_filename << "frames/"; 33 | png_filename << proc_id << "/"; 34 | png_filename << std::setfill('0') << std::setw(6) << frame_index; 35 | png_filename << ".png"; 36 | return png_filename.str(); 37 | } 38 | 39 | } // end namespace atari 40 | } // end namespace cule 41 | 42 | -------------------------------------------------------------------------------- /cule/atari/controller.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | namespace cule 12 | { 13 | namespace atari 14 | { 15 | 16 | struct controller 17 | { 18 | 19 | template 20 | static 21 | CULE_ANNOTATION 22 | void set_flags(State& s, 23 | const bool& use_paddles, 24 | const bool& swap_paddles, 25 | const bool& left_difficulty_B, 26 | const bool& right_difficulty_B) 27 | { 28 | s.sysFlags.template change(use_paddles); 29 | s.sysFlags.template change(swap_paddles); 30 | 31 | s.sysFlags.template change(!left_difficulty_B); 32 | s.sysFlags.template change(!right_difficulty_B); 33 | } 34 | 35 | template 36 | static 37 | CULE_ANNOTATION 38 | void set_action(State& s, const Action& player_a_action) 39 | { 40 | UPDATE_FIELD(s.sysFlags.asBitField(), FIELD_SYS_CON_RESET, player_a_action); 41 | } 42 | 43 | template 44 | static 45 | CULE_ANNOTATION 46 | void set_actions(State& s, const Action& player_a_action, const Action&) 47 | { 48 | set_action(s, player_a_action); 49 | } 50 | 51 | template 52 | static 53 | CULE_ANNOTATION 54 | Action get_action(State& s) 55 | { 56 | return Action(SELECT_FIELD(s.sysFlags.asBitField(), FIELD_SYS_CON_RESET)); 57 | } 58 | 59 | template 60 | static 61 | CULE_ANNOTATION 62 | void reset(State& s) 63 | { 64 | UPDATE_FIELD(s.sysFlags.asBitField(), FIELD_SYS_CON, 0); 65 | 66 | if(s.sysFlags[FLAG_CON_PADDLES]) 67 | paddles::reset(s); 68 | else 69 | joystick::reset(s); 70 | } 71 | 72 | template 73 | static 74 | CULE_ANNOTATION 75 | void applyAction(State& s) 76 | { 77 | // Handle reset 78 | s.sysFlags.template change(!s.sysFlags[FLAG_CON_RESET]); 79 | 80 | if(s.sysFlags[FLAG_CON_PADDLES]) 81 | paddles::applyAction(s); 82 | else 83 | joystick::applyAction(s); 84 | } 85 | 86 | template 87 | static 88 | CULE_ANNOTATION 89 | bool read(State& s, const Control_Jack& jack, const Control_DigitalPin& pin) 90 | { 91 | bool value = false; 92 | 93 | if(s.sysFlags[FLAG_CON_PADDLES]) 94 | value = paddles::read(s, jack, pin); 95 | else 96 | value = joystick::read(s, jack, pin); 97 | 98 | return value; 99 | } 100 | 101 | template 102 | static 103 | CULE_ANNOTATION 104 | int32_t read(State& s, const Control_Jack& jack, const Control_AnalogPin& pin) 105 | { 106 | int32_t value = 0; 107 | 108 | if(s.sysFlags[FLAG_CON_PADDLES]) 109 | value = paddles::read(s, jack, pin); 110 | else 111 | value = joystick::read(s, jack, pin); 112 | 113 | return value; 114 | } 115 | 116 | }; // end namespace controller 117 | 118 | } // end namespace atari 119 | } // end namespace cule 120 | 121 | -------------------------------------------------------------------------------- /cule/atari/cuda/frame_state.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | 12 | CULE_ANNOTATION 13 | void state_store_load_helper(frame_state& t, const frame_state& s) 14 | { 15 | t.Color = s.Color; 16 | t.GRP = s.GRP; 17 | t.HM = s.HM; 18 | t.PF = s.PF; 19 | t.POS = s.POS; 20 | t.CurrentGRP0 = s.CurrentGRP0; 21 | t.CurrentGRP1 = s.CurrentGRP1; 22 | 23 | t.clockWhenFrameStarted = s.clockWhenFrameStarted; 24 | t.clockAtLastUpdate = s.clockAtLastUpdate; 25 | t.lastHMOVEClock = s.lastHMOVEClock; 26 | 27 | t.playfieldPriorityAndScore = s.playfieldPriorityAndScore; 28 | 29 | t.tiaFlags = s.tiaFlags; 30 | 31 | t.CurrentPFMask = s.CurrentPFMask; 32 | t.CurrentP0Mask = s.CurrentP0Mask; 33 | t.CurrentP1Mask = s.CurrentP1Mask; 34 | t.CurrentM0Mask = s.CurrentM0Mask; 35 | t.CurrentM1Mask = s.CurrentM1Mask; 36 | t.CurrentBLMask = s.CurrentBLMask; 37 | } 38 | 39 | } // end namespace atari 40 | } // end namespace cule 41 | 42 | -------------------------------------------------------------------------------- /cule/atari/cuda/state.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace cule 7 | { 8 | namespace atari 9 | { 10 | 11 | template 12 | CULE_ANNOTATION 13 | void state_store_load_helper(State& t, const State& s) 14 | { 15 | t.A = s.A; 16 | t.X = s.X; 17 | t.Y = s.Y; 18 | t.SP = s.SP; 19 | t.PC = s.PC; 20 | t.addr = s.addr; 21 | t.value = s.value; 22 | t.noise = s.noise; 23 | 24 | t.cpuCycles = s.cpuCycles; 25 | t.bank = s.bank; 26 | 27 | t.resistance = s.resistance; 28 | 29 | t.GRP = s.GRP; 30 | t.HM = s.HM; 31 | t.PF = s.PF; 32 | t.POS = s.POS; 33 | t.CurrentGRP0 = s.CurrentGRP0; 34 | t.CurrentGRP1 = s.CurrentGRP1; 35 | 36 | t.collision = s.collision; 37 | t.clockWhenFrameStarted = s.clockWhenFrameStarted; 38 | t.clockAtLastUpdate = s.clockAtLastUpdate; 39 | t.dumpDisabledCycle = s.dumpDisabledCycle; 40 | t.VSYNCFinishClock = s.VSYNCFinishClock; 41 | t.lastHMOVEClock = s.lastHMOVEClock; 42 | 43 | t.riotData = s.riotData; 44 | t.cyclesWhenTimerSet = s.cyclesWhenTimerSet; 45 | t.cyclesWhenInterruptReset = s.cyclesWhenInterruptReset; 46 | 47 | t.sysFlags = s.sysFlags; 48 | t.tiaFlags = s.tiaFlags; 49 | 50 | t.frameData = s.frameData; 51 | // t.rand = s.rand; 52 | t.score = s.score; 53 | 54 | t.CurrentPFMask = s.CurrentPFMask; 55 | t.CurrentP0Mask = s.CurrentP0Mask; 56 | t.CurrentP1Mask = s.CurrentP1Mask; 57 | t.CurrentM0Mask = s.CurrentM0Mask; 58 | t.CurrentM1Mask = s.CurrentM1Mask; 59 | t.CurrentBLMask = s.CurrentBLMask; 60 | } 61 | 62 | } // end namespace atari 63 | } // end namespace cule 64 | 65 | -------------------------------------------------------------------------------- /cule/atari/cuda/tables.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | #define COPY_DATA_FROM_VECTOR(ARR_NAME) CULE_ERRCHK(cudaMemcpyToSymbol(gpu_##ARR_NAME, cule::atari::ARR_NAME, sizeof(cule::atari::ARR_NAME))); 13 | 14 | namespace cule 15 | { 16 | namespace atari 17 | { 18 | namespace cuda 19 | { 20 | __constant__ uint8_t gpu_rom[rom::MAX_ROM_SIZE]; 21 | __constant__ uint32_t gpu_NTSCPalette[256]; 22 | } 23 | 24 | __device__ int8_t gpu_ourPlayerPositionResetWhenTable[8][160][160]; 25 | __device__ uint8_t gpu_ourBallMaskTable[4][4][320]; 26 | __device__ uint8_t gpu_ourDisabledMaskTable[640]; 27 | __device__ uint8_t gpu_ourPlayerMaskTable[4][2][8][320]; 28 | __device__ uint8_t gpu_ourPlayerReflectTable[256]; 29 | __device__ uint8_t gpu_ourMissleMaskTable[4][8][4][320]; 30 | __device__ uint8_t gpu_ourPriorityEncoder[2][256]; 31 | __device__ uint16_t gpu_ourCollisionTable[64]; 32 | __device__ uint32_t gpu_ourPlayfieldTable[2][160]; 33 | 34 | __device__ bool gpu_ourHMOVEBlankEnableCycles[76]; 35 | __device__ int16_t gpu_ourPokeDelayTable[64]; 36 | __device__ int16_t gpu_ourCompleteMotionTable[76][16]; 37 | 38 | void initialize_tables(const rom& cart) 39 | { 40 | assert(cart.data() != nullptr); 41 | assert(cart.rom_size() != 0); 42 | 43 | CULE_ERRCHK(cudaMemcpyToSymbol(cule::atari::opcode::gpu_opdata, cule::atari::opcode::opdata, sizeof(cule::atari::opcode::M6502_OPCODE) * 256)); 44 | CULE_ERRCHK(cudaMemcpyToSymbol(cuda::gpu_rom, cart.data(), sizeof(uint8_t) * cart.rom_size())); 45 | CULE_ERRCHK(cudaMemcpyToSymbol(cuda::gpu_NTSCPalette, cule::atari::NTSCPalette, sizeof(uint32_t) * 256)); 46 | CULE_CUDA_PEEK_AND_SYNC; 47 | 48 | COPY_DATA_FROM_VECTOR(ourHMOVEBlankEnableCycles); 49 | COPY_DATA_FROM_VECTOR(ourPokeDelayTable); 50 | COPY_DATA_FROM_VECTOR(ourCompleteMotionTable); 51 | 52 | COPY_DATA_FROM_VECTOR(ourPlayerPositionResetWhenTable); 53 | COPY_DATA_FROM_VECTOR(ourBallMaskTable); 54 | COPY_DATA_FROM_VECTOR(ourDisabledMaskTable); 55 | COPY_DATA_FROM_VECTOR(ourPlayerMaskTable); 56 | COPY_DATA_FROM_VECTOR(ourPlayerReflectTable); 57 | COPY_DATA_FROM_VECTOR(ourMissleMaskTable); 58 | COPY_DATA_FROM_VECTOR(ourPriorityEncoder); 59 | COPY_DATA_FROM_VECTOR(ourCollisionTable); 60 | COPY_DATA_FROM_VECTOR(ourPlayfieldTable); 61 | CULE_CUDA_PEEK_AND_SYNC; 62 | } 63 | } // end namespace atari 64 | } // end namespace cule 65 | -------------------------------------------------------------------------------- /cule/atari/cuda/timer.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // A simple timer class 4 | 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | 12 | class timer 13 | { 14 | public: 15 | 16 | timer() 17 | { 18 | cudaEventCreate(&start); 19 | cudaEventCreate(&end); 20 | cudaEventRecord(start,0); 21 | } 22 | 23 | ~timer() 24 | { 25 | cudaEventDestroy(start); 26 | cudaEventDestroy(end); 27 | } 28 | 29 | float milliseconds_elapsed() 30 | { 31 | float elapsed_time; 32 | cudaEventRecord(end, 0); 33 | cudaEventSynchronize(end); 34 | cudaEventElapsedTime(&elapsed_time, start, end); 35 | return elapsed_time; 36 | } 37 | 38 | float seconds_elapsed() 39 | { 40 | return milliseconds_elapsed() / 1000.0; 41 | } 42 | 43 | private: 44 | 45 | cudaEvent_t start; 46 | cudaEvent_t end; 47 | }; 48 | 49 | } // end namespace atari 50 | } // end namespace cule 51 | 52 | -------------------------------------------------------------------------------- /cule/atari/frame_state.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | 12 | struct frame_state 13 | { 14 | // TIA vars 15 | uint32_t Color; 16 | uint32_t GRP; 17 | uint32_t HM; 18 | uint32_t PF; 19 | uint32_t POS; 20 | uint8_t CurrentGRP0; 21 | uint8_t CurrentGRP1; 22 | 23 | int32_t clockWhenFrameStarted; 24 | int32_t clockAtLastUpdate; 25 | 26 | // Color clock when last HMOVE occured 27 | int32_t lastHMOVEClock; 28 | 29 | uint8_t playfieldPriorityAndScore; 30 | uint16_t cpuCycles; 31 | 32 | tia_flag_t tiaFlags; 33 | 34 | uint8_t M0CosmicArkCounter; 35 | 36 | uint8_t* framePointer; 37 | const uint32_t* srcBuffer; 38 | 39 | uint32_t* CurrentPFMask; 40 | uint8_t * CurrentP0Mask; 41 | uint8_t * CurrentP1Mask; 42 | uint8_t * CurrentM0Mask; 43 | uint8_t * CurrentM1Mask; 44 | uint8_t * CurrentBLMask; 45 | }; 46 | 47 | } // end namespace atari 48 | } // end namespace cule 49 | 50 | -------------------------------------------------------------------------------- /cule/atari/games.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include 41 | #include 42 | #include 43 | #include 44 | #include 45 | #include 46 | #include 47 | #include 48 | #include 49 | #include 50 | #include 51 | #include 52 | #include 53 | #include 54 | #include 55 | #include 56 | #include 57 | #include 58 | #include 59 | #include 60 | #include 61 | #include 62 | #include 63 | #include 64 | #include 65 | #include 66 | #include 67 | #include 68 | #include 69 | #include 70 | #include 71 | #include 72 | 73 | -------------------------------------------------------------------------------- /cule/atari/games/airraid.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace airraid 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 0; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xAA, 0xA9, 0xA8); 35 | s.m_reward = score - s.m_score; 36 | s.m_score = score; 37 | 38 | // update terminal status 39 | int lives = ram::read(s, 0xA7); 40 | s.tiaFlags.template change(lives == 0xFF); 41 | } 42 | 43 | CULE_ANNOTATION 44 | bool isMinimal(const Action &a) 45 | { 46 | switch (a) 47 | { 48 | case ACTION_NOOP: 49 | case ACTION_FIRE: 50 | case ACTION_RIGHT: 51 | case ACTION_LEFT: 52 | case ACTION_RIGHTFIRE: 53 | case ACTION_LEFTFIRE: 54 | return true; 55 | default: 56 | return false; 57 | } 58 | } 59 | 60 | template 61 | CULE_ANNOTATION 62 | int32_t lives(State& s) 63 | { 64 | return cule::atari::ram::read(s.ram, 0xA7); 65 | } 66 | 67 | template 68 | CULE_ANNOTATION 69 | void setTerminal(State& s) 70 | { 71 | // update terminal status 72 | s.tiaFlags.template change(lives(s) == 0xFF); 73 | } 74 | 75 | template 76 | CULE_ANNOTATION 77 | int32_t score(State& s) 78 | { 79 | return cule::atari::games::getDecimalScore(s, 0xAA, 0xA9, 0xA8); 80 | } 81 | 82 | template 83 | CULE_ANNOTATION 84 | int32_t reward(State& s) 85 | { 86 | return score(s) - s.score; 87 | } 88 | 89 | } // end namespace airraid 90 | } // end namespace games 91 | } // end namespace atari 92 | } // end namespace cule 93 | 94 | -------------------------------------------------------------------------------- /cule/atari/games/amidar.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace amidar 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xD9, 0xDA, 0xDB); 35 | s.m_reward = score - s.m_score; 36 | s.m_score = score; 37 | 38 | // update terminal status 39 | int livesByte = ram::read(s, 0xD6); 40 | 41 | // MGB it takes one step for the system to reset; this assumes we've 42 | // reset 43 | s.tiaFlags.template change(livesByte == 0x80); 44 | s.m_lives = (livesByte & 0xF); 45 | } 46 | 47 | CULE_ANNOTATION 48 | bool isMinimal(const Action &a) 49 | { 50 | switch (a) 51 | { 52 | case ACTION_NOOP: 53 | case ACTION_FIRE: 54 | case ACTION_UP: 55 | case ACTION_RIGHT: 56 | case ACTION_LEFT: 57 | case ACTION_DOWN: 58 | case ACTION_UPFIRE: 59 | case ACTION_RIGHTFIRE: 60 | case ACTION_LEFTFIRE: 61 | case ACTION_DOWNFIRE: 62 | return true; 63 | default: 64 | return false; 65 | } 66 | } 67 | 68 | template 69 | CULE_ANNOTATION 70 | int32_t lives(State& s) 71 | { 72 | return cule::atari::ram::read(s.ram, 0xD6) & 0xF; 73 | } 74 | 75 | template 76 | CULE_ANNOTATION 77 | void setTerminal(State& s) 78 | { 79 | // MGB it takes one step for the system to reset; this assumes we've reset 80 | const int32_t livesByte = cule::atari::ram::read(s.ram, 0xD6); 81 | s.tiaFlags.template change(livesByte == 0x80); 82 | } 83 | 84 | template 85 | CULE_ANNOTATION 86 | int32_t score(State& s) 87 | { 88 | return cule::atari::games::getDecimalScore(s, 0xD9, 0xDA, 0xDB); 89 | } 90 | 91 | template 92 | CULE_ANNOTATION 93 | int32_t reward(State& s) 94 | { 95 | return score(s) - s.score; 96 | } 97 | 98 | } // end namespace amidar 99 | } // end namespace games 100 | } // end namespace atari 101 | } // end namespace cule 102 | 103 | -------------------------------------------------------------------------------- /cule/atari/games/assault.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace assault 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 4; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0x82, 0x81, 0x80); 35 | s.m_reward = score - s.m_score; 36 | s.m_score = score; 37 | 38 | // update terminal status 39 | s.m_lives = ram::read(s, 0xE5); 40 | s.tiaFlags.template change(s.m_lives == 0); 41 | } 42 | 43 | CULE_ANNOTATION 44 | bool isMinimal(const Action &a) 45 | { 46 | switch (a) 47 | { 48 | case ACTION_NOOP: 49 | case ACTION_FIRE: 50 | case ACTION_UP: 51 | case ACTION_RIGHT: 52 | case ACTION_LEFT: 53 | case ACTION_RIGHTFIRE: 54 | case ACTION_LEFTFIRE: 55 | return true; 56 | default: 57 | return false; 58 | } 59 | } 60 | 61 | template 62 | CULE_ANNOTATION 63 | int32_t lives(State& s) 64 | { 65 | return cule::atari::ram::read(s.ram, 0xE5); 66 | } 67 | 68 | template 69 | CULE_ANNOTATION 70 | void setTerminal(State& s) 71 | { 72 | // update terminal status 73 | s.tiaFlags.template change(lives(s) == 0); 74 | } 75 | 76 | template 77 | CULE_ANNOTATION 78 | int32_t score(State& s) 79 | { 80 | return cule::atari::games::getDecimalScore(s, 0x82, 0x81, 0x80); 81 | } 82 | 83 | template 84 | CULE_ANNOTATION 85 | int32_t reward(State& s) 86 | { 87 | return score(s) - s.score; 88 | } 89 | 90 | } // end namespace assault 91 | } // end namespace games 92 | } // end namespace atari 93 | } // end namespace cule 94 | 95 | -------------------------------------------------------------------------------- /cule/atari/games/asterix.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace asterix 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 0; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | s.m_lives = 3; 25 | } 26 | 27 | template 28 | CULE_ANNOTATION 29 | void step(State& s) 30 | { 31 | using cule::atari::games::getDecimalScore; 32 | using cule::atari::ram::read; 33 | 34 | // update the reward 35 | int score = getDecimalScore(s, 0xE0, 0xDF, 0xDE); 36 | s.m_reward = score - s.m_score; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | s.m_lives = ram::read(s, 0xD3) & 0xF; 41 | int death_counter = ram::read(s, 0xC7); 42 | 43 | // we cannot wait for lives to be set to 0, because the agent has the 44 | // option of the restarting the game on the very last frame (when lives==1 45 | // and death_counter == 0x01) by holding 'fire' 46 | s.tiaFlags.template change((death_counter == 0x01) && (s.m_lives == 1)); 47 | } 48 | 49 | CULE_ANNOTATION 50 | bool isMinimal(const Action &a) 51 | { 52 | switch (a) 53 | { 54 | case ACTION_NOOP: 55 | case ACTION_UP: 56 | case ACTION_RIGHT: 57 | case ACTION_LEFT: 58 | case ACTION_DOWN: 59 | case ACTION_UPRIGHT: 60 | case ACTION_UPLEFT: 61 | case ACTION_DOWNRIGHT: 62 | case ACTION_DOWNLEFT: 63 | return true; 64 | default: 65 | return false; 66 | } 67 | } 68 | 69 | template 70 | CULE_ANNOTATION 71 | int32_t lives(State& s) 72 | { 73 | return cule::atari::ram::read(s.ram, 0xD3) & 0xF; 74 | } 75 | 76 | template 77 | CULE_ANNOTATION 78 | void setTerminal(State& s) 79 | { 80 | // update terminal status 81 | int death_counter = cule::atari::ram::read(s.ram, 0xC7); 82 | 83 | // we cannot wait for lives to be set to 0, because the agent has the 84 | // option of the restarting the game on the very last frame (when lives==1 85 | // and death_counter == 0x01) by holding 'fire' 86 | s.tiaFlags.template change((death_counter == 0x01) && (lives(s) == 1)); 87 | } 88 | 89 | template 90 | CULE_ANNOTATION 91 | int32_t score(State& s) 92 | { 93 | return cule::atari::games::getDecimalScore(s, 0xE0, 0xDF, 0xDE); 94 | } 95 | 96 | template 97 | CULE_ANNOTATION 98 | int32_t reward(State& s) 99 | { 100 | return score(s) - s.score; 101 | } 102 | 103 | } // end namespace asterix 104 | } // end namespace games 105 | } // end namespace atari 106 | } // end namespace cule 107 | 108 | -------------------------------------------------------------------------------- /cule/atari/games/asteroids.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace asteroids 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 4; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xBE, 0xBD); 35 | score *= 10; 36 | s.m_reward = score - s.m_score; 37 | 38 | // Deal with score wrapping. In truth this should be done for all games and in a more 39 | // uniform fashion. 40 | if (s.m_reward < 0) 41 | { 42 | const int WRAP_SCORE = 100000; 43 | s.m_reward += WRAP_SCORE; 44 | } 45 | s.m_score = score; 46 | 47 | // update terminal status 48 | int byte = ram::read(s, 0xBC); 49 | s.m_lives = (byte - (byte & 15)) >> 4; 50 | s.tiaFlags.template change(s.m_lives == 0); 51 | } 52 | 53 | CULE_ANNOTATION 54 | bool isMinimal(const Action &a) 55 | { 56 | switch (a) 57 | { 58 | case ACTION_NOOP: 59 | case ACTION_FIRE: 60 | case ACTION_UP: 61 | case ACTION_RIGHT: 62 | case ACTION_LEFT: 63 | case ACTION_DOWN: 64 | case ACTION_UPRIGHT: 65 | case ACTION_UPLEFT: 66 | case ACTION_UPFIRE: 67 | case ACTION_RIGHTFIRE: 68 | case ACTION_LEFTFIRE: 69 | case ACTION_DOWNFIRE: 70 | case ACTION_UPRIGHTFIRE: 71 | case ACTION_UPLEFTFIRE: 72 | return true; 73 | default: 74 | return false; 75 | } 76 | } 77 | 78 | template 79 | CULE_ANNOTATION 80 | int32_t lives(State& s) 81 | { 82 | // update terminal status 83 | int byte = cule::atari::ram::read(s.ram, 0xBC); 84 | return (byte - (byte & 0xF)) >> 4; 85 | } 86 | 87 | template 88 | CULE_ANNOTATION 89 | void setTerminal(State& s) 90 | { 91 | // update terminal status 92 | s.tiaFlags.template change(lives(s) == 0); 93 | } 94 | 95 | template 96 | CULE_ANNOTATION 97 | int32_t score(State& s) 98 | { 99 | return 10 * cule::atari::games::getDecimalScore(s, 0xBE, 0xBD); 100 | } 101 | 102 | template 103 | CULE_ANNOTATION 104 | int32_t reward(State& s) 105 | { 106 | int32_t m_reward = score(s) - s.score; 107 | /* s.score = m_score; */ 108 | 109 | // Deal with score wrapping. In truth this should be done for all games and in a more 110 | // uniform fashion. 111 | if (m_reward < 0) 112 | { 113 | const int WRAP_SCORE = 100000; 114 | m_reward += WRAP_SCORE; 115 | } 116 | 117 | return m_reward; 118 | } 119 | 120 | } // end namespace asteroids 121 | } // end namespace games 122 | } // end namespace atari 123 | } // end namespace cule 124 | 125 | -------------------------------------------------------------------------------- /cule/atari/games/bankheist.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace bankheist 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 5; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xDA, 0xD9, 0xD8); 35 | s.m_reward = score - s.m_score; 36 | s.m_score = score; 37 | 38 | // update terminal status 39 | int death_timer = ram::read(s, 0xCE); 40 | s.m_lives = ram::read(s, 0xD5); 41 | 42 | s.tiaFlags.template change((death_timer == 0x01) && (s.m_lives == 0x00)); 43 | } 44 | 45 | CULE_ANNOTATION 46 | bool isMinimal(const Action &a) 47 | { 48 | switch (a) 49 | { 50 | case ACTION_NOOP: 51 | case ACTION_FIRE: 52 | case ACTION_UP: 53 | case ACTION_RIGHT: 54 | case ACTION_LEFT: 55 | case ACTION_DOWN: 56 | case ACTION_UPRIGHT: 57 | case ACTION_UPLEFT: 58 | case ACTION_DOWNRIGHT: 59 | case ACTION_DOWNLEFT: 60 | case ACTION_UPFIRE: 61 | case ACTION_RIGHTFIRE: 62 | case ACTION_LEFTFIRE: 63 | case ACTION_DOWNFIRE: 64 | case ACTION_UPRIGHTFIRE: 65 | case ACTION_UPLEFTFIRE: 66 | case ACTION_DOWNRIGHTFIRE: 67 | case ACTION_DOWNLEFTFIRE: 68 | return true; 69 | default: 70 | return false; 71 | } 72 | } 73 | 74 | template 75 | CULE_ANNOTATION 76 | int32_t lives(State& s) 77 | { 78 | return cule::atari::ram::read(s.ram, 0xD5); 79 | } 80 | 81 | template 82 | CULE_ANNOTATION 83 | void setTerminal(State& s) 84 | { 85 | // update terminal status 86 | int death_timer = cule::atari::ram::read(s.ram, 0xCE); 87 | s.tiaFlags.template change((death_timer == 0x01) && (lives(s) == 0x00)); 88 | } 89 | 90 | template 91 | CULE_ANNOTATION 92 | int32_t score(State& s) 93 | { 94 | return cule::atari::games::getDecimalScore(s, 0xDA, 0xD9, 0xD8); 95 | } 96 | 97 | template 98 | CULE_ANNOTATION 99 | int32_t reward(State& s) 100 | { 101 | return score(s) - s.score; 102 | } 103 | 104 | } // end namespace bankheist 105 | } // end namespace games 106 | } // end namespace atari 107 | } // end namespace cule 108 | 109 | -------------------------------------------------------------------------------- /cule/atari/games/beamrider.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace beamrider 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 0; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | s.m_lives = 3; 25 | } 26 | 27 | template 28 | CULE_ANNOTATION 29 | void step(State& s) 30 | { 31 | using cule::atari::games::getDecimalScore; 32 | using cule::atari::ram::read; 33 | 34 | // update the reward 35 | int score = getDecimalScore(s, 9, 10, 11); 36 | s.m_reward = score - s.m_score; 37 | s.m_score = score; 38 | int new_lives = ram::read(s, 0x85) + 1; 39 | 40 | // Decrease lives *after* the death animation; this is necessary as the lives counter 41 | // blinks during death 42 | if (new_lives == s.m_lives - 1) 43 | { 44 | if (ram::read(s, 0x8C) == 0x01) 45 | { 46 | s.m_lives = new_lives; 47 | } 48 | } 49 | else 50 | { 51 | s.m_lives = new_lives; 52 | } 53 | 54 | // update terminal status 55 | int byte_val = ram::read(s, 5); 56 | s.tiaFlags.template change(byte_val == 255); 57 | byte_val = byte_val & 15; 58 | s.tiaFlags.template change(s.tiaFlags[FLAG_ALE_TERMINAL] || (byte_val < 0)); 59 | } 60 | 61 | CULE_ANNOTATION 62 | bool isMinimal(const Action &a) 63 | { 64 | switch (a) 65 | { 66 | case ACTION_NOOP: 67 | case ACTION_FIRE: 68 | case ACTION_UP: 69 | case ACTION_RIGHT: 70 | case ACTION_LEFT: 71 | case ACTION_UPRIGHT: 72 | case ACTION_UPLEFT: 73 | case ACTION_RIGHTFIRE: 74 | case ACTION_LEFTFIRE: 75 | return true; 76 | default: 77 | return false; 78 | } 79 | } 80 | 81 | template 82 | CULE_ANNOTATION 83 | int32_t lives(State& s) 84 | { 85 | return cule::atari::ram::read(s.ram, 0x85) + 1; 86 | } 87 | 88 | template 89 | CULE_ANNOTATION 90 | void setTerminal(State& s) 91 | { 92 | // update terminal status 93 | int byte_val = cule::atari::ram::read(s.ram, 5); 94 | s.tiaFlags.template change(byte_val == 255); 95 | byte_val = byte_val & 0xF; 96 | s.tiaFlags.template change(s.tiaFlags[FLAG_ALE_TERMINAL] || (byte_val < 0)); 97 | } 98 | 99 | template 100 | CULE_ANNOTATION 101 | int32_t score(State& s) 102 | { 103 | return cule::atari::games::getDecimalScore(s, 9, 10, 11); 104 | } 105 | 106 | template 107 | CULE_ANNOTATION 108 | int32_t reward(State& s) 109 | { 110 | return score(s) - s.score; 111 | } 112 | 113 | } // end namespace beamrider 114 | } // end namespace games 115 | } // end namespace atari 116 | } // end namespace cule 117 | 118 | -------------------------------------------------------------------------------- /cule/atari/games/berzerk.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace berzerk 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | CULE_ANNOTATION 27 | bool isMinimal(const Action &a) 28 | { 29 | switch (a) 30 | { 31 | case ACTION_NOOP: 32 | case ACTION_FIRE: 33 | case ACTION_UP: 34 | case ACTION_RIGHT: 35 | case ACTION_LEFT: 36 | case ACTION_DOWN: 37 | case ACTION_UPRIGHT: 38 | case ACTION_UPLEFT: 39 | case ACTION_DOWNRIGHT: 40 | case ACTION_DOWNLEFT: 41 | case ACTION_UPFIRE: 42 | case ACTION_RIGHTFIRE: 43 | case ACTION_LEFTFIRE: 44 | case ACTION_DOWNFIRE: 45 | case ACTION_UPRIGHTFIRE: 46 | case ACTION_UPLEFTFIRE: 47 | case ACTION_DOWNRIGHTFIRE: 48 | case ACTION_DOWNLEFTFIRE: 49 | return true; 50 | default: 51 | return false; 52 | } 53 | } 54 | 55 | template 56 | CULE_ANNOTATION 57 | int32_t lives(State& s) 58 | { 59 | return cule::atari::ram::read(s.ram, 0xDA) + 1; 60 | } 61 | 62 | template 63 | CULE_ANNOTATION 64 | void setTerminal(State& s) 65 | { 66 | s.tiaFlags.template change((lives(s) - 1) == 0xFF); 67 | } 68 | 69 | template 70 | CULE_ANNOTATION 71 | int32_t score(State& s) 72 | { 73 | return cule::atari::games::getDecimalScore(s, 95, 94, 93); 74 | } 75 | 76 | template 77 | CULE_ANNOTATION 78 | int32_t reward(State& s) 79 | { 80 | return score(s) - s.score; 81 | } 82 | 83 | } // end namespace berzerk 84 | } // end namespace games 85 | } // end namespace atari 86 | } // end namespace cule 87 | 88 | -------------------------------------------------------------------------------- /cule/atari/games/bowling.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace bowling 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | } 24 | 25 | template 26 | CULE_ANNOTATION 27 | void step(State& s) 28 | { 29 | // update the reward 30 | s.m_reward = s.m_score - getScore(s); 31 | 32 | // handle KO 33 | setTerminal(s); 34 | } 35 | 36 | CULE_ANNOTATION 37 | bool isMinimal(const Action &a) 38 | { 39 | switch (a) 40 | { 41 | case ACTION_NOOP: 42 | case ACTION_FIRE: 43 | case ACTION_UP: 44 | case ACTION_DOWN: 45 | case ACTION_UPFIRE: 46 | case ACTION_DOWNFIRE: 47 | return true; 48 | default: 49 | return false; 50 | } 51 | } 52 | 53 | template 54 | CULE_ANNOTATION 55 | int32_t lives(State&) 56 | { 57 | return 0; 58 | } 59 | 60 | template 61 | CULE_ANNOTATION 62 | void setTerminal(State& s) 63 | { 64 | s.tiaFlags.template change(cule::atari::ram::read(s.ram, 0xA4) > 0x10); 65 | } 66 | 67 | template 68 | CULE_ANNOTATION 69 | int32_t score(State& s) 70 | { 71 | return cule::atari::games::getDecimalScore(s, 0xA1, 0xA6); 72 | } 73 | 74 | template 75 | CULE_ANNOTATION 76 | int32_t reward(State& s) 77 | { 78 | return score(s) - s.score; 79 | } 80 | 81 | } // end namespace bowling 82 | } // end namespace games 83 | } // end namespace atari 84 | } // end namespace cule 85 | 86 | -------------------------------------------------------------------------------- /cule/atari/games/breakout.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace breakout 15 | { 16 | 17 | // reset 18 | template 19 | CULE_ANNOTATION 20 | void reset(State& s) 21 | { 22 | s.m_reward = 0; 23 | s.m_score = 0; 24 | s.m_lives = 5; 25 | s.tiaFlags.clear(FLAG_ALE_STARTED); 26 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 27 | } 28 | 29 | // process the latest information from ALE 30 | template 31 | CULE_ANNOTATION 32 | void step(State& s) 33 | { 34 | // update the reward 35 | uint8_t x = cule::atari::ram::read(s, 77); 36 | uint8_t y = cule::atari::ram::read(s, 76); 37 | 38 | uint32_t score = 1 * (x & 0x0F) + 10 * ((x & 0xF0) >> 4) + 100 * (y & 0x0F); 39 | s.m_reward = score - s.m_score; 40 | s.m_score = score; 41 | 42 | // update terminal status 43 | s.m_lives = cule::atari::ram::read(s, 57); 44 | 45 | if (!s.tiaFlags[FLAG_ALE_STARTED] && (s.m_lives == 5)) 46 | { 47 | s.tiaFlags.set(FLAG_ALE_STARTED); 48 | } 49 | 50 | s.tiaFlags.template change(s.tiaFlags[FLAG_ALE_STARTED] && (s.m_lives == 0)); 51 | } 52 | 53 | // is an action part of the minimal set? 54 | CULE_ANNOTATION 55 | bool isMinimal(const Action& a) 56 | { 57 | switch (a) 58 | { 59 | case ACTION_NOOP: 60 | case ACTION_FIRE: 61 | case ACTION_RIGHT: 62 | case ACTION_LEFT: 63 | return true; 64 | default: 65 | return false; 66 | } 67 | } 68 | 69 | template 70 | CULE_ANNOTATION 71 | int32_t lives(State& s) 72 | { 73 | return cule::atari::ram::read(s.ram, 57); 74 | } 75 | 76 | template 77 | CULE_ANNOTATION 78 | void setTerminal(State& s) 79 | { 80 | int m_lives = lives(s); 81 | 82 | if (!s.tiaFlags[FLAG_ALE_STARTED] && (m_lives == 5)) 83 | { 84 | s.tiaFlags.set(FLAG_ALE_STARTED); 85 | } 86 | 87 | s.tiaFlags.template change(s.tiaFlags[FLAG_ALE_STARTED] && (m_lives == 0)); 88 | } 89 | 90 | template 91 | CULE_ANNOTATION 92 | int32_t score(State& s) 93 | { 94 | // update the reward 95 | uint8_t x = cule::atari::ram::read(s.ram, 77); 96 | uint8_t y = cule::atari::ram::read(s.ram, 76); 97 | 98 | return 1 * (x & 0x0F) + 10 * ((x & 0xF0) >> 4) + 100 * (y & 0x0F); 99 | } 100 | 101 | template 102 | CULE_ANNOTATION 103 | int32_t reward(State& s) 104 | { 105 | return score(s) - s.score; 106 | } 107 | 108 | } // end namespace breakout 109 | } // end namespace games 110 | } // end namespace atari 111 | } // end namespace cule 112 | 113 | -------------------------------------------------------------------------------- /cule/atari/games/carnival.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace carnival 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | } 24 | 25 | template 26 | CULE_ANNOTATION 27 | void step(State& s) 28 | { 29 | using cule::atari::games::getDecimalScore; 30 | using cule::atari::ram::read; 31 | 32 | // update the reward 33 | int score = getDecimalScore(s, 0xAE, 0xAD); 34 | score *= 10; 35 | s.m_reward = score - s.m_score; 36 | s.m_score = score; 37 | 38 | // update terminal status 39 | int ammo = ram::read(s, 0x83); 40 | s.tiaFlags.template change(ammo < 1); 41 | } 42 | 43 | CULE_ANNOTATION 44 | bool isMinimal(const Action &a) 45 | { 46 | switch (a) 47 | { 48 | case ACTION_NOOP: 49 | case ACTION_FIRE: 50 | case ACTION_RIGHT: 51 | case ACTION_LEFT: 52 | case ACTION_RIGHTFIRE: 53 | case ACTION_LEFTFIRE: 54 | return true; 55 | default: 56 | return false; 57 | } 58 | } 59 | 60 | template 61 | CULE_ANNOTATION 62 | int32_t lives(State&) 63 | { 64 | return 0; 65 | } 66 | 67 | template 68 | CULE_ANNOTATION 69 | void setTerminal(State& s) 70 | { 71 | // update terminal status 72 | int ammo = cule::atari::ram::read(s.ram, 0x83); 73 | s.tiaFlags.template change(ammo < 1); 74 | } 75 | 76 | template 77 | CULE_ANNOTATION 78 | int32_t score(State& s) 79 | { 80 | return 10 * cule::atari::games::getDecimalScore(s, 0xAE, 0xAD); 81 | } 82 | 83 | template 84 | CULE_ANNOTATION 85 | int32_t reward(State& s) 86 | { 87 | return score(s) - s.score; 88 | } 89 | 90 | } // end namespace carnival 91 | } // end namespace games 92 | } // end namespace atari 93 | } // end namespace cule 94 | 95 | -------------------------------------------------------------------------------- /cule/atari/games/centipede.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace centipede 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 118, 117, 116); 35 | s.m_reward = score - s.m_score; 36 | s.m_score = score; 37 | 38 | // HACK: the score sometimes gets reset before termination; ignoring for now. 39 | if (s.m_reward < 0) s.m_reward = 0.0; 40 | 41 | // Maximum of 8 lives 42 | s.m_lives = ((ram::read(s, 0xED) >> 4) & 0x7) + 1; 43 | 44 | // update terminal status 45 | int some_bit = ram::read(s, 0xA6) & 0x40; 46 | s.tiaFlags.template change(some_bit != 0); 47 | } 48 | 49 | CULE_ANNOTATION 50 | bool isMinimal(const Action &a) 51 | { 52 | switch (a) 53 | { 54 | case ACTION_NOOP: 55 | case ACTION_FIRE: 56 | case ACTION_UP: 57 | case ACTION_RIGHT: 58 | case ACTION_LEFT: 59 | case ACTION_DOWN: 60 | case ACTION_UPRIGHT: 61 | case ACTION_UPLEFT: 62 | case ACTION_DOWNRIGHT: 63 | case ACTION_DOWNLEFT: 64 | case ACTION_UPFIRE: 65 | case ACTION_RIGHTFIRE: 66 | case ACTION_LEFTFIRE: 67 | case ACTION_DOWNFIRE: 68 | case ACTION_UPRIGHTFIRE: 69 | case ACTION_UPLEFTFIRE: 70 | case ACTION_DOWNRIGHTFIRE: 71 | case ACTION_DOWNLEFTFIRE: 72 | return true; 73 | default: 74 | return false; 75 | } 76 | } 77 | 78 | template 79 | CULE_ANNOTATION 80 | int32_t lives(State& s) 81 | { 82 | // Maximum of 8 lives 83 | return ((cule::atari::ram::read(s.ram, 0xED) >> 4) & 0x7) + 1; 84 | } 85 | 86 | template 87 | CULE_ANNOTATION 88 | void setTerminal(State& s) 89 | { 90 | // update terminal status 91 | int some_bit = cule::atari::ram::read(s.ram, 0xA6) & 0x40; 92 | s.tiaFlags.template change(some_bit != 0); 93 | } 94 | 95 | template 96 | CULE_ANNOTATION 97 | int32_t score(State& s) 98 | { 99 | return cule::atari::games::getDecimalScore(s, 118, 117, 116); 100 | } 101 | 102 | template 103 | CULE_ANNOTATION 104 | int32_t reward(State& s) 105 | { 106 | int32_t m_reward = score(s) - s.score; 107 | 108 | // HACK: the score sometimes gets reset before termination; ignoring for now. 109 | if (m_reward < 0) 110 | m_reward = 0; 111 | 112 | return m_reward; 113 | } 114 | 115 | } // end namespace centipede 116 | } // end namespace games 117 | } // end namespace atari 118 | } // end namespace cule 119 | 120 | -------------------------------------------------------------------------------- /cule/atari/games/chopper.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace chopper 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | CULE_ANNOTATION 27 | bool isMinimal(const Action &a) 28 | { 29 | switch (a) 30 | { 31 | case ACTION_NOOP: 32 | case ACTION_FIRE: 33 | case ACTION_UP: 34 | case ACTION_RIGHT: 35 | case ACTION_LEFT: 36 | case ACTION_DOWN: 37 | case ACTION_UPRIGHT: 38 | case ACTION_UPLEFT: 39 | case ACTION_DOWNRIGHT: 40 | case ACTION_DOWNLEFT: 41 | case ACTION_UPFIRE: 42 | case ACTION_RIGHTFIRE: 43 | case ACTION_LEFTFIRE: 44 | case ACTION_DOWNFIRE: 45 | case ACTION_UPRIGHTFIRE: 46 | case ACTION_UPLEFTFIRE: 47 | case ACTION_DOWNRIGHTFIRE: 48 | case ACTION_DOWNLEFTFIRE: 49 | return true; 50 | default: 51 | return false; 52 | } 53 | } 54 | 55 | template 56 | CULE_ANNOTATION 57 | int32_t lives(State& s) 58 | { 59 | return cule::atari::ram::read(s.ram, 0xE4) & 0xF; 60 | } 61 | 62 | template 63 | CULE_ANNOTATION 64 | void setTerminal(State& s) 65 | { 66 | // update terminal status 67 | s.tiaFlags.template change(lives(s) == 0); 68 | } 69 | 70 | template 71 | CULE_ANNOTATION 72 | int32_t score(State& s) 73 | { 74 | return 100 * cule::atari::games::getDecimalScore(s, 0xEE, 0xEC); 75 | } 76 | 77 | template 78 | CULE_ANNOTATION 79 | int32_t reward(State& s) 80 | { 81 | return score(s) - s.score; 82 | } 83 | 84 | } // end namespace chopper 85 | } // end namespace games 86 | } // end namespace atari 87 | } // end namespace cule 88 | 89 | -------------------------------------------------------------------------------- /cule/atari/games/crazyclimber.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace crazyclimber 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 5; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = 0; 35 | int digit = ram::read(s, 0x82); 36 | score += digit; 37 | digit = ram::read(s, 0x83); 38 | score += 10 * digit; 39 | digit = ram::read(s, 0x84); 40 | score += 100 * digit; 41 | digit = ram::read(s, 0x85); 42 | score += 1000 * digit; 43 | score *= 100; 44 | s.m_reward = score - s.m_score; 45 | if (s.m_reward < 0) s.m_reward = 0; 46 | s.m_score = score; 47 | 48 | // update terminal status 49 | s.m_lives = ram::read(s, 0xAA); 50 | s.tiaFlags.template change(s.m_lives == 0); 51 | } 52 | 53 | CULE_ANNOTATION 54 | bool isMinimal(const Action &a) 55 | { 56 | switch (a) 57 | { 58 | case ACTION_NOOP: 59 | case ACTION_UP: 60 | case ACTION_RIGHT: 61 | case ACTION_LEFT: 62 | case ACTION_DOWN: 63 | case ACTION_UPRIGHT: 64 | case ACTION_UPLEFT: 65 | case ACTION_DOWNRIGHT: 66 | case ACTION_DOWNLEFT: 67 | return true; 68 | default: 69 | return false; 70 | } 71 | } 72 | 73 | template 74 | CULE_ANNOTATION 75 | int32_t lives(State& s) 76 | { 77 | return cule::atari::ram::read(s.ram, 0xAA); 78 | } 79 | 80 | template 81 | CULE_ANNOTATION 82 | void setTerminal(State& s) 83 | { 84 | // update terminal status 85 | s.tiaFlags.template change(lives(s) == 0); 86 | } 87 | 88 | template 89 | CULE_ANNOTATION 90 | int32_t score(State& s) 91 | { 92 | int32_t m_score = 0; 93 | int32_t digit = cule::atari::ram::read(s.ram, 0x82); 94 | m_score += digit; 95 | digit = cule::atari::ram::read(s.ram, 0x83); 96 | m_score += 10 * digit; 97 | digit = cule::atari::ram::read(s.ram, 0x84); 98 | m_score += 100 * digit; 99 | digit = cule::atari::ram::read(s.ram, 0x85); 100 | m_score += 1000 * digit; 101 | m_score *= 100; 102 | 103 | return m_score; 104 | } 105 | 106 | template 107 | CULE_ANNOTATION 108 | int32_t reward(State& s) 109 | { 110 | int32_t m_reward = score(s) - s.score; 111 | 112 | if(m_reward < 0) 113 | m_reward = 0; 114 | 115 | return m_reward; 116 | } 117 | 118 | } // end namespace crazyclimber 119 | } // end namespace games 120 | } // end namespace atari 121 | } // end namespace cule 122 | 123 | -------------------------------------------------------------------------------- /cule/atari/games/defender.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace defender 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int mult = 1, score = 0; 35 | for (int digit = 0; digit < 6; digit++) 36 | { 37 | int v = ram::read(s, 0x9C + digit) & 0xF; 38 | // A indicates a 0 which we don't display 39 | if (v == 0xA) v = 0; 40 | score += v * mult; 41 | mult *= 10; 42 | } 43 | s.m_reward = score - s.m_score; 44 | s.m_score = score; 45 | 46 | // update terminal status 47 | s.m_lives = ram::read(s, 0xC2); 48 | s.tiaFlags.template change(s.m_lives == 0); 49 | } 50 | 51 | CULE_ANNOTATION 52 | bool isMinimal(const Action &a) 53 | { 54 | switch (a) 55 | { 56 | case ACTION_NOOP: 57 | case ACTION_FIRE: 58 | case ACTION_UP: 59 | case ACTION_RIGHT: 60 | case ACTION_LEFT: 61 | case ACTION_DOWN: 62 | case ACTION_UPRIGHT: 63 | case ACTION_UPLEFT: 64 | case ACTION_DOWNRIGHT: 65 | case ACTION_DOWNLEFT: 66 | case ACTION_UPFIRE: 67 | case ACTION_RIGHTFIRE: 68 | case ACTION_LEFTFIRE: 69 | case ACTION_DOWNFIRE: 70 | case ACTION_UPRIGHTFIRE: 71 | case ACTION_UPLEFTFIRE: 72 | case ACTION_DOWNRIGHTFIRE: 73 | case ACTION_DOWNLEFTFIRE: 74 | return true; 75 | default: 76 | return false; 77 | } 78 | } 79 | 80 | template 81 | CULE_ANNOTATION 82 | int32_t lives(State& s) 83 | { 84 | return cule::atari::ram::read(s.ram, 0xC2); 85 | } 86 | 87 | template 88 | CULE_ANNOTATION 89 | void setTerminal(State& s) 90 | { 91 | s.tiaFlags.template change(lives(s) == 0); 92 | } 93 | 94 | template 95 | CULE_ANNOTATION 96 | int32_t score(State& s) 97 | { 98 | int mult = 1, m_score = 0; 99 | for (int digit = 0; digit < 6; digit++) 100 | { 101 | int v = cule::atari::ram::read(s.ram, 0x9C + digit) & 0xF; 102 | // a indicates a 0 which we don't display 103 | if (v == 0xA) v = 0; 104 | m_score += v * mult; 105 | mult *= 10; 106 | } 107 | 108 | return m_score; 109 | } 110 | 111 | template 112 | CULE_ANNOTATION 113 | int32_t reward(State& s) 114 | { 115 | return score(s) - s.score; 116 | } 117 | 118 | } // end namespace defender 119 | } // end namespace games 120 | } // end namespace atari 121 | } // end namespace cule 122 | 123 | -------------------------------------------------------------------------------- /cule/atari/games/demonattack.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace demonattack 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 4; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0x85, 0x83, 0x81); 35 | 36 | // MGB: something funny with the RAM; it is not initialized to 0? 37 | if (ram::read(s, 0x81) == 0xAB && 38 | ram::read(s, 0x83) == 0xCD && 39 | ram::read(s, 0x85) == 0xEA) 40 | { 41 | score = 0; 42 | } 43 | s.m_reward = score - s.m_score; 44 | s.m_score = score; 45 | 46 | // update terminal status 47 | int lives_displayed = ram::read(s, 0xF2); 48 | int display_flag = ram::read(s, 0xF1); 49 | s.tiaFlags.template change((lives_displayed == 0) && (display_flag == 0xBD)); 50 | s.m_lives = lives_displayed + 1; // Once we reach terminal, lives() will correctly return 0 51 | } 52 | 53 | CULE_ANNOTATION 54 | bool isMinimal(const Action &a) 55 | { 56 | switch (a) 57 | { 58 | case ACTION_NOOP: 59 | case ACTION_FIRE: 60 | case ACTION_RIGHT: 61 | case ACTION_LEFT: 62 | case ACTION_RIGHTFIRE: 63 | case ACTION_LEFTFIRE: 64 | return true; 65 | default: 66 | return false; 67 | } 68 | } 69 | 70 | template 71 | CULE_ANNOTATION 72 | int32_t lives(State& s) 73 | { 74 | return cule::atari::ram::read(s.ram, 0xF2) + 1; // Once we reach terminal, lives() will correctly return 0 75 | } 76 | 77 | template 78 | CULE_ANNOTATION 79 | void setTerminal(State& s) 80 | { 81 | // update terminal status 82 | int lives_displayed = lives(s) - 1; 83 | int display_flag = cule::atari::ram::read(s.ram, 0xF1); 84 | s.tiaFlags.template change((lives_displayed == 0) && (display_flag == 0xBD)); 85 | } 86 | 87 | template 88 | CULE_ANNOTATION 89 | int32_t score(State& s) 90 | { 91 | int32_t m_score = cule::atari::games::getDecimalScore(s, 0x85, 0x83, 0x81); 92 | 93 | // MGB: something funny with the RAM; it is not initialized to 0? 94 | if (cule::atari::ram::read(s.ram, 0x81) == 0xAB && 95 | cule::atari::ram::read(s.ram, 0x83) == 0xCD && 96 | cule::atari::ram::read(s.ram, 0x85) == 0xEA) 97 | { 98 | m_score = 0; 99 | } 100 | 101 | return m_score; 102 | } 103 | 104 | template 105 | CULE_ANNOTATION 106 | int32_t reward(State& s) 107 | { 108 | return score(s) - s.score; 109 | } 110 | 111 | } // end namespace demonattack 112 | } // end namespace games 113 | } // end namespace atari 114 | } // end namespace cule 115 | 116 | -------------------------------------------------------------------------------- /cule/atari/games/detail/attributes.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace cule 6 | { 7 | namespace atari 8 | { 9 | namespace games 10 | { 11 | 12 | enum GAME_TYPE : uint8_t 13 | { 14 | GAME_BOWLING, 15 | GAME_BOXING, 16 | GAME_BREAKOUT, 17 | GAME_FISHING_DERBY, 18 | GAME_FREEWAY, 19 | GAME_KABOOM, 20 | GAME_PONG, 21 | GAME_SKIING, 22 | GAME_TENNIS, 23 | 24 | GAME_ADVENTURE, 25 | GAME_AIR_RAID, 26 | GAME_ALIEN, 27 | GAME_AMIDAR, 28 | GAME_ASSAULT, 29 | GAME_ATLANTIS, 30 | GAME_BANK_HEIST, 31 | GAME_BERZERK, 32 | GAME_CARNIVAL, 33 | GAME_CHOPPER, 34 | GAME_DEFENDER, 35 | GAME_DEMON_ATTACK, 36 | GAME_ENDURO, 37 | GAME_FROSTBITE, 38 | GAME_GOPHER, 39 | GAME_ICE_HOCKEY, 40 | GAME_JOURNEY_ESCAPE, 41 | GAME_NAME_THIS_GAME, 42 | GAME_PITFALL, 43 | GAME_POOYAN, 44 | GAME_QBERT, 45 | GAME_RIVERRAID, 46 | GAME_SEAQUEST, 47 | GAME_SPACE_INVADERS, 48 | GAME_STAR_GUNNER, 49 | GAME_VENTURE, 50 | GAME_PINBALL, 51 | GAME_WIZARD, 52 | GAME_YARS_REVENGE, 53 | 54 | GAME_ASTERIX, 55 | GAME_ASTEROIDS, 56 | GAME_BATTLE_ZONE, 57 | GAME_BEAM_RIDER, 58 | GAME_CENTIPEDE, 59 | GAME_CRAZY_CLIMBER, 60 | GAME_ELEVATOR_ACTION, 61 | GAME_GRAVITAR, 62 | GAME_HERO, 63 | GAME_JAMESBOND, 64 | GAME_KANGAROO, 65 | GAME_KRULL, 66 | GAME_KUNG_FU_MASTER, 67 | GAME_MONTEZUMA_REVENGE, 68 | GAME_MS_PACMAN, 69 | GAME_PHOENIX, 70 | GAME_PRIVATE_EYE, 71 | GAME_ROBOTANK, 72 | GAME_TIME_PILOT, 73 | GAME_TUTANKHAM, 74 | GAME_UP_N_DOWN, 75 | GAME_ZAXXON, 76 | 77 | GAME_DOUBLE_DUNK, 78 | GAME_ROAD_RUNNER, 79 | GAME_SOLARIS, 80 | }; 81 | 82 | enum ROM_ATTR 83 | { 84 | ROM_ATTR_Manufacturer, 85 | ROM_ATTR_ModelNo, 86 | ROM_ATTR_Name, 87 | ROM_ATTR_Note, 88 | ROM_ATTR_Rarity, 89 | ROM_ATTR_Sound, 90 | ROM_ATTR_Type, 91 | ROM_ATTR_LeftDifficulty, 92 | ROM_ATTR_RightDifficulty, 93 | ROM_ATTR_TelevisionType, 94 | ROM_ATTR_SwapPorts, 95 | ROM_ATTR_ControllerLeft, 96 | ROM_ATTR_ControllerRight, 97 | ROM_ATTR_SwapPaddles, 98 | ROM_ATTR_Format, 99 | ROM_ATTR_YStart, 100 | ROM_ATTR_Height, 101 | ROM_ATTR_Phosphor, 102 | ROM_ATTR_PPBlend, 103 | ROM_ATTR_HmoveBlanks, 104 | _ROM_ATTR_MAX 105 | }; 106 | 107 | } // end namespace games 108 | } // end namespace atari 109 | } // end namespace cule 110 | 111 | -------------------------------------------------------------------------------- /cule/atari/games/detail/utils.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | 15 | template 16 | CULE_ANNOTATION 17 | int getDecimalScore(State& s, const int idx) 18 | { 19 | int score = 0; 20 | int digits_val = ram::read(s.ram, idx); 21 | int right_digit = digits_val & 15; 22 | int left_digit = digits_val >> 4; 23 | score += ((10 * left_digit) + right_digit); 24 | 25 | return score; 26 | } 27 | 28 | template 29 | CULE_ANNOTATION 30 | int getDecimalScore(State& s, int lower_index, int higher_index) 31 | { 32 | int score = 0; 33 | int lower_digits_val = ram::read(s.ram, lower_index); 34 | int lower_right_digit = lower_digits_val & 15; 35 | int lower_left_digit = (lower_digits_val - lower_right_digit) >> 4; 36 | score += ((10 * lower_left_digit) + lower_right_digit); 37 | 38 | if (higher_index < 0) 39 | { 40 | return score; 41 | } 42 | 43 | int higher_digits_val = ram::read(s.ram, higher_index); 44 | int higher_right_digit = higher_digits_val & 15; 45 | int higher_left_digit = (higher_digits_val - higher_right_digit) >> 4; 46 | score += ((1000 * higher_left_digit) + 100 * higher_right_digit); 47 | 48 | return score; 49 | } 50 | 51 | template 52 | CULE_ANNOTATION 53 | int getDecimalScore(State& s, int lower_index, int middle_index, int higher_index) 54 | { 55 | int score = getDecimalScore(s, lower_index, middle_index); 56 | int higher_digits_val = ram::read(s.ram, higher_index); 57 | int higher_right_digit = higher_digits_val & 15; 58 | int higher_left_digit = (higher_digits_val - higher_right_digit) >> 4; 59 | score += ((100000 * higher_left_digit) + 10000 * higher_right_digit); 60 | 61 | return score; 62 | } 63 | 64 | } // end namespace games 65 | } // end namespace atari 66 | } // end namespace cule 67 | 68 | -------------------------------------------------------------------------------- /cule/atari/games/doubledunk.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace doubledunk 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 0; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int my_score = getDecimalScore(s, 0xF6); 35 | int oppt_score = getDecimalScore(s, 0xF7); 36 | int score = my_score - oppt_score; 37 | s.m_reward = score - s.m_score; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | int some_value = ram::read(s, 0xFE); 42 | s.tiaFlags.template change((my_score >= 24 || oppt_score >= 24) && (some_value == 0xE7)); 43 | } 44 | 45 | CULE_ANNOTATION 46 | bool isMinimal(const Action &a) 47 | { 48 | switch (a) 49 | { 50 | case ACTION_NOOP: 51 | case ACTION_FIRE: 52 | case ACTION_UP: 53 | case ACTION_RIGHT: 54 | case ACTION_LEFT: 55 | case ACTION_DOWN: 56 | case ACTION_UPRIGHT: 57 | case ACTION_UPLEFT: 58 | case ACTION_DOWNRIGHT: 59 | case ACTION_DOWNLEFT: 60 | case ACTION_UPFIRE: 61 | case ACTION_RIGHTFIRE: 62 | case ACTION_LEFTFIRE: 63 | case ACTION_DOWNFIRE: 64 | case ACTION_UPRIGHTFIRE: 65 | case ACTION_UPLEFTFIRE: 66 | case ACTION_DOWNRIGHTFIRE: 67 | case ACTION_DOWNLEFTFIRE: 68 | return true; 69 | default: 70 | return false; 71 | } 72 | } 73 | 74 | template 75 | CULE_ANNOTATION 76 | int32_t lives(State&) 77 | { 78 | return 0; 79 | } 80 | 81 | template 82 | CULE_ANNOTATION 83 | void setTerminal(State& s) 84 | { 85 | int my_score = cule::atari::games::getDecimalScore(s, 0xF6); 86 | int oppt_score = cule::atari::games::getDecimalScore(s, 0xF7); 87 | 88 | // update terminal status 89 | int some_value = cule::atari::ram::read(s.ram, 0xFE); 90 | s.tiaFlags.template change((my_score >= 24 || oppt_score >= 24) && (some_value == 0xE7)); 91 | } 92 | 93 | template 94 | CULE_ANNOTATION 95 | int32_t score(State& s) 96 | { 97 | int my_score = cule::atari::games::getDecimalScore(s, 0xF6); 98 | int oppt_score = cule::atari::games::getDecimalScore(s, 0xF7); 99 | return my_score - oppt_score; 100 | } 101 | 102 | template 103 | CULE_ANNOTATION 104 | int32_t reward(State& s) 105 | { 106 | return score(s) - s.score; 107 | } 108 | 109 | } // end namespace doubledunk 110 | } // end namespace games 111 | } // end namespace atari 112 | } // end namespace cule 113 | 114 | -------------------------------------------------------------------------------- /cule/atari/games/elevatoraction.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace elevatoraction 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 0; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | s.m_lives = 4; 25 | } 26 | 27 | template 28 | CULE_ANNOTATION 29 | void step(State& s) 30 | { 31 | using cule::atari::games::getDecimalScore; 32 | using cule::atari::ram::read; 33 | 34 | // update the reward 35 | int score = getDecimalScore(s, 0x89, 0x88, 0x87); 36 | s.m_reward = score - s.m_score; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | s.m_lives = ram::read(s, 0x83); 41 | int is_start_screen = ram::read(s, 0x81) == 0x00; 42 | s.tiaFlags.template change((s.m_lives == 0) && !is_start_screen); 43 | } 44 | 45 | CULE_ANNOTATION 46 | bool isMinimal(const Action &a) 47 | { 48 | switch (a) 49 | { 50 | case ACTION_NOOP: 51 | case ACTION_FIRE: 52 | case ACTION_UP: 53 | case ACTION_RIGHT: 54 | case ACTION_LEFT: 55 | case ACTION_DOWN: 56 | case ACTION_UPRIGHT: 57 | case ACTION_UPLEFT: 58 | case ACTION_DOWNRIGHT: 59 | case ACTION_DOWNLEFT: 60 | case ACTION_UPFIRE: 61 | case ACTION_RIGHTFIRE: 62 | case ACTION_LEFTFIRE: 63 | case ACTION_DOWNFIRE: 64 | case ACTION_UPRIGHTFIRE: 65 | case ACTION_UPLEFTFIRE: 66 | case ACTION_DOWNRIGHTFIRE: 67 | case ACTION_DOWNLEFTFIRE: 68 | return true; 69 | default: 70 | return false; 71 | } 72 | } 73 | template 74 | CULE_ANNOTATION 75 | int32_t lives(State& s) 76 | { 77 | return cule::atari::ram::read(s.ram, 0x83); 78 | } 79 | 80 | template 81 | CULE_ANNOTATION 82 | void setTerminal(State& s) 83 | { 84 | // update terminal status 85 | bool is_start_screen = cule::atari::ram::read(s.ram, 0x81) == 0x00; 86 | s.tiaFlags.template change((lives(s) == 0) && !is_start_screen); 87 | } 88 | 89 | template 90 | CULE_ANNOTATION 91 | int32_t score(State& s) 92 | { 93 | return cule::atari::games::getDecimalScore(s, 0x89, 0x88, 0x87); 94 | } 95 | 96 | template 97 | CULE_ANNOTATION 98 | int32_t reward(State& s) 99 | { 100 | return score(s) - s.score; 101 | } 102 | 103 | } // end namespace elevatoraction 104 | } // end namespace games 105 | } // end namespace atari 106 | } // end namespace cule 107 | 108 | -------------------------------------------------------------------------------- /cule/atari/games/fishingderby.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace fishingderby 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | } 24 | template 25 | CULE_ANNOTATION 26 | void step(State& s) 27 | { 28 | using cule::atari::games::getDecimalScore; 29 | using cule::atari::ram::read; 30 | 31 | // update the reward 32 | int my_score = max(getDecimalScore(s, 0xBD), 0); 33 | int oppt_score = max(getDecimalScore(s, 0xBE), 0); 34 | int score = my_score - oppt_score; 35 | s.m_reward = score - s.m_score; 36 | s.m_score = score; 37 | 38 | // update terminal status 39 | int my_score_byte = ram::read(s, 0xBD); 40 | int my_oppt_score_byte = ram::read(s, 0xBE); 41 | 42 | s.tiaFlags.template change((my_score_byte == 0x99) || (my_oppt_score_byte == 0x99)); 43 | } 44 | 45 | CULE_ANNOTATION 46 | bool isMinimal(const Action &a) 47 | { 48 | switch (a) 49 | { 50 | case ACTION_NOOP: 51 | case ACTION_FIRE: 52 | case ACTION_UP: 53 | case ACTION_RIGHT: 54 | case ACTION_LEFT: 55 | case ACTION_DOWN: 56 | case ACTION_UPRIGHT: 57 | case ACTION_UPLEFT: 58 | case ACTION_DOWNRIGHT: 59 | case ACTION_DOWNLEFT: 60 | case ACTION_UPFIRE: 61 | case ACTION_RIGHTFIRE: 62 | case ACTION_LEFTFIRE: 63 | case ACTION_DOWNFIRE: 64 | case ACTION_UPRIGHTFIRE: 65 | case ACTION_UPLEFTFIRE: 66 | case ACTION_DOWNRIGHTFIRE: 67 | case ACTION_DOWNLEFTFIRE: 68 | return true; 69 | default: 70 | return false; 71 | } 72 | } 73 | 74 | template 75 | CULE_ANNOTATION 76 | int32_t lives(State&) 77 | { 78 | return 0; 79 | } 80 | 81 | template 82 | CULE_ANNOTATION 83 | void setTerminal(State& s) 84 | { 85 | // update terminal status 86 | int my_score_byte = cule::atari::ram::read(s.ram, 0xBD); 87 | int my_oppt_score_byte = cule::atari::ram::read(s.ram, 0xBE); 88 | 89 | s.tiaFlags.template change((my_score_byte == 0x99) || (my_oppt_score_byte == 0x99)); 90 | } 91 | 92 | template 93 | CULE_ANNOTATION 94 | int32_t score(State& s) 95 | { 96 | int my_score = max(cule::atari::games::getDecimalScore(s, 0xBD), 0); 97 | int oppt_score = max(cule::atari::games::getDecimalScore(s, 0xBE), 0); 98 | return my_score - oppt_score; 99 | } 100 | 101 | template 102 | CULE_ANNOTATION 103 | int32_t reward(State& s) 104 | { 105 | return score(s) - s.score; 106 | } 107 | 108 | } // end namespace fishingderby 109 | } // end namespace games 110 | } // end namespace atari 111 | } // end namespace cule 112 | 113 | -------------------------------------------------------------------------------- /cule/atari/games/freeway.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace freeway 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | } 24 | 25 | template 26 | CULE_ANNOTATION 27 | void step(State& s) 28 | { 29 | using cule::atari::games::getDecimalScore; 30 | using cule::atari::ram::read; 31 | 32 | // update the reward 33 | int score = getDecimalScore(s, 103, -1); 34 | int reward = score - s.m_score; 35 | if (reward < 0) reward = 0; 36 | if (reward > 1) reward = 1; 37 | s.m_reward = reward; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | s.tiaFlags.template change(ram::read(s, 22) == 1); 42 | } 43 | 44 | CULE_ANNOTATION 45 | bool isMinimal(const Action &a) 46 | { 47 | switch (a) 48 | { 49 | case ACTION_NOOP: 50 | case ACTION_UP: 51 | case ACTION_DOWN: 52 | return true; 53 | default: 54 | return false; 55 | } 56 | } 57 | 58 | template 59 | CULE_ANNOTATION 60 | int32_t lives(State&) 61 | { 62 | return 0; 63 | } 64 | 65 | template 66 | CULE_ANNOTATION 67 | void setTerminal(State& s) 68 | { 69 | // update terminal status 70 | s.tiaFlags.template change(cule::atari::ram::read(s.ram, 22) == 1); 71 | } 72 | 73 | template 74 | CULE_ANNOTATION 75 | int32_t score(State& s) 76 | { 77 | return cule::atari::games::getDecimalScore(s, 103, -1); 78 | } 79 | 80 | template 81 | CULE_ANNOTATION 82 | int32_t reward(State& s) 83 | { 84 | int32_t m_reward = score(s) - s.score; 85 | 86 | if (m_reward < 0) m_reward = 0; 87 | if (m_reward > 1) m_reward = 1; 88 | 89 | return m_reward; 90 | } 91 | 92 | } // end namespace freeway 93 | } // end namespace games 94 | } // end namespace atari 95 | } // end namespace cule 96 | 97 | -------------------------------------------------------------------------------- /cule/atari/games/frostbite.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace frostbite 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 4; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xCA, 0xC9, 0xC8); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | // MGB: the maximum achievable life is 9. The system will actually let us set the byte to 41 | // higher values & properly decrement, but we do not gain lives beyond 9. 42 | int lives_byte = (ram::read(s, 0xCC) & 0xF); 43 | int flag = ram::read(s, 0xF1) & 0x80; 44 | s.tiaFlags.template change((lives_byte == 0) && (flag != 0)); 45 | 46 | s.m_lives = lives_byte + 1; 47 | } 48 | 49 | CULE_ANNOTATION 50 | bool isMinimal(const Action &a) 51 | { 52 | switch (a) 53 | { 54 | case ACTION_NOOP: 55 | case ACTION_FIRE: 56 | case ACTION_UP: 57 | case ACTION_RIGHT: 58 | case ACTION_LEFT: 59 | case ACTION_DOWN: 60 | case ACTION_UPRIGHT: 61 | case ACTION_UPLEFT: 62 | case ACTION_DOWNRIGHT: 63 | case ACTION_DOWNLEFT: 64 | case ACTION_UPFIRE: 65 | case ACTION_RIGHTFIRE: 66 | case ACTION_LEFTFIRE: 67 | case ACTION_DOWNFIRE: 68 | case ACTION_UPRIGHTFIRE: 69 | case ACTION_UPLEFTFIRE: 70 | case ACTION_DOWNRIGHTFIRE: 71 | case ACTION_DOWNLEFTFIRE: 72 | return true; 73 | default: 74 | return false; 75 | } 76 | } 77 | 78 | template 79 | CULE_ANNOTATION 80 | int32_t lives(State& s) 81 | { 82 | return (cule::atari::ram::read(s.ram, 0xCC) & 0xF) + 1; 83 | } 84 | 85 | template 86 | CULE_ANNOTATION 87 | void setTerminal(State& s) 88 | { 89 | // update terminal status 90 | // MGB: the maximum achievable life is 9. The system will actually let us set the byte to 91 | // higher values & properly decrement, but we do not gain lives beyond 9. 92 | int lives_byte = lives(s) - 1; 93 | int flag = cule::atari::ram::read(s.ram, 0xF1) & 0x80; 94 | s.tiaFlags.template change((lives_byte == 0) && (flag != 0)); 95 | } 96 | 97 | template 98 | CULE_ANNOTATION 99 | int32_t score(State& s) 100 | { 101 | return cule::atari::games::getDecimalScore(s, 0xCA, 0xC9, 0xC8); 102 | } 103 | 104 | template 105 | CULE_ANNOTATION 106 | int32_t reward(State& s) 107 | { 108 | return score(s) - s.score; 109 | } 110 | 111 | } // end namespace frostbite 112 | } // end namespace games 113 | } // end namespace atari 114 | } // end namespace cule 115 | 116 | -------------------------------------------------------------------------------- /cule/atari/games/gopher.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace gopher 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 0; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | s.m_lives = 3; 25 | } 26 | 27 | template 28 | CULE_ANNOTATION 29 | void step(State& s) 30 | { 31 | using cule::atari::games::getDecimalScore; 32 | using cule::atari::ram::read; 33 | 34 | // update the reward 35 | int score = getDecimalScore(s, 0xB2, 0xB1, 0xB0); 36 | int reward = score - s.m_score; 37 | s.m_reward = reward; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | int carrot_bits = ram::read(s, 0xB4) & 0x7; 42 | s.tiaFlags.template change(carrot_bits == 0); 43 | 44 | // A very crude popcount 45 | static int livesFromCarrots[] = { 0, 1, 1, 2, 1, 2, 2, 3}; 46 | s.m_lives = livesFromCarrots[carrot_bits]; 47 | } 48 | 49 | CULE_ANNOTATION 50 | bool isMinimal(const Action &a) 51 | { 52 | switch (a) 53 | { 54 | case ACTION_NOOP: 55 | case ACTION_FIRE: 56 | case ACTION_UP: 57 | case ACTION_RIGHT: 58 | case ACTION_LEFT: 59 | case ACTION_UPFIRE: 60 | case ACTION_RIGHTFIRE: 61 | case ACTION_LEFTFIRE: 62 | return true; 63 | default: 64 | return false; 65 | } 66 | } 67 | 68 | template 69 | CULE_ANNOTATION 70 | int32_t lives(State& s) 71 | { 72 | // update terminal status 73 | // A very crude popcount 74 | uint32_t carrot_bits = cule::atari::ram::read(s.ram, 0xB4) & 0x7; 75 | static int livesFromCarrots[] = { 0, 1, 1, 2, 1, 2, 2, 3}; 76 | return livesFromCarrots[carrot_bits]; 77 | } 78 | 79 | template 80 | CULE_ANNOTATION 81 | void setTerminal(State& s) 82 | { 83 | // update terminal status 84 | int carrot_bits = cule::atari::ram::read(s.ram, 0xB4) & 0x7; 85 | s.tiaFlags.template change(carrot_bits == 0); 86 | } 87 | 88 | template 89 | CULE_ANNOTATION 90 | int32_t score(State& s) 91 | { 92 | return cule::atari::games::getDecimalScore(s, 0xB2, 0xB1, 0xB0); 93 | } 94 | 95 | template 96 | CULE_ANNOTATION 97 | int32_t reward(State& s) 98 | { 99 | return score(s) - s.score; 100 | } 101 | 102 | } // end namespace gopher 103 | } // end namespace games 104 | } // end namespace atari 105 | } // end namespace cule 106 | 107 | -------------------------------------------------------------------------------- /cule/atari/games/hero.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace hero 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 4; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xB9, 0xB8, 0xB7); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | s.m_lives = ram::read(s, 0xB3); 41 | s.tiaFlags.template change(s.m_lives == 0); 42 | } 43 | 44 | CULE_ANNOTATION 45 | bool isMinimal(const Action &a) 46 | { 47 | switch (a) 48 | { 49 | case ACTION_NOOP: 50 | case ACTION_FIRE: 51 | case ACTION_UP: 52 | case ACTION_RIGHT: 53 | case ACTION_LEFT: 54 | case ACTION_DOWN: 55 | case ACTION_UPRIGHT: 56 | case ACTION_UPLEFT: 57 | case ACTION_DOWNRIGHT: 58 | case ACTION_DOWNLEFT: 59 | case ACTION_UPFIRE: 60 | case ACTION_RIGHTFIRE: 61 | case ACTION_LEFTFIRE: 62 | case ACTION_DOWNFIRE: 63 | case ACTION_UPRIGHTFIRE: 64 | case ACTION_UPLEFTFIRE: 65 | case ACTION_DOWNRIGHTFIRE: 66 | case ACTION_DOWNLEFTFIRE: 67 | return true; 68 | default: 69 | return false; 70 | } 71 | } 72 | 73 | template 74 | CULE_ANNOTATION 75 | int32_t lives(State& s) 76 | { 77 | return cule::atari::ram::read(s.ram, 0xB3); 78 | } 79 | 80 | template 81 | CULE_ANNOTATION 82 | void setTerminal(State& s) 83 | { 84 | // update terminal status 85 | s.tiaFlags.template change(lives(s) == 0); 86 | } 87 | 88 | template 89 | CULE_ANNOTATION 90 | int32_t score(State& s) 91 | { 92 | return cule::atari::games::getDecimalScore(s, 0xB9, 0xB8, 0xB7); 93 | } 94 | 95 | template 96 | CULE_ANNOTATION 97 | int32_t reward(State& s) 98 | { 99 | return score(s) - s.score; 100 | } 101 | 102 | } // end namespace hero 103 | } // end namespace games 104 | } // end namespace atari 105 | } // end namespace cule 106 | 107 | -------------------------------------------------------------------------------- /cule/atari/games/icehockey.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace icehockey 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | } 24 | 25 | template 26 | CULE_ANNOTATION 27 | void step(State& s) 28 | { 29 | using cule::atari::games::getDecimalScore; 30 | using cule::atari::ram::read; 31 | 32 | // update the reward 33 | int my_score = max(getDecimalScore(s, 0x8A), 0); 34 | int oppt_score = max(getDecimalScore(s, 0x8B), 0); 35 | int score = my_score - oppt_score; 36 | int reward = min(score - s.m_score, 1); 37 | s.m_reward = reward; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | int minutes = ram::read(s, 0x87); 42 | int seconds = ram::read(s, 0x86); 43 | 44 | // end of game when out of time 45 | s.tiaFlags.template change((minutes == 0) && (seconds == 0)); 46 | } 47 | 48 | CULE_ANNOTATION 49 | bool isMinimal(const Action &a) 50 | { 51 | switch (a) 52 | { 53 | case ACTION_NOOP: 54 | case ACTION_FIRE: 55 | case ACTION_UP: 56 | case ACTION_RIGHT: 57 | case ACTION_LEFT: 58 | case ACTION_DOWN: 59 | case ACTION_UPRIGHT: 60 | case ACTION_UPLEFT: 61 | case ACTION_DOWNRIGHT: 62 | case ACTION_DOWNLEFT: 63 | case ACTION_UPFIRE: 64 | case ACTION_RIGHTFIRE: 65 | case ACTION_LEFTFIRE: 66 | case ACTION_DOWNFIRE: 67 | case ACTION_UPRIGHTFIRE: 68 | case ACTION_UPLEFTFIRE: 69 | case ACTION_DOWNRIGHTFIRE: 70 | case ACTION_DOWNLEFTFIRE: 71 | return true; 72 | default: 73 | return false; 74 | } 75 | } 76 | 77 | template 78 | CULE_ANNOTATION 79 | int32_t lives(State&) 80 | { 81 | return 0; 82 | } 83 | 84 | template 85 | CULE_ANNOTATION 86 | void setTerminal(State& s) 87 | { 88 | // update terminal status 89 | int minutes = cule::atari::ram::read(s.ram, 0x87); 90 | int seconds = cule::atari::ram::read(s.ram, 0x86); 91 | 92 | // end of game when out of time 93 | s.tiaFlags.template change((minutes == 0) && (seconds == 0)); 94 | } 95 | 96 | template 97 | CULE_ANNOTATION 98 | int32_t score(State& s) 99 | { 100 | int my_score = max(cule::atari::games::getDecimalScore(s, 0x8A), 0); 101 | int oppt_score = max(cule::atari::games::getDecimalScore(s, 0x8B), 0); 102 | return my_score - oppt_score; 103 | } 104 | 105 | template 106 | CULE_ANNOTATION 107 | int32_t reward(State& s) 108 | { 109 | return min(score(s) - s.score, 1); 110 | } 111 | 112 | } // end namespace icehockey 113 | } // end namespace games 114 | } // end namespace atari 115 | } // end namespace cule 116 | 117 | -------------------------------------------------------------------------------- /cule/atari/games/jamesbond.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace jamesbond 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 6; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xDC, 0xDD, 0xDE); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | int lives_byte = ram::read(s, 0x86) & 0xF; 41 | int screen_byte = ram::read(s, 0x8C); 42 | 43 | // byte 0x8C is 0x68 when we die; it does not remain so forever, as 44 | // the system loops back to start state after a while (where fire will 45 | // start a new game) 46 | s.tiaFlags.template change((lives_byte == 0) && (screen_byte == 0x68)); 47 | s.m_lives = lives_byte + 1; 48 | } 49 | 50 | CULE_ANNOTATION 51 | bool isMinimal(const Action &a) 52 | { 53 | switch (a) 54 | { 55 | case ACTION_NOOP: 56 | case ACTION_FIRE: 57 | case ACTION_UP: 58 | case ACTION_RIGHT: 59 | case ACTION_LEFT: 60 | case ACTION_DOWN: 61 | case ACTION_UPRIGHT: 62 | case ACTION_UPLEFT: 63 | case ACTION_DOWNRIGHT: 64 | case ACTION_DOWNLEFT: 65 | case ACTION_UPFIRE: 66 | case ACTION_RIGHTFIRE: 67 | case ACTION_LEFTFIRE: 68 | case ACTION_DOWNFIRE: 69 | case ACTION_UPRIGHTFIRE: 70 | case ACTION_UPLEFTFIRE: 71 | case ACTION_DOWNRIGHTFIRE: 72 | case ACTION_DOWNLEFTFIRE: 73 | return true; 74 | default: 75 | return false; 76 | } 77 | } 78 | 79 | template 80 | CULE_ANNOTATION 81 | int32_t lives(State&) 82 | { 83 | return 0; 84 | } 85 | 86 | template 87 | CULE_ANNOTATION 88 | void setTerminal(State& s) 89 | { 90 | using cule::atari::ram::read; 91 | 92 | // update terminal status 93 | int lives_byte = ram::read(s.ram, 0x86) & 0xF; 94 | int screen_byte = ram::read(s.ram, 0x8C); 95 | 96 | // byte 0x8C is 0x68 when we die; it does not remain so forever, as 97 | // the system loops back to start state after a while (where fire will 98 | // start a new game) 99 | s.tiaFlags.template change((lives_byte == 0) && (screen_byte == 0x68)); 100 | } 101 | 102 | template 103 | CULE_ANNOTATION 104 | int32_t score(State& s) 105 | { 106 | return cule::atari::games::getDecimalScore(s, 0xDC, 0xDD, 0xDE); 107 | } 108 | 109 | template 110 | CULE_ANNOTATION 111 | int32_t reward(State& s) 112 | { 113 | return score(s) - s.score; 114 | } 115 | 116 | } // end namespace jamesbond 117 | } // end namespace games 118 | } // end namespace atari 119 | } // end namespace cule 120 | 121 | -------------------------------------------------------------------------------- /cule/atari/games/journeyescape.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace journeyescape 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 0; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0x92, 0x91, 0x90); 35 | int reward = score - s.m_score; 36 | if (reward == 50000) reward = 0; // HACK: ignoring starting cash 37 | s.m_reward = reward; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | int minutes = ram::read(s, 0x95); 42 | int seconds = ram::read(s, 0x96); 43 | s.tiaFlags.template change((minutes == 0) && (seconds == 0)); 44 | } 45 | 46 | CULE_ANNOTATION 47 | bool isMinimal(const Action &a) 48 | { 49 | switch (a) 50 | { 51 | case ACTION_NOOP: 52 | case ACTION_UP: 53 | case ACTION_RIGHT: 54 | case ACTION_LEFT: 55 | case ACTION_DOWN: 56 | case ACTION_UPRIGHT: 57 | case ACTION_UPLEFT: 58 | case ACTION_DOWNRIGHT: 59 | case ACTION_DOWNLEFT: 60 | case ACTION_RIGHTFIRE: 61 | case ACTION_LEFTFIRE: 62 | case ACTION_DOWNFIRE: 63 | case ACTION_UPRIGHTFIRE: 64 | case ACTION_UPLEFTFIRE: 65 | case ACTION_DOWNRIGHTFIRE: 66 | case ACTION_DOWNLEFTFIRE: 67 | return true; 68 | default: 69 | return false; 70 | } 71 | } 72 | 73 | template 74 | CULE_ANNOTATION 75 | int32_t lives(State&) 76 | { 77 | return 0; 78 | } 79 | 80 | template 81 | CULE_ANNOTATION 82 | void setTerminal(State& s) 83 | { 84 | // update terminal status 85 | int minutes = cule::atari::ram::read(s.ram, 0x95); 86 | int seconds = cule::atari::ram::read(s.ram, 0x96); 87 | s.tiaFlags.template change((minutes == 0) && (seconds == 0)); 88 | } 89 | 90 | template 91 | CULE_ANNOTATION 92 | int32_t score(State& s) 93 | { 94 | return cule::atari::games::getDecimalScore(s, 0x92, 0x91, 0x90); 95 | } 96 | 97 | template 98 | CULE_ANNOTATION 99 | int32_t reward(State& s) 100 | { 101 | int32_t m_score = score(s); 102 | int32_t m_reward = m_score - s.score; 103 | 104 | if (m_reward == 50000) m_reward = 0; // HACK: ignoring starting cash 105 | 106 | return m_reward; 107 | } 108 | 109 | } // end namespace journeyescape 110 | } // end namespace games 111 | } // end namespace atari 112 | } // end namespace cule 113 | 114 | -------------------------------------------------------------------------------- /cule/atari/games/kaboom.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace kaboom 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 0; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xA5, 0xA4, 0xA3); 35 | s.m_reward = score - s.m_score; 36 | s.m_score = score; 37 | 38 | // update terminal status 39 | int lives = ram::read(s, 0xA1); 40 | s.tiaFlags.template change((lives == 0x0) || (s.m_score == 999999)); 41 | } 42 | 43 | CULE_ANNOTATION 44 | bool isMinimal(const Action &a) 45 | { 46 | switch (a) 47 | { 48 | case ACTION_NOOP: 49 | case ACTION_FIRE: 50 | case ACTION_RIGHT: 51 | case ACTION_LEFT: 52 | return true; 53 | default: 54 | return false; 55 | } 56 | } 57 | 58 | template 59 | CULE_ANNOTATION 60 | int32_t lives(State& s) 61 | { 62 | return cule::atari::ram::read(s.ram, 0xA1); 63 | } 64 | 65 | template 66 | CULE_ANNOTATION 67 | void setTerminal(State& s) 68 | { 69 | int32_t score = cule::atari::games::getDecimalScore(s, 0xA5, 0xA4, 0xA3); 70 | s.tiaFlags.template change((lives(s) == 0x0) || (score == 999999)); 71 | } 72 | 73 | template 74 | CULE_ANNOTATION 75 | int32_t score(State& s) 76 | { 77 | return cule::atari::games::getDecimalScore(s, 0xA5, 0xA4, 0xA3); 78 | } 79 | 80 | template 81 | CULE_ANNOTATION 82 | int32_t reward(State& s) 83 | { 84 | return score(s) - s.score; 85 | } 86 | 87 | } // end namespace kaboom 88 | } // end namespace games 89 | } // end namespace atari 90 | } // end namespace cule 91 | 92 | -------------------------------------------------------------------------------- /cule/atari/games/kangaroo.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace kangaroo 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | CULE_ANNOTATION 27 | bool isMinimal(const Action &a) 28 | { 29 | switch (a) 30 | { 31 | case ACTION_NOOP: 32 | case ACTION_FIRE: 33 | case ACTION_UP: 34 | case ACTION_RIGHT: 35 | case ACTION_LEFT: 36 | case ACTION_DOWN: 37 | case ACTION_UPRIGHT: 38 | case ACTION_UPLEFT: 39 | case ACTION_DOWNRIGHT: 40 | case ACTION_DOWNLEFT: 41 | case ACTION_UPFIRE: 42 | case ACTION_RIGHTFIRE: 43 | case ACTION_LEFTFIRE: 44 | case ACTION_DOWNFIRE: 45 | case ACTION_UPRIGHTFIRE: 46 | case ACTION_UPLEFTFIRE: 47 | case ACTION_DOWNRIGHTFIRE: 48 | case ACTION_DOWNLEFTFIRE: 49 | return true; 50 | default: 51 | return false; 52 | } 53 | } 54 | 55 | template 56 | CULE_ANNOTATION 57 | void step(State& s) 58 | { 59 | using cule::atari::games::getDecimalScore; 60 | using cule::atari::ram::read; 61 | 62 | // update the reward 63 | int score = getDecimalScore(s, 0xA8, 0xA7); 64 | score *= 100; 65 | int reward = score - s.m_score; 66 | s.m_reward = reward; 67 | s.m_score = score; 68 | 69 | // update terminal status 70 | int lives_byte = ram::read(s, 0xAD); 71 | s.tiaFlags.template change(lives_byte == 0xFF); 72 | s.m_lives = (lives_byte & 0x7) + 1; 73 | } 74 | 75 | template 76 | CULE_ANNOTATION 77 | int32_t lives(State& s) 78 | { 79 | return cule::atari::ram::read(s.ram, 0xAD); 80 | } 81 | 82 | template 83 | CULE_ANNOTATION 84 | void setTerminal(State& s) 85 | { 86 | // update terminal status 87 | s.tiaFlags.template change(lives(s) == 0xFF); 88 | } 89 | 90 | template 91 | CULE_ANNOTATION 92 | int32_t score(State& s) 93 | { 94 | return 100 * cule::atari::games::getDecimalScore(s, 0xA8, 0xA7); 95 | } 96 | 97 | template 98 | CULE_ANNOTATION 99 | int32_t reward(State& s) 100 | { 101 | return score(s) - s.score; 102 | } 103 | 104 | } // end namespace kangaroo 105 | } // end namespace games 106 | } // end namespace atari 107 | } // end namespace cule 108 | 109 | -------------------------------------------------------------------------------- /cule/atari/games/krull.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace krull 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0x9E, 0x9D, 0x9C); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | int lives = ram::read(s, 0x9F); 41 | int byte1 = ram::read(s, 0xA2); 42 | int byte2 = ram::read(s, 0x80); 43 | s.tiaFlags.template change((lives == 0) && (byte1 == 0x03) && (byte2 == 0x80)); 44 | s.m_lives = (lives & 0x7) + 1; 45 | } 46 | 47 | CULE_ANNOTATION 48 | bool isMinimal(const Action &a) 49 | { 50 | switch (a) 51 | { 52 | case ACTION_NOOP: 53 | case ACTION_FIRE: 54 | case ACTION_UP: 55 | case ACTION_RIGHT: 56 | case ACTION_LEFT: 57 | case ACTION_DOWN: 58 | case ACTION_UPRIGHT: 59 | case ACTION_UPLEFT: 60 | case ACTION_DOWNRIGHT: 61 | case ACTION_DOWNLEFT: 62 | case ACTION_UPFIRE: 63 | case ACTION_RIGHTFIRE: 64 | case ACTION_LEFTFIRE: 65 | case ACTION_DOWNFIRE: 66 | case ACTION_UPRIGHTFIRE: 67 | case ACTION_UPLEFTFIRE: 68 | case ACTION_DOWNRIGHTFIRE: 69 | case ACTION_DOWNLEFTFIRE: 70 | return true; 71 | default: 72 | return false; 73 | } 74 | } 75 | 76 | template 77 | CULE_ANNOTATION 78 | int32_t lives(State& s) 79 | { 80 | return cule::atari::ram::read(s.ram, 0x9F) + 1; 81 | } 82 | 83 | template 84 | CULE_ANNOTATION 85 | void setTerminal(State& s) 86 | { 87 | using cule::atari::ram::read; 88 | 89 | // update terminal status 90 | int lives = ram::read(s.ram, 0x9F); 91 | int byte1 = ram::read(s.ram, 0xA2); 92 | int byte2 = ram::read(s.ram, 0x80); 93 | s.tiaFlags.template change((lives == 0) && (byte1 == 0x03) && (byte2 == 0x80)); 94 | } 95 | 96 | template 97 | CULE_ANNOTATION 98 | int32_t score(State& s) 99 | { 100 | return cule::atari::games::getDecimalScore(s, 0x9E, 0x9D, 0x9C); 101 | } 102 | 103 | template 104 | CULE_ANNOTATION 105 | int32_t reward(State& s) 106 | { 107 | return score(s) - s.score; 108 | } 109 | 110 | } // end namespace krull 111 | } // end namespace games 112 | } // end namespace atari 113 | } // end namespace cule 114 | 115 | -------------------------------------------------------------------------------- /cule/atari/games/kungfumaster.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace kungfumaster 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 4; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0x9A, 0x99, 0x98); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | int lives_byte = ram::read(s, 0x9D); 41 | s.tiaFlags.template change(lives_byte == 0xFF); 42 | s.m_lives = (lives_byte & 0x7) + 1; 43 | } 44 | 45 | CULE_ANNOTATION 46 | bool isMinimal(const Action &a) 47 | { 48 | switch (a) 49 | { 50 | case ACTION_NOOP: 51 | case ACTION_UP: 52 | case ACTION_RIGHT: 53 | case ACTION_LEFT: 54 | case ACTION_DOWN: 55 | case ACTION_DOWNRIGHT: 56 | case ACTION_DOWNLEFT: 57 | case ACTION_RIGHTFIRE: 58 | case ACTION_LEFTFIRE: 59 | case ACTION_DOWNFIRE: 60 | case ACTION_UPRIGHTFIRE: 61 | case ACTION_UPLEFTFIRE: 62 | case ACTION_DOWNRIGHTFIRE: 63 | case ACTION_DOWNLEFTFIRE: 64 | return true; 65 | default: 66 | return false; 67 | } 68 | } 69 | 70 | template 71 | CULE_ANNOTATION 72 | int32_t lives(State& s) 73 | { 74 | return cule::atari::ram::read(s.ram, 0x9D) + 1; 75 | } 76 | 77 | template 78 | CULE_ANNOTATION 79 | void setTerminal(State& s) 80 | { 81 | // update terminal status 82 | int lives_byte = cule::atari::ram::read(s.ram, 0x9D); 83 | s.tiaFlags.template change(lives_byte == 0xFF); 84 | } 85 | 86 | template 87 | CULE_ANNOTATION 88 | int32_t score(State& s) 89 | { 90 | return cule::atari::games::getDecimalScore(s, 0x9A, 0x99, 0x98); 91 | } 92 | 93 | template 94 | CULE_ANNOTATION 95 | int32_t reward(State& s) 96 | { 97 | return score(s) - s.score; 98 | } 99 | 100 | } // end namespace kungfumaster 101 | } // end namespace games 102 | } // end namespace atari 103 | } // end namespace cule 104 | 105 | -------------------------------------------------------------------------------- /cule/atari/games/montezumarevenge.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace montezumarevenge 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 6; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0x95, 0x94, 0x93); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | int new_lives = ram::read(s, 0xBA); 41 | int some_byte = ram::read(s, 0xFE); 42 | s.tiaFlags.template change((new_lives == 0) && (some_byte == 0x60)); 43 | 44 | // Actually does not go up to 8, but that's alright 45 | s.m_lives = (new_lives & 0x7) + 1; 46 | } 47 | 48 | CULE_ANNOTATION 49 | bool isMinimal(const Action &a) 50 | { 51 | switch (a) 52 | { 53 | case ACTION_NOOP: 54 | case ACTION_FIRE: 55 | case ACTION_UP: 56 | case ACTION_RIGHT: 57 | case ACTION_LEFT: 58 | case ACTION_DOWN: 59 | case ACTION_UPRIGHT: 60 | case ACTION_UPLEFT: 61 | case ACTION_DOWNRIGHT: 62 | case ACTION_DOWNLEFT: 63 | case ACTION_UPFIRE: 64 | case ACTION_RIGHTFIRE: 65 | case ACTION_LEFTFIRE: 66 | case ACTION_DOWNFIRE: 67 | case ACTION_UPRIGHTFIRE: 68 | case ACTION_UPLEFTFIRE: 69 | case ACTION_DOWNRIGHTFIRE: 70 | case ACTION_DOWNLEFTFIRE: 71 | return true; 72 | default: 73 | return false; 74 | } 75 | } 76 | 77 | template 78 | CULE_ANNOTATION 79 | int32_t lives(State& s) 80 | { 81 | // update terminal status 82 | int new_lives = cule::atari::ram::read(s.ram, 0xBA); 83 | 84 | // Actually does not go up to 8, but that's alright 85 | return (new_lives & 0x7) + 1; 86 | } 87 | 88 | template 89 | CULE_ANNOTATION 90 | void setTerminal(State& s) 91 | { 92 | using cule::atari::ram::read; 93 | 94 | // update terminal status 95 | int new_lives = ram::read(s.ram, 0xBA); 96 | int some_byte = ram::read(s.ram, 0xFE); 97 | s.tiaFlags.template change((new_lives == 0) && (some_byte == 0x60)); 98 | } 99 | 100 | template 101 | CULE_ANNOTATION 102 | int32_t score(State& s) 103 | { 104 | return cule::atari::games::getDecimalScore(s, 0x95, 0x94, 0x93); 105 | } 106 | 107 | template 108 | CULE_ANNOTATION 109 | int32_t reward(State& s) 110 | { 111 | return score(s) - s.score; 112 | } 113 | 114 | } // end namespace montezumarevenge 115 | } // end namespace games 116 | } // end namespace atari 117 | } // end namespace cule 118 | 119 | -------------------------------------------------------------------------------- /cule/atari/games/mspacman.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace mspacman 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xF8, 0xF9, 0xFA); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | int lives_byte = ram::read(s, 0xFB) & 0xF; 41 | // MGB Did not work int black_screen_byte = ram::read(&system, 0x94); 42 | int death_timer = ram::read(s, 0xA7); 43 | s.tiaFlags.template change((lives_byte == 0) && (death_timer == 0x53)); 44 | 45 | s.m_lives = (lives_byte & 0x7) + 1; 46 | } 47 | 48 | 49 | CULE_ANNOTATION 50 | bool isMinimal(const Action &a) 51 | { 52 | switch (a) 53 | { 54 | case ACTION_NOOP: 55 | case ACTION_UP: 56 | case ACTION_RIGHT: 57 | case ACTION_LEFT: 58 | case ACTION_DOWN: 59 | case ACTION_UPRIGHT: 60 | case ACTION_UPLEFT: 61 | case ACTION_DOWNRIGHT: 62 | case ACTION_DOWNLEFT: 63 | return true; 64 | default: 65 | return false; 66 | } 67 | } 68 | 69 | template 70 | CULE_ANNOTATION 71 | int32_t lives(State& s) 72 | { 73 | int lives_byte = cule::atari::ram::read(s.ram, 0xFB) & 0xF; 74 | return (lives_byte & 0x7) + 1; 75 | } 76 | 77 | template 78 | CULE_ANNOTATION 79 | void setTerminal(State& s) 80 | { 81 | // update terminal status 82 | int lives_byte = cule::atari::ram::read(s.ram, 0xFB) & 0xF; 83 | // MGB Did not work int black_screen_byte = ram::read(&system, 0x94); 84 | int death_timer = cule::atari::ram::read(s.ram, 0xA7); 85 | s.tiaFlags.template change((lives_byte == 0) && (death_timer == 0x53)); 86 | } 87 | 88 | template 89 | CULE_ANNOTATION 90 | int32_t score(State& s) 91 | { 92 | return cule::atari::games::getDecimalScore(s, 0xF8, 0xF9, 0xFA); 93 | } 94 | 95 | template 96 | CULE_ANNOTATION 97 | int32_t reward(State& s) 98 | { 99 | return score(s) - s.score; 100 | } 101 | 102 | } // end namespace mspacman 103 | } // end namespace games 104 | } // end namespace atari 105 | } // end namespace cule 106 | 107 | -------------------------------------------------------------------------------- /cule/atari/games/namethisgame.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace namethisgame 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xC6, 0xC5, 0xC4); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | s.m_lives = (ram::read(s, 0xC7) & 0x7); 41 | s.tiaFlags.template change(s.m_lives == 0); 42 | } 43 | 44 | CULE_ANNOTATION 45 | bool isMinimal(const Action &a) 46 | { 47 | switch (a) 48 | { 49 | case ACTION_NOOP: 50 | case ACTION_FIRE: 51 | case ACTION_RIGHT: 52 | case ACTION_LEFT: 53 | case ACTION_RIGHTFIRE: 54 | case ACTION_LEFTFIRE: 55 | return true; 56 | default: 57 | return false; 58 | } 59 | } 60 | 61 | template 62 | CULE_ANNOTATION 63 | int32_t lives(State& s) 64 | { 65 | return cule::atari::ram::read(s.ram, 0xC7) & 0x7; 66 | } 67 | 68 | template 69 | CULE_ANNOTATION 70 | void setTerminal(State& s) 71 | { 72 | // update terminal status 73 | s.tiaFlags.template change(lives(s) == 0); 74 | } 75 | 76 | template 77 | CULE_ANNOTATION 78 | int32_t score(State& s) 79 | { 80 | return cule::atari::games::getDecimalScore(s, 0xC6, 0xC5, 0xC4); 81 | } 82 | 83 | template 84 | CULE_ANNOTATION 85 | int32_t reward(State& s) 86 | { 87 | return score(s) - s.score; 88 | } 89 | 90 | } // end namespace namethisgame 91 | } // end namespace games 92 | } // end namespace atari 93 | } // end namespace cule 94 | 95 | -------------------------------------------------------------------------------- /cule/atari/games/phoenix.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace phoenix 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 5; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xC8, 0xC9) * 10; 35 | score += ram::read(s, 0xC7) >> 4; 36 | score *= 10; 37 | int reward = score - s.m_score; 38 | s.m_reward = reward; 39 | s.m_score = score; 40 | 41 | // update terminal status 42 | int state_byte = ram::read(s, 0xCC); 43 | s.tiaFlags.template change(state_byte == 0x80); 44 | // Technically seems to only go up to 5 45 | s.m_lives = ram::read(s, 0xCB) & 0x7; 46 | } 47 | 48 | CULE_ANNOTATION 49 | bool isMinimal(const Action &a) 50 | { 51 | switch (a) 52 | { 53 | case ACTION_NOOP: 54 | case ACTION_FIRE: 55 | case ACTION_RIGHT: 56 | case ACTION_LEFT: 57 | case ACTION_DOWN: 58 | case ACTION_RIGHTFIRE: 59 | case ACTION_LEFTFIRE: 60 | case ACTION_DOWNFIRE: 61 | return true; 62 | default: 63 | return false; 64 | } 65 | } 66 | 67 | template 68 | CULE_ANNOTATION 69 | int32_t lives(State& s) 70 | { 71 | // Technically seems to only go up to 5 72 | return cule::atari::ram::read(s.ram, 0xCB) & 0x7; 73 | } 74 | 75 | template 76 | CULE_ANNOTATION 77 | void setTerminal(State& s) 78 | { 79 | // update terminal status 80 | int32_t state_byte = cule::atari::ram::read(s.ram, 0xCC); 81 | s.tiaFlags.template change(state_byte == 0x80); 82 | } 83 | 84 | template 85 | CULE_ANNOTATION 86 | int32_t score(State& s) 87 | { 88 | // update the reward 89 | int m_score = cule::atari::games::getDecimalScore(s, 0xC8, 0xC9) * 10; 90 | m_score += cule::atari::ram::read(s.ram, 0xC7) >> 4; 91 | m_score *= 10; 92 | 93 | return m_score; 94 | } 95 | 96 | template 97 | CULE_ANNOTATION 98 | int32_t reward(State& s) 99 | { 100 | return score(s) - s.score; 101 | } 102 | 103 | } // end namespace phoenix 104 | } // end namespace games 105 | } // end namespace atari 106 | } // end namespace cule 107 | 108 | -------------------------------------------------------------------------------- /cule/atari/games/pinball.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace pinball 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xB0, 0xB2, 0xB4); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | int flag = ram::read(s, 0xAF) & 0x1; 41 | s.tiaFlags.template change(flag != 0); 42 | 43 | // The lives in video pinball are displayed as ball number; so #1 == 3 lives 44 | int lives_byte = ram::read(s, 0x99) & 0x7; 45 | // And of course, we keep the 'extra ball' counter in a different memory location 46 | int extra_ball = ram::read(s, 0xA8) & 0x1; 47 | 48 | s.m_lives = 4 + extra_ball - lives_byte; 49 | } 50 | 51 | CULE_ANNOTATION 52 | bool isMinimal(const Action &a) 53 | { 54 | switch (a) 55 | { 56 | case ACTION_NOOP: 57 | case ACTION_FIRE: 58 | case ACTION_UP: 59 | case ACTION_RIGHT: 60 | case ACTION_LEFT: 61 | case ACTION_DOWN: 62 | case ACTION_UPFIRE: 63 | case ACTION_RIGHTFIRE: 64 | case ACTION_LEFTFIRE: 65 | return true; 66 | default: 67 | return false; 68 | } 69 | } 70 | 71 | template 72 | CULE_ANNOTATION 73 | int32_t lives(State& s) 74 | { 75 | // The lives in video pinball are displayed as ball number; so #1 == 3 lives 76 | int32_t lives_byte = cule::atari::ram::read(s.ram, 0x99) & 0x7; 77 | // And of course, we keep the 'extra ball' counter in a different memory location 78 | int32_t extra_ball = cule::atari::ram::read(s.ram, 0xA8) & 0x1; 79 | 80 | return 4 + extra_ball - lives_byte; 81 | } 82 | 83 | template 84 | CULE_ANNOTATION 85 | void setTerminal(State& s) 86 | { 87 | // update terminal status 88 | int flag = cule::atari::ram::read(s.ram, 0xAF) & 0x1; 89 | s.tiaFlags.template change(flag != 0); 90 | } 91 | 92 | template 93 | CULE_ANNOTATION 94 | int32_t score(State& s) 95 | { 96 | return cule::atari::games::getDecimalScore(s, 0xB0, 0xB2, 0xB4); 97 | } 98 | 99 | template 100 | CULE_ANNOTATION 101 | int32_t reward(State& s) 102 | { 103 | return score(s) - s.score; 104 | } 105 | 106 | } // end namespace pinball 107 | } // end namespace games 108 | } // end namespace atari 109 | } // end namespace cule 110 | 111 | -------------------------------------------------------------------------------- /cule/atari/games/pitfall.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace pitfall 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 2000; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | s.m_lives = 3; 25 | } 26 | 27 | template 28 | CULE_ANNOTATION 29 | void step(State& s) 30 | { 31 | using cule::atari::games::getDecimalScore; 32 | using cule::atari::ram::read; 33 | 34 | // update the reward 35 | int score = getDecimalScore(s, 0xD7, 0xD6, 0xD5); 36 | int reward = score - s.m_score; 37 | s.m_reward = reward; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | int lives_byte = ram::read(s, 0x80) >> 4; 42 | // The value at 09xE will be nonzero if we cannot control the player 43 | int logo_timer = ram::read(s, 0x9E); 44 | s.tiaFlags.template change((lives_byte == 0) && (logo_timer != 0)); 45 | 46 | s.m_lives = (lives_byte == 0xA) ? 3 : ((lives_byte == 0x8) ? 2 : 1); 47 | } 48 | 49 | CULE_ANNOTATION 50 | bool isMinimal(const Action &a) 51 | { 52 | switch (a) 53 | { 54 | case ACTION_NOOP: 55 | case ACTION_FIRE: 56 | case ACTION_UP: 57 | case ACTION_RIGHT: 58 | case ACTION_LEFT: 59 | case ACTION_DOWN: 60 | case ACTION_UPRIGHT: 61 | case ACTION_UPLEFT: 62 | case ACTION_DOWNRIGHT: 63 | case ACTION_DOWNLEFT: 64 | case ACTION_UPFIRE: 65 | case ACTION_RIGHTFIRE: 66 | case ACTION_LEFTFIRE: 67 | case ACTION_DOWNFIRE: 68 | case ACTION_UPRIGHTFIRE: 69 | case ACTION_UPLEFTFIRE: 70 | case ACTION_DOWNRIGHTFIRE: 71 | case ACTION_DOWNLEFTFIRE: 72 | return true; 73 | default: 74 | return false; 75 | } 76 | } 77 | 78 | template 79 | CULE_ANNOTATION 80 | int32_t lives(State& s) 81 | { 82 | // update terminal status 83 | int32_t lives_byte = cule::atari::ram::read(s.ram, 0x80) >> 4; 84 | return (lives_byte == 0xA) ? 3 : ((lives_byte == 0x8) ? 2 : 1); 85 | } 86 | 87 | template 88 | CULE_ANNOTATION 89 | void setTerminal(State& s) 90 | { 91 | // update terminal status 92 | int lives_byte = cule::atari::ram::read(s.ram, 0x80) >> 4; 93 | // The value at 09xE will be nonzero if we cannot control the player 94 | int logo_timer = cule::atari::ram::read(s.ram, 0x9E); 95 | s.tiaFlags.template change((lives_byte == 0) && (logo_timer != 0)); 96 | } 97 | 98 | template 99 | CULE_ANNOTATION 100 | int32_t score(State& s) 101 | { 102 | return cule::atari::games::getDecimalScore(s, 0xD7, 0xD6, 0xD5); 103 | } 104 | 105 | template 106 | CULE_ANNOTATION 107 | int32_t reward(State& s) 108 | { 109 | return score(s) - s.score; 110 | } 111 | 112 | } // end namespace pitfall 113 | } // end namespace games 114 | } // end namespace atari 115 | } // end namespace cule 116 | 117 | -------------------------------------------------------------------------------- /cule/atari/games/pong.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace pong 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 0; 23 | s.m_lives = 0; 24 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 25 | } 26 | 27 | template 28 | CULE_ANNOTATION 29 | void step(State& s) 30 | { 31 | // update the reward 32 | int x = cule::atari::ram::read(s, 13); // cpu score 33 | int y = cule::atari::ram::read(s, 14); // player score 34 | int score = y - x; 35 | s.m_reward = score - s.m_score; 36 | s.m_score = score; 37 | 38 | // update terminal status 39 | // (game over when a player reaches 21) 40 | s.tiaFlags.template change(x == 21 || y == 21); 41 | } 42 | 43 | CULE_ANNOTATION 44 | bool isMinimal(const Action& a) 45 | { 46 | switch (a) 47 | { 48 | case ACTION_NOOP: 49 | case ACTION_FIRE: 50 | case ACTION_RIGHT: 51 | case ACTION_LEFT: 52 | case ACTION_RIGHTFIRE: 53 | case ACTION_LEFTFIRE: 54 | return true; 55 | default: 56 | return false; 57 | } 58 | } 59 | 60 | template 61 | CULE_ANNOTATION 62 | int32_t lives(State&) 63 | { 64 | return 0; 65 | } 66 | 67 | template 68 | CULE_ANNOTATION 69 | void setTerminal(State& s) 70 | { 71 | // update the reward 72 | const int32_t x = cule::atari::ram::read(s.ram, 13); // cpu score 73 | const int32_t y = cule::atari::ram::read(s.ram, 14); // player score 74 | 75 | // update terminal status 76 | // (game over when a player reaches 21) 77 | s.tiaFlags.template change(x == 21 || y == 21); 78 | } 79 | 80 | template 81 | CULE_ANNOTATION 82 | int32_t score(State& s) 83 | { 84 | // update the reward 85 | int32_t x = cule::atari::ram::read(s.ram, 13); // cpu score 86 | int32_t y = cule::atari::ram::read(s.ram, 14); // player score 87 | 88 | return (y - x); 89 | } 90 | 91 | template 92 | CULE_ANNOTATION 93 | int32_t reward(State& s) 94 | { 95 | return score(s) - s.score; 96 | } 97 | 98 | } // end namespace pong 99 | } // end namespace games 100 | } // end namespace atari 101 | } // end namespace cule 102 | 103 | -------------------------------------------------------------------------------- /cule/atari/games/pooyan.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace pooyan 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0x8A, 0x89, 0x88); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | int lives_byte = ram::read(s, 0x96); 41 | int some_byte = ram::read(s, 0x98); 42 | s.tiaFlags.template change((lives_byte == 0x0) && (some_byte == 0x05)); 43 | 44 | s.m_lives = (lives_byte & 0x7) + 1; 45 | } 46 | 47 | CULE_ANNOTATION 48 | bool isMinimal(const Action &a) 49 | { 50 | switch (a) 51 | { 52 | case ACTION_NOOP: 53 | case ACTION_FIRE: 54 | case ACTION_UP: 55 | case ACTION_DOWN: 56 | case ACTION_UPFIRE: 57 | case ACTION_DOWNFIRE: 58 | return true; 59 | default: 60 | return false; 61 | } 62 | } 63 | 64 | template 65 | CULE_ANNOTATION 66 | int32_t lives(State& s) 67 | { 68 | // update terminal status 69 | int lives_byte = cule::atari::ram::read(s.ram, 0x96); 70 | return (lives_byte & 0x7) + 1; 71 | } 72 | 73 | template 74 | CULE_ANNOTATION 75 | void setTerminal(State& s) 76 | { 77 | // update terminal status 78 | int32_t lives_byte = cule::atari::ram::read(s.ram, 0x96); 79 | int32_t some_byte = cule::atari::ram::read(s.ram, 0x98); 80 | s.tiaFlags.template change((lives_byte == 0x0) && (some_byte == 0x05)); 81 | } 82 | 83 | template 84 | CULE_ANNOTATION 85 | int32_t score(State& s) 86 | { 87 | return cule::atari::games::getDecimalScore(s, 0x8A, 0x89, 0x88); 88 | } 89 | 90 | template 91 | CULE_ANNOTATION 92 | int32_t reward(State& s) 93 | { 94 | return score(s) - s.score; 95 | } 96 | 97 | } // end namespace pooyan 98 | } // end namespace games 99 | } // end namespace atari 100 | } // end namespace cule 101 | 102 | -------------------------------------------------------------------------------- /cule/atari/games/privateeye.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace privateeye 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 1000; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xCA, 0xC9, 0xC8); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | int copyright_timer = ram::read(s, 0xC2); 41 | // 00 when the game is running; 01 at start of game. 42 | s.tiaFlags.template change((copyright_timer != 0x00) && (copyright_timer != 0x01)); 43 | } 44 | 45 | CULE_ANNOTATION 46 | bool isMinimal(const Action &a) 47 | { 48 | switch (a) 49 | { 50 | case ACTION_NOOP: 51 | case ACTION_FIRE: 52 | case ACTION_UP: 53 | case ACTION_RIGHT: 54 | case ACTION_LEFT: 55 | case ACTION_DOWN: 56 | case ACTION_UPRIGHT: 57 | case ACTION_UPLEFT: 58 | case ACTION_DOWNRIGHT: 59 | case ACTION_DOWNLEFT: 60 | case ACTION_UPFIRE: 61 | case ACTION_RIGHTFIRE: 62 | case ACTION_LEFTFIRE: 63 | case ACTION_DOWNFIRE: 64 | case ACTION_UPRIGHTFIRE: 65 | case ACTION_UPLEFTFIRE: 66 | case ACTION_DOWNRIGHTFIRE: 67 | case ACTION_DOWNLEFTFIRE: 68 | return true; 69 | default: 70 | return false; 71 | } 72 | } 73 | 74 | template 75 | CULE_ANNOTATION 76 | int32_t lives(State&) 77 | { 78 | return 0; 79 | } 80 | 81 | template 82 | CULE_ANNOTATION 83 | void setTerminal(State& s) 84 | { 85 | // update terminal status 86 | int copyright_timer = cule::atari::ram::read(s.ram, 0xC2); 87 | // 00 when the game is running; 01 at start of game. 88 | s.tiaFlags.template change((copyright_timer != 0x00) && (copyright_timer != 0x01)); 89 | } 90 | 91 | template 92 | CULE_ANNOTATION 93 | int32_t score(State& s) 94 | { 95 | return cule::atari::games::getDecimalScore(s, 0xCA, 0xC9, 0xC8); 96 | } 97 | 98 | template 99 | CULE_ANNOTATION 100 | int32_t reward(State& s) 101 | { 102 | return score(s) - s.score; 103 | } 104 | 105 | } // end namespace privateeye 106 | } // end namespace games 107 | } // end namespace atari 108 | } // end namespace cule 109 | 110 | -------------------------------------------------------------------------------- /cule/atari/games/robotank.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace robotank 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 4; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int dead_squadrons = ram::read(s, 0xB6); 35 | int dead_tanks = ram::read(s, 0xB5); 36 | int score = dead_squadrons * 12 + dead_tanks; 37 | int reward = score - s.m_score; 38 | s.m_reward = reward; 39 | s.m_score = score; 40 | 41 | // update terminal status 42 | int termination_flag = ram::read(s, 0xB4); 43 | int lives = ram::read(s, 0xA8); 44 | s.tiaFlags.template change((lives == 0) && (termination_flag == 0xFF)); 45 | 46 | s.m_lives = (lives & 0xF) + 1; 47 | } 48 | 49 | CULE_ANNOTATION 50 | bool isMinimal(const Action &a) 51 | { 52 | switch (a) 53 | { 54 | case ACTION_NOOP: 55 | case ACTION_FIRE: 56 | case ACTION_UP: 57 | case ACTION_RIGHT: 58 | case ACTION_LEFT: 59 | case ACTION_DOWN: 60 | case ACTION_UPRIGHT: 61 | case ACTION_UPLEFT: 62 | case ACTION_DOWNRIGHT: 63 | case ACTION_DOWNLEFT: 64 | case ACTION_UPFIRE: 65 | case ACTION_RIGHTFIRE: 66 | case ACTION_LEFTFIRE: 67 | case ACTION_DOWNFIRE: 68 | case ACTION_UPRIGHTFIRE: 69 | case ACTION_UPLEFTFIRE: 70 | case ACTION_DOWNRIGHTFIRE: 71 | case ACTION_DOWNLEFTFIRE: 72 | return true; 73 | default: 74 | return false; 75 | } 76 | } 77 | 78 | template 79 | CULE_ANNOTATION 80 | int32_t lives(State& s) 81 | { 82 | int lives = cule::atari::ram::read(s.ram, 0xA8); 83 | return (lives & 0xF) + 1; 84 | } 85 | 86 | template 87 | CULE_ANNOTATION 88 | void setTerminal(State& s) 89 | { 90 | // update terminal status 91 | int termination_flag = cule::atari::ram::read(s.ram, 0xB4); 92 | int lives = cule::atari::ram::read(s.ram, 0xA8); 93 | s.tiaFlags.template change((lives == 0) && (termination_flag == 0xFF)); 94 | } 95 | 96 | template 97 | CULE_ANNOTATION 98 | int32_t score(State& s) 99 | { 100 | // update the reward 101 | int dead_squadrons = cule::atari::ram::read(s.ram, 0xB6); 102 | int dead_tanks = cule::atari::ram::read(s.ram, 0xB5); 103 | return dead_squadrons * 12 + dead_tanks; 104 | } 105 | 106 | template 107 | CULE_ANNOTATION 108 | int32_t reward(State& s) 109 | { 110 | return score(s) - s.score; 111 | } 112 | 113 | } // end namespace robotank 114 | } // end namespace games 115 | } // end namespace atari 116 | } // end namespace cule 117 | 118 | -------------------------------------------------------------------------------- /cule/atari/games/seaquest.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace seaquest 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 4; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xBA, 0xB9, 0xB8); 35 | s.m_reward = score - s.m_score; 36 | s.m_score = score; 37 | 38 | // update terminal status 39 | s.tiaFlags.template change(ram::read(s, 0xA3) != 0); 40 | s.m_lives = ram::read(s, 0xBB) + 1; 41 | } 42 | 43 | CULE_ANNOTATION 44 | bool isMinimal(const Action &a) 45 | { 46 | switch (a) 47 | { 48 | case ACTION_NOOP: 49 | case ACTION_FIRE: 50 | case ACTION_UP: 51 | case ACTION_RIGHT: 52 | case ACTION_LEFT: 53 | case ACTION_DOWN: 54 | case ACTION_UPRIGHT: 55 | case ACTION_UPLEFT: 56 | case ACTION_DOWNRIGHT: 57 | case ACTION_DOWNLEFT: 58 | case ACTION_UPFIRE: 59 | case ACTION_RIGHTFIRE: 60 | case ACTION_LEFTFIRE: 61 | case ACTION_DOWNFIRE: 62 | case ACTION_UPRIGHTFIRE: 63 | case ACTION_UPLEFTFIRE: 64 | case ACTION_DOWNRIGHTFIRE: 65 | case ACTION_DOWNLEFTFIRE: 66 | return true; 67 | default: 68 | return false; 69 | } 70 | } 71 | 72 | template 73 | CULE_ANNOTATION 74 | int32_t lives(State& s) 75 | { 76 | return cule::atari::ram::read(s.ram, 0xBB) + 1; 77 | } 78 | 79 | template 80 | CULE_ANNOTATION 81 | void setTerminal(State& s) 82 | { 83 | // update terminal status 84 | s.tiaFlags.template change(cule::atari::ram::read(s.ram, 0xA3) != 0); 85 | } 86 | 87 | template 88 | CULE_ANNOTATION 89 | int32_t score(State& s) 90 | { 91 | return cule::atari::games::getDecimalScore(s, 0xBA, 0xB9, 0xB8); 92 | } 93 | 94 | template 95 | CULE_ANNOTATION 96 | int32_t reward(State& s) 97 | { 98 | return score(s) - s.score; 99 | } 100 | 101 | } // end namespace seaquest 102 | } // end namespace games 103 | } // end namespace atari 104 | } // end namespace cule 105 | 106 | -------------------------------------------------------------------------------- /cule/atari/games/skiing.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace skiing 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | } 24 | 25 | template 26 | CULE_ANNOTATION 27 | void step(State& s) 28 | { 29 | using cule::atari::games::getDecimalScore; 30 | using cule::atari::ram::read; 31 | 32 | // update the reward 33 | int centiseconds = getDecimalScore(s, 0xEA, 0xE9); 34 | int minutes = ram::read(s, 0xE8); 35 | int score = minutes * 6000 + centiseconds; 36 | int reward = s.m_score - score; // negative reward for time 37 | s.m_reward = reward; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | int end_flag = ram::read(s, 0x91); 42 | s.tiaFlags.template change(end_flag == 0xFF); 43 | } 44 | 45 | CULE_ANNOTATION 46 | bool isMinimal(const Action &a) 47 | { 48 | switch (a) 49 | { 50 | case ACTION_NOOP: 51 | case ACTION_RIGHT: 52 | case ACTION_LEFT: 53 | return true; 54 | default: 55 | return false; 56 | } 57 | } 58 | 59 | template 60 | CULE_ANNOTATION 61 | int32_t lives(State&) 62 | { 63 | return 0; 64 | } 65 | 66 | template 67 | CULE_ANNOTATION 68 | void setTerminal(State& s) 69 | { 70 | // update terminal status 71 | int end_flag = cule::atari::ram::read(s.ram, 0x91); 72 | s.tiaFlags.template change(end_flag == 0xFF); 73 | } 74 | 75 | template 76 | CULE_ANNOTATION 77 | int32_t score(State& s) 78 | { 79 | // update the reward 80 | int centiseconds = cule::atari::games::getDecimalScore(s, 0xEA, 0xE9); 81 | int minutes = cule::atari::ram::read(s.ram, 0xE8); 82 | return minutes * 6000 + centiseconds; 83 | } 84 | 85 | template 86 | CULE_ANNOTATION 87 | int32_t reward(State& s) 88 | { 89 | return s.score - score(s); 90 | } 91 | 92 | } // end namespace skiing 93 | } // end namespace games 94 | } // end namespace atari 95 | } // end namespace cule 96 | 97 | -------------------------------------------------------------------------------- /cule/atari/games/solaris.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace solaris 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | // only 5 digits are displayed but we keep track of 6 digits 35 | int score = getDecimalScore(s, 0xDC, 0xDD, 0xDE); 36 | score *= 10; 37 | int reward = score - s.m_score; 38 | s.m_reward = reward; 39 | s.m_score = score; 40 | 41 | // update terminal status 42 | int lives_byte = ram::read(s, 0xD9); 43 | s.tiaFlags.template change(lives_byte == 0); 44 | 45 | s.m_lives = lives_byte & 0xF; 46 | } 47 | 48 | CULE_ANNOTATION 49 | bool isMinimal(const Action &a) 50 | { 51 | switch (a) 52 | { 53 | case ACTION_NOOP: 54 | case ACTION_FIRE: 55 | case ACTION_UP: 56 | case ACTION_RIGHT: 57 | case ACTION_LEFT: 58 | case ACTION_DOWN: 59 | case ACTION_UPRIGHT: 60 | case ACTION_UPLEFT: 61 | case ACTION_DOWNRIGHT: 62 | case ACTION_DOWNLEFT: 63 | case ACTION_UPFIRE: 64 | case ACTION_RIGHTFIRE: 65 | case ACTION_LEFTFIRE: 66 | case ACTION_DOWNFIRE: 67 | case ACTION_UPRIGHTFIRE: 68 | case ACTION_UPLEFTFIRE: 69 | case ACTION_DOWNRIGHTFIRE: 70 | case ACTION_DOWNLEFTFIRE: 71 | return true; 72 | default: 73 | return false; 74 | } 75 | } 76 | 77 | template 78 | CULE_ANNOTATION 79 | int32_t lives(State& s) 80 | { 81 | // update terminal status 82 | int32_t lives_byte = cule::atari::ram::read(s.ram, 0xD9); 83 | return lives_byte & 0xF; 84 | } 85 | 86 | template 87 | CULE_ANNOTATION 88 | void setTerminal(State& s) 89 | { 90 | // update terminal status 91 | s.tiaFlags.template change(lives(s) == 0); 92 | } 93 | 94 | template 95 | CULE_ANNOTATION 96 | int32_t score(State& s) 97 | { 98 | return 10 * cule::atari::games::getDecimalScore(s, 0xDC, 0xDD, 0xDE); 99 | } 100 | 101 | template 102 | CULE_ANNOTATION 103 | int32_t reward(State& s) 104 | { 105 | return score(s) - s.score; 106 | } 107 | 108 | } // end namespace solaris 109 | } // end namespace games 110 | } // end namespace atari 111 | } // end namespace cule 112 | 113 | -------------------------------------------------------------------------------- /cule/atari/games/spaceinvaders.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace spaceinvaders 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.m_lives = 3; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xE8, 0xE6); 35 | 36 | // reward cannot get negative in this game. When it does, it means that the score has looped 37 | // (overflow) 38 | s.m_reward = score - s.m_score; 39 | if(s.m_reward < 0) 40 | { 41 | // 10000 is the highest possible score 42 | const int maximumScore = 10000; 43 | s.m_reward = (maximumScore - s.m_score) + score; 44 | } 45 | s.m_score = score; 46 | s.m_lives = ram::read(s, 0xC9); 47 | 48 | // update terminal status 49 | // If bit 0x80 is on, then game is over 50 | int some_byte = ram::read(s, 0x98); 51 | s.tiaFlags.template change((some_byte & 0x80) || (s.m_lives == 0)); 52 | } 53 | 54 | CULE_ANNOTATION 55 | bool isMinimal(const Action &a) 56 | { 57 | switch (a) 58 | { 59 | case ACTION_NOOP: 60 | case ACTION_LEFT: 61 | case ACTION_RIGHT: 62 | case ACTION_FIRE: 63 | case ACTION_LEFTFIRE: 64 | case ACTION_RIGHTFIRE: 65 | return true; 66 | default: 67 | return false; 68 | } 69 | } 70 | 71 | template 72 | CULE_ANNOTATION 73 | int32_t lives(State& s) 74 | { 75 | return cule::atari::ram::read(s.ram, 0xC9); 76 | } 77 | 78 | template 79 | CULE_ANNOTATION 80 | void setTerminal(State& s) 81 | { 82 | // update terminal status 83 | // If bit 0x80 is on, then game is over 84 | int some_byte = cule::atari::ram::read(s.ram, 0x98); 85 | s.tiaFlags.template change((some_byte & 0x80) || (lives(s) == 0)); 86 | } 87 | 88 | template 89 | CULE_ANNOTATION 90 | int32_t score(State& s) 91 | { 92 | return cule::atari::games::getDecimalScore(s, 0xE8, 0xE6); 93 | } 94 | 95 | template 96 | CULE_ANNOTATION 97 | int32_t reward(State& s) 98 | { 99 | return score(s) - s.score; 100 | } 101 | 102 | } // end namespace spaceinvaders 103 | } // end namespace games 104 | } // end namespace atari 105 | } // end namespace cule 106 | 107 | -------------------------------------------------------------------------------- /cule/atari/games/timepilot.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace timepilot 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 5; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0x8D, 0x8F); 35 | score *= 100; 36 | int reward = score - s.m_score; 37 | s.m_reward = reward; 38 | s.m_score = score; 39 | 40 | int lives_byte = ram::read(s, 0x8B) & 0x7; 41 | int screen_byte = ram::read(s, 0x80) & 0xF; 42 | 43 | // update terminal status 44 | s.tiaFlags.template change(ram::read(s, 0xA0)); 45 | // Only update lives when actually flying; otherwise funny stuff happens 46 | s.m_lives = (screen_byte == 2) ? (lives_byte + 1) : s.m_lives; 47 | } 48 | 49 | CULE_ANNOTATION 50 | bool isMinimal(const Action &a) 51 | { 52 | switch (a) 53 | { 54 | case ACTION_NOOP: 55 | case ACTION_FIRE: 56 | case ACTION_UP: 57 | case ACTION_RIGHT: 58 | case ACTION_LEFT: 59 | case ACTION_DOWN: 60 | case ACTION_UPFIRE: 61 | case ACTION_RIGHTFIRE: 62 | case ACTION_LEFTFIRE: 63 | case ACTION_DOWNFIRE: 64 | return true; 65 | default: 66 | return false; 67 | } 68 | } 69 | 70 | template 71 | CULE_ANNOTATION 72 | int32_t lives(State& s) 73 | { 74 | int lives_byte = cule::atari::ram::read(s.ram, 0x8B) & 0x7; 75 | return lives_byte + 1; 76 | } 77 | 78 | template 79 | CULE_ANNOTATION 80 | void setTerminal(State& s) 81 | { 82 | // update terminal status 83 | s.tiaFlags.template change(cule::atari::ram::read(s.ram, 0xA0)); 84 | } 85 | 86 | template 87 | CULE_ANNOTATION 88 | int32_t score(State& s) 89 | { 90 | return 100 * cule::atari::games::getDecimalScore(s, 0x8D, 0x8F); 91 | } 92 | 93 | template 94 | CULE_ANNOTATION 95 | int32_t reward(State& s) 96 | { 97 | return score(s) - s.score; 98 | } 99 | 100 | } // end namespace timepilot 101 | } // end namespace games 102 | } // end namespace atari 103 | } // end namespace cule 104 | 105 | -------------------------------------------------------------------------------- /cule/atari/games/tutankham.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace tutankham 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0x9C, 0x9A); 35 | int reward = score - s.m_score; 36 | s.m_reward = reward; 37 | s.m_score = score; 38 | 39 | // update terminal status 40 | int lives_byte = ram::read(s, 0x9E); 41 | // byte 0x81 is set to 0x84 when the game is loaded, but not reset 42 | int some_byte = ram::read(s, 0x81); 43 | 44 | s.tiaFlags.template change((lives_byte == 0) && (some_byte != 0x84)); 45 | 46 | s.m_lives = (lives_byte & 0x3); 47 | } 48 | 49 | CULE_ANNOTATION 50 | bool isMinimal(const Action &a) 51 | { 52 | switch (a) 53 | { 54 | case ACTION_NOOP: 55 | case ACTION_UP: 56 | case ACTION_RIGHT: 57 | case ACTION_LEFT: 58 | case ACTION_DOWN: 59 | case ACTION_UPFIRE: 60 | case ACTION_RIGHTFIRE: 61 | case ACTION_LEFTFIRE: 62 | return true; 63 | default: 64 | return false; 65 | } 66 | } 67 | 68 | template 69 | CULE_ANNOTATION 70 | int32_t lives(State& s) 71 | { 72 | // update terminal status 73 | int lives_byte = cule::atari::ram::read(s.ram, 0x9E); 74 | return (lives_byte & 0x3); 75 | } 76 | 77 | template 78 | CULE_ANNOTATION 79 | void setTerminal(State& s) 80 | { 81 | // update terminal status 82 | int lives_byte = cule::atari::ram::read(s.ram, 0x9E); 83 | // byte 0x81 is set to 0x84 when the game is loaded, but not reset 84 | int some_byte = cule::atari::ram::read(s.ram, 0x81); 85 | 86 | s.tiaFlags.template change((lives_byte == 0) && (some_byte != 0x84)); 87 | } 88 | 89 | template 90 | CULE_ANNOTATION 91 | int32_t score(State& s) 92 | { 93 | return cule::atari::games::getDecimalScore(s, 0x9C, 0x9A); 94 | } 95 | 96 | template 97 | CULE_ANNOTATION 98 | int32_t reward(State& s) 99 | { 100 | return score(s) - s.score; 101 | } 102 | 103 | } // end namespace tutankham 104 | } // end namespace games 105 | } // end namespace atari 106 | } // end namespace cule 107 | 108 | -------------------------------------------------------------------------------- /cule/atari/games/upndown.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace upndown 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 0; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | s.m_lives = 5; 25 | } 26 | 27 | template 28 | CULE_ANNOTATION 29 | void step(State& s) 30 | { 31 | using cule::atari::games::getDecimalScore; 32 | using cule::atari::ram::read; 33 | 34 | // update the reward 35 | int score = getDecimalScore(s, 0x82, 0x81, 0x80); 36 | int reward = score - s.m_score; 37 | s.m_reward = reward; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | int lives_byte = ram::read(s, 0x86) & 0xF; 42 | int death_timer = ram::read(s, 0x94); 43 | s.tiaFlags.template change((death_timer > 0x40) && (lives_byte == 0)); 44 | 45 | s.m_lives = lives_byte + 1; 46 | } 47 | 48 | CULE_ANNOTATION 49 | bool isMinimal(const Action &a) 50 | { 51 | switch (a) 52 | { 53 | case ACTION_NOOP: 54 | case ACTION_FIRE: 55 | case ACTION_UP: 56 | case ACTION_DOWN: 57 | case ACTION_UPFIRE: 58 | case ACTION_DOWNFIRE: 59 | return true; 60 | default: 61 | return false; 62 | } 63 | } 64 | 65 | template 66 | CULE_ANNOTATION 67 | int32_t lives(State& s) 68 | { 69 | int lives_byte = cule::atari::ram::read(s.ram, 0x86) & 0xF; 70 | return lives_byte + 1; 71 | } 72 | 73 | 74 | template 75 | CULE_ANNOTATION 76 | void setTerminal(State& s) 77 | { 78 | // update terminal status 79 | int lives_byte = cule::atari::ram::read(s.ram, 0x86) & 0xF; 80 | int death_timer = cule::atari::ram::read(s.ram, 0x94); 81 | s.tiaFlags.template change((death_timer > 0x40) && (lives_byte == 0)); 82 | } 83 | 84 | template 85 | CULE_ANNOTATION 86 | int32_t score(State& s) 87 | { 88 | return cule::atari::games::getDecimalScore(s, 0x82, 0x81, 0x80); 89 | } 90 | 91 | template 92 | CULE_ANNOTATION 93 | int32_t reward(State& s) 94 | { 95 | return score(s) - s.score; 96 | } 97 | 98 | } // end namespace upndown 99 | } // end namespace games 100 | } // end namespace atari 101 | } // end namespace cule 102 | 103 | -------------------------------------------------------------------------------- /cule/atari/games/venture.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace venture 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 4; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xC8, 0xC7); 35 | score *= 100; 36 | int reward = score - s.m_score; 37 | s.m_reward = reward; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | int lives_byte = ram::read(s, 0xC6); 42 | int audio_byte = ram::read(s, 0xCD); 43 | int death_byte = ram::read(s, 0xBF); 44 | s.tiaFlags.template change((lives_byte == 0) && (audio_byte == 0xFF) && (death_byte & 0x80)); 45 | 46 | s.m_lives = (lives_byte & 0x7) + 1; 47 | } 48 | 49 | CULE_ANNOTATION 50 | bool isMinimal(const Action &a) 51 | { 52 | switch (a) 53 | { 54 | case ACTION_NOOP: 55 | case ACTION_FIRE: 56 | case ACTION_UP: 57 | case ACTION_RIGHT: 58 | case ACTION_LEFT: 59 | case ACTION_DOWN: 60 | case ACTION_UPRIGHT: 61 | case ACTION_UPLEFT: 62 | case ACTION_DOWNRIGHT: 63 | case ACTION_DOWNLEFT: 64 | case ACTION_UPFIRE: 65 | case ACTION_RIGHTFIRE: 66 | case ACTION_LEFTFIRE: 67 | case ACTION_DOWNFIRE: 68 | case ACTION_UPRIGHTFIRE: 69 | case ACTION_UPLEFTFIRE: 70 | case ACTION_DOWNRIGHTFIRE: 71 | case ACTION_DOWNLEFTFIRE: 72 | return true; 73 | default: 74 | return false; 75 | } 76 | } 77 | 78 | template 79 | CULE_ANNOTATION 80 | int32_t lives(State& s) 81 | { 82 | // update terminal status 83 | int lives_byte = cule::atari::ram::read(s.ram, 0xC6); 84 | return (lives_byte & 0x7) + 1; 85 | } 86 | 87 | template 88 | CULE_ANNOTATION 89 | void setTerminal(State& s) 90 | { 91 | // update terminal status 92 | int lives_byte = cule::atari::ram::read(s.ram, 0xC6); 93 | int audio_byte = cule::atari::ram::read(s.ram, 0xCD); 94 | int death_byte = cule::atari::ram::read(s.ram, 0xBF); 95 | s.tiaFlags.template change((lives_byte == 0) && (audio_byte == 0xFF) && (death_byte & 0x80)); 96 | } 97 | 98 | template 99 | CULE_ANNOTATION 100 | int32_t score(State& s) 101 | { 102 | return 100 * cule::atari::games::getDecimalScore(s, 0xC8, 0xC7); 103 | } 104 | 105 | template 106 | CULE_ANNOTATION 107 | int32_t reward(State& s) 108 | { 109 | return score(s) - s.score; 110 | } 111 | 112 | } // end namespace venture 113 | } // end namespace games 114 | } // end namespace atari 115 | } // end namespace cule 116 | 117 | -------------------------------------------------------------------------------- /cule/atari/games/wizard.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace wizard 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 3; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0x86, 0x88); 35 | if (score >= 8000) score -= 8000; // MGB score does not go beyond 999 36 | score *= 100; 37 | s.m_reward = score - s.m_score; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | int newLives = ram::read(s, 0x8D) & 15; 42 | int byte1 = ram::read(s, 0xF4); 43 | 44 | bool isWaiting = (ram::read(s, 0xD7) & 0x1) == 0; 45 | 46 | s.tiaFlags.template change((newLives == 0) && (byte1 == 0xF8)); 47 | 48 | // Wizard of Wor decreases the life total when we move into the play field; we only 49 | // change the life total when we actually are waiting 50 | s.m_lives = isWaiting ? newLives : s.m_lives; 51 | } 52 | 53 | CULE_ANNOTATION 54 | bool isMinimal(const Action &a) 55 | { 56 | switch (a) 57 | { 58 | case ACTION_NOOP: 59 | case ACTION_FIRE: 60 | case ACTION_UP: 61 | case ACTION_RIGHT: 62 | case ACTION_LEFT: 63 | case ACTION_DOWN: 64 | case ACTION_UPFIRE: 65 | case ACTION_RIGHTFIRE: 66 | case ACTION_LEFTFIRE: 67 | case ACTION_DOWNFIRE: 68 | return true; 69 | default: 70 | return false; 71 | } 72 | } 73 | 74 | template 75 | CULE_ANNOTATION 76 | int32_t lives(State& s) 77 | { 78 | return cule::atari::ram::read(s.ram, 0x8D) & 15; 79 | } 80 | 81 | template 82 | CULE_ANNOTATION 83 | void setTerminal(State& s) 84 | { 85 | // update terminal status 86 | int newLives = cule::atari::ram::read(s.ram, 0x8D) & 15; 87 | int byte1 = cule::atari::ram::read(s.ram, 0xF4); 88 | 89 | s.tiaFlags.template change((newLives == 0) && (byte1 == 0xF8)); 90 | } 91 | 92 | template 93 | CULE_ANNOTATION 94 | int32_t score(State& s) 95 | { 96 | // update the reward 97 | int m_score = cule::atari::games::getDecimalScore(s, 0x86, 0x88); 98 | if (m_score >= 8000) m_score -= 8000; // MGB score does not go beyond 999 99 | m_score *= 100; 100 | 101 | return m_score; 102 | } 103 | 104 | template 105 | CULE_ANNOTATION 106 | int32_t reward(State& s) 107 | { 108 | return score(s) - s.score; 109 | } 110 | 111 | } // end namespace wizard 112 | } // end namespace games 113 | } // end namespace atari 114 | } // end namespace cule 115 | 116 | -------------------------------------------------------------------------------- /cule/atari/games/yarsrevenge.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | namespace games 13 | { 14 | namespace yarsrevenge 15 | { 16 | 17 | template 18 | CULE_ANNOTATION 19 | void reset(State& s) 20 | { 21 | s.m_reward = 0; 22 | s.m_score = 0; 23 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 24 | s.m_lives = 4; 25 | } 26 | 27 | template 28 | CULE_ANNOTATION 29 | void step(State& s) 30 | { 31 | using cule::atari::games::getDecimalScore; 32 | using cule::atari::ram::read; 33 | 34 | // update the reward 35 | int score = getDecimalScore(s, 0xE2, 0xE1, 0xE0); 36 | int reward = score - s.m_score; 37 | s.m_reward = reward; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | int lives_byte = ram::read(s, 0x9E) >> 4; 42 | s.tiaFlags.template change(lives_byte == 0); 43 | 44 | s.m_lives = lives_byte; 45 | } 46 | 47 | CULE_ANNOTATION 48 | bool isMinimal(const Action &a) 49 | { 50 | switch (a) 51 | { 52 | case ACTION_NOOP: 53 | case ACTION_FIRE: 54 | case ACTION_UP: 55 | case ACTION_RIGHT: 56 | case ACTION_LEFT: 57 | case ACTION_DOWN: 58 | case ACTION_UPRIGHT: 59 | case ACTION_UPLEFT: 60 | case ACTION_DOWNRIGHT: 61 | case ACTION_DOWNLEFT: 62 | case ACTION_UPFIRE: 63 | case ACTION_RIGHTFIRE: 64 | case ACTION_LEFTFIRE: 65 | case ACTION_DOWNFIRE: 66 | case ACTION_UPRIGHTFIRE: 67 | case ACTION_UPLEFTFIRE: 68 | case ACTION_DOWNRIGHTFIRE: 69 | case ACTION_DOWNLEFTFIRE: 70 | return true; 71 | default: 72 | return false; 73 | } 74 | } 75 | 76 | template 77 | CULE_ANNOTATION 78 | int32_t lives(State& s) 79 | { 80 | return cule::atari::ram::read(s.ram, 0x9E) >> 4; 81 | } 82 | 83 | template 84 | CULE_ANNOTATION 85 | void setTerminal(State& s) 86 | { 87 | // update terminal status 88 | s.tiaFlags.template change(lives(s) == 0); 89 | } 90 | 91 | template 92 | CULE_ANNOTATION 93 | int32_t score(State& s) 94 | { 95 | return cule::atari::games::getDecimalScore(s, 0xE2, 0xE1, 0xE0); 96 | } 97 | 98 | template 99 | CULE_ANNOTATION 100 | int32_t reward(State& s) 101 | { 102 | return score(s) - s.score; 103 | } 104 | 105 | } // end namespace yarsrevenge 106 | } // end namespace games 107 | } // end namespace atari 108 | } // end namespace cule 109 | 110 | -------------------------------------------------------------------------------- /cule/atari/games/zaxxon.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | namespace games 12 | { 13 | namespace zaxxon 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State& s) 19 | { 20 | s.m_reward = 0; 21 | s.m_score = 0; 22 | s.tiaFlags.clear(FLAG_ALE_TERMINAL); 23 | s.m_lives = 5; 24 | } 25 | 26 | template 27 | CULE_ANNOTATION 28 | void step(State& s) 29 | { 30 | using cule::atari::games::getDecimalScore; 31 | using cule::atari::ram::read; 32 | 33 | // update the reward 34 | int score = getDecimalScore(s, 0xE9, 0xE8); 35 | score *= 100; 36 | int reward = score - s.m_score; 37 | s.m_reward = reward; 38 | s.m_score = score; 39 | 40 | // update terminal status 41 | int lives_byte = ram::read(s, 0xEA) & 0x7; 42 | // Note - this *requires* a reset at load time; lives are set to 0 before 43 | // reset is pushed 44 | s.tiaFlags.template change(lives_byte == 0); 45 | 46 | s.m_lives = lives_byte; 47 | } 48 | 49 | CULE_ANNOTATION 50 | bool isMinimal(const Action &a) 51 | { 52 | switch (a) 53 | { 54 | case ACTION_NOOP: 55 | case ACTION_FIRE: 56 | case ACTION_UP: 57 | case ACTION_RIGHT: 58 | case ACTION_LEFT: 59 | case ACTION_DOWN: 60 | case ACTION_UPRIGHT: 61 | case ACTION_UPLEFT: 62 | case ACTION_DOWNRIGHT: 63 | case ACTION_DOWNLEFT: 64 | case ACTION_UPFIRE: 65 | case ACTION_RIGHTFIRE: 66 | case ACTION_LEFTFIRE: 67 | case ACTION_DOWNFIRE: 68 | case ACTION_UPRIGHTFIRE: 69 | case ACTION_UPLEFTFIRE: 70 | case ACTION_DOWNRIGHTFIRE: 71 | case ACTION_DOWNLEFTFIRE: 72 | return true; 73 | default: 74 | return false; 75 | } 76 | } 77 | 78 | template 79 | CULE_ANNOTATION 80 | int32_t lives(State& s) 81 | { 82 | return cule::atari::ram::read(s.ram, 0xEA) & 0x7; 83 | } 84 | 85 | template 86 | CULE_ANNOTATION 87 | void setTerminal(State& s) 88 | { 89 | // update terminal status 90 | int lives_byte = cule::atari::ram::read(s.ram, 0xEA) & 0x7; 91 | // Note - this *requires* a reset at load time; lives are set to 0 before 92 | // reset is pushed 93 | s.tiaFlags.template change(lives_byte == 0); 94 | } 95 | 96 | template 97 | CULE_ANNOTATION 98 | int32_t score(State& s) 99 | { 100 | return 100 * cule::atari::games::getDecimalScore(s, 0xE9, 0xE8); 101 | } 102 | 103 | template 104 | CULE_ANNOTATION 105 | int32_t reward(State& s) 106 | { 107 | return score(s) - s.score; 108 | } 109 | 110 | } // end namespace zaxxon 111 | } // end namespace games 112 | } // end namespace atari 113 | } // end namespace cule 114 | 115 | -------------------------------------------------------------------------------- /cule/atari/internals.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | namespace cule 9 | { 10 | namespace atari 11 | { 12 | 13 | // internal type definitions 14 | #ifdef FAST_TYPE 15 | typedef unsigned _addr16_t; 16 | typedef unsigned _addr15_t; 17 | typedef unsigned _addr14_t; 18 | typedef unsigned _addr8_t; 19 | typedef unsigned _reg8_t; 20 | typedef unsigned _alutemp_t; 21 | typedef unsigned byte_t; 22 | typedef unsigned word_t; 23 | typedef unsigned uint_t; 24 | #else // FAST_TYPE 25 | typedef uint16_t _addr16_t; 26 | typedef uint16_t _addr15_t; 27 | typedef uint16_t _addr14_t; 28 | typedef uint8_t _addr8_t; 29 | typedef uint8_t _reg8_t; 30 | typedef uint16_t _alutemp_t; 31 | typedef uint8_t byte_t; 32 | typedef uint16_t word_t; 33 | typedef uint32_t uint_t; 34 | #endif // EXACT_TYPE 35 | 36 | // address 37 | typedef bit_field<_addr16_t,16> maddr_t; 38 | typedef bit_field<_addr15_t,15> scroll_t, addr15_t; 39 | typedef bit_field<_addr14_t,14> vaddr_t, addr14_t; 40 | typedef bit_field<_addr8_t,8> maddr8_t, saddr_t; 41 | 42 | // cpu 43 | typedef uint8_t opcode_t; 44 | typedef uint8_t operand_t; 45 | 46 | // alu 47 | typedef bit_field operandb_t; 48 | typedef bit_field operandw_t; 49 | typedef bit_field<_alutemp_t,8> alu_t; 50 | 51 | // color 52 | typedef uint32_t rgb32_t; 53 | typedef uint16_t rgb16_t, rgb15_t; 54 | typedef bit_field<_reg8_t, 8> reg_bit_field_t; 55 | 56 | // others 57 | typedef bit_field offset3_t; 58 | typedef bit_field offset10_t; 59 | 60 | using sys_flag_t = flag_set; 61 | using tia_flag_t = flag_set; 62 | using collision_t = flag_set; 63 | 64 | } // end namespace atari 65 | } // end namespace cule 66 | 67 | -------------------------------------------------------------------------------- /cule/atari/joystick.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include 8 | 9 | namespace cule 10 | { 11 | namespace atari 12 | { 13 | namespace joystick 14 | { 15 | 16 | template 17 | CULE_ANNOTATION 18 | void reset(State&){} 19 | 20 | template 21 | CULE_ANNOTATION 22 | void applyAction(State&){} 23 | 24 | /** 25 | Read the value of the specified digital pin for this controller. 26 | 27 | @param pin The pin of the controller jack to read 28 | @return The State of the pin 29 | */ 30 | template 31 | CULE_ANNOTATION 32 | bool read(State& s, const Control_Jack& jack, const Control_DigitalPin& pin) 33 | { 34 | if(s.sysFlags[FLAG_CON_SWAP] == (jack == Control_Left)) return true; 35 | 36 | switch(pin) 37 | { 38 | case Control_One: 39 | return s.sysFlags[FLAG_CON_UP] == 0; 40 | 41 | case Control_Two: 42 | return s.sysFlags[FLAG_CON_DOWN] == 0; 43 | 44 | case Control_Three: 45 | return s.sysFlags[FLAG_CON_LEFT] == 0; 46 | 47 | case Control_Four: 48 | return s.sysFlags[FLAG_CON_RIGHT] == 0; 49 | 50 | case Control_Six: 51 | return s.sysFlags[FLAG_CON_FIRE] == 0; 52 | } 53 | 54 | return false; 55 | } 56 | 57 | /** 58 | Read the resistance at the specified analog pin for this controller. 59 | The returned value is the resistance measured in ohms. 60 | 61 | @param pin The pin of the controller jack to read 62 | @return The resistance at the specified pin 63 | */ 64 | template 65 | CULE_ANNOTATION 66 | int32_t read(State&, const Control_Jack&, const Control_AnalogPin&) 67 | { 68 | return Control_maximumResistance; 69 | } 70 | 71 | } // end namespace joystick 72 | } // end namespace atari 73 | } // end namespace cule 74 | 75 | -------------------------------------------------------------------------------- /cule/atari/paddles.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #define PADDLE_DELTA 23000 9 | 10 | // MGB Values taken from Paddles.cxx (Stella 3.3) - 1400000 * [5,235] / 255 11 | #define PADDLE_MIN 27450 12 | 13 | // MGB - was 1290196; updated to 790196... seems to be fine for breakout and pong; 14 | // avoids pong paddle going off screen 15 | #define PADDLE_MAX 790196 16 | 17 | #define PADDLE_DEFAULT_VALUE (((PADDLE_MAX - PADDLE_MIN) / 2) + PADDLE_MIN) 18 | 19 | namespace cule 20 | { 21 | namespace atari 22 | { 23 | namespace paddles 24 | { 25 | 26 | template 27 | CULE_ANNOTATION 28 | void reset(State& s) 29 | { 30 | s.resistance = PADDLE_DEFAULT_VALUE; 31 | } 32 | 33 | /** Applies paddle actions. This actually modifies the game State by updating the paddle 34 | * resistances. */ 35 | template 36 | CULE_ANNOTATION 37 | void applyAction(State& s) 38 | { 39 | // First compute whether we should increase or decrease the paddle position 40 | s.resistance += (s.sysFlags[FLAG_CON_LEFT] - s.sysFlags[FLAG_CON_RIGHT]) * PADDLE_DELTA; 41 | 42 | // Now update the paddle position 43 | s.resistance = cule::max(cule::min(s.resistance, PADDLE_MAX), PADDLE_MIN); 44 | } 45 | 46 | /** 47 | Read the value of the specified digital pin for this controller. 48 | 49 | @param pin The pin of the controller jack to read 50 | @return The State of the pin 51 | */ 52 | template 53 | CULE_ANNOTATION 54 | bool read(State& s, const Control_Jack& jack, const Control_DigitalPin& pin) 55 | { 56 | return (jack == Control_Right) || 57 | ((pin != Control_Three) && (pin != Control_Four)) || 58 | (s.sysFlags[FLAG_CON_SWAP] == (pin == Control_Four)) || 59 | (s.sysFlags[FLAG_CON_FIRE]==0); 60 | } 61 | 62 | /** 63 | Read the resistance at the specified analog pin for this controller. 64 | The returned value is the resistance measured in ohms. 65 | 66 | @param pin The pin of the controller jack to read 67 | @return The resistance at the specified pin 68 | */ 69 | template 70 | CULE_ANNOTATION 71 | int32_t read(State& s, const Control_Jack& jack, const Control_AnalogPin& pin) 72 | { 73 | return (jack != Control_Right) * ((s.sysFlags[FLAG_CON_SWAP] == (pin == Control_Five)) ? s.resistance : PADDLE_DEFAULT_VALUE); 74 | } 75 | 76 | } // end namespace paddles 77 | } // end namespace atari 78 | } // end namespace cule 79 | 80 | -------------------------------------------------------------------------------- /cule/atari/prng.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #define PCG_DEFAULT_MULTIPLIER_32 747796405U 6 | #define PCG_DEFAULT_INCREMENT_32 2891336453U 7 | 8 | #ifndef CURAND_2POW32_INV 9 | #define CURAND_2POW32_INV (2.3283064e-10f) 10 | #endif 11 | 12 | namespace cule 13 | { 14 | namespace atari 15 | { 16 | 17 | class prng 18 | { 19 | public: 20 | 21 | CULE_ANNOTATION 22 | prng(uint32_t& state) 23 | : state(state) 24 | {} 25 | 26 | CULE_ANNOTATION 27 | void initialize(const uint32_t seed) 28 | { 29 | state = 0U; 30 | sample(); 31 | state += seed; 32 | sample(); 33 | } 34 | 35 | CULE_ANNOTATION 36 | uint32_t sample() 37 | { 38 | uint32_t oldstate = state; 39 | 40 | // Advance internal state 41 | state = oldstate * PCG_DEFAULT_MULTIPLIER_32 + PCG_DEFAULT_INCREMENT_32; 42 | 43 | uint32_t word = ((oldstate >> ((oldstate >> 28U) + 4U)) ^ oldstate) * 277803737U; 44 | 45 | return (word >> 22U) ^ word; 46 | } 47 | 48 | CULE_ANNOTATION 49 | float sample_float() 50 | { 51 | return sample() * CURAND_2POW32_INV + (CURAND_2POW32_INV / 2.0f); 52 | } 53 | 54 | private: 55 | 56 | uint32_t& state; 57 | }; 58 | 59 | } // end namespace atari 60 | } // end namespace cule 61 | 62 | -------------------------------------------------------------------------------- /cule/atari/rom.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | namespace cule 12 | { 13 | namespace atari 14 | { 15 | 16 | class rom 17 | { 18 | public: 19 | 20 | enum 21 | { 22 | MAX_ROM_SIZE = 16 * 1024, 23 | }; 24 | 25 | static const char* names[_ROM_MAX]; 26 | 27 | rom(const std::string& filename = ""); 28 | 29 | rom(const rom& other); 30 | 31 | void reset(const std::string& filename); 32 | 33 | std::vector const& minimal_actions() const; 34 | 35 | std::string file_name() const; 36 | 37 | std::string game_name() const; 38 | 39 | std::string type_name() const; 40 | 41 | std::string md5() const; 42 | 43 | bool swap_ports() const; 44 | 45 | bool use_paddles() const; 46 | 47 | bool swap_paddles() const; 48 | 49 | bool allow_hmove_blanks() const; 50 | 51 | bool player_left_difficulty_B() const; 52 | 53 | bool player_right_difficulty_B() const; 54 | 55 | bool is_supported() const; 56 | 57 | bool is_ntsc() const; 58 | 59 | size_t ram_size() const; 60 | 61 | size_t rom_size() const; 62 | 63 | size_t screen_height() const; 64 | 65 | size_t screen_width() const; 66 | 67 | size_t screen_size() const; 68 | 69 | ROM_FORMAT type() const; 70 | 71 | bool has_banks() const; 72 | 73 | games::GAME_TYPE game_id() const; 74 | 75 | uint8_t const* data() const; 76 | 77 | template 78 | CULE_ANNOTATION 79 | static void write(State& s, const maddr_t& addr, const uint8_t& value); 80 | 81 | template 82 | CULE_ANNOTATION 83 | static uint8_t read(State& s, const maddr_t& addr); 84 | 85 | private: 86 | 87 | std::string value_or_default(const games::ROM_ATTR attr) const; 88 | 89 | void set_game_id(); 90 | 91 | void compute_md5(); 92 | 93 | size_t _ram_size; 94 | size_t _rom_size; 95 | ROM_FORMAT _type; 96 | games::GAME_TYPE _gameId; 97 | 98 | std::string _md5; 99 | std::string _filename; 100 | std::vector image; 101 | std::vector _minimal_actions; 102 | }; // end rom class 103 | 104 | } // end namespace atari 105 | } // end namespace cule 106 | 107 | -------------------------------------------------------------------------------- /cule/atari/stack.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace cule 10 | { 11 | namespace atari 12 | { 13 | 14 | template 15 | struct stack 16 | { 17 | 18 | template 19 | static 20 | CULE_ANNOTATION 21 | void pushByte(State_t& s, const uint8_t byte) 22 | { 23 | #ifdef MONITOR_STACK 24 | printf("[S] Push 0x%02X to $%02X\n",byte,valueOf(s.SP)); 25 | #endif 26 | MMC_t::write(s, maddr_t(0x0100 + s.SP), byte); 27 | dec(s.SP); 28 | } 29 | 30 | template 31 | static 32 | CULE_ANNOTATION 33 | void pushReg(State_t& s, const reg_bit_field_t& reg) 34 | { 35 | pushByte(s, reg); 36 | } 37 | 38 | template 39 | static 40 | CULE_ANNOTATION 41 | void pushWord(State_t& s, const word_t& word) 42 | { 43 | #ifdef MONITOR_STACK 44 | printf("[S] Push 0x%04X to $%02X\n",word,valueOf(s.SP)); 45 | #endif 46 | // FATAL_ERROR_IF(s.SP.reachMax(), INVALID_MEMORY_ACCESS, ILLEGAL_ADDRESS_WARP); 47 | 48 | uint8_t* word_ptr = (uint8_t*)&word; 49 | pushByte(s, word_ptr[1]); 50 | pushByte(s, word_ptr[0]); 51 | } 52 | 53 | template 54 | static 55 | CULE_ANNOTATION 56 | void pushPC(State_t& s) 57 | { 58 | pushWord(s, s.PC); 59 | } 60 | 61 | template 62 | static 63 | CULE_ANNOTATION 64 | uint8_t popByte(State_t& s) 65 | { 66 | return MMC_t::read(s, maddr_t(0x0100 + inc(s.SP))); 67 | } 68 | 69 | template 70 | static 71 | CULE_ANNOTATION 72 | word_t popWord(State_t& s) 73 | { 74 | // FATAL_ERROR_UNLESS(s.SP.belowMax(), INVALID_MEMORY_ACCESS, ILLEGAL_ADDRESS_WARP); 75 | 76 | word_t word; 77 | uint8_t* word_ptr = (uint8_t*)&word; 78 | word_ptr[0] = popByte(s); 79 | word_ptr[1] = popByte(s); 80 | 81 | return word; 82 | } 83 | 84 | template 85 | static 86 | CULE_ANNOTATION 87 | void reset(State_t& s) 88 | { 89 | // move stack pointer to the top of the stack 90 | s.SP.selfSetMax(); 91 | } 92 | 93 | }; // end namespace stack 94 | 95 | } // end namespace atari 96 | } // end namespace cule 97 | 98 | -------------------------------------------------------------------------------- /cule/atari/state.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | namespace cule 12 | { 13 | namespace atari 14 | { 15 | 16 | struct state 17 | { 18 | // m6502 vars 19 | _reg8_t A; // accumulator 20 | _reg8_t X, Y; // index 21 | 22 | maddr8_t SP; // stack pointer 23 | maddr_t PC; // program counter 24 | maddr_t addr; // effective address 25 | 26 | uint8_t value; // operand 27 | uint8_t noise; 28 | 29 | uint16_t cpuCycles; 30 | uint16_t bank; 31 | 32 | // controller vars 33 | int32_t resistance; 34 | 35 | // TIA vars 36 | uint32_t GRP; 37 | uint32_t HM; 38 | uint32_t PF; 39 | uint32_t POS; 40 | uint8_t CurrentGRP0; 41 | uint8_t CurrentGRP1; 42 | 43 | uint16_t collision; 44 | int16_t clockWhenFrameStarted; 45 | int32_t clockAtLastUpdate; 46 | int32_t dumpDisabledCycle; 47 | int32_t VSYNCFinishClock; 48 | int32_t lastHMOVEClock; // Color clock when last HMOVE occurred 49 | 50 | // m6532 vars 51 | uint32_t riotData; 52 | int32_t cyclesWhenTimerSet; 53 | int32_t cyclesWhenInterruptReset; 54 | 55 | // state flags 56 | sys_flag_t sysFlags; 57 | tia_flag_t tiaFlags; 58 | 59 | // frame data 60 | uint32_t frameData; 61 | uint32_t rand; 62 | int32_t score; 63 | uint8_t M0CosmicArkCounter; 64 | 65 | // pointers 66 | uint32_t * ram; 67 | const uint8_t * rom; 68 | uint32_t * tia_update_buffer; 69 | 70 | uint32_t* CurrentPFMask; 71 | uint8_t * CurrentP0Mask; 72 | uint8_t * CurrentP1Mask; 73 | uint8_t * CurrentM0Mask; 74 | uint8_t * CurrentM1Mask; 75 | uint8_t * CurrentBLMask; 76 | }; 77 | 78 | } // end namespace atari 79 | } // end namespace cule 80 | 81 | -------------------------------------------------------------------------------- /cule/atari/tables.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace atari 10 | { 11 | 12 | // Compute the ball mask table 13 | void computeBallMaskTable(); 14 | 15 | // Compute the collision decode table 16 | void computeCollisionTable(); 17 | 18 | // Compute the missle mask table 19 | void computeMissleMaskTable(); 20 | 21 | // Compute the player mask table 22 | void computePlayerMaskTable(); 23 | 24 | // Compute the player position reset when table 25 | void computePlayerPositionResetWhenTable(); 26 | 27 | // Compute the player reflect table 28 | void computePlayerReflectTable(); 29 | 30 | // Compute playfield mask table 31 | void computePlayfieldMaskTable(); 32 | 33 | void computePriorityEncoding(); 34 | 35 | CULE_ANNOTATION 36 | uint32_t playfield_mask(const uint8_t side, const uint8_t x); 37 | 38 | CULE_ANNOTATION 39 | bool missle_mask(const uint8_t align, const uint8_t number, const uint8_t size, int16_t x); 40 | 41 | CULE_ANNOTATION 42 | uint8_t player_mask(const uint8_t align, const bool enable, const uint8_t mode, int16_t x); 43 | 44 | CULE_ANNOTATION 45 | bool ball_mask(const uint8_t align, const uint8_t size, int16_t x); 46 | 47 | CULE_ANNOTATION 48 | uint8_t reflect_mask(uint8_t b); 49 | 50 | // Compute the collision decode table 51 | uint16_t collision_mask(const uint8_t i); 52 | 53 | int8_t position_mask(const uint8_t mode, const uint8_t oldx, const uint8_t newx); 54 | 55 | } // end namespace atari 56 | } // end namespace cule 57 | 58 | #include 59 | 60 | #ifdef __CUDACC__ 61 | #include 62 | #endif 63 | -------------------------------------------------------------------------------- /cule/atari/types/types.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace cule 6 | { 7 | namespace atari 8 | { 9 | 10 | // forward declaration 11 | template class bit_field; 12 | template class flag_set; 13 | 14 | } // end namespace atari 15 | } // end namespace cule 16 | 17 | #include 18 | #include 19 | -------------------------------------------------------------------------------- /cule/atari/types/valueobj.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace cule 6 | { 7 | namespace atari 8 | { 9 | 10 | template 11 | class value_object 12 | { 13 | public: 14 | typedef DT DataTp; 15 | typedef VT ValueTp; 16 | 17 | CULE_ANNOTATION 18 | value_object() {} 19 | 20 | CULE_ANNOTATION 21 | value_object(const VT& value): _value(value) {} 22 | 23 | CULE_ANNOTATION 24 | friend const VT& valueOf(const value_object& vo) 25 | { 26 | return vo._value; 27 | } 28 | 29 | protected: 30 | DT _value; 31 | }; 32 | 33 | template 34 | class transparent_value_object : public value_object 35 | { 36 | private: 37 | using super_t = value_object; 38 | 39 | public: 40 | CULE_ANNOTATION 41 | transparent_value_object() {} 42 | 43 | CULE_ANNOTATION 44 | transparent_value_object(const VT& value) {} 45 | 46 | // transparent value getter 47 | CULE_ANNOTATION 48 | operator const VT&() const 49 | { 50 | return super_t::_value; 51 | } 52 | 53 | // transparent value setter 54 | CULE_ANNOTATION 55 | transparent_value_object& operator = (const value_object& other) 56 | { 57 | super_t::_value = other._value; 58 | } 59 | }; 60 | 61 | } // end namespace atari 62 | } // end namespace cule 63 | 64 | -------------------------------------------------------------------------------- /cule/config.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | // macro with optional string parameter 7 | // taken from http://stackoverflow.com/questions/3046889/optional-parameters-with-c-macros 8 | #define CULE_ASSERT_0() CULE_ASSERT_1(true) 9 | #define CULE_ASSERT_1(A) CULE_ASSERT_2(A, std::string()) 10 | #define CULE_ASSERT_2(A,B) CULE_ASSERT_3(A, B, std::runtime_error) 11 | #define CULE_ASSERT_3(A,B,C) \ 12 | if(not (A)) \ 13 | { \ 14 | std::ostringstream m; \ 15 | m << __FILE__ << ":" << __LINE__ << " in " << __func__ << "\n"; \ 16 | m << B << "\n"; \ 17 | throw C(m.str()); \ 18 | } 19 | 20 | // The interim macro that simply strips the excess and ends up with the required macro 21 | #define CULE_ASSERT_X(x,A,B,C,FUNC,...) FUNC 22 | 23 | // The macro that the programmer uses 24 | #ifdef __CUDA_ARCH__ 25 | #define CULE_ASSERT(...) 26 | #else 27 | #define CULE_ASSERT(...) CULE_ASSERT_X(,##__VA_ARGS__, \ 28 | CULE_ASSERT_3(__VA_ARGS__), \ 29 | CULE_ASSERT_2(__VA_ARGS__), \ 30 | CULE_ASSERT_1(__VA_ARGS__), \ 31 | CULE_ASSERT_0(__VA_ARGS__) \ 32 | ) 33 | #endif 34 | 35 | // Macro for unimplemented functions 36 | #define CULE_NOT_IMPLEMENTED CULE_ASSERT(false, std::string(__PRETTY_FUNCTION__)); 37 | 38 | // CULE_RETURNS() is used to avoid writing boilerplate "->decltype(x) { return x; }" phrases. 39 | // see https://gist.github.com/dabrahams/1457531 for details 40 | #define CULE_RETURNS(...) -> decltype(__VA_ARGS__) { return (__VA_ARGS__); } typedef int CULE_RETURNS_CAT(CULE_RETURNS_, __LINE__) 41 | #define CULE_RETURNS_CAT_0(x, y) x ## y 42 | #define CULE_RETURNS_CAT(x, y) CULE_RETURNS_CAT_0(x,y) 43 | 44 | #ifdef __CUDACC__ 45 | #define CULE_ERRCHK(ans) { cule::cuda::gpuAssert((ans), __FILE__, __LINE__); } 46 | 47 | #define CULE_CUDA_PEEK_AND_SYNC \ 48 | CULE_ERRCHK(cudaPeekAtLastError()); \ 49 | CULE_ERRCHK(cudaDeviceSynchronize()); 50 | 51 | #define CULE_ANNOTATION __host__ __device__ __forceinline__ 52 | #else 53 | #define CULE_ANNOTATION inline 54 | #define CULE_ERRCHK(ans) 55 | #endif 56 | 57 | #ifdef __CUDA_ARCH__ 58 | #define CULE_ARRAY_ACCESSOR(name) cule::atari::gpu_##name 59 | #else 60 | #define CULE_ARRAY_ACCESSOR(name) cule::atari::name 61 | #endif 62 | 63 | -------------------------------------------------------------------------------- /cule/cuda.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | -------------------------------------------------------------------------------- /cule/cuda/errchk.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace cule 8 | { 9 | namespace cuda 10 | { 11 | 12 | void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) 13 | { 14 | if (code != cudaSuccess) 15 | { 16 | fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); 17 | if (abort) exit(code); 18 | } 19 | } 20 | 21 | } // end namespace cuda 22 | } // end namespace cule 23 | 24 | -------------------------------------------------------------------------------- /cule/cuda/parallel_execution_policy.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include 8 | 9 | namespace cule 10 | { 11 | namespace cuda 12 | { 13 | 14 | class parallel_execution_policy : public agency::cuda::parallel_execution_policy 15 | { 16 | private: 17 | using super_t = agency::cuda::parallel_execution_policy; 18 | 19 | public: 20 | parallel_execution_policy() 21 | { 22 | CULE_ERRCHK(cudaStreamCreate(&stream)); 23 | } 24 | 25 | ~parallel_execution_policy() 26 | { 27 | CULE_ERRCHK(cudaStreamDestroy(stream)); 28 | } 29 | 30 | void sync() const 31 | { 32 | CULE_ERRCHK(cudaStreamSynchronize(stream)); 33 | } 34 | 35 | void insert_other_stream(const cudaStream_t& otherStream) const 36 | { 37 | cudaEvent_t event; 38 | 39 | CULE_ERRCHK(cudaEventCreate(&event)); 40 | CULE_ERRCHK(cudaEventRecord(event, otherStream)); 41 | CULE_ERRCHK(cudaStreamWaitEvent(stream, event, 0)); 42 | CULE_ERRCHK(cudaEventDestroy(event)); 43 | } 44 | 45 | void insert_this_stream(const cudaStream_t& otherStream) const 46 | { 47 | cudaEvent_t event; 48 | 49 | CULE_ERRCHK(cudaEventCreate(&event)); 50 | CULE_ERRCHK(cudaEventRecord(event, stream)); 51 | CULE_ERRCHK(cudaStreamWaitEvent(otherStream, event, 0)); 52 | CULE_ERRCHK(cudaEventDestroy(event)); 53 | } 54 | 55 | cudaStream_t getStream() const 56 | { 57 | return stream; 58 | } 59 | 60 | private: 61 | cudaStream_t stream; 62 | }; 63 | 64 | const parallel_execution_policy par{}; 65 | 66 | } // end namespace cuda 67 | } // end namespace cule 68 | 69 | -------------------------------------------------------------------------------- /cule/cule.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | -------------------------------------------------------------------------------- /cule/macros.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace cule 6 | { 7 | 8 | // e.g. SIZE_IN_BITS(int) is 32 and SIZE_IN_BITS(bool) is 1 9 | #define SIZE_IN_BITS(TP) (__BITSOFLAG_CLASS::NUM_BITS) 10 | 11 | // e.g. BIT_MASK(3)=7 12 | // bits must be positive 13 | #define BIT_MASK(TP, bits) (1|((((TP)1<<((bits)-1))-1)<<1)) 14 | 15 | // e.g. TYPE_MAX(int)=0xFFFFFFFF 16 | #define TYPE_MAX(TP) BIT_MASK(TP,SIZE_IN_BITS(TP)) 17 | 18 | // bit manipulation 19 | #define LOW_BIT(x) ((x)&(-(x))) 20 | #define RTRIM(x) ((x)/LOW_BIT(x)) 21 | 22 | #define SINGLE_BIT(x) (((x)&((x)-1))==0) 23 | 24 | #define SELECT_FIELD(x, f) (((x)&(f))/LOW_BIT(f)) 25 | #define UPDATE_FIELD(x, f, y) x=((x)&(~f))|(((y)*LOW_BIT(f))&(f)) 26 | #define INC_FIELD(x, f) x=((x)&(~(f))) | ( (((x)&(f)) + LOW_BIT(f)) & (f) ); 27 | 28 | #define CASE_ENUM_RETURN_STRING(ENUM) case ENUM: return #ENUM 29 | 30 | #define fast_cast(VAR,TP) ((TP&)*(TP*)(&(VAR))) 31 | #define fast_constcast(VAR,TP) ((const TP&)*(const TP*)(&(VAR))) 32 | 33 | // type casting 34 | template 35 | destType& safe_cast(srcType& source) 36 | { 37 | return *(destType*)(&source); 38 | } 39 | 40 | template 41 | CULE_ANNOTATION 42 | T min(const T& x,const T& y) 43 | { 44 | return x 48 | CULE_ANNOTATION 49 | T max(const T& x,const T& y) 50 | { 51 | return x>y?x:y; 52 | } 53 | 54 | // Type info 55 | template 56 | class __BITSOFLAG_CLASS 57 | { 58 | public: 59 | __BITSOFLAG_CLASS(); 60 | enum :int { 61 | NUM_BITS=sizeof(T)<<3 62 | }; 63 | }; 64 | 65 | template <> 66 | class __BITSOFLAG_CLASS 67 | { 68 | public: 69 | __BITSOFLAG_CLASS(); 70 | enum :int { 71 | NUM_BITS=1 72 | }; 73 | }; 74 | 75 | } // end namespace cule 76 | 77 | -------------------------------------------------------------------------------- /envs/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.7.1-devel-ubuntu22.04 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | ENV TZ=US 5 | 6 | RUN apt-get -y update -qq && apt-get install -y --no-install-recommends \ 7 | build-essential \ 8 | ca-certificates \ 9 | clang \ 10 | gcc \ 11 | cmake \ 12 | htop \ 13 | curl \ 14 | git \ 15 | libomp-dev \ 16 | libsm6 \ 17 | libssl-dev \ 18 | libxrender-dev \ 19 | libxext-dev \ 20 | iproute2 \ 21 | python3.9 \ 22 | python3-dev \ 23 | python3-setuptools \ 24 | python3-pip \ 25 | vim \ 26 | ssh \ 27 | wget \ 28 | vim \ 29 | zip \ 30 | && \ 31 | rm -rf /var/lib/apt/lists/* && \ 32 | ln -s /usr/bin/python3.9 /usr/bin/python 33 | ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64" 34 | 35 | RUN pip install --upgrade cython \ 36 | cloudpickle \ 37 | gym[atari] \ 38 | opencv-python \ 39 | psutil \ 40 | torch==1.11.0 \ 41 | torchvision==0.12.0 \ 42 | tqdm 43 | 44 | RUN git clone -b master --recursive https://github.com/NVLabs/cule && \ 45 | cd cule && \ 46 | python setup.py install 47 | -------------------------------------------------------------------------------- /envs/environment.yml: -------------------------------------------------------------------------------- 1 | name: cule 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - matplotlib 6 | - pytorch 7 | - pip: 8 | - gym[atari] 9 | - opencv-python 10 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/cule/dd0382b99ded6be23cd3c3e79e37938e7c873de0/examples/__init__.py -------------------------------------------------------------------------------- /examples/a2c/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/cule/dd0382b99ded6be23cd3c3e79e37938e7c873de0/examples/a2c/__init__.py -------------------------------------------------------------------------------- /examples/a2c/a2c_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | _path = os.path.abspath(os.path.pardir) 5 | if not _path in sys.path: 6 | sys.path = [_path] + sys.path 7 | 8 | from utils.launcher import main 9 | 10 | from train import worker 11 | 12 | def a2c_parser_options(parser): 13 | parser.add_argument('--entropy-coef', type=float, default=0.01, help='entropy term coefficient (default: 0.01)') 14 | parser.add_argument('--lr-scale', action='store_true', default=False, help='Scale the learning rate with the batch-size') 15 | parser.add_argument('--num-stack', type=int, default=4, help='number of images in a stack (default: 4)') 16 | parser.add_argument('--num-steps', type=int, default=5, help='number of forward steps in A2C (default: 5)') 17 | parser.add_argument('--tau', type=float, default=1.00, help='parameter for GAE (default: 1.00)') 18 | parser.add_argument('--use-gae', action='store_true', default=False, help='use generalized advantage estimation') 19 | parser.add_argument('--value-loss-coef', type=float, default=0.5, help='value loss coefficient (default: 0.5)') 20 | 21 | return parser 22 | 23 | def a2c_main(): 24 | if sys.version_info.major == 3: 25 | from train import worker 26 | else: 27 | worker = None 28 | 29 | sys.exit(main(a2c_parser_options, worker)) 30 | 31 | if __name__ == '__main__': 32 | a2c_main() 33 | -------------------------------------------------------------------------------- /examples/a2c/helper.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import pytz 3 | import torch 4 | 5 | total_time = 0 6 | last_save = 0 7 | 8 | def gen_data(x): 9 | return [f(x.float()).item() for f in [torch.mean, torch.median, torch.min, torch.max, torch.std]] 10 | 11 | def format_time(f): 12 | return datetime.fromtimestamp(f, tz=pytz.utc).strftime('%H:%M:%S.%f s') 13 | 14 | def callback(args, model, frames, iter_time, rewards, lengths, 15 | value_loss, policy_loss, entropy, csv_writer, csv_file): 16 | global last_save, total_time 17 | 18 | if not hasattr(args, 'num_steps_per_update'): 19 | args.num_steps_per_update = args.num_steps 20 | 21 | total_time += iter_time 22 | fps = (args.world_size * args.num_steps_per_update * args.num_ales) / iter_time 23 | lmean, lmedian, lmin, lmax, lstd = gen_data(lengths) 24 | rmean, rmedian, rmin, rmax, rstd = gen_data(rewards) 25 | 26 | if frames >= last_save: 27 | last_save += args.save_interval 28 | 29 | # torch.save(model.state_dict(), args.model_name) 30 | 31 | if csv_writer and csv_file: 32 | csv_writer.writerow([frames, fps, total_time, 33 | rmean, rmedian, rmin, rmax, rstd, 34 | lmean, lmedian, lmin, lmax, lstd, 35 | entropy, value_loss, policy_loss]) 36 | csv_file.flush() 37 | 38 | str_template = '{fps:8.2f}f/s, ' \ 39 | 'min/max/mean/median reward: {rmin:5.1f}/{rmax:5.1f}/{rmean:5.1f}/{rmedian:5.1f}, ' \ 40 | 'entropy/value/policy: {entropy:6.4f}/{value:6.4f}/{policy: 6.4f}' 41 | 42 | return str_template.format(fps=fps, rmin=rmin, rmax=rmax, rmean=rmean, rmedian=rmedian, 43 | entropy=entropy, value=value_loss, policy=policy_loss) 44 | -------------------------------------------------------------------------------- /examples/a2c/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import re 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | _path = os.path.abspath(os.path.pardir) 10 | if not _path in sys.path: 11 | sys.path = [_path] + sys.path 12 | 13 | from utils.openai.vec_normalize import RunningMeanStd 14 | 15 | def weights_init(m): 16 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 17 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 18 | if m.bias is not None: 19 | nn.init.zeros_(m.bias) 20 | 21 | class ActorCritic(nn.Module): 22 | 23 | def __init__(self, num_inputs, action_space, normalize=False, name=None): 24 | super(ActorCritic, self).__init__() 25 | 26 | self._name = name 27 | 28 | self.conv1 = nn.Conv2d(in_channels=num_inputs, out_channels=32, kernel_size=8, stride=4) 29 | self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2) 30 | self.conv3 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1) 31 | 32 | conv_out_size = self._get_conv_out((num_inputs, 84, 84)) 33 | self.linear1 = nn.Linear(in_features=conv_out_size, out_features=512) 34 | 35 | self.critic_linear = nn.Linear(in_features=512, out_features=1) 36 | self.actor_linear = nn.Linear(in_features=512, out_features=action_space.n) 37 | 38 | self.apply(weights_init) 39 | 40 | relu_gain = nn.init.calculate_gain('relu') 41 | self.conv1.weight.data.mul_(relu_gain) 42 | self.conv2.weight.data.mul_(relu_gain) 43 | self.conv3.weight.data.mul_(relu_gain) 44 | self.linear1.weight.data.mul_(relu_gain) 45 | 46 | self.ob_rms = RunningMeanStd(shape=(84, 84)) if normalize else None 47 | 48 | def _get_conv_out(self, shape): 49 | o = self.conv1(torch.zeros(1, *shape)) 50 | o = self.conv2(o) 51 | o = self.conv3(o) 52 | return int(np.prod(o.size())) 53 | 54 | def forward(self, x): 55 | with torch.no_grad(): 56 | if self.ob_rms: 57 | if self.training: 58 | self.ob_rms.update(x) 59 | mean = self.ob_rms.mean.to(dtype=torch.float32, device=x.device) 60 | std = torch.sqrt(self.ob_rms.var.to(dtype=torch.float32, device=x.device) + float(np.finfo(np.float32).eps)) 61 | x = (x - mean) / std 62 | 63 | x = x.to(dtype=self.conv1.weight.dtype) 64 | x = F.relu(self.conv1(x)) 65 | x = F.relu(self.conv2(x)) 66 | x = F.relu(self.conv3(x)) 67 | 68 | x = x.view(x.size(0), -1) 69 | x = F.relu(self.linear1(x)) 70 | 71 | return self.critic_linear(x), self.actor_linear(x) 72 | 73 | def name(self): 74 | return self._name 75 | 76 | def save(self): 77 | if self.name(): 78 | name = '{}.pth'.format(self.name()) 79 | torch.save(self.state_dict(), name) 80 | 81 | def load(self, name=None): 82 | self.load_state_dict(torch.load(name if name else self.name())) 83 | 84 | -------------------------------------------------------------------------------- /examples/a2c/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | def test(args, policy_net, env): 7 | device = next(policy_net.parameters()).device 8 | 9 | width, height = 84, 84 10 | num_ales = args.evaluation_episodes 11 | 12 | if args.use_openai_test_env: 13 | observation = torch.from_numpy(env.reset()).squeeze(1) 14 | else: 15 | observation = env.reset(initial_steps=50).squeeze(-1) 16 | 17 | lengths = torch.zeros(num_ales, dtype=torch.int32) 18 | rewards = torch.zeros(num_ales, dtype=torch.float32) 19 | all_done = torch.zeros(num_ales, dtype=torch.bool) 20 | not_done = torch.ones(num_ales, dtype=torch.bool) 21 | 22 | fire_reset = torch.zeros(num_ales, dtype=torch.bool) 23 | actions = torch.ones(num_ales, dtype=torch.uint8) 24 | 25 | maybe_npy = lambda a: a.numpy() if args.use_openai_test_env else a 26 | 27 | info = env.step(maybe_npy(actions))[-1] 28 | if args.use_openai_test_env: 29 | lives = torch.IntTensor([d['ale.lives'] for d in info]) 30 | else: 31 | lives = info['ale.lives'].clone() 32 | 33 | states = torch.zeros((num_ales, args.num_stack, width, height), device=device, dtype=torch.float32) 34 | states[:, -1] = observation.to(device=device, dtype=torch.float32) 35 | 36 | policy_net.eval() 37 | 38 | while not all_done.all(): 39 | logit = policy_net(states)[1] 40 | 41 | actions = F.softmax(logit, dim=1).multinomial(1).cpu() 42 | actions[fire_reset] = 1 43 | 44 | observation, reward, done, info = env.step(maybe_npy(actions)) 45 | 46 | if args.use_openai_test_env: 47 | # convert back to pytorch tensors 48 | observation = torch.from_numpy(observation) 49 | reward = torch.from_numpy(reward.astype(np.float32)) 50 | done = torch.from_numpy(done.astype(np.bool)) 51 | new_lives = torch.IntTensor([d['ale.lives'] for d in info]) 52 | else: 53 | new_lives = info['ale.lives'].clone() 54 | 55 | fire_reset = new_lives < lives 56 | lives.copy_(new_lives) 57 | 58 | observation = observation.to(device=device, dtype=torch.float32) 59 | 60 | states[:, :-1].copy_(states[:, 1:].clone()) 61 | states *= (1.0 - done.to(device=device, dtype=torch.float32)).view(-1, *[1] * (observation.dim() - 1)) 62 | states[:, -1].copy_(observation.view(-1, *states.size()[-2:])) 63 | 64 | # update episodic reward counters 65 | lengths += not_done.int() 66 | rewards += reward.cpu() * not_done.float().cpu() 67 | 68 | all_done |= done.cpu() 69 | all_done |= (lengths >= args.max_episode_length) 70 | not_done = (all_done == False).int() 71 | 72 | policy_net.train() 73 | 74 | return lengths, rewards 75 | -------------------------------------------------------------------------------- /examples/dqn/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Kai Arulkumaran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /examples/dqn/README.md: -------------------------------------------------------------------------------- 1 | Rainbow 2 | ======= 3 | Rainbow: Combining Improvements in Deep Reinforcement Learning. 4 | 5 | This implementation is based on the excellent example code provided by 6 | [Kaixhin](https://github.com/Kaixhin/Rainbow). 7 | -------------------------------------------------------------------------------- /examples/dqn/benchmark.config: -------------------------------------------------------------------------------- 1 | [Defaults] 2 | adam_eps=0.00015 3 | ale_start_steps=4000 4 | atoms=51 5 | batch_size=32 6 | categorical=False 7 | discount=0.99 8 | double_q=False 9 | dueling=False 10 | evaluate=False 11 | evaluation_episodes=10 12 | evaluation_interval=200000 13 | evaluation_size=500 14 | gpu=0 15 | hidden_size=512 16 | history_length=4 17 | learn_start=80000 18 | log_interval=100 19 | lr=0.0000625 20 | max_episode_length=18000 21 | max_grad_norm=1 22 | memory_capacity=500000 23 | multi_step=3 24 | noisy_linear=False 25 | noisy_std=0.1 26 | normalize=False 27 | num_ales=32 28 | plot=True 29 | priority_exponent=0.7 30 | priority_replay=False 31 | priority_weight=0.5 32 | rainbow=False 33 | replay_frequency=4 34 | reward_clip=True 35 | t_max=12500000 36 | target_update=32000 37 | v_max=10 38 | v_min=-10 39 | verbose=False 40 | use_openai=False 41 | -------------------------------------------------------------------------------- /examples/ppo/benchmark.config: -------------------------------------------------------------------------------- 1 | [Defaults] 2 | ale_start_steps=400 3 | alpha=0.99 4 | batch_size=256 5 | clip_epsilon=0.1 6 | entropy_coef=0.01 7 | env_name=PongNoFrameskip-v4 8 | episodic_life=True 9 | eps=1e-05 10 | evaluation_episodes=10 11 | evaluation_interval=50000 12 | gamma=0.99 13 | gpu=0 14 | local_rank=0 15 | log_dir=runs 16 | loss_scale=None 17 | lr=0.00025 18 | lr_scale=False 19 | max_episode_length=18000 20 | max_grad_norm=0.5 21 | multiprocessing_distributed=False 22 | no_cuda_train=True 23 | normalize=False 24 | num_ales=8 25 | num_gpus_per_node=-1 26 | num_stack=4 27 | num_steps=128 28 | opt_level=O0 29 | output_filename=None 30 | plot=False 31 | ppo_epoch=3 32 | profile=False 33 | save_interval=0 34 | seed=1567992099 35 | t_max=50000000 36 | tau=1.0 37 | use_adam=True 38 | use_cuda_env=False 39 | use_gae=False 40 | use_openai=False 41 | use_openai_test_env=False 42 | value_loss_coef=0.5 43 | verbose=False 44 | -------------------------------------------------------------------------------- /examples/ppo/ppo_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | current_path = os.path.dirname(os.path.realpath(__file__)) 5 | _path = os.path.abspath(os.path.join(current_path, os.pardir)) 6 | if not _path in sys.path: 7 | sys.path = [_path] + sys.path 8 | 9 | from a2c.a2c_main import a2c_parser_options 10 | from utils.launcher import main 11 | 12 | def ppo_parser_options(parser): 13 | parser = a2c_parser_options(parser) 14 | 15 | parser.add_argument('--batch-size', type=int, default=256) 16 | parser.add_argument('--clip-epsilon', type=float, default=0.1, help='ppo clip parameter (default: 0.1)') 17 | parser.add_argument('--ppo-epoch', type=int, default=3, help='Number of ppo epochs (default: 3)') 18 | 19 | return parser 20 | 21 | def ppo_main(): 22 | if sys.version_info.major == 3: 23 | from train import worker 24 | else: 25 | worker = None 26 | 27 | sys.exit(main(ppo_parser_options, worker)) 28 | 29 | if __name__ == '__main__': 30 | ppo_main() 31 | -------------------------------------------------------------------------------- /examples/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/cule/dd0382b99ded6be23cd3c3e79e37938e7c873de0/examples/utils/__init__.py -------------------------------------------------------------------------------- /examples/utils/openai/LICENSE.md: -------------------------------------------------------------------------------- 1 | # gym 2 | 3 | The MIT License 4 | 5 | Copyright (c) 2016 OpenAI (https://openai.com) 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | 25 | # Mujoco models 26 | This work is derived from [MuJuCo models](http://www.mujoco.org/forum/index.php?resources/) used under the following license: 27 | ``` 28 | This file is part of MuJoCo. 29 | Copyright 2009-2015 Roboti LLC. 30 | Mujoco :: Advanced physics simulation engine 31 | Source : www.roboti.us 32 | Version : 1.31 33 | Released : 23Apr16 34 | Author :: Vikash Kumar 35 | Contacts : kumar@roboti.us 36 | ``` 37 | -------------------------------------------------------------------------------- /examples/utils/openai/README.md: -------------------------------------------------------------------------------- 1 | Based on code from [OpenAI Gym](https://github.com/openai/gym). 2 | -------------------------------------------------------------------------------- /examples/utils/openai/envs.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym.spaces import Box 3 | 4 | from .atari_wrappers import make_atari, wrap_deepmind 5 | from .subproc_vec_env import SubprocVecEnv 6 | 7 | class WrapPyTorch(gym.ObservationWrapper): 8 | def __init__(self, env=None): 9 | super(WrapPyTorch, self).__init__(env) 10 | self.observation_space = Box(0.0, 1.0, [1, 84, 84]) 11 | 12 | def observation(self, observation): 13 | return observation.transpose(2, 0, 1) 14 | 15 | def create_atari_env(env_id, seed=0, rank=0, episode_life=False, clip_rewards=False, deepmind=True, max_frames=18000): 16 | def _thunk(): 17 | env = make_atari(env_id) 18 | env.seed(seed + rank) 19 | if deepmind: 20 | env = wrap_deepmind(env, episode_life=episode_life, clip_rewards=clip_rewards) 21 | env = WrapPyTorch(env) 22 | return env 23 | return _thunk 24 | 25 | def create_vectorize_atari_env(env_id, seed, num_envs, episode_life=False, clip_rewards=False, deepmind=True, max_frames=18000): 26 | return SubprocVecEnv([create_atari_env(env_id, 27 | seed=seed, 28 | rank=proc_id, 29 | episode_life=episode_life, 30 | clip_rewards=clip_rewards, 31 | deepmind=deepmind, 32 | max_frames=max_frames) for proc_id in range(num_envs)]) 33 | 34 | -------------------------------------------------------------------------------- /examples/utils/openai/vec_normalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py 4 | class RunningMeanStd(object): 5 | def __init__(self, epsilon=1e-4, shape=()): 6 | self.mean = torch.zeros(shape, dtype=torch.double) 7 | self.var = torch.ones(shape, dtype=torch.double) 8 | self.count = epsilon 9 | 10 | def update(self, x): 11 | self.mean = self.mean.to(device=x.device) 12 | self.var = self.var.to(device=x.device) 13 | 14 | batch_mean = torch.mean(x, dim=0).double() 15 | batch_var = torch.var(x, dim=0).double() 16 | batch_count = x.size(0) 17 | self.update_from_moments(batch_mean, batch_var, batch_count) 18 | 19 | def update_from_moments(self, batch_mean, batch_var, batch_count): 20 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 21 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count) 22 | 23 | def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count): 24 | delta = batch_mean - mean 25 | tot_count = count + batch_count 26 | 27 | new_mean = mean + delta * batch_count / tot_count 28 | m_a = var * count 29 | m_b = batch_var * batch_count 30 | M2 = m_a + m_b + torch.pow(delta, 2).double() * count * batch_count / tot_count 31 | new_var = M2 / tot_count 32 | new_count = tot_count 33 | 34 | return new_mean, new_var, new_count 35 | -------------------------------------------------------------------------------- /examples/vtrace/benchmark_vtrace.py: -------------------------------------------------------------------------------- 1 | import re 2 | import gym 3 | import os 4 | 5 | def atari_games(): 6 | pattern = re.compile('\w+NoFrameskip-v4') 7 | return [env_spec.id for env_spec in gym.envs.registry.all() if pattern.match(env_spec.id)] 8 | 9 | env_names = atari_games() 10 | env_names.remove('QbertNoFrameskip-v4') 11 | env_names.remove('ElevatorActionNoFrameskip-v4') 12 | env_names.remove('DefenderNoFrameskip-v4') 13 | num_ales_list = [1024, 2048, 16, 4096] #[1, 32, 64, 128, 256, 512, 1024, 2048, 4096] 14 | 15 | for num_ales in num_ales_list: 16 | for env_name in env_names: 17 | 18 | if num_ales < 1025: 19 | os.system('python vtrace_main.py --benchmark --num-ales ' + str(num_ales) + ' --env-name ' + env_name + ' --num-steps 5 --num-minibatches 1 --num-steps-per-update 5 --normalize --use-openai') 20 | os.system('python vtrace_main.py --benchmark --num-ales ' + str(num_ales) + ' --env-name ' + env_name + ' --num-steps 5 --num-minibatches 1 --num-steps-per-update 5 --normalize') 21 | os.system('python vtrace_main.py --benchmark --num-ales ' + str(num_ales) + ' --env-name ' + env_name + ' --num-steps 5 --num-minibatches 1 --num-steps-per-update 5 --normalize --use-cuda-env') 22 | -------------------------------------------------------------------------------- /examples/vtrace/test_vtrace.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import os 4 | 5 | parser = argparse.ArgumentParser(description='Test A2C+V-trace, multiple configurations') 6 | parser.add_argument('--game-name', default='PongNoFrameskip-v4', help='name of the game (default = PongNoFrameskip-v4)') 7 | parser.add_argument('--t-max', default=20, type=int, help='number of training frames (default=20M)') 8 | args = parser.parse_args() 9 | 10 | envs = [120, 120, 120, 1200, 1200, 1200, 1200, 1200*4] 11 | n_steps = [5, 5, 20, 20, 5, 5, 20, 20] 12 | n_steps_per_update = [5, 1, 1, 1, 5, 1, 1, 1] 13 | n_minibatches = [1, 5, 20, 20, 1, 5, 20, 20] 14 | n_gpus = [0, 0, 0, 0, 1, 1, 1, 4] 15 | n_configs = len(n_gpus) 16 | 17 | for n_test in range(0, 3): 18 | for n_config in range(0, n_configs): 19 | 20 | t_max = args.t_max 21 | if n_gpus[n_config] == 0: 22 | base_cmd_string = ' --use-openai --use-openai-test-env' 23 | if n_gpus[n_config] == 1: 24 | base_cmd_string = ' --use-openai-test-env --use-cuda-env' 25 | if n_gpus[n_config] == 4: 26 | t_max = t_max * 2 27 | base_cmd_string = ' --multiprocessing-distributed --use-cuda-env' 28 | base_cmd_string = base_cmd_string + ' --normalize ' 29 | output_filename = 'a2cvtrace_' + args.game_name + '_nenvs_' + str(envs[n_config]) + '_nsteps_' + str(n_steps[n_config]) + \ 30 | '_nstepsperupdate_' + str(n_steps_per_update[n_config]) + '_nminibatches_' + str(n_minibatches[n_config]) + \ 31 | '_n_gpus_' + str(n_gpus[n_config]) + '_ntest_' + str(n_test) + '.csv' 32 | common_cmd_string = ' --env-name=' + args.game_name + ' --num-ales=' + str(envs[n_config]) + \ 33 | ' --num-steps=' + str(n_steps[n_config]) + ' --num-steps-per-update=' + str(n_steps_per_update[n_config]) + \ 34 | ' --num-minibatches=' + str(n_minibatches[n_config]) + ' --t-max=' + str(t_max) + \ 35 | ' --evaluation-interval=500000 --output-filename=/results/' + output_filename 36 | 37 | cmd_string = base_cmd_string + common_cmd_string 38 | os.system('python vtrace_main.py ' + cmd_string) 39 | -------------------------------------------------------------------------------- /examples/vtrace/vtrace_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | _path = os.path.abspath(os.path.pardir) 5 | if not _path in sys.path: 6 | sys.path = [_path] + sys.path 7 | 8 | from a2c.a2c_main import a2c_parser_options 9 | from utils.launcher import main 10 | 11 | def vtrace_parser_options(parser): 12 | parser = a2c_parser_options(parser) 13 | 14 | parser.add_argument('--c-hat', type=int, default=1.0, help='Trace cutting truncation level (default: 1.0)') 15 | parser.add_argument('--rho-hat', type=int, default=1.0, help='Temporal difference truncation level (default: 1.0)') 16 | parser.add_argument('--num-minibatches', type=int, default=16, help='number of mini-batches in the set of environments (default: 16)') 17 | parser.add_argument('--num-steps-per-update', type=int, default=1, help='number of steps per update (default: 1)') 18 | 19 | parser.add_argument('--benchmark', action='store_true', help='Special case: benchmark') 20 | 21 | return parser 22 | 23 | def vtrace_main(): 24 | if sys.version_info.major == 3: 25 | from train import worker 26 | else: 27 | worker = None 28 | 29 | sys.exit(main(vtrace_parser_options, worker)) 30 | 31 | if __name__ == '__main__': 32 | vtrace_main() 33 | -------------------------------------------------------------------------------- /media/images/System.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/cule/dd0382b99ded6be23cd3c3e79e37938e7c873de0/media/images/System.png -------------------------------------------------------------------------------- /torchcule/.pylintrc: -------------------------------------------------------------------------------- 1 | [TYPECHECK] 2 | 3 | ignored-modules = torch,torchcule.th_cule 4 | -------------------------------------------------------------------------------- /torchcule/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/cule/dd0382b99ded6be23cd3c3e79e37938e7c873de0/torchcule/__init__.py -------------------------------------------------------------------------------- /torchcule/atari/__init__.py: -------------------------------------------------------------------------------- 1 | from torchcule.atari.env import Env 2 | from torchcule.atari.rom import Rom 3 | from torchcule.atari.state import State 4 | __all__ = ['Env', 'Rom', 'State'] 5 | -------------------------------------------------------------------------------- /torchcule/atari/rom.py: -------------------------------------------------------------------------------- 1 | """CuLE (CUda Learning Environment module) 2 | 3 | This module provides access to several RL environments that generate data 4 | on the CPU or GPU. 5 | """ 6 | 7 | import atari_py 8 | import gym 9 | import os 10 | 11 | from torchcule_atari import AtariRom 12 | 13 | class Rom(AtariRom): 14 | 15 | def __init__(self, env_name): 16 | ### TODO improve method to get base game name ### 17 | game_path = atari_py.get_game_path(env_name.split('No')[0].lower()) 18 | if not os.path.exists(game_path): 19 | raise IOError('Requested environment (%s) does not exist ' 20 | 'in valid list of environments:\n%s' \ 21 | % (env_name, ', '.join(sorted(atari_py.list_games())))) 22 | super(Rom, self).__init__(game_path) 23 | 24 | def __repr__(self): 25 | return 'Name : {}\n'\ 26 | 'Controller : {}\n'\ 27 | 'Swapped : {}\n'\ 28 | 'Left Diff : {}\n'\ 29 | 'Right Diff : {}\n'\ 30 | 'Type : {}\n'\ 31 | 'Display : {}\n'\ 32 | 'ROM Size : {}\n'\ 33 | 'RAM Size : {}\n'\ 34 | 'MD5 : {}\n'\ 35 | .format(self.game_name(), 36 | 'Paddles' if self.use_paddles() else 'Joystick', 37 | 'Yes' if self.swap_paddles() or self.swap_ports() else 'No', 38 | 'B' if self.player_left_difficulty_B() else 'A', 39 | 'B' if self.player_right_difficulty_B() else 'A', 40 | self.type(), 41 | 'NTSC' if self.is_ntsc() else 'PAL', 42 | self.rom_size(), 43 | self.ram_size(), 44 | self.md5()) 45 | 46 | -------------------------------------------------------------------------------- /torchcule/atari_env.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | #include 11 | 12 | class AtariEnv : public cule::atari::wrapper 13 | { 14 | private: 15 | using super_t = cule::atari::wrapper; 16 | 17 | public: 18 | AtariEnv(const cule::atari::rom& cart, 19 | const size_t num_envs, 20 | const size_t noop_reset_steps); 21 | 22 | ~AtariEnv(); 23 | 24 | void reset(uint32_t* seedBuffer); 25 | 26 | void reset_states(); 27 | 28 | void get_states(const size_t num_states, 29 | const int32_t* indices, 30 | AtariState* states); 31 | 32 | void set_states(const size_t num_states, 33 | const int32_t* indices, 34 | const AtariState* states); 35 | 36 | void step(const bool fire_reset, 37 | const cule::atari::Action* playerABuffer, 38 | const cule::atari::Action* playerBBuffer, 39 | bool* doneBuffer); 40 | 41 | void get_data(const bool episodic_life, 42 | bool* doneBuffer, 43 | float* rewardsBuffer, 44 | int32_t* livesBuffer); 45 | 46 | void generate_frames(const bool rescale, 47 | const bool last_frame, 48 | const size_t num_channels, 49 | uint8_t* imageBuffer); 50 | 51 | void generate_random_actions(cule::atari::Action* actionBuffer); 52 | 53 | void set_cuda(const bool use_cuda, const int32_t gpu_id); 54 | 55 | size_t state_size(); 56 | 57 | size_t frame_state_size(); 58 | 59 | size_t tia_update_size(); 60 | 61 | void sync_other_stream(cudaStream_t& stream); 62 | 63 | void sync_this_stream(cudaStream_t& stream); 64 | 65 | template 66 | ExecutionPolicy& get_policy(); 67 | 68 | private: 69 | void* cule_par; 70 | size_t num_channels; 71 | bool rescale; 72 | bool use_cuda; 73 | int32_t gpu_id; 74 | }; 75 | 76 | -------------------------------------------------------------------------------- /torchcule/atari_state.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | struct AtariState 8 | { 9 | AtariState(){} 10 | 11 | // ale vars 12 | int32_t left_paddle; 13 | int32_t right_paddle; 14 | int32_t frame_number; 15 | int32_t episode_frame_number; 16 | int32_t string_length; 17 | bool save_system; 18 | std::string md5; 19 | 20 | // system vars 21 | int32_t cycles; 22 | 23 | // m6502 vars 24 | int32_t A; // accumulator 25 | int32_t X, Y; // index 26 | 27 | int32_t SP; // stack pointer 28 | int32_t PC; // program counter 29 | int32_t IR; // interrupt 30 | 31 | bool N; 32 | bool V; 33 | bool B; 34 | bool D; 35 | bool I; 36 | bool notZ; 37 | bool C; 38 | 39 | int32_t executionStatus; 40 | 41 | // m6532 vars 42 | std::array ram; 43 | 44 | int32_t timer; 45 | int32_t intervalShift; 46 | int32_t cyclesWhenTimerSet; 47 | int32_t cyclesWhenInterruptReset; 48 | bool timerReadAfterInterrupt; 49 | int32_t DDRA; 50 | int32_t DDRB; 51 | 52 | // TIA vars 53 | int32_t clockWhenFrameStarted; 54 | int32_t clockStartDisplay; 55 | int32_t clockStopDisplay; 56 | int32_t clockAtLastUpdate; 57 | int32_t clocksToEndOfScanLine; 58 | int32_t scanlineCountForLastFrame; 59 | int32_t currentScanline; 60 | int32_t VSYNCFinishClock; 61 | int32_t enabledObjects; 62 | int32_t VSYNC; 63 | int32_t VBLANK; 64 | int32_t NUSIZ0; 65 | int32_t NUSIZ1; 66 | int32_t COLUP0; 67 | int32_t COLUP1; 68 | int32_t COLUPF; 69 | int32_t COLUBK; 70 | int32_t CTRLPF; 71 | int32_t playfieldPriorityAndScore; 72 | bool REFP0; 73 | bool REFP1; 74 | int32_t PF; 75 | int32_t GRP0; 76 | int32_t GRP1; 77 | int32_t DGRP0; 78 | int32_t DGRP1; 79 | int32_t ENAM0; 80 | int32_t ENAM1; 81 | int32_t ENABL; 82 | int32_t DENABL; 83 | int32_t HMP0; 84 | int32_t HMP1; 85 | int32_t HMM0; 86 | int32_t HMM1; 87 | int32_t HMBL; 88 | int32_t VDELP0; 89 | int32_t VDELP1; 90 | int32_t VDELBL; 91 | int32_t RESMP0; 92 | int32_t RESMP1; 93 | int32_t collision; 94 | int32_t POSP0; 95 | int32_t POSP1; 96 | int32_t POSM0; 97 | int32_t POSM1; 98 | int32_t POSBL; 99 | int32_t currentGRP0; 100 | int32_t currentGRP1; 101 | int32_t lastHMOVEClock; 102 | int32_t HMOVEBlankEnabled; 103 | int32_t M0CosmicArkMotionEnabled; 104 | int32_t M0CosmicArkCounter; 105 | int32_t dumpEnabled; 106 | int32_t dumpDisabledCycle; 107 | 108 | int32_t bank; 109 | int32_t reward; 110 | int32_t score; 111 | bool terminal; 112 | bool started; 113 | int32_t lives; 114 | int32_t points; 115 | int32_t last_lives; 116 | }; 117 | 118 | struct encode_states_functor; 119 | struct decode_states_functor; 120 | 121 | --------------------------------------------------------------------------------