├── python ├── common │ ├── __init__.py │ ├── constants.py │ ├── logging_options.py │ ├── test.py │ └── entity.py ├── requirements.txt ├── agents │ ├── __init__.py │ └── simple_agent.py ├── towerfall │ ├── __init__.py │ ├── connection.py │ └── towerfall.py ├── gym_wrapper │ ├── __init__.py │ ├── objective.py │ ├── observation.py │ ├── actions.py │ ├── blank_env.py │ ├── player_observation.py │ ├── kill_enemy_objective.py │ └── base_env.py ├── gym_wrapper_test.py ├── run_simple_agent_quest.py ├── run_simple_agent_versus.py ├── run_simple_agent_versus_2v2.py ├── run_simple_agent_versus_human.py ├── run_simple_agent_sandbox.py ├── reset_test.py └── end_to_end_test.py ├── .gitignore ├── banner.png └── TowerFallAi ├── TowerfallAiMod ├── Api │ ├── State.cs │ ├── StateLava.cs │ ├── StateHat.cs │ ├── StateItem.cs │ ├── StateChest.cs │ ├── StateFalling.cs │ ├── StateSubType.cs │ ├── StateCrackedWall.cs │ ├── StateKingReaper.cs │ ├── StateProximityBlock.cs │ ├── StateArrow.cs │ ├── StateReaperBeam.cs │ ├── StateSpikeBall.cs │ ├── StateSwitchBlock.cs │ ├── StateShiftBlock.cs │ ├── StateMiasma.cs │ ├── StateScenario.cs │ ├── StateUpdate.cs │ ├── StateInit.cs │ ├── StateArcher.cs │ ├── StateEntity.cs │ ├── Vec2.cs │ └── Types.cs ├── Core │ ├── ConfigException.cs │ ├── MessageException.cs │ ├── OperationException.cs │ ├── botpool.json │ ├── Sound.cs │ ├── AgentConnectionRemote.cs │ ├── KeyboardConfigs.cs │ ├── SandboxLevelSystem.cs │ ├── Server.cs │ ├── AgentConfigExtension.cs │ ├── AgentConnection.cs │ ├── RemoteConnection.cs │ ├── EntityCreator.cs │ └── ConnectionDispatcher.cs ├── Data │ ├── GameModes.cs │ ├── Metadata.cs │ ├── AgentConfig.cs │ ├── MatchConfig.cs │ └── Message.cs ├── Mod │ ├── Enemy.cs │ ├── Orb.cs │ ├── Lava.cs │ ├── Bat.cs │ ├── Icicle.cs │ ├── Lantern.cs │ ├── BombArrow.cs │ ├── Cultist.cs │ ├── CrackedWall.cs │ ├── OrbPickup.cs │ ├── PlayerCorpse.cs │ ├── DefaultHat.cs │ ├── ReaperBomb.cs │ ├── EnemyAttack.cs │ ├── ReaperCrystal.cs │ ├── ProximityBlock.cs │ ├── CrackedPlatform.cs │ ├── Ghost.cs │ ├── Slime.cs │ ├── KingReaper.cs │ ├── SwitchBlock.cs │ ├── ArrowTypePickup.cs │ ├── Miasma.cs │ ├── SpikeBall.cs │ ├── EvilCrystal.cs │ ├── FloorMiasma.cs │ ├── TreasureChest.cs │ ├── Session.cs │ ├── SuperBombArrow.cs │ ├── QuestSpawnPortal.cs │ ├── ReaperBeam.cs │ ├── ShiftBlock.cs │ ├── Birdman.cs │ ├── Level.cs │ ├── Player.cs │ ├── Skeleton.cs │ ├── MenuInput.cs │ ├── PauseMenu.cs │ ├── TFGame.cs │ └── Entity.cs ├── Common │ ├── TaskExtensions.cs │ ├── StringExtensions.cs │ ├── DoubleDictionary.cs │ ├── app.config │ ├── packages.config │ ├── Logger.cs │ ├── KeySemaphore.cs │ ├── Properties │ │ └── AssemblyInfo.cs │ ├── CustomLogger.cs │ ├── AsyncQueue.cs │ ├── Util.cs │ ├── Common.csproj │ └── RemoteConnection.cs ├── packages.config ├── app.config └── Properties │ └── AssemblyInfo.cs ├── PatcherLib ├── packages.config ├── PatcherAttribute.cs ├── Properties │ └── AssemblyInfo.cs └── PatcherLib.csproj ├── Patcher ├── Properties │ ├── Pacher.cs │ └── AssemblyInfo.cs ├── packages.config ├── CecilExtensions.cs ├── Program.cs └── Patcher.csproj ├── TowerfallAiModTests ├── packages.config ├── app.config ├── Properties │ └── AssemblyInfo.cs ├── AsyncQueueTest.cs └── TowerfallAiModTests.csproj ├── TowerFallAi.sln └── .gitignore /python/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | psutil -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | .vscode -------------------------------------------------------------------------------- /banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TowerfallAi/towerfall-ai/HEAD/banner.png -------------------------------------------------------------------------------- /python/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .simple_agent import SimpleAgent 2 | 3 | __all__ = [ 4 | 'SimpleAgent' 5 | ] -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/State.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class State { 3 | public string type; 4 | } 5 | } -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateLava.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateLava : StateEntity { 3 | public float height; 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateHat.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateHat : StateEntity { 3 | public int playerIndex; 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateItem.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateItem : StateEntity { 3 | public string itemType; 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateChest.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateChest : StateEntity { 3 | public string chestType; 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateFalling.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateFalling : StateEntity { 3 | public bool falling; 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateSubType.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateSubType : StateEntity { 3 | public string subType; 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /python/towerfall/__init__.py: -------------------------------------------------------------------------------- 1 | from .connection import Connection 2 | from .towerfall import Towerfall 3 | 4 | __all__ = [ 5 | 'Connection', 6 | 'Towerfall', 7 | ] 8 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateCrackedWall.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateCrackedWall : StateEntity { 3 | public float count; 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateKingReaper.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateKingReaper : StateEntity { 3 | public bool shield; 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /TowerFallAi/PatcherLib/packages.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateProximityBlock.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateProximityBlock : StateEntity { 3 | public bool collidable; 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateArrow.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateArrow : StateEntity { 3 | public string arrowType; 4 | public float timeLeft; 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateReaperBeam.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateReaperBeam : StateEntity { 3 | public Vec2 dir; 4 | public float width; 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateSpikeBall.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateSpikeBall : StateEntity { 3 | public Vec2 center; 4 | public float radius; 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateSwitchBlock.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateSwitchBlock : StateEntity { 3 | public bool warning; 4 | public bool collidable; 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateShiftBlock.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateShiftBlock : StateEntity { 3 | public Vec2 startPosition; 4 | public Vec2 endPosition; 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /python/common/constants.py: -------------------------------------------------------------------------------- 1 | WIDTH = 320 2 | HEIGHT = 240 3 | 4 | HW = WIDTH // 2 5 | HH = HEIGHT // 2 6 | 7 | LEFT = 'l' 8 | RIGHT = 'r' 9 | DOWN = 'd' 10 | UP = 'u' 11 | JUMP = 'j' 12 | DASH = 'z' 13 | SHOOT = 's' 14 | -------------------------------------------------------------------------------- /TowerFallAi/Patcher/Properties/Pacher.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | 6 | namespace Patcher.Properties { 7 | internal class Pacher { 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/ConfigException.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace TowerfallAi.Core { 4 | public class ConfigException : Exception { 5 | public ConfigException(string message) : base(message) { } 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/MessageException.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace TowerfallAiMod.Core { 4 | public class MessageException : Exception { 5 | public MessageException(string message) : base(message) { } 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateMiasma.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateMiasma : StateEntity { 3 | public float left; 4 | public float right; 5 | public float bottom; 6 | public float top; 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/OperationException.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace TowerfallAi.Core { 4 | public class OperationException : Exception { 5 | public OperationException(string message) : base(message) { } 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Data/GameModes.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Data { 2 | public static class GameModes { 3 | public static string Versus = "versus"; 4 | public static string Quest = "quest"; 5 | public static string Sandbox = "sandbox"; 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /TowerFallAi/PatcherLib/PatcherAttribute.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace Patcher { 4 | public class PatchAttribute : Attribute { 5 | public string Name; 6 | public PatchAttribute(string name = null) { 7 | this.Name = name; 8 | } 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateScenario.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateScenario : State { 3 | public StateScenario() { 4 | type = "scenario"; 5 | } 6 | 7 | public int[,] grid; 8 | public float cellSize = 10; 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /python/common/logging_options.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class NoLevelFormatter(logging.Formatter): 5 | def format(self, record): 6 | return record.getMessage() 7 | 8 | def default_logging(): 9 | logging.basicConfig(level=logging.INFO) 10 | logging.getLogger().handlers[0].setFormatter(NoLevelFormatter()) -------------------------------------------------------------------------------- /TowerFallAi/Patcher/packages.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateUpdate.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | 3 | namespace TowerfallAi.Api { 4 | public class StateUpdate : State { 5 | public StateUpdate() { 6 | type = "update"; 7 | } 8 | 9 | public List entities = new List(); 10 | 11 | public float dt; 12 | public int id; 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Enemy.cs: -------------------------------------------------------------------------------- 1 | using TowerFall; 2 | using TowerfallAi.Common; 3 | using System.Reflection; 4 | 5 | namespace TowerfallAi.Mod { 6 | public static class ExtEnemy { 7 | public static bool IsDead(this Enemy e) { 8 | return (bool)Util.GetFieldValue("dead", typeof(Enemy), e, BindingFlags.Instance | BindingFlags.NonPublic); 9 | } 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateInit.cs: -------------------------------------------------------------------------------- 1 | using TowerfallAi.Core; 2 | 3 | namespace TowerfallAi.Api { 4 | public class StateInit : State { 5 | public StateInit() { 6 | type = "init"; 7 | version = AiMod.Version; 8 | } 9 | 10 | // The version of the mod. 11 | public string version; 12 | 13 | // The index of the player 14 | public int index; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/TaskExtensions.cs: -------------------------------------------------------------------------------- 1 | using System.Threading.Tasks; 2 | 3 | namespace TowerfallAi.Common { 4 | public static class TaskExtensions { 5 | public static bool IsAlive(this Task task) { 6 | return 7 | task.Status != TaskStatus.RanToCompletion && 8 | task.Status != TaskStatus.Faulted && 9 | task.Status != TaskStatus.Canceled; 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Data/Metadata.cs: -------------------------------------------------------------------------------- 1 | using System.Runtime.Serialization; 2 | 3 | namespace TowerfallAi.Data { 4 | [DataContract] 5 | public class Metadata { 6 | [DataMember(EmitDefaultValue = true)] 7 | public int port; 8 | 9 | [DataMember(EmitDefaultValue = true)] 10 | public bool fastrun; 11 | 12 | [DataMember(EmitDefaultValue = true)] 13 | public bool nographics; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/packages.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Orb.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.Orb")] 7 | public static class ModOrb{ 8 | public static StateEntity GetState(this Orb ent) { 9 | var aiState = new StateFalling { type = Types.Orb }; 10 | ExtEntity.SetAiState(ent, aiState); 11 | aiState.falling = ent.falling; 12 | return aiState; 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Lava.cs: -------------------------------------------------------------------------------- 1 | using TowerfallAi.Api; 2 | using Patcher; 3 | using TowerFall; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.Lava")] 7 | public static class ModLava { 8 | public static StateEntity GetState(this Lava ent) { 9 | var aiState = new StateLava { type = Types.Lava }; 10 | ExtEntity.SetAiState(ent, aiState); 11 | aiState.height = ent.Collider.Top; 12 | return aiState; 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Bat.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.Bat")] 7 | public static class ModBat { 8 | public static StateEntity GetState(this Bat ent) { 9 | var aiState = new StateEntity(); 10 | 11 | aiState.type = ConversionTypes.BatTypes.GetB(ent.batType); 12 | ExtEntity.SetAiState(ent, aiState); 13 | return aiState; 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Icicle.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.Icicle")] 7 | public static class ModIcicle { 8 | public static StateEntity GetState(this Icicle ent) { 9 | var aiState = new StateFalling { type = Types.Icicle }; 10 | ExtEntity.SetAiState(ent, aiState); 11 | aiState.falling = ent.falling; 12 | return aiState; 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateArcher.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | 3 | namespace TowerfallAi.Api { 4 | public class StateArcher : StateEntity { 5 | public int playerIndex; 6 | public bool shield; 7 | public bool wing; 8 | public List arrows; 9 | public bool dead; 10 | public bool onGround; 11 | public bool onWall; 12 | public Vec2 aimDirection; 13 | public string team; 14 | public bool dodgeCooldown; 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Lantern.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.Lantern")] 7 | public static class ModLantern { 8 | public static StateEntity GetState(this Lantern ent) { 9 | var aiState = new StateFalling { type = Types.Lantern }; 10 | ExtEntity.SetAiState(ent, aiState); 11 | aiState.falling = ent.falling; 12 | return aiState; 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/BombArrow.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using Patcher; 3 | using TowerFall; 4 | using TowerfallAi.Api; 5 | 6 | namespace TowerfallAi.Mod { 7 | [Patch("TowerFall.BombArrow")] 8 | public static class ModBombArrow { 9 | public static StateEntity GetState(this BombArrow ent) { 10 | var state = (StateArrow)ExtEntity.GetStateArrow(ent); 11 | state.timeLeft = (float)Math.Ceiling(ent.explodeAlarm.FramesLeft); 12 | return state; 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Cultist.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.Cultist")] 7 | public static class ModCultist { 8 | public static StateEntity GetState(this Cultist ent) { 9 | var aiState = new StateEntity(); 10 | 11 | aiState.type = ConversionTypes.CultistTypes.GetB(ent.type); 12 | ExtEntity.SetAiState(ent, aiState); 13 | return aiState; 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/StringExtensions.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Common { 2 | public static class StringExtensions { 3 | public static string Format(this string s, params object[] args) { 4 | return string.Format(s, args); 5 | } 6 | 7 | public static string FirstLower(this string s) { 8 | if (string.IsNullOrWhiteSpace(s)) { 9 | return s; 10 | } 11 | return char.ToLowerInvariant(s[0]) + s.Substring(1, s.Length - 1); 12 | } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/CrackedWall.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.CrackedWall")] 7 | public static class ModCrackedWall{ 8 | public static StateEntity GetState(this CrackedWall ent) { 9 | var state = new StateCrackedWall { type = Types.CrackedWall }; 10 | ExtEntity.SetAiState(ent, state); 11 | state.count = ent.explodeCounter; 12 | return state; 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/OrbPickup.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.OrbPickup")] 7 | public static class ModOrbPickup{ 8 | public static StateEntity GetState(this OrbPickup ent) { 9 | var aiState = new StateItem { type = Types.Item }; 10 | ExtEntity.SetAiState(ent, aiState); 11 | aiState.itemType = "orb" + ent.orbType.ToString(); 12 | return aiState; 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/PlayerCorpse.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.PlayerCorpse")] 7 | public static class ModPlayerCorpse { 8 | public static StateEntity GetState(this PlayerCorpse ent) { 9 | if (ent.PlayerIndex < 0) return null; 10 | var state = new StateEntity { type = "playerCorpse" }; 11 | ExtEntity.SetAiState(ent, state); 12 | return state; 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/DefaultHat.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.DefaultHat")] 7 | public static class ModDefaultHat{ 8 | public static StateEntity GetState(this DefaultHat ent) { 9 | var aiState = new StateHat { type = Types.Hat }; 10 | 11 | aiState.playerIndex = ent.PlayerIndex; 12 | ExtEntity.SetAiState(ent, aiState); 13 | return aiState; 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/botpool.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "ip": "127.0.0.1", 4 | //"ip": "104.45.224.202", 5 | //"ip": "192.168.0.12", 6 | "port": 4511 7 | }, 8 | { 9 | "ip": "104.45.224.202", 10 | //"ip": "127.0.0.1", 11 | //"ip": "192.168.0.12", 12 | "port": 4511 13 | }, 14 | { 15 | "ip": "127.0.0.1", 16 | "port": 4513 17 | }, 18 | { 19 | "ip": "127.0.0.1", 20 | "port": 4514 21 | } 22 | ] -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/ReaperBomb.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.KingReaper/ReaperBomb")] 7 | public static class ModReaperBomb{ 8 | public static StateEntity GetState(this KingReaper.ReaperBomb ent) { 9 | var aiState = new StateEntity { type = "kingReaperBomb" }; 10 | ExtEntity.SetAiState(ent, aiState); 11 | aiState.canHurt = true; 12 | return aiState; 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /python/gym_wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from .actions import Actions 2 | from .base_env import TowerfallEnv 3 | from .blank_env import TowerfallBlankEnv 4 | from .kill_enemy_objective import KillEnemyObjective 5 | from .objective import Objective 6 | from .observation import Observation 7 | from .player_observation import PlayerObservation 8 | 9 | __all__ = [ 10 | 'Actions', 11 | 'KillEnemyObjective', 12 | 'Objective', 13 | 'Observation', 14 | 'PlayerObservation', 15 | 'TowerfallBlankEnv', 16 | 'TowerfallEnv', 17 | ] -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/EnemyAttack.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.EnemyAttack")] 7 | public static class ModEnemyAttack { 8 | public static StateEntity GetState(this EnemyAttack ent) { 9 | var aiState = new StateEntity { 10 | type = "enemyAttack", 11 | }; 12 | 13 | ExtEntity.SetAiState(ent, aiState); 14 | aiState.canHurt = true; 15 | return aiState; 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/ReaperCrystal.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.KingReaper/ReaperCrystal")] 7 | public static class ModReaperCrystal { 8 | public static StateEntity GetState(this KingReaper.ReaperCrystal ent) { 9 | var aiState = new StateEntity { type = "kingReaperCrystal" }; 10 | ExtEntity.SetAiState(ent, aiState); 11 | aiState.canHurt = true; 12 | return aiState; 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /python/gym_wrapper/objective.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from .base_env import TowerfallEnv 4 | from .observation import Observation 5 | 6 | 7 | class Objective(Observation): 8 | ''' 9 | Overite this to define an objective by calculating reward and done. 10 | ''' 11 | def __init__(self): 12 | self.done: bool 13 | self.reward: float 14 | self.env: TowerfallEnv 15 | 16 | def get_reset_entities(self) -> list[Dict[str, Any]]: 17 | '''Specifies how the environment needs to be reset.''' 18 | return [] 19 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/StateEntity.cs: -------------------------------------------------------------------------------- 1 | namespace TowerfallAi.Api { 2 | public class StateEntity : State { 3 | public Vec2 pos; 4 | public Vec2 vel; 5 | public Vec2 size; 6 | public int id; 7 | public string state; 8 | public bool isEnemy; 9 | public bool canHurt; 10 | public bool canBounceOn; 11 | public bool isDead; 12 | public int facing; 13 | 14 | public StateEntity ReCenter() { 15 | pos.x += size.x / 2; 16 | pos.y -= size.y / 2; 17 | return this; 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/ProximityBlock.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.ProximityBlock")] 7 | public static class ModProximityBlock { 8 | public static StateEntity GetState(this ProximityBlock ent) { 9 | var aiState = new StateProximityBlock { type = Types.ProximityBlock }; 10 | 11 | ExtEntity.SetAiState(ent, aiState); 12 | aiState.collidable = ent.Collidable; 13 | 14 | return aiState; 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/CrackedPlatform.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | using TowerfallAi.Common; 5 | 6 | namespace TowerfallAi.Mod { 7 | [Patch("TowerFall.CrackedPlatform")] 8 | public static class ModCrackedPlatform { 9 | public static StateEntity GetState(this CrackedPlatform ent) { 10 | var state = new StateEntity { type = Types.CrackedPlatform }; 11 | ExtEntity.SetAiState(ent, state); 12 | state.state = ent.state.ToString().FirstLower(); 13 | return state; 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/DoubleDictionary.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | 3 | namespace TowerfallAi.Common { 4 | public class DoubleDictionary { 5 | private Dictionary first = new Dictionary(); 6 | private Dictionary second = new Dictionary(); 7 | 8 | public void Add(A a, B b) { 9 | first.Add(a, b); 10 | second.Add(b, a); 11 | } 12 | 13 | public B GetB(A a) { 14 | return first[a]; 15 | } 16 | 17 | public A GetA(B b) { 18 | return second[b]; 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Data/AgentConfig.cs: -------------------------------------------------------------------------------- 1 | using System.Runtime.Serialization; 2 | 3 | namespace TowerfallAi.Data { 4 | [DataContract] 5 | public class AgentConfig { 6 | public static class Type { 7 | public static string Human = "human"; 8 | public static string Remote = "remote"; 9 | } 10 | 11 | [DataMember(EmitDefaultValue = true)] 12 | public string type; 13 | 14 | [DataMember(EmitDefaultValue = false)] 15 | public string team; 16 | 17 | [DataMember(EmitDefaultValue = false)] 18 | public string archer; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiModTests/packages.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/app.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Ghost.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | using TowerfallAi.Common; 5 | 6 | namespace TowerfallAi.Mod { 7 | [Patch("TowerFall.Ghost")] 8 | public static class ModGhost { 9 | public static StateEntity GetState(this Ghost ent) { 10 | var aiState = new StateSubType { type = "ghost" }; 11 | ExtEntity.SetAiState(ent, aiState); 12 | aiState.subType = ConversionTypes.GhostTypes.GetB((Ghost.GhostTypes)Util.GetPrivateFieldValue("ghostType", ent)); 13 | return aiState; 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Slime.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | using TowerfallAi.Common; 5 | 6 | namespace TowerfallAi.Mod { 7 | [Patch("TowerFall.Slime")] 8 | public static class ModSlime { 9 | public static StateEntity GetState(this Slime ent) { 10 | var aiState = new StateSubType { type = "slime" }; 11 | ExtEntity.SetAiState(ent, aiState); 12 | aiState.subType = ConversionTypes.SlimeTypes.GetB((Slime.SlimeColors)Util.GetPrivateFieldValue("slimeColor", ent)); 13 | return aiState; 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/KingReaper.cs: -------------------------------------------------------------------------------- 1 | using Monocle; 2 | using Patcher; 3 | using TowerFall; 4 | using TowerfallAi.Api; 5 | using TowerfallAi.Common; 6 | 7 | namespace TowerfallAi.Mod { 8 | [Patch("TowerFall.KingReaper")] 9 | public static class ModKingReaper { 10 | public static StateEntity GetState(this KingReaper ent) { 11 | var aiState = new StateKingReaper { type = "kingReaper" }; 12 | ExtEntity.SetAiState(ent, aiState); 13 | aiState.shield = (Counter)Util.GetPrivateFieldValue("shieldCounter", ent); 14 | return aiState; 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/SwitchBlock.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.SwitchBlock")] 7 | public static class ModSwitchBlock { 8 | public static StateEntity GetState(this SwitchBlock ent) { 9 | var aiState = new StateSwitchBlock { type = Types.SwitchBlock }; 10 | 11 | ExtEntity.SetAiState(ent, aiState); 12 | aiState.collidable = ent.Collidable; 13 | aiState.warning = ent.drawFlicker || ent.DrawWarning > 0; 14 | 15 | return aiState; 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/Sound.cs: -------------------------------------------------------------------------------- 1 | using Monocle; 2 | 3 | namespace TowerfallAi.Core { 4 | public static class Sound { 5 | private static bool soundStoped; 6 | private static float musicVolume; 7 | 8 | public static void StopSound() { 9 | if (soundStoped) return; 10 | musicVolume = Music.MasterVolume; 11 | Music.MasterVolume = 0; 12 | soundStoped = true; 13 | } 14 | 15 | public static void ResumeSound() { 16 | if (!soundStoped) return; 17 | Music.MasterVolume = musicVolume; 18 | soundStoped = false; 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/ArrowTypePickup.cs: -------------------------------------------------------------------------------- 1 | using Monocle; 2 | using Patcher; 3 | using TowerFall; 4 | using TowerfallAi.Api; 5 | 6 | namespace TowerfallAi.Mod { 7 | [Patch("TowerFall.ArrowTypePickup")] 8 | public static class ModArrowTypePickup { 9 | public static StateEntity GetState(this Entity ent) { 10 | var item = ent as ArrowTypePickup; 11 | var state = new StateItem { 12 | type = Types.Item, 13 | itemType = "arrow" + item.arrowType.ToString() 14 | }; 15 | ExtEntity.SetAiState(ent, state); 16 | return state; 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Miasma.cs: -------------------------------------------------------------------------------- 1 | using Monocle; 2 | using Patcher; 3 | using TowerFall; 4 | using TowerfallAi.Api; 5 | 6 | namespace TowerfallAi.Mod { 7 | [Patch("TowerFall.Miasma")] 8 | public static class ModMiasma{ 9 | public static StateEntity GetState(this Miasma ent) { 10 | var aiState = new StateMiasma { type = Types.Miasma }; 11 | ExtEntity.SetAiState(ent, aiState); 12 | aiState.left = ((ColliderList)ent.Collider).colliders[0].Right; 13 | aiState.right = ((ColliderList)ent.Collider).colliders[1].Left; 14 | return aiState; 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/SpikeBall.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.Spikeball")] 7 | public static class ModSpikeball { 8 | public static StateEntity GetState(this Spikeball ent) { 9 | var aiState = new StateSpikeBall { type = Types.SpikeBall }; 10 | 11 | ExtEntity.SetAiState(ent, aiState); 12 | aiState.center = new Vec2 { 13 | x = ent.pivot.X, 14 | y = ent.pivot.Y 15 | }; 16 | 17 | aiState.radius = ent.radius; 18 | 19 | return aiState; 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/EvilCrystal.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | using TowerfallAi.Common; 5 | 6 | namespace TowerfallAi.Mod { 7 | [Patch("TowerFall.EvilCrystal")] 8 | public static class ModEvilCrystal { 9 | public static StateEntity GetState(this EvilCrystal ent) { 10 | var aiState = new StateSubType { type = "evilCrystal" }; 11 | ExtEntity.SetAiState(ent, aiState); 12 | aiState.subType = ConversionTypes.CrystalTypes.GetB((EvilCrystal.CrystalColors)Util.GetPrivateFieldValue("crystalColor", ent)); 13 | return aiState; 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/FloorMiasma.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | 5 | namespace TowerfallAi.Mod { 6 | [Patch("TowerFall.FloorMiasma")] 7 | public static class ModFloorMiasma { 8 | public static StateEntity GetState(this FloorMiasma ent) { 9 | if (ent.state == FloorMiasma.States.Invisible) return null; 10 | 11 | var aiState = new StateEntity { type = Types.FloorMiasma }; 12 | 13 | if (ent.state == FloorMiasma.States.Dangerous) { 14 | aiState.canHurt = true; 15 | } 16 | 17 | ExtEntity.SetAiState(ent, aiState); 18 | return aiState; 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/packages.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/TreasureChest.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | using TowerfallAi.Common; 5 | 6 | namespace TowerfallAi.Mod { 7 | [Patch("TowerFall.TreasureChest")] 8 | public static class ModTreasureChest { 9 | public static StateEntity GetState(this TreasureChest ent) { 10 | if (ent.State == TreasureChest.States.WaitingToAppear) return null; 11 | 12 | var aiState = new StateChest { type = Types.Chest }; 13 | ExtEntity.SetAiState(ent, aiState); 14 | aiState.state = ent.State.ToString().FirstLower(); 15 | aiState.chestType = ent.type.ToString().FirstLower(); 16 | 17 | return aiState; 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Session.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using Patcher; 3 | using TowerFall; 4 | using TowerfallAi.Common; 5 | using TowerfallAi.Core; 6 | 7 | namespace TowerfallAi.Mod { 8 | [Patch] 9 | public class ModSession : Session { 10 | Action originalOnLevelLoadFinish; 11 | 12 | public ModSession(MatchSettings matchSettings) : base(matchSettings) { 13 | originalOnLevelLoadFinish = Util.GetAction("$original_OnLevelLoadFinish", typeof(Session), this); 14 | } 15 | 16 | public override void OnLevelLoadFinish() { 17 | originalOnLevelLoadFinish(); 18 | 19 | if (AiMod.Enabled) { 20 | Agents.NotifyLevelLoad(this.CurrentLevel); 21 | } 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/SuperBombArrow.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using Patcher; 3 | using TowerFall; 4 | using TowerfallAi.Api; 5 | using TowerfallAi.Common; 6 | 7 | namespace TowerfallAi.Mod { 8 | [Patch("TowerFall.SuperBombArrow")] 9 | public static class ModSuperBombArrow { 10 | public static StateEntity GetState(this SuperBombArrow ent) { 11 | var aiState = new StateArrow { type = Types.Arrow }; 12 | 13 | ExtEntity.SetAiState(ent, aiState); 14 | aiState.state = ent.State.ToString().FirstLower(); 15 | aiState.arrowType = ent.ArrowType.ToString().FirstLower(); 16 | aiState.timeLeft = (int)Math.Ceiling(ent.explodeAlarm.FramesLeft); 17 | 18 | return aiState; 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/QuestSpawnPortal.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using Patcher; 3 | using TowerFall; 4 | using TowerfallAi.Api; 5 | using TowerfallAi.Common; 6 | 7 | namespace TowerfallAi.Mod { 8 | [Patch("TowerFall.QuestSpawnPortal")] 9 | public static class ModQuestSpawnPortal { 10 | public static StateEntity GetState(this QuestSpawnPortal ent) { 11 | if (!(bool)Util.GetFieldValue("appeared", typeof(QuestSpawnPortal), ent, BindingFlags.Instance | BindingFlags.NonPublic)) { 12 | return null; 13 | } 14 | 15 | var aiState = new StateEntity { 16 | type = "portal", 17 | }; 18 | 19 | ExtEntity.SetAiState(ent, aiState); 20 | return aiState; 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/app.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiModTests/app.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiModTests/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.CompilerServices; 3 | using System.Runtime.InteropServices; 4 | 5 | [assembly: AssemblyTitle("TowerfallAiModTests")] 6 | [assembly: AssemblyDescription("")] 7 | [assembly: AssemblyConfiguration("")] 8 | [assembly: AssemblyCompany("")] 9 | [assembly: AssemblyProduct("TowerfallAiModTests")] 10 | [assembly: AssemblyCopyright("Copyright © 2023")] 11 | [assembly: AssemblyTrademark("")] 12 | [assembly: AssemblyCulture("")] 13 | 14 | [assembly: ComVisible(false)] 15 | 16 | [assembly: Guid("19cfa5fc-a793-46ef-9f34-88f6c9d8c19f")] 17 | 18 | // [assembly: AssemblyVersion("1.0.*")] 19 | [assembly: AssemblyVersion("1.0.0.0")] 20 | [assembly: AssemblyFileVersion("1.0.0.0")] 21 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/ReaperBeam.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.Xna.Framework; 2 | using Patcher; 3 | using TowerFall; 4 | using TowerfallAi.Api; 5 | using TowerfallAi.Common; 6 | 7 | namespace TowerfallAi.Mod { 8 | [Patch("TowerFall.KingReaper/ReaperBeam")] 9 | public static class ModReaperBeam { 10 | public static StateEntity GetState(this KingReaper.ReaperBeam ent) { 11 | var aiState = new StateReaperBeam { type = "kingReaperBeam" }; 12 | ExtEntity.SetAiState(ent, aiState); 13 | aiState.canHurt = ent.Collidable; 14 | Vector2 normal = (Vector2)Util.GetPrivateFieldValue("normal", ent); 15 | aiState.dir = new Vec2 { 16 | x = normal.X, 17 | y = normal.Y 18 | }; 19 | aiState.width = 8; 20 | return aiState; 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/Logger.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | 4 | namespace TowerfallAi.Common { 5 | public static class Logger { 6 | static CustomLogger logger; 7 | 8 | public static void Init(string folder) { 9 | string path = Path.Combine(folder, "logs", $"{DateTime.UtcNow.Ticks / 1000000}.log"); 10 | logger = new CustomLogger(path); 11 | } 12 | 13 | public static void WriteLine(string message) { 14 | logger.WriteLine(message); 15 | } 16 | 17 | public static void Log(string message, string level) { 18 | logger.Log(message, level); 19 | } 20 | 21 | public static void Error(string message) { 22 | logger.Error(message); 23 | } 24 | 25 | public static void Info(string message) { 26 | logger.Info(message); 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Data/MatchConfig.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Runtime.Serialization; 4 | using TowerfallAi.Api; 5 | 6 | namespace TowerfallAi.Data { 7 | [DataContract] 8 | public class MatchConfig { 9 | 10 | [DataMember(EmitDefaultValue = false)] 11 | public int level; 12 | 13 | [DataMember(EmitDefaultValue = false)] 14 | public int skipWaves; 15 | 16 | [DataMember(EmitDefaultValue = false)] 17 | public List agents; 18 | 19 | [DataMember(EmitDefaultValue = false)] 20 | public string mode; 21 | 22 | [DataMember(EmitDefaultValue = false)] 23 | public TimeSpan agentTimeout; 24 | 25 | [DataMember(EmitDefaultValue = false)] 26 | public int fps; 27 | 28 | [DataMember(EmitDefaultValue = false)] 29 | public int[,] solids; 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/ShiftBlock.cs: -------------------------------------------------------------------------------- 1 | using Patcher; 2 | using TowerFall; 3 | using TowerfallAi.Api; 4 | using TowerfallAi.Common; 5 | 6 | namespace TowerfallAi.Mod { 7 | [Patch("TowerFall.ShiftBlock")] 8 | public static class ModShiftBlock { 9 | public static StateEntity GetState(this ShiftBlock ent) { 10 | var aiState = new StateShiftBlock { type = Types.ShiftBlock }; 11 | ExtEntity.SetAiState(ent, aiState); 12 | aiState.startPosition = new Vec2 { 13 | x = ent.startPosition.X, 14 | y = ent.startPosition.Y 15 | }; 16 | 17 | aiState.endPosition = new Vec2 { 18 | x = ent.node.X, 19 | y = ent.node.Y 20 | }; 21 | 22 | aiState.state = ((ShiftBlock.States)Util.GetPrivateFieldValue("state", ent)).ToString().FirstLower(); 23 | return aiState; 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /TowerFallAi/Patcher/CecilExtensions.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using System.Linq; 3 | using Mono.Cecil; 4 | 5 | namespace Patcher { 6 | static class CecilExtensions { 7 | public static IEnumerable AllNestedTypes(this TypeDefinition type) { 8 | yield return type; 9 | foreach (TypeDefinition nested in type.NestedTypes) { 10 | foreach (TypeDefinition moreNested in AllNestedTypes(nested)) { 11 | yield return moreNested; 12 | } 13 | } 14 | } 15 | 16 | public static IEnumerable AllNestedTypes(this ModuleDefinition module) { 17 | return module.Types.SelectMany(AllNestedTypes); 18 | } 19 | 20 | public static string Signature(this MethodReference method) { 21 | return string.Format("{0}({1})", method.Name, string.Join(", ", method.Parameters.Select(p => p.ParameterType))); 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /python/common/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | 5 | def assert_eq(expected, actual): 6 | assert expected == actual, f'{expected} != {actual}' 7 | 8 | 9 | def assert_close(expected, actual, eps): 10 | assert abs(expected - actual) < eps, f'{expected} and {actual} not within {eps}' 11 | 12 | 13 | def assert_initial_state(expected_entities: List[dict], actual_entities: List[dict]): 14 | actual_entities.sort(key=lambda e: e['pos']['x']) 15 | for i, (expected, actual) in enumerate(zip(expected_entities, actual_entities)): 16 | try: 17 | assert_eq(expected['type'], actual['type']) 18 | if 'subType' in expected: 19 | assert_eq(expected['subType'], actual['subType']) 20 | assert_close(expected['pos']['x'], actual['pos']['x'], 2) 21 | assert_close(expected['pos']['y'], actual['pos']['y'], 2) 22 | except Exception as ex: 23 | logging.error(f'Error in entity {expected["type"]}') 24 | raise ex -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/AgentConnectionRemote.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Threading; 3 | using System.Threading.Tasks; 4 | using Newtonsoft.Json; 5 | using TowerfallAi.Data; 6 | 7 | namespace TowerfallAi.Core { 8 | public class AgentConnectionRemote : AgentConnection { 9 | public RemoteConnection Connection { get; set; } 10 | 11 | public AgentConnectionRemote(int index) : base(index) { } 12 | 13 | public override void Send(string message, int frame) { 14 | Connection.Write(message); 15 | } 16 | 17 | public override async Task ReceiveAsync(TimeSpan timeout = default, CancellationToken cancellationToken = default) { 18 | return JsonConvert.DeserializeObject(await Connection.ReadAsync(timeout, cancellationToken)); 19 | } 20 | 21 | public override void Dispose() { 22 | if (Connection != null) { 23 | Connection.Dispose(); 24 | Connection = null; 25 | } 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /python/gym_wrapper_test.py: -------------------------------------------------------------------------------- 1 | from gym_wrapper import (KillEnemyObjective, PlayerObservation, 2 | TowerfallBlankEnv) 3 | from towerfall import Towerfall 4 | 5 | ''' 6 | Test the gym wrapper with one agent performing random actions. 7 | ''' 8 | 9 | def main(): 10 | n_episodes = 1000 11 | env = create_env() 12 | env.reset() 13 | for _ in range(n_episodes): 14 | _, _, done, _ = env.step(env.action_space.sample()) 15 | if done: 16 | env.reset() 17 | 18 | def create_env() -> TowerfallBlankEnv: 19 | towerfall = Towerfall( 20 | verbose=1, 21 | config=dict( 22 | mode='sandbox', 23 | level='2', 24 | fps=999, 25 | agents=[dict(type='remote')])) 26 | env = TowerfallBlankEnv( 27 | towerfall=towerfall, 28 | observations= [PlayerObservation()], 29 | objective=KillEnemyObjective( 30 | enemy_count=3, 31 | episode_max_len=60*10), 32 | verbose=1) 33 | return env 34 | 35 | 36 | if __name__ == '__main__': 37 | main() -------------------------------------------------------------------------------- /python/gym_wrapper/observation.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, List, Mapping, Optional 3 | 4 | from common.entity import Entity 5 | from gym import Space 6 | 7 | 8 | class Observation(ABC): 9 | ''' 10 | Base class for observations. 11 | ''' 12 | @abstractmethod 13 | def extend_obs_space(self, obs_space_dict: Mapping[str, Space]): 14 | '''Adds the new definitions to observations to obs_space.''' 15 | raise NotImplementedError() 16 | 17 | @abstractmethod 18 | def post_reset(self, state_scenario: Mapping[str, Any], player: Optional[Entity], entities: List[Entity], obs_dict: Dict[str, Any]): 19 | '''Hook for a gym reset call. Adds observations to obs_dict.''' 20 | raise NotImplementedError 21 | 22 | @abstractmethod 23 | def post_step(self, player: Optional[Entity], entities: List[Entity], actions: str, obs_dict: Mapping[str, Any]): 24 | '''Hook for a gym step call. Adds observations to obs_dict.''' 25 | raise NotImplementedError 26 | -------------------------------------------------------------------------------- /TowerFallAi/Patcher/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | 3 | // Information about this assembly is defined by the following attributes. 4 | // Change them to the values specific to your project. 5 | 6 | [assembly: AssemblyTitle("Patcher")] 7 | [assembly: AssemblyDescription("")] 8 | [assembly: AssemblyConfiguration("")] 9 | [assembly: AssemblyCompany("")] 10 | [assembly: AssemblyProduct("")] 11 | [assembly: AssemblyCopyright("Nobody")] 12 | [assembly: AssemblyTrademark("")] 13 | [assembly: AssemblyCulture("")] 14 | 15 | // The assembly version has the format "{Major}.{Minor}.{Build}.{Revision}". 16 | // The form "{Major}.{Minor}.*" will automatically update the build and revision, 17 | // and "{Major}.{Minor}.{Build}.*" will update just the revision. 18 | 19 | [assembly: AssemblyVersion("0.1.1.0")] 20 | 21 | // The following attributes are used to specify the signing key for the assembly, 22 | // if desired. See the Mono documentation for more information about signing. 23 | 24 | //[assembly: AssemblyDelaySign(false)] 25 | //[assembly: AssemblyKeyFile("")] 26 | 27 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Birdman.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using Monocle; 3 | using Patcher; 4 | using TowerFall; 5 | using TowerfallAi.Api; 6 | using TowerfallAi.Common; 7 | 8 | namespace TowerfallAi.Mod { 9 | [Patch("TowerFall.Birdman")] 10 | public static class ModBirdman { 11 | public static StateEntity GetState(this Birdman ent) { 12 | var aiState = new StateEntity { type = "birdman" }; 13 | 14 | if ((Counter)Util.GetFieldValue("attackCooldown", typeof(Birdman), ent, BindingFlags.NonPublic | BindingFlags.Instance) 15 | && !(bool)Util.GetFieldValue("canArrowAttack", typeof(Birdman), ent, BindingFlags.NonPublic | BindingFlags.Instance)) { 16 | aiState.state = "resting"; 17 | } else { 18 | switch (ent.State) { 19 | case Birdman.ST_IDLE: 20 | aiState.state = "idle"; 21 | break; 22 | case Birdman.ST_ATTACK: 23 | aiState.state = "attack"; 24 | break; 25 | } 26 | } 27 | 28 | ExtEntity.SetAiState(ent, aiState); 29 | return aiState; 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.CompilerServices; 3 | 4 | // Information about this assembly is defined by the following attributes. 5 | // Change them to the values specific to your project. 6 | 7 | [assembly: AssemblyTitle("Mod")] 8 | [assembly: AssemblyDescription("")] 9 | [assembly: AssemblyConfiguration("")] 10 | [assembly: AssemblyCompany("")] 11 | [assembly: AssemblyProduct("")] 12 | [assembly: AssemblyCopyright("Nobody")] 13 | [assembly: AssemblyTrademark("")] 14 | [assembly: AssemblyCulture("")] 15 | 16 | // The assembly version has the format "{Major}.{Minor}.{Build}.{Revision}". 17 | // The form "{Major}.{Minor}.*" will automatically update the build and revision, 18 | // and "{Major}.{Minor}.{Build}.*" will update just the revision. 19 | 20 | [assembly: AssemblyVersion("1.0.*")] 21 | 22 | // The following attributes are used to specify the signing key for the assembly, 23 | // if desired. See the Mono documentation for more information about signing. 24 | 25 | //[assembly: AssemblyDelaySign(false)] 26 | //[assembly: AssemblyKeyFile("")] 27 | 28 | -------------------------------------------------------------------------------- /TowerFallAi/PatcherLib/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.CompilerServices; 3 | 4 | // Information about this assembly is defined by the following attributes. 5 | // Change them to the values specific to your project. 6 | 7 | [assembly: AssemblyTitle ("PatcherLib")] 8 | [assembly: AssemblyDescription ("")] 9 | [assembly: AssemblyConfiguration ("")] 10 | [assembly: AssemblyCompany ("")] 11 | [assembly: AssemblyProduct ("")] 12 | [assembly: AssemblyCopyright ("Nobody")] 13 | [assembly: AssemblyTrademark ("")] 14 | [assembly: AssemblyCulture ("")] 15 | 16 | // The assembly version has the format "{Major}.{Minor}.{Build}.{Revision}". 17 | // The form "{Major}.{Minor}.*" will automatically update the build and revision, 18 | // and "{Major}.{Minor}.{Build}.*" will update just the revision. 19 | 20 | [assembly: AssemblyVersion ("1.0.*")] 21 | 22 | // The following attributes are used to specify the signing key for the assembly, 23 | // if desired. See the Mono documentation for more information about signing. 24 | 25 | //[assembly: AssemblyDelaySign(false)] 26 | //[assembly: AssemblyKeyFile("")] 27 | 28 | -------------------------------------------------------------------------------- /python/run_simple_agent_quest.py: -------------------------------------------------------------------------------- 1 | from agents import SimpleAgent 2 | from common.logging_options import default_logging 3 | from towerfall import Towerfall 4 | 5 | default_logging() 6 | 7 | ''' 8 | Two agents playing a normal quest game. 9 | ''' 10 | 11 | def main(): 12 | # Creates or reuse a Towerfall game. 13 | towerfall = Towerfall( 14 | verbose = 1, 15 | config = dict( 16 | mode='quest', 17 | level='2', 18 | fps=60, 19 | agentTimeout='00:00:02', 20 | agents=[ 21 | dict(type='remote', archer='blue'), 22 | dict(type='remote', archer='green')], 23 | ) 24 | ) 25 | 26 | connections = [] 27 | agents = [] 28 | remote_agents = sum(1 for agent in towerfall.config['agents'] if agent['type'] != 'human') 29 | for i in range(remote_agents): 30 | connections.append(towerfall.join(timeout=10, verbose=1)) 31 | agents.append(SimpleAgent(connections[i])) 32 | 33 | while True: 34 | # Read the state of the game then replies with an action. 35 | for connection, agent in zip(connections, agents): 36 | game_state = connection.read_json() 37 | agent.act(game_state) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() -------------------------------------------------------------------------------- /python/run_simple_agent_versus.py: -------------------------------------------------------------------------------- 1 | from agents import SimpleAgent 2 | from common.logging_options import default_logging 3 | from towerfall import Towerfall 4 | 5 | default_logging() 6 | 7 | ''' 8 | Two agents playing a 1 versus 1 match. 9 | ''' 10 | 11 | def main(): 12 | # Creates or reuse a Towerfall game. 13 | towerfall = Towerfall( 14 | verbose = 1, 15 | config = dict( 16 | mode='versus', 17 | level='3', 18 | fps=60, 19 | agentTimeout='00:00:02', 20 | agents=[ 21 | dict(type='remote', archer='green-alt', team='blue'), 22 | dict(type='remote', archer='blue-alt', team='red')], 23 | ) 24 | ) 25 | 26 | connections = [] 27 | agents = [] 28 | remote_agents = sum(1 for agent in towerfall.config['agents'] if agent['type'] != 'human') 29 | for i in range(remote_agents): 30 | connections.append(towerfall.join(timeout=10, verbose=1)) 31 | agents.append(SimpleAgent(connections[i])) 32 | 33 | while True: 34 | # Read the state of the game then replies with an action. 35 | for connection, agent in zip(connections, agents): 36 | game_state = connection.read_json() 37 | agent.act(game_state) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Level.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Xml; 3 | using Microsoft.Xna.Framework; 4 | using Monocle; 5 | using Patcher; 6 | using TowerFall; 7 | using TowerfallAi.Core; 8 | 9 | namespace TowerfallAi.Mod { 10 | [Patch] 11 | public class ModLevel : Level { 12 | Action originalUpdate; 13 | 14 | Action originalHandlePausing; 15 | 16 | public ModLevel(Session session, XmlElement xml) : base(session, xml) { 17 | var ptr = typeof(Level).GetMethod("$original_Update").MethodHandle.GetFunctionPointer(); 18 | originalUpdate = (Action)Activator.CreateInstance(typeof(Action), this, ptr); 19 | 20 | ptr = typeof(Level).GetMethod("$original_HandlePausing").MethodHandle.GetFunctionPointer(); 21 | originalHandlePausing = (Action)Activator.CreateInstance(typeof(Action), this, ptr); 22 | } 23 | 24 | public override void HandlePausing() { 25 | // Avoid pausing when no human is playing and the screen goes out of focus. 26 | if (AiMod.Enabled && !AiMod.IsHumanPlaying()) { 27 | return; 28 | } 29 | 30 | originalHandlePausing(); 31 | } 32 | 33 | public override void Update() { 34 | if (AiMod.Enabled) { 35 | Agents.RefreshInputFromAgents(this); 36 | } 37 | 38 | originalUpdate(); 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /python/run_simple_agent_versus_2v2.py: -------------------------------------------------------------------------------- 1 | from agents import SimpleAgent 2 | from common.logging_options import default_logging 3 | from towerfall import Towerfall 4 | 5 | default_logging() 6 | 7 | ''' 8 | Four agents playing a 2 versus 2 match. 9 | ''' 10 | 11 | def main(): 12 | # Creates or reuse a Towerfall game. 13 | towerfall = Towerfall( 14 | verbose = 1, 15 | config = dict( 16 | mode='versus', 17 | level='3', 18 | fps=60, 19 | agentTimeout='00:00:02', 20 | agents=[ 21 | dict(type='remote', archer='white', team='blue'), 22 | dict(type='remote', archer='green', team='blue'), 23 | dict(type='remote', archer='blue', team='red'), 24 | dict(type='remote', archer='purple', team='red')], 25 | ) 26 | ) 27 | 28 | connections = [] 29 | agents = [] 30 | remote_agents = sum(1 for agent in towerfall.config['agents'] if agent['type'] != 'human') 31 | for i in range(remote_agents): 32 | connections.append(towerfall.join(timeout=10, verbose=1)) 33 | agents.append(SimpleAgent(connections[i])) 34 | 35 | while True: 36 | # Read the state of the game then replies with an action. 37 | for i, (connection, agent) in enumerate(zip(connections, agents)): 38 | game_state = connection.read_json() 39 | agent.act(game_state) 40 | 41 | 42 | if __name__ == '__main__': 43 | main() -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/Vec2.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using TowerfallAi.Common; 3 | 4 | namespace TowerfallAi.Api { 5 | public class Vec2 { 6 | public float x; 7 | public float y; 8 | 9 | public float Dot(Vec2 v) { 10 | return x * v.x + y * v.y; 11 | } 12 | 13 | public Vec2 Add(Vec2 v2) { 14 | return new Vec2 { 15 | x = x + v2.x, 16 | y = y + v2.y 17 | }; 18 | } 19 | 20 | public Vec2 Sub(Vec2 v2) { 21 | return new Vec2 { 22 | x = x - v2.x, 23 | y = y - v2.y 24 | }; 25 | } 26 | 27 | public Vec2 Copy() { 28 | return new Vec2 { 29 | x = x, 30 | y = y 31 | }; 32 | } 33 | 34 | public Vec2 Mul(float a) { 35 | return new Vec2 { 36 | x = x * a, 37 | y = y * a 38 | }; 39 | } 40 | 41 | public Vec2 Div(float a) { 42 | return new Vec2 { 43 | x = x / a, 44 | y = y / a 45 | }; 46 | } 47 | 48 | public void Normalize() { 49 | float length = Length(); 50 | if (length == 0) return; 51 | 52 | x /= length; 53 | y /= length; 54 | } 55 | 56 | public float Length() { 57 | return (float)Math.Sqrt(x * x + y * y); 58 | } 59 | 60 | public override string ToString() { 61 | return "({0}, {1})".Format(x, y); 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /python/run_simple_agent_versus_human.py: -------------------------------------------------------------------------------- 1 | from agents import SimpleAgent 2 | from common.logging_options import default_logging 3 | from towerfall import Towerfall 4 | 5 | default_logging() 6 | 7 | ''' 8 | Three agents playing a 2 versus 2 match alongside a human. 9 | BUG: only the first slot can be human. 10 | ''' 11 | 12 | def main(): 13 | # Creates or reuse a Towerfall game. 14 | towerfall = Towerfall( 15 | verbose = 1, 16 | config = dict( 17 | mode='versus', 18 | level='3', 19 | fps=60, 20 | agentTimeout='00:00:02', 21 | agents=[ 22 | dict(type='human', archer='yellow', team='blue'), 23 | dict(type='remote', archer='purple', team='blue'), 24 | dict(type='remote', archer='orange', team='red'), 25 | dict(type='remote', archer='red', team='red')], 26 | ) 27 | ) 28 | 29 | connections = [] 30 | agents = [] 31 | remote_agents = sum(1 for agent in towerfall.config['agents'] if agent['type'] != 'human') 32 | for i in range(remote_agents): 33 | connections.append(towerfall.join(timeout=20, verbose=1)) 34 | agents.append(SimpleAgent(connections[i])) 35 | 36 | while True: 37 | # Read the state of the game then replies with an action. 38 | for connection, agent in zip(connections, agents): 39 | game_state = connection.read_json() 40 | agent.act(game_state) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() -------------------------------------------------------------------------------- /TowerFallAi/PatcherLib/PatcherLib.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Debug 5 | x86 6 | PatcherLib 7 | PatcherLib 8 | {FAE08B42-BBE6-4177-9110-6DDE1032C415} 9 | 10 | 11 | Library 12 | 13 | 14 | true 15 | bin\x86\PatchWindows\ 16 | DEBUG; 17 | full 18 | x86 19 | 7.3 20 | prompt 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/KeySemaphore.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Text; 5 | using System.Threading; 6 | using System.Threading.Tasks; 7 | 8 | namespace Common { 9 | public class KeySemaphore { 10 | object locker = new object(); 11 | Dictionary semaphores = new Dictionary(); 12 | 13 | public void Run(string key, Action action) { 14 | Semaphore semaphore = null; 15 | 16 | Exception exception = null; 17 | 18 | try { 19 | lock (locker) { 20 | if (!semaphores.TryGetValue(key, out semaphore)) { 21 | semaphore = new Semaphore(1, 1); 22 | semaphores[key] = semaphore; 23 | } 24 | } 25 | 26 | semaphore.WaitOne(); 27 | action(); 28 | } catch (Exception ex) { 29 | exception = ex; 30 | } 31 | 32 | lock (locker) { 33 | if (semaphore != null) { 34 | int c = semaphore.Release(1); 35 | if (c == 0 && semaphores.ContainsKey(key)) { 36 | semaphores.Remove(key); 37 | } 38 | } 39 | } 40 | 41 | if (exception != null) { 42 | throw exception; 43 | } 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Player.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using Microsoft.Xna.Framework; 3 | using Monocle; 4 | using Patcher; 5 | using TowerFall; 6 | using TowerfallAi.Api; 7 | using TowerfallAi.Common; 8 | using TowerfallAi.Core; 9 | 10 | namespace TowerfallAi.Mod { 11 | [Patch("TowerFall.Player")] 12 | public static class ModPlayer{ 13 | public static StateEntity GetState(this Player ent) { 14 | var aiState = new StateArcher() { type = "archer" }; 15 | 16 | ExtEntity.SetAiState(ent, aiState); 17 | aiState.playerIndex = ent.PlayerIndex; 18 | aiState.shield = ent.HasShield; 19 | aiState.wing = ent.HasWings; 20 | aiState.state = ent.State.ToString().FirstLower(); 21 | aiState.arrows = new List(); 22 | List arrows = ent.Arrows.Arrows; 23 | for (int i = 0; i < arrows.Count; i++) { 24 | aiState.arrows.Add(arrows[i].ToString().FirstLower()); 25 | } 26 | aiState.canHurt = ent.CanHurt; 27 | aiState.dead = ent.Dead; 28 | aiState.facing = (int)ent.Facing; 29 | aiState.onGround = ent.OnGround; 30 | aiState.onWall = ent.CanWallJump(Facing.Left) || ent.CanWallJump(Facing.Right); 31 | Vector2 aim = Calc.AngleToVector(ent.AimDirection, 1); 32 | aiState.aimDirection = new Vec2 { 33 | x = aim.X, 34 | y = -aim.Y 35 | }; 36 | aiState.team = AgentConfigExtension.GetTeam(ent.TeamColor); 37 | 38 | return aiState; 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.CompilerServices; 3 | using System.Runtime.InteropServices; 4 | 5 | // General Information about an assembly is controlled through the following 6 | // set of attributes. Change these attribute values to modify the information 7 | // associated with an assembly. 8 | [assembly: AssemblyTitle("Common")] 9 | [assembly: AssemblyDescription("")] 10 | [assembly: AssemblyConfiguration("")] 11 | [assembly: AssemblyCompany("")] 12 | [assembly: AssemblyProduct("Common")] 13 | [assembly: AssemblyCopyright("Copyright © 2016")] 14 | [assembly: AssemblyTrademark("")] 15 | [assembly: AssemblyCulture("")] 16 | 17 | // Setting ComVisible to false makes the types in this assembly not visible 18 | // to COM components. If you need to access a type in this assembly from 19 | // COM, set the ComVisible attribute to true on that type. 20 | [assembly: ComVisible(false)] 21 | 22 | // The following GUID is for the ID of the typelib if this project is exposed to COM 23 | [assembly: Guid("42382a63-69a3-4f90-bd64-bfc2d1ca4393")] 24 | 25 | // Version information for an assembly consists of the following four values: 26 | // 27 | // Major Version 28 | // Minor Version 29 | // Build Number 30 | // Revision 31 | // 32 | // You can specify all the values or you can default the Build and Revision Numbers 33 | // by using the '*' as shown below: 34 | // [assembly: AssemblyVersion("1.0.*")] 35 | [assembly: AssemblyVersion("1.0.0.0")] 36 | [assembly: AssemblyFileVersion("1.0.0.0")] 37 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/KeyboardConfigs.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.Xna.Framework.Input; 2 | using TowerFall; 3 | 4 | namespace TowerfallAiMod.Core { 5 | public static class KeyboardConfigs { 6 | public static KeyboardConfig[] Configs = new KeyboardConfig[4]; 7 | 8 | static KeyboardConfigs() { 9 | for (int i = 0; i < Configs.Length; i++) { 10 | Configs[i] = KeyboardConfig.GetDefault(); 11 | } 12 | 13 | int j = 1; 14 | Configs[j].Down = new Keys[] { Keys.K }; 15 | Configs[j].Up = new Keys[] { Keys.I }; 16 | Configs[j].Left = new Keys[] { Keys.J }; 17 | Configs[j].Right = new Keys[] { Keys.L }; 18 | Configs[j].Jump = new Keys[] { Keys.U }; 19 | Configs[j].Shoot = new Keys[] { Keys.O }; 20 | Configs[j].Dodge = new Keys[] { Keys.NumPad9 }; 21 | 22 | j++; 23 | Configs[j].Down = new Keys[] { Keys.H }; 24 | Configs[j].Up = new Keys[] { Keys.Y }; 25 | Configs[j].Left = new Keys[] { Keys.G }; 26 | Configs[j].Right = new Keys[] { Keys.J }; 27 | Configs[j].Jump = new Keys[] { Keys.Y }; 28 | Configs[j].Shoot = new Keys[] { Keys.I }; 29 | Configs[j].Dodge = new Keys[] { Keys.NumPad6 }; 30 | 31 | j++; 32 | Configs[j].Down = new Keys[] { Keys.W }; 33 | Configs[j].Up = new Keys[] { Keys.S }; 34 | Configs[j].Left = new Keys[] { Keys.A }; 35 | Configs[j].Right = new Keys[] { Keys.F }; 36 | Configs[j].Jump = new Keys[] { Keys.Q }; 37 | Configs[j].Shoot = new Keys[] { Keys.E }; 38 | Configs[j].Dodge = new Keys[] { Keys.NumPad3 }; 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Skeleton.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using Microsoft.Xna.Framework; 3 | using Monocle; 4 | using Patcher; 5 | using TowerFall; 6 | using TowerfallAi.Api; 7 | using TowerfallAi.Common; 8 | 9 | namespace TowerfallAi.Mod { 10 | [Patch("TowerFall.Skeleton")] 11 | public static class ModSkeleton{ 12 | public static StateEntity GetState(this Skeleton ent) { 13 | var aiState = new StateArcher() { type = "archer" }; 14 | 15 | ExtEntity.SetAiState(ent, aiState); 16 | aiState.playerIndex = -1; 17 | PlayerShield shield = ((PlayerShield)Util.GetPrivateFieldValue("shield", ent)); 18 | aiState.shield = shield != null && shield.Visible; 19 | PlayerWings wings = ((PlayerWings)Util.GetPrivateFieldValue("wings", ent)); 20 | aiState.wing = wings != null && wings.Visible; 21 | aiState.arrows = new List(); 22 | List arrows = ent.Arrows.Arrows; 23 | for (int i = 0; i < arrows.Count; i++) { 24 | aiState.arrows.Add(arrows[i].ToString().FirstLower()); 25 | } 26 | aiState.canHurt = ent.CanHurt; 27 | aiState.dead = ent.IsDead(); 28 | aiState.facing = (int)ent.Facing; 29 | aiState.onGround = (bool)Util.GetPrivateFieldValue("onGround", ent); 30 | aiState.onWall = false; 31 | 32 | Vector2 aim = Calc.AngleToVector(ent.AimingAngle, 1); 33 | aiState.aimDirection = new Vec2 { 34 | x = aim.X, 35 | y = -aim.Y 36 | }; 37 | aiState.team = "neutral"; 38 | aiState.isEnemy = true; 39 | 40 | return aiState; 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/MenuInput.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using Monocle; 3 | using Patcher; 4 | using TowerFall; 5 | using TowerfallAi.Core; 6 | 7 | namespace TowerfallAi.Mod { 8 | [Patch("TowerFall.MenuInput")] 9 | public static class ModMenuInput { 10 | [Patch(".cctor")] 11 | static ModMenuInput() { 12 | if (!AiMod.Enabled) { 13 | // Avoid overriding menu inputs created by the mod. 14 | MenuInput.MenuInputs = new PlayerInput[0]; 15 | } 16 | MenuInput.repeatLeftCounter = new Counter(); 17 | MenuInput.repeatRightCounter = new Counter(); 18 | MenuInput.repeatUpCounter = new Counter(); 19 | MenuInput.repeatDownCounter = new Counter(); 20 | } 21 | 22 | public static bool Confirm { 23 | get { 24 | // Makes the bot automatically confirm all menus. 25 | if (AiMod.Enabled && !AiMod.IsHumanPlaying()) return true; 26 | var ptr = typeof(MenuInput).GetMethod("$original_get_Confirm").MethodHandle.GetFunctionPointer(); 27 | var orginalGetConfirm = (Func)Activator.CreateInstance(typeof(Func), null, ptr); 28 | return orginalGetConfirm(); 29 | } 30 | } 31 | 32 | public static bool ConfirmOrStart { 33 | get { 34 | // Makes the bot automatically confirm all menus. 35 | if (AiMod.Enabled && !AiMod.IsHumanPlaying()) return true; 36 | var ptr = typeof(MenuInput).GetMethod("$original_get_ConfirmOrStart").MethodHandle.GetFunctionPointer(); 37 | var orginalGetConfirm = (Func)Activator.CreateInstance(typeof(Func), null, ptr); 38 | return orginalGetConfirm(); 39 | } 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /TowerFallAi/Patcher/Program.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Reflection; 3 | using CommandLine; 4 | 5 | namespace Patcher { 6 | public class Program { 7 | [Verb("makebaseimage", HelpText = "Creates a base image to be referenced by the mod.")] 8 | public class MakeBaseImageOptions { 9 | [Option('t', "target", Required = true, HelpText = "The file to be patched.")] 10 | public string Target { get; set; } 11 | } 12 | 13 | [Verb("patch", HelpText = "Patches the binary with the mod changes.")] 14 | public class PatchOptions { 15 | [Option('t', "target", Required = false, HelpText = "The file to be patched.")] 16 | public string Target { get; set; } 17 | } 18 | 19 | public static int Main(string[] args) { 20 | Patcher patcher = new Patcher(); 21 | if (args.Length == 1 && args[0] == "--version") { 22 | Assembly assembly = Assembly.GetExecutingAssembly(); 23 | Version version = assembly.GetName().Version; 24 | Console.WriteLine(version); 25 | } 26 | 27 | if (args.Length == 0) { 28 | patcher.MakeBaseImage(); 29 | patcher.Patch(); 30 | return 0; 31 | } 32 | 33 | Parser.Default.ParseArguments(args) 34 | .MapResult( 35 | (MakeBaseImageOptions o) => { 36 | patcher.MakeBaseImage(o.Target); 37 | return 0; 38 | }, 39 | (PatchOptions o) => { 40 | patcher.Patch(o.Target); 41 | return 0; 42 | }, 43 | e => 1); ; 44 | if (args.Length == 0) { 45 | Console.WriteLine("Usage: Patcher.exe makeBaseImage | patch *"); 46 | return -1; 47 | } 48 | 49 | return 0; 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /python/gym_wrapper/actions.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from common.constants import DASH, DOWN, JUMP, LEFT, RIGHT, SHOOT, UP 4 | from gym import spaces 5 | from numpy.typing import NDArray 6 | 7 | 8 | class Actions: 9 | ''' 10 | Does conversion between gym actions and Towerfall API actions. 11 | Allows enabling or disabling certain actions. 12 | ''' 13 | def __init__(self, 14 | can_jump=True, 15 | can_dash=True, 16 | can_shoot=True): 17 | 18 | self.can_jump = can_jump 19 | self.can_dash = can_dash 20 | self.can_shoot = can_shoot 21 | 22 | actions = [3,3] 23 | self.action_map: Dict[str, Any] = {} # maps key to index 24 | if can_jump: 25 | self.action_map[JUMP] = len(actions) 26 | actions.append(2) 27 | if can_dash: 28 | self.action_map[DASH] = len(actions) 29 | actions.append(2) 30 | if can_shoot: 31 | self.action_map[SHOOT] = len(actions) 32 | actions.append(2) 33 | self.action_space = spaces.MultiDiscrete(actions) 34 | 35 | def to_serialized_actions(self, actions: NDArray) -> str: 36 | ''' 37 | Converts a list of actions to a the serialized string of pressed buttons that the game expects. 38 | ''' 39 | actions_str = '' 40 | if actions[0] == 0: 41 | actions_str += LEFT 42 | elif actions[0] == 2: 43 | actions_str += RIGHT 44 | if actions[1] == 0: 45 | actions_str += DOWN 46 | elif actions[1] == 2: 47 | actions_str += UP 48 | 49 | if self.can_jump and actions[self.action_map[JUMP]]: 50 | actions_str += JUMP 51 | if self.can_dash and actions[self.action_map[DASH]]: 52 | actions_str += DASH 53 | if self.can_shoot and actions[self.action_map[SHOOT]]: 54 | actions_str += SHOOT 55 | return actions_str -------------------------------------------------------------------------------- /python/gym_wrapper/blank_env.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional, Tuple 3 | 4 | from gym import spaces 5 | from towerfall import Towerfall 6 | 7 | from .actions import Actions 8 | from .base_env import TowerfallEnv 9 | from .objective import Objective 10 | from .observation import Observation 11 | 12 | 13 | class TowerfallBlankEnv(TowerfallEnv): 14 | ''' 15 | A modular implementation of TowerfallEnv that can be customized with the addition of observations and an objective. 16 | ''' 17 | def __init__(self, 18 | towerfall: Towerfall, 19 | observations: List[Observation], 20 | objective: Objective, 21 | actions: Optional[Actions]=None, 22 | record_path: Optional[str]=None, 23 | verbose: int = 0): 24 | super().__init__(towerfall, actions, record_path, verbose) 25 | obs_space = {} 26 | self.components = list(observations) 27 | self.components.append(objective) 28 | self.objective = objective 29 | self.objective.env = self 30 | for obs in self.components: 31 | obs.extend_obs_space(obs_space) 32 | self.observation_space = spaces.Dict(obs_space) 33 | logging.info('Action space: %s', str(self.action_space)) 34 | logging.info('Observation space: %s', str(self.observation_space)) 35 | 36 | def _send_reset(self): 37 | reset_entities = self.objective.get_reset_entities() 38 | self.towerfall.send_reset(reset_entities) 39 | 40 | def _post_reset(self) -> dict: 41 | obs_dict = {} 42 | for obs in self.components: 43 | obs.post_reset(self.state_scenario, self.me, self.entities, obs_dict) 44 | return obs_dict 45 | 46 | def _post_step(self) -> Tuple[object, float, bool, object]: 47 | obs_dict = {} 48 | for obs in self.components: 49 | obs.post_step(self.me, self.entities, self.actions_str, obs_dict) 50 | return obs_dict, self.objective.reward, self.objective.done, {} 51 | -------------------------------------------------------------------------------- /python/run_simple_agent_sandbox.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping 2 | 3 | from common.logging_options import default_logging 4 | from towerfall import Towerfall 5 | 6 | from agents import SimpleAgent 7 | 8 | default_logging() 9 | 10 | ''' 11 | One agent playing a sandbox game. This mode allows resetting the 12 | entities and the scenario in an initial state for quick experimentation. 13 | ''' 14 | 15 | def main(): 16 | # Creates or reuse a Towerfall game. 17 | towerfall = Towerfall( 18 | verbose = 1, 19 | config = dict( 20 | mode='sandbox', 21 | level='2', 22 | fps=60, 23 | agentTimeout='00:00:02', 24 | solids=[[0] * 32]*14 + [[1]*32] + [[0] * 32]*9, 25 | agents=[ 26 | dict(type='remote', archer='white-alt')] 27 | ) 28 | ) 29 | 30 | connection = towerfall.join(timeout=10, verbose=1) 31 | 32 | send_reset(towerfall) 33 | agent = SimpleAgent(connection) 34 | while True: 35 | # Read the state of the game then replies with an action. 36 | game_state = connection.read_json() 37 | check_for_end_condition(game_state, towerfall) 38 | agent.act(game_state) 39 | 40 | 41 | def send_reset(towerfall: Towerfall): 42 | towerfall.send_reset([ 43 | dict(type='archer', pos=dict(x=160, y=110)), 44 | dict(type='slime', facing=-1, pos=dict(x=220, y=110)) 45 | ] 46 | ) 47 | 48 | 49 | def check_for_end_condition(game_state: Mapping[str, Any], towerfall: Towerfall): 50 | # The sandbox mode will not end the game automatically. We need to explicitly request a reset. 51 | if game_state['type'] != 'update': 52 | return 53 | 54 | is_player_alive = False 55 | is_enemy_alive = False 56 | for state in game_state['entities']: 57 | if state['type'] == 'archer': 58 | is_player_alive = True 59 | if state['isEnemy']: 60 | is_enemy_alive = True 61 | if not is_player_alive or not is_enemy_alive: 62 | send_reset(towerfall) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/SandboxLevelSystem.cs: -------------------------------------------------------------------------------- 1 | using System.Text; 2 | using System.Xml; 3 | using Monocle; 4 | using TowerFall; 5 | 6 | namespace TowerfallAi.Core { 7 | public class SandboxLevelSystem : LevelSystem { 8 | private int[,] solids; 9 | public QuestLevelData QuestTowerData { 10 | get; 11 | private set; 12 | } 13 | 14 | public SandboxLevelSystem(QuestLevelData tower, int[,] solids) { 15 | QuestTowerData = tower; 16 | Theme = tower.Theme; 17 | ID = tower.ID; 18 | this.solids = solids; 19 | } 20 | 21 | public override XmlElement GetNextRoundLevel(MatchSettings matchSettings, int roundIndex, out int randomSeed) { 22 | randomSeed = this.QuestTowerData.ID.X; 23 | var xml = Calc.LoadXML(this.QuestTowerData.Path)["level"]; 24 | xml["Entities"].RemoveAll(); 25 | 26 | if (solids != null) { 27 | xml["BG"].InnerText = "00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000\n00000000000000000000000000000000"; 28 | xml["BGTiles"].InnerText = ""; 29 | xml.RemoveChild(xml["SolidTiles"]); 30 | 31 | StringBuilder sb = new StringBuilder(); 32 | for (int i = 0; i < solids.GetLength(0); i++) { 33 | for (int j = 0; j < solids.GetLength(1); j++) { 34 | sb.Append(solids[i, j]); 35 | } 36 | sb.Append('\n'); 37 | } 38 | xml["Solids"].InnerText = sb.ToString(); 39 | } 40 | 41 | return xml; 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/CustomLogger.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | using System.Text; 4 | using System.Threading; 5 | 6 | namespace TowerfallAi.Common { 7 | public class CustomLogger { 8 | static object l = new object(); 9 | string path; 10 | 11 | int batchSize; 12 | 13 | public CustomLogger(string path, int batchSize = 0) { 14 | this.path = path; 15 | Directory.CreateDirectory(Path.GetDirectoryName(path)); 16 | if (!File.Exists(path)) { 17 | var f = File.Create(path); 18 | f.Close(); 19 | } 20 | 21 | this.batchSize = batchSize; 22 | } 23 | 24 | StringBuilder buffer = new StringBuilder(); 25 | 26 | public void Flush() { 27 | lock (l) { 28 | if (buffer.Length > 0) { 29 | string m = buffer.ToString(); 30 | using (var f = File.Open(path, FileMode.Append)) { 31 | using (var w = new StreamWriter(f)) { 32 | w.Write(m); 33 | } 34 | } 35 | 36 | Console.Write(m); 37 | 38 | buffer.Clear(); 39 | } 40 | } 41 | } 42 | 43 | public void Log(string message, string level) { 44 | string m = "{0} {1} {2}".Format(level, Thread.CurrentThread.ManagedThreadId, message); 45 | WriteLine(m); 46 | } 47 | 48 | public void WriteLine(string message) { 49 | lock (l) { 50 | buffer.Append(DateTime.Now.ToString("hh:mm:ss.ffff")); 51 | buffer.Append(" "); 52 | buffer.AppendLine(message); 53 | 54 | if (buffer.Length > batchSize) { 55 | message = buffer.ToString(); 56 | using (var f = File.Open(path, FileMode.Append)) { 57 | using (var w = new StreamWriter(f)) { 58 | w.Write(message); 59 | } 60 | } 61 | 62 | Console.Write(message); 63 | 64 | buffer.Clear(); 65 | } 66 | } 67 | } 68 | 69 | public void Error(string message) { 70 | Log(message, "ERROR"); 71 | } 72 | 73 | public void Info(string message) { 74 | Log(message, "INFO "); 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Data/Message.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Runtime.Serialization; 4 | using Microsoft.Xna.Framework; 5 | using Newtonsoft.Json.Linq; 6 | using TowerfallAi.Api; 7 | using TowerfallAi.Common; 8 | 9 | namespace TowerfallAi.Data { 10 | [DataContract] 11 | public class Message { 12 | public class Type { 13 | // Game actions sent by the bot. 14 | public const string Actions = "actions"; 15 | 16 | // Reset the game with a new config. 17 | public const string Config = "config"; 18 | 19 | // Reset the current level with defined entities. 20 | public const string Reset = "reset"; 21 | 22 | public const string Join = "join"; 23 | 24 | public const string Result = "result"; 25 | } 26 | 27 | [DataMember(EmitDefaultValue = true)] 28 | public string type; 29 | 30 | [DataMember(EmitDefaultValue = false)] 31 | public string actions; 32 | 33 | [DataMember(EmitDefaultValue = false)] 34 | public MatchConfig config; 35 | 36 | [DataMember(EmitDefaultValue = false)] 37 | public State state; 38 | 39 | [DataMember(EmitDefaultValue = false)] 40 | public Vec2 pos; 41 | 42 | [DataMember(EmitDefaultValue = false)] 43 | public List draws; 44 | 45 | [DataMember(EmitDefaultValue = false)] 46 | public List entities; 47 | 48 | [DataMember(EmitDefaultValue = true)] 49 | public bool success; 50 | 51 | [DataMember(EmitDefaultValue = false)] 52 | public string message; 53 | 54 | [DataMember(EmitDefaultValue = false)] 55 | public int id; 56 | } 57 | 58 | public class DrawInstruction { 59 | public string type; 60 | public Vec2 start; 61 | public Vec2 end; 62 | public float[] color; 63 | public float thick; 64 | public void Draw() { 65 | if (type == "line") { 66 | float scale = 3; 67 | Monocle.Draw.Line(new Vector2(start.x * scale, (240 - start.y) * scale), new Vector2(end.x * scale, (240 - end.y) * scale), 68 | new Color(color[0], color[1], color[2]), thick); 69 | } else { 70 | throw new Exception("Type {0} not supported.".Format(type)); 71 | } 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/Server.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Diagnostics; 3 | using System.IO; 4 | using System.Net; 5 | using System.Net.Sockets; 6 | using System.Threading.Tasks; 7 | using TowerfallAi.Common; 8 | 9 | namespace TowerfallAi.Core { 10 | public class Server { 11 | Socket listener; 12 | Task serverTask; 13 | 14 | public Action onConnection; 15 | public Action onClose; 16 | 17 | public Server(Action onConnection, Action onClose) { 18 | this.onConnection = onConnection; 19 | this.onClose = onClose; 20 | } 21 | 22 | public bool IsAlive() { 23 | return serverTask.IsAlive(); 24 | } 25 | 26 | public int Start(string ip = "127.0.0.1", int port = 0) { 27 | if (serverTask != null && serverTask.IsAlive()) { 28 | throw new InvalidOperationException("Server task is still alive"); 29 | } 30 | 31 | SocketPermission permission = new SocketPermission( 32 | NetworkAccess.Accept, 33 | TransportType.Tcp, 34 | "", 35 | SocketPermission.AllPorts); 36 | permission.Demand(); 37 | 38 | Logger.Info("Permissions created."); 39 | 40 | IPEndPoint ipEndPoint = GetIpEndpoint(ip, port); 41 | listener = new Socket( 42 | ipEndPoint.AddressFamily, 43 | SocketType.Stream, 44 | ProtocolType.Tcp); 45 | 46 | Logger.Info("Server socket created."); 47 | listener.Bind(ipEndPoint); 48 | Logger.Info("Listener binded."); 49 | port = ((IPEndPoint)listener.LocalEndPoint).Port; 50 | Logger.Info("Server started on port {0}".Format(port)); 51 | 52 | listener.Listen(10); 53 | Logger.Info("Starting server thread."); 54 | serverTask = TaskEx.Run(() => { 55 | try { 56 | Logger.Info("Server thread started."); 57 | while (true) { 58 | Logger.Info("Server Accepting."); 59 | onConnection(listener.Accept()); 60 | } 61 | } catch(Exception ex) { 62 | Logger.Error(ex.ToString()); 63 | throw; 64 | } 65 | }); 66 | 67 | return port; 68 | } 69 | 70 | private IPEndPoint GetIpEndpoint(string ip, int port) { 71 | return new IPEndPoint(IPAddress.Parse(ip), port); 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/PauseMenu.cs: -------------------------------------------------------------------------------- 1 | using Microsoft.Xna.Framework; 2 | using Patcher; 3 | using TowerFall; 4 | using TowerfallAi.Common; 5 | using TowerfallAi.Core; 6 | 7 | namespace TowerfallAi.Mod { 8 | [Patch] 9 | public class ModPauseMenu : PauseMenu { 10 | public ModPauseMenu(Level level, Vector2 position, MenuType menuType, int controllerDisconnected = -1) : base(level, position, menuType, controllerDisconnected) { } 11 | 12 | public override void VersusMatchSettingsAndSave() { 13 | if (AiMod.Enabled) { 14 | AiMod.EndSession(); 15 | return; 16 | } 17 | 18 | Util.GetAction("$original_VersusMatchSettingsAndSave", this)(); 19 | } 20 | 21 | public override void Quit() { 22 | if (AiMod.Enabled) { 23 | AiMod.EndSession(); 24 | return; 25 | } 26 | 27 | Util.GetAction("$original_Quit", this)(); 28 | } 29 | 30 | public override void VersusMatchSettings() { 31 | if (AiMod.Enabled) { 32 | AiMod.EndSession(); 33 | return; 34 | } 35 | 36 | Util.GetAction("$original_VersusMatchSettings", this)(); 37 | } 38 | 39 | public override void VersusArcherSelect() { 40 | if (AiMod.Enabled) { 41 | AiMod.EndSession(); 42 | return; 43 | } 44 | 45 | Util.GetAction("$original_VersusArcherSelect", this)(); 46 | } 47 | 48 | public override void QuestMap() { 49 | if (AiMod.Enabled) { 50 | AiMod.EndSession(); 51 | return; 52 | } 53 | 54 | Util.GetAction("$original_QuestMap", this)(); 55 | } 56 | 57 | public override void VersusRematch() { 58 | if (AiMod.Enabled) { 59 | AiMod.Rematch(); 60 | return; 61 | } 62 | 63 | Util.GetAction("$original_VersusRematch", this)(); 64 | } 65 | 66 | public override void QuestRestart() { 67 | if (AiMod.Enabled) { 68 | AiMod.Rematch(); 69 | return; 70 | } 71 | 72 | Util.GetAction("$original_QuestRestart", this)(); 73 | } 74 | 75 | public override void QuestMapAndSave() { 76 | if (AiMod.Enabled) { 77 | AiMod.EndSession(); 78 | return; 79 | } 80 | 81 | Util.GetAction("$original_QuestMapAndSave", this)(); 82 | } 83 | 84 | public override void QuitAndSave() { 85 | if (AiMod.Enabled) { 86 | AiMod.EndSession(); 87 | return; 88 | } 89 | 90 | Util.GetAction("$original_QuitAndSave", this)(); 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/TFGame.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using Microsoft.Xna.Framework; 3 | using Patcher; 4 | using TowerFall; 5 | using TowerfallAi.Core; 6 | 7 | namespace TowerfallAi.Mod { 8 | [Patch] 9 | public class ModTFGame : TFGame { 10 | // This allows identifying that TowerFall.exe is patched. 11 | public const string AiModVersion = AiMod.Version; 12 | 13 | Action originalInitialize; 14 | Action originalUpdate; 15 | 16 | [STAThread] 17 | public static void Main(string[] args) { 18 | try { 19 | AiMod.ParseArgs(args); 20 | typeof(TFGame).GetMethod("$original_Main").Invoke(null, new object[] { args }); 21 | } catch (Exception exception) { 22 | TFGame.Log(exception, false); 23 | TFGame.OpenLog(); 24 | } 25 | } 26 | 27 | public ModTFGame(bool noIntro) : base(noIntro) { 28 | var ptr = typeof(TFGame).GetMethod("$original_Initialize").MethodHandle.GetFunctionPointer(); 29 | originalInitialize = (Action)Activator.CreateInstance(typeof(Action), this, ptr); 30 | 31 | ptr = typeof(TFGame).GetMethod("$original_Update").MethodHandle.GetFunctionPointer(); 32 | originalUpdate = (Action)Activator.CreateInstance(typeof(Action), this, ptr); 33 | 34 | if (AiMod.Enabled) { 35 | this.InactiveSleepTime = new TimeSpan(0); 36 | 37 | if (AiMod.IsFastrun) { 38 | Monocle.Engine.Instance.Graphics.SynchronizeWithVerticalRetrace = false; 39 | this.IsFixedTimeStep = false; 40 | } else { 41 | this.IsFixedTimeStep = true; 42 | } 43 | } 44 | } 45 | 46 | public override void Initialize() { 47 | if (!AiMod.Enabled) { 48 | originalInitialize(); 49 | return; 50 | } 51 | 52 | AiMod.PreGameInitialize(); 53 | originalInitialize(); 54 | AiMod.PostGameInitialize(); 55 | } 56 | 57 | public override void Update(GameTime gameTime) { 58 | if (!AiMod.Enabled) { 59 | originalUpdate(gameTime); 60 | return; 61 | } 62 | 63 | AiMod.Update(originalUpdate); 64 | } 65 | 66 | public override void Draw(GameTime gameTime) { 67 | if (!AiMod.Enabled) { 68 | base.Draw(gameTime); 69 | return; 70 | } 71 | 72 | if (!AiMod.IsMatchRunning() || AiMod.NoGraphics) { 73 | Monocle.Engine.Instance.GraphicsDevice.SetRenderTarget(null); 74 | return; 75 | } 76 | base.Draw(gameTime); 77 | Monocle.Draw.SpriteBatch.Begin(); 78 | Agents.Draw(); 79 | Monocle.Draw.SpriteBatch.End(); 80 | } 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /python/gym_wrapper/player_observation.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Mapping, Optional, Sequence 2 | 3 | import numpy as np 4 | from common.constants import DASH, JUMP, SHOOT 5 | from common.entity import Entity 6 | from gym import Space, spaces 7 | 8 | from .observation import Observation 9 | 10 | 11 | class PlayerObservation(Observation): 12 | ''' 13 | Observation for the player state. 14 | ''' 15 | def __init__(self, exclude: Optional[Sequence[str]] = None): 16 | self.exclude = exclude 17 | 18 | def extend_obs_space(self, obs_space_dict: Dict[str, Space]): 19 | def try_add_obs(key, value): 20 | if self.exclude and key in self.exclude: 21 | return 22 | if key in obs_space_dict.keys(): 23 | raise Exception(f'Observation space already has {key}') 24 | obs_space_dict[key] = value 25 | 26 | try_add_obs('prev_jump', spaces.Discrete(2)) 27 | try_add_obs('prev_dash', spaces.Discrete(2)) 28 | try_add_obs('prev_shoot', spaces.Discrete(2)) 29 | 30 | try_add_obs('dodgeCooldown', spaces.Discrete(2)) 31 | try_add_obs('dodging', spaces.Discrete(2)) 32 | try_add_obs('facing', spaces.Discrete(2)) 33 | try_add_obs('onGround', spaces.Discrete(2)) 34 | try_add_obs('onWall', spaces.Discrete(2)) 35 | try_add_obs('vel', spaces.Box(low=-2, high=2, shape=(2,), dtype=np.float32)) 36 | 37 | def post_reset(self, state_scenario: Mapping[str, Any], player: Optional[Entity], entities: List[Entity], obs_dict: Dict[str, Any]): 38 | self._extend_obs(player, '', obs_dict) 39 | 40 | def post_step(self, player: Optional[Entity], entities: List[Entity], actions: str, obs_dict: Dict[str, Any]): 41 | self._extend_obs(player, actions, obs_dict) 42 | 43 | def _extend_obs(self, player: Optional[Entity], actions: str, obs_dict: Dict[str, Any]): 44 | def try_add_obs(key, value): 45 | if self.exclude and key in self.exclude: 46 | return 47 | obs_dict[key] = value 48 | 49 | try_add_obs('prev_jump', int(JUMP in actions)) 50 | try_add_obs('prev_dash', int(DASH in actions)) 51 | try_add_obs('prev_shoot', int(SHOOT in actions)) 52 | if not player: 53 | try_add_obs('dodgeCooldown', 0) 54 | try_add_obs('dodging', 0) 55 | try_add_obs('facing', 0) 56 | try_add_obs('onGround', 0) 57 | try_add_obs('onWall', 0) 58 | try_add_obs('vel', np.zeros(2)) 59 | return 60 | 61 | try_add_obs('dodgeCooldown', int(player['dodgeCooldown'])) 62 | try_add_obs('dodging', int(player['state']=='dodging')) 63 | try_add_obs('facing', (player['facing'] + 1) // 2) # -1,1 -> 0,1 64 | try_add_obs('onGround', int(player['onGround'])) 65 | try_add_obs('onWall', int(player['onWall'])) 66 | try_add_obs('vel', np.clip(player.v.numpy() / 5, -2, 2)) 67 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/AgentConfigExtension.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using TowerFall; 3 | using TowerfallAi.Common; 4 | using TowerfallAi.Data; 5 | 6 | namespace TowerfallAi.Core { 7 | public static class AgentConfigExtension { 8 | private static string[] archerColors = new string[] { 9 | "green", 10 | "blue", 11 | "pink", 12 | "orange", 13 | "white", 14 | "yellow", 15 | "teal", 16 | "purple", 17 | "red", 18 | }; 19 | 20 | public static Allegiance GetTeam(this AgentConfig agentConfig) { 21 | if (agentConfig.team == null || agentConfig.team == "neutral") return Allegiance.Neutral; 22 | if (agentConfig.team == "red") return Allegiance.Red; 23 | if (agentConfig.team == "blue") return Allegiance.Blue; 24 | 25 | throw new Exception("Invalid team '{0}'".Format(agentConfig.team)); 26 | } 27 | 28 | public static string GetTeam(Allegiance team) { 29 | if (team == Allegiance.Red) return "red"; 30 | if (team == Allegiance.Blue) return "blue"; 31 | if (team == Allegiance.Neutral) return "neutral"; ; 32 | 33 | throw new Exception("Invalid team '{0}'".Format(team)); 34 | } 35 | 36 | public static int GetArcherIndex(this AgentConfig agentConfig) { 37 | if (string.IsNullOrWhiteSpace(agentConfig.archer)) { 38 | return AiMod.Random.Next(0, 8); 39 | } 40 | 41 | var parts = agentConfig.archer.Split('-'); 42 | if (parts.Length < 1 || parts.Length > 2) { 43 | throw new Exception("Invalid archer '{0}'".Format(agentConfig.archer)); 44 | } 45 | 46 | int archerIndex; 47 | if (int.TryParse(parts[0], out archerIndex)) { 48 | if (archerIndex < 0 || archerIndex >= archerColors.Length) { 49 | throw new Exception("Invalid archer '{0}'".Format(agentConfig.archer)); 50 | } 51 | 52 | return archerIndex; 53 | } 54 | 55 | for (int i = 0; i < archerColors.Length; i++) { 56 | if (archerColors[i] == parts[0]) { 57 | return i; 58 | } 59 | } 60 | 61 | throw new Exception("Invalid archer '{0}'".Format(agentConfig.archer)); 62 | } 63 | 64 | public static ArcherData.ArcherTypes GetArcherType(this AgentConfig agentConfig) { 65 | if (string.IsNullOrWhiteSpace(agentConfig.archer)) { 66 | return ArcherData.ArcherTypes.Normal; 67 | } 68 | 69 | var parts = agentConfig.archer.Split('-'); 70 | if (parts.Length != 2) { 71 | return ArcherData.ArcherTypes.Normal; 72 | } 73 | 74 | if (parts[1] == "alt") { 75 | return ArcherData.ArcherTypes.Alt; 76 | } 77 | 78 | if (parts[1] == "secret") { 79 | return ArcherData.ArcherTypes.Secret; 80 | } 81 | 82 | throw new Exception("Invalid archer '{0}'".Format(agentConfig.archer)); 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /TowerFallAi/Patcher/Patcher.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Debug 5 | x86 6 | {B1DF3728-5F4C-4987-8213-4DCE21255B0F} 7 | Exe 8 | Patcher 9 | Patcher 10 | 11 | 12 | ..\bin\PatchWindows 13 | true 14 | full 15 | false 16 | DEBUG; 17 | prompt 18 | 4 19 | true 20 | x86 21 | 22 | 23 | ..\bin\PatchWindows 24 | true 25 | full 26 | false 27 | DEBUG; 28 | prompt 29 | 4 30 | true 31 | x86 32 | 33 | 34 | 35 | ..\packages\CommandLineParser.2.9.1\lib\net40\CommandLine.dll 36 | 37 | 38 | ..\packages\Mono.Cecil.0.11.5\lib\net40\Mono.Cecil.dll 39 | 40 | 41 | ..\packages\Newtonsoft.Json.13.0.1\lib\net40\Newtonsoft.Json.dll 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | {fae08b42-bbe6-4177-9110-6dde1032c415} 60 | PatcherLib 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /python/towerfall/connection.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import socket 4 | from typing import Any, Callable, Mapping 5 | 6 | _BYTE_ORDER = 'big' 7 | _ENCODING = 'ascii' 8 | _LOCALHOST = '127.0.0.1' 9 | 10 | class Connection: 11 | ''' 12 | Connection to a Towerfall server. It is used to send and receive messages. 13 | 14 | params port: Port of the server. 15 | params ip: Ip address of the server. 16 | params timeout: Timeout for the socket. 17 | params verbose: Verbosity level. 0: no logging, 1: much logging. 18 | params log_cap: Maximum number of characters to log. 19 | params record_path: Path to a file to record the messages sent and received. 20 | ''' 21 | def __init__(self, port: int, ip: str = _LOCALHOST, timeout: float = 0, verbose=0, log_cap=100, record_path=None): 22 | self.verbose = verbose 23 | self.log_cap = log_cap 24 | self.record_path = record_path 25 | self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 26 | self._socket.connect((ip, port)) 27 | if timeout: 28 | self._socket.settimeout(timeout) 29 | self.port = port 30 | self.on_close: Callable 31 | 32 | def __del__(self): 33 | self.close() 34 | 35 | def close(self): 36 | ''' 37 | Closes the socket. 38 | ''' 39 | if hasattr(self, '_socket'): 40 | if self.verbose > 0: 41 | logging.info('Closing socket') 42 | self._socket.close() 43 | del self._socket 44 | if hasattr(self, 'on_close'): 45 | self.on_close() 46 | 47 | def write(self, msg: str): 48 | ''' 49 | Writes a new message following the game's protocol. 50 | ''' 51 | size = len(msg) 52 | if self.verbose > 0: 53 | logging.info('Writing: %sB %s', size, self._cap(msg)) 54 | 55 | self._socket.sendall(size.to_bytes(2, byteorder=_BYTE_ORDER)) 56 | self._socket.sendall(msg.encode(_ENCODING)) 57 | if self.record_path: 58 | with open(self.record_path, 'a') as file: 59 | file.write(msg + '\n') 60 | 61 | def read(self) -> str: 62 | ''' 63 | Reads a message following the game's protocol. 64 | ''' 65 | try: 66 | header: bytes = self._socket.recv(2) 67 | size = int.from_bytes(header, _BYTE_ORDER) 68 | if size == 0: 69 | raise ConnectionError('Connection is closed') 70 | payload = self._socket.recv(size) 71 | resp = payload.decode(_ENCODING) 72 | if self.verbose > 0: 73 | logging.info('Read: %dB %s', size, self._cap(resp)) 74 | if self.record_path: 75 | with open(self.record_path, 'a') as file: 76 | file.write(resp + '\n') 77 | return resp 78 | except socket.timeout as ex: 79 | logging.error(f'Socket timeout {self._socket.getsockname()}') 80 | raise ex 81 | 82 | def read_json(self) -> Mapping[str, Any]: 83 | ''' 84 | Reads a message and parses it to json. 85 | ''' 86 | return json.loads(self.read()) 87 | 88 | def send_json(self, obj: Mapping[str, Any]): 89 | ''' 90 | Convert the object to json and writes it. 91 | ''' 92 | self.write(json.dumps(obj)) 93 | 94 | def _cap(self, value: str) -> str: 95 | return value[:self.log_cap] + '...' if len(value) > self.log_cap else value 96 | -------------------------------------------------------------------------------- /python/common/entity.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | from math import sqrt 5 | from typing import Any, Dict, List, Tuple 6 | 7 | import numpy as np 8 | from numpy.typing import NDArray 9 | 10 | 11 | class Entity: 12 | def __init__(self, e: Dict[str, Any]): 13 | self.p: Vec2 = vec2_from_dict(e['pos']) 14 | self.v: Vec2 = vec2_from_dict(e['vel']) 15 | self.s: Vec2 = vec2_from_dict(e['size']) 16 | self.isEnemy: bool = e['isEnemy'] 17 | self.type: str = e['type'] 18 | self.e: Any = e 19 | 20 | def __getitem__(self, key): 21 | return self.e[key] 22 | 23 | def bot_left(self) -> Vec2: 24 | return Vec2(self.p.x - self.s.x / 2, self.p.y - self.s.y / 2) 25 | 26 | def top_right(self) -> Vec2: 27 | return Vec2(self.p.x + self.s.x / 2, self.p.y + self.s.y / 2) 28 | 29 | 30 | class Vec2: 31 | def __init__(self, x: float, y: float): 32 | self.x: float = x 33 | self.y: float = y 34 | 35 | def __str__(self): 36 | return 'Vec2({}, {})'.format(self.x, self.y) 37 | 38 | def __hash__(self): 39 | return hash((self.x, self.y)) 40 | 41 | def __add__(self, o): 42 | return Vec2(self.x + o.x, self.y + o.y) 43 | 44 | def __sub__(self, o): 45 | return Vec2(self.x - o.x, self.y - o.y) 46 | 47 | def __neg__(self): 48 | return Vec2(-self.x, -self.y) 49 | 50 | def __mul__(self, o): 51 | if isinstance(o, float) or isinstance(o, int): 52 | return Vec2(self.x * o, self.y * o) 53 | raise NotImplementedError() 54 | 55 | def __truediv__(self, o): 56 | if isinstance(o, float) or isinstance(o, int): 57 | return Vec2(self.x / o, self.y / o) 58 | raise NotImplementedError() 59 | 60 | def __eq__(self, other): 61 | if isinstance(other, Vec2): 62 | if self.x != other.x: 63 | return False 64 | if self.y != other.y: 65 | return False 66 | return True 67 | return NotImplemented 68 | 69 | def tupleint(self) -> Tuple[int, int]: 70 | return int(self.x), int(self.y) 71 | 72 | def numpy(self) -> NDArray[np.float32]: 73 | return np.array([self.x, self.y], dtype=np.float32) 74 | 75 | def dict(self) -> dict[str, float]: 76 | return dict(x=self.x, y=self.y) 77 | 78 | def set_length(self, l: float): 79 | d = self.length() 80 | self.x *= l/d 81 | self.y *= l/d 82 | 83 | def length(self): 84 | return sqrt(self.x**2 + self.y**2) 85 | 86 | def copy(self): 87 | return Vec2(self.x, self.y) 88 | 89 | def add(self, v: Vec2): 90 | self.x += v.x 91 | self.y += v.y 92 | 93 | def sub(self, v: Vec2): 94 | self.x -= v.x 95 | self.y -= v.y 96 | 97 | def mul(self, f: float): 98 | self.x *= f 99 | self.y *= f 100 | 101 | def div(self, f: float): 102 | self.x /= f 103 | self.y /= f 104 | 105 | 106 | def vec2_from_dict(p: Dict[str, Any]) -> Vec2: 107 | try: 108 | return Vec2(p['x'], p['y']) 109 | except KeyError: 110 | sys.stderr.write(str(p)) 111 | raise 112 | 113 | 114 | def to_entities(entities: List[Dict[str, Any]]) -> List[Entity]: 115 | result = [] 116 | for e in entities: 117 | result.append(Entity(e)) 118 | 119 | return result 120 | -------------------------------------------------------------------------------- /python/reset_test.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Mapping 2 | 3 | from common.logging_options import default_logging 4 | from common.test import assert_initial_state 5 | from towerfall import Towerfall 6 | 7 | from agents import SimpleAgent 8 | 9 | default_logging() 10 | 11 | ''' 12 | Test reset for all supported entities. 13 | ''' 14 | 15 | _TEST_ENTITIES = [ 16 | dict(type='bat', facing=-1, pos=dict(x=220, y=160)), 17 | dict(type='batBomb', facing=-1, pos=dict(x=220, y=160)), 18 | dict(type='batSuperBomb', facing=-1, pos=dict(x=220, y=160)), 19 | dict(type='bird', facing=-1, pos=dict(x=220, y=160)), 20 | dict(type='slime', subType="blue", facing=-1, pos=dict(x=220, y=110)), 21 | dict(type='slime', subType="green", facing=-1, pos=dict(x=220, y=110)), 22 | dict(type='slime', subType="red", facing=-1, pos=dict(x=220, y=110)), 23 | dict(type='birdman', facing=-1, pos=dict(x=220, y=160)), 24 | dict(type='cultist', subtype='boss', facing=-1, pos=dict(x=220, y=120)), 25 | dict(type='cultist', subtype='normal', facing=-1, pos=dict(x=220, y=120)), 26 | dict(type='cultist', subtype='scythe', facing=-1, pos=dict(x=220, y=120)), 27 | dict(type='evilCrystal', subtype='blue', facing=-1, pos=dict(x=220, y=180)), 28 | dict(type='evilCrystal', subtype='green', facing=-1, pos=dict(x=220, y=180)), 29 | dict(type='evilCrystal', subtype='pink', facing=-1, pos=dict(x=220, y=180)), 30 | dict(type='evilCrystal', subtype='red', facing=-1, pos=dict(x=220, y=180)), 31 | dict(type='exploder', facing=-1, pos=dict(x=220, y=180)), 32 | dict(type='flamingSkull', facing=-1, pos=dict(x=220, y=180)), 33 | dict(type='ghost', subType='blue', facing=-1, pos=dict(x=220, y=180)), 34 | dict(type='ghost', subType='fire', facing=-1, pos=dict(x=220, y=180)), 35 | dict(type='ghost', subType='green', facing=-1, pos=dict(x=220, y=180)), 36 | dict(type='ghost', subType='greenFire', facing=-1, pos=dict(x=220, y=180)), 37 | dict(type='mole', facing=-1, pos=dict(x=220, y=120)), 38 | dict(type='worm', pos=dict(x=220, y=120)), 39 | ] 40 | 41 | 42 | def main(): 43 | # Creates or reuse a Towerfall game. 44 | towerfall = Towerfall( 45 | verbose = 1, 46 | config = dict( 47 | mode='sandbox', 48 | level='2', 49 | fps=999, 50 | agentTimeout='00:00:02', 51 | solids=[[0] * 32]*14 + [[1]*32] + [[0] * 32]*9, 52 | agents=[dict(type='remote')] 53 | ) 54 | ) 55 | 56 | connection = towerfall.join(timeout=10, verbose=1) 57 | agent = SimpleAgent(connection) 58 | 59 | n_frames = 180 60 | for i in range(0, len(_TEST_ENTITIES)): 61 | reset_entities = send_reset(towerfall, i) 62 | if i > 0: 63 | # We still need to ack the last frame, or the agent will get disconnected. 64 | game_state = connection.read_json() 65 | connection.send_json(dict(type='actions', actions='', id=game_state['id'])) 66 | asserted = False 67 | for _ in range(n_frames): 68 | # Read the state of the game then replies with an action. 69 | game_state = connection.read_json() 70 | if not asserted and game_state['type'] == 'update': 71 | assert_initial_state(reset_entities, game_state['entities']) 72 | asserted = True 73 | agent.act(game_state) 74 | 75 | 76 | def send_reset(towerfall: Towerfall, i: int) -> List[Dict[str, Any]]: 77 | reset_entities = get_reset_entities(i) 78 | towerfall.send_reset(reset_entities) 79 | return reset_entities 80 | 81 | 82 | def get_reset_entities(i: int) -> List[Dict[str, Any]]: 83 | return [dict(type='archer', pos=dict(x=120, y=110)), _TEST_ENTITIES[i]] 84 | 85 | 86 | if __name__ == '__main__': 87 | main() -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/AsyncQueue.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Threading; 4 | using System.Threading.Tasks; 5 | 6 | namespace TowerfallAi.Common { 7 | public class AsyncQueue { 8 | private Queue> completedTasks = new Queue>(); 9 | private Queue> waitingTasks = new Queue>(); 10 | private object queueLock = new object(); 11 | public bool IsClosed { get; private set; } 12 | 13 | public Task DequeueAsync(TimeSpan timeout = default, CancellationToken cancellationToken = default) { 14 | lock (queueLock) { 15 | if (completedTasks.Count > 0) { 16 | return completedTasks.Dequeue().Task; 17 | } 18 | 19 | if (IsClosed) { 20 | throw new InvalidOperationException("Can't dequeue from a closed AsyncQueue."); 21 | } 22 | 23 | var tcs = new TaskCompletionSource(); 24 | if (cancellationToken != default) { 25 | cancellationToken.Register(() => HandleManualCancellation(tcs)); 26 | } 27 | 28 | if (timeout != default) { 29 | var timoutTask = FailAfterTimeout(timeout, tcs); 30 | } 31 | 32 | waitingTasks.Enqueue(tcs); 33 | return tcs.Task; 34 | } 35 | } 36 | 37 | public void EnqueueException(Exception ex) { 38 | TaskCompletionSource tcs; 39 | lock (queueLock) { 40 | if (IsClosed) { 41 | throw new InvalidOperationException("Can't dequeue from a closed AsyncQueue."); 42 | } 43 | 44 | tcs = DequeueTillNextRunning(); 45 | 46 | if (tcs == null) { 47 | tcs = new TaskCompletionSource(); 48 | completedTasks.Enqueue(tcs); 49 | } 50 | } 51 | tcs.SetException(ex); 52 | } 53 | 54 | public void Enqueue(T result) { 55 | TaskCompletionSource tcs = null; 56 | lock (queueLock) { 57 | if (IsClosed) { 58 | throw new InvalidOperationException("Can't dequeue from a closed AsyncQueue."); 59 | } 60 | 61 | tcs = DequeueTillNextRunning(); 62 | 63 | if (tcs == null) { 64 | tcs = new TaskCompletionSource(); 65 | completedTasks.Enqueue(tcs); 66 | } 67 | } 68 | tcs.SetResult(result); 69 | } 70 | 71 | public void Close() { 72 | lock (queueLock) { 73 | IsClosed = true; 74 | } 75 | } 76 | 77 | private TaskCompletionSource DequeueTillNextRunning() { 78 | // Look for existing requests for results in waitingTasks. Some might have been cancelled or timed out. 79 | // So select the ones that are still running. 80 | if (waitingTasks.Count > 0) { 81 | TaskCompletionSource tcs = waitingTasks.Dequeue(); 82 | if (tcs.Task.IsAlive()) { 83 | return tcs; 84 | } 85 | } 86 | 87 | return null; 88 | } 89 | 90 | private void HandleManualCancellation(TaskCompletionSource tcs) { 91 | lock (queueLock) { 92 | if (tcs.Task.IsAlive()) { 93 | tcs.SetCanceled(); 94 | } 95 | } 96 | } 97 | 98 | private async Task FailAfterTimeout(TimeSpan timeout, TaskCompletionSource tcs) { 99 | await TaskEx.Delay(timeout); 100 | Console.WriteLine($"Timeout: {timeout} {tcs.Task.Status}"); 101 | lock (queueLock) { 102 | if (tcs.Task.IsAlive()) { 103 | tcs.SetException(new TimeoutException($"Timeout: {timeout}")); 104 | Console.WriteLine("Task was set to time out"); 105 | } 106 | } 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/AgentConnection.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Threading; 4 | using System.Threading.Tasks; 5 | using TowerFall; 6 | using TowerfallAi.Common; 7 | using TowerfallAi.Data; 8 | using TowerfallAiMod.Core; 9 | 10 | namespace TowerfallAi.Core { 11 | public abstract class AgentConnection : KeyboardInput, IDisposable { 12 | public int index; 13 | InputState prevInputState; 14 | public InputState InputState; 15 | 16 | public AgentConnection(int index) : base(KeyboardConfigs.Configs[index], index) { 17 | this.index = index; 18 | this.Config = KeyboardConfigs.Configs[index]; 19 | } 20 | 21 | protected InputState GetCopy(InputState inputState) { 22 | return new InputState { 23 | AimAxis = inputState.AimAxis, 24 | ArrowsPressed = inputState.ArrowsPressed, 25 | DodgeCheck = inputState.DodgeCheck, 26 | DodgePressed = inputState.DodgePressed, 27 | JumpCheck = inputState.JumpCheck, 28 | JumpPressed = inputState.JumpPressed, 29 | MoveX = inputState.MoveX, 30 | MoveY = inputState.MoveY, 31 | ShootCheck = inputState.ShootCheck, 32 | ShootPressed = inputState.ShootPressed 33 | }; 34 | } 35 | 36 | public override InputState GetState() { 37 | return GetCopy(this.InputState); 38 | } 39 | 40 | protected void ProcessResponse(string response) { 41 | if (response == null || response.Length == 0) return; 42 | 43 | this.InputState.AimAxis.X = 0; 44 | this.InputState.MoveX = 0; 45 | this.InputState.AimAxis.Y = 0; 46 | this.InputState.MoveY = 0; 47 | 48 | var responseChars = new HashSet(); 49 | foreach (char c in response) { 50 | responseChars.Add(c); 51 | } 52 | 53 | foreach(char c in responseChars) { 54 | ProcessAction(c); 55 | } 56 | } 57 | 58 | protected void ProcessAction(char c) { 59 | switch (c) { 60 | case 'j': 61 | this.InputState.JumpCheck = true; 62 | this.InputState.JumpPressed = !this.prevInputState.JumpCheck; 63 | break; 64 | case 's': 65 | this.InputState.ShootCheck = true; 66 | this.InputState.ShootPressed = !this.prevInputState.ShootCheck; 67 | break; 68 | case 'z': 69 | this.InputState.DodgeCheck = true; 70 | this.InputState.DodgePressed = !this.prevInputState.DodgeCheck; 71 | break; 72 | case 'a': 73 | this.InputState.ArrowsPressed = true; 74 | break; 75 | case 'u': 76 | this.InputState.AimAxis.Y = -1; 77 | this.InputState.MoveY = -1; 78 | break; 79 | case 'd': 80 | this.InputState.AimAxis.Y += 1; 81 | this.InputState.MoveY += 1; 82 | break; 83 | case 'l': 84 | this.InputState.AimAxis.X -= 1; 85 | this.InputState.MoveX -= 1; 86 | break; 87 | case 'r': 88 | this.InputState.MoveX += 1; 89 | this.InputState.AimAxis.X += 1; 90 | break; 91 | } 92 | } 93 | 94 | public void UpdateGameInput(string response) { 95 | this.InputState = new InputState(); 96 | ProcessResponse(response); 97 | this.prevInputState = GetCopy(this.InputState); 98 | } 99 | 100 | public abstract void Send(string message, int frame); 101 | 102 | public abstract Task ReceiveAsync(TimeSpan timeout = default, CancellationToken cancellationToken = default); 103 | 104 | public abstract void Dispose(); 105 | 106 | protected void Disconnect(string message) { 107 | Logger.Error(message); 108 | this.Dispose(); 109 | throw new Exception("Disconnected"); 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Api/Types.cs: -------------------------------------------------------------------------------- 1 | using TowerFall; 2 | using TowerfallAi.Common; 3 | 4 | namespace TowerfallAi.Api { 5 | public class ConversionTypes { 6 | public static DoubleDictionary BatTypes = new DoubleDictionary(); 7 | public static DoubleDictionary SlimeTypes = new DoubleDictionary(); 8 | public static DoubleDictionary CultistTypes = new DoubleDictionary(); 9 | public static DoubleDictionary CrystalTypes = new DoubleDictionary(); 10 | public static DoubleDictionary GhostTypes = new DoubleDictionary(); 11 | 12 | static ConversionTypes() { 13 | BatTypes.Add(Bat.BatType.Eye, "bat"); 14 | BatTypes.Add(Bat.BatType.Bomb, "batBomb"); 15 | BatTypes.Add(Bat.BatType.SuperBomb, "batSuperBomb"); 16 | BatTypes.Add(Bat.BatType.Bird, "bird"); 17 | 18 | SlimeTypes.Add(Slime.SlimeColors.Blue, "blue"); 19 | SlimeTypes.Add(Slime.SlimeColors.Green, "green"); 20 | SlimeTypes.Add(Slime.SlimeColors.Red, "red"); 21 | 22 | CultistTypes.Add(Cultist.CultistTypes.Boss, "boss"); 23 | CultistTypes.Add(Cultist.CultistTypes.Normal, "normal"); 24 | CultistTypes.Add(Cultist.CultistTypes.Scythe, "scythe"); 25 | 26 | CrystalTypes.Add(EvilCrystal.CrystalColors.Blue, "blue"); 27 | CrystalTypes.Add(EvilCrystal.CrystalColors.Green, "green"); 28 | CrystalTypes.Add(EvilCrystal.CrystalColors.Pink, "pink"); 29 | CrystalTypes.Add(EvilCrystal.CrystalColors.Red, "red"); 30 | 31 | GhostTypes.Add(Ghost.GhostTypes.Blue, "blue"); 32 | GhostTypes.Add(Ghost.GhostTypes.Fire, "fire"); 33 | GhostTypes.Add(Ghost.GhostTypes.Green, "green"); 34 | GhostTypes.Add(Ghost.GhostTypes.GreenFire, "greenFire"); 35 | } 36 | } 37 | 38 | public class Types { 39 | public const string Arrow = "arrow"; 40 | public const string Bramble = "bramble"; 41 | public const string CrackedPlatform = "crackedPlatform"; 42 | public const string CrackedWall = "crackedWall"; 43 | public const string Crown = "crown"; 44 | public const string Hat = "hat"; 45 | public const string EvilCrystal = "evilCrystal"; 46 | public const string ExpandingPlatform = "expandingPlatform"; 47 | public const string ExpandingTrigger = "expandingTrigger"; 48 | public const string Explosion = "explosion"; 49 | public const string FakeWall = "fakeWall"; 50 | public const string FloorMiasma = "floorMiasma"; 51 | public const string GraniteBlock = "graniteBlock"; 52 | public const string HotCoal = "hotCoal"; 53 | public const string Ice = "ice"; 54 | public const string Icicle = "icicle"; 55 | public const string JumpPad = "jumpPad"; 56 | public const string Lantern = "lantern"; 57 | public const string Lava = "lava"; 58 | public const string Miasma = "miasma"; 59 | public const string MoonGlassBlock = "moonGlassBlock"; 60 | public const string MovingPlatform = "movingPlatform"; 61 | public const string Orb = "orb"; 62 | public const string ProximityBlock = "proximityBlock"; 63 | public const string Item = "item"; 64 | public const string ShiftBlock = "shiftBlock"; 65 | public const string SpikeBall = "spikeBall"; 66 | public const string SwitchBlock = "switchBlock"; 67 | public const string Chest = "chest"; 68 | } 69 | 70 | public class TypesItems { 71 | public const string Bomb = "bomb"; 72 | public const string Mirror = "mirror"; 73 | public const string Shield = "shield"; 74 | public const string Wings = "wings"; 75 | 76 | public const string OrbDark = "orbDark"; 77 | public const string OrbTime = "orbTime"; 78 | public const string OrbLava = "orbLava"; 79 | public const string OrbSpace = "orbSpace"; 80 | public const string OrbChaos = "orbChaos"; 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /python/gym_wrapper/kill_enemy_objective.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, Dict, List, Optional 3 | 4 | import numpy as np 5 | from common.constants import HH, HW 6 | from common.entity import Entity, Vec2 7 | from gym import Space, spaces 8 | 9 | from .objective import Objective 10 | 11 | 12 | class KillEnemyObjective(Objective): 13 | ''' 14 | Specifies observation and rewards associated with killing slimes. 15 | ''' 16 | def __init__(self, 17 | enemy_type:str = 'slime', 18 | enemy_count: int = 1, 19 | min_distance: float = 50, 20 | max_distance: float = 100, 21 | bounty = 5, 22 | episode_max_len: int=60*2): 23 | super().__init__() 24 | self.enemy_type = enemy_type 25 | self.enemy_count = enemy_count 26 | self.target_ids = [] 27 | self.min_distance = min_distance 28 | self.max_distance = max_distance 29 | self.bounty = bounty 30 | self.episode_max_len = episode_max_len 31 | self.episode_len = 0 32 | self.obs_space = spaces.Box(low=-1, high = 1, shape=(3*self.enemy_count,), dtype=np.float32) 33 | 34 | def extend_obs_space(self, obs_space_dict: Dict[str, Space]): 35 | if 'targets' in obs_space_dict: 36 | raise Exception('Observation space already has \'target\'') 37 | obs_space_dict['targets'] = self.obs_space 38 | 39 | def get_reset_entities(self) -> Optional[List[Dict[str, Any]]]: 40 | p = Vec2(160, 110) 41 | entities: List[Dict[str, Any]] = [dict(type='archer', pos=p.dict())] 42 | for i in range(self.enemy_count): 43 | sign = random.randint(0, 1)*2 - 1 44 | d = random.uniform(self.min_distance, self.max_distance) * sign 45 | enemy = dict( 46 | type=self.enemy_type, 47 | pos=(p + Vec2(d, -5)).dict(), 48 | facing=-sign) 49 | entities.append(enemy) 50 | return entities 51 | 52 | def post_reset(self, state_scenario: Dict[str, Any], player: Optional[Entity], entities: List[Entity], obs_dict: Dict[str, Any]): 53 | assert player 54 | targets = list(e for e in entities if e['type'] == self.enemy_type) 55 | assert len(targets) > 0, 'No targets found' 56 | self.target_ids = [t['id'] for t in targets] 57 | self.n_targets_prev = len(self.target_ids) 58 | 59 | self.done = False 60 | self.episode_len = 0 61 | self._update_obs(player, targets, obs_dict) 62 | 63 | 64 | def post_step(self, player: Optional[Entity], entities: List[Entity], actions: str, obs_dict: Dict[str, Any]): 65 | targets = list(e for e in entities if e['id'] in self.target_ids) 66 | self._update_reward(player, targets) 67 | self.episode_len += 1 68 | self._update_obs(player, targets, obs_dict) 69 | 70 | def _update_reward(self, player: Optional[Entity], targets: List[Entity]): 71 | ''' 72 | Updates the reward and checks if the episode is done. 73 | ''' 74 | self.reward = 0 75 | if not player or self.episode_len >= self.episode_max_len: 76 | self.done = True 77 | self.reward -= self.bounty / 5 78 | 79 | if len(targets) < self.n_targets_prev: 80 | self.reward = self.bounty * (self.n_targets_prev - len(targets)) 81 | 82 | self.n_targets_prev = len(targets) 83 | if self.n_targets_prev == 0: 84 | self.done = True 85 | 86 | def limit(self, x: float, a: float, b: float) -> float: 87 | return x+b-a if x < a else x-b+a if x > b else x 88 | 89 | def _update_obs(self, player: Optional[Entity], targets: List[Entity], obs_dict: Dict[str, Any]): 90 | obs_target = np.zeros((3*self.enemy_count,), dtype=np.float32) 91 | if not player: 92 | obs_dict['targets'] = obs_target 93 | return 94 | 95 | target_by_dist = [] 96 | for i, target in enumerate(targets): 97 | target_by_dist.append(((target.p - player.p).length(), target)) 98 | 99 | target_by_dist.sort(key=lambda x: x[0]) 100 | for i, (_, target) in enumerate(target_by_dist): 101 | obs_target[i*3] = 1 102 | obs_target[i*3 + 1] = self.limit(target.p.x - player.p.x / HW, -1, 1) 103 | obs_target[i*3 + 2] = self.limit(target.p.y - player.p.y / HH, -1, 1) 104 | obs_dict['targets'] = obs_target 105 | -------------------------------------------------------------------------------- /TowerFallAi/TowerFallAi.sln: -------------------------------------------------------------------------------- 1 | 2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 16 4 | VisualStudioVersion = 16.0.31727.386 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Patcher", "Patcher\Patcher.csproj", "{B1DF3728-5F4C-4987-8213-4DCE21255B0F}" 7 | ProjectSection(ProjectDependencies) = postProject 8 | {5C1AA190-DD28-41A1-8DD2-FEF3A9DEE28A} = {5C1AA190-DD28-41A1-8DD2-FEF3A9DEE28A} 9 | EndProjectSection 10 | EndProject 11 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TowerfallAiMod", "TowerfallAiMod\TowerfallAiMod.csproj", "{5C1AA190-DD28-41A1-8DD2-FEF3A9DEE28A}" 12 | EndProject 13 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "PatcherLib", "PatcherLib\PatcherLib.csproj", "{FAE08B42-BBE6-4177-9110-6DDE1032C415}" 14 | EndProject 15 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TowerfallAiModTests", "TowerfallAiModTests\TowerfallAiModTests.csproj", "{19CFA5FC-A793-46EF-9F34-88F6C9D8C19F}" 16 | EndProject 17 | Global 18 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 19 | MakeBaseImageWindows|Any CPU = MakeBaseImageWindows|Any CPU 20 | MakeBaseImageWindows|x86 = MakeBaseImageWindows|x86 21 | PatchWindows|Any CPU = PatchWindows|Any CPU 22 | PatchWindows|x86 = PatchWindows|x86 23 | EndGlobalSection 24 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 25 | {B1DF3728-5F4C-4987-8213-4DCE21255B0F}.MakeBaseImageWindows|Any CPU.ActiveCfg = MakeBaseImageWindows|x86 26 | {B1DF3728-5F4C-4987-8213-4DCE21255B0F}.MakeBaseImageWindows|Any CPU.Build.0 = MakeBaseImageWindows|x86 27 | {B1DF3728-5F4C-4987-8213-4DCE21255B0F}.MakeBaseImageWindows|x86.ActiveCfg = MakeBaseImage|x86 28 | {B1DF3728-5F4C-4987-8213-4DCE21255B0F}.MakeBaseImageWindows|x86.Build.0 = MakeBaseImage|x86 29 | {B1DF3728-5F4C-4987-8213-4DCE21255B0F}.PatchWindows|Any CPU.ActiveCfg = PatchWindows|x86 30 | {B1DF3728-5F4C-4987-8213-4DCE21255B0F}.PatchWindows|Any CPU.Build.0 = PatchWindows|x86 31 | {B1DF3728-5F4C-4987-8213-4DCE21255B0F}.PatchWindows|x86.ActiveCfg = Patch|x86 32 | {B1DF3728-5F4C-4987-8213-4DCE21255B0F}.PatchWindows|x86.Build.0 = Patch|x86 33 | {5C1AA190-DD28-41A1-8DD2-FEF3A9DEE28A}.MakeBaseImageWindows|Any CPU.ActiveCfg = PatchWindows|Any CPU 34 | {5C1AA190-DD28-41A1-8DD2-FEF3A9DEE28A}.MakeBaseImageWindows|x86.ActiveCfg = Debug|Any CPU 35 | {5C1AA190-DD28-41A1-8DD2-FEF3A9DEE28A}.MakeBaseImageWindows|x86.Build.0 = Debug|Any CPU 36 | {5C1AA190-DD28-41A1-8DD2-FEF3A9DEE28A}.PatchWindows|Any CPU.ActiveCfg = PatchWindows|Any CPU 37 | {5C1AA190-DD28-41A1-8DD2-FEF3A9DEE28A}.PatchWindows|Any CPU.Build.0 = PatchWindows|Any CPU 38 | {5C1AA190-DD28-41A1-8DD2-FEF3A9DEE28A}.PatchWindows|x86.ActiveCfg = Patch|Any CPU 39 | {5C1AA190-DD28-41A1-8DD2-FEF3A9DEE28A}.PatchWindows|x86.Build.0 = Patch|Any CPU 40 | {FAE08B42-BBE6-4177-9110-6DDE1032C415}.MakeBaseImageWindows|Any CPU.ActiveCfg = PatchWindows|x86 41 | {FAE08B42-BBE6-4177-9110-6DDE1032C415}.MakeBaseImageWindows|x86.ActiveCfg = Debug|x86 42 | {FAE08B42-BBE6-4177-9110-6DDE1032C415}.MakeBaseImageWindows|x86.Build.0 = Debug|x86 43 | {FAE08B42-BBE6-4177-9110-6DDE1032C415}.PatchWindows|Any CPU.ActiveCfg = PatchWindows|x86 44 | {FAE08B42-BBE6-4177-9110-6DDE1032C415}.PatchWindows|Any CPU.Build.0 = PatchWindows|x86 45 | {FAE08B42-BBE6-4177-9110-6DDE1032C415}.PatchWindows|x86.ActiveCfg = Patch|x86 46 | {FAE08B42-BBE6-4177-9110-6DDE1032C415}.PatchWindows|x86.Build.0 = Patch|x86 47 | {19CFA5FC-A793-46EF-9F34-88F6C9D8C19F}.MakeBaseImageWindows|Any CPU.ActiveCfg = Debug|Any CPU 48 | {19CFA5FC-A793-46EF-9F34-88F6C9D8C19F}.MakeBaseImageWindows|Any CPU.Build.0 = Debug|Any CPU 49 | {19CFA5FC-A793-46EF-9F34-88F6C9D8C19F}.MakeBaseImageWindows|x86.ActiveCfg = Debug|Any CPU 50 | {19CFA5FC-A793-46EF-9F34-88F6C9D8C19F}.MakeBaseImageWindows|x86.Build.0 = Debug|Any CPU 51 | {19CFA5FC-A793-46EF-9F34-88F6C9D8C19F}.PatchWindows|Any CPU.ActiveCfg = Debug|Any CPU 52 | {19CFA5FC-A793-46EF-9F34-88F6C9D8C19F}.PatchWindows|Any CPU.Build.0 = Debug|Any CPU 53 | {19CFA5FC-A793-46EF-9F34-88F6C9D8C19F}.PatchWindows|x86.ActiveCfg = Debug|Any CPU 54 | {19CFA5FC-A793-46EF-9F34-88F6C9D8C19F}.PatchWindows|x86.Build.0 = Debug|Any CPU 55 | EndGlobalSection 56 | GlobalSection(SolutionProperties) = preSolution 57 | HideSolutionNode = FALSE 58 | EndGlobalSection 59 | GlobalSection(ExtensibilityGlobals) = postSolution 60 | SolutionGuid = {12C54BD4-0FFF-4AB2-8A9D-D85C13A493FD} 61 | EndGlobalSection 62 | EndGlobal 63 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/Util.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | using System.Reflection; 4 | using Microsoft.Xna.Framework.Graphics; 5 | 6 | namespace TowerfallAi.Common { 7 | public static class Util { 8 | public static string PathCombine(params string[] paths) { 9 | return string.Join("/", paths); 10 | } 11 | 12 | public static string[] GetPathSegments(string path) { 13 | string[] segments = path.Split(new char[] { '/', '\\' }, StringSplitOptions.RemoveEmptyEntries); 14 | return segments; 15 | } 16 | 17 | public static string GetDirectoryPath(string path) { 18 | int i = Math.Max(path.LastIndexOf('/'), path.LastIndexOf('\\')); 19 | 20 | return path.Substring(0, i); 21 | } 22 | 23 | public static void CreateDirectory(string path) { 24 | string[] segments = path.Split(new string[] { "/", "\\" }, StringSplitOptions.None); 25 | path = null; 26 | 27 | foreach (string segment in segments) { 28 | if (path == null) { 29 | path = segment; 30 | } else { 31 | path = PathCombine(path, segment); 32 | } 33 | 34 | if (!Directory.Exists(path)) { 35 | if (string.IsNullOrEmpty(path)) continue; 36 | Directory.CreateDirectory(path); 37 | } 38 | } 39 | } 40 | 41 | public static void DeleteFolderIfExists(string path) { 42 | if (Directory.Exists(path)) { 43 | Directory.Delete(path, true); 44 | } 45 | } 46 | 47 | public static void DeleteFileIfExists(string path) { 48 | if (File.Exists(path)) { 49 | File.Delete(path); 50 | } 51 | } 52 | 53 | public static void ParseExecute(string p, out string program, out string args) { 54 | var space = new char[] { ' ' }; 55 | p = p.Trim(space); 56 | 57 | bool closeQuotes = false; 58 | for (int i = 0; i < p.Length; i++) { 59 | if (p[i] == ' ') { 60 | if (!closeQuotes) { 61 | program = p.Substring(0, i); 62 | program = program.Trim(space); 63 | args = p.Substring(i + 1, p.Length - i - 1); 64 | args = args.Trim(space); 65 | return; 66 | } 67 | } else if (p[i] == '\"') { 68 | closeQuotes = !closeQuotes; 69 | } 70 | } 71 | 72 | program = p; 73 | args = null; 74 | } 75 | 76 | public static Func GetFunction(string name, object inst) { 77 | var ptr = inst.GetType().GetMethod(name).MethodHandle.GetFunctionPointer(); 78 | return (Func)Activator.CreateInstance(typeof(Func), inst, ptr); 79 | } 80 | 81 | public static Action GetAction(string name, object inst) { 82 | var ptr = inst.GetType().GetMethod(name).MethodHandle.GetFunctionPointer(); 83 | return (Action)Activator.CreateInstance(typeof(Action), inst, ptr); 84 | } 85 | 86 | public static Action GetAction(string name, Type type, object inst, BindingFlags bindingFlags) { 87 | MethodInfo method = type.GetMethod(name, bindingFlags); 88 | if (method == null) { 89 | throw new Exception("Type {0} doesnt have method {1}".Format(type, name)); 90 | } 91 | var ptr = method.MethodHandle.GetFunctionPointer(); 92 | return (Action)Activator.CreateInstance(typeof(Action), inst, ptr); 93 | } 94 | 95 | public static Action GetAction(string name, Type type, object inst) { 96 | MethodInfo method = type.GetMethod(name); 97 | if (method == null) { 98 | throw new Exception("Type {0} doesnt have method {1}".Format(type, name)); 99 | } 100 | var ptr = method.MethodHandle.GetFunctionPointer(); 101 | return (Action)Activator.CreateInstance(typeof(Action), inst, ptr); 102 | } 103 | 104 | public static object GetFieldValue(string name, Type type, object inst, BindingFlags bindingFlags) { 105 | FieldInfo info = type.GetField(name, bindingFlags); 106 | if (info == null) { 107 | throw new Exception("Type {0} doesnt have field {1}".Format(type, name)); 108 | } 109 | 110 | return info.GetValue(inst); 111 | } 112 | 113 | public static object GetPrivateFieldValue(string name, object inst) { 114 | return GetFieldValue(name, inst.GetType(), inst, BindingFlags.NonPublic | BindingFlags.Instance); 115 | } 116 | 117 | public static void GetData(RenderTarget2D renderTarget, byte[] buffer) { 118 | renderTarget.GetData(buffer); 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /python/end_to_end_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import time 4 | from typing import Any, Dict, List 5 | 6 | from common.logging_options import default_logging 7 | from common.test import assert_initial_state 8 | from towerfall import Connection, Towerfall 9 | 10 | default_logging() 11 | 12 | _VERBOSE = 1 13 | _TIMEOUT = 10 14 | 15 | 16 | ''' 17 | Runs the game with different reset configurations. 18 | ''' 19 | 20 | def main(): 21 | Towerfall.close_all() 22 | try: 23 | agent_count = 2 24 | towerfall = get_process(agent_count) 25 | run_test(towerfall, agent_count) 26 | 27 | agent_count = 1 28 | towerfall.send_config(get_config(agent_count)) 29 | run_test(towerfall, agent_count) 30 | 31 | agent_count = 3 32 | towerfall.send_config(get_config(agent_count)) 33 | run_test(towerfall, agent_count) 34 | 35 | logging.info('Test closing towerfall ...') 36 | towerfall.close() 37 | logging.info('Test closing towerfall successful.') 38 | 39 | agent_count = 4 40 | towerfall = get_process(agent_count) 41 | run_test(towerfall, agent_count) 42 | finally: 43 | Towerfall.close_all() 44 | 45 | 46 | def get_random_actions() -> str: 47 | s = '' 48 | p = 0.1 49 | keys = ['u', 'd', 'l', 'r', 'j', 'z', 's'] 50 | for key in keys: 51 | if random.random() < p: 52 | s += key 53 | return s 54 | 55 | 56 | def get_config(agent_count: int) -> dict[str, Any]: 57 | return dict( 58 | mode='sandbox', 59 | level='2', 60 | fps=999, 61 | solids=[[0] * 32]*14 + [[1]*32] + [[0] * 32]*9, 62 | agents=[dict(type='remote', team='blue')]*agent_count) 63 | 64 | 65 | def get_process(agent_count: int) -> Towerfall: 66 | return Towerfall( 67 | verbose=_VERBOSE, 68 | config=get_config(agent_count)) 69 | 70 | 71 | def join(towerfall: Towerfall, agent_count: int) -> list[Connection]: 72 | connections = [] 73 | for _ in range(agent_count): 74 | conn = towerfall.join(timeout=_TIMEOUT) 75 | conn.log_cap = 100 76 | connections.append(conn) 77 | return connections 78 | 79 | 80 | def reset(towerfall: Towerfall, agent_count: int) -> list[dict]: 81 | y = 120 82 | entities: List[Dict[str, Any]] = [dict(type='archer', pos=dict(x=0, y=y)) for i in range(agent_count)] 83 | entities.append(dict(type='slime', facing=-1, pos=dict(x=0, y=y))) 84 | for i, entity in enumerate(entities): 85 | entity['pos']['x'] = 10 + i * (300 / len(entities)) 86 | towerfall.send_reset(entities) 87 | return entities 88 | 89 | 90 | def receive_init(connections: List[Connection]): 91 | if _VERBOSE >= 1: 92 | logging.info('receive_init') 93 | for connection in connections: 94 | # init 95 | state_init = connection.read_json() 96 | assert state_init['type'] == 'init', state_init['type'] 97 | connection.send_json(dict(type='result', success=True)) 98 | 99 | for connection in connections: 100 | # scenario 101 | state_scenario = connection.read_json() 102 | assert state_scenario['type'] == 'scenario', state_scenario['type'] 103 | connection.send_json(dict(type='result', success=True)) 104 | 105 | 106 | def receive_update(connections: List[Connection], expected_entities: List[dict], n_frames: int): 107 | now = time.time() 108 | for _ in range(n_frames): 109 | for i_con, connection in enumerate(connections): 110 | # update 111 | state_update = connection.read_json() 112 | assert state_update['type'] == 'update', state_update['type'] 113 | if state_update['id'] == 0 and i_con == 0: 114 | assert_initial_state(expected_entities, state_update['entities']) 115 | connection.send_json(dict(type='actions', actions=get_random_actions(), id=state_update['id'])) 116 | dt = time.time() - now 117 | logging.info(f'fps: {n_frames/dt:.2f}') 118 | 119 | 120 | def run_many_resets(towerfall: Towerfall, agent_count: int, reset_count: int): 121 | connections = join(towerfall, agent_count) 122 | entities = reset(towerfall, agent_count) 123 | 124 | receive_init(connections) 125 | receive_update(connections, entities, n_frames=200) 126 | 127 | for _ in range(reset_count): 128 | entities = reset(towerfall, agent_count) 129 | receive_update(connections, entities, n_frames=20) 130 | 131 | 132 | def run_test(towerfall, agent_count): 133 | reset_count = 5 134 | logging.info(f'Test with {agent_count} agents ...') 135 | run_many_resets(towerfall, agent_count, reset_count) 136 | logging.info(f'Test with {agent_count} agents successful.') 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/RemoteConnection.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Net.Sockets; 3 | using System.Text; 4 | using System.Threading; 5 | using System.Threading.Tasks; 6 | using TowerfallAi.Common; 7 | 8 | namespace TowerfallAi.Core { 9 | /// 10 | /// Implements read/write over sockets. 11 | /// 12 | public class RemoteConnection : IDisposable { 13 | private const int maxMessageSize = 1 << 16; 14 | private const int headerSize = 2; 15 | 16 | private Socket socket; 17 | private byte[] headerBuffer = new byte[headerSize]; 18 | private byte[] bodyBuffer = new byte[maxMessageSize]; 19 | private int bytesToReceive = 0; 20 | private AsyncQueue messages = new AsyncQueue(); 21 | 22 | private TaskCompletionSource disconnectionTcs = new TaskCompletionSource(); 23 | 24 | public RemoteConnection(Socket socket) { 25 | 26 | this.socket = socket; 27 | TaskEx.Run(() => { 28 | try { 29 | while (true) { 30 | int bytesRead = socket.Receive(headerBuffer, 0, headerSize, SocketFlags.None); 31 | if (bytesRead == 0) { 32 | messages.EnqueueException(new SocketException((int)SocketError.ConnectionReset)); 33 | return; 34 | } 35 | 36 | bytesToReceive = headerBuffer[0] << 8 | headerBuffer[1]; 37 | bytesRead = socket.Receive(bodyBuffer, 0, bytesToReceive, SocketFlags.None); 38 | if (bytesRead == 0) { 39 | messages.EnqueueException(new SocketException((int)SocketError.ConnectionReset)); 40 | return; 41 | } 42 | 43 | string response = Encoding.ASCII.GetString(bodyBuffer, 0, bytesToReceive); 44 | messages.Enqueue(response); 45 | } 46 | } finally { 47 | disconnectionTcs.SetResult(true); 48 | } 49 | }); 50 | } 51 | 52 | public Task WaitDisconnectAsync() { 53 | return disconnectionTcs.Task; 54 | } 55 | 56 | public bool IsAlive() { 57 | return disconnectionTcs.Task.IsAlive(); 58 | } 59 | 60 | private void ReceiveHeaderCallback(IAsyncResult result) { 61 | var bytesRead = socket.EndReceive(result); 62 | if (bytesRead == 0) { 63 | messages.EnqueueException(new SocketException((int)SocketError.ConnectionReset)); 64 | return; 65 | } 66 | 67 | if (bytesRead != 2) { 68 | Logger.Error($"Received unexpected number of bytes. bytesRead: {bytesRead}"); 69 | Dispose(); 70 | return; 71 | } 72 | 73 | bytesToReceive = headerBuffer[0] << 8 | headerBuffer[1]; 74 | if (bytesToReceive > maxMessageSize) { 75 | Logger.Error($"Message exceeds limit: {maxMessageSize}."); 76 | Dispose(); 77 | return; 78 | } 79 | 80 | socket.BeginReceive(bodyBuffer, 0, bytesToReceive, SocketFlags.None, ReceiveBodyCallback, null); 81 | } 82 | 83 | private void ReceiveBodyCallback(IAsyncResult result) { 84 | var bytesRead = socket.EndReceive(result); 85 | if (bytesRead != bytesToReceive) { 86 | Logger.Error($"Received unexpected number of bytes. bytesRead: {bytesRead}. bytesToReceive: {bytesToReceive}"); 87 | Dispose(); 88 | return; 89 | } 90 | 91 | string response = Encoding.ASCII.GetString(bodyBuffer, 0, bytesRead); 92 | messages.Enqueue(response); 93 | bytesToReceive = 0; 94 | socket.BeginReceive(headerBuffer, 0, headerSize, SocketFlags.None, ReceiveHeaderCallback, null); 95 | } 96 | 97 | 98 | public async Task ReadAsync(TimeSpan timeout = default, CancellationToken cancellationToken = default) { 99 | string message = await messages.DequeueAsync(timeout, cancellationToken); 100 | return message; 101 | } 102 | 103 | public void Write(string text) { 104 | byte[] payload = Encoding.ASCII.GetBytes(text); 105 | int size = payload.Length; 106 | if (size > maxMessageSize) { 107 | throw new Exception("Message exceeds limit: {0}.".Format(maxMessageSize)); 108 | } 109 | byte[] header = new Byte[2]; 110 | header[0] = (byte)(size >> 8); 111 | header[1] = (byte)(size & 0x00FF); 112 | socket.Send(header); 113 | socket.Send(payload); 114 | } 115 | 116 | public void Dispose() { 117 | if (socket != null && socket.IsBound) { 118 | socket.Shutdown(SocketShutdown.Both); 119 | socket.Close(); 120 | socket.Dispose(); 121 | socket = null; 122 | } 123 | } 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/Common.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Debug 6 | AnyCPU 7 | {3A23E492-74B3-459B-9EDE-3B6DA7F17E39} 8 | Library 9 | Properties 10 | Common 11 | Common 12 | v4.0 13 | 512 14 | 15 | 16 | 17 | AnyCPU 18 | true 19 | full 20 | false 21 | bin\Debug\ 22 | DEBUG;TRACE 23 | prompt 24 | 4 25 | false 26 | 27 | 28 | AnyCPU 29 | pdbonly 30 | true 31 | bin\Release\ 32 | TRACE 33 | prompt 34 | 4 35 | false 36 | 37 | 38 | 39 | 40 | 41 | 42 | ..\..\Builder\Builder\packages\Microsoft.Data.Edm.5.6.4\lib\net40\Microsoft.Data.Edm.dll 43 | True 44 | 45 | 46 | ..\..\Builder\Builder\packages\Microsoft.Data.OData.5.6.4\lib\net40\Microsoft.Data.OData.dll 47 | True 48 | 49 | 50 | ..\..\Builder\Builder\packages\Microsoft.Data.Services.Client.5.6.4\lib\net40\Microsoft.Data.Services.Client.dll 51 | True 52 | 53 | 54 | ..\..\packages\Newtonsoft.Json.13.0.1\lib\net40\Newtonsoft.Json.dll 55 | 56 | 57 | 58 | 59 | ..\..\Builder\Builder\packages\System.Spatial.5.6.4\lib\net40\System.Spatial.dll 60 | True 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 90 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Mod/Entity.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using Monocle; 4 | using Patcher; 5 | using TowerFall; 6 | using TowerfallAi.Api; 7 | using TowerfallAi.Common; 8 | using System.Reflection; 9 | using Microsoft.Xna.Framework; 10 | 11 | namespace TowerfallAi.Mod { 12 | [Patch] 13 | public class ModEntity : Entity { 14 | public override void Update() { 15 | try { 16 | base.Update(); 17 | } catch (Exception ex) { 18 | throw new Exception(GetType().ToString(), ex); 19 | } 20 | } 21 | } 22 | 23 | public static class ExtEntity { 24 | static Dictionary ids = new Dictionary(); 25 | 26 | static Dictionary prevPositions = new Dictionary(); 27 | 28 | public static void Reset() { 29 | ids.Clear(); 30 | prevPositions.Clear(); 31 | } 32 | 33 | public static bool TryGetSpeed(object obj, out Vector2 speed) { 34 | var type = obj.GetType(); 35 | speed = new Vector2(); 36 | var field = type.GetField("Speed"); 37 | if (field != null) { 38 | field.GetValue(obj); 39 | } 40 | return false; 41 | } 42 | 43 | static float GetDecimal(float x) { 44 | return x - (float)(Math.Round(x)); 45 | } 46 | 47 | public static void SetAiState(Entity ent, StateEntity state) { 48 | if (ids.ContainsKey(ent)) { 49 | state.id = ids[ent]; 50 | } else { 51 | state.id = ids.Count; 52 | ids.Add(ent, state.id); 53 | } 54 | 55 | float x; 56 | float y; 57 | 58 | if (ent.Collider == null) { 59 | x = ent.Position.X; 60 | y = 240 - ent.Position.Y; 61 | } else { 62 | x = ent.CenterX; 63 | y = 240 - ent.CenterY; 64 | } 65 | 66 | if (ent is Actor) { 67 | Actor actor = (Actor)ent; 68 | x += GetDecimal(actor.ActualPosition.X); 69 | y += GetDecimal(240 - actor.ActualPosition.Y); 70 | } 71 | 72 | state.pos = new Vec2 { x = x, y = y }; 73 | 74 | Vec2 prevPos; 75 | Vector2 speed; 76 | 77 | if (TryGetSpeed(ent, out speed)) { 78 | state.vel = new Vec2 { 79 | x = speed.X, 80 | y = -speed.Y 81 | }; 82 | } else { 83 | if (prevPositions.TryGetValue(ent, out prevPos)) { 84 | state.vel = new Vec2 { 85 | x = state.pos.x - prevPos.x, 86 | y = state.pos.y - prevPos.y, 87 | }; 88 | } else { 89 | state.vel = new Vec2 { 90 | x = 0, 91 | y = 0, 92 | }; 93 | } 94 | prevPositions[ent] = state.pos; 95 | } 96 | 97 | if (ent.Collider == null) { 98 | state.size = new Vec2(); 99 | } else { 100 | state.size = new Vec2 { x = ent.Width, y = ent.Height }; 101 | } 102 | 103 | if (ent is Enemy) { 104 | Enemy enemy = (Enemy)ent; 105 | 106 | state.isEnemy = true; 107 | state.canHurt = enemy.CanHurt; 108 | state.canBounceOn = enemy.CanBounceOn; 109 | state.isDead = enemy.IsDead(); 110 | state.facing = (int)enemy.Facing; 111 | state.state = ((string[])Util.GetFieldValue("names", typeof(Enemy), enemy, BindingFlags.NonPublic | BindingFlags.Instance))[enemy.State].FirstLower(); 112 | } 113 | } 114 | 115 | public static StateEntity GetState(Entity e) { 116 | return GetState(e, null); 117 | } 118 | 119 | public static StateEntity GetState(Entity ent, string type) { 120 | var state = new StateEntity { 121 | type = (type == null ? ent.GetType().Name : type).FirstLower() 122 | }; 123 | SetAiState(ent, state); 124 | return state; 125 | } 126 | 127 | public static StateItem GetStateItem(Entity ent, string itemType) { 128 | var state = new StateItem { 129 | type = Types.Item, 130 | itemType = itemType 131 | }; 132 | SetAiState(ent, state); 133 | return state; 134 | } 135 | 136 | public static StateEntity GetStateArrow(Entity ent) { 137 | Arrow arrow = ent as Arrow; 138 | var state = new StateArrow { 139 | type = Types.Arrow, 140 | arrowType = arrow.ArrowType.ToString().FirstLower(), 141 | }; 142 | SetAiState(arrow, state); 143 | state.state = arrow.State.ToString().FirstLower(); 144 | return state; 145 | } 146 | 147 | public static StateItem GetStateBombPickup(BombPickup ent) { 148 | var state = new StateItem { 149 | type = Types.Item, 150 | itemType = TypesItems.Bomb 151 | }; 152 | SetAiState(ent, state); 153 | return state; 154 | } 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /python/gym_wrapper/base_env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, List, Optional, Tuple 3 | 4 | from common.entity import Entity, to_entities 5 | from gym import Env 6 | from numpy.typing import NDArray 7 | from towerfall import Towerfall 8 | 9 | from .actions import Actions 10 | 11 | 12 | class TowerfallEnv(Env, ABC): 13 | ''' 14 | Interacts with the Towerfall.exe process to create an interface with the agent that follows the gym API. 15 | Inherit from this class to choose the appropriate observations and reward functions. 16 | 17 | params towerfall: The Towerfall instance to connect to. 18 | params actions: The actions that the agent can take. If None, the default actions are used. 19 | params record_path: The path to record the game to. If None, no recording is done. 20 | params verbose: The verbosity level. 0: no logging, 1: much logging. 21 | ''' 22 | def __init__(self, 23 | towerfall: Towerfall, 24 | actions: Optional[Actions] = None, 25 | record_path: Optional[str] = None, 26 | verbose: int = 0): 27 | self.towerfall = towerfall 28 | self.verbose = verbose 29 | self.connection = self.towerfall.join(timeout=5) 30 | self.connection.record_path = record_path 31 | if actions: 32 | self.actions = actions 33 | else: 34 | self.actions = Actions() 35 | self.action_space = self.actions.action_space 36 | self._draw_elems = [] 37 | self.is_init_sent = False 38 | 39 | def _send_reset(self): 40 | ''' 41 | Sends the reset instruction to the game. Overwrite this to change the starting conditions. 42 | ''' 43 | self.towerfall.send_reset() 44 | 45 | @abstractmethod 46 | def _post_reset(self) -> Tuple[NDArray, dict]: 47 | ''' 48 | Hook for a gym reset call. Subclass should populate and return the same as a reset in gym API. 49 | 50 | Returns: 51 | A tuple of (observation, info) 52 | ''' 53 | raise NotImplementedError 54 | 55 | @abstractmethod 56 | def _post_step(self) -> Tuple[NDArray, float, bool, dict]: 57 | ''' 58 | Hook for a gym step call. Subclass should populate this to return the same as a step in gym API. 59 | 60 | Returns: 61 | A tuple of (observation, reward, done, info) 62 | ''' 63 | raise NotImplementedError 64 | 65 | def draws(self, draw_elem): 66 | ''' 67 | Draws an element on the screen. This is useful for debugging. 68 | ''' 69 | self._draw_elems.append(draw_elem) 70 | 71 | def reset(self) -> Tuple[NDArray, dict]: 72 | ''' 73 | Gym reset. This is called by the agent to reset the environment. 74 | ''' 75 | 76 | self._send_reset() 77 | if not self.is_init_sent: 78 | state_init = self.connection.read_json() 79 | assert state_init['type'] == 'init', state_init['type'] 80 | self.index = state_init['index'] 81 | self.connection.send_json(dict(type='result', success=True)) 82 | 83 | self.state_scenario = self.connection.read_json() 84 | assert self.state_scenario['type'] == 'scenario', self.state_scenario['type'] 85 | self.connection.send_json(dict(type='result', success=True)) 86 | self.is_init_sent = True 87 | else: 88 | self.connection.send_json(dict(type='actions', actions="", id=self.state_update['id'])) 89 | 90 | self.frame = 0 91 | self.state_update = self.connection.read_json() 92 | assert self.state_update['type'] == 'update', self.state_update['type'] 93 | self.entities = to_entities(self.state_update['entities']) 94 | self.me = self._get_own_archer(self.entities) 95 | 96 | return self._post_reset() 97 | 98 | def step(self, actions: NDArray) -> Tuple[NDArray, float, bool, object]: 99 | ''' 100 | Gym step. This is called by the agent to take an action in the environment. 101 | ''' 102 | actions_str = self.actions.to_serialized_actions(actions) 103 | 104 | resp: Dict[str, Any] = dict( 105 | type='actions', 106 | actions=actions_str, 107 | id=self.state_update['id'] 108 | ) 109 | if self._draw_elems: 110 | resp['draws'] = self._draw_elems 111 | self.connection.send_json(resp) 112 | self._draw_elems.clear() 113 | self.state_update = self.connection.read_json() 114 | self.actions_str = actions_str 115 | assert self.state_update['type'] == 'update' 116 | self.entities = to_entities(self.state_update['entities']) 117 | self.me = self._get_own_archer(self.entities) 118 | return self._post_step() 119 | 120 | def _get_own_archer(self, entities: List[Entity]) -> Optional[Entity]: 121 | ''' 122 | Iterates over all entities to find the archer that matches the index specified in init. 123 | ''' 124 | for e in entities: 125 | if e.type == 'archer': 126 | if e['playerIndex'] == self.index: 127 | return e 128 | return None 129 | 130 | def render(self, mode='human'): 131 | ''' 132 | This is a no-op since the game is rendered independenly by MonoGame/XNA. 133 | ''' 134 | pass -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/EntityCreator.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using Microsoft.Xna.Framework; 4 | using Monocle; 5 | using Newtonsoft.Json.Linq; 6 | using TowerFall; 7 | using TowerfallAi.Api; 8 | using TowerfallAi.Common; 9 | 10 | namespace TowerfallAi.Core { 11 | public static class EntityCreator { 12 | private static Dictionary> creatorFuncs = new Dictionary> { 13 | { "bat", CreateBat }, 14 | { "batBomb", CreateBat }, 15 | { "batSuperBomb", CreateBat }, 16 | { "bird", CreateBat }, 17 | { "slime", CreateSlime }, 18 | { "birdman", CreateBirdman }, 19 | { "cultist", CreateCultist }, 20 | { "evilCrystal", CreateEvilCrystal }, 21 | { "exploder", CreateExploder }, 22 | { "flamingSkull", CreateFlamingSkull }, 23 | { "ghost", CreateGhost }, 24 | { "mole", CreateMole }, 25 | { "worm", CreateWorm }, 26 | }; 27 | 28 | public static Player CreatePlayer(JObject e, int playerIndex, Allegiance team) { 29 | var state = e.ToObject(); 30 | Vector2 pos = GetPos(state); 31 | var player = new Player( 32 | playerIndex, 33 | pos, 34 | team, 35 | team, 36 | PlayerInventory.Default, 37 | Player.HatStates.Normal, 38 | false, 39 | false, 40 | false); 41 | Logger.Info($"Create archer {playerIndex} at pos {pos}"); 42 | return player; 43 | } 44 | 45 | public static Entity CreateEntity(JObject e) { 46 | JToken typeToken = e.GetValue("type"); 47 | if (typeToken == null) throw new ArgumentException("Missing 'type' from reset entity."); 48 | string type = typeToken.Value(); 49 | Func creatorFunc; 50 | if (creatorFuncs.TryGetValue(type, out creatorFunc)) { 51 | return creatorFunc(e); 52 | } 53 | 54 | throw new Exception($"Unsupported entity type to be created: {type}"); 55 | } 56 | 57 | public static Bat CreateBat(JObject e) { 58 | var state = e.ToObject(); 59 | return new Bat( 60 | GetPos(state) + new Vector2(0, 2), 61 | GetFacing(state), 62 | ConversionTypes.BatTypes.GetA(state.type)); 63 | } 64 | 65 | public static Slime CreateSlime(JObject e) { 66 | var state = e.ToObject(); 67 | return new Slime( 68 | GetPos(state) + new Vector2(0, 5), 69 | GetFacing(state), 70 | state.subType == null ? Slime.SlimeColors.Green : ConversionTypes.SlimeTypes.GetA(state.subType)); 71 | } 72 | 73 | public static Birdman CreateBirdman(JObject e) { 74 | var state = e.ToObject(); 75 | return new Birdman(GetPos(state), GetFacing(state), darkWorld: false); 76 | } 77 | 78 | public static Cultist CreateCultist(JObject e) { 79 | var state = e.ToObject(); 80 | return new Cultist( 81 | GetPos(state) + new Vector2(0, 8), 82 | GetFacing(state), 83 | state.subType == null ? Cultist.CultistTypes.Normal : ConversionTypes.CultistTypes.GetA(state.subType)); 84 | } 85 | 86 | public static EvilCrystal CreateEvilCrystal(JObject e) { 87 | var state = e.ToObject(); 88 | Vector2 pos = GetPos(state); 89 | return new EvilCrystal( 90 | pos, 91 | GetFacing(state), 92 | state.subType == null ? EvilCrystal.CrystalColors.Red : ConversionTypes.CrystalTypes.GetA(state.subType), 93 | new Vector2[] { pos + new Vector2(50 * state.facing, 0) }); 94 | } 95 | 96 | public static Exploder CreateExploder(JObject e) { 97 | var state = e.ToObject(); 98 | Vector2 pos = GetPos(state); 99 | return new Exploder( 100 | pos, 101 | GetFacing(state), 102 | new Vector2[] { pos + new Vector2(50 * state.facing, 0) }); 103 | } 104 | 105 | public static FlamingSkull CreateFlamingSkull(JObject e) { 106 | var state = e.ToObject(); 107 | return new FlamingSkull(GetPos(state), GetFacing(state)); 108 | } 109 | 110 | public static Ghost CreateGhost(JObject e) { 111 | var state = e.ToObject(); 112 | Vector2 pos = GetPos(state); 113 | return new Ghost( 114 | pos, 115 | GetFacing(state), 116 | new Vector2[] { pos + new Vector2(50 * state.facing, 0) }, 117 | state.subType == null ? Ghost.GhostTypes.Blue : ConversionTypes.GhostTypes.GetA(state.subType)); 118 | } 119 | 120 | public static Mole CreateMole(JObject e) { 121 | var state = e.ToObject(); 122 | return new Mole(GetPos(state), GetFacing(state)); 123 | } 124 | 125 | public static Worm CreateWorm(JObject e) { 126 | var state = e.ToObject(); 127 | return new Worm(GetPos(state)); 128 | } 129 | 130 | private static Facing GetFacing(StateEntity e) { 131 | return e.facing > 0 ? Facing.Right : Facing.Left; 132 | } 133 | 134 | private static Vector2 GetPos(StateEntity e) { 135 | return new Vector2(e.pos.x, 240 - e.pos.y); 136 | } 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /python/agents/simple_agent.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from typing import Any, Mapping 4 | 5 | from towerfall import Connection 6 | 7 | class SimpleAgent: 8 | ''' 9 | A minimal implementation of an agent that shows how to communicate with the game. 10 | It can be used in modes: 11 | - quest 12 | - versus 13 | - sandbox 14 | 15 | params connection: A connection to a Towerfall game. 16 | params attack_archers: If True, the agent will attack other neutral archers. 17 | ''' 18 | def __init__(self, connection: Connection, attack_archers: bool = False): 19 | self.state_init: Mapping[str, Any] = {} 20 | self.state_scenario: Mapping[str, Any] = {} 21 | self.state_update: Mapping[str, Any] = {} 22 | self.pressed = set() 23 | self.connection = connection 24 | self.attack_archers = attack_archers 25 | 26 | def act(self, game_state: Mapping[str, Any]): 27 | ''' 28 | Handles a game message. 29 | ''' 30 | 31 | # There are three main types to handle, 'init', 'scenario' and 'update'. 32 | # Check 'type' to handle each accordingly. 33 | if game_state['type'] == 'init': 34 | # 'init' is sent every time a match series starts. It contains information about the players and teams. 35 | # The seed is based on the bot index so each bots acts differently. 36 | self.state_init = game_state 37 | random.seed(self.state_init['index']) 38 | # Acknowledge the init message. 39 | self.connection.send_json(dict(type='result', success=True)) 40 | return True 41 | 42 | if game_state['type'] == 'scenario': 43 | # 'scenario' informs your bot about the current state of the ground. Store this information 44 | # to use in all subsequent loops. (This example bot doesn't use the shape of the scenario) 45 | self.state_scenario = game_state 46 | # Acknowledge the scenario message. 47 | self.connection.send_json(dict(type='result', success=True)) 48 | return 49 | 50 | if game_state['type'] == 'update': 51 | # 'update' informs the state of entities in the map (players, arrows, enemies, etc). 52 | self.state_update = game_state 53 | 54 | # After receiving an 'update', your bot is expected to output string with the pressed buttons. 55 | # Each button is represented by a character: 56 | # r = right 57 | # l = left 58 | # u = up 59 | # d = down 60 | # j = jump 61 | # z = dash 62 | # s = shoot 63 | # The order of the characters are irrelevant. Any other character is ignored. Repeated characters are ignored. 64 | 65 | # This bot acts based on the position of the other player only. It 66 | # has a very random playstyle: 67 | # - Runs to the enemy when they are below. 68 | # - Runs away from the enemy when they are above. 69 | # - Shoots when in the same horizontal line. 70 | # - Dashes randomly. 71 | # - Jumps randomly. 72 | 73 | my_state = None 74 | enemy_state = None 75 | 76 | players = [] 77 | 78 | for state in self.state_update['entities']: 79 | if state['type'] == 'archer': 80 | players.append(state) 81 | if state['playerIndex'] == self.state_init['index']: 82 | my_state = state 83 | 84 | # If the agent is not present, it means it is dead. 85 | if my_state == None: 86 | # You are required to reply with actions, or the agent will get disconnected. 87 | self.send_actions() 88 | return 89 | 90 | # Try to find an enemy archer. 91 | for state in players: 92 | if state['playerIndex'] == my_state['playerIndex']: 93 | continue 94 | if (self.attack_archers and state['team'] == 'neutral') or state['team'] != my_state['team']: 95 | enemy_state = state 96 | break 97 | 98 | # If no enemy archer is found, try to find another enemy. 99 | if not enemy_state: 100 | for state in self.state_update['entities']: 101 | if state['isEnemy']: 102 | enemy_state = state 103 | 104 | # If no enemy is found, means all are dead. 105 | if enemy_state == None: 106 | self.send_actions() 107 | return 108 | 109 | my_pos = my_state['pos'] 110 | enemy_pos = enemy_state['pos'] 111 | if enemy_pos['y'] >= my_pos['y'] and enemy_pos['y'] <= my_pos['y'] + 50: 112 | # Runs away if enemy is right above 113 | if my_pos['x'] < enemy_pos['x']: 114 | self.press('l') 115 | else: 116 | self.press('r') 117 | else: 118 | # Runs to enemy if they are below 119 | if my_pos['x'] < enemy_pos['x']: 120 | self.press('r') 121 | else: 122 | self.press('l') 123 | 124 | # If in the same line shoots, 125 | if abs(my_pos['y'] - enemy_pos['y']) < enemy_state['size']['y']: 126 | if random.randint(0, 1) == 0: 127 | self.press('s') 128 | 129 | # Presses dash in 1/10 of the frames. 130 | if random.randint(0, 9) == 0: 131 | self.press('z') 132 | 133 | # Presses jump in 1/20 of the frames. 134 | if random.randint(0, 19) == 0: 135 | self.press('j') 136 | 137 | # Respond the update frame with actions from this agent. 138 | self.send_actions() 139 | 140 | def press(self, b): 141 | self.pressed.add(b) 142 | 143 | def send_actions(self): 144 | assert self.state_update 145 | self.connection.send_json(dict( 146 | type = 'actions', 147 | actions = ''.join(self.pressed), 148 | id = self.state_update['id'] 149 | )) 150 | self.pressed.clear() 151 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiModTests/AsyncQueueTest.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Threading; 4 | using System.Threading.Tasks; 5 | using NUnit.Framework; 6 | using TowerfallAi.Common; 7 | 8 | namespace TowerfallAiTests { 9 | [TestFixture] 10 | public class AsyncQueueTest { 11 | class TestException : Exception { } 12 | 13 | [Test] 14 | public void MultiEnqueueDequeueTest() { 15 | // Spawns many threads to enqueue and dequeue from AsyncQueue concurrently. 16 | var asyncQueue = new AsyncQueue(); 17 | 18 | // This is used to block the threads so they are not re 19 | var enqueueAllEvent = new ManualResetEvent(false); 20 | 21 | int messageCount = 10; 22 | int workerThreads; 23 | int completionPortThreads; 24 | ThreadPool.GetMinThreads(out workerThreads, out completionPortThreads); 25 | ThreadPool.SetMinThreads(workerThreads + messageCount, completionPortThreads); 26 | 27 | TaskEx.Run(() => { 28 | for (int i = 0; i < messageCount; i++) { 29 | TaskEx.Run(() => { 30 | asyncQueue.Enqueue(Thread.CurrentThread.ManagedThreadId); 31 | if (i == messageCount - 1) { 32 | enqueueAllEvent.Set(); 33 | } else { 34 | enqueueAllEvent.WaitOne(); 35 | } 36 | }); 37 | } 38 | }); 39 | 40 | List> dequeueTasks = new List>(); 41 | var dequeueTask = TaskEx.Run(() => { 42 | for (int i = 0; i < messageCount; i++) { 43 | dequeueTasks.Add(asyncQueue.DequeueAsync()); 44 | } 45 | }); 46 | 47 | dequeueTask.Wait(); 48 | var threadIds = new HashSet(); 49 | foreach (Task task in dequeueTasks) { 50 | threadIds.Add(task.Result); 51 | } 52 | 53 | Assert.AreEqual(messageCount, threadIds.Count); 54 | } 55 | 56 | [Test] 57 | public void EnqueueExceptionTest() { 58 | var asyncQueue = new AsyncQueue(); 59 | var dequeueTask = asyncQueue.DequeueAsync(); 60 | 61 | asyncQueue.EnqueueException(new TestException()); 62 | 63 | AssertException(dequeueTask, typeof(TestException)); 64 | } 65 | 66 | [Test] 67 | public void CancellationTest() { 68 | var asyncQueue = new AsyncQueue(); 69 | CancellationTokenSource cancelTs = new CancellationTokenSource(); 70 | 71 | var dequeueTask = asyncQueue.DequeueAsync(cancellationToken: cancelTs.Token); 72 | 73 | AssertAlive(dequeueTask); 74 | 75 | Task dependentTask = TaskEx.Run(async () => { 76 | return await dequeueTask + 1; 77 | }); 78 | 79 | AssertAlive(dependentTask); 80 | 81 | cancelTs.Cancel(); 82 | 83 | AssertException(dependentTask, typeof(TaskCanceledException)); 84 | AssertException(dequeueTask, typeof(TaskCanceledException)); 85 | 86 | Assert.AreEqual(TaskStatus.Canceled, dependentTask.Status); 87 | Assert.AreEqual(TaskStatus.Canceled, dequeueTask.Status); 88 | } 89 | 90 | [Test] 91 | public void EnqueueFirstTest() { 92 | var asyncQueue = new AsyncQueue(); 93 | 94 | for (int i = 0; i < 10; i++) { 95 | asyncQueue.Enqueue(i); 96 | } 97 | 98 | for (int i = 0; i < 10; i++) { 99 | Assert.AreEqual(i, asyncQueue.DequeueAsync().Result); 100 | } 101 | } 102 | 103 | [Test] 104 | public void EnqueueLastCheckOrderTest() { 105 | var asyncQueue = new AsyncQueue(); 106 | 107 | var dequeueTasks = new List>(); 108 | for (int i = 0; i < 10; i++) { 109 | dequeueTasks.Add(asyncQueue.DequeueAsync()); 110 | } 111 | 112 | for (int i = 0; i < 10; i++) { 113 | asyncQueue.Enqueue(i); 114 | } 115 | 116 | for (int i = 0; i < 10; i++) { 117 | Assert.AreEqual(i, dequeueTasks[i].Result); 118 | } 119 | } 120 | 121 | [Test] 122 | public void CloseTest() { 123 | var asyncQueue = new AsyncQueue(); 124 | asyncQueue.Enqueue(42); 125 | asyncQueue.Close(); 126 | 127 | Assert.AreEqual(42, asyncQueue.DequeueAsync().Result); 128 | 129 | Assert.Throws(() => asyncQueue.Enqueue(10)); 130 | Assert.Throws(() => asyncQueue.DequeueAsync()); 131 | Assert.Throws(() => asyncQueue.EnqueueException(new TestException())); 132 | } 133 | 134 | [Test] 135 | public void TimeoutTest() { 136 | var asyncQueue = new AsyncQueue(); 137 | 138 | int delayTimeMs = 50; 139 | var dequeueTask = asyncQueue.DequeueAsync(new TimeSpan(0, 0, 0, 0, delayTimeMs)); 140 | asyncQueue.Enqueue(42); 141 | int result = dequeueTask.Result; 142 | Assert.AreEqual(42, result); 143 | 144 | // This is necessary to make sure nothing weird happens after timeout period. 145 | Thread.Sleep(delayTimeMs); 146 | 147 | delayTimeMs = 50; 148 | dequeueTask = asyncQueue.DequeueAsync(new TimeSpan(0, 0, 0, 0, delayTimeMs)); 149 | var dependentTask = TaskEx.Run(async () => { 150 | return await dequeueTask; 151 | }); 152 | 153 | AssertAlive(dequeueTask); 154 | 155 | TaskEx.WhenAny(dependentTask, TaskEx.Delay(2*delayTimeMs)).Wait(); 156 | 157 | Assert.AreEqual(TaskStatus.Faulted, dequeueTask.Status); 158 | Assert.AreEqual(TaskStatus.Faulted, dependentTask.Status); 159 | 160 | AssertException(dequeueTask, typeof(TimeoutException)); 161 | AssertException(dequeueTask, typeof(TimeoutException)); 162 | } 163 | 164 | private void AssertAlive(Task task) { 165 | Assert.AreNotEqual(TaskStatus.RanToCompletion, task.Status); 166 | Assert.AreNotEqual(TaskStatus.Faulted, task.Status); 167 | Assert.AreNotEqual(TaskStatus.Canceled, task.Status); 168 | } 169 | 170 | private void AssertException(Task task, Type expectedExceptionType) { 171 | try { 172 | task.Wait(); 173 | Assert.Fail("Exception was not thrown."); 174 | } catch (AggregateException aggregateException) { 175 | Assert.AreEqual(1, aggregateException.InnerExceptions.Count); 176 | Assert.AreEqual(expectedExceptionType, aggregateException.InnerException.GetType()); 177 | } 178 | } 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiModTests/TowerfallAiModTests.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Debug 8 | AnyCPU 9 | {19CFA5FC-A793-46EF-9F34-88F6C9D8C19F} 10 | Library 11 | Properties 12 | TowerfallAiModTests 13 | TowerfallAiModTests 14 | v4.0 15 | 512 16 | {3AC096D0-A1C2-E12C-1390-A8335801FDAB};{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC} 17 | 15.0 18 | $(MSBuildExtensionsPath32)\Microsoft\VisualStudio\v$(VisualStudioVersion) 19 | $(ProgramFiles)\Common Files\microsoft shared\VSTT\$(VisualStudioVersion)\UITestExtensionPackages 20 | False 21 | UnitTest 22 | 23 | 24 | 25 | 26 | 27 | true 28 | full 29 | false 30 | bin\Debug\ 31 | DEBUG;TRACE 32 | prompt 33 | 4 34 | AnyCPU 35 | 36 | 37 | pdbonly 38 | true 39 | bin\Release\ 40 | TRACE 41 | prompt 42 | 4 43 | 44 | 45 | 46 | ..\packages\Microsoft.Bcl.Async.1.0.168\lib\net40\Microsoft.Threading.Tasks.dll 47 | 48 | 49 | ..\packages\Microsoft.Bcl.Async.1.0.168\lib\net40\Microsoft.Threading.Tasks.Extensions.dll 50 | 51 | 52 | ..\packages\Microsoft.Bcl.Async.1.0.168\lib\net40\Microsoft.Threading.Tasks.Extensions.Desktop.dll 53 | 54 | 55 | ..\packages\NUnit.3.13.3\lib\net40\nunit.framework.dll 56 | 57 | 58 | 59 | 60 | ..\packages\Microsoft.Bcl.1.1.8\lib\net40\System.IO.dll 61 | 62 | 63 | 64 | ..\packages\Microsoft.Bcl.1.1.8\lib\net40\System.Runtime.dll 65 | 66 | 67 | ..\packages\Microsoft.Bcl.1.1.8\lib\net40\System.Threading.Tasks.dll 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | {5c1aa190-dd28-41a1-8dd2-fef3a9dee28a} 81 | TowerfallAiMod 82 | 83 | 84 | 85 | 86 | 87 | 88 | This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Common/RemoteConnection.cs: -------------------------------------------------------------------------------- 1 | //using Newtonsoft.Json; 2 | //using System; 3 | //using System.Collections.Generic; 4 | //using System.Linq; 5 | //using System.Net; 6 | //using System.Net.Sockets; 7 | //using System.Text; 8 | //using System.Threading; 9 | //using System.Threading.Tasks; 10 | 11 | //namespace Common { 12 | // public class RemoteConnection : IDisposable { 13 | // Socket listener; 14 | // Socket senderSocket; 15 | // List inputLines = new List(); 16 | // Semaphore mutexInput = new Semaphore(0, int.MaxValue); 17 | 18 | // Exception receiveException; 19 | 20 | // public bool IsConnected { 21 | // get { 22 | // if (this.senderSocket == null) return false; 23 | 24 | // return this.senderSocket.Connected; 25 | // } 26 | // } 27 | 28 | // public void StartListening(string ip, int port) { 29 | // SocketPermission permission = new SocketPermission( 30 | // NetworkAccess.Accept, 31 | // TransportType.Tcp, 32 | // "", 33 | // SocketPermission.AllPorts); 34 | 35 | // permission.Demand(); 36 | 37 | // IPEndPoint ipEndPoint = GetIpEndpoint(ip, port); 38 | 39 | // listener = new Socket( 40 | // ipEndPoint.AddressFamily, 41 | // SocketType.Stream, 42 | // ProtocolType.Tcp); 43 | 44 | // listener.Bind(ipEndPoint); 45 | 46 | // listener.Listen(1); 47 | 48 | // AsyncCallback aCallback = new AsyncCallback(AcceptCallback); 49 | // listener.BeginAccept(aCallback, listener); 50 | // } 51 | 52 | // public void Disconnect() { 53 | // if (listener != null) { 54 | // listener.Close(); 55 | // listener.Dispose(); 56 | // listener = null; 57 | // } 58 | 59 | // if (senderSocket != null) { 60 | // senderSocket.Close(); 61 | // senderSocket.Dispose(); 62 | // senderSocket = null; 63 | // } 64 | // } 65 | 66 | // private IPEndPoint GetIpEndpoint(string ip, int port) { 67 | // return new IPEndPoint(IPAddress.Parse(ip), port); 68 | // } 69 | 70 | // public void Connect(string ip, int port) { 71 | // SocketPermission permission = new SocketPermission( 72 | // NetworkAccess.Connect, 73 | // TransportType.Tcp, 74 | // "", 75 | // SocketPermission.AllPorts); 76 | 77 | // permission.Demand(); 78 | 79 | // IPEndPoint ipEndPoint = GetIpEndpoint(ip, port); 80 | 81 | // senderSocket = new Socket( 82 | // ipEndPoint.AddressFamily, 83 | // SocketType.Stream, 84 | // ProtocolType.Tcp); 85 | 86 | // senderSocket.NoDelay = false; 87 | // senderSocket.Connect(ipEndPoint); 88 | 89 | // byte[] buffer = new byte[1024]; 90 | // object[] obj = new object[2]; 91 | // obj[0] = buffer; 92 | // obj[1] = senderSocket; 93 | 94 | // senderSocket.BeginReceive( 95 | // buffer, 96 | // 0, 97 | // buffer.Length, 98 | // SocketFlags.None, 99 | // new AsyncCallback(ReceiveCallback), 100 | // obj); 101 | // } 102 | 103 | // public void AcceptCallback(IAsyncResult ar) { 104 | // byte[] buffer = new byte[1024]; 105 | // listener = (Socket)ar.AsyncState; 106 | // senderSocket = listener.EndAccept(ar); 107 | // senderSocket.NoDelay = false; 108 | 109 | // object[] obj = new object[2]; 110 | // obj[0] = buffer; 111 | // obj[1] = senderSocket; 112 | 113 | // senderSocket.BeginReceive( 114 | // buffer, 115 | // 0, 116 | // buffer.Length, 117 | // SocketFlags.None, 118 | // new AsyncCallback(ReceiveCallback), 119 | // obj); 120 | // } 121 | 122 | // public Task ReadLine() { 123 | // //return Task.Run(() => { 124 | // // mutexInput.WaitOne(); 125 | 126 | // // if (receiveException != null) { 127 | // // throw new Exception("Failed to receive", receiveException); 128 | // // } 129 | 130 | // // lock (inputLines) { 131 | // // string content = inputLines.First(); 132 | // // inputLines.RemoveAt(0); 133 | // // return content; 134 | // // } 135 | // //}); 136 | // return null; 137 | // } 138 | 139 | // public async Task ReadAs() { 140 | // //string content = await this.ReadLine(); 141 | // //T obj = JsonConvert.DeserializeObject(content); 142 | // //return obj; 143 | // return default(T); 144 | // } 145 | 146 | // public bool HasData() { 147 | // return inputLines.Count > 0; 148 | // } 149 | 150 | // public void Write(string text) { 151 | // byte[] byteData = Encoding.Unicode.GetBytes(text); 152 | // senderSocket.Send(byteData); 153 | // } 154 | 155 | // public void WriteLine(string text) { 156 | // byte[] byteData = Encoding.Unicode.GetBytes(text + '\n'); 157 | // senderSocket.Send(byteData); 158 | // } 159 | 160 | // public void ReceiveCallback(IAsyncResult ar) { 161 | // try { 162 | // object[] obj = new object[2]; 163 | // obj = (object[])ar.AsyncState; 164 | 165 | // byte[] buffer = (byte[])obj[0]; 166 | 167 | // var handler = (Socket)obj[1]; 168 | 169 | // StringBuilder content = new StringBuilder(); 170 | 171 | // int bytesRead = handler.EndReceive(ar); 172 | 173 | // bool keepReading = bytesRead > 0; 174 | 175 | // while (keepReading) { 176 | // string partialMessage = Encoding.Unicode.GetString(buffer, 0, 177 | // bytesRead); 178 | // keepReading = partialMessage[partialMessage.Length - 1] != '\n'; 179 | 180 | // int start = 0; 181 | // int i = partialMessage.IndexOf('\n'); 182 | // int length = partialMessage.Length; 183 | // while (i < partialMessage.Length && i > -1) { 184 | // lock (inputLines) { 185 | // content.Append(partialMessage.Substring(start, i - start)); 186 | // inputLines.Add(content.ToString()); 187 | // content.Clear(); 188 | // mutexInput.Release(1); 189 | // //Console.WriteLine(mutexInput.Release(1)); 190 | // } 191 | 192 | // length -= i - start + 1; 193 | // start = i + 1; 194 | // i = partialMessage.IndexOf('\n', start); 195 | // } 196 | 197 | // if (partialMessage.Length > length) { 198 | // content.Append(partialMessage.Substring(start, length)); 199 | // } else { 200 | // content.Append(partialMessage); 201 | // } 202 | 203 | // if (keepReading) { 204 | // bytesRead = senderSocket.Receive( 205 | // buffer, 206 | // 0, 207 | // buffer.Length, 208 | // SocketFlags.None); 209 | // } 210 | // } 211 | 212 | // senderSocket.BeginReceive( 213 | // buffer, 214 | // 0, 215 | // buffer.Length, 216 | // SocketFlags.None, 217 | // new AsyncCallback(ReceiveCallback), 218 | // obj); 219 | // } catch (Exception ex) { 220 | // receiveException = ex; 221 | // mutexInput.Release(1); 222 | // } 223 | // } 224 | 225 | // public void Dispose() { 226 | // if (senderSocket != null) { 227 | // senderSocket.Close(); 228 | // senderSocket.Dispose(); 229 | // senderSocket = null; 230 | // } 231 | 232 | // if (listener != null) { 233 | // listener.Close(); 234 | // listener.Dispose(); 235 | // listener = null; 236 | // } 237 | // } 238 | // } 239 | //} 240 | -------------------------------------------------------------------------------- /python/towerfall/towerfall.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import random 5 | import signal 6 | import time 7 | from io import TextIOWrapper 8 | from typing import Any, Callable, Dict, List, Mapping, Optional 9 | 10 | import psutil 11 | from psutil import Popen 12 | 13 | from .connection import Connection 14 | 15 | _DEFAULT_STEAM_PATH_WINDOWS = 'C:/Program Files (x86)/Steam/steamapps/common/TowerFall' 16 | _ENV_TOWERFALL_PATH = 'TOWERFALL_PATH' 17 | 18 | class TowerfallError(Exception): 19 | pass 20 | 21 | class Towerfall: 22 | ''' 23 | Creates or reuses a Towerfall game process. 24 | 25 | params towerfall_path: The parent path where Towerfall.exe is located. 26 | params timeout: The timeout for the management API (Config, Reset). 27 | params verbose: The verbosity level. 0: no logging, 1: much logging. 28 | ''' 29 | def __init__(self, 30 | config: Mapping[str, Any] = {}, 31 | towerfall_path: str = _DEFAULT_STEAM_PATH_WINDOWS, 32 | timeout: float = 2, 33 | verbose: int = 0): 34 | if not os.path.exists(towerfall_path): 35 | env_towerfall_path = os.environ.get(_ENV_TOWERFALL_PATH) 36 | if env_towerfall_path: 37 | if not os.path.exists(env_towerfall_path): 38 | raise TowerfallError(f'\n\nInstallation path defined in env variable {_ENV_TOWERFALL_PATH} does not exist: {env_towerfall_path}. Make sure {_ENV_TOWERFALL_PATH} is set to the installation path of Towerfall.') 39 | towerfall_path = env_towerfall_path 40 | else: 41 | raise TowerfallError(f'\n\nThe default installation path does not exist: {towerfall_path}. You have 2 options:\n a) Set env variable {_ENV_TOWERFALL_PATH} with the installation path.\n b) Set the towerfall_path parameter in Towerfall constructor with the installation path.') 42 | self.config: Mapping[str, Any] = config 43 | self.towerfall_path = towerfall_path 44 | self.towerfall_path_exe = os.path.join(self.towerfall_path, 'TowerFall.exe') 45 | self.pool_name = 'default' 46 | self.pool_path = os.path.join(self.towerfall_path, 'aimod', 'pools', self.pool_name) 47 | self.timeout = timeout 48 | self.verbose = verbose 49 | tries = 0 50 | while True: 51 | self.port = self._attain_game_port() 52 | 53 | try: 54 | self.open_connection = Connection(self.port, timeout=timeout, verbose=verbose) 55 | self.send_config(config) 56 | break 57 | except TowerfallError: 58 | if tries > 3: 59 | raise TowerfallError('Could not config a Towerfall process.') 60 | tries += 1 61 | 62 | def join(self, timeout: float = 2, verbose: int = 0) -> Connection: 63 | ''' 64 | Joins a towerfall game. 65 | 66 | params timeout: Timeout in seconds to wait for a response. The same timeout will be used on calls to get the observations. 67 | 68 | returns: A connection to a Towerfall game. This should be used by the agent to interact with the game. 69 | ''' 70 | connection = Connection(self.port, timeout=timeout, verbose=verbose) 71 | connection.send_json(dict(type='join')) 72 | response = connection.read_json() 73 | if response['type'] != 'result': 74 | raise TowerfallError(f'Unexpected response type: {response["type"]}') 75 | if not response['success']: 76 | raise TowerfallError(f'Failed to join the game. Port: {self.port}, Response: {response["message"]}') 77 | self._try_log(logging.info, f'Successfully joined the game. Port: {self.port}') 78 | return connection 79 | 80 | def send_reset(self, entities: Optional[List[Dict[str, Any]]] = None): 81 | ''' 82 | Sends a game reset. This will recreate the entities in the game in the same scenario. To change the scenario, use send_config. 83 | 84 | params entities: The entities to reset. If None, the entities specified in the last reset will be used. 85 | ''' 86 | 87 | response = self.send_request_json(dict(type='reset', entities=entities)) 88 | if response['type'] != 'result': 89 | raise TowerfallError(f'Unexpected response type: {response["type"]}') 90 | if not response['success']: 91 | raise TowerfallError(f'Failed to reset the game. Port: {self.port}, Response: {response["message"]}') 92 | self._try_log(logging.info, f'Successfully reset the game. Port: {self.port}') 93 | 94 | def send_config(self, config = None): 95 | ''' 96 | Sends a game configuration. This will restart the session of the game in the specified scenario and specified number of agents. 97 | 98 | params config: The configuration to send. If None, the configuration specified in the last config will be used. 99 | ''' 100 | if config: 101 | self.config = config 102 | else: 103 | config = self.config 104 | 105 | response = self.send_request_json(dict(type='config', config=config)) 106 | if response['type'] != 'result': 107 | raise TowerfallError(f'Unexpected response type: {response["type"]}') 108 | if not response['success']: 109 | raise TowerfallError(f'Failed to configure the game. Port: {self.port}, Response: {response["message"]}') 110 | self.config = config 111 | 112 | def send_request_json(self, obj: Mapping[str, Any]): 113 | self.open_connection.send_json(obj) 114 | return self.open_connection.read_json() 115 | 116 | @classmethod 117 | def close_all(cls): 118 | ''' 119 | Closes all Towerfall processes. 120 | ''' 121 | logging.info('Closing all TowerFall.exe processes...') 122 | for process in psutil.process_iter(attrs=['pid', 'name']): 123 | # logging.info(f'Checking process {process.pid} {process.name()}') 124 | if process.name() != 'TowerFall.exe': 125 | continue 126 | try: 127 | logging.info(f'Killing process {process.pid}...') 128 | os.kill(process.pid, signal.SIGTERM) 129 | except Exception as ex: 130 | logging.error(f'Failed to kill process {process.pid}: {ex}') 131 | continue 132 | 133 | def close(self): 134 | ''' 135 | Close the management connection. This will free the Towerfall process to be used by other clients. 136 | ''' 137 | self.open_connection.close() 138 | 139 | def _attain_game_port(self) -> int: 140 | metadata = self._find_compatible_metadata() 141 | 142 | if not metadata: 143 | self._try_log(logging.info, f'Starting new process from {self.towerfall_path_exe}.') 144 | pargs = [self.towerfall_path_exe] 145 | Popen(pargs, cwd=self.towerfall_path) 146 | 147 | tries = 0 148 | self._try_log(logging.info, f'Waiting for available process.') 149 | while not metadata and tries < 10: 150 | time.sleep(2) 151 | metadata = self._find_compatible_metadata() 152 | tries += 1 153 | if not metadata: 154 | raise TowerfallError('Could not find or create a Towerfall process.') 155 | 156 | return metadata['port'] 157 | 158 | def _find_compatible_metadata(self) -> Optional[Mapping[str, Any]]: 159 | if not os.path.exists(self.pool_path): 160 | return None 161 | dirs = list(os.listdir(self.pool_path)) 162 | random.shuffle(dirs) 163 | for file_name in os.listdir(self.pool_path): 164 | try: 165 | pid = int(file_name) 166 | psutil.Process(pid) 167 | except (ValueError, psutil.NoSuchProcess): 168 | os.remove(os.path.join(self.pool_path, file_name)) 169 | continue 170 | try: 171 | with open(os.path.join(self.pool_path, file_name), 'r') as file: 172 | try: 173 | metadata = Towerfall._load_metadata(file) 174 | except (ValueError, json.JSONDecodeError, FileNotFoundError) as ex: 175 | self._try_log(logging.warning, f'Invalid metadata file {file_name}. Exception: {ex}') 176 | continue 177 | return metadata 178 | except PermissionError as ex: 179 | self._try_log(logging.warning, f'Failed to access {file_name}. Exception: {ex}') 180 | continue 181 | return None 182 | 183 | @staticmethod 184 | def _load_metadata(file: TextIOWrapper) -> Mapping[str, Any]: 185 | metadata = json.load(file) 186 | if 'port' not in metadata: 187 | raise ValueError('Port not found in metadata.') 188 | try: 189 | metadata['port'] = int(metadata['port']) 190 | except ValueError: 191 | raise ValueError(f'Port is not an integer. Port: {metadata["port"]}') 192 | 193 | if 'fastrun' not in metadata: 194 | metadata['fastrun'] = False 195 | if 'nographics' not in metadata: 196 | metadata['nographics'] = False 197 | return metadata 198 | 199 | def _try_log(self, log_fn: Callable[[str], None], message: str): 200 | if self.verbose > 0: 201 | log_fn(message) 202 | 203 | -------------------------------------------------------------------------------- /TowerFallAi/.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Ww][Ii][Nn]32/ 27 | [Aa][Rr][Mm]/ 28 | [Aa][Rr][Mm]64/ 29 | bld/ 30 | [Bb]in/ 31 | [Oo]bj/ 32 | [Ll]og/ 33 | [Ll]ogs/ 34 | 35 | # Visual Studio 2015/2017 cache/options directory 36 | .vs/ 37 | # Uncomment if you have tasks that create the project's static files in wwwroot 38 | #wwwroot/ 39 | 40 | # Visual Studio 2017 auto generated files 41 | Generated\ Files/ 42 | 43 | # MSTest test Results 44 | [Tt]est[Rr]esult*/ 45 | [Bb]uild[Ll]og.* 46 | 47 | # NUnit 48 | *.VisualState.xml 49 | TestResult.xml 50 | nunit-*.xml 51 | 52 | # Build Results of an ATL Project 53 | [Dd]ebugPS/ 54 | [Rr]eleasePS/ 55 | dlldata.c 56 | 57 | # Benchmark Results 58 | BenchmarkDotNet.Artifacts/ 59 | 60 | # .NET Core 61 | project.lock.json 62 | project.fragment.lock.json 63 | artifacts/ 64 | 65 | # ASP.NET Scaffolding 66 | ScaffoldingReadMe.txt 67 | 68 | # StyleCop 69 | StyleCopReport.xml 70 | 71 | # Files built by Visual Studio 72 | *_i.c 73 | *_p.c 74 | *_h.h 75 | *.ilk 76 | *.meta 77 | *.obj 78 | *.iobj 79 | *.pch 80 | *.pdb 81 | *.ipdb 82 | *.pgc 83 | *.pgd 84 | *.rsp 85 | *.sbr 86 | *.tlb 87 | *.tli 88 | *.tlh 89 | *.tmp 90 | *.tmp_proj 91 | *_wpftmp.csproj 92 | *.log 93 | *.tlog 94 | *.vspscc 95 | *.vssscc 96 | .builds 97 | *.pidb 98 | *.svclog 99 | *.scc 100 | *.exe 101 | 102 | # Chutzpah Test files 103 | _Chutzpah* 104 | 105 | # Visual C++ cache files 106 | ipch/ 107 | *.aps 108 | *.ncb 109 | *.opendb 110 | *.opensdf 111 | *.sdf 112 | *.cachefile 113 | *.VC.db 114 | *.VC.VC.opendb 115 | 116 | # Visual Studio profiler 117 | *.psess 118 | *.vsp 119 | *.vspx 120 | *.sap 121 | 122 | # Visual Studio Trace Files 123 | *.e2e 124 | 125 | # TFS 2012 Local Workspace 126 | $tf/ 127 | 128 | # Guidance Automation Toolkit 129 | *.gpState 130 | 131 | # ReSharper is a .NET coding add-in 132 | _ReSharper*/ 133 | *.[Rr]e[Ss]harper 134 | *.DotSettings.user 135 | 136 | # TeamCity is a build add-in 137 | _TeamCity* 138 | 139 | # DotCover is a Code Coverage Tool 140 | *.dotCover 141 | 142 | # AxoCover is a Code Coverage Tool 143 | .axoCover/* 144 | !.axoCover/settings.json 145 | 146 | # Coverlet is a free, cross platform Code Coverage Tool 147 | coverage*.json 148 | coverage*.xml 149 | coverage*.info 150 | 151 | # Visual Studio code coverage results 152 | *.coverage 153 | *.coveragexml 154 | 155 | # NCrunch 156 | _NCrunch_* 157 | .*crunch*.local.xml 158 | nCrunchTemp_* 159 | 160 | # MightyMoose 161 | *.mm.* 162 | AutoTest.Net/ 163 | 164 | # Web workbench (sass) 165 | .sass-cache/ 166 | 167 | # Installshield output folder 168 | [Ee]xpress/ 169 | 170 | # DocProject is a documentation generator add-in 171 | DocProject/buildhelp/ 172 | DocProject/Help/*.HxT 173 | DocProject/Help/*.HxC 174 | DocProject/Help/*.hhc 175 | DocProject/Help/*.hhk 176 | DocProject/Help/*.hhp 177 | DocProject/Help/Html2 178 | DocProject/Help/html 179 | 180 | # Click-Once directory 181 | publish/ 182 | 183 | # Publish Web Output 184 | *.[Pp]ublish.xml 185 | *.azurePubxml 186 | # Note: Comment the next line if you want to checkin your web deploy settings, 187 | # but database connection strings (with potential passwords) will be unencrypted 188 | *.pubxml 189 | *.publishproj 190 | 191 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 192 | # checkin your Azure Web App publish settings, but sensitive information contained 193 | # in these scripts will be unencrypted 194 | PublishScripts/ 195 | 196 | # NuGet Packages 197 | *.nupkg 198 | # NuGet Symbol Packages 199 | *.snupkg 200 | # The packages folder can be ignored because of Package Restore 201 | **/[Pp]ackages/* 202 | # except build/, which is used as an MSBuild target. 203 | !**/[Pp]ackages/build/ 204 | # Uncomment if necessary however generally it will be regenerated when needed 205 | #!**/[Pp]ackages/repositories.config 206 | # NuGet v3's project.json files produces more ignorable files 207 | *.nuget.props 208 | *.nuget.targets 209 | 210 | # Microsoft Azure Build Output 211 | csx/ 212 | *.build.csdef 213 | 214 | # Microsoft Azure Emulator 215 | ecf/ 216 | rcf/ 217 | 218 | # Windows Store app package directories and files 219 | AppPackages/ 220 | BundleArtifacts/ 221 | Package.StoreAssociation.xml 222 | _pkginfo.txt 223 | *.appx 224 | *.appxbundle 225 | *.appxupload 226 | 227 | # Visual Studio cache files 228 | # files ending in .cache can be ignored 229 | *.[Cc]ache 230 | # but keep track of directories ending in .cache 231 | !?*.[Cc]ache/ 232 | 233 | # Others 234 | ClientBin/ 235 | ~$* 236 | *~ 237 | *.dbmdl 238 | *.dbproj.schemaview 239 | *.jfm 240 | *.pfx 241 | *.publishsettings 242 | orleans.codegen.cs 243 | 244 | # Including strong name files can present a security risk 245 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 246 | #*.snk 247 | 248 | # Since there are multiple workflows, uncomment next line to ignore bower_components 249 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 250 | #bower_components/ 251 | 252 | # RIA/Silverlight projects 253 | Generated_Code/ 254 | 255 | # Backup & report files from converting an old project file 256 | # to a newer Visual Studio version. Backup files are not needed, 257 | # because we have git ;-) 258 | _UpgradeReport_Files/ 259 | Backup*/ 260 | UpgradeLog*.XML 261 | UpgradeLog*.htm 262 | ServiceFabricBackup/ 263 | *.rptproj.bak 264 | 265 | # SQL Server files 266 | *.mdf 267 | *.ldf 268 | *.ndf 269 | 270 | # Business Intelligence projects 271 | *.rdl.data 272 | *.bim.layout 273 | *.bim_*.settings 274 | *.rptproj.rsuser 275 | *- [Bb]ackup.rdl 276 | *- [Bb]ackup ([0-9]).rdl 277 | *- [Bb]ackup ([0-9][0-9]).rdl 278 | 279 | # Microsoft Fakes 280 | FakesAssemblies/ 281 | 282 | # GhostDoc plugin setting file 283 | *.GhostDoc.xml 284 | 285 | # Node.js Tools for Visual Studio 286 | .ntvs_analysis.dat 287 | node_modules/ 288 | 289 | # Visual Studio 6 build log 290 | *.plg 291 | 292 | # Visual Studio 6 workspace options file 293 | *.opt 294 | 295 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 296 | *.vbw 297 | 298 | # Visual Studio 6 auto-generated project file (contains which files were open etc.) 299 | *.vbp 300 | 301 | # Visual Studio 6 workspace and project file (working project files containing files to include in project) 302 | *.dsw 303 | *.dsp 304 | 305 | # Visual Studio 6 technical files 306 | *.ncb 307 | *.aps 308 | 309 | # Visual Studio LightSwitch build output 310 | **/*.HTMLClient/GeneratedArtifacts 311 | **/*.DesktopClient/GeneratedArtifacts 312 | **/*.DesktopClient/ModelManifest.xml 313 | **/*.Server/GeneratedArtifacts 314 | **/*.Server/ModelManifest.xml 315 | _Pvt_Extensions 316 | 317 | # Paket dependency manager 318 | .paket/paket.exe 319 | paket-files/ 320 | 321 | # FAKE - F# Make 322 | .fake/ 323 | 324 | # CodeRush personal settings 325 | .cr/personal 326 | 327 | # Python Tools for Visual Studio (PTVS) 328 | __pycache__/ 329 | *.pyc 330 | 331 | # Cake - Uncomment if you are using it 332 | # tools/** 333 | # !tools/packages.config 334 | 335 | # Tabs Studio 336 | *.tss 337 | 338 | # Telerik's JustMock configuration file 339 | *.jmconfig 340 | 341 | # BizTalk build output 342 | *.btp.cs 343 | *.btm.cs 344 | *.odx.cs 345 | *.xsd.cs 346 | 347 | # OpenCover UI analysis results 348 | OpenCover/ 349 | 350 | # Azure Stream Analytics local run output 351 | ASALocalRun/ 352 | 353 | # MSBuild Binary and Structured Log 354 | *.binlog 355 | 356 | # NVidia Nsight GPU debugger configuration file 357 | *.nvuser 358 | 359 | # MFractors (Xamarin productivity tool) working folder 360 | .mfractor/ 361 | 362 | # Local History for Visual Studio 363 | .localhistory/ 364 | 365 | # Visual Studio History (VSHistory) files 366 | .vshistory/ 367 | 368 | # BeatPulse healthcheck temp database 369 | healthchecksdb 370 | 371 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 372 | MigrationBackup/ 373 | 374 | # Ionide (cross platform F# VS Code tools) working folder 375 | .ionide/ 376 | 377 | # Fody - auto-generated XML schema 378 | FodyWeavers.xsd 379 | 380 | # VS Code files for those working on multiple tools 381 | .vscode/* 382 | !.vscode/settings.json 383 | !.vscode/tasks.json 384 | !.vscode/launch.json 385 | !.vscode/extensions.json 386 | *.code-workspace 387 | 388 | # Local History for Visual Studio Code 389 | .history/ 390 | 391 | # Windows Installer files from build outputs 392 | *.cab 393 | *.msi 394 | *.msix 395 | *.msm 396 | *.msp 397 | 398 | -------------------------------------------------------------------------------- /TowerFallAi/TowerfallAiMod/Core/ConnectionDispatcher.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Diagnostics; 4 | using System.IO; 5 | using System.Net.Sockets; 6 | using System.Threading; 7 | using System.Threading.Tasks; 8 | using Newtonsoft.Json; 9 | using TowerfallAi.Common; 10 | using TowerfallAi.Data; 11 | 12 | namespace TowerfallAi.Core { 13 | public class ConnectionDispatcher { 14 | private Server server; 15 | private List remoteConnections; 16 | 17 | private object joinLock = new object(); 18 | private object waitConnectionsLock = new object(); 19 | private bool isWaitingConnection = false; 20 | private int joinCount; 21 | 22 | private Semaphore semConfig = new Semaphore(1, 1); 23 | 24 | private TaskCompletionSource> remoteConnectionsTaskSource; 25 | 26 | private RemoteConnection openConfigConnection; 27 | 28 | private const string pools_dir = "pools"; 29 | private string poolPath; 30 | 31 | public int Port { get; private set; } 32 | 33 | public ConnectionDispatcher(string poolName) { 34 | this.poolPath = Path.Combine(AiMod.BaseDirectory, pools_dir, poolName); 35 | server = new Server(OnConnection, () => { throw new Exception("Server closed"); }); 36 | remoteConnections = new List(); 37 | Port = server.Start(); 38 | RegisterInPool(); 39 | } 40 | 41 | public bool IsRunning { get { return server.IsAlive(); } } 42 | 43 | private async Task> WaitForConnectionsAsync(int count) { 44 | if (count < 0) throw new Exception($"Expected count > 0. Got: {count}"); 45 | 46 | lock (waitConnectionsLock) { 47 | if (isWaitingConnection) throw new InvalidOperationException("Cannot have concurrent WaitForConnections calls."); 48 | isWaitingConnection = true; 49 | } 50 | 51 | lock (joinLock) { 52 | remoteConnections.Clear(); 53 | isWaitingConnection = true; 54 | joinCount = count; 55 | remoteConnectionsTaskSource = new TaskCompletionSource>(); 56 | } 57 | 58 | Logger.Info($"Accepting join requests: {count}"); 59 | 60 | // ConfigureAwait(false) is required, otherwise the main thread is blocked. 61 | var result = await remoteConnectionsTaskSource.Task.ConfigureAwait(false); 62 | isWaitingConnection = false; 63 | joinCount = 0; 64 | return result; 65 | } 66 | 67 | // The server handles one connection at a time. So it is ok for this implementation to not be thread safe. 68 | private void OnConnection(Socket socket) { 69 | RemoteConnection connection = new RemoteConnection(socket); 70 | try { 71 | Message message = JsonConvert.DeserializeObject(connection.ReadAsync().Result); 72 | if (message.type == Message.Type.Join) { 73 | HandleJoinMessage(connection); 74 | } else if (message.type == Message.Type.Config) { 75 | HandleNewConfigConnection(connection, message); 76 | } else if (message.type == Message.Type.Reset) { 77 | HandleResetMessage(connection, message); 78 | } else { 79 | Logger.Error("Message type not supported: {0}".Format(message.type)); 80 | } 81 | } catch (SocketException ex) { 82 | Logger.Info($"Error handling connection: Exception:\n {ex}"); 83 | } 84 | } 85 | 86 | private void HandleJoinMessage(RemoteConnection connection) { 87 | lock (joinLock) { 88 | if (!isWaitingConnection || remoteConnections.Count >= joinCount) { 89 | Logger.Info("No open slot to join."); 90 | connection.Write(JsonConvert.SerializeObject(new Message { 91 | type = Message.Type.Result, 92 | success = false, 93 | message = "No open slot to join.", 94 | })); 95 | return; 96 | } 97 | 98 | remoteConnections.Add(connection); 99 | Logger.Info($"Join accepted: {remoteConnections.Count}/{joinCount}"); 100 | if (remoteConnections.Count >= joinCount) { 101 | remoteConnectionsTaskSource.SetResult(remoteConnections); 102 | } 103 | 104 | connection.Write(JsonConvert.SerializeObject(new Message { 105 | type = Message.Type.Result, 106 | success = true, 107 | message = "Game will start once all agents join." 108 | })); 109 | } 110 | } 111 | 112 | private bool CheckOpenConnectionAndReply(RemoteConnection connection) { 113 | if (openConfigConnection != null && openConfigConnection.IsAlive()) { 114 | connection.Write(JsonConvert.SerializeObject(new Message { 115 | type = Message.Type.Result, 116 | success = false, 117 | message = "Game is currently owned by a different client.", 118 | })); 119 | return true; 120 | } 121 | return false; 122 | } 123 | 124 | private void HandleNewConfigConnection(RemoteConnection connection, Message message) { 125 | Logger.Info("Config received from new connection."); 126 | if (CheckOpenConnectionAndReply(connection)) { 127 | return; 128 | } 129 | 130 | HandleConfigMessage(connection, message); 131 | ListenToNewConfigConnection(connection); 132 | } 133 | 134 | private void ListenToNewConfigConnection(RemoteConnection connection) { 135 | openConfigConnection = connection; 136 | UnregisterFromPool(); 137 | TaskEx.Run(async () => { 138 | await connection.WaitDisconnectAsync(); 139 | RegisterInPool(); 140 | }); 141 | TaskEx.Run(() => { 142 | try { 143 | Logger.Info("Start listening to new config connection."); 144 | while (true) { 145 | Message message = JsonConvert.DeserializeObject(connection.ReadAsync().Result); 146 | if (message.type == Message.Type.Config) { 147 | HandleConfigMessage(connection, message); 148 | } else if (message.type == Message.Type.Reset) { 149 | HandleResetMessage(connection, message); 150 | } else { 151 | Logger.Error("Message type not supported: {0}".Format(message.type)); 152 | } 153 | } 154 | } catch (SocketException ex) { 155 | Logger.Info($"Error handling connection: Exception:\n {ex}"); 156 | } finally { 157 | Logger.Info("Stopped listening to config connection."); 158 | } 159 | }); 160 | } 161 | 162 | private void HandleConfigMessage(RemoteConnection connection, Message message) { 163 | try { 164 | AiMod.ValidateConfig(message.config); 165 | int connectionsRequired = Agents.CountRemoteConnections(message.config.agents); 166 | lock (joinLock) { 167 | if (remoteConnectionsTaskSource != null && remoteConnectionsTaskSource.Task.IsAlive()) { 168 | remoteConnectionsTaskSource.SetCanceled(); 169 | } 170 | } 171 | 172 | if (!semConfig.WaitOne(1000)) { 173 | connection.Write(JsonConvert.SerializeObject(new Message { 174 | type = Message.Type.Result, 175 | success = false, 176 | message = "Unable to cancel existing reconfig request", 177 | })); 178 | return; 179 | } 180 | 181 | Logger.Info("Reconfiguration started."); 182 | TaskEx.Run(async () => { 183 | try { 184 | List connections = await WaitForConnectionsAsync(connectionsRequired); 185 | AiMod.ReconfigOperation reconfigOperation = new AiMod.ReconfigOperation() { 186 | Config = message.config, 187 | Connections = connections 188 | }; 189 | AiMod.Reconfig(reconfigOperation); 190 | } catch (TaskCanceledException) { 191 | Logger.Info("Config cancelled"); 192 | } catch (Exception ex) { 193 | Logger.Error($"{ex}"); 194 | throw; 195 | } finally { 196 | semConfig.Release(); 197 | } 198 | }); 199 | connection.Write(JsonConvert.SerializeObject(new Message { 200 | type = Message.Type.Result, 201 | success = true, 202 | })); 203 | } catch (ConfigException ex) { 204 | Logger.Error(ex.Message); 205 | connection.Write(JsonConvert.SerializeObject(new Message { 206 | type = Message.Type.Result, 207 | success = false, 208 | message = ex.Message, 209 | })); 210 | } 211 | } 212 | 213 | private void RegisterInPool() { 214 | string metadataPath = GetMetadataPath(); 215 | Util.CreateDirectory(poolPath); 216 | 217 | using (var f = File.Open(metadataPath, FileMode.OpenOrCreate)) { 218 | using (var w = new StreamWriter(f)) { 219 | Logger.Info($"Writing metadata to {metadataPath}"); 220 | w.Write(JsonConvert.SerializeObject(new Metadata { 221 | port = Port, 222 | fastrun = AiMod.IsFastrun, 223 | nographics = AiMod.NoGraphics, 224 | })); 225 | } 226 | } 227 | } 228 | 229 | public void UnregisterFromPool() { 230 | string metadataPath = GetMetadataPath(); 231 | Logger.Info($"Deleting {metadataPath}"); 232 | File.Delete(GetMetadataPath()); 233 | } 234 | 235 | private string GetMetadataPath() { 236 | return Path.Combine(poolPath, Process.GetCurrentProcess().Id.ToString()); 237 | } 238 | 239 | private void HandleResetMessage(RemoteConnection connection, Message message) { 240 | AiMod.Reset(new AiMod.ResetOperation { Entities = message.entities }); 241 | connection.Write(JsonConvert.SerializeObject(new Message { 242 | type = Message.Type.Result, 243 | success = true, 244 | })); 245 | } 246 | } 247 | } 248 | --------------------------------------------------------------------------------