├── environment.yml ├── environment_gpu.yml ├── agent ├── src │ ├── main │ │ └── java │ │ │ └── com │ │ │ └── alphaStS │ │ │ ├── State.java │ │ │ ├── utils │ │ │ ├── Tuple.java │ │ │ ├── Tuple3.java │ │ │ ├── DrawOrder.java │ │ │ ├── DrawOrderReadOnly.java │ │ │ ├── CircularArray.java │ │ │ ├── BigRational.java │ │ │ ├── CounterStat.java │ │ │ ├── Utils.java │ │ │ └── FuzzyMatch.java │ │ │ ├── enums │ │ │ ├── Stance.java │ │ │ ├── CharacterEnum.java │ │ │ ├── DebuffType.java │ │ │ └── OrbType.java │ │ │ ├── GameAction.java │ │ │ ├── model │ │ │ ├── NNOutput.java │ │ │ ├── Model.java │ │ │ ├── NNInputHash.java │ │ │ ├── ModelBatchProducer.java │ │ │ ├── ModelExecutor.java │ │ │ └── ModelPlain.java │ │ │ ├── OnDamageHandler.java │ │ │ ├── action │ │ │ ├── GameEnvironmentAction.java │ │ │ ├── EndOfGameAction.java │ │ │ ├── CardDrawAction.java │ │ │ ├── GuardianGainBlockAction.java │ │ │ ├── HavocExhaustAction.java │ │ │ └── HavocAction.java │ │ │ ├── TrainingTarget.java │ │ │ ├── ServerRequest.java │ │ │ ├── GameActionCtx.java │ │ │ ├── GameActionType.java │ │ │ ├── GameEventHandler.java │ │ │ ├── GameEventEnemyHandler.java │ │ │ ├── RandomGenCtx.java │ │ │ ├── GameEventCardHandler.java │ │ │ ├── PlayerBuff.java │ │ │ ├── GameStep.java │ │ │ ├── CrashException.java │ │ │ ├── LineOfPlay.java │ │ │ ├── enemy │ │ │ ├── EnemyListReadOnly.java │ │ │ ├── EnemyList.java │ │ │ ├── EnemyReadOnly.java │ │ │ └── Enemy.java │ │ │ ├── SearchFrontier.java │ │ │ ├── player │ │ │ ├── PlayerReadOnly.java │ │ │ └── Player.java │ │ │ ├── Configuration.java │ │ │ ├── RandomGen.java │ │ │ └── GameStateBuilder.java │ ├── resources │ │ ├── mallet.jar │ │ └── mallet-deps.jar │ └── test │ │ └── java │ │ └── com │ │ └── alphaStS │ │ └── AppTest.java └── pom.xml ├── .gitignore ├── stats.py ├── misc.py ├── remoteSetup.txt ├── interactive.py ├── plot.py ├── commands.txt ├── test_dist.py ├── test.py ├── upload_server.py ├── ml_conv.py └── README.md /environment.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MANGO1234/AlphaStS/HEAD/environment.yml -------------------------------------------------------------------------------- /environment_gpu.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MANGO1234/AlphaStS/HEAD/environment_gpu.yml -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/State.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public interface State { 4 | } 5 | -------------------------------------------------------------------------------- /agent/src/resources/mallet.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MANGO1234/AlphaStS/HEAD/agent/src/resources/mallet.jar -------------------------------------------------------------------------------- /agent/src/resources/mallet-deps.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MANGO1234/AlphaStS/HEAD/agent/src/resources/mallet-deps.jar -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/utils/Tuple.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.utils; 2 | 3 | public record Tuple(T1 v1, T2 v2) {} 4 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/utils/Tuple3.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.utils; 2 | 3 | public record Tuple3(T1 v1, T2 v2, T3 v3) {} 4 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/enums/Stance.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.enums; 2 | 3 | public enum Stance { 4 | NEUTRAL, WRATH, CALM, DIVINITY 5 | } -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/enums/CharacterEnum.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.enums; 2 | 3 | public enum CharacterEnum { 4 | IRONCLAD, SILENT, DEFECT, WATCHER 5 | } 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | saves*/ 2 | saves* 3 | tmp/ 4 | __pycache__ 5 | *.c 6 | *.iml 7 | agent/.idea/ 8 | *.pyd 9 | build/ 10 | kaggle/ 11 | target 12 | agent/hs_err_p* 13 | *.java~ 14 | .claude/ -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/GameAction.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public record GameAction(GameActionType type, int idx) { // idx is either cardIdx, enemyIdx, potionIdx, etc. 4 | } 5 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/model/NNOutput.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.model; 2 | 3 | public record NNOutput(float v_health, float v_win, float[] v_other, float[] policy, float[] rawPolicy, int[] legalActions) { 4 | } 5 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/OnDamageHandler.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public abstract class OnDamageHandler { 4 | public abstract void handle(GameState state, Object source, boolean isAttack, int damageDealt); 5 | } 6 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/model/Model.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.model; 2 | 3 | import com.alphaStS.GameState; 4 | 5 | public interface Model { 6 | NNOutput eval(GameState state); 7 | void close(); 8 | void startRecordCalls(); 9 | int endRecordCalls(); 10 | } 11 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/action/GameEnvironmentAction.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.action; 2 | 3 | import com.alphaStS.GameState; 4 | 5 | public interface GameEnvironmentAction { 6 | void doAction(GameState state); 7 | default boolean canHappenInsideCardPlay() { return false; }; 8 | } 9 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/TrainingTarget.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public interface TrainingTarget { 4 | void fillVArray(GameState state, VArray v, int isTerminal); 5 | void updateQValues(GameState state, VArray v); 6 | default int getNumberOfTargets() { 7 | return 1; 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/ServerRequest.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public class ServerRequest { 4 | public ServerRequestType type; 5 | public int remainingGames; 6 | public int nodeCount; 7 | public boolean zTraining; 8 | public boolean curriculumTraining; 9 | public byte[] bytes; 10 | } 11 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/GameActionCtx.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public enum GameActionCtx { 4 | BEGIN_PRE_BATTLE, 5 | BEGIN_BATTLE, 6 | PLAY_CARD, 7 | SELECT_ENEMY, 8 | SELECT_CARD_DISCARD, 9 | SELECT_CARD_HAND, 10 | SELECT_CARD_EXHAUST, 11 | SELECT_CARD_DECK, 12 | SELECT_CARD_1_OUT_OF_3, 13 | SCRYING, 14 | SELECT_SCENARIO, 15 | BEGIN_TURN, 16 | AFTER_RANDOMIZATION, 17 | START_OF_BATTLE, 18 | } 19 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/GameActionType.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public enum GameActionType { 4 | BEGIN_BATTLE, 5 | PLAY_CARD, 6 | SELECT_ENEMY, 7 | SELECT_CARD_DISCARD, 8 | SELECT_CARD_HAND, 9 | SELECT_CARD_EXHAUST, 10 | SELECT_CARD_DECK, 11 | SELECT_CARD_1_OUT_OF_3, 12 | SCRY_KEEP_CARD, 13 | END_TURN, 14 | USE_POTION, 15 | BEGIN_PRE_BATTLE, 16 | SELECT_SCENARIO, 17 | BEGIN_TURN, 18 | END_SELECT_CARD_HAND, 19 | END_SELECT_CARD_1_OUT_OF_3, 20 | AFTER_RANDOMIZATION, 21 | } 22 | -------------------------------------------------------------------------------- /stats.py: -------------------------------------------------------------------------------- 1 | # format 2 | 3 | import math 4 | 5 | # https://www.itl.nist.gov/div898/handbook/prc/section3/prc33.htm 6 | n_win1 = 6771 7 | n1 = 10000 8 | 9 | n_win2 = 6655 10 | n2 = 10000 11 | 12 | p1 = n_win1 / n1 13 | p2 = n_win2 / n2 14 | p = (n_win1 + n_win2) / (n1 + n2) 15 | 16 | z = (p1 - p2) / math.sqrt(p * (1 - p) * (1 / n1 + 1 / n2)) 17 | print(z) 18 | 19 | wins = 140 20 | losses = 124 21 | 22 | print((1 + math.erf((wins - losses) / math.sqrt((2 * wins + 2 * losses)))) / 2) 23 | print(math.erf((wins - losses) / math.sqrt((2 * wins + 2 * losses)))) 24 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/GameEventHandler.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public abstract class GameEventHandler implements Comparable { 4 | private int priority; 5 | 6 | protected GameEventHandler(int priority) { 7 | this.priority = priority; 8 | } 9 | 10 | protected GameEventHandler() { 11 | } 12 | 13 | public abstract void handle(GameState state); 14 | 15 | @Override public int compareTo(GameEventHandler other) { 16 | return Integer.compare(priority, other.priority); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | # format 2 | 3 | import sys 4 | 5 | args = [a for a in sys.argv] 6 | 7 | 8 | def getFlag(flag, default=False): 9 | new_args = [] 10 | for arg in args: 11 | if arg == flag: 12 | return True 13 | new_args.append(arg) 14 | return default 15 | 16 | 17 | def getFlagValue(flag, default=None): 18 | new_args = [] 19 | v = default 20 | for i, arg in enumerate(args): 21 | if arg == flag: 22 | v = args[i + 1] 23 | i += 1 24 | else: 25 | new_args.append(arg) 26 | return v 27 | -------------------------------------------------------------------------------- /remoteSetup.txt: -------------------------------------------------------------------------------- 1 | Setup ec2 to run server instance: 2 | 3 | Java: 4 | sudo yum install sudo yum install java-17-amazon-corretto-devel 5 | 6 | Maven: 7 | sudo wget https://www-eu.apache.org/dist/maven/maven-3/3.8.6/binaries/apache-maven-3.8.6-bin.tar.gz 8 | sudo tar xf ./apache-maven-*.tar.gz -C /opt 9 | sudo ln -s /opt/apache-maven-3.8.6 /opt/maven 10 | rm apache-maven-3.8.6-bin.tar.gz 11 | 12 | Add to .bashrc: 13 | 14 | 15 | Conda (not needed for now): 16 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 17 | bash Miniconda3-latest-Linux-x86_64.sh 18 | rm Miniconda3-latest-Linux-x86_64.sh -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/GameEventEnemyHandler.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | import com.alphaStS.enemy.EnemyReadOnly; 4 | 5 | public abstract class GameEventEnemyHandler implements Comparable { 6 | private int priority; 7 | 8 | public GameEventEnemyHandler(int priority) { 9 | this.priority = priority; 10 | } 11 | 12 | public GameEventEnemyHandler() { 13 | } 14 | 15 | public abstract void handle(GameState state, EnemyReadOnly enemy); 16 | 17 | @Override public int compareTo(GameEventEnemyHandler other) { 18 | return Integer.compare(other.priority, priority); 19 | } 20 | } -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/enums/DebuffType.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.enums; 2 | 3 | public enum DebuffType { 4 | VULNERABLE, 5 | WEAK, 6 | FRAIL, 7 | NO_MORE_CARD_DRAW, 8 | LOSE_STRENGTH, 9 | LOSE_DEXTERITY, 10 | LOSE_DEXTERITY_PER_TURN, 11 | LOSE_STRENGTH_EOT, 12 | LOSE_DEXTERITY_EOT, 13 | POISON, 14 | CORPSE_EXPLOSION, 15 | ENTANGLED, 16 | HEX, 17 | LOSE_FOCUS, 18 | LOSE_FOCUS_PER_TURN, 19 | CONSTRICTED, 20 | DRAW_REDUCTION, 21 | SNECKO, 22 | LOCK_ON, 23 | NO_BLOCK_FROM_CARDS, 24 | CHOKE, 25 | LOSE_ENERGY_PER_TURN, 26 | TALK_TO_THE_HAND, 27 | MARK, 28 | } 29 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/RandomGenCtx.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public enum RandomGenCtx { 4 | CardDraw, 5 | EnemyChooseMove, 6 | RandomEnemySwordBoomerang, 7 | RandomEnemyJuggernaut, 8 | RandomEnemyGeneral, 9 | RandomEnemyLightningOrb, 10 | RandomEnemyRipAndTear, 11 | RandomCardHand, 12 | RandomCardHandMummifiedHand, 13 | RandomCardHandWarpedTongs, 14 | ShieldGremlin, 15 | GremlinLeader, 16 | BronzeOrb, 17 | Chaos, 18 | ShieldAndSpear, 19 | Snecko, 20 | EntropicBrew, 21 | RandomCardGen, 22 | SelectCard1OutOf3, 23 | RandomEnemyHealth, 24 | Misc, 25 | CommonNumberVR, 26 | BeginningOfGameRandomization, 27 | Other, 28 | } 29 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/action/EndOfGameAction.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.action; 2 | 3 | import com.alphaStS.GameState; 4 | 5 | public class EndOfGameAction implements GameEnvironmentAction { 6 | public static final EndOfGameAction instance = new EndOfGameAction(); 7 | 8 | @Override public void doAction(GameState state) { 9 | for (var handler : state.properties.endOfBattleHandlers) { 10 | handler.handle(state); 11 | } 12 | } 13 | 14 | @Override public boolean equals(Object o) { 15 | if (this == o) 16 | return true; 17 | return o != null && getClass() == o.getClass(); 18 | } 19 | 20 | @Override public String toString() { 21 | return "EndOfGameAction"; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/action/CardDrawAction.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.action; 2 | 3 | import com.alphaStS.GameState; 4 | 5 | public class CardDrawAction implements GameEnvironmentAction { 6 | int n; 7 | 8 | public CardDrawAction(int n) { 9 | this.n = n; 10 | } 11 | 12 | @Override public void doAction(GameState state) { 13 | state.draw(n); 14 | } 15 | 16 | @Override public boolean equals(Object o) { 17 | if (this == o) 18 | return true; 19 | if (o == null || getClass() != o.getClass()) 20 | return false; 21 | CardDrawAction that = (CardDrawAction) o; 22 | return n == that.n; 23 | } 24 | 25 | @Override public int hashCode() { 26 | return n; 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/model/NNInputHash.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.model; 2 | 3 | import java.util.Arrays; 4 | 5 | class NNInputHash { 6 | float[] input; 7 | int hashCode; 8 | 9 | public NNInputHash(float[] input) { 10 | this.input = input; 11 | hashCode = Arrays.hashCode(input); 12 | } 13 | 14 | @Override public boolean equals(Object o) { 15 | if (this == o) 16 | return true; 17 | if (o == null || getClass() != o.getClass()) 18 | return false; 19 | NNInputHash inputHash = (NNInputHash) o; 20 | return hashCode == inputHash.hashCode && Arrays.equals(input, inputHash.input); 21 | } 22 | 23 | @Override public int hashCode() { 24 | return hashCode; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/enums/OrbType.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.enums; 2 | 3 | import com.alphaStS.GameState; 4 | import com.alphaStS.RandomGenCtx; 5 | 6 | public enum OrbType { 7 | EMPTY("E", "Empty"), LIGHTNING("L", "Lightning"), FROST("F", "Frost"), DARK("D", "Dark"), PLASMA("P", "Plasma"); 8 | 9 | public final String abbrev; 10 | public final String displayName; 11 | OrbType(String abbrev, String displayName) { 12 | this.abbrev = abbrev; 13 | this.displayName = displayName; 14 | } 15 | 16 | public static OrbType getRandom(GameState state) { 17 | state.setIsStochastic(); 18 | var values = OrbType.values(); 19 | int r = state.getSearchRandomGen().nextInt(values.length - 1, RandomGenCtx.Chaos, state) + 1; 20 | return values()[r]; 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/GameEventCardHandler.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public abstract class GameEventCardHandler implements Comparable { 4 | // todo: figure out how the game handle priority properly... 5 | public static int CLONE_CARD_PRIORITY = -100; 6 | public static int HEARTBEAT_PRIORITY = -10; 7 | 8 | private int priority; 9 | 10 | public GameEventCardHandler(int priority) { 11 | this.priority = priority; 12 | } 13 | 14 | public GameEventCardHandler() { 15 | } 16 | 17 | public abstract void handle(GameState state, int cardIdx, int lastIdx, int energyUsed, Class cloneSource, int cloneParentLocation); 18 | 19 | @Override public int compareTo(GameEventCardHandler other) { 20 | return Integer.compare(other.priority, priority); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/action/GuardianGainBlockAction.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.action; 2 | 3 | import com.alphaStS.GameState; 4 | import com.alphaStS.enemy.EnemyExordium; 5 | import com.alphaStS.enemy.EnemyReadOnly; 6 | 7 | public class GuardianGainBlockAction implements GameEnvironmentAction { 8 | public static GuardianGainBlockAction singleton = new GuardianGainBlockAction(); 9 | 10 | @Override public boolean equals(Object o) { 11 | if (this == o) 12 | return true; 13 | if (o == null || getClass() != o.getClass()) 14 | return false; 15 | return true; 16 | } 17 | 18 | @Override public int hashCode() { 19 | return 0; 20 | } 21 | 22 | @Override public void doAction(GameState state) { 23 | for (EnemyReadOnly enemy : state.getEnemiesForWrite().iterateOverAlive()) { 24 | if (enemy instanceof EnemyExordium.TheGuardian e) { 25 | e.gainBlock(20); 26 | } 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/action/HavocExhaustAction.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.action; 2 | 3 | import com.alphaStS.GameState; 4 | import com.alphaStS.card.Card; 5 | 6 | public class HavocExhaustAction implements GameEnvironmentAction { 7 | int cardIdx; 8 | 9 | public HavocExhaustAction(int cardIdx) { 10 | this.cardIdx = cardIdx; 11 | } 12 | 13 | @Override public void doAction(GameState state) { 14 | state.getCounterForWrite()[state.properties.havocCounterIdx]--; 15 | if (state.properties.cardDict[cardIdx].cardType != Card.POWER) { 16 | state.exhaustedCardHandle(cardIdx, true); 17 | } 18 | } 19 | 20 | @Override public boolean equals(Object o) { 21 | if (this == o) 22 | return true; 23 | if (o == null || getClass() != o.getClass()) 24 | return false; 25 | HavocExhaustAction that = (HavocExhaustAction) o; 26 | return cardIdx == that.cardIdx; 27 | } 28 | 29 | @Override public int hashCode() { 30 | return cardIdx; 31 | } 32 | 33 | @Override public String toString() { 34 | return "HavocExhaustAction(" + cardIdx + ")"; 35 | } 36 | } -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/utils/DrawOrder.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.utils; 2 | 3 | import java.util.Arrays; 4 | 5 | public class DrawOrder extends DrawOrderReadOnly { 6 | public DrawOrder(int startLen) { 7 | super(startLen); 8 | } 9 | 10 | public DrawOrder(DrawOrder other) { 11 | super(other); 12 | } 13 | 14 | public void pushOnTop(int cardIdx) { 15 | if (order.length <= len) { 16 | order = Arrays.copyOf(order, order.length + (order.length + 2) / 2); 17 | } 18 | order[len++] = cardIdx; 19 | } 20 | 21 | public int drawTop() { 22 | if (len == 0) { 23 | return -1; 24 | } 25 | return order[--len]; 26 | } 27 | 28 | public int peekTop() { 29 | if (len == 0) { 30 | return -1; 31 | } 32 | return order[len - 1]; 33 | } 34 | 35 | public void clear() { 36 | len = 0; 37 | } 38 | 39 | public void transformTopMost(int cardIdx, int newCardIdx) { 40 | for (int i = len - 1; i >= 0 ; i--) { 41 | if (order[i] == cardIdx) { 42 | order[i] = newCardIdx; 43 | return; 44 | } 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /interactive.py: -------------------------------------------------------------------------------- 1 | # format 2 | import subprocess 3 | import platform 4 | import os 5 | 6 | sep = ':' 7 | if platform.system() == 'Windows': 8 | sep = ';' 9 | 10 | 11 | CLASS_PATH = f'./target/classes{sep}{os.getenv("M2_HOME")}/repository/com/microsoft/onnxruntime/onnxruntime/1.10.0/onnxruntime-1.10.0.jar{sep}./src/resources/mallet.jar{sep}./src/resources/mallet-deps.jar{sep}{os.getenv("M2_HOME")}/repository/org/jdom/jdom/1.1/jdom-1.1.jar{sep}{os.getenv("M2_HOME")}/repository/com/fasterxml/jackson/core/jackson-databind/2.12.4/jackson-databind-2.12.4.jar{sep}{os.getenv("M2_HOME")}/repository/com/fasterxml/jackson/core/jackson-annotations/2.12.4/jackson-annotations-2.12.4.jar{sep}{os.getenv("M2_HOME")}/repository/com/fasterxml/jackson/core/jackson-core/2.12.4/jackson-core-2.12.4.jar{sep}{os.getenv("M2_HOME")}/repository/org/apache/commons/commons-compress/1.21/commons-compress-1.21.jar{sep}{os.getenv("M2_HOME")}/repository/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar{sep}{os.getenv("M2_HOME")}/repository/one/util/streamex/0.8.3/streamex-0.8.3.jar' 12 | 13 | os.chdir('agent') 14 | agent_args = ['java', '--add-opens', 'java.base/java.util=ALL-UNNAMED', '-classpath', CLASS_PATH, 'com.alphaStS.Main', '-i'] 15 | subprocess.run(agent_args) 16 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/PlayerBuff.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public record PlayerBuff(long mask, String name) { 4 | public final static PlayerBuff BARRICADE = new PlayerBuff(1, "Barricade"); 5 | public final static PlayerBuff CORRUPTION = new PlayerBuff(1L << 1, "Corruption"); 6 | 7 | public final static PlayerBuff AKABEKO = new PlayerBuff(1L << 32, "Akabeko"); 8 | public final static PlayerBuff ART_OF_WAR = new PlayerBuff(1L << 33, "Art of War"); 9 | public final static PlayerBuff CENTENNIAL_PUZZLE = new PlayerBuff(1L << 34, "Centennial Puzzle"); 10 | public final static PlayerBuff NECRONOMICON = new PlayerBuff(1L << 35, "Necronomicon"); 11 | public final static PlayerBuff BLASPHEMY = new PlayerBuff(1L << 36, "Blasphemy"); 12 | public final static PlayerBuff END_TURN_IMMEDIATELY = new PlayerBuff(1L << 37, "End Turn Immediately"); 13 | public final static PlayerBuff USED_VAULT = new PlayerBuff(1L << 38, "Used Vault"); 14 | public final static PlayerBuff WRIST_BLADE = new PlayerBuff(1L << 39, "Wrist Blade"); 15 | 16 | public final static PlayerBuff[] BUFFS = new PlayerBuff[] { 17 | BARRICADE, 18 | CORRUPTION, 19 | AKABEKO, 20 | ART_OF_WAR, 21 | CENTENNIAL_PUZZLE, 22 | NECRONOMICON, 23 | BLASPHEMY, 24 | END_TURN_IMMEDIATELY, 25 | USED_VAULT, 26 | WRIST_BLADE, 27 | }; 28 | } 29 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/action/HavocAction.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.action; 2 | 3 | import com.alphaStS.GameActionCtx; 4 | import com.alphaStS.GameState; 5 | import com.alphaStS.GameStateUtils; 6 | import com.alphaStS.RandomGenCtx; 7 | 8 | public class HavocAction implements GameEnvironmentAction { 9 | int cardIdx; 10 | 11 | public HavocAction(int cardIdx) { 12 | this.cardIdx = cardIdx; 13 | } 14 | 15 | @Override public void doAction(GameState state) { 16 | state.getCounterForWrite()[state.properties.havocCounterIdx]++; 17 | var action = state.properties.actionsByCtx[GameActionCtx.PLAY_CARD.ordinal()][cardIdx]; 18 | state.playCard(action, -1, true, null, false, true, -1, -1); 19 | while (state.actionCtx == GameActionCtx.SELECT_ENEMY) { 20 | int enemyIdx = GameStateUtils.getRandomEnemyIdx(state, RandomGenCtx.RandomEnemyGeneral); 21 | if (state.properties.makingRealMove || state.properties.stateDescOn) { 22 | state.getStateDesc().append(" -> ").append(state.getEnemiesForRead().get(enemyIdx).getName()).append(" (").append(enemyIdx).append(")"); 23 | } 24 | state.playCard(action, enemyIdx, true, null, false, true, -1, -1); 25 | } 26 | } 27 | 28 | @Override public boolean equals(Object o) { 29 | if (this == o) 30 | return true; 31 | if (o == null || getClass() != o.getClass()) 32 | return false; 33 | HavocAction that = (HavocAction) o; 34 | return cardIdx == that.cardIdx; 35 | } 36 | 37 | @Override public int hashCode() { 38 | return cardIdx; 39 | } 40 | 41 | @Override public String toString() { 42 | return "HavocAction(" + cardIdx + ")"; 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | # format 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import sys 6 | import ast 7 | import json 8 | 9 | plt.style.use('seaborn-whitegrid') 10 | 11 | if len(sys.argv) > 1: 12 | if sys.argv[1] == '-a': 13 | num_plots = (len(sys.argv) - 2) // 2 14 | fig, axs = plt.subplots(num_plots) 15 | for i in range(2, len(sys.argv), 2): 16 | array_name = sys.argv[i].strip() 17 | array_values = json.loads(sys.argv[i + 1]) 18 | array_values = np.nan_to_num(array_values) 19 | ax = axs[(i - 2) // 2] 20 | ax.plot(array_values) 21 | ax.set_title(array_name) 22 | else: 23 | fig, ax = plt.subplots(4) 24 | for path in sys.argv[1:]: 25 | with open(f'{path}/training.json', 'r') as f: 26 | training_info = json.load(f) 27 | cur_iteration = int(training_info['iteration']) 28 | info = [] 29 | for i in range(1, cur_iteration): 30 | info.append(training_info['iteration_info'][str(i)]) 31 | x0 = np.linspace(0, cur_iteration - 1, cur_iteration - 1) 32 | y0 = [x['death_rate'] for x in info if 'death_rate' in x] 33 | ax[0].plot(x0[:len(y0)], y0, label=path) 34 | y1 = [x['avg_dmg_no_death'] for x in info if 'avg_dmg_no_death' in x] 35 | ax[1].plot(x0[:len(y1)], y1, label=path) 36 | y2 = [x['avg_final_q'] for x in info if 'avg_final_q' in x] 37 | ax[2].plot(x0[:len(y2)], y2, label=path) 38 | x1 = [x['accumulated_time'] for x in info if 'accumulated_time' in x] 39 | ax[3].plot(x1[:len(y0)], y2, label=path) 40 | 41 | ax[0].legend() 42 | ax[1].legend() 43 | ax[2].legend() 44 | 45 | plt.show() 46 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/model/ModelBatchProducer.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.model; 2 | 3 | import com.alphaStS.GameState; 4 | import com.alphaStS.utils.Tuple; 5 | 6 | import java.util.concurrent.TimeUnit; 7 | 8 | public class ModelBatchProducer implements Model { 9 | private final ModelExecutor pool; 10 | private final int clientIdx; 11 | private int calls; 12 | private int cacheHits; 13 | 14 | public ModelBatchProducer(ModelExecutor modelExecutor, int clientIdx) { 15 | this.pool = modelExecutor; 16 | this.clientIdx = clientIdx; 17 | } 18 | 19 | @Override public NNOutput eval(GameState state) { 20 | if (state != null) { 21 | calls++; 22 | } 23 | pool.queue.offer(new Tuple<>(clientIdx, state)); 24 | pool.locks.get(clientIdx).lock(); 25 | while (pool.resultsStatus.get(clientIdx) == ModelExecutor.EXEC_WAITING) { 26 | try { 27 | pool.conditions.get(clientIdx).await(1, TimeUnit.SECONDS); 28 | } catch (InterruptedException e) { 29 | pool.locks.get(clientIdx).unlock(); 30 | throw new RuntimeException(e); 31 | } 32 | } 33 | pool.locks.get(clientIdx).unlock(); 34 | if (pool.resultsStatus.get(clientIdx) == ModelExecutor.EXEC_SUCCESS_CACHE_HIT) { 35 | cacheHits++; 36 | } 37 | pool.resultsStatus.set(clientIdx, ModelExecutor.EXEC_WAITING); 38 | return pool.results.set(clientIdx, null); 39 | } 40 | 41 | @Override public void close() { 42 | } 43 | 44 | int prev; 45 | @Override public void startRecordCalls() { 46 | prev = calls - cacheHits; 47 | } 48 | 49 | @Override public int endRecordCalls() { 50 | return calls - cacheHits - prev; 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/utils/DrawOrderReadOnly.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.utils; 2 | 3 | import java.util.Arrays; 4 | import java.util.Objects; 5 | 6 | public class DrawOrderReadOnly { 7 | protected int[] order; 8 | protected int len; 9 | 10 | public DrawOrderReadOnly(int startLen) { 11 | order = new int[startLen]; 12 | len = 0; 13 | } 14 | 15 | public DrawOrderReadOnly(DrawOrder other) { 16 | order = Arrays.copyOf(other.order, other.order.length); 17 | len = other.len; 18 | } 19 | 20 | @Override public boolean equals(Object o) { 21 | if (this == o) 22 | return true; 23 | if (o == null || getClass() != o.getClass()) 24 | return false; 25 | DrawOrder drawOrder = (DrawOrder) o; 26 | return len == drawOrder.len && Utils.arrayEqualLen(order, drawOrder.order, len); 27 | } 28 | 29 | @Override public int hashCode() { 30 | int result = Objects.hash(len); 31 | result = 31 * result + Utils.arrayHashCodeLen(order, len); 32 | return result; 33 | } 34 | 35 | public int size() { 36 | return len; 37 | } 38 | 39 | public int ithCardFromTop(int i) { 40 | return order[len - 1 - i]; 41 | } 42 | 43 | public boolean contains(int cardIndex) { 44 | for (int i = 0; i < len; i++) { 45 | if (order[i] == cardIndex) { 46 | return true; 47 | } 48 | } 49 | return false; 50 | } 51 | 52 | public void remove(int cardIndex) { 53 | for (int i = len - 1; i >= 0; i--) { 54 | if (order[i] == cardIndex) { 55 | for (int j = i; j < len; j++) { 56 | order[j] = order[j + 1]; 57 | } 58 | len--; 59 | return; 60 | } 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /commands.txt: -------------------------------------------------------------------------------- 1 | Some documentation of common commands (incomplete). 2 | 3 | This controls the draw order: 4 | do ("draw order") 5 | clear ("clear draw order") 6 | p ("pop last card in draw order") 7 | e ("save draw order, end changes and go back") 8 | b ("cancel changes and go back") 9 | 10 | eho ("set original enemy health"): set max health and modify health to match, used at beginning of game 11 | eh ("set enemy health"): set health without changing max health, used at any time if health doesn't match 12 | 13 | em ("set enemy move"): set current enemy move 14 | 15 | "rng off": use this to control rng roll as they happens (e.g. when defect lightning orb hits multiple enemy, you will be prompted on which enemy is hit, or when Begin Battle rolls one of the enemy encounter, you can select the enemy encounter) 16 | "rng on": turn off the above feature and re-enable rng 17 | 18 | "n ": do a MCTS search for the given node count 19 | "nn ": do a MCT search for the given node count until a action with random outcome occurs 20 | "reset": reset the current tree from the result of a "n" or "nn" command 21 | "nn exec": execute the result of the nn call above 22 | "nn execv": same as "nn exec" but doesn't exectue the final action with random outcome 23 | "states": see all the actions taken until the current state 24 | 25 | "save": save current play session to session.txt (saved seed + every command used) 26 | "load": load a saved play session from session.txt 27 | "save ": save session to session_2.txt 28 | "load ": load session from session_2.txt 29 | "save play": special command that give something like 30 | List.of("", ) 31 | 32 | Can be passed to builder to set the initial state. 33 | 34 | builder.set 35 | 36 | Which can be used to specfically train a neural net on some specific state. 37 | Note: need to remove commands before "Pre-Battle" action is chosen -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/GameStep.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | import java.util.List; 4 | 5 | public final class GameStep { 6 | private GameState state; 7 | private final int action; 8 | public boolean takenFromChanceStateCache; 9 | public double trainingWriteCount; 10 | public boolean trainingSkipOpening; 11 | public VArray v; 12 | public List lines; 13 | public boolean isExplorationMove; 14 | public RandomGen searchRandomGenMCTS; 15 | public StringBuilder actionDesc; 16 | public String stateStr; 17 | public int[] nWhenSearchHalfDone; 18 | 19 | GameStep(GameState state, int action) { 20 | this.state = state; 21 | this.action = action; 22 | } 23 | 24 | public GameState state() { 25 | return state; 26 | } 27 | 28 | public void setState(GameState state) { 29 | this.state = state; 30 | } 31 | 32 | public int action() { 33 | return action; 34 | } 35 | 36 | @Override public String toString() { 37 | return "GameStep{" + 38 | "state=" + state + 39 | ", action=" + getActionString() + 40 | '}'; 41 | } 42 | 43 | public GameAction getAction() { 44 | if (action < 0) { 45 | return null; 46 | } 47 | return state().getAction(action()); 48 | } 49 | 50 | public String getActionString() { 51 | if (action < 0) { 52 | return null; 53 | } 54 | return state().getActionString(action()) + (actionDesc == null ? "" : (" (" + actionDesc + ")")); 55 | } 56 | 57 | @Override public boolean equals(Object o) { 58 | if (this == o) 59 | return true; 60 | if (o == null || getClass() != o.getClass()) 61 | return false; 62 | 63 | GameStep gameStep = (GameStep) o; 64 | 65 | if (action != gameStep.action) 66 | return false; 67 | return state != null ? state.equals(gameStep.state) : gameStep.state == null; 68 | } 69 | 70 | @Override public int hashCode() { 71 | int result = state != null ? state.hashCode() : 0; 72 | result = 31 * result + action; 73 | return result; 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /agent/src/test/java/com/alphaStS/AppTest.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | import static org.junit.Assert.assertEquals; 4 | import static org.junit.Assert.assertNotEquals; 5 | 6 | import com.alphaStS.utils.CircularArray; 7 | import org.junit.Test; 8 | 9 | import java.util.Arrays; 10 | 11 | public class AppTest { 12 | @Test 13 | public void testCircularArray() { 14 | var arr = new CircularArray(); 15 | arr.addFirst(3); 16 | arr.addFirst(2); 17 | arr.addFirst(1); 18 | assertEquals(arr.pollFirst().intValue(), 1); 19 | arr.addFirst(0); 20 | arr.addFirst(-1); 21 | assertEquals(arr.pollFirst().intValue(), -1); 22 | assertEquals(arr.pollFirst().intValue(), 0); 23 | assertEquals(arr.pollFirst().intValue(), 2); 24 | assertEquals(arr.pollFirst().intValue(), 3); 25 | 26 | assertEquals(arr.size(), 0); 27 | arr.addLast(1); 28 | arr.addLast(2); 29 | arr.addLast(3); 30 | arr.addLast(4); 31 | assertEquals(arr.size(), 4); 32 | assertEquals(arr.pollFirst().intValue(), 1); 33 | assertEquals(arr.pollFirst().intValue(), 2); 34 | assertEquals(arr.size(), 2); 35 | arr.addLast(5); 36 | arr.addLast(6); 37 | arr.addFirst(0); 38 | assertEquals(arr.size(), 5); 39 | assertEquals(arr.pollFirst().intValue(), 0); 40 | assertEquals(arr.pollFirst().intValue(), 3); 41 | assertEquals(arr.pollFirst().intValue(), 4); 42 | assertEquals(arr.pollFirst().intValue(), 5); 43 | assertEquals(arr.pollFirst().intValue(), 6); 44 | 45 | arr.addLast(1); 46 | arr.addLast(2); 47 | arr.addLast(3); 48 | arr.addLast(4); 49 | assertEquals(arr.size(), 4); 50 | arr.clear(); 51 | assertEquals(arr.size(), 0); 52 | 53 | var arr2 = new CircularArray(); 54 | assertEquals(arr, arr2); 55 | arr.addLast(1); 56 | arr.addLast(2); 57 | arr.addLast(3); 58 | arr.addLast(4); 59 | assertNotEquals(arr, arr2); 60 | arr2.addLast(1); 61 | arr2.addLast(2); 62 | arr2.addLast(3); 63 | arr2.addLast(4); 64 | assertEquals(arr, arr2); 65 | arr2.pollFirst(); 66 | assertNotEquals(arr, arr2); 67 | arr.pollFirst(); 68 | assertEquals(arr, arr2); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/CrashException.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public class CrashException extends RuntimeException { 4 | private String message; 5 | 6 | private CrashException() { 7 | super(); 8 | } 9 | 10 | @Override 11 | public String getMessage() { 12 | return message; 13 | } 14 | 15 | public static Builder builder() { 16 | return new Builder(); 17 | } 18 | 19 | public static class Builder { 20 | private String message; 21 | private GameState state; 22 | private GameState childState; 23 | private Integer action; 24 | 25 | public Builder withMessage(String message) { 26 | this.message = message; 27 | return this; 28 | } 29 | 30 | public Builder withState(GameState state) { 31 | this.state = state; 32 | return this; 33 | } 34 | 35 | public Builder withChildState(GameState childState) { 36 | this.childState = childState; 37 | return this; 38 | } 39 | 40 | public Builder withAction(int action) { 41 | this.action = action; 42 | return this; 43 | } 44 | 45 | public CrashException build() { 46 | CrashException ex = new CrashException(); 47 | ex.message = buildMessage(); 48 | return ex; 49 | } 50 | 51 | private String buildMessage() { 52 | StringBuilder sb = new StringBuilder(message == null ? "Crash" : message); 53 | if (action != null) { 54 | appendAction(sb); 55 | } 56 | appendState(sb, "State", state); 57 | if (childState != null) { 58 | appendState(sb, "Child state", childState); 59 | } 60 | return sb.toString(); 61 | } 62 | 63 | private void appendAction(StringBuilder sb) { 64 | sb.append("\nAction: ").append(action); 65 | if (state != null) { 66 | sb.append(" ctx=").append(state.getActionCtx()); 67 | try { 68 | sb.append(" desc=").append(state.getActionString(action)); 69 | } catch (Exception ignored) { 70 | sb.append(" desc="); 71 | } 72 | } 73 | } 74 | 75 | private void appendState(StringBuilder sb, String label, GameState value) { 76 | sb.append("\n").append(label).append(": "); 77 | if (value == null) { 78 | sb.append(""); 79 | } else { 80 | sb.append(value); 81 | } 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /test_dist.py: -------------------------------------------------------------------------------- 1 | # format 2 | 3 | import platform 4 | import subprocess 5 | import os 6 | import time 7 | from misc import getFlag, getFlagValue 8 | 9 | RUN_SERVER = getFlag('-server') 10 | THREAD_COUNT = int(getFlagValue('-thread', 1)) 11 | BATCH_PER_THREAD = int(getFlagValue('-b', 1)) 12 | PLAY_MATCHES = getFlag('-m') 13 | 14 | sep = ':' 15 | if platform.system() == 'Windows': 16 | sep = ';' 17 | 18 | CLASS_PATH_AGENT = f'./target/classes{sep}{os.getenv("M2_HOME")}/repository/com/microsoft/onnxruntime/onnxruntime/1.10.0/onnxruntime-1.10.0.jar{sep}./src/resources/mallet.jar{sep}./src/resources/mallet-deps.jar{sep}{os.getenv("M2_HOME")}/repository/org/jdom/jdom/1.1/jdom-1.1.jar{sep}{os.getenv("M2_HOME")}/repository/com/fasterxml/jackson/core/jackson-databind/2.12.4/jackson-databind-2.12.4.jar{sep}{os.getenv("M2_HOME")}/repository/com/fasterxml/jackson/core/jackson-annotations/2.12.4/jackson-annotations-2.12.4.jar{sep}{os.getenv("M2_HOME")}/repository/com/fasterxml/jackson/core/jackson-core/2.12.4/jackson-core-2.12.4.jar{sep}{os.getenv("M2_HOME")}/repository/org/apache/commons/commons-compress/1.21/commons-compress-1.21.jar{sep}{os.getenv("M2_HOME")}/repository/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar' 19 | 20 | start = time.time() 21 | 22 | if RUN_SERVER: 23 | os.chdir('./agent') 24 | agent_args = ['java', '--add-opens', 'java.base/java.util=ALL-UNNAMED', '-classpath', CLASS_PATH_AGENT, 'com.alphaStS.Main', '--server'] 25 | if THREAD_COUNT > 1: 26 | agent_args += ['-t', str(THREAD_COUNT)] 27 | if BATCH_PER_THREAD > 1: 28 | agent_args += ['-b', str(BATCH_PER_THREAD)] 29 | if platform.system() != 'Windows': 30 | agent_args = agent_args[:1] + ['-Xmx700m'] + agent_args[1:] 31 | p = subprocess.Popen(agent_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 32 | while p.poll() is None: 33 | print(p.stdout.readline().decode('ascii'), end='', flush=True) 34 | [print(line.decode('ascii'), end='', flush=True) for line in p.stderr.readlines()] 35 | p.terminate() 36 | print(time.time() - start) 37 | 38 | if PLAY_MATCHES: 39 | os.chdir('./agent') 40 | agent_args = ['java', '--add-opens', 'java.base/java.util=ALL-UNNAMED', '-classpath', CLASS_PATH_AGENT, 'com.alphaStS.Main', '-p', '-g', '-c', '200', '-n', '100'] 41 | if platform.system() != 'Windows': 42 | agent_args = agent_args[:1] + ['-Xmx700m'] + agent_args[1:] 43 | p = subprocess.Popen(agent_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 44 | while p.poll() is None: 45 | print(p.stdout.readline().decode('ascii'), end='', flush=True) 46 | [print(line.decode('ascii'), end='', flush=True) for line in p.stderr.readlines()] 47 | print(time.time() - start) 48 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/LineOfPlay.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.List; 6 | 7 | import static com.alphaStS.utils.Utils.formatFloat; 8 | 9 | public class LineOfPlay { 10 | State state; 11 | int n; 12 | int depth; 13 | double p_total; 14 | double p_cur; 15 | double q_win; 16 | double q_health; 17 | double q_comb; 18 | boolean internal; 19 | int numberOfActions; 20 | List parentLines; 21 | 22 | public LineOfPlay(State state, double p, int depth, LineOfPlay parentLine, int action) { 23 | this.state = state; 24 | this.p_cur = p; 25 | this.p_total = p; 26 | this.depth = depth; 27 | if (state instanceof GameState s) { 28 | numberOfActions = s.getLegalActions().length; 29 | } else { 30 | numberOfActions = 1; 31 | } 32 | if (parentLine != null) { 33 | parentLines = new ArrayList<>(); 34 | parentLines.add(new Edge(parentLine, action)); 35 | } 36 | } 37 | 38 | public List getActions(GameState s) { 39 | var rand = s.properties.random; 40 | var actions = new ArrayList(); 41 | var line = this; 42 | if (line.parentLines == null) { 43 | var state = (GameState) line.state; 44 | var maxP = 0.0; 45 | var maxAction = -1; 46 | for (int i = 0; i < state.policy.length; i++) { 47 | if (state.policy[i] > maxP) { 48 | maxAction = i; 49 | maxP = state.policy[i]; 50 | } 51 | } 52 | actions.add(maxAction); 53 | return actions; 54 | } 55 | while (line.parentLines != null) { 56 | var r = line.parentLines.size() == 1 ? 0 : rand.nextInt(line.parentLines.size(), RandomGenCtx.Other); 57 | if (((GameState) line.parentLines.get(r).line.state).getAction(line.parentLines.get(r).action).type() != GameActionType.BEGIN_TURN) { 58 | actions.add(line.parentLines.get(r).action); 59 | } 60 | line = line.parentLines.get(r).line; 61 | } 62 | Collections.reverse(actions); 63 | return actions; 64 | } 65 | 66 | @Override public String toString() { 67 | return "LineOfPlay{" + 68 | "n=" + n + 69 | ", p=" + formatFloat(p_cur) + "/" + formatFloat(p_total) + 70 | ", q_comb=" + formatFloat(q_comb / n) + 71 | ", " + state + 72 | ", " + numberOfActions + 73 | '}'; 74 | } 75 | 76 | static record Edge(LineOfPlay line, int action) {} 77 | } 78 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/enemy/EnemyListReadOnly.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.enemy; 2 | 3 | import com.alphaStS.GameState; 4 | 5 | import java.util.Arrays; 6 | import java.util.Collection; 7 | import java.util.Iterator; 8 | 9 | public class EnemyListReadOnly implements Iterable { 10 | protected Enemy[] enemies; 11 | 12 | public EnemyListReadOnly(Enemy... enemies) { 13 | this.enemies = enemies; 14 | } 15 | 16 | public EnemyListReadOnly(Collection enemies) { 17 | this.enemies = enemies.toArray(new Enemy[0]); 18 | } 19 | 20 | public EnemyListReadOnly(EnemyListReadOnly enemyList) { 21 | this.enemies = Arrays.copyOf(enemyList.enemies, enemyList.enemies.length); 22 | } 23 | 24 | public EnemyReadOnly get(int i) { 25 | return enemies[i]; 26 | } 27 | 28 | @Override public boolean equals(Object o) { 29 | if (this == o) 30 | return true; 31 | if (o == null || getClass() != o.getClass()) 32 | return false; 33 | EnemyListReadOnly that = (EnemyListReadOnly) o; 34 | return Arrays.equals(enemies, that.enemies); 35 | } 36 | 37 | @Override public int hashCode() { 38 | return Arrays.hashCode(enemies); 39 | } 40 | 41 | public int size() { 42 | return enemies.length; 43 | } 44 | 45 | public int find(EnemyReadOnly enemy) { 46 | for (int i = 0; i < enemies.length; i++) { 47 | if (enemies[i] instanceof Enemy.MergedEnemy m) { 48 | if (m.currentEnemy.equals(enemy)) { 49 | return i; 50 | } 51 | // for (int j = 0; j < m.possibleEnemies.size(); j++) { 52 | // if (m.possibleEnemies.get(j).equals(enemy)) { 53 | // return i; 54 | // } 55 | // } 56 | } else if (enemies[i].equals(enemy)) { 57 | return i; 58 | } 59 | } 60 | return -1; 61 | } 62 | 63 | @Override public String toString() { 64 | return Arrays.toString(enemies); 65 | } 66 | 67 | private static class EnemyListIterator implements Iterator { 68 | Enemy[] enemies; 69 | int i = 0; 70 | 71 | protected EnemyListIterator(Enemy[] enemies) { 72 | this.enemies = enemies; 73 | } 74 | 75 | @Override public boolean hasNext() { 76 | return i < enemies.length; 77 | } 78 | 79 | @Override public EnemyReadOnly next() { 80 | return enemies[i++]; 81 | } 82 | } 83 | 84 | @Override public Iterator iterator() { 85 | return new EnemyListReadOnly.EnemyListIterator(enemies); 86 | } 87 | } 88 | 89 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/enemy/EnemyList.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.enemy; 2 | 3 | import java.util.Collection; 4 | import java.util.Iterator; 5 | 6 | public class EnemyList extends EnemyListReadOnly { 7 | protected long cloned; 8 | 9 | public EnemyList(EnemyList enemies) { 10 | super(enemies); 11 | } 12 | 13 | public EnemyList(Collection enemies) { 14 | super(enemies); 15 | } 16 | 17 | public void replace(int i, Enemy enemy) { 18 | cloned |= (1L << i); 19 | enemies[i] = enemy.copy(); 20 | } 21 | 22 | public Enemy getForWrite(int i) { 23 | if ((cloned & (1L << i)) == 0) { 24 | enemies[i] = enemies[i].copy(); 25 | cloned |= (1L << i); 26 | } 27 | return enemies[i]; 28 | } 29 | 30 | public EnemyListAliveIterable iterateOverAlive() { 31 | return new EnemyListAliveIterable(this); 32 | } 33 | 34 | private static class EnemyListAliveIterator implements Iterator { 35 | EnemyList enemies; 36 | int i = 0; 37 | 38 | protected EnemyListAliveIterator(EnemyList enemies) { 39 | this.enemies = enemies; 40 | } 41 | 42 | @Override public boolean hasNext() { 43 | while (i < enemies.enemies.length) { 44 | if (enemies.enemies[i].isAlive()) { 45 | return true; 46 | } 47 | i++; 48 | } 49 | return false; 50 | } 51 | 52 | @Override public Enemy next() { 53 | return enemies.getForWrite(i++); 54 | } 55 | } 56 | 57 | public static class EnemyListAliveIterable implements Iterable { 58 | EnemyList enemies; 59 | 60 | public EnemyListAliveIterable(EnemyList enemies) { 61 | this.enemies = enemies; 62 | } 63 | 64 | @Override public Iterator iterator() { 65 | return new EnemyListAliveIterator(enemies); 66 | } 67 | } 68 | 69 | public EnemyListAllIterable iterateOverAll() { 70 | return new EnemyListAllIterable(this); 71 | } 72 | 73 | private static class EnemyListAllIterator implements Iterator { 74 | EnemyList enemies; 75 | int i = 0; 76 | 77 | protected EnemyListAllIterator(EnemyList enemies) { 78 | this.enemies = enemies; 79 | } 80 | 81 | @Override public boolean hasNext() { 82 | return i < enemies.enemies.length; 83 | } 84 | 85 | @Override public Enemy next() { 86 | return enemies.getForWrite(i++); 87 | } 88 | } 89 | 90 | public static class EnemyListAllIterable implements Iterable { 91 | EnemyList enemies; 92 | 93 | public EnemyListAllIterable(EnemyList enemies) { 94 | this.enemies = enemies; 95 | } 96 | 97 | @Override public Iterator iterator() { 98 | return new EnemyListAllIterator(enemies); 99 | } 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import json 4 | import numpy 5 | import struct 6 | import sys 7 | print(sys.byteorder) 8 | 9 | lens_raw = subprocess.run(['java', '-classpath', f'F:/git/alphaStS/agent/target/classes;{os.getenv("M2_HOME")}/repository/com/microsoft/onnxruntime/onnxruntime/1.10.0/onnxruntime-1.10.0.jar', 10 | 'com.alphaStS.Main', '--get-lengths'], capture_output=True).stdout.decode('utf-8').strip() 11 | lens_data = json.loads(lens_raw) 12 | input_len = int(lens_data['inputLength']) 13 | num_of_actions = int(lens_data['policyLength']) 14 | agent_output = subprocess.run(['java', '-classpath', f'F:/git/alphaStS/agent/target/classes;{os.getenv("M2_HOME")}/repository/com/microsoft/onnxruntime/onnxruntime/1.10.0/onnxruntime-1.10.0.jar', 15 | 'com.alphaStS.Main', '-tm', '-c', '1', '-t', '-dir', './tmp'], capture_output=True).stdout 16 | split1 = agent_output.find(b'--------------------') 17 | print(agent_output[0:split1+20].decode('ascii')) 18 | split2 = agent_output.find(b'--------------------', split1 + 20) 19 | print(agent_output[split1+20:split2+20].decode('ascii')) 20 | agent_output = agent_output[split2+20:] 21 | print(agent_output[split2+20:]) 22 | print(len(agent_output)) 23 | 24 | agent_output = subprocess.run(['java', '-classpath', f'./agent/target/classes;{os.getenv("M2_HOME")}/repository/com/microsoft/onnxruntime/onnxruntime/1.10.0/onnxruntime-1.10.0.jar;./agent/src/resources/mallet.jar;./agent/src/resources/mallet-deps.jar', 'com.alphaStS.Main', '-t', '-dir', './tmp'], capture_output=True) 25 | if len(agent_output.stderr) > 0: 26 | print(agent_output.stderr.decode('ascii')) 27 | raise "agent error" 28 | agent_output = agent_output.stdout 29 | 30 | split = agent_output.find(b'--------------------') 31 | # print(agent_output[0:split + 20].decode('ascii')) 32 | agent_output = agent_output[split + 20:] 33 | 34 | training_pool = [] 35 | offset = 0 36 | # technically N^2 37 | while len(agent_output) > 0: 38 | x_fmt = '>' + ('f' * input_len) 39 | x = struct.unpack(x_fmt, agent_output[0:4 * input_len]) 40 | [v_health, v_win] = struct.unpack('>ff', agent_output[4 * input_len:4 * (input_len + 2)]) 41 | p_fmt = '>' + ('f' * num_of_actions) 42 | p = struct.unpack(p_fmt, agent_output[4 * (input_len + 2):4 * (input_len + 2 + num_of_actions)]) 43 | agent_output = agent_output[4 * (input_len + 2 + num_of_actions):] 44 | training_pool.append([list(x), [v_health], [v_win], list(p)]) 45 | if len(agent_output) != 0: 46 | print(f'{len(agent_output)} bytes remaining for decoding') 47 | raise "agent error" 48 | 49 | 50 | agent_output = subprocess.run(['java', '-classpath', f'./agent/target/classes;{os.getenv("M2_HOME")}/repository/com/microsoft/onnxruntime/onnxruntime/1.10.0/onnxruntime-1.10.0.jar;./agent/src/resources/mallet.jar;./agent/src/resources/mallet-deps.jar', 'com.alphaStS.Main', '-i', '-dir', './tmp'], capture_output=True) 51 | 52 | # print(training_pool) 53 | # x_fmt = '>' + ('f' * input_len) 54 | # x = struct.unpack(x_fmt, agent_output[0:4*input_len]) 55 | # [v_health, v_win] = struct.unpack('>ff', agent_output[4*input_len:4*(input_len+2)]) 56 | # p_fmt = '>' + ('f' * num_of_actions) 57 | # p = struct.unpack(p_fmt, agent_output[4*(input_len+2):4*(input_len+2+num_of_actions)]) 58 | # i = struct.unpack('>i', agent_output[4*(input_len+2+num_of_actions):]) 59 | # print(list(x), v_health, v_win, list(p)) 60 | # print(struct.pack('f', 0.1)) 61 | # print(struct.pack('f', 0.2)) 62 | # print(struct.pack('f', 0.3)) 63 | # print(struct.pack('f', 0)) 64 | # t = numpy.asarray(2, dtype='f') / numpy.asarray(10, dtype='f') 65 | # print(f'{t[0]:.10f}) 66 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/SearchFrontier.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | import java.util.ArrayList; 4 | import java.util.HashMap; 5 | import java.util.List; 6 | import java.util.Map; 7 | 8 | import static com.alphaStS.utils.Utils.formatFloat; 9 | 10 | public class SearchFrontier { 11 | Map lines = new HashMap<>(); 12 | int total_n; 13 | 14 | public void addLine(LineOfPlay line) { 15 | lines.put(line.state, line); 16 | } 17 | 18 | public void removeLine(LineOfPlay line) { 19 | lines.remove(line.state); 20 | } 21 | 22 | public LineOfPlay getLine(State state) { 23 | return lines.get(state); 24 | } 25 | 26 | public LineOfPlay getBestLine() { 27 | var max_n = 0; 28 | LineOfPlay maxLine = null; 29 | for (var line : lines.values()) { 30 | if (line.internal) { 31 | continue; 32 | } 33 | if (line.state instanceof GameState) { 34 | GameState state = (GameState) line.state; 35 | if (state.terminalAction == -1234) { 36 | maxLine = line; 37 | break; 38 | } 39 | } 40 | if (line.n > max_n) { 41 | maxLine = line; 42 | max_n = maxLine.n; 43 | } else if (line.n == max_n) { 44 | if (line.q_comb / line.n > maxLine.q_comb / maxLine.n) { 45 | maxLine = line; 46 | } 47 | } 48 | } 49 | return maxLine; 50 | } 51 | 52 | @Override public String toString() { 53 | var str = "SearchFrontier{total_n=" + total_n + ", lines=[\n "; 54 | str += String.join("\n ", lines.values().stream().sorted((a, b) -> { 55 | if (a.internal == b.internal) { 56 | return -Integer.compare(a.n, b.n); 57 | } else if (a.internal) { 58 | return 1; 59 | } else { 60 | return -1; 61 | } 62 | }).map(LineOfPlay::toString).toList()); 63 | str += "\n]}"; 64 | return str; 65 | } 66 | 67 | public boolean isOneOfBestLine(LineOfPlay line) { 68 | var bestLine = getBestLine(); 69 | // int i = actions.size() - 1; 70 | // while (line.parentLines != null) { 71 | // if (i < 0) { 72 | // return false; 73 | // } 74 | // boolean found = false; 75 | // for (LineOfPlay.Edge parentLine : line.parentLines) { 76 | // if (parentLine.action() == actions) { 77 | // 78 | // } 79 | // } 80 | // i--; 81 | // } 82 | // if (i >= 0) { 83 | // return false; 84 | // } 85 | return line.state.equals(bestLine.state); 86 | } 87 | 88 | public List getSortedLinesAsStrings(GameState state) { 89 | return lines.values().stream().sorted((a, b) -> { 90 | if (a.internal == b.internal) { 91 | return -Integer.compare(a.n, b.n); 92 | } else if (a.internal) { 93 | return 1; 94 | } else { 95 | return -1; 96 | } 97 | }).map((x) -> { 98 | var tmpS = state.clone(false); 99 | var actions = x.getActions(tmpS); 100 | var strings = new ArrayList(); 101 | for (var _action : actions) { 102 | if (tmpS.getActionString(_action).equals("Begin Turn")) { 103 | continue; 104 | } 105 | strings.add(tmpS.getActionString(_action)); 106 | tmpS = tmpS.doAction(_action); 107 | } 108 | return String.join(", ", strings) + ": n=" + x.n + ", p=" + formatFloat(x.p_cur) + ", q=" + formatFloat(x.q_comb / x.n) + 109 | ", q_win=" + formatFloat(x.q_win / x.n) + ", q_health=" + formatFloat(x.q_health / x.n) + 110 | " (" + formatFloat(x.q_health / x.n * state.getPlayeForRead().getMaxHealth()) + ")"; 111 | }).toList(); 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/utils/CircularArray.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.utils; 2 | 3 | import java.util.Arrays; 4 | 5 | // ArrayDeque doesn't implement equals 6 | public class CircularArray { 7 | public Object[] arr = new Object[3]; 8 | public int start = 0; 9 | public int end = 0; 10 | public int size = 0; 11 | 12 | public CircularArray() {} 13 | 14 | public CircularArray(CircularArray other) { 15 | arr = Arrays.copyOf(other.arr, other.arr.length); 16 | start = other.start; 17 | end = other.end; 18 | size = other.size; 19 | } 20 | 21 | public int size() { 22 | return size; 23 | } 24 | 25 | public void addFirst(T ob) { 26 | int next; 27 | if (size == arr.length) { 28 | var newArr = new Object[arr.length * 2]; 29 | if (start == 0) { 30 | System.arraycopy(arr, 0, newArr, 0, arr.length); 31 | } else { 32 | System.arraycopy(arr, start, newArr, 0, arr.length - start); 33 | System.arraycopy(arr, 0, newArr, arr.length - start, end); 34 | } 35 | arr = newArr; 36 | next = arr.length - 1; 37 | end = arr.length / 2; 38 | } else { 39 | next = (start - 1) < 0 ? arr.length - 1: start - 1; 40 | } 41 | arr[next] = ob; 42 | start = next; 43 | size++; 44 | } 45 | 46 | public void addLast(T ob) { 47 | int next = end % arr.length; 48 | if (size == arr.length) { 49 | var newArr = new Object[arr.length * 2]; 50 | if (start == 0) { 51 | System.arraycopy(arr, 0, newArr, 0, arr.length); 52 | } else { 53 | System.arraycopy(arr, start, newArr, 0, arr.length - start); 54 | System.arraycopy(arr, 0, newArr, arr.length - start, end); 55 | } 56 | start = 0; 57 | next = arr.length; 58 | arr = newArr; 59 | } 60 | arr[next] = ob; 61 | end = next + 1; 62 | size++; 63 | } 64 | 65 | @SuppressWarnings("unchecked") 66 | public T pollFirst() { 67 | if (size == 0) { 68 | return null; 69 | } 70 | Object o = arr[start]; 71 | arr[start] = null; 72 | start = (start + 1) % arr.length; 73 | size--; 74 | return (T) o; 75 | } 76 | 77 | @SuppressWarnings("unchecked") 78 | public T peekFirst() { 79 | if (size == 0) { 80 | return null; 81 | } 82 | return (T) arr[start]; 83 | } 84 | 85 | public void clear() { 86 | for (int i = end > start ? start : 0; i < end; i++) { 87 | arr[i] = null; 88 | } 89 | if (end < start) { 90 | for (int i = start; i < arr.length; i++) { 91 | arr[i] = null; 92 | } 93 | } 94 | start = 0; 95 | end = 0; 96 | size = 0; 97 | } 98 | 99 | @Override public String toString() { 100 | return "CircularArray{" + 101 | "arr=" + Arrays.toString(arr) + 102 | ", start=" + start + 103 | ", end=" + end + 104 | ", size=" + size + 105 | '}'; 106 | } 107 | 108 | @Override public boolean equals(Object o) { 109 | if (this == o) 110 | return true; 111 | if (o == null || getClass() != o.getClass()) 112 | return false; 113 | CircularArray that = (CircularArray) o; 114 | if (size() != that.size()) { 115 | return false; 116 | } 117 | for (int i = start, j = that.start; i < start + size; i++, j++) { 118 | if (!that.arr[j % that.arr.length].equals(arr[i % arr.length])) { 119 | return false; 120 | } 121 | } 122 | return true; 123 | } 124 | 125 | @Override public int hashCode() { 126 | int result = 1237; 127 | int end2 = end > start ? end : end + arr.length; 128 | for (int i = start; i < end2; i++) { 129 | result = 31 * result + arr[i % arr.length].hashCode(); 130 | } 131 | return result; 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/utils/BigRational.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.utils; 2 | 3 | import java.math.BigInteger; 4 | 5 | /** 6 | * Immutable class representing the rationals. Backed by BigInteger i.e. absolutely no overflow. Support very basic arithmetic. 7 | * The rational will always be in lowest reduced form. 8 | */ 9 | public final class BigRational implements Comparable { 10 | final public static BigRational ZERO = BigRational.valueOf(0); 11 | final public static BigRational ONE = BigRational.valueOf(1); 12 | 13 | final private BigInteger num; 14 | final private BigInteger den; 15 | 16 | /** 17 | * Constructs a rational number with the given numerator and denominator. 18 | * 19 | * @param numerator 20 | * @param denominator 21 | */ 22 | public BigRational(long numerator, long denominator) { 23 | this(BigInteger.valueOf(numerator), BigInteger.valueOf(denominator)); 24 | } 25 | 26 | /** 27 | * Constructs a rational number with the given numerator and denominator. 28 | * @param numerator 29 | * @param denominator 30 | */ 31 | public BigRational(BigInteger numerator, BigInteger denominator) { 32 | int c = denominator.compareTo(BigInteger.ZERO); 33 | if (c == 0) throw new ArithmeticException("BigRational: denominator cannot be 0"); 34 | if (c < 0) { 35 | denominator = denominator.negate(); 36 | numerator = denominator.negate(); 37 | } 38 | BigInteger a = numerator.gcd(denominator); 39 | den = denominator.divide(a); 40 | num = numerator.divide(a); 41 | } 42 | 43 | public static BigRational valueOf(BigInteger n) { 44 | return new BigRational(n, BigInteger.ONE); 45 | } 46 | 47 | public static BigRational valueOf(long n) { 48 | return new BigRational(BigInteger.valueOf(n), BigInteger.ONE); 49 | } 50 | 51 | public BigInteger getNumerator() { 52 | return num; 53 | } 54 | 55 | public BigInteger getDenominator() { 56 | return den; 57 | } 58 | 59 | public BigRational inverse() { 60 | return new BigRational(den, num); 61 | } 62 | 63 | public BigRational negate() { 64 | return new BigRational(num.negate(), den); 65 | } 66 | 67 | public BigRational add(BigRational rat) { 68 | return new BigRational(this.num.multiply(rat.den).add(rat.num.multiply(this.den)), this.den.multiply(rat.den)); 69 | } 70 | 71 | public BigRational subtract(BigRational rat) { 72 | return new BigRational(this.num.multiply(rat.den).subtract(rat.num.multiply(this.den)), this.den.multiply(rat.den)); 73 | } 74 | 75 | public BigRational multiply(BigRational rat) { 76 | return new BigRational(rat.num.multiply(this.num), rat.den.multiply(this.den)); 77 | } 78 | 79 | public BigRational divide(BigRational rat) { 80 | return new BigRational(this.num.multiply(rat.den), this.den.multiply(rat.num)); 81 | } 82 | 83 | 84 | @Override 85 | public int compareTo(BigRational o) { 86 | int c1 = num.compareTo(o.num); 87 | int c2 = den.compareTo(o.den); 88 | if (c1 == 0) { 89 | if (c2 > 0) return -1; 90 | else if (c2 < 0) return 1; 91 | return 0; 92 | } 93 | if (c2 == 0) { 94 | return c1 > 0 ? 1 : -1; 95 | } 96 | return num.multiply(o.den).compareTo(o.num.multiply(den)) > 0 ? 1 : -1; 97 | } 98 | 99 | public double toDouble() { 100 | return num.longValueExact() / (double) den.longValueExact(); 101 | } 102 | 103 | /** 104 | * Return a copy of the current rational number. 105 | * @return a copy of current rational 106 | */ 107 | public BigRational copy() { 108 | return new BigRational(num, den); 109 | } 110 | 111 | @Override 112 | public boolean equals(Object o) { 113 | if (this == o) return true; 114 | if (o == null || getClass() != o.getClass()) return false; 115 | 116 | BigRational rational = (BigRational) o; 117 | 118 | if (!den.equals(rational.den)) return false; 119 | return !num.equals(rational.num); 120 | } 121 | 122 | @Override 123 | public int hashCode() { 124 | int result = num.hashCode(); 125 | result = 31 * result + den.hashCode(); 126 | return result; 127 | } 128 | 129 | @Override 130 | public String toString() { 131 | return "BigRational{" + 132 | "num=" + num + 133 | ", den=" + den + 134 | '}'; 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /agent/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | com.alphaStS 8 | alphaStS 9 | 1.0-SNAPSHOT 10 | 11 | alphaStS 12 | 13 | http://www.example.com 14 | 15 | 16 | 17 | data-local 18 | data 19 | file://${project.basedir}/src/resources 20 | 21 | 22 | 23 | 24 | UTF-8 25 | 17 26 | 17 27 | 17 28 | 29 | 30 | 31 | 32 | junit 33 | junit 34 | 4.13.1 35 | test 36 | 37 | 38 | com.microsoft.onnxruntime 39 | onnxruntime 40 | 1.10.0 41 | 42 | 43 | cc.mallet 44 | mallet 45 | 2.0.8 46 | 47 | 48 | com.fasterxml.jackson.core 49 | jackson-databind 50 | 2.18.4 51 | 52 | 53 | org.apache.commons 54 | commons-compress 55 | 1.26.0 56 | 57 | 58 | commons-codec 59 | commons-codec 60 | 1.18.0 61 | 62 | 63 | org.apache.commons 64 | commons-math3 65 | 3.6.1 66 | 67 | 68 | one.util 69 | streamex 70 | 0.8.3 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | maven-clean-plugin 80 | 3.1.0 81 | 82 | 83 | 84 | maven-resources-plugin 85 | 3.0.2 86 | 87 | 88 | maven-compiler-plugin 89 | 3.8.0 90 | 91 | 92 | maven-surefire-plugin 93 | 2.22.1 94 | 95 | 96 | maven-jar-plugin 97 | 3.0.2 98 | 99 | 100 | maven-install-plugin 101 | 2.5.2 102 | 103 | 104 | maven-deploy-plugin 105 | 2.8.2 106 | 107 | 108 | 109 | maven-site-plugin 110 | 3.7.1 111 | 112 | 113 | maven-project-info-reports-plugin 114 | 3.0.0 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/utils/CounterStat.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.utils; 2 | 3 | import com.alphaStS.GameState; 4 | import com.fasterxml.jackson.annotation.JsonIgnoreProperties; 5 | 6 | import java.util.ArrayList; 7 | import java.util.List; 8 | import java.util.StringJoiner; 9 | 10 | @JsonIgnoreProperties(ignoreUnknown = true) 11 | public class CounterStat { 12 | int counterIdx; 13 | String counterName; 14 | List counterFrequency = new ArrayList<>(); 15 | int cmpWin; 16 | int cmpWinAmt; 17 | int cmpLoss; 18 | int cmpLossAmt; 19 | int counterLen = 1; 20 | boolean showFrequency = false; 21 | 22 | public CounterStat() {} // for jackson 23 | 24 | public CounterStat(int idx, String name) { 25 | counterIdx = idx; 26 | counterName = name; 27 | } 28 | 29 | public CounterStat(int idx, int len, String name) { 30 | counterIdx = idx; 31 | counterName = name; 32 | counterLen = len; 33 | } 34 | 35 | public CounterStat copy() { 36 | return new CounterStat(counterIdx, counterName); 37 | } 38 | 39 | public void add(CounterStat counterStat) { 40 | for (int i = counterFrequency.size(); i <= counterStat.counterFrequency.size(); i++) { 41 | counterFrequency.add(0); 42 | } 43 | for (int i = 0; i < counterStat.counterFrequency.size(); i++) { 44 | counterFrequency.set(i, counterFrequency.get(i) + counterStat.counterFrequency.get(i)); 45 | } 46 | cmpWin += counterStat.cmpWin; 47 | cmpWinAmt += counterStat.cmpWinAmt; 48 | cmpLoss += counterStat.cmpLoss; 49 | cmpLossAmt += counterStat.cmpLossAmt; 50 | } 51 | 52 | private int getCounter(GameState state) { 53 | if (counterLen == 1) { 54 | return state.getCounterForRead()[counterIdx]; 55 | } else { 56 | for (int i = 0; i < counterLen; i++) { 57 | if (state.getCounterForRead()[i] > 0) { 58 | return i; 59 | } 60 | } 61 | return 0; 62 | } 63 | } 64 | 65 | public void add(GameState state) { 66 | int counter = state.getCounterForRead()[counterIdx]; 67 | for (int i = counterFrequency.size(); i <= counter; i++) { 68 | counterFrequency.add(0); 69 | } 70 | counterFrequency.set(counter, counterFrequency.get(counter) + 1); 71 | } 72 | 73 | public void addComparison(GameState state, GameState state2) { 74 | var counter1 = state.getCounterForRead()[counterIdx]; 75 | var counter2 = state2.getCounterForRead()[counterIdx]; 76 | if ((state.isTerminal() == 1 && state2.isTerminal() == 1) && counter1 != counter2) { 77 | if (counter1 > counter2) { 78 | cmpWin++; 79 | cmpWinAmt += counter1 - counter2; 80 | } else { 81 | cmpLoss++; 82 | cmpLossAmt += counter2 - counter1; 83 | } 84 | } 85 | } 86 | 87 | public void printStat(String indent, int n) { 88 | long totalCounterAmt = 0; 89 | for (int i = 0; i < counterFrequency.size(); i++) { 90 | totalCounterAmt += counterFrequency.get(i) * i; 91 | } 92 | if (counterLen > 1 || showFrequency) { 93 | StringJoiner sj = new StringJoiner(", "); 94 | for (int i = 0; i < counterFrequency.size(); i++) { 95 | if (counterFrequency.get(i) > 0) { 96 | sj.add(i + ": " + counterFrequency.get(i) + " (" + String.format("%.3f", 100.0 * counterFrequency.get(i) / n) + "%)"); 97 | } 98 | } 99 | System.out.println(indent + "Average " + counterName + " Counter: " + String.format("%.5f", ((double) totalCounterAmt) / n) + " (" + sj + ")"); 100 | } else { 101 | System.out.println(indent + "Average " + counterName + " Counter: " + String.format("%.5f", ((double) totalCounterAmt) / n)); 102 | } 103 | } 104 | 105 | public CounterStat setShowFrequency(boolean showFrequency) { 106 | this.showFrequency = showFrequency; 107 | return this; 108 | } 109 | 110 | public void printCmpStat(String indent) { 111 | System.out.println(indent + "Win/Loss " + counterName + ": " + cmpWin + "/" + cmpLoss + " (" + cmpWinAmt / (double) cmpWin + "/" + cmpLossAmt / (double) cmpLoss + "/" + (cmpWinAmt - cmpLossAmt) / (double) (cmpWin + cmpLoss) + ")"); 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /upload_server.py: -------------------------------------------------------------------------------- 1 | # format 2 | 3 | import subprocess 4 | import os 5 | from misc import getFlag, getFlagValue 6 | from os.path import expanduser 7 | home = expanduser("~") 8 | 9 | # This Python script updates and deploys alphaStS to multiple linux servers. It performs the following steps: 10 | # - Adds all changes to the git repository 11 | # - Commits the changes with a temporary commit message 12 | # - For each specified IP address, it creates a directory "alphaStS" if it doesn't exist already 13 | # - Copies the .git directory to the remote server using rsync 14 | # - Executes a sequence of commands on the remote server: 15 | # - Kills any running Java processes 16 | # - Resets the git repository to the latest commit 17 | # - Builds and installs a Java project using Maven 18 | # - Runs a Python script and redirects output to a log file 19 | # - If any exception occurs during the deployment, the script rolls back the last commit and raises the exception again. 20 | 21 | ips = getFlagValue('-ip', '').split(',') 22 | print(ips) 23 | 24 | f = open(f'{home}\\alphaStSServers.txt', 'w') 25 | 26 | p = subprocess.Popen(f'git add -A', stdout=subprocess.PIPE, stderr=subprocess.PIPE) 27 | while p.poll() is None: 28 | print(p.stdout.readline().decode('ascii'), end='') 29 | [print(line.decode('ascii'), end='') for line in p.stderr.readlines()] 30 | 31 | p = subprocess.Popen(f'git commit -m "tmp"', stdout=subprocess.PIPE, stderr=subprocess.PIPE) 32 | while p.poll() is None: 33 | print(p.stdout.readline().decode('ascii'), end='') 34 | [print(line.decode('ascii'), end='') for line in p.stderr.readlines()] 35 | 36 | git_command = ["git", "rev-parse", "--abbrev-ref", "HEAD"] 37 | process = subprocess.Popen(git_command, stdout=subprocess.PIPE) 38 | output, _ = process.communicate() 39 | branch_name = output.decode("utf-8").strip() 40 | 41 | try: 42 | for ip in ips: 43 | print(f'*********** {ip} ************************************************') 44 | if ip.find('@192.168.1.') >= 0: 45 | username = ip[:ip.find('@192.168.1.')] 46 | ip = ip[ip.find('@192.168.1.') + 1:] 47 | f.write(f'{ip}:4000\n') 48 | p = subprocess.Popen(f'ssh -o StrictHostKeyChecking=no {username}@{ip} "F: & cd git\\alphaStS & git fetch --all & git reset --hard HEAD & git checkout origin/{branch_name} & git reset --hard origin/{branch_name} & cd agent & mvn install & cd .. & taskkill /F /IM java.exe"', stdout=subprocess.PIPE, stderr=subprocess.PIPE) 49 | while p.poll() is None: 50 | print(p.stdout.readline().decode('ascii'), end='') 51 | [print(line.decode('ascii'), end='') for line in p.stderr.readlines()] 52 | else: 53 | f.write(f'{ip}:4000\n') 54 | p = subprocess.Popen(f'ssh -o StrictHostKeyChecking=no -i {home}\\Downloads\\Test.pem ec2-user@{ip} "mkdir -p ./alphaStS"', stdout=subprocess.PIPE, stderr=subprocess.PIPE) 55 | while p.poll() is None: 56 | print(p.stdout.readline().decode('ascii'), end='') 57 | [print(line.decode('ascii'), end='') for line in p.stderr.readlines()] 58 | 59 | p = subprocess.Popen(f'bash -c "rsync -r -P --delete .git/ -e \'ssh -o StrictHostKeyChecking=no -i /root/.ssh/svr1.pem\' ec2-user@{ip}:/home/ec2-user/alphaStS/.git/"', stdout=subprocess.PIPE, stderr=subprocess.PIPE) 60 | while p.poll() is None: 61 | print(p.stdout.readline().decode('ascii'), end='') 62 | [print(line.decode('ascii'), end='') for line in p.stderr.readlines()] 63 | 64 | p = subprocess.Popen(f'ssh -o StrictHostKeyChecking=no -i {home}\\Downloads\\Test.pem ec2-user@{ip} "bash -c \'killall -9 java; cd ~/alphaStS; git reset --hard; cd agent; mvn install; cd ..; python3 test_dist.py -server > /tmp/alphaStS.log 2>&1 & disown -a; exit\'"', stdout=subprocess.PIPE, stderr=subprocess.PIPE) 65 | while p.poll() is None: 66 | print(p.stdout.readline().decode('ascii'), end='') 67 | [print(line.decode('ascii'), end='') for line in p.stderr.readlines()] 68 | except e: 69 | p = subprocess.Popen(f'git reset HEAD^', stdout=subprocess.PIPE, stderr=subprocess.PIPE) 70 | while p.poll() is None: 71 | print(p.stdout.readline().decode('ascii'), end='') 72 | [print(line.decode('ascii'), end='') for line in p.stderr.readlines()] 73 | raise e 74 | 75 | p = subprocess.Popen(f'git reset HEAD^', stdout=subprocess.PIPE, stderr=subprocess.PIPE) 76 | while p.poll() is None: 77 | print(p.stdout.readline().decode('ascii'), end='') 78 | [print(line.decode('ascii'), end='') for line in p.stderr.readlines()] 79 | 80 | f.close() 81 | -------------------------------------------------------------------------------- /ml_conv.py: -------------------------------------------------------------------------------- 1 | # format 2 | 3 | import random as rand 4 | import platform 5 | import time 6 | import math 7 | import json 8 | import re 9 | import os 10 | import subprocess 11 | import struct 12 | import shutil 13 | import tensorflow as tf 14 | import numpy as np 15 | import lz4.frame 16 | from tensorflow import keras 17 | from tensorflow.keras import layers 18 | import tf2onnx 19 | from misc import getFlag, getFlagValue 20 | 21 | # rand.seed(5) 22 | # np.random.seed(5) 23 | # tf.random.set_seed(5) 24 | 25 | DO_TRAINING = getFlag('-training') 26 | SKIP_TRAINING_MATCHES = getFlag('-s') 27 | PLAY_A_GAME = getFlag('-p') 28 | PLAY_MATCHES = getFlag('-m') 29 | NUMBER_OF_THREADS = int(getFlagValue('-t', 1)) 30 | NUMBER_OF_THREADS_TRAINING = int(getFlagValue('-tt', 0)) 31 | ITERATION_COUNT = int(getFlagValue('-c', 5)) 32 | Z_TRAIN_WINDOW_END = int(getFlagValue('-z', -1)) 33 | NODE_COUNT = int(getFlagValue('-n', 1000)) 34 | SAVES_DIR = getFlagValue('-dir', './saves') 35 | SKIP_FIRST = getFlag('-skip_first') 36 | USE_KAGGLE = getFlagValue('-kaggle') 37 | KAGGLE_USER_NAME = getFlagValue('-kaggle_user', None) 38 | KAGGLE_DATASET_NAME = 'dataset' 39 | 40 | if NUMBER_OF_THREADS_TRAINING > 0: 41 | tf.config.threading.set_intra_op_parallelism_threads(1) 42 | tf.config.threading.set_inter_op_parallelism_threads(1) 43 | os.environ["OMP_NUM_THREADS"] = f"1" 44 | os.environ['TF_NUM_INTEROP_THREADS'] = f"1" 45 | os.environ['TF_NUM_INTRAOP_THREADS'] = f"1" 46 | 47 | 48 | def convertToOnnx(model, input_len, output_dir): 49 | spec = (tf.TensorSpec((None, input_len), tf.float32, name="input"),) 50 | output_path = output_dir + "/model.onnx" 51 | model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path) 52 | 53 | sep = ':' 54 | if platform.system() == 'Windows': 55 | sep = ';' 56 | 57 | CLASS_PATH = f'./agent/target/classes{sep}{os.getenv("M2_HOME")}/repository/com/microsoft/onnxruntime/onnxruntime/1.10.0/onnxruntime-1.10.0.jar{sep}./agent/src/resources/mallet.jar{sep}./agent/src/resources/mallet-deps.jar{sep}{os.getenv("M2_HOME")}/repository/org/jdom/jdom/1.1/jdom-1.1.jar{sep}{os.getenv("M2_HOME")}/repository/com/fasterxml/jackson/core/jackson-databind/2.12.4/jackson-databind-2.12.4.jar{sep}{os.getenv("M2_HOME")}/repository/com/fasterxml/jackson/core/jackson-annotations/2.12.4/jackson-annotations-2.12.4.jar{sep}{os.getenv("M2_HOME")}/repository/com/fasterxml/jackson/core/jackson-core/2.12.4/jackson-core-2.12.4.jar{sep}{os.getenv("M2_HOME")}/repository/org/apache/commons/commons-compress/1.21/commons-compress-1.21.jar{sep}{os.getenv("M2_HOME")}/repository/org/apache/commons/commons-math3/3.6.1/commons-math3-3.6.1.jar' 58 | 59 | 60 | def snake_case(s): 61 | return ''.join(['_' + i.lower() if i.isupper()else i for i in s]).lstrip('_') 62 | 63 | p = subprocess.run(['java', '--add-opens', 'java.base/java.util=ALL-UNNAMED', '-classpath', CLASS_PATH, 'com.alphaStS.Main', '--get-lengths'], capture_output=True) 64 | lens_data = json.loads(p.stdout.decode('utf-8').strip()) 65 | input_len = int(lens_data['inputLength']) 66 | num_of_actions = int(lens_data['policyLength']) 67 | v_other_lens = [] 68 | v_other_label = [] 69 | for target in lens_data.get('extraTargets', []): 70 | v_other_label.append(snake_case(target['label'])) 71 | v_other_lens.append(int(target['length'])) 72 | v_other_len = sum(v_other_lens) 73 | print(f'input_len={input_len}, policy_len={num_of_actions}, v_other_len={v_other_len}') 74 | 75 | 76 | def softmax_cross_entropy_with_logits(y_true, y_pred): 77 | p = y_pred 78 | pi = y_true 79 | zero = tf.zeros(shape=tf.shape(pi), dtype=tf.float32) 80 | where = tf.less(pi, zero) 81 | negatives = tf.fill(tf.shape(pi), -1000.0) 82 | p = tf.where(where, negatives, p) 83 | pi = tf.where(where, zero, pi) 84 | loss = tf.nn.softmax_cross_entropy_with_logits(labels=pi, logits=p) 85 | return loss 86 | 87 | 88 | def softmax_cross_entropy_with_logits_simple(y_true, y_pred): 89 | return tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred) 90 | 91 | 92 | def mse_ignoring_out_of_bound(y_true, y_pred): 93 | # allow agent to output negative values to ignore training data for losses when training 94 | p = y_pred 95 | pi = y_true 96 | zero = tf.zeros(shape=tf.shape(pi), dtype=tf.float32) 97 | where = tf.less(pi, zero - 2) 98 | pi = tf.where(where, p, pi) 99 | return tf.keras.losses.MeanSquaredError()(p, pi) 100 | 101 | 102 | if os.path.exists(f'{SAVES_DIR}/training.json'): 103 | with open(f'{SAVES_DIR}/training.json', 'r') as f: 104 | training_info = json.load(f) 105 | else: 106 | try: 107 | os.mkdir(SAVES_DIR) 108 | except: 109 | pass 110 | training_info = {'iteration': 1, 'iteration_info': {}} 111 | with open(f'{SAVES_DIR}/training.json', 'w') as f: 112 | json.dump(training_info, f) 113 | 114 | custom_objects = {"softmax_cross_entropy_with_logits": softmax_cross_entropy_with_logits, 115 | "softmax_cross_entropy_with_logits_simple": softmax_cross_entropy_with_logits_simple, "mse_ignoring_out_of_bound": mse_ignoring_out_of_bound} 116 | if os.path.exists(f'{SAVES_DIR}/iteration{training_info["iteration"] - 1}'): 117 | with keras.utils.custom_object_scope(custom_objects): 118 | model = tf.keras.models.load_model(f'{SAVES_DIR}/iteration{training_info["iteration"] - 1}') 119 | # model.optimizer.lr.assign(0.01) 120 | convertToOnnx(model, input_len, f'{SAVES_DIR}/iteration{training_info["iteration"] - 1}') 121 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/player/PlayerReadOnly.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.player; 2 | 3 | import java.util.Objects; 4 | 5 | public class PlayerReadOnly { 6 | protected int origHealth; 7 | protected int inBattleMaxHealth; 8 | protected int maxHealth; 9 | protected int health; 10 | protected int block; 11 | protected int strength; 12 | protected int dexterity; 13 | protected int vulnerable; 14 | protected int weak; 15 | protected int frail; 16 | protected int artifact; 17 | protected boolean cannotDrawCard; 18 | protected boolean hexed; 19 | protected int entangled; 20 | protected int loseStrengthEot; 21 | protected int loseDexterityEot; 22 | protected int platedArmor; 23 | protected short noMoreBlockFromCards; 24 | protected int accumulatedDamage; 25 | 26 | public PlayerReadOnly(int health, int maxHealth) { 27 | this(health, maxHealth, maxHealth); 28 | } 29 | 30 | public PlayerReadOnly(int health, int maxHealth, int inBattleMaxHealth) { 31 | this.maxHealth = maxHealth; 32 | this.inBattleMaxHealth = inBattleMaxHealth; 33 | this.health = health; 34 | origHealth = health; 35 | } 36 | 37 | public PlayerReadOnly(PlayerReadOnly other) { 38 | origHealth = other.origHealth; 39 | maxHealth = other.maxHealth; 40 | inBattleMaxHealth = other.inBattleMaxHealth; 41 | health = other.health; 42 | block = other.block; 43 | strength = other.strength; 44 | dexterity = other.dexterity; 45 | vulnerable = other.vulnerable; 46 | weak = other.weak; 47 | frail = other.frail; 48 | artifact = other.artifact; 49 | cannotDrawCard = other.cannotDrawCard; 50 | entangled = other.entangled; 51 | hexed = other.hexed; 52 | loseStrengthEot = other.loseStrengthEot; 53 | loseDexterityEot = other.loseDexterityEot; 54 | platedArmor = other.platedArmor; 55 | noMoreBlockFromCards = other.noMoreBlockFromCards; 56 | accumulatedDamage = other.accumulatedDamage; 57 | } 58 | 59 | public int getMaxHealth() { 60 | return maxHealth; 61 | } 62 | 63 | public int getInBattleMaxHealth() { 64 | return inBattleMaxHealth; 65 | } 66 | 67 | public int getOrigHealth() { 68 | return origHealth; 69 | } 70 | 71 | public int getHealth() { 72 | return health; 73 | } 74 | 75 | public int getBlock() { 76 | return block; 77 | } 78 | 79 | public int getStrength() { 80 | return strength; 81 | } 82 | 83 | public int getDexterity() { 84 | return dexterity; 85 | } 86 | 87 | public int getVulnerable() { 88 | return vulnerable; 89 | } 90 | 91 | public int getWeak() { 92 | return weak; 93 | } 94 | 95 | public int getFrail() { 96 | return frail; 97 | } 98 | 99 | public int getArtifact() { 100 | return artifact; 101 | } 102 | 103 | public boolean cannotDrawCard() { 104 | return cannotDrawCard; 105 | } 106 | 107 | public int getLoseStrengthEot() { 108 | return loseStrengthEot; 109 | } 110 | 111 | public int getLoseDexterityEot() { 112 | return loseDexterityEot; 113 | } 114 | 115 | public int getPlatedArmor() { 116 | return platedArmor; 117 | } 118 | 119 | public boolean isEntangled() { 120 | return entangled > 0; 121 | } 122 | 123 | public boolean isHexed() { 124 | return hexed; 125 | } 126 | 127 | public int getAccumulatedDamage() { 128 | return accumulatedDamage; 129 | } 130 | 131 | @Override public String toString() { 132 | String str = "Player{health=" + health; 133 | if (block > 0) { 134 | str += ", block=" + block; 135 | } 136 | if (strength != 0) { 137 | str += ", str=" + strength; 138 | } 139 | if (dexterity != 0) { 140 | str += ", dex=" + dexterity; 141 | } 142 | if (vulnerable > 0) { 143 | str += ", vuln=" + vulnerable; 144 | } 145 | if (weak > 0) { 146 | str += ", weak=" + weak; 147 | } 148 | if (frail > 0) { 149 | str += ", frail=" + frail; 150 | } 151 | if (artifact > 0) { 152 | str += ", art=" + artifact; 153 | } 154 | if (cannotDrawCard) { 155 | str += ", cannotDraw=true"; 156 | } 157 | if (isEntangled()) { 158 | str += ", entangled"; 159 | } 160 | if (hexed) { 161 | str += ", hexed"; 162 | } 163 | if (loseStrengthEot > 0) { 164 | str += ", loseStrEot=" + loseStrengthEot; 165 | } 166 | if (loseDexterityEot > 0) { 167 | str += ", loseDexEot=" + loseDexterityEot; 168 | } 169 | if (noMoreBlockFromCards > 0) { 170 | str += ", noMoreBlockFromCards=" + noMoreBlockFromCards; 171 | } 172 | return str + '}'; 173 | } 174 | 175 | @Override public boolean equals(Object o) { 176 | if (this == o) 177 | return true; 178 | if (o == null || getClass() != o.getClass()) 179 | return false; 180 | PlayerReadOnly that = (PlayerReadOnly) o; 181 | return origHealth == that.origHealth && maxHealth == that.maxHealth && inBattleMaxHealth == that.inBattleMaxHealth && health == that.health && block == that.block && strength == that.strength && dexterity == that.dexterity && vulnerable == that.vulnerable && weak == that.weak && frail == that.frail && artifact == that.artifact && cannotDrawCard == that.cannotDrawCard && entangled == that.entangled && hexed == that.hexed && loseStrengthEot == that.loseStrengthEot && loseDexterityEot == that.loseDexterityEot && noMoreBlockFromCards == that.noMoreBlockFromCards && accumulatedDamage == that.accumulatedDamage; 182 | } 183 | 184 | @Override public int hashCode() { 185 | return Objects.hash(origHealth, maxHealth, health, block, strength, dexterity, vulnerable, weak, frail, artifact, cannotDrawCard, entangled, hexed, loseStrengthEot, loseDexterityEot); 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/utils/Utils.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.utils; 2 | 3 | import com.alphaStS.GameState; 4 | import com.alphaStS.RandomGen; 5 | import com.alphaStS.RandomGenCtx; 6 | import org.apache.commons.math3.distribution.ChiSquaredDistribution; 7 | 8 | public class Utils { 9 | public static int max(int[] arr) { 10 | int m = Integer.MIN_VALUE; 11 | for (int i = 0; i < arr.length; i++) { 12 | m = Math.max(m, arr[i]); 13 | } 14 | return m; 15 | } 16 | 17 | public static boolean arrayEqualLen(int[] a1, int[] a2, int len) { 18 | if (len > a1.length || len > a2.length) { 19 | return false; 20 | } 21 | for (int i = 0; i < len; i++) { 22 | if (a1[i] != a2[i]) { 23 | return false; 24 | } 25 | } 26 | return true; 27 | } 28 | 29 | public static int arrayHashCodeLen(int[] a, int len) { 30 | int result = 1; 31 | for (int i = 0; i < len; i++) { 32 | result = 31 * result + a[i]; 33 | } 34 | return result; 35 | } 36 | 37 | public static float[] arrayCopy(float[] from, int start, int len) { 38 | var to = new float[len]; 39 | for (int i = 0; i < len; i++) { 40 | to[i] = from[start + i]; 41 | } 42 | return to; 43 | } 44 | 45 | public static String formatFloat(double f) { 46 | if (f == 0) { 47 | return "0"; 48 | } else if (f < 0.001) { 49 | return String.format("%6.3e", f).trim(); 50 | } else { 51 | return String.format("%6.3f", f).trim(); 52 | } 53 | } 54 | 55 | public static void sleep(long ms) { 56 | try { 57 | if (ms > 0) Thread.sleep(ms); 58 | } catch (InterruptedException e) { 59 | e.printStackTrace(); 60 | } 61 | } 62 | 63 | public static int[] shortToIntArray(short[] arr) { 64 | int[] result = new int[arr.length]; 65 | for (int i = 0; i < arr.length; i++) { 66 | result[i] = arr[i]; 67 | } 68 | return result; 69 | } 70 | 71 | public static boolean equals(short[] arr1, short[] arr2, int len) { 72 | for (int i = 0; i < len; i++) { 73 | if (arr1[i] != arr2[i]) { 74 | return false; 75 | } 76 | } 77 | return true; 78 | } 79 | 80 | public static boolean equals(boolean[] arr1, boolean[] arr2, int len) { 81 | for (int i = 0; i < len; i++) { 82 | if (arr1[i] != arr2[i]) { 83 | return false; 84 | } 85 | } 86 | return true; 87 | } 88 | 89 | public static void shuffle(GameState state, short[] arr, int len, RandomGen rand) { 90 | for (int i = len - 1; i >= 0; i--) { 91 | int j = rand.nextInt(i + 1, RandomGenCtx.CardDraw, new Tuple3<>(state, 0, 0)); 92 | short temp = arr[i]; 93 | arr[i] = arr[j]; 94 | arr[j] = temp; 95 | } 96 | } 97 | 98 | public static void shuffle(GameState state, short[] arr, int len, int start, RandomGen rand) { 99 | for (int i = len - 1; i >= start; i--) { 100 | int j = rand.nextInt(i + 1, RandomGenCtx.CardDraw, new Tuple3<>(state, 2, arr)); 101 | short temp = arr[i]; 102 | arr[i] = arr[j]; 103 | arr[j] = temp; 104 | } 105 | } 106 | 107 | public static float[] nArrayToPolicy(int[] n) { 108 | float[] policy = new float[n.length]; 109 | var total_n = 0; 110 | for (int i = 0; i < n.length; i++) { 111 | total_n += n[i]; 112 | } 113 | for (int i = 0; i < n.length; i++) { 114 | policy[i] = n[i] / (float) total_n; 115 | } 116 | return policy; 117 | } 118 | 119 | public static double calcKLDivergence(float[] p, float[] q) { 120 | if (p.length != q.length) { 121 | throw new IllegalArgumentException(); 122 | } 123 | var kld = 0.0; 124 | for (int j = 0; j < q.length; j++) { 125 | var t = q[j]; 126 | if (q[j] == 0) { 127 | t = 0.001f; 128 | } 129 | if (p[j] > 0) { 130 | kld += p[j] * Math.log(p[j] / t); 131 | } 132 | } 133 | return kld; 134 | } 135 | 136 | public static double[] chiSquareInverseProb = new double[5000]; 137 | public static void initChiSquareInverseProb() { 138 | for (int i = 0; i < 5000; i++) { 139 | chiSquareInverseProb[i] = (i + 1) / new ChiSquaredDistribution(i + 1).inverseCumulativeProbability(0.01); 140 | } 141 | } 142 | 143 | public static double getChiSquareInverseProb(int degree) { 144 | if (degree <= chiSquareInverseProb.length) { 145 | return chiSquareInverseProb[degree - 1]; 146 | } 147 | if (degree <= 5698) { 148 | return 1.045; 149 | } 150 | if (degree <= 7161) { 151 | return 1.04; 152 | } 153 | if (degree <= 9289) { 154 | return 1.035; 155 | } 156 | if (degree <= 12554) { 157 | return 1.03; 158 | } 159 | if (degree <= 17951) { 160 | return 1.025; 161 | } 162 | if (degree <= 27850) { 163 | return 1.02; 164 | } 165 | if (degree <= 49159) { 166 | return 1.015; 167 | } 168 | return 1; 169 | } 170 | 171 | /** 172 | * Converts a camelCase string to a display string with spaces. 173 | * Examples: "FearPotion" -> "Fear Potion", "PotionOfCapacity" -> "Potion Of Capacity" 174 | */ 175 | public static String camelCaseToDisplayString(String camelCase) { 176 | if (camelCase == null || camelCase.isEmpty()) { 177 | return camelCase; 178 | } 179 | 180 | StringBuilder result = new StringBuilder(); 181 | boolean isFirstChar = true; 182 | 183 | for (int i = 0; i < camelCase.length(); i++) { 184 | char c = camelCase.charAt(i); 185 | 186 | if (Character.isUpperCase(c)) { 187 | if (!isFirstChar) { 188 | result.append(' '); 189 | } 190 | result.append(c); 191 | } else { 192 | result.append(c); 193 | } 194 | isFirstChar = false; 195 | } 196 | 197 | return result.toString(); 198 | } 199 | } -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/utils/FuzzyMatch.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.utils; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class FuzzyMatch { 7 | public static record Result(boolean sequential, int score) {} 8 | 9 | public static Result fuzzyMatch(String pattern, String string) { 10 | // Score consts 11 | int adjacency_bonus = 5; // bonus for adjacent matches 12 | int separator_bonus = 10; // bonus if match occurs after a separator 13 | int camel_bonus = 10; // bonus if match is uppercase and prev is lower 14 | // penalty applied for every letter in string before the first match 15 | int leading_letter_penalty = -3; 16 | int max_leading_letter_penalty = -9; // maximum penalty for leading letters 17 | int unmatched_letter_penalty = -1; // penalty for every letter that doesn't matter 18 | 19 | // Loop variables 20 | int score = 0; 21 | int patternIdx = 0; 22 | int patternLength = pattern.length(); 23 | int strIdx = 0; 24 | int strLength = string.length(); 25 | boolean prevMatched = false; 26 | boolean prevLower = false; 27 | boolean prevSeparator = true; // True so if first letter match gets separator bonus 28 | 29 | // Use "best" matched letter if multiple string letters match the pattern 30 | Character bestLetter = null; 31 | Character bestLower = null; 32 | Integer bestLetterIdx = null; 33 | int bestLetterScore = 0; 34 | 35 | var matchedIndices = new ArrayList(); 36 | 37 | // Loop over strings 38 | while (strIdx != strLength) { 39 | var patternChar = patternIdx != patternLength ? pattern.charAt(patternIdx) : null; 40 | var strChar = string.charAt(strIdx); 41 | 42 | var patternLower = patternChar != null ? Character.toLowerCase(patternChar) : null; 43 | var strLower = Character.toLowerCase(strChar); 44 | var strUpper = Character.toUpperCase(strChar); 45 | 46 | var nextMatch = patternChar != null && patternLower == strLower; 47 | var rematch = bestLetter != null && bestLower == strLower; 48 | 49 | var advanced = nextMatch && bestLetter != null; 50 | var patternRepeat = bestLetter != null && patternChar != null && bestLower == patternLower; 51 | if (advanced || patternRepeat) { 52 | score += bestLetterScore; 53 | matchedIndices.add(bestLetterIdx); 54 | bestLetter = null; 55 | bestLower = null; 56 | bestLetterIdx = null; 57 | bestLetterScore = 0; 58 | } 59 | 60 | if (nextMatch || rematch) { 61 | int newScore = 0; 62 | 63 | // Apply penalty for each letter before the first pattern match 64 | // Note: std::max because penalties are negative values. So max is 65 | // smallest penalty. 66 | if (patternIdx == 0) { 67 | int penalty = Math.max(strIdx * leading_letter_penalty, max_leading_letter_penalty); 68 | score += penalty; 69 | } 70 | 71 | // Apply bonus for consecutive bonuses 72 | if (prevMatched) { 73 | newScore += adjacency_bonus; 74 | } 75 | 76 | // Apply bonus for matches after a separator 77 | if (prevSeparator) { 78 | newScore += separator_bonus; 79 | } 80 | 81 | // Apply bonus across camel case boundaries. Includes "clever" 82 | // isLetter check. 83 | if (prevLower && strChar == strUpper && strLower != strUpper) { 84 | newScore += camel_bonus; 85 | } 86 | 87 | // Update patter index IFF the next pattern letter was matched 88 | if (nextMatch) { 89 | patternIdx = patternIdx + 1; 90 | } 91 | 92 | // Update best letter in string which may be for a "next" letter or a 93 | // "rematch" 94 | if (newScore >= bestLetterScore) { 95 | // Apply penalty for now skipped letter 96 | if (bestLetter != null) { 97 | score += unmatched_letter_penalty; 98 | } 99 | bestLetter = strChar; 100 | bestLower = Character.toLowerCase(bestLetter); 101 | bestLetterIdx = strIdx; 102 | bestLetterScore = newScore; 103 | } 104 | 105 | prevMatched = true; 106 | } else { 107 | // Append unmatch characters 108 | // formattedStr += strChar 109 | score += unmatched_letter_penalty; 110 | prevMatched = false; 111 | } 112 | 113 | // Includes "clever" isLetter check. 114 | prevLower = strChar == strLower && strLower != strUpper; 115 | prevSeparator = strChar == '_' || strChar == ' '; 116 | 117 | strIdx = strIdx + 1; 118 | } 119 | 120 | // Apply score for last match 121 | if (bestLetter != null) { 122 | score += bestLetterScore; 123 | matchedIndices.add(bestLetterIdx); 124 | } 125 | boolean matched = patternIdx == patternLength; 126 | return new Result(matched, score); 127 | } 128 | 129 | public static String getBestFuzzyMatch(String query, List strings) { 130 | String maxString = null; 131 | int score = Integer.MIN_VALUE; 132 | for (String string : strings) { 133 | var r = fuzzyMatch(query, string); 134 | if (r.score > score && stringContainsLetters(string, query)) { 135 | maxString = string; 136 | score = r.score; 137 | } 138 | } 139 | return maxString; 140 | } 141 | 142 | private static boolean stringContainsLetters(String str, String query) { 143 | int[] count = new int[256]; 144 | for (int i = 0; i < str.length(); i++) { 145 | count[str.charAt(i)]++; 146 | } 147 | for (int i = 0; i < query.length(); i++) { 148 | count[query.charAt(i)]--; 149 | if (count[query.charAt(i)] < 0) { 150 | return false; 151 | } 152 | } 153 | return true; 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/Configuration.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | public class Configuration { 4 | public static boolean DO_NOT_USE_CACHED_STATE_WHEN_MAKING_REAL_MOVE = false; 5 | 6 | // During training, percentage of games played with no temperature 7 | public static float TRAINING_PERCENTAGE_NO_TEMPERATURE = 0.2f; 8 | // During training, enable KataGo's forced playout 9 | public static boolean TRAINING_USE_FORCED_PLAYOUT = true; 10 | public static boolean TRAINING_EXPERIMENT_USE_UNCERTAINTY_FOR_EXPLORATION = false; 11 | public static boolean CARD_IN_HAND_IN_NN_INPUT = true; 12 | public static boolean CARD_IN_DECK_IN_NN_INPUT = true; 13 | 14 | public static float PRE_BATTLE_SCENARIO_TEMP = 0f; 15 | 16 | public static boolean PRINT_MODEL_COMPARE_DIFF = false; 17 | public static int CMP_DEVIATION_NUM_RERUN = 0; 18 | public static int SLEEP_PER_GAME = 0; 19 | public static int SLEEP_PER_GAME_TRAINING = 0; 20 | public static boolean USE_REMOTE_SERVERS = true; 21 | 22 | public static boolean COMMON_RANDOM_NUMBER_VARIANCE_REDUCTION = true; 23 | public static boolean CPUCT_SCALING = false; 24 | public static boolean TEST_CPUCT_SCALING = false; 25 | public static boolean USE_PROGRESSIVE_WIDENING = false; 26 | public static boolean TEST_PROGRESSIVE_WIDENING = false; 27 | public static boolean PROGRESSIVE_WIDENING_IMPROVEMENTS = true; 28 | public static boolean PROGRESSIVE_WIDENING_IMPROVEMENTS2 = false; 29 | public static boolean TRANSPOSITION_ACROSS_CHANCE_NODE = true; 30 | public static boolean TEST_TRANSPOSITION_ACROSS_CHANCE_NODE = false; 31 | public static boolean NEW_COMMON_RANOM_NUMBER_VARIANCE_REDUCTION = false; 32 | public static boolean TEST_NEW_COMMON_RANOM_NUMBER_VARIANCE_REDUCTION = true; 33 | public static boolean UPDATE_TRANSPOSITIONS_ON_ALL_PATH = false; 34 | public static boolean TEST_UPDATE_TRANSPOSITIONS_ON_ALL_PATH = false; 35 | public static boolean TRAINING_RESCORE_SEARCH_FOR_BEST_LINE = true; 36 | public static boolean TRAINING_SKIP_OPENING_TURNS = false; 37 | public static boolean COMBINE_END_AND_BEGIN_TURN_FOR_STOCHASTIC_BEGIN = true; 38 | public static int TRAINING_SKIP_OPENING_TURNS_UPTO = 3; 39 | public static double TRAINING_SKIP_OPENING_GAMES_INCREASE_RATIO = 1.25; 40 | 41 | public static boolean ADD_BEGIN_TURN_CTX_TO_NN_INPUT = true; 42 | public static boolean TRAINING_POLICY_SURPRISE_WEIGHTING = false; 43 | public static boolean USE_FIGHT_PROGRESS_WHEN_LOSING = true; 44 | public static boolean USE_Z_TRAINING = false; 45 | public static double DISCOUNT_REWARD_ON_RANDOM_NODE = 0.2f; 46 | 47 | public static boolean TRANSPOSITION_ALWAYS_EXPAND_NEW_NODE = false; 48 | public static boolean TEST_TRANSPOSITION_ALWAYS_EXPAND_NEW_NODE = false; 49 | 50 | public static boolean isTranspositionAlwaysExpandNewNodeOn(GameState state) { 51 | return TRANSPOSITION_ALWAYS_EXPAND_NEW_NODE && (!TEST_TRANSPOSITION_ALWAYS_EXPAND_NEW_NODE || state.properties.testNewFeature); 52 | } 53 | 54 | public static boolean isTranspositionAcrossChanceNodeOn(GameState state) { 55 | return TRANSPOSITION_ACROSS_CHANCE_NODE && (!TEST_TRANSPOSITION_ACROSS_CHANCE_NODE || state.properties.testNewFeature); 56 | } 57 | 58 | public static boolean USE_UTILITY_STD_ERR_FOR_PUCT = false; 59 | public static boolean TEST_USE_UTILITY_STD_ERR_FOR_PUCT = false; 60 | 61 | public static boolean isUseUtilityStdErrForPuctOn(GameState state) { 62 | return isTranspositionAlwaysExpandNewNodeOn(state) && USE_UTILITY_STD_ERR_FOR_PUCT && (!TEST_USE_UTILITY_STD_ERR_FOR_PUCT || state.properties.testNewFeature); 63 | } 64 | 65 | public static boolean BAN_TRANSPOSITION_IN_TREE = false; 66 | public static boolean TEST_BAN_TRANSPOSITION_IN_TREE = false; 67 | 68 | public static boolean isBanTranspositionInTreeOn(GameState state) { 69 | return BAN_TRANSPOSITION_IN_TREE && (!TEST_BAN_TRANSPOSITION_IN_TREE || state.properties.testNewFeature); 70 | } 71 | 72 | public static boolean PRIORITIZE_CHANCE_NODES_BEFORE_DETERMINISTIC_IN_TREE = true; 73 | public static boolean TEST_PRIORITIZE_CHANCE_NODES_BEFORE_DETERMINISTIC_IN_TREE = false; 74 | 75 | public static boolean isPrioritizeChanceNodesBeforeDeterministicInTreeOn(GameState state) { 76 | return PRIORITIZE_CHANCE_NODES_BEFORE_DETERMINISTIC_IN_TREE && (!TEST_PRIORITIZE_CHANCE_NODES_BEFORE_DETERMINISTIC_IN_TREE || state.properties.testNewFeature); 77 | } 78 | 79 | public static boolean FLATTEN_POLICY_AS_NODES_INCREASE = false; 80 | public static boolean TEST_FLATTEN_POLICY_AS_NODES_INCREASE = false; 81 | 82 | public static boolean isFlattenPolicyAsNodesIncreaseOn(GameState state) { 83 | return FLATTEN_POLICY_AS_NODES_INCREASE && (!TEST_FLATTEN_POLICY_AS_NODES_INCREASE || state.properties.testNewFeature); 84 | } 85 | 86 | public static boolean HEART_GAUNTLET_POTION_REWARD = false; 87 | public static boolean HEART_GAUNTLET_CARD_REWARD = false; 88 | 89 | // having this help network know when it's almost losing due to 50 turns losing rule 90 | // basic testing show it helps a bit in preventing losing to 50 turns 91 | // combine with USE_TURNS_LEFT_HEAD below for maximum effect 92 | public static boolean ADD_CURRENT_TURN_NUM_TO_NN_INPUT = true; 93 | 94 | // predict how many turns remains in the battle (regardless of win or loss), use it to select the move that 95 | // ends the battle faster among moves with similar evaluation 96 | // basic testing show it significantly helps in preventing losing to 50 turns and dramatically shorten number of turns 97 | // (usually in defect battles with lots of focus where the network stop progressing due to every move being the same eval) 98 | // disable unless needed for now, some testing shows it will cause the network to lose fights it wouldn't have lost by picking 99 | // more "aggressive" moves to do damage 100 | public static boolean USE_TURNS_LEFT_HEAD = true; 101 | // Experimental feature to test the problem mentioned above (Worked on the single instance I tested it on) 102 | // Since the main problem is when defensive output far exceeds enemy damage, add an addition head for probablity 103 | // of taking any more damage from current position and only use turns left head when it exeeds a threshold 104 | public static boolean USE_TURNS_LEFT_HEAD_ONLY_WHEN_NO_DMG = true; 105 | 106 | // print model prediction error compare to actual result 107 | public static boolean STATS_PRINT_PREDICTION_ERRORS = false; 108 | public static boolean STATS_PRINT_CARD_USAGE_COUNT = false; 109 | 110 | public static boolean USE_NEW_ACTION_SELECTION = false; 111 | 112 | public static final String ONNX_LIB_PATH = "F:/git/lib"; 113 | public static final boolean ONNX_USE_CUDA_FOR_INFERENCE = false; 114 | } 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Given the description of a fight in Slay the Spire (i.e. player health, deck, enemies, relics, and potions), uses the AlphaZero algorithm to train a neural network to play it. 2 | 3 | # Setup 4 | 5 | ## Windows: 6 | 1. Download and install JDK 17 (I have used IntelliJ's built in JDKs and Amazon Correto JDKs and both worked) 7 | 2. Download and install Maven (I use 3.8.4, but any version should work) 8 | 3. Download and install Miniconda (https://docs.conda.io/en/latest/miniconda.html) 9 | 4. Use conda to create a python environment from environment.yml and activate it 10 | 1. conda env create -f environment.ml 11 | 2. conda activate alphasts 12 | 5. Change working directory to agents folder and run mvn install to build the java agent 13 | 6. Change working directory back to project root and run python ml.py -t -c 20. It should start generating games in the saves/ directory. The default fight is a gremlin nob fight with the ironclad starter deck with bash upgraded and without ascender's bane (it will run the alphazero agorithm for 20 iterations) 14 | 7. Once finished, use python interactive.py to play the fight in the simulator (see Interactive Mode). 15 | 16 | ## Linux: 17 | Note: I only set it up on AWS once. All I remember is it works but there were error messages at step 4.iii even though everything run fine. 18 | 1. Download and install JDK 17 from your distribution 19 | 2. Download and install Maven from your distribution 20 | 3. Download and install Miniconda (https://docs.conda.io/en/latest/miniconda.html) 21 | 4. Use conda to create a python environment from environment.yml and activate it 22 | 1. conda env create -n alphasts python=3.8 23 | 2. conda activate alphasts 24 | 3. run python ml.py -t -c 20 and just download any missing dependencies from the error messages given via 'conda install' and hope it works in the end 25 | - note: use 'pip install' instead of 'conda install' instead for tf2onnx 26 | 5. Change working directory to agents folder and run mvn install to build the java agent 27 | 6. Change working directory back to project root and run python ml.py -t -c 20. It should start generating games in the saves/ directory. The default fight is a gremlin nob fight with the ironclad starter deck with bash upgraded and without ascender's bane (it will run the alphazero agorithm for 20 iterations) 28 | 7. Once finished, use python interactive.py to play the fight in the simulator (see Interactive Mode). 29 | 30 | # Interactive Mode 31 | Some common commands: 32 | 1. Enter the number/letter beside a card to play it. For example, 0. Bash -> type <0> to play bash. You can also use the card name (i.e. type Bash). Also there's fuzzy search so even 'ba' would work if there are no other possible matches. 33 | 2. n \: run monte carlo tree search on the current state for n simulations and display the state 34 | 3. nn \: give the best line of play with n simulations searched for each move 35 | 4. reset: clear all the nodes searched so far 36 | 37 | # Architecture 38 | It's suggested for the reader to have knowledge of the AlphaZero algorithm first before reading the Architecture section. 39 | 40 | ## Input 41 | All information that can change in a fight are passed to the neural network. Using the fight of A20 base Ironclad deck versus Gremlin Nob as an example: the count of Bash, Strike, Defend, Ascender's Bane in deck+hand+discard, player hp and vulnerable, Gremlin Nob hp and vulnerable, and the current move of Gremlin Nob are all passed to the network. Information such as energy per turn, strength, or weakness are not passed in since they are static values. Those are consider as part of the environment. All inputs are standardized to around [-1, 1]. 42 | 43 | Multi-steps actions are split up as individual action while playing the fight. For example, to play True Grit+, we first select True Grit+ as the card to play, then we select what card to exhaust from our hand. As such, additional inputs such as the current state context (e.g. selecting a card from hand) and the card that triggered the current context (e.g. True Grit+) are also passed in as inputs as needed. 44 | 45 | ## Neural Network Architecture 46 | Currently just a simple 2 layer fully connected feedforward network. 47 | 48 | ## Output 49 | **policy**: probability distribution over each possible action. This includes for each card an action to select the card to play, for each enemy an action to select the enemy, for each potion an action to use the potion, and ending the turn. Currently the action for selecting a card to play are also reused for selecting a card from hand, selecting a card from discard, and selecting a card from exhaust. 50 | 51 | **v_win**: a value between -1 and 1 represting win rate, with -1 mapping to 0% chance of winning and 1 mapping 100% chance of meaning. 52 | 53 | **v_health**: a value between -1 and 1 represting expected health at the end of fight, with -1 mapping to 0 health and 1 mapping to max player health. 54 | 55 | Note the agent will shift v_win and v_health to [0, 1] when using it. We will refer to the shifted value as v_win_norm and v_health_norm. 56 | 57 | There may also be extra outputs for meta rewards such as using Feed to kill an enemy depending on the fight. 58 | 59 | ## Loss 60 | The loss function is 0.45 * MSE(v_health) + 0.45 * MSE(v_win) + 0.1 * crossEntropyLoss(policy) + 0.25 * MSE(meta rewards). 61 | 62 | Note I haven't added regularization to the loss yet (planning to test it). 63 | 64 | Note the high value to policy loss ratio empirically speeds up training for the few cases I have tested it on. Based on [Accelerating and Improving AlphaZero Using Population Based Training 65 | ](https://arxiv.org/abs/2003.06212) and [Alternative Loss Functions in AlphaZero-like Self-play 66 | ](https://ieeexplore.ieee.org/document/9002814), high value to loss ratio are also preferred in other games. Intuitively speaking, it should be more important that the network can distinguish between a 60% win rate position and a 63% win rate position than it is to have a slightly bit more accurate policy because MCTS would override the policy as more simulations are run. 67 | 68 | Another thing to note is currently only legal actions contribute to the cross entropy loss of the policy. i.e. non-legal actions are masked out when calculating the loss. 69 | 70 | ## MCTS 71 | 72 | ## Reward (or Q value to maximize for MCTS) 73 | Currently the reward is v_base=v_win_norm / 2 + (v_win_norm2 * v_health_norm) / 2. The intention is that win rate is maximized first, and if win rate are about equal then expected health at the end of the fight is maximized. 74 | 75 | Playing a potion will penalize v_base by a percentage (e.g. 20% generally means a potion won't be played unless the situation is dire, while 0% means use it at the possible moment to maximize win rate and expected health). The percentage is passed in as part of the fight description as a measure of how valuable the potion is. This penalty is only applied during MCTS search, but the neural network will learn about this penalty in its policy and value outputs. Note this was a very adhoc choice. It seem to work alright so far but there should be better choices out there. 76 | 77 | Meta rewards currently are added to v_health for purpose of reward calculation. So e.g. using Feed to kill an enemy adds 6 hp to v_health_norm for the purpose of calculating v_base. The amount of health added is also passed in as part of the fight description as a measure of how valuable the meta reward is. 78 | 79 | ### Chance Nodes 80 | Todo 81 | 82 | ### Transpositions 83 | Todo 84 | 85 | ### Terminal Propagation 86 | Todo 87 | 88 | ## Training 89 | Todo 90 | 91 | ## Domain Knowledge 92 | Todo -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/player/Player.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.player; 2 | 3 | import com.alphaStS.*; 4 | import com.alphaStS.enums.DebuffType; 5 | 6 | public class Player extends PlayerReadOnly { 7 | public Player(int health, int maxHealth, int inBattleMaxHealth) { 8 | super(health, maxHealth, inBattleMaxHealth); 9 | } 10 | 11 | public Player(int health, int maxHealth) { 12 | super(health, maxHealth); 13 | } 14 | 15 | public Player(Player other) { 16 | super(other); 17 | } 18 | 19 | public int damage(GameState state, int n) { 20 | int startHealth = health; 21 | int dmg = Math.max(0, n - block); 22 | if (dmg <= 5 && dmg >= 2 && state.properties.torii != null && state.properties.torii.isRelicEnabledInScenario(state)) { 23 | dmg = 1; 24 | } 25 | if (dmg > 0 && state.properties.tungstenRod != null && state.properties.tungstenRod.isRelicEnabledInScenario(state)) { 26 | dmg -= 1; 27 | } 28 | if (n > block && state.properties.bufferCounterIdx >= 0 && state.getCounterForRead()[state.properties.bufferCounterIdx] > 0) { 29 | dmg = 0; 30 | state.getCounterForWrite()[state.properties.bufferCounterIdx]--; 31 | } 32 | health -= dmg; 33 | block = Math.max(0, block - n); 34 | if (platedArmor > 0 && dmg > 0) { 35 | platedArmor -= 1; 36 | } 37 | if (health < 0) { 38 | health = 0; 39 | } 40 | int dmgDealt = startHealth - health; 41 | if (health == 0) { 42 | tryReviveWithFairyInABottle(state); 43 | } 44 | if (dmgDealt > 0) { 45 | accumulatedDamage += dmgDealt; 46 | } 47 | return dmgDealt; 48 | } 49 | 50 | public int nonAttackDamage(GameState state, int n, boolean blockable) { 51 | int startHealth = health; 52 | if (blockable) { 53 | int dmg = Math.max(0, n - block); 54 | if (dmg > 0 && state.properties.tungstenRod != null && state.properties.tungstenRod.isRelicEnabledInScenario(state)) { 55 | dmg -= 1; 56 | } 57 | if (n > block && state.properties.bufferCounterIdx >= 0 && state.getCounterForRead()[state.properties.bufferCounterIdx] > 0) { 58 | dmg = 0; 59 | state.getCounterForWrite()[state.properties.bufferCounterIdx]--; 60 | } 61 | health -= dmg; 62 | block = Math.max(0, block - n); 63 | } else { 64 | int dmg = n; 65 | if (dmg > 0 && state.properties.tungstenRod != null && state.properties.tungstenRod.isRelicEnabledInScenario(state)) { 66 | dmg -= 1; 67 | } 68 | if (dmg >= 0 && state.properties.bufferCounterIdx >= 0 && state.getCounterForRead()[state.properties.bufferCounterIdx] > 0) { 69 | dmg = 0; 70 | state.getCounterForWrite()[state.properties.bufferCounterIdx]--; 71 | } 72 | health -= dmg; 73 | } 74 | if (health < 0) { 75 | health = 0; 76 | } 77 | int dmgDealt = startHealth - health; 78 | if (health == 0) { 79 | tryReviveWithFairyInABottle(state); 80 | } 81 | if (dmgDealt > 0) { 82 | accumulatedDamage += dmgDealt; 83 | } 84 | return dmgDealt; 85 | } 86 | 87 | private void tryReviveWithFairyInABottle(GameState state) { 88 | for (int i = 0; i < state.properties.potions.size(); i++) { 89 | if (state.potionUsable(i) && state.properties.potions.get(i) instanceof Potion.FairyInABottle) { 90 | state.getPotionsStateForWrite()[i * 3] = 0; 91 | state.properties.potions.get(i).use(state, -1); 92 | break; 93 | } 94 | } 95 | } 96 | 97 | public int heal(int n) { 98 | int healed = Math.min(n, getInBattleMaxHealth() - health); 99 | health += healed; 100 | return healed; 101 | } 102 | 103 | public int gainBlock(int n) { 104 | if (noMoreBlockFromCards > 0) { 105 | return 0; 106 | } 107 | n += dexterity; 108 | n = frail > 0? (n - (n + 3) / 4) : n; 109 | if (n < 0) { 110 | n = 0; 111 | } 112 | block += n; 113 | if (block > 999) { 114 | block = 999; 115 | } 116 | return n; 117 | } 118 | 119 | public void gainBlockNotFromCardPlay(int n) { 120 | block += n; 121 | if (block > 999) { 122 | block = 999; 123 | } 124 | } 125 | 126 | public void gainStrength(int n) { 127 | strength += n; 128 | strength = Math.min(999, Math.max(-999, strength)); 129 | } 130 | 131 | public void gainDexterity(int n) { 132 | dexterity += n; 133 | dexterity = Math.min(999, Math.max(-999, dexterity)); 134 | } 135 | 136 | public void applyDebuff(GameState state, DebuffType type, int n) { 137 | if (n == 0) { 138 | return; 139 | } 140 | if (state.properties.ginger != null && state.properties.ginger.isRelicEnabledInScenario(state) && type == DebuffType.WEAK) { 141 | return; 142 | } else if (state.properties.turnip != null && state.properties.turnip.isRelicEnabledInScenario(state) && type == DebuffType.FRAIL) { 143 | return; 144 | } 145 | if (artifact > 0) { 146 | artifact--; 147 | return; 148 | } 149 | // add 1 to some debuff so they don't get cleared after enemies do their move 150 | switch (type) { 151 | case VULNERABLE -> this.vulnerable += state.getActionCtx() == GameActionCtx.BEGIN_TURN && this.vulnerable == 0 ? n + 1 : n; 152 | case WEAK -> this.weak += state.getActionCtx() == GameActionCtx.BEGIN_TURN && this.weak == 0 ? n + 1 : n; 153 | case FRAIL -> this.frail += state.getActionCtx() == GameActionCtx.BEGIN_TURN && this.frail == 0 ? n + 1 : n; 154 | case LOSE_STRENGTH -> this.gainStrength(-n); 155 | case LOSE_DEXTERITY -> this.gainDexterity(-n); 156 | case LOSE_STRENGTH_EOT -> this.loseStrengthEot += n; 157 | case LOSE_DEXTERITY_EOT -> this.loseDexterityEot += n; 158 | case LOSE_DEXTERITY_PER_TURN -> state.getCounterForWrite()[state.properties.loseDexterityPerTurnCounterIdx] += n; 159 | case NO_MORE_CARD_DRAW -> this.cannotDrawCard = true; 160 | case ENTANGLED -> this.entangled = state.getActionCtx() == GameActionCtx.BEGIN_TURN && this.entangled == 0 ? n + 1 : n; 161 | case HEX -> this.hexed = true; 162 | case LOSE_FOCUS -> state.gainFocus(-n); 163 | case LOSE_FOCUS_PER_TURN -> state.getCounterForWrite()[state.properties.loseFocusPerTurnCounterIdx] += n; 164 | case LOSE_ENERGY_PER_TURN -> state.getCounterForWrite()[state.properties.loseEnergyPerTurnCounterIdx] += n; 165 | case CONSTRICTED -> state.getCounterForWrite()[state.properties.constrictedCounterIdx] += n; 166 | case DRAW_REDUCTION -> state.getCounterForWrite()[state.properties.drawReductionCounterIdx] += n; 167 | case NO_BLOCK_FROM_CARDS -> this.noMoreBlockFromCards += n; 168 | case SNECKO -> state.getCounterForWrite()[state.properties.sneckoDebuffCounterIdx] = 1; 169 | } 170 | } 171 | 172 | public void preEndTurn(GameState state) { 173 | cannotDrawCard = false; 174 | if (platedArmor > 0) { 175 | gainBlockNotFromCardPlay(platedArmor); 176 | } 177 | } 178 | 179 | public void endTurn(GameState state) { 180 | if (vulnerable > 0) { 181 | vulnerable -= 1; 182 | } 183 | if (weak > 0) { 184 | weak -= 1; 185 | } 186 | if (frail > 0) { 187 | frail -= 1; 188 | } 189 | if (entangled > 0) { 190 | entangled -= 1; 191 | } 192 | if (noMoreBlockFromCards > 0) { 193 | noMoreBlockFromCards -= 1; 194 | } 195 | if (loseStrengthEot > 0) { 196 | applyDebuff(state, DebuffType.LOSE_STRENGTH, loseStrengthEot); 197 | } 198 | if (loseDexterityEot > 0) { 199 | applyDebuff(state, DebuffType.LOSE_DEXTERITY, loseDexterityEot); 200 | } 201 | loseStrengthEot = 0; 202 | loseDexterityEot = 0; 203 | if ((state.buffs & PlayerBuff.BARRICADE.mask()) != 0) { 204 | } else if (state.properties.blurCounterIdx >= 0 && state.getCounterForRead()[state.properties.blurCounterIdx] > 0) { 205 | state.getCounterForWrite()[state.properties.blurCounterIdx]--; 206 | } else if (state.properties.calipers != null && state.properties.calipers.isRelicEnabledInScenario(state)) { 207 | block = Math.max(block - 15, 0); 208 | } else { 209 | block = 0; 210 | } 211 | } 212 | 213 | public void gainArtifact(int n) { 214 | artifact += n; 215 | } 216 | 217 | public void gainPlatedArmor(int n) { 218 | platedArmor += n; 219 | } 220 | 221 | public void setHealth(int hp) { 222 | health = hp; 223 | } 224 | 225 | public void setOrigHealth(int hp) { 226 | origHealth = hp; 227 | health = hp; 228 | } 229 | 230 | public void setInBattleMaxHealth(int maxHealth) { 231 | this.inBattleMaxHealth = maxHealth; 232 | this.health = Math.min(maxHealth, this.health); 233 | } 234 | 235 | public void setAccumulatedDamage(int health) { 236 | this.accumulatedDamage = health; 237 | } 238 | 239 | public void removeAllDebuffs(GameState state) { 240 | vulnerable = 0; 241 | weak = 0; 242 | frail = 0; 243 | cannotDrawCard = false; 244 | hexed = false; 245 | entangled = 0; 246 | loseStrengthEot = 0; 247 | loseDexterityEot = 0; 248 | if (strength < 0) { 249 | strength = 0; 250 | } 251 | if (dexterity < 0) { 252 | dexterity = 0; 253 | } 254 | noMoreBlockFromCards = 0; 255 | if (state.properties.loseDexterityPerTurnCounterIdx >= 0) { 256 | state.getCounterForWrite()[state.properties.loseDexterityPerTurnCounterIdx] = 0; 257 | } 258 | if (state.properties.loseFocusPerTurnCounterIdx >= 0) { 259 | state.getCounterForWrite()[state.properties.loseFocusPerTurnCounterIdx] = 0; 260 | } 261 | if (state.properties.constrictedCounterIdx >= 0) { 262 | state.getCounterForWrite()[state.properties.constrictedCounterIdx] = 0; 263 | } 264 | if (state.properties.drawReductionCounterIdx >= 0) { 265 | state.getCounterForWrite()[state.properties.drawReductionCounterIdx] = 0; 266 | } 267 | if (state.properties.sneckoDebuffCounterIdx >= 0) { 268 | state.getCounterForWrite()[state.properties.sneckoDebuffCounterIdx] = 0; 269 | } 270 | } 271 | } 272 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/model/ModelExecutor.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.model; 2 | 3 | import com.alphaStS.GameState; 4 | import com.alphaStS.utils.Tuple; 5 | 6 | import java.util.ArrayList; 7 | import java.util.List; 8 | import java.util.concurrent.ArrayBlockingQueue; 9 | import java.util.concurrent.BlockingQueue; 10 | import java.util.concurrent.locks.Condition; 11 | import java.util.concurrent.locks.Lock; 12 | import java.util.concurrent.locks.ReentrantLock; 13 | import java.util.stream.Collectors; 14 | 15 | // likely due to JNI, batching nn evaluations even on CPU is faster than executing 1 by 1 by 10x (small networks) to 3x (larger networks) 16 | // currently, use producers threads to produce nn evaluations requests and executors to execute them 17 | // todo: switch to continuations for better performance by removing context switch overheads from large amount of threads + locking/communication cost 18 | // consideration: if I use C++ instead I don't have to do any of those, although the same batching logic can be used for GPU evaluation 19 | public class ModelExecutor { 20 | public final static Integer EXEC_WAITING = 0; 21 | public final static Integer EXEC_SUCCESS = 1; 22 | public final static Integer EXEC_SUCCESS_CACHE_HIT = 2; 23 | 24 | private final List executorThreads = new ArrayList<>(); 25 | private final List executorModels = new ArrayList<>(); 26 | private final List producerThreads = new ArrayList<>(); 27 | private final List producerModels = new ArrayList<>(); 28 | private final String modelDir; 29 | 30 | // ModelBatchProducer uses those for communication 31 | BlockingQueue> queue = new ArrayBlockingQueue<>(1000); 32 | List locks = new ArrayList<>(); 33 | List conditions = new ArrayList<>(); 34 | List results = new ArrayList<>(); 35 | List resultsStatus = new ArrayList<>(); 36 | private int currentBatchSize; 37 | private boolean running = true; 38 | 39 | public ModelExecutor(String modelDir) { 40 | this.modelDir = modelDir; 41 | } 42 | 43 | public void start(int executorCount, int batchSize) { 44 | if (executorCount >= executorModels.size()) { 45 | for (int i = executorModels.size(); i < executorCount; i++) { 46 | executorModels.add(new ModelPlain(modelDir)); 47 | } 48 | } 49 | if (executorCount >= executorThreads.size()) { 50 | for (int i = executorThreads.size(); i < executorCount; i++) { 51 | executorThreads.add(null); 52 | } 53 | } 54 | 55 | currentBatchSize = batchSize; 56 | if (batchSize == 1) { 57 | return; 58 | } 59 | 60 | running = true; 61 | var _pool = this; 62 | for (int i = 0; i < executorCount; i++) { 63 | int _i = i; 64 | executorThreads.set(i, new Thread(() -> { 65 | List> reqs = new ArrayList<>(); 66 | List hashes = new ArrayList<>(); 67 | int reqCount = 0; 68 | while (_pool.running) { 69 | try { 70 | var req = queue.poll(); 71 | if (req != null) { 72 | if (req.v2() != null) { 73 | var hash = new NNInputHash(req.v2().getNNInput()); 74 | var output = executorModels.get(_i).tryUseCache(hash, req.v2()); 75 | if (output != null) { 76 | locks.get(req.v1()).lock(); 77 | results.set(req.v1(), output); 78 | resultsStatus.set(req.v1(), EXEC_SUCCESS_CACHE_HIT); 79 | conditions.get(req.v1()).signal(); 80 | locks.get(req.v1()).unlock(); 81 | continue; 82 | } 83 | hashes.add(hash); 84 | reqCount++; // toward the end, fake requests will come in to keep the executor running without delay for the real requests 85 | } 86 | reqs.add(req); 87 | if (reqs.size() >= batchSize) { 88 | if (reqCount > 0) { 89 | var outputs = executorModels.get(_i).eval(reqs.stream().map(Tuple::v2).collect(Collectors.toList()), hashes, reqCount); 90 | for (int resultsIdx = 0; resultsIdx < reqs.size(); resultsIdx++) { 91 | locks.get(reqs.get(resultsIdx).v1()).lock(); 92 | results.set(reqs.get(resultsIdx).v1(), outputs.get(resultsIdx)); 93 | resultsStatus.set(reqs.get(resultsIdx).v1(), EXEC_SUCCESS); 94 | conditions.get(reqs.get(resultsIdx).v1()).signal(); 95 | locks.get(reqs.get(resultsIdx).v1()).unlock(); 96 | } 97 | } else { 98 | for (int resultsIdx = 0; resultsIdx < reqs.size(); resultsIdx++) { 99 | locks.get(reqs.get(resultsIdx).v1()).lock(); 100 | resultsStatus.set(reqs.get(resultsIdx).v1(), EXEC_SUCCESS); 101 | conditions.get(reqs.get(resultsIdx).v1()).signal(); 102 | locks.get(reqs.get(resultsIdx).v1()).unlock(); 103 | } 104 | } 105 | reqs.clear(); 106 | hashes.clear(); 107 | reqCount = 0; 108 | } 109 | } else if (reqs.size() > 0) { 110 | if (reqCount > 0) { 111 | var outputs = executorModels.get(_i).eval(reqs.stream().map(Tuple::v2).collect(Collectors.toList()), hashes, reqCount); 112 | for (int resultsIdx = 0; resultsIdx < reqs.size(); resultsIdx++) { 113 | locks.get(reqs.get(resultsIdx).v1()).lock(); 114 | results.set(reqs.get(resultsIdx).v1(), outputs.get(resultsIdx)); 115 | resultsStatus.set(reqs.get(resultsIdx).v1(), EXEC_SUCCESS); 116 | conditions.get(reqs.get(resultsIdx).v1()).signal(); 117 | locks.get(reqs.get(resultsIdx).v1()).unlock(); 118 | } 119 | } else { 120 | for (int resultsIdx = 0; resultsIdx < reqs.size(); resultsIdx++) { 121 | locks.get(reqs.get(resultsIdx).v1()).lock(); 122 | resultsStatus.set(reqs.get(resultsIdx).v1(), EXEC_SUCCESS); 123 | conditions.get(reqs.get(resultsIdx).v1()).signal(); 124 | locks.get(reqs.get(resultsIdx).v1()).unlock(); 125 | } 126 | } 127 | reqs.clear(); 128 | hashes.clear(); 129 | reqCount = 0; 130 | } 131 | } catch (Exception e) { 132 | _pool.running = false; 133 | throw e; 134 | } 135 | } 136 | })); 137 | executorThreads.get(i).start(); 138 | } 139 | } 140 | 141 | public void stop() { 142 | running = false; 143 | for (Thread thread : producerThreads) { 144 | thread.interrupt(); 145 | } 146 | for (Thread thread : producerThreads) { 147 | try { 148 | thread.join(); 149 | } catch (InterruptedException e) { 150 | throw new RuntimeException(e); 151 | } 152 | } 153 | producerThreads.clear(); 154 | if (currentBatchSize == 1) { 155 | return; 156 | } 157 | 158 | for (Thread thread : executorThreads) { 159 | thread.interrupt(); 160 | } 161 | for (Thread thread : executorThreads) { 162 | try { 163 | thread.join(); 164 | } catch (InterruptedException e) { 165 | throw new RuntimeException(e); 166 | } 167 | } 168 | executorThreads.clear(); 169 | for (int i = 0; i < producerModels.size(); i++) { 170 | resultsStatus.set(i, EXEC_WAITING); 171 | } 172 | queue.clear(); 173 | } 174 | 175 | public Model getModelForProducer(int producerIdx) { 176 | if (currentBatchSize == 1) { 177 | return executorModels.get(producerIdx); 178 | } 179 | for (int i = producerModels.size(); i <= producerIdx; i++) { 180 | producerModels.add(null); 181 | locks.add(null); 182 | conditions.add(null); 183 | results.add(null); 184 | resultsStatus.add(EXEC_WAITING); 185 | } 186 | producerModels.set(producerIdx, new ModelBatchProducer(this, producerIdx)); 187 | locks.set(producerIdx, new ReentrantLock()); 188 | conditions.set(producerIdx, locks.get(producerIdx).newCondition()); 189 | results.set(producerIdx, null); 190 | resultsStatus.set(producerIdx, EXEC_WAITING); 191 | return producerModels.get(producerIdx); 192 | } 193 | 194 | public Thread addAndStartProducerThread(Runnable t) { 195 | var _pool = this; 196 | producerThreads.add(new Thread(() -> { 197 | try { 198 | t.run(); 199 | } catch (RuntimeException e) { 200 | if (!(e.getCause() instanceof InterruptedException)) { 201 | _pool.running = false; 202 | throw e; 203 | } 204 | } 205 | })); 206 | producerThreads.get(producerThreads.size() - 1).start(); 207 | return producerThreads.get(producerThreads.size() - 1); 208 | } 209 | 210 | public boolean producerWaitForClose(int producerIdx) { 211 | try { 212 | if (currentBatchSize == 1 || !running) { 213 | return true; 214 | } 215 | producerModels.get(producerIdx).eval(null); 216 | } catch (Exception e) { 217 | if (e.getCause() instanceof InterruptedException) { 218 | return !running; 219 | } 220 | throw e; 221 | } 222 | return !running; 223 | } 224 | 225 | public List getExecutorModels() { 226 | return new ArrayList<>(this.executorModels); 227 | } 228 | 229 | public void close() { 230 | for (int i = 0; i < executorModels.size(); i++) { 231 | executorModels.get(i).close(); 232 | } 233 | } 234 | 235 | public static int getNumberOfProducers(int numberOfThreads, int batch) { 236 | return batch == 1 ? numberOfThreads : (numberOfThreads * batch + 2 * batch); 237 | } 238 | 239 | public void sleep(int ms) { 240 | try { 241 | if (ms > 0) Thread.sleep(ms); 242 | } catch (InterruptedException e) { 243 | throw new RuntimeException(e); 244 | } 245 | } 246 | 247 | public boolean isRunning() { 248 | return running; 249 | } 250 | 251 | public void setToNotRunning() { 252 | running = false; 253 | } 254 | } 255 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/enemy/EnemyReadOnly.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.enemy; 2 | 3 | import com.alphaStS.card.Card; 4 | import com.alphaStS.GameProperties; 5 | import com.alphaStS.GameState; 6 | import com.alphaStS.GameStateUtils; 7 | 8 | import java.util.List; 9 | import java.util.Objects; 10 | 11 | public abstract class EnemyReadOnly implements GameProperties.TrainingTargetRegistrant { 12 | public static class EnemyProperties implements Cloneable { 13 | public int numOfMoves; 14 | public int maxHealth; 15 | public int origMaxHealth; // during randomization, property cloning can be set to change origHealth and hasBurningHealthBuff for that battle 16 | public int actNumber; 17 | protected boolean hasBurningHealthBuff = false; 18 | public boolean isElite = false; 19 | public boolean isBoss = false; 20 | public boolean isMinion = false; 21 | public boolean canVulnerable = false; 22 | public boolean canEntangle = false; 23 | public boolean canWeaken = false; 24 | public boolean canFrail = false; 25 | public boolean canGainStrength = false; 26 | public boolean canGainRegeneration = false; 27 | public boolean canHeal = false; 28 | public boolean canGainMetallicize = false; 29 | public boolean canGainPlatedArmor = false; 30 | public boolean canGainBlock = false; 31 | public boolean changePlayerStrength = false; 32 | public boolean changePlayerFocus = false; 33 | public boolean changePlayerDexterity = false; 34 | public boolean hasBurningEliteBuff = false; 35 | public boolean hasArtifact = false; 36 | public boolean useLast2MovesForMoveSelection; 37 | public boolean canSelfRevive; 38 | public boolean isAct3; 39 | public int generatedCardIdx = -1; // when getPossibleGeneratedCards return 1 card, this is the card index for it 40 | public int[] generatedCardIdxes; // when getPossibleGeneratedCards returns non-empty list, this is the card indexes for each card in the order of the list 41 | public int[] generatedCardReverseIdxes; // given a cardIdx, return the index of it in generatedCardIdxes (-1 otherwise) 42 | 43 | public EnemyProperties(int numOfMoves, boolean useLast2MovesForMoveSelection) { 44 | this.numOfMoves = numOfMoves; 45 | this.useLast2MovesForMoveSelection = useLast2MovesForMoveSelection; 46 | } 47 | 48 | public void applyBurningEliteBuff() { 49 | hasBurningEliteBuff= true; 50 | canGainMetallicize = true; 51 | canGainRegeneration = true; 52 | maxHealth = (int) (maxHealth * 1.25); 53 | } 54 | 55 | public boolean hasBurningEliteBuff() { 56 | return hasBurningEliteBuff; 57 | } 58 | 59 | public EnemyProperties clone() { 60 | try { 61 | return (EnemyProperties) super.clone(); 62 | } catch (CloneNotSupportedException e) { 63 | throw new RuntimeException(e); 64 | } 65 | } 66 | } 67 | 68 | public EnemyProperties properties; 69 | protected int health; 70 | protected int maxHealthInBattle; 71 | protected int block; 72 | protected int strength; 73 | protected int vulnerable; 74 | protected int weak; 75 | protected int artifact; 76 | protected int poison; 77 | protected int regeneration; 78 | protected int metallicize; 79 | protected int platedArmor; 80 | protected int loseStrengthEot; 81 | protected int corpseExplosion; 82 | protected int choke; 83 | protected int lockOn; 84 | protected int talkToTheHand; 85 | protected int mark; 86 | protected int move = -1; 87 | protected int lastMove = -1; 88 | int vExtraIdx = -1; 89 | 90 | public EnemyReadOnly(int health, int numOfMoves, boolean useLast2MovesForMoveSelection) { 91 | this.health = health; 92 | this.maxHealthInBattle = health; 93 | properties = new EnemyProperties(numOfMoves, useLast2MovesForMoveSelection); 94 | properties.maxHealth = health; 95 | properties.origMaxHealth = health; 96 | } 97 | 98 | public EnemyReadOnly(EnemyReadOnly other) { 99 | this.properties = other.properties; 100 | } 101 | 102 | public abstract void doMove(GameState state, EnemyReadOnly self); 103 | public abstract Enemy copy(); 104 | public abstract String getMoveString(GameState state, int move); 105 | 106 | public void setupGeneratedCardIndexes(GameProperties gameProperties) { 107 | List possibleCards = getPossibleGeneratedCards(gameProperties, List.of(gameProperties.cardDict)); 108 | var result = GameStateUtils.setupGeneratedCardIndexes(possibleCards, gameProperties); 109 | properties.generatedCardIdx = result.v1(); 110 | properties.generatedCardIdxes = result.v2(); 111 | properties.generatedCardReverseIdxes = result.v3(); 112 | } 113 | 114 | public void setVExtraIdx(GameProperties gameProperties, int idx) { 115 | vExtraIdx = idx; 116 | } 117 | 118 | public void gamePropertiesSetup(GameState state) {} 119 | public List getPossibleGeneratedCards(GameProperties prop, List cards) { return List.of(); } 120 | public int getNNInputLen(GameProperties prop) { return 0; } 121 | public String getNNInputDesc(GameProperties prop) { return null; } 122 | public int writeNNInput(GameProperties prop, float[] input, int idx) { return 0; } 123 | 124 | public String getMoveString(GameState state) { 125 | return getMoveString(state, this.move); 126 | } 127 | 128 | public String getLastMoveString(GameState state) { 129 | return getMoveString(state, this.lastMove); 130 | } 131 | 132 | public abstract String getName(); 133 | 134 | protected void copyFieldsFrom(Enemy other) { 135 | properties = other.properties; 136 | health = other.health; 137 | maxHealthInBattle = other.maxHealthInBattle; 138 | block = other.block; 139 | strength = other.strength; 140 | vulnerable = other.vulnerable; 141 | weak = other.weak; 142 | artifact = other.artifact; 143 | poison = other.poison; 144 | loseStrengthEot = other.loseStrengthEot; 145 | regeneration = other.regeneration; 146 | metallicize = other.metallicize; 147 | platedArmor = other.platedArmor; 148 | corpseExplosion = other.corpseExplosion; 149 | choke = other.choke; 150 | lockOn = other.lockOn; 151 | talkToTheHand = other.talkToTheHand; 152 | mark = other.mark; 153 | move = other.move; 154 | lastMove = other.lastMove; 155 | } 156 | 157 | public int getHealth() { 158 | return health; 159 | } 160 | 161 | public int getMaxHealthInBattle() { 162 | return maxHealthInBattle; 163 | } 164 | 165 | public int getBlock() { 166 | return block; 167 | } 168 | 169 | public int getStrength() { 170 | return strength; 171 | } 172 | 173 | public int getLoseStrengthEot() { 174 | return loseStrengthEot; 175 | } 176 | 177 | public int getVulnerable() { 178 | return vulnerable; 179 | } 180 | 181 | public int getWeak() { 182 | return weak; 183 | } 184 | 185 | public int getRegeneration() { 186 | return regeneration; 187 | } 188 | 189 | public int getMetallicize() { 190 | return metallicize; 191 | } 192 | 193 | public int getPlatedArmor() { 194 | return platedArmor; 195 | } 196 | 197 | public int getArtifact() { 198 | return artifact; 199 | } 200 | 201 | public int getPoison() { 202 | return poison; 203 | } 204 | 205 | public int getCorpseExplosion() { 206 | return corpseExplosion; 207 | } 208 | 209 | public int getChoke() { 210 | return choke; 211 | } 212 | 213 | public int getLockOn() { 214 | return lockOn; 215 | } 216 | 217 | public int getTalkToTheHand() { 218 | return talkToTheHand; 219 | } 220 | 221 | public int getMark() { 222 | return mark; 223 | } 224 | 225 | public boolean hasBurningHealthBuff() { 226 | return properties.hasBurningHealthBuff; 227 | } 228 | 229 | public int getMove() { 230 | return move; 231 | } 232 | 233 | public int getLastMove() { 234 | return lastMove; 235 | } 236 | 237 | public boolean isAlive() { 238 | return health > 0; 239 | } 240 | 241 | public boolean isTargetable() { 242 | return health > 0; 243 | } 244 | 245 | public String toString(GameState state) { 246 | String str = this.getName() + "{hp=" + health; 247 | if (strength != 0) { 248 | str += ", str=" + strength; 249 | } 250 | if (loseStrengthEot != 0) { 251 | str += ", gainStrEot=" + -loseStrengthEot; 252 | } 253 | if (move >= 0) { 254 | if (state.properties.isRunicDomeEnabled(state)) { 255 | str += ", lastMove=" + getMoveString(state); 256 | } else { 257 | str += ", move=" + getMoveString(state); 258 | } 259 | if (properties.useLast2MovesForMoveSelection) { 260 | if (lastMove >= 0) { 261 | str += " [prev=" + getLastMoveString(state) + "]"; 262 | } 263 | } 264 | } 265 | if (block > 0) { 266 | str += ", block=" + block; 267 | } 268 | if (vulnerable > 0) { 269 | str += ", vuln=" + vulnerable; 270 | } 271 | if (weak > 0) { 272 | str += ", weak=" + weak; 273 | } 274 | if (artifact > 0) { 275 | str += ", art=" + artifact; 276 | } 277 | if (poison > 0) { 278 | str += ", poison=" + poison; 279 | } 280 | if (regeneration > 0) { 281 | str += ", regen=" + regeneration; 282 | } 283 | if (metallicize > 0) { 284 | str += ", metallicize=" + metallicize; 285 | } 286 | if (platedArmor > 0) { 287 | str += ", platedArmor=" + platedArmor; 288 | } 289 | if (corpseExplosion > 0) { 290 | str += ", corpseExplosion=" + corpseExplosion; 291 | } 292 | if (choke > 0) { 293 | str += ", choke=" + choke; 294 | } 295 | if (lockOn > 0) { 296 | str += ", lockOn=" + lockOn; 297 | } 298 | if (talkToTheHand > 0) { 299 | str += ", talkToTheHand=" + talkToTheHand; 300 | } 301 | if (mark > 0) { 302 | str += ", mark=" + mark; 303 | } 304 | return str + '}'; 305 | } 306 | 307 | public String toString() { 308 | String str = this.getName() + "{hp=" + health; 309 | if (strength != 0) { 310 | str += ", str=" + strength; 311 | } 312 | if (loseStrengthEot != 0) { 313 | str += ", gainStrEot=" + -loseStrengthEot; 314 | } 315 | if (move >= 0) { 316 | str += ", move=" + move; 317 | } 318 | if (block > 0) { 319 | str += ", block=" + block; 320 | } 321 | if (vulnerable > 0) { 322 | str += ", vuln=" + vulnerable; 323 | } 324 | if (weak > 0) { 325 | str += ", weak=" + weak; 326 | } 327 | if (artifact > 0) { 328 | str += ", art=" + artifact; 329 | } 330 | if (poison > 0) { 331 | str += ", poison=" + poison; 332 | } 333 | if (regeneration > 0) { 334 | str += ", regen=" + regeneration; 335 | } 336 | if (metallicize > 0) { 337 | str += ", metallicize=" + metallicize; 338 | } 339 | if (platedArmor > 0) { 340 | str += ", platedArmor=" + platedArmor; 341 | } 342 | if (corpseExplosion > 0) { 343 | str += ", corpseExplosion=" + corpseExplosion; 344 | } 345 | if (choke > 0) { 346 | str += ", choke=" + choke; 347 | } 348 | if (lockOn > 0) { 349 | str += ", lockOn=" + lockOn; 350 | } 351 | if (talkToTheHand > 0) { 352 | str += ", talkToTheHand=" + talkToTheHand; 353 | } 354 | if (mark > 0) { 355 | str += ", mark=" + mark; 356 | } 357 | return str + '}'; 358 | } 359 | 360 | @Override public boolean equals(Object o) { 361 | if (this == o) 362 | return true; 363 | if (o == null || getClass() != o.getClass()) 364 | return false; 365 | Enemy enemy = (Enemy) o; 366 | if (health == 0 && health == enemy.health) { 367 | return true; 368 | } 369 | return health == enemy.health && maxHealthInBattle == enemy.maxHealthInBattle && move == enemy.move && lastMove == enemy.lastMove && block == enemy.block && 370 | strength == enemy.strength && vulnerable == enemy.vulnerable && weak == enemy.weak && artifact == enemy.artifact && 371 | poison == enemy.poison && loseStrengthEot == enemy.loseStrengthEot && corpseExplosion == enemy.corpseExplosion && 372 | choke == enemy.choke && lockOn == enemy.lockOn && talkToTheHand == enemy.talkToTheHand && mark == enemy.mark; 373 | } 374 | 375 | @Override public int hashCode() { 376 | return Objects.hash(health, block, strength, vulnerable, weak, artifact, poison, move, lastMove); 377 | } 378 | } 379 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/RandomGen.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | import java.lang.reflect.Field; 4 | import java.util.ArrayList; 5 | import java.util.Collections; 6 | import java.util.List; 7 | import java.util.Random; 8 | import java.util.concurrent.atomic.AtomicLong; 9 | 10 | public abstract class RandomGen { 11 | public abstract boolean nextBoolean(RandomGenCtx ctx); 12 | public abstract float nextFloat(RandomGenCtx ctx); 13 | public abstract double nextDouble(RandomGenCtx ctx); 14 | public abstract int nextInt(int bound, RandomGenCtx ctx); 15 | public abstract int nextInt(int bound, RandomGenCtx ctx, Object arg); 16 | public abstract long nextLong(RandomGenCtx ctx); 17 | 18 | public RandomGen createWithSeed(long seed) { 19 | throw new UnsupportedOperationException(); 20 | } 21 | 22 | public RandomGen getCopy() { 23 | throw new UnsupportedOperationException(); 24 | } 25 | 26 | public long getSeed(RandomGenCtx ctx) { 27 | throw new UnsupportedOperationException(); 28 | } 29 | 30 | public long getStartingSeed() { 31 | throw new UnsupportedOperationException(); 32 | } 33 | 34 | public void timeTravelToBeginning() { 35 | throw new UnsupportedOperationException(); 36 | } 37 | 38 | public static class RandomGenPlain extends RandomGen { 39 | Random random; 40 | long startingSeed; 41 | 42 | public RandomGenPlain() { 43 | random = new Random(); 44 | startingSeed = getSeed(null); 45 | } 46 | 47 | public RandomGenPlain(Random random) { 48 | this.random = random; 49 | startingSeed = getSeed(null); 50 | } 51 | 52 | public RandomGenPlain(long seed) { 53 | this.random = new Random(seed); 54 | startingSeed = seed; 55 | } 56 | 57 | protected Random getRandomClone() { 58 | long seed; 59 | try { 60 | Field field = Random.class.getDeclaredField("seed"); 61 | field.setAccessible(true); 62 | AtomicLong scrambledSeed = (AtomicLong) field.get(random); 63 | seed = scrambledSeed.get(); 64 | } catch (Exception e) { 65 | throw new RuntimeException(e); 66 | } 67 | return new Random(seed ^ 0x5DEECE66DL); 68 | } 69 | 70 | public boolean nextBoolean(RandomGenCtx ctx) { 71 | return random.nextBoolean(); 72 | } 73 | 74 | public float nextFloat(RandomGenCtx ctx) { 75 | return random.nextFloat(); 76 | } 77 | 78 | public double nextDouble(RandomGenCtx ctx) { 79 | return random.nextDouble(); 80 | } 81 | 82 | public int nextInt(int bound, RandomGenCtx ctx) { 83 | return random.nextInt(bound); 84 | } 85 | 86 | public int nextInt(int bound, RandomGenCtx ctx, Object arg) { 87 | return random.nextInt(bound); 88 | } 89 | 90 | public long nextLong(RandomGenCtx ctx) { 91 | return random.nextLong(); 92 | } 93 | 94 | public RandomGen getCopy() { 95 | return new RandomGenPlain(getRandomClone()); 96 | } 97 | 98 | public void timeTravelToBeginning() { 99 | throw new UnsupportedOperationException(); 100 | } 101 | 102 | public RandomGen createWithSeed(long seed) { 103 | var random = new Random(); 104 | random.setSeed(seed); 105 | return new RandomGenPlain(random); 106 | } 107 | 108 | public long getSeed(RandomGenCtx ctx) { 109 | return RandomGen.getRandomSeed(random); 110 | } 111 | 112 | public long getStartingSeed() { 113 | return startingSeed; 114 | } 115 | } 116 | 117 | private static long getRandomSeed(Random r) { 118 | long seed; 119 | try { 120 | Field field = Random.class.getDeclaredField("seed"); 121 | field.setAccessible(true); 122 | AtomicLong scrambledSeed = (AtomicLong) field.get(r); 123 | seed = scrambledSeed.get(); 124 | seed ^= 0x5DEECE66DL; 125 | } catch (Exception e) { 126 | throw new RuntimeException(e); 127 | } 128 | return seed; 129 | } 130 | 131 | private static Random getRandomClone(Random r) { 132 | return new Random(getRandomSeed(r)); 133 | } 134 | 135 | public static class RandomGenByCtx extends RandomGen { 136 | public List randoms = new ArrayList<>(); 137 | long startingSeed; 138 | public boolean useNewCommonNumberVR; 139 | 140 | public RandomGenByCtx(long seed) { 141 | startingSeed = seed; 142 | Random random = new Random(seed); 143 | for (int i = 0; i < RandomGenCtx.values().length; i++) { 144 | randoms.add(new RandomGenMemory(random.nextLong())); 145 | } 146 | } 147 | 148 | public RandomGenByCtx(RandomGenByCtx other) { 149 | startingSeed = other.startingSeed; 150 | for (int i = 0; i < RandomGenCtx.values().length; i++) { 151 | randoms.add(other.randoms.get(i).getCopy()); 152 | } 153 | } 154 | 155 | public boolean nextBoolean(RandomGenCtx ctx) { 156 | return randoms.get(ctx.ordinal()).nextBoolean(ctx); 157 | } 158 | 159 | public float nextFloat(RandomGenCtx ctx) { 160 | return randoms.get(ctx.ordinal()).nextFloat(ctx); 161 | } 162 | 163 | public double nextDouble(RandomGenCtx ctx) { 164 | return randoms.get(ctx.ordinal()).nextDouble(ctx); 165 | } 166 | 167 | public int nextInt(int bound, RandomGenCtx ctx) { 168 | return randoms.get(ctx.ordinal()).nextInt(bound, ctx); 169 | } 170 | 171 | public int nextInt(int bound, RandomGenCtx ctx, Object arg) { 172 | return randoms.get(ctx.ordinal()).nextInt(bound, ctx, arg); 173 | } 174 | 175 | public long nextLong(RandomGenCtx ctx) { 176 | return randoms.get(ctx.ordinal()).nextLong(ctx); 177 | } 178 | 179 | public RandomGen getCopy() { 180 | return new RandomGenByCtx(this); 181 | } 182 | 183 | public RandomGen createWithSeed(long seed) { 184 | if (Configuration.NEW_COMMON_RANOM_NUMBER_VARIANCE_REDUCTION && (!Configuration.TEST_NEW_COMMON_RANOM_NUMBER_VARIANCE_REDUCTION || useNewCommonNumberVR)) { 185 | return new RandomGenByCtxCOW(seed); 186 | } else { 187 | var random = new Random(); 188 | random.setSeed(seed); 189 | return new RandomGenPlain(random); 190 | } 191 | } 192 | 193 | public void timeTravelToBeginning() { 194 | for (int i = 0; i < RandomGenCtx.values().length; i++) { 195 | randoms.get(i).timeTravelToBeginning(); 196 | } 197 | } 198 | 199 | public long getSeed(RandomGenCtx ctx) { 200 | return randoms.get(ctx.ordinal()).getSeed(ctx); 201 | } 202 | 203 | public long getStartingSeed() { 204 | return startingSeed; 205 | } 206 | } 207 | 208 | public static record Arg(int type, int bound, Object result) { 209 | public final static int NEXT_BOOLEAN = 0; 210 | public final static int NEXT_FLOAT = 1; 211 | public final static int NEXT_INT = 2; 212 | public final static int NEXT_LONG = 3; 213 | public final static int NEXT_DOUBLE = 1; 214 | } 215 | 216 | // after calling timeTravelToStart, will return the same result for the same sequence of next* calls as the original 217 | public static class RandomGenMemory extends RandomGen { 218 | Random random; 219 | private final List memory; 220 | private int currentIdx = -1; 221 | 222 | public RandomGenMemory(long seed) { 223 | random = new Random(seed); 224 | memory = new ArrayList<>(); 225 | } 226 | 227 | public RandomGenMemory(RandomGenMemory other) { 228 | random = getRandomClone(other.random); 229 | currentIdx = other.currentIdx; 230 | memory = new ArrayList<>(other.memory); 231 | } 232 | 233 | public boolean nextBoolean(RandomGenCtx ctx) { 234 | currentIdx++; 235 | if (currentIdx < memory.size()) { 236 | if (memory.get(currentIdx).type == Arg.NEXT_BOOLEAN) { 237 | return (Boolean) memory.get(currentIdx).result; 238 | } else { 239 | memory.subList(currentIdx, memory.size()).clear(); 240 | } 241 | } 242 | boolean result = random.nextBoolean(); 243 | memory.add(new Arg(Arg.NEXT_BOOLEAN, 0, result)); 244 | return result; 245 | } 246 | 247 | public float nextFloat(RandomGenCtx ctx) { 248 | currentIdx++; 249 | if (currentIdx < memory.size()) { 250 | if (memory.get(currentIdx).type == Arg.NEXT_FLOAT) { 251 | return (Float) memory.get(currentIdx).result; 252 | } else { 253 | memory.subList(currentIdx, memory.size()).clear(); 254 | } 255 | } 256 | float result = random.nextFloat(); 257 | memory.add(new Arg(Arg.NEXT_FLOAT, 0, result)); 258 | return result; 259 | } 260 | 261 | public double nextDouble(RandomGenCtx ctx) { 262 | currentIdx++; 263 | if (currentIdx < memory.size()) { 264 | if (memory.get(currentIdx).type == Arg.NEXT_DOUBLE) { 265 | return (Double) memory.get(currentIdx).result; 266 | } else { 267 | memory.subList(currentIdx, memory.size()).clear(); 268 | } 269 | } 270 | double result = random.nextDouble(); 271 | memory.add(new Arg(Arg.NEXT_DOUBLE, 0, result)); 272 | return result; 273 | } 274 | 275 | public int nextInt(int bound, RandomGenCtx ctx) { 276 | currentIdx++; 277 | if (currentIdx < memory.size()) { 278 | if (memory.get(currentIdx).type == Arg.NEXT_INT && memory.get(currentIdx).bound == bound) { 279 | return (Integer) memory.get(currentIdx).result; 280 | } else { 281 | memory.subList(currentIdx, memory.size()).clear(); 282 | } 283 | } 284 | int result = random.nextInt(bound); 285 | memory.add(new Arg(Arg.NEXT_INT, bound, result)); 286 | return result; 287 | } 288 | 289 | public int nextInt(int bound, RandomGenCtx ctx, Object arg) { 290 | return nextInt(bound, ctx); 291 | } 292 | 293 | public long nextLong(RandomGenCtx ctx) { 294 | currentIdx++; 295 | if (currentIdx < memory.size()) { 296 | if (memory.get(currentIdx).type == Arg.NEXT_LONG) { 297 | return (Long) memory.get(currentIdx).result; 298 | } else { 299 | memory.subList(currentIdx, memory.size()).clear(); 300 | } 301 | } 302 | long result = random.nextLong(); 303 | memory.add(new Arg(Arg.NEXT_LONG, 0, result)); 304 | return result; 305 | } 306 | 307 | public void timeTravelToBeginning() { 308 | currentIdx = -1; 309 | } 310 | 311 | public RandomGen getCopy() { 312 | return new RandomGenMemory(this); 313 | } 314 | 315 | public RandomGen createWithSeed(long seed) { 316 | throw new UnsupportedOperationException(); 317 | } 318 | 319 | public long getSeed(RandomGenCtx ctx) { 320 | long seed; 321 | try { 322 | Field field = Random.class.getDeclaredField("seed"); 323 | field.setAccessible(true); 324 | AtomicLong scrambledSeed = (AtomicLong) field.get(random); 325 | seed = scrambledSeed.get(); 326 | } catch (Exception e) { 327 | throw new RuntimeException(e); 328 | } 329 | return seed ^ 0x5DEECE66DL; 330 | } 331 | } 332 | 333 | public static class RandomGenByCtxCOW extends RandomGen { 334 | public List randoms = new ArrayList<>(); 335 | long startingSeed; 336 | 337 | public RandomGenByCtxCOW(long seed) { 338 | startingSeed = seed; 339 | Random random = new Random(seed); 340 | for (int i = 0; i < RandomGenCtx.values().length; i++) { 341 | randoms.add(new RandomGenPlain(random.nextLong())); 342 | } 343 | } 344 | 345 | public RandomGenByCtxCOW(RandomGenByCtxCOW other) { 346 | startingSeed = other.startingSeed; 347 | for (int i = 0; i < RandomGenCtx.values().length; i++) { 348 | randoms.add(other.randoms.get(i).getCopy()); 349 | } 350 | } 351 | 352 | public boolean nextBoolean(RandomGenCtx ctx) { 353 | return randoms.get(ctx.ordinal()).nextBoolean(ctx); 354 | } 355 | 356 | public float nextFloat(RandomGenCtx ctx) { 357 | return randoms.get(ctx.ordinal()).nextFloat(ctx); 358 | } 359 | 360 | public double nextDouble(RandomGenCtx ctx) { 361 | return randoms.get(ctx.ordinal()).nextDouble(ctx); 362 | } 363 | 364 | public int nextInt(int bound, RandomGenCtx ctx) { 365 | return randoms.get(ctx.ordinal()).nextInt(bound, ctx); 366 | } 367 | 368 | public int nextInt(int bound, RandomGenCtx ctx, Object arg) { 369 | return randoms.get(ctx.ordinal()).nextInt(bound, ctx, arg); 370 | } 371 | 372 | public long nextLong(RandomGenCtx ctx) { 373 | return randoms.get(ctx.ordinal()).nextLong(ctx); 374 | } 375 | 376 | public long getSeed(RandomGenCtx ctx) { 377 | return randoms.get(ctx.ordinal()).getSeed(ctx); 378 | } 379 | 380 | public RandomGen getCopy() { 381 | return new RandomGenByCtxCOW(this); 382 | } 383 | 384 | public long getStartingSeed() { 385 | return startingSeed; 386 | } 387 | } 388 | } 389 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/GameStateBuilder.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS; 2 | 3 | import com.alphaStS.card.Card; 4 | import com.alphaStS.card.CardSilent; 5 | import com.alphaStS.enemy.Enemy; 6 | import com.alphaStS.enemy.EnemyCity; 7 | import com.alphaStS.enemy.EnemyEncounter; 8 | import com.alphaStS.enemy.EnemyExordium; 9 | import com.alphaStS.enums.CharacterEnum; 10 | import com.alphaStS.player.Player; 11 | import com.alphaStS.utils.Tuple; 12 | 13 | import java.io.OutputStream; 14 | import java.io.PrintStream; 15 | import java.util.ArrayList; 16 | import java.util.List; 17 | import java.util.function.BiConsumer; 18 | import java.util.function.Consumer; 19 | import java.util.function.Function; 20 | 21 | public class GameStateBuilder { 22 | private Player player = null; 23 | private CharacterEnum character = CharacterEnum.IRONCLAD; 24 | private List cards = new ArrayList<>(); 25 | private List enemies = new ArrayList<>(); 26 | private List enemiesEncounters = new ArrayList<>(); 27 | private List gremlinEncounterFightIndexes = new ArrayList<>(); 28 | private List gremlinGangEncounterIndexes = new ArrayList<>(); 29 | private List relics = new ArrayList<>(); 30 | private List potions = new ArrayList<>(); 31 | private Potion.PotionGenerator potionsGenerator; 32 | private List> enemyReorderings = new ArrayList<>(); 33 | private GameStateRandomization randomization = null; 34 | private GameStateRandomization preBattleRandomization = null; 35 | private GameStateRandomization preBattleGameScenarios = null; 36 | private GameEventHandler endOfPreBattleSetupHandler = null; 37 | private int[] potionsScenarios; 38 | private boolean isBurningElite; 39 | private Function switchBattleHandler; 40 | 41 | public void setSwitchBattleHandler(Function switchBattleHandler) { 42 | this.switchBattleHandler = switchBattleHandler; 43 | } 44 | 45 | public Function getSwitchBattleHandler() { 46 | return switchBattleHandler; 47 | } 48 | 49 | public GameStateBuilder() { 50 | } 51 | 52 | public void setPlayer(Player player) { 53 | this.player = player; 54 | } 55 | 56 | public Player getPlayer() { 57 | return player; 58 | } 59 | 60 | public void addCard(Card card, int count) { 61 | for (int i = 0; i < count; i++) { 62 | cards.add(card); 63 | } 64 | } 65 | 66 | public void addCard(Card... cards) { 67 | for (Card card : cards) { 68 | this.cards.add(card); 69 | } 70 | } 71 | 72 | public List getStartingCards() { 73 | return new ArrayList<>(cards); 74 | } 75 | 76 | public List getStartingCards(Card... extraCards) { 77 | var cards = new ArrayList<>(List.of(extraCards)); 78 | cards.addAll(this.cards); 79 | return cards; 80 | } 81 | 82 | public List getCards() { 83 | return cards; 84 | } 85 | 86 | public List getEnemies() { 87 | return enemies; 88 | } 89 | 90 | public void addEnemyEncounter(EnemyEncounter.EncounterEnum encounterEnum, Enemy... enemies) { 91 | addEnemyEncounter(-1, encounterEnum, enemies); 92 | } 93 | 94 | public void addEnemyEncounter(int startingHealth, EnemyEncounter.EncounterEnum encounterEnum, Enemy... enemies) { 95 | var indexes = new ArrayList>(); 96 | boolean isGremlinLeaderFight = false; 97 | boolean isGremlinGangFight = false; 98 | for (Enemy enemy : enemies) { 99 | if (enemy instanceof EnemyExordium.FatGremlin) { 100 | isGremlinGangFight = true; 101 | break; 102 | } 103 | if (enemy instanceof EnemyCity.GremlinLeader) { 104 | isGremlinLeaderFight = true; 105 | break; 106 | } 107 | } 108 | if (isGremlinLeaderFight) { 109 | indexes.add(new Tuple<>(this.enemies.size(), 0)); 110 | indexes.add(new Tuple<>(this.enemies.size() + 1, 0)); 111 | indexes.add(new Tuple<>(this.enemies.size() + 2, 0)); 112 | indexes.add(new Tuple<>(this.enemies.size() + 3, -1)); 113 | gremlinEncounterFightIndexes.add(enemiesEncounters.size()); 114 | } else if (isGremlinGangFight) { 115 | for (int i = 0; i < enemies.length; i++) { 116 | indexes.add(new Tuple<>(this.enemies.size() + i, -1)); 117 | } 118 | gremlinGangEncounterIndexes.add(enemiesEncounters.size()); 119 | } else { 120 | for (int i = 0; i < enemies.length; i++) { 121 | indexes.add(new Tuple<>(this.enemies.size() + i, -1)); 122 | } 123 | } 124 | enemiesEncounters.add(new EnemyEncounter(startingHealth, encounterEnum, indexes)); 125 | this.enemies.addAll(List.of(enemies)); 126 | } 127 | 128 | public void addEnemyEncounter(Enemy... enemies) { 129 | addEnemyEncounter(EnemyEncounter.EncounterEnum.UNKNOWN, enemies); 130 | } 131 | 132 | public void addEnemyEncounter(int startingHealth, Enemy... enemies) { 133 | addEnemyEncounter(startingHealth, EnemyEncounter.EncounterEnum.UNKNOWN, enemies); 134 | } 135 | 136 | public List getEnemiesEncounters() { 137 | return enemiesEncounters; 138 | } 139 | 140 | public void setBurningElite() { 141 | isBurningElite = true; 142 | } 143 | 144 | public void build(GameState state) { 145 | if (enemiesEncounters.size() > 1) { 146 | GameStateRandomization r = new GameStateRandomization.EnemyEncounterRandomization(enemies, enemiesEncounters); 147 | for (var gremlinLeaderFightIndex : gremlinEncounterFightIndexes) { 148 | var encounter = enemiesEncounters.get(gremlinLeaderFightIndex); 149 | r = r.followByIf(gremlinLeaderFightIndex, new EnemyEncounter.GremlinLeaderRandomization2(enemies, encounter.idxes.get(0).v1()).collapse("Random Gremlins")); 150 | } 151 | for (var gremlinGangEncounterIndexes : gremlinGangEncounterIndexes) { 152 | var encounter = enemiesEncounters.get(gremlinGangEncounterIndexes); 153 | r = r.followByIf(gremlinGangEncounterIndexes, new EnemyEncounter.GremlinGangRandomization(enemies, encounter.idxes.get(0).v1())); 154 | } 155 | randomization = randomization == null ? r : randomization.doAfter(r); 156 | } else if (enemiesEncounters.size() == 1) { 157 | if (gremlinEncounterFightIndexes.size() > 0) { 158 | var encounter = enemiesEncounters.get(gremlinEncounterFightIndexes.get(0)); 159 | var r = new EnemyEncounter.GremlinLeaderRandomization2(enemies, encounter.idxes.get(0).v1()).collapse("Random Gremlins"); 160 | randomization = randomization == null ? r : randomization.doAfter(r); 161 | } else if (gremlinGangEncounterIndexes.size() > 0) { 162 | var encounter = enemiesEncounters.get(gremlinGangEncounterIndexes.get(0)); 163 | var r = new EnemyEncounter.GremlinGangRandomization(enemies, encounter.idxes.get(0).v1()); 164 | randomization = randomization == null ? r : randomization.doAfter(r); 165 | } 166 | state.currentEncounter = enemiesEncounters.get(0).encounterEnum; 167 | } 168 | if (isBurningElite) { 169 | for (Enemy enemy : enemies) { 170 | if (enemy.properties.isElite) { 171 | enemy.markAsBurningElite(); 172 | } 173 | } 174 | randomization = new GameStateRandomization.BurningEliteRandomization().doAfter(randomization); 175 | } 176 | } 177 | 178 | public void addRelic(Relic relic) { 179 | relics.add(relic); 180 | } 181 | 182 | public List getRelics() { 183 | return relics; 184 | } 185 | 186 | public void addPotion(Potion potion) { 187 | potions.add(potion); 188 | } 189 | 190 | public List getPotions() { 191 | int maxPotionSlot = relics.stream().anyMatch((x) -> x instanceof Relic.PotionBelt) ? 4 : 2; 192 | for (Relic relic : relics) { 193 | if (relic instanceof Relic.LizardTail) { 194 | potions.add(new Potion.LizardTail()); 195 | } 196 | } 197 | boolean canGeneratePotion = false; 198 | int basePenaltyRatio = 0; 199 | int possibleGeneratedPotions = 0; 200 | for (int i = 0; i < potions.size(); i++) { 201 | if (potions.get(i) instanceof Potion.EntropicBrew pot) { 202 | canGeneratePotion = true; 203 | basePenaltyRatio = Math.max(basePenaltyRatio, pot.basePenaltyRatio); 204 | possibleGeneratedPotions |= pot.possibleGeneratedPotions; 205 | break; 206 | } 207 | } 208 | for (Card card : cards) { 209 | if (card.cardName.startsWith("Alchemize") && card instanceof CardSilent._AlchemizeT alchemize) { 210 | if (alchemize.basePenaltyRatio > 0) { 211 | canGeneratePotion = true; 212 | basePenaltyRatio = Math.max(basePenaltyRatio, alchemize.basePenaltyRatio); 213 | possibleGeneratedPotions |= alchemize.possibleGeneratedPotions; 214 | break; 215 | } 216 | } 217 | } 218 | if (canGeneratePotion) { 219 | potionsGenerator = new Potion.PotionGenerator(possibleGeneratedPotions); 220 | potionsGenerator.initPossibleGeneratedPotions(character, getPlayer().getMaxHealth(), basePenaltyRatio); 221 | for (int potionIdx = 0; potionIdx < potionsGenerator.commonPotions.size(); potionIdx++) { 222 | for (int j = 0; j < maxPotionSlot; j++) { 223 | potions.add(potionsGenerator.commonPotions.get(potionIdx).get(j)); 224 | } 225 | } 226 | for (int potionIdx = 0; potionIdx < potionsGenerator.uncommonPotions.size(); potionIdx++) { 227 | for (int j = 0; j < maxPotionSlot; j++) { 228 | potions.add(potionsGenerator.uncommonPotions.get(potionIdx).get(j)); 229 | } 230 | } 231 | for (int potionIdx = 0; potionIdx < potionsGenerator.rarePotions.size(); potionIdx++) { 232 | for (int j = 0; j < maxPotionSlot; j++) { 233 | potions.add(potionsGenerator.rarePotions.get(potionIdx).get(j)); 234 | } 235 | } 236 | } 237 | return potions; 238 | } 239 | 240 | public Potion.PotionGenerator getPotionsGenerator() { 241 | return potionsGenerator; 242 | } 243 | 244 | public void setPotionsScenarios(int... scenarios) { 245 | potionsScenarios = scenarios; 246 | } 247 | 248 | public int[] getPotionsScenarios() { 249 | return potionsScenarios; 250 | } 251 | 252 | public void setRandomization(GameStateRandomization randomization) { 253 | this.randomization = randomization; 254 | } 255 | 256 | public GameStateRandomization getRandomization() { 257 | return randomization; 258 | } 259 | 260 | public void setPreBattleRandomization(GameStateRandomization randomization) { 261 | this.preBattleRandomization = randomization; 262 | } 263 | 264 | public GameStateRandomization getPreBattleRandomization() { 265 | return preBattleRandomization; 266 | } 267 | 268 | public void setPreBattleScenarios(GameStateRandomization randomization) { 269 | List> setters = new ArrayList<>(); 270 | var info = randomization.listRandomizations(); 271 | var desc = new String[info.size()]; 272 | for (int i = 0; i < info.size(); i++) { 273 | var _i = i; 274 | setters.add((state) -> state.preBattleScenariosChosenIdx = _i); 275 | desc[i] = info.get(i).desc(); 276 | } 277 | preBattleGameScenarios = randomization.join(new GameStateRandomization.SimpleCustomRandomization(setters)).setDescriptions(desc); 278 | } 279 | 280 | public GameStateRandomization getPreBattleScenarios() { 281 | return preBattleGameScenarios; 282 | } 283 | 284 | public void addEnemyReordering(BiConsumer reordering) { 285 | enemyReorderings.add(reordering); 286 | } 287 | 288 | public List> getEnemyReordering() { 289 | return enemyReorderings; 290 | } 291 | 292 | public void setCharacter(CharacterEnum character) { 293 | this.character = character; 294 | } 295 | 296 | public CharacterEnum getCharacter() { 297 | return character; 298 | } 299 | 300 | public GameEventHandler getEndOfPreBattleSetupHandler() { 301 | return endOfPreBattleSetupHandler; 302 | } 303 | 304 | public void setEndOfPreBattleSetupHandler(GameEventHandler endOfPreBattleSetupHandler) { 305 | this.endOfPreBattleSetupHandler = endOfPreBattleSetupHandler; 306 | } 307 | 308 | public void setGameStateViaInteractiveMode(List commands, boolean trainWholeBattle) { 309 | setEndOfPreBattleSetupHandler(new GameEventHandler() { 310 | @Override public void handle(GameState state) { 311 | // if (state.properties.isTraining && state.skipInteractiveModeSetup) { 312 | if (state.skipInteractiveModeSetup) { 313 | return; 314 | } 315 | state.clearAllSearchInfo(); 316 | new InteractiveMode(new PrintStream(OutputStream.nullOutputStream())).interactiveApplyHistory(state, commands); 317 | } 318 | }); 319 | if (trainWholeBattle) { 320 | preBattleRandomization = new GameStateRandomization.SimpleCustomRandomization(List.of( 321 | (state) -> state.skipInteractiveModeSetup = true 322 | )).union(0.2, new GameStateRandomization.SimpleCustomRandomization(List.of( 323 | (state) -> {} 324 | ))).setDescriptions("Skip interactive mode setup", "Interactive mode setup").doAfter(preBattleRandomization); 325 | } 326 | } 327 | 328 | public void setGameStateViaInteractiveMode(List commands) { 329 | setGameStateViaInteractiveMode(commands, false); 330 | } 331 | 332 | 333 | public List> perScenarioCommands; 334 | public void setPerScenarioCommands(List> perScenarioCommands) { 335 | this.perScenarioCommands = perScenarioCommands; 336 | } 337 | 338 | public List> getPerScenarioCommands() { 339 | return perScenarioCommands; 340 | } 341 | } 342 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/model/ModelPlain.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.model; 2 | 3 | import ai.onnxruntime.OnnxTensor; 4 | import ai.onnxruntime.OrtEnvironment; 5 | import ai.onnxruntime.OrtException; 6 | import ai.onnxruntime.OrtSession; 7 | import ai.onnxruntime.OrtSession.Result; 8 | import ai.onnxruntime.OrtSession.SessionOptions; 9 | import ai.onnxruntime.OrtSession.SessionOptions.OptLevel; 10 | import com.alphaStS.Configuration; 11 | import com.alphaStS.CrashException; 12 | import com.alphaStS.GameActionCtx; 13 | import com.alphaStS.GameState; 14 | import com.alphaStS.utils.Utils; 15 | 16 | import java.util.*; 17 | 18 | public class ModelPlain implements Model { 19 | public static class LRUCache extends LinkedHashMap { 20 | private static final int MAX_ENTRIES; 21 | 22 | static { 23 | if (System.getProperty("os.name").startsWith("Windows")) { 24 | MAX_ENTRIES = 250000; 25 | } else { 26 | MAX_ENTRIES = 10000; 27 | } 28 | } 29 | 30 | protected boolean removeEldestEntry(Map.Entry eldest) { 31 | return size() > MAX_ENTRIES; 32 | } 33 | } 34 | 35 | OrtEnvironment env; 36 | OrtSession session; 37 | String inputName; 38 | public LRUCache cache; 39 | public int calls; 40 | public int cache_hits; 41 | public long time_taken; 42 | 43 | /** 44 | * Naively takes the softmax of the input. 45 | * 46 | * @param input The input array. 47 | * @return The softmax of the input. 48 | */ 49 | public static float[] softmax(float[] input) { 50 | double[] tmp = new double[input.length]; 51 | float max = input[0]; 52 | for (int i = 1; i < input.length; i++) { 53 | max = Math.max(max, input[i]); 54 | } 55 | double sum = 0.0; 56 | for (int i = 0; i < input.length; i++) { 57 | double val = Math.exp(input[i] - max); 58 | sum += val; 59 | tmp[i] = val; 60 | } 61 | for (int i = 0; i < input.length; i++) { 62 | input[i] = (float) (tmp[i] / sum); 63 | } 64 | return input; 65 | } 66 | 67 | public static void softmax(float[] input, int start, int len) { 68 | double[] tmp = new double[len]; 69 | double max = input[0]; 70 | for (int i = 1; i < len; i++) { 71 | max = Math.max(max, input[start + i]); 72 | } 73 | double sum = 0.0; 74 | for (int i = 0; i < len; i++) { 75 | double val = Math.exp(input[start + i] - max); 76 | sum += val; 77 | tmp[i] = val; 78 | } 79 | for (int i = 0; i < len; i++) { 80 | input[start + i] = (float) (tmp[i] / sum); 81 | } 82 | } 83 | 84 | public static float[] softmax(float[] input, float temp) { 85 | double[] tmp = new double[input.length]; 86 | float max = Float.MIN_VALUE; 87 | for (int i = 0; i < input.length; i++) { 88 | max = Math.max(max, input[i] / temp); 89 | } 90 | double sum = 0.0; 91 | for (int i = 0; i < input.length; i++) { 92 | double val = Math.exp(input[i] / temp - max); 93 | sum += val; 94 | tmp[i] = val; 95 | } 96 | float[] output = new float[input.length]; 97 | for (int i = 0; i < output.length; i++) { 98 | output[i] = (float) (tmp[i] / sum); 99 | } 100 | return output; 101 | } 102 | 103 | public ModelPlain(String modelDir) { 104 | try { 105 | env = OrtEnvironment.getEnvironment(); 106 | if (Configuration.ONNX_USE_CUDA_FOR_INFERENCE) { 107 | System.setProperty("onnxruntime.native.path", Configuration.ONNX_LIB_PATH); 108 | } 109 | OrtSession.SessionOptions opts = new SessionOptions(); 110 | opts.setOptimizationLevel(OptLevel.ALL_OPT); 111 | opts.setCPUArenaAllocator(true); 112 | opts.setMemoryPatternOptimization(true); 113 | opts.setExecutionMode(SessionOptions.ExecutionMode.SEQUENTIAL); 114 | opts.setInterOpNumThreads(1); 115 | opts.setIntraOpNumThreads(1); 116 | if (Configuration.ONNX_USE_CUDA_FOR_INFERENCE) { 117 | opts.addCUDA(0); 118 | } 119 | session = env.createSession(modelDir + "/model.onnx", opts); 120 | inputName = session.getInputNames().iterator().next(); 121 | cache = new LRUCache<>(); 122 | } catch (Exception e) { 123 | throw new RuntimeException(e); 124 | } 125 | } 126 | 127 | public void close() { 128 | try { 129 | session.close(); 130 | } catch (OrtException e) { 131 | throw new RuntimeException(e); 132 | } 133 | } 134 | 135 | int prev; 136 | @Override public void startRecordCalls() { 137 | prev = calls - cache_hits; 138 | } 139 | 140 | @Override public int endRecordCalls() { 141 | return calls - cache_hits - prev; 142 | } 143 | 144 | public void resetStats() { 145 | calls = 0; 146 | cache_hits = 0; 147 | time_taken = 0; 148 | } 149 | 150 | public NNOutput eval(GameState state) { 151 | calls += 1; 152 | NNInputHash hash; 153 | try { 154 | hash = new NNInputHash(state.getNNInput()); 155 | } catch (Exception e) { 156 | System.out.println(state); 157 | throw e; 158 | } 159 | NNOutput o = cache.get(hash); 160 | if (o != null) { 161 | if (!Arrays.equals(o.legalActions(), state.getLegalActions()) && 162 | (state.getActionCtx() != GameActionCtx.SELECT_ENEMY || state.properties.enemiesReordering == null)) { 163 | System.err.println(Arrays.toString(state.getNNInput())); 164 | System.err.println(state); 165 | System.err.println(Arrays.toString(o.legalActions())); 166 | System.err.println(Arrays.toString(state.getLegalActions())); 167 | for (int i = 0; i < state.getLegalActions().length; i++) { 168 | System.err.println(state.getActionString(i)); 169 | } 170 | throw CrashException.builder().withState(state).build(); 171 | } 172 | } 173 | 174 | var x = new float[1][]; 175 | x[0] = hash.input; 176 | long start = System.currentTimeMillis(); 177 | try (OnnxTensor test = OnnxTensor.createTensor(env, x); 178 | Result output = session.run(Collections.singletonMap(inputName, test))) { 179 | float v_health = ((float[][]) output.get(0).getValue())[0][0]; 180 | float v_win = ((float[][]) output.get(1).getValue())[0][0]; 181 | 182 | float[] policy = ((float[][]) output.get(2).getValue())[0]; 183 | if (state.properties.maxNumOfActions != state.properties.totalNumOfActions) { 184 | if (state.getActionCtx() == GameActionCtx.PLAY_CARD) { 185 | policy = Arrays.copyOf(policy, state.properties.actionsByCtx[GameActionCtx.PLAY_CARD.ordinal()].length); 186 | } else if (state.getActionCtx() == GameActionCtx.SELECT_ENEMY) { 187 | int startIdx = state.properties.actionsByCtx[GameActionCtx.PLAY_CARD.ordinal()].length; 188 | int lenToCopy = state.properties.actionsByCtx[GameActionCtx.SELECT_ENEMY.ordinal()].length; 189 | policy = Utils.arrayCopy(policy, startIdx, lenToCopy); 190 | } else if (state.getActionCtx() == GameActionCtx.SELECT_SCENARIO) { 191 | int startIdx = state.properties.actionsByCtx[GameActionCtx.PLAY_CARD.ordinal()].length; 192 | if (state.properties.actionsByCtx[GameActionCtx.SELECT_ENEMY.ordinal()] != null) { 193 | startIdx += state.properties.actionsByCtx[GameActionCtx.SELECT_ENEMY.ordinal()].length; 194 | } 195 | int lenToCopy = state.properties.actionsByCtx[GameActionCtx.SELECT_SCENARIO.ordinal()].length; 196 | policy = Utils.arrayCopy(policy, startIdx, lenToCopy); 197 | } 198 | } 199 | if (state.getActionCtx() == GameActionCtx.SELECT_ENEMY) { 200 | var order = state.getEnemyOrder(); 201 | if (order != null) { 202 | var newPolicy = new float[policy.length]; 203 | for (int i = 0; i < newPolicy.length; i++) { 204 | newPolicy[order[i]] = policy[i]; 205 | } 206 | policy = newPolicy; 207 | } 208 | } 209 | 210 | float[] v_other = null; 211 | if (output.size() > 3 && state.properties.extraOutputLen > 0) { 212 | v_other = new float[state.properties.extraOutputLen]; 213 | int idx = 0; 214 | for (int i = 0; i < state.properties.extraTrainingTargets.size(); i++) { 215 | int n = state.properties.extraTrainingTargets.get(i).getNumberOfTargets(); 216 | if (n == 1) { 217 | float value = ((float[][]) output.get(3 + i).getValue())[0][0]; 218 | if (state.properties.extraTrainingTargetsLabel.get(i).startsWith("Z")) { 219 | v_other[idx] = value; 220 | v_other[idx] = (value + 1) / 2; 221 | } else { 222 | v_other[idx] = (value + 1) / 2; 223 | } 224 | } else { 225 | float[] v = ((float[][]) output.get(3 + i).getValue())[0]; 226 | for (int j = 0; j < n; j++) { 227 | v_other[idx + j] = v[j]; 228 | } 229 | softmax(v_other, idx, n); 230 | } 231 | idx += n; 232 | } 233 | } 234 | 235 | v_health = (v_health + 1) / 2; 236 | v_win = (v_win + 1) / 2; 237 | o = new NNOutput(v_health, v_win, v_other, softmax(policyCompressToLegalActions(state, policy)), hasEnemyReordering(state) ? policy : null, state.getLegalActions()); 238 | cache.put(hash, o); 239 | time_taken += System.currentTimeMillis() - start; 240 | return o; 241 | } catch (OrtException exception) { 242 | throw new RuntimeException(exception); 243 | } 244 | } 245 | 246 | public NNOutput tryUseCache(NNInputHash input, GameState state) { 247 | var output = cache.get(input); 248 | if (output != null) { 249 | cache_hits++; 250 | calls++; 251 | output = updatePolicyWithReordering(state, output); 252 | } 253 | return output; 254 | } 255 | 256 | private static boolean hasEnemyReordering(GameState state) { 257 | return state.getActionCtx() == GameActionCtx.SELECT_ENEMY && state.properties.enemiesReordering != null; 258 | } 259 | 260 | private static NNOutput updatePolicyWithReordering(GameState state, NNOutput output) { 261 | if (state.getActionCtx() == GameActionCtx.SELECT_ENEMY) { 262 | var order = state.getEnemyOrder(); 263 | if (order != null) { 264 | var newPolicy = new float[output.rawPolicy().length]; 265 | for (int i = 0; i < newPolicy.length; i++) { 266 | newPolicy[order[i]] = output.rawPolicy()[i]; 267 | } 268 | output = new NNOutput(output.v_health(), output.v_win(), output.v_other(), softmax(policyCompressToLegalActions(state, newPolicy)), output.rawPolicy(), output.legalActions()); 269 | } 270 | } 271 | return output; 272 | } 273 | 274 | private static float[] policyCompressToLegalActions(GameState state, float[] policy) { 275 | var legalActions = state.getLegalActions(); 276 | float[] policyCompressed = new float[legalActions.length]; 277 | for (int i = 0; i < legalActions.length; i++) { 278 | policyCompressed[i] = policy[legalActions[i]]; 279 | } 280 | return policyCompressed; 281 | } 282 | 283 | public List eval(List states, List hashes, int reqCount) { 284 | calls += reqCount; 285 | var x = new float[reqCount][]; 286 | var outputs = new ArrayList(); 287 | var outputsIdx = new int[reqCount]; 288 | var k = 0; 289 | for (int i = 0; i < states.size(); i++) { 290 | outputs.add(null); 291 | if (states.get(i) != null) { 292 | outputsIdx[k] = i; 293 | x[k++] = states.get(i).getNNInput(); 294 | } 295 | } 296 | 297 | long start = System.currentTimeMillis(); 298 | try (OnnxTensor test = OnnxTensor.createTensor(env, x); 299 | Result output = session.run(Collections.singletonMap(inputName, test))) { 300 | for (int row = 0; row < x.length; row++) { 301 | var state = states.get(outputsIdx[row]); 302 | float v_health = ((float[][]) output.get(0).getValue())[row][0]; 303 | float v_win = ((float[][]) output.get(1).getValue())[row][0]; 304 | 305 | float[] policy = ((float[][]) output.get(2).getValue())[row]; 306 | if (state.properties.maxNumOfActions != state.properties.totalNumOfActions) { 307 | if (state.getActionCtx() == GameActionCtx.PLAY_CARD) { 308 | policy = Arrays.copyOf(policy, state.properties.actionsByCtx[GameActionCtx.PLAY_CARD.ordinal()].length); 309 | } else if (state.getActionCtx() == GameActionCtx.SELECT_ENEMY) { 310 | int startIdx = state.properties.actionsByCtx[GameActionCtx.PLAY_CARD.ordinal()].length; 311 | int lenToCopy = state.properties.actionsByCtx[GameActionCtx.SELECT_ENEMY.ordinal()].length; 312 | policy = Utils.arrayCopy(policy, startIdx, lenToCopy); 313 | } else if (state.getActionCtx() == GameActionCtx.SELECT_SCENARIO) { 314 | int startIdx = state.properties.actionsByCtx[GameActionCtx.PLAY_CARD.ordinal()].length; 315 | if (state.properties.actionsByCtx[GameActionCtx.SELECT_ENEMY.ordinal()] != null) { 316 | startIdx += state.properties.actionsByCtx[GameActionCtx.SELECT_ENEMY.ordinal()].length; 317 | } 318 | int lenToCopy = state.properties.actionsByCtx[GameActionCtx.SELECT_SCENARIO.ordinal()].length; 319 | policy = Utils.arrayCopy(policy, startIdx, lenToCopy); 320 | } 321 | } 322 | 323 | float[] v_other = null; 324 | if (output.size() > 3 && state.properties.extraOutputLen > 0) { 325 | v_other = new float[state.properties.extraOutputLen]; 326 | int idx = 0; 327 | for (int i = 0; i < state.properties.extraTrainingTargets.size(); i++) { 328 | int n = state.properties.extraTrainingTargets.get(i).getNumberOfTargets(); 329 | if (n == 1) { 330 | float value = ((float[][]) output.get(3 + i).getValue())[row][0]; 331 | if (state.properties.extraTrainingTargetsLabel.get(i).startsWith("Z")) { 332 | v_other[idx] = value; 333 | v_other[idx] = (value + 1) / 2; 334 | } else { 335 | v_other[idx] = (value + 1) / 2; 336 | } 337 | } else { 338 | float[] v = ((float[][]) output.get(3 + i).getValue())[row]; 339 | for (int j = 0; j < n; j++) { 340 | v_other[idx + j] = v[j]; 341 | } 342 | softmax(v_other, idx, n); 343 | } 344 | idx += n; 345 | } 346 | } 347 | 348 | v_health = (v_health + 1) / 2; 349 | v_win = (v_win + 1) / 2; 350 | var o = new NNOutput(v_health, v_win, v_other, softmax(policyCompressToLegalActions(state, policy)), hasEnemyReordering(state) ? policy : null, state.getLegalActions()); 351 | cache.put(hashes.get(row), o); 352 | outputs.set(outputsIdx[row], o); 353 | } 354 | time_taken += System.currentTimeMillis() - start; 355 | return outputs; 356 | } catch (OrtException exception) { 357 | throw new RuntimeException(exception); 358 | } 359 | } 360 | } 361 | -------------------------------------------------------------------------------- /agent/src/main/java/com/alphaStS/enemy/Enemy.java: -------------------------------------------------------------------------------- 1 | package com.alphaStS.enemy; 2 | 3 | import com.alphaStS.*; 4 | import com.alphaStS.card.Card; 5 | import com.alphaStS.enums.DebuffType; 6 | 7 | import java.util.ArrayList; 8 | import java.util.List; 9 | 10 | public abstract class Enemy extends EnemyReadOnly { 11 | public Enemy(int health, int numOfMoves, boolean useLast2MovesForMoveSelection) { 12 | super(health, numOfMoves, useLast2MovesForMoveSelection); 13 | } 14 | 15 | public Enemy(Enemy other) { 16 | super(other); 17 | copyFieldsFrom(other); 18 | } 19 | 20 | public abstract void nextMove(GameState state, RandomGen random); 21 | public void saveStateForNextMove(GameState state) {} 22 | 23 | public int damage(double n, GameState state) { 24 | if (health <= 0) { 25 | return 0; 26 | } 27 | int dmg = ((int) n) - block; 28 | if (state.properties.boot != null && dmg > 0 && dmg < 5 && state.properties.boot.isRelicEnabledInScenario(state)) { 29 | dmg = 5; 30 | } 31 | int dmgDone = Math.max(0, dmg); 32 | health -= dmgDone; 33 | block = Math.max(0, block - ((int) n)); 34 | if (dmg > 0 && platedArmor > 0) { 35 | platedArmor--; 36 | } 37 | if (health <= 0) { 38 | health = 0; 39 | } 40 | return dmgDone; 41 | } 42 | 43 | public void nonAttackDamage(int n, boolean blockable, GameState state) { 44 | if (health <= 0) { 45 | return; 46 | } 47 | if (blockable) { 48 | health -= Math.max(0, n - block); 49 | block = Math.max(0, block - n); 50 | } else { 51 | health -= n; 52 | } 53 | if (health <= 0) { 54 | health = 0; 55 | } 56 | } 57 | 58 | public void gainBlock(int n) { 59 | block += n; 60 | if (block > 999) { 61 | block = 999; 62 | } 63 | } 64 | 65 | public void gainStrength(int n) { 66 | strength += n; 67 | } 68 | 69 | public void setHealth(int hp) { 70 | health = hp; 71 | } 72 | 73 | public void setMaxHealthInBattle(int hp) { 74 | maxHealthInBattle = hp; 75 | } 76 | 77 | public void setRegeneration(int regen) { 78 | regeneration = regen; 79 | } 80 | 81 | public void setMetallicize(int n) { 82 | metallicize = n; 83 | } 84 | 85 | public void setBurningHealthBuff(boolean b) { 86 | if (properties.hasBurningHealthBuff != b) { 87 | properties = properties.clone(); 88 | properties.hasBurningHealthBuff = b; 89 | } 90 | } 91 | 92 | public void gainMetallicize(int n) { 93 | metallicize += n; 94 | } 95 | 96 | public void gainPlatedArmor(int n) { 97 | platedArmor += n; 98 | } 99 | 100 | public void removeAllDebuffs() { 101 | vulnerable = 0; 102 | weak = 0; 103 | loseStrengthEot = 0; 104 | if (strength < 0) { 105 | strength = 0; 106 | } 107 | poison = 0; 108 | lockOn = 0; 109 | mark = 0; 110 | } 111 | 112 | public void setMove(int move) { 113 | this.move = move; 114 | } 115 | 116 | public void startTurn(GameState state) { 117 | block = 0; 118 | if (poison > 0) { 119 | state.playerDoNonAttackDamageToEnemy(this, poison, false); 120 | poison--; 121 | } 122 | } 123 | 124 | public void endTurn(int turnNum) { 125 | if (turnNum > 0 && vulnerable > 0) { 126 | vulnerable -= 1; 127 | } 128 | if (turnNum > 0 && weak > 0) { 129 | weak -= 1; 130 | } 131 | if (turnNum > 0 && lockOn > 0) { 132 | lockOn -= 1; 133 | } 134 | if (loseStrengthEot != 0) { 135 | strength += loseStrengthEot; 136 | loseStrengthEot = 0; 137 | } 138 | if (choke != 0) { 139 | choke = 0; 140 | } 141 | if (!properties.isElite || (turnNum > 0 && metallicize > 0)) { // todo: burning elite 142 | gainBlock(metallicize); 143 | } 144 | if (platedArmor > 0) { 145 | gainBlock(platedArmor); 146 | } 147 | if (regeneration > 0) { 148 | heal(regeneration); 149 | } 150 | } 151 | 152 | public void reviveReset() { 153 | block = 0; 154 | strength = 0; 155 | vulnerable = 0; 156 | lockOn = 0; 157 | weak = 0; 158 | artifact = 0; 159 | poison = 0; 160 | regeneration = 0; 161 | metallicize = 0; 162 | platedArmor = 0; 163 | loseStrengthEot = 0; 164 | move = -1; 165 | lastMove = -1; 166 | } 167 | 168 | protected void heal(int hp) { 169 | if (health > 0) { 170 | health += Math.min(hp, Math.max(0, getMaxHealthInBattle() - health)); 171 | } 172 | } 173 | 174 | public void react(GameState state, Card card) { 175 | } 176 | 177 | public boolean applyDebuff(GameState state, DebuffType type, int n) { 178 | if (health <= 0) { 179 | return false; 180 | } 181 | if (artifact > 0) { 182 | artifact--; 183 | return false; 184 | } 185 | switch (type) { 186 | case VULNERABLE -> { 187 | this.vulnerable += n; 188 | if (state.properties.championBelt != null && state.properties.championBelt.isRelicEnabledInScenario(state)) { 189 | this.weak += 1; 190 | } 191 | } 192 | case WEAK -> this.weak += n; 193 | case LOSE_STRENGTH -> this.gainStrength(-n); 194 | case LOSE_STRENGTH_EOT -> { 195 | this.loseStrengthEot += n; 196 | this.gainStrength(-n); 197 | } 198 | case POISON -> this.poison += n + (state.properties.sneckoSkull != null && state.properties.sneckoSkull.isRelicEnabledInScenario(state) ? 1 : 0); 199 | case CORPSE_EXPLOSION -> this.corpseExplosion += n; 200 | case CHOKE -> this.choke += n; 201 | case LOCK_ON -> this.lockOn += n; 202 | case TALK_TO_THE_HAND -> this.talkToTheHand += n; 203 | case MARK -> this.mark += n; 204 | } 205 | 206 | if (state.properties.sadisticNatureCounterIdx >= 0 && state.getCounterForRead()[state.properties.sadisticNatureCounterIdx] > 0) { 207 | int damage = state.getCounterForRead()[state.properties.sadisticNatureCounterIdx]; 208 | state.playerDoNonAttackDamageToEnemy(this, damage, false); 209 | } 210 | 211 | return true; 212 | } 213 | 214 | public void setBlock(int n) { 215 | block = n; 216 | } 217 | 218 | public Enemy markAsBurningElite() { 219 | properties.applyBurningEliteBuff(); 220 | return this; 221 | } 222 | 223 | public abstract void randomize(RandomGen random, boolean training, int difficulty); 224 | public int getMaxRandomizeDifficulty() { 225 | return 0; 226 | } 227 | 228 | // used to implement an enemy that can "morph" to one of several enemies to reduce network size for fights 229 | // involving gremlin. Can and will be buggy for other combination of enemies. 230 | public static class MergedEnemy extends Enemy { 231 | public List possibleEnemies; 232 | public Enemy currentEnemy; 233 | public int currentEnemyIdx; 234 | 235 | public MergedEnemy(List possibleEnemies) { 236 | super(0, possibleEnemies.stream().mapToInt((enemy) -> enemy.properties.numOfMoves).reduce(0, Integer::sum), false); 237 | this.possibleEnemies = possibleEnemies; 238 | properties.maxHealth = possibleEnemies.stream().mapToInt((enemy) -> enemy.properties.maxHealth).reduce(0, Math::max); 239 | properties.origMaxHealth = properties.maxHealth; 240 | properties.isElite = possibleEnemies.stream().anyMatch((e) -> e.properties.isElite); 241 | properties.isMinion = possibleEnemies.stream().anyMatch((e) -> e.properties.isMinion); 242 | properties.canVulnerable = possibleEnemies.stream().anyMatch((e) -> e.properties.canVulnerable); 243 | properties.canEntangle = possibleEnemies.stream().anyMatch((e) -> e.properties.canEntangle); 244 | properties.canWeaken = possibleEnemies.stream().anyMatch((e) -> e.properties.canWeaken); 245 | properties.canFrail = possibleEnemies.stream().anyMatch((e) -> e.properties.canFrail); 246 | properties.canGainStrength = possibleEnemies.stream().anyMatch((e) -> e.properties.canGainStrength); 247 | properties.canGainRegeneration = possibleEnemies.stream().anyMatch((e) -> e.properties.canGainRegeneration); 248 | properties.canHeal = possibleEnemies.stream().anyMatch((e) -> e.properties.canHeal); 249 | properties.canGainMetallicize = possibleEnemies.stream().anyMatch((e) -> e.properties.canGainMetallicize); 250 | properties.canGainPlatedArmor = possibleEnemies.stream().anyMatch((e) -> e.properties.canGainPlatedArmor); 251 | properties.canGainBlock = possibleEnemies.stream().anyMatch((e) -> e.properties.canGainBlock); 252 | properties.changePlayerStrength = possibleEnemies.stream().anyMatch((e) -> e.properties.changePlayerStrength); 253 | properties.changePlayerDexterity = possibleEnemies.stream().anyMatch((e) -> e.properties.changePlayerDexterity); 254 | properties.changePlayerFocus = possibleEnemies.stream().anyMatch((e) -> e.properties.changePlayerFocus); 255 | properties.hasBurningEliteBuff = possibleEnemies.stream().anyMatch((e) -> e.properties.hasBurningEliteBuff); 256 | properties.hasArtifact = possibleEnemies.stream().anyMatch((e) -> e.properties.hasArtifact); 257 | properties.actNumber = possibleEnemies.stream().mapToInt((enemy) -> enemy.properties.actNumber).reduce(0, Math::max); 258 | setEnemy(0); 259 | } 260 | 261 | public MergedEnemy(MergedEnemy other) { 262 | super(other); 263 | possibleEnemies = other.possibleEnemies; 264 | currentEnemy = other.currentEnemy.copy(); 265 | currentEnemyIdx = other.currentEnemyIdx; 266 | } 267 | 268 | public void setEnemy(int idx) { 269 | currentEnemy = possibleEnemies.get(idx).copy(); 270 | currentEnemyIdx = idx; 271 | } 272 | 273 | public EnemyProperties getEnemyProperty(int idx) { 274 | return possibleEnemies.get(idx).properties; 275 | } 276 | 277 | public String getDescName() { 278 | return "Merged Enemy (" + String.join(", ", possibleEnemies.stream().map(Enemy::getName).toList()) + ")"; 279 | } 280 | 281 | @Override public void nextMove(GameState state, RandomGen random) { 282 | currentEnemy.nextMove(state, random); 283 | } 284 | 285 | @Override public void doMove(GameState state, EnemyReadOnly self) { 286 | currentEnemy.doMove(state, self); 287 | } 288 | 289 | @Override public Enemy copy() { 290 | return new MergedEnemy(this); 291 | } 292 | 293 | @Override public String getMoveString(GameState state, int move) { 294 | return currentEnemy.getMoveString(state, move); 295 | } 296 | 297 | @Override public String getName() { 298 | return currentEnemy.getName(); 299 | } 300 | 301 | @Override public void gamePropertiesSetup(GameState state) { 302 | for (var e : possibleEnemies) { 303 | e.gamePropertiesSetup(state); 304 | } 305 | } 306 | 307 | @Override public List getPossibleGeneratedCards(GameProperties prop, List cards) { 308 | var newCards = new ArrayList(); 309 | for (var e : possibleEnemies) { 310 | newCards.addAll(e.getPossibleGeneratedCards(prop, cards)); 311 | } 312 | return newCards; 313 | } 314 | 315 | // ******************* generated delegate methods follow **************** 316 | 317 | @Override public int damage(double n, GameState state) { 318 | return currentEnemy.damage(n, state); 319 | } 320 | 321 | @Override public void nonAttackDamage(int n, boolean blockable, GameState state) { 322 | currentEnemy.nonAttackDamage(n, blockable, state); 323 | } 324 | 325 | @Override public void gainBlock(int n) { 326 | currentEnemy.gainBlock(n); 327 | } 328 | 329 | @Override public void gainStrength(int n) { 330 | currentEnemy.gainStrength(n); 331 | } 332 | 333 | @Override public void setHealth(int hp) { 334 | currentEnemy.setHealth(hp); 335 | } 336 | 337 | @Override public void setMaxHealthInBattle(int hp) { 338 | currentEnemy.setMaxHealthInBattle(hp); 339 | } 340 | 341 | @Override public void setRegeneration(int regen) { 342 | currentEnemy.setRegeneration(regen); 343 | } 344 | 345 | @Override public void setMetallicize(int n) { 346 | currentEnemy.setMetallicize(n); 347 | } 348 | 349 | @Override public void setBurningHealthBuff(boolean b) { 350 | currentEnemy.setBurningHealthBuff(b); 351 | } 352 | 353 | @Override public void gainMetallicize(int n) { 354 | currentEnemy.gainMetallicize(n); 355 | } 356 | 357 | @Override public void removeAllDebuffs() { 358 | currentEnemy.removeAllDebuffs(); 359 | } 360 | 361 | @Override public void setMove(int move) { 362 | currentEnemy.setMove(move); 363 | } 364 | 365 | @Override public void startTurn(GameState state) { 366 | currentEnemy.startTurn(state); 367 | } 368 | 369 | @Override public void endTurn(int turnNum) { 370 | currentEnemy.endTurn(turnNum); 371 | } 372 | 373 | @Override public void heal(int hp) { 374 | currentEnemy.heal(hp); 375 | } 376 | 377 | @Override public void react(GameState state, Card card) { 378 | currentEnemy.react(state, card); 379 | } 380 | 381 | @Override public boolean applyDebuff(GameState state, DebuffType type, int n) { 382 | return currentEnemy.applyDebuff(state, type, n); 383 | } 384 | 385 | @Override public Enemy markAsBurningElite() { 386 | super.markAsBurningElite(); 387 | for (int i = 0; i < possibleEnemies.size(); i++) { 388 | possibleEnemies.get(i).markAsBurningElite(); 389 | } 390 | return this; 391 | } 392 | 393 | @Override public void randomize(RandomGen random, boolean training, int difficulty) { 394 | currentEnemy.randomize(random, training, difficulty); 395 | } 396 | 397 | @Override public int getMaxRandomizeDifficulty() { 398 | return currentEnemy.getMaxRandomizeDifficulty(); 399 | } 400 | 401 | @Override public int getNNInputLen(GameProperties prop) { 402 | return currentEnemy.getNNInputLen(prop); 403 | } 404 | 405 | @Override public String getNNInputDesc(GameProperties prop) { 406 | return currentEnemy.getNNInputDesc(prop); 407 | } 408 | 409 | @Override public int writeNNInput(GameProperties prop, float[] input, int idx) { 410 | return currentEnemy.writeNNInput(prop, input, idx); 411 | } 412 | 413 | @Override public String getMoveString(GameState state) { 414 | return currentEnemy.getMoveString(state); 415 | } 416 | 417 | @Override public int getHealth() { 418 | return currentEnemy.getHealth(); 419 | } 420 | 421 | @Override public int getMaxHealthInBattle() { 422 | return currentEnemy.getMaxHealthInBattle(); 423 | } 424 | 425 | @Override public int getBlock() { 426 | return currentEnemy.getBlock(); 427 | } 428 | 429 | @Override public int getStrength() { 430 | return currentEnemy.getStrength(); 431 | } 432 | 433 | @Override public int getLoseStrengthEot() { 434 | return currentEnemy.getLoseStrengthEot(); 435 | } 436 | 437 | @Override public int getVulnerable() { 438 | return currentEnemy.getVulnerable(); 439 | } 440 | 441 | @Override public int getWeak() { 442 | return currentEnemy.getWeak(); 443 | } 444 | 445 | @Override public int getRegeneration() { 446 | return currentEnemy.getRegeneration(); 447 | } 448 | 449 | @Override public int getMetallicize() { 450 | return currentEnemy.getMetallicize(); 451 | } 452 | 453 | @Override public int getArtifact() { 454 | return currentEnemy.getArtifact(); 455 | } 456 | 457 | @Override public boolean hasBurningHealthBuff() { 458 | return currentEnemy.hasBurningHealthBuff(); 459 | } 460 | 461 | @Override public int getMove() { 462 | return currentEnemy.getMove(); 463 | } 464 | 465 | @Override public int getLastMove() { 466 | return currentEnemy.getLastMove(); 467 | } 468 | 469 | @Override public boolean isAlive() { 470 | return currentEnemy.isAlive(); 471 | } 472 | 473 | @Override public boolean isTargetable() { 474 | return currentEnemy.isTargetable(); 475 | } 476 | 477 | @Override public String toString(GameState state) { 478 | return currentEnemy.toString(state); 479 | } 480 | 481 | @Override public String toString() { 482 | return currentEnemy.toString(); 483 | } 484 | 485 | @Override public boolean equals(Object o) { 486 | if (this == o) 487 | return true; 488 | if (o == null || getClass() != o.getClass()) 489 | return false; 490 | MergedEnemy enemy = (MergedEnemy) o; 491 | return currentEnemyIdx == ((MergedEnemy) o).currentEnemyIdx && currentEnemy.equals(enemy.currentEnemy); 492 | } 493 | 494 | @Override public int hashCode() { 495 | return currentEnemy.hashCode(); 496 | } 497 | } 498 | } 499 | --------------------------------------------------------------------------------