├── Pretrained Networks ├── Navigation_q1 └── Navigation_q2 ├── Dota2 DQN ├── .settings │ ├── org.eclipse.m2e.core.prefs │ ├── org.eclipse.core.resources.prefs │ └── org.eclipse.jdt.core.prefs ├── target │ ├── classes │ │ └── hmm │ │ │ └── dota2dqn │ │ │ ├── App.class │ │ │ ├── bots │ │ │ └── Lina.class │ │ │ ├── common │ │ │ ├── Memory.class │ │ │ ├── State.class │ │ │ ├── Configs.class │ │ │ └── MemoryEntry.class │ │ │ └── nn │ │ │ ├── DeepQLearning.class │ │ │ └── NeuralNetwork.class │ └── test-classes │ │ └── hmm │ │ └── dota2dqn │ │ └── AppTest.class ├── src │ ├── main │ │ └── java │ │ │ └── hmm │ │ │ └── dota2dqn │ │ │ ├── App.java │ │ │ ├── common │ │ │ ├── State.java │ │ │ ├── MemoryEntry.java │ │ │ ├── Memory.java │ │ │ └── Configs.java │ │ │ ├── nn │ │ │ ├── DeepQLearning.java │ │ │ └── NeuralNetwork.java │ │ │ └── bots │ │ │ └── Lina.java │ └── test │ │ └── java │ │ └── hmm │ │ └── dota2dqn │ │ └── AppTest.java ├── .project ├── .classpath └── pom.xml ├── Dota2 Addon ├── addon_game_mode.lua ├── hero_tools.lua └── json_helpers.lua └── README.md /Pretrained Networks/Navigation_q1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamidMoghaddam/dota2dqn/HEAD/Pretrained Networks/Navigation_q1 -------------------------------------------------------------------------------- /Pretrained Networks/Navigation_q2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamidMoghaddam/dota2dqn/HEAD/Pretrained Networks/Navigation_q2 -------------------------------------------------------------------------------- /Dota2 DQN/.settings/org.eclipse.m2e.core.prefs: -------------------------------------------------------------------------------- 1 | activeProfiles= 2 | eclipse.preferences.version=1 3 | resolveWorkspaceProjects=true 4 | version=1 5 | -------------------------------------------------------------------------------- /Dota2 DQN/target/classes/hmm/dota2dqn/App.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamidMoghaddam/dota2dqn/HEAD/Dota2 DQN/target/classes/hmm/dota2dqn/App.class -------------------------------------------------------------------------------- /Dota2 DQN/target/classes/hmm/dota2dqn/bots/Lina.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamidMoghaddam/dota2dqn/HEAD/Dota2 DQN/target/classes/hmm/dota2dqn/bots/Lina.class -------------------------------------------------------------------------------- /Dota2 DQN/target/classes/hmm/dota2dqn/common/Memory.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamidMoghaddam/dota2dqn/HEAD/Dota2 DQN/target/classes/hmm/dota2dqn/common/Memory.class -------------------------------------------------------------------------------- /Dota2 DQN/target/classes/hmm/dota2dqn/common/State.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamidMoghaddam/dota2dqn/HEAD/Dota2 DQN/target/classes/hmm/dota2dqn/common/State.class -------------------------------------------------------------------------------- /Dota2 DQN/target/test-classes/hmm/dota2dqn/AppTest.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamidMoghaddam/dota2dqn/HEAD/Dota2 DQN/target/test-classes/hmm/dota2dqn/AppTest.class -------------------------------------------------------------------------------- /Dota2 DQN/target/classes/hmm/dota2dqn/common/Configs.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamidMoghaddam/dota2dqn/HEAD/Dota2 DQN/target/classes/hmm/dota2dqn/common/Configs.class -------------------------------------------------------------------------------- /Dota2 DQN/.settings/org.eclipse.core.resources.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | encoding//src/main/java=UTF-8 3 | encoding//src/test/java=UTF-8 4 | encoding/=UTF-8 5 | -------------------------------------------------------------------------------- /Dota2 DQN/target/classes/hmm/dota2dqn/common/MemoryEntry.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamidMoghaddam/dota2dqn/HEAD/Dota2 DQN/target/classes/hmm/dota2dqn/common/MemoryEntry.class -------------------------------------------------------------------------------- /Dota2 DQN/target/classes/hmm/dota2dqn/nn/DeepQLearning.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamidMoghaddam/dota2dqn/HEAD/Dota2 DQN/target/classes/hmm/dota2dqn/nn/DeepQLearning.class -------------------------------------------------------------------------------- /Dota2 DQN/target/classes/hmm/dota2dqn/nn/NeuralNetwork.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HamidMoghaddam/dota2dqn/HEAD/Dota2 DQN/target/classes/hmm/dota2dqn/nn/NeuralNetwork.class -------------------------------------------------------------------------------- /Dota2 DQN/src/main/java/hmm/dota2dqn/App.java: -------------------------------------------------------------------------------- 1 | package hmm.dota2dqn; 2 | 3 | 4 | import java.io.IOException; 5 | 6 | import hmm.dota2dqn.bots.Lina; 7 | import hmm.dota2dqn.common.Configs; 8 | import se.lu.lucs.dota2.service.Dota2AIService; 9 | 10 | public class App 11 | { 12 | public static void main( String[] args ) throws IOException 13 | { 14 | Configs conf= new Configs(); 15 | Lina lina = new Lina(conf); 16 | @SuppressWarnings("unused") 17 | Dota2AIService ss= new Dota2AIService(lina); 18 | 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /Dota2 DQN/.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | dota2dqn 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | org.eclipse.m2e.core.maven2Builder 15 | 16 | 17 | 18 | 19 | 20 | org.eclipse.jdt.core.javanature 21 | org.eclipse.m2e.core.maven2Nature 22 | 23 | 24 | -------------------------------------------------------------------------------- /Dota2 DQN/src/main/java/hmm/dota2dqn/common/State.java: -------------------------------------------------------------------------------- 1 | package hmm.dota2dqn.common; 2 | 3 | import org.nd4j.linalg.api.ndarray.INDArray; 4 | import org.nd4j.linalg.factory.Nd4j; 5 | 6 | public class State { 7 | private INDArray state; 8 | private int stateSize; 9 | public State(int statesize){ 10 | this.stateSize=statesize; 11 | this.state=Nd4j.zeros(stateSize, stateSize); 12 | } 13 | public State(State original){ 14 | this.state= original.state.dup(); 15 | } 16 | public void addToState(int[] indexes, double val){ 17 | state.putScalar(indexes, val); 18 | 19 | } 20 | public INDArray getFlattenState(){ 21 | return Nd4j.toFlattened(this.state); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /Dota2 DQN/.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled 3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.8 4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve 5 | org.eclipse.jdt.core.compiler.compliance=1.8 6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate 7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate 8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate 9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error 10 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error 11 | org.eclipse.jdt.core.compiler.problem.forbiddenReference=warning 12 | org.eclipse.jdt.core.compiler.source=1.8 13 | -------------------------------------------------------------------------------- /Dota2 DQN/src/test/java/hmm/dota2dqn/AppTest.java: -------------------------------------------------------------------------------- 1 | package hmm.dota2dqn; 2 | 3 | import junit.framework.Test; 4 | import junit.framework.TestCase; 5 | import junit.framework.TestSuite; 6 | 7 | /** 8 | * Unit test for simple App. 9 | */ 10 | public class AppTest 11 | extends TestCase 12 | { 13 | /** 14 | * Create the test case 15 | * 16 | * @param testName name of the test case 17 | */ 18 | public AppTest( String testName ) 19 | { 20 | super( testName ); 21 | } 22 | 23 | /** 24 | * @return the suite of tests being tested 25 | */ 26 | public static Test suite() 27 | { 28 | return new TestSuite( AppTest.class ); 29 | } 30 | 31 | /** 32 | * Rigourous Test :-) 33 | */ 34 | public void testApp() 35 | { 36 | assertTrue( true ); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /Dota2 DQN/src/main/java/hmm/dota2dqn/common/MemoryEntry.java: -------------------------------------------------------------------------------- 1 | package hmm.dota2dqn.common; 2 | 3 | import java.io.Serializable; 4 | 5 | 6 | public class MemoryEntry implements Serializable{ 7 | /** 8 | * 9 | */ 10 | private static final long serialVersionUID = 1L; 11 | private transient State state; 12 | private transient int action; 13 | private transient double reward; 14 | private transient State observedState; 15 | private transient boolean isFinal; 16 | 17 | public MemoryEntry(State state, int action, State observation,double reward,boolean isfinal) { 18 | this.state = new State(state); 19 | this.action = action; 20 | this.reward = reward; 21 | this.observedState = observation; 22 | this.isFinal = isfinal; 23 | } 24 | 25 | public int getAction() { 26 | return action; 27 | } 28 | 29 | public State getNewState() { 30 | return observedState; 31 | } 32 | 33 | public double getReward() { 34 | return reward; 35 | } 36 | 37 | public State getState() { 38 | return state; 39 | } 40 | public boolean isFinal() { 41 | return isFinal; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /Dota2 DQN/.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /Dota2 DQN/src/main/java/hmm/dota2dqn/common/Memory.java: -------------------------------------------------------------------------------- 1 | package hmm.dota2dqn.common; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.List; 6 | 7 | 8 | 9 | public class Memory { 10 | private List memories; 11 | public Memory(){ 12 | /* 13 | The state is generated by Dota 2 every 0.33 second (Please check addon_game_mode.lua) so if a game takes 4 hours, 14 | the number of states will be almost 4*60*60*3=43200 which is not a big size for saving in memory. Therefore, there is no 15 | need to limit memory size. 16 | */ 17 | this.memories=new ArrayList(); 18 | } 19 | public void addToMemory(State observation, State lastState, int action,double reward,boolean isfinal) { 20 | MemoryEntry newMemory = new MemoryEntry(lastState,action,observation,reward,isfinal); 21 | memories.add(newMemory); 22 | } 23 | public List getMiniBatch(){ 24 | Configs conf= new Configs(); 25 | //Randomize memory list 26 | Collections.shuffle(memories); 27 | return memories.subList(0,conf.getMiniBatchSize() ); 28 | 29 | } 30 | public int getMemorySize(){ 31 | return memories.size(); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /Dota2 Addon/addon_game_mode.lua: -------------------------------------------------------------------------------- 1 | if CAddonTemplateGameMode == nil then 2 | CAddonTemplateGameMode = class({}) 3 | end 4 | require("json_helpers") 5 | require("hero_tools") 6 | function Precache( context ) 7 | --[[ 8 | Precache things we know we'll use. Possible file types include (but not limited to): 9 | PrecacheResource( "model", "*.vmdl", context ) 10 | PrecacheResource( "soundfile", "*.vsndevts", context ) 11 | PrecacheResource( "particle", "*.vpcf", context ) 12 | PrecacheResource( "particle_folder", "particles/folder", context ) 13 | ]] 14 | end 15 | 16 | -- Create the game mode when we activate 17 | function Activate() 18 | GameRules.AddonTemplate = CAddonTemplateGameMode() 19 | GameRules.AddonTemplate:InitGameMode() 20 | end 21 | 22 | function CAddonTemplateGameMode:InitGameMode() 23 | print( "Template addon is loaded." ) 24 | counter=0; 25 | GameRules:GetGameModeEntity():SetThink( "OnThink", self, "GlobalThink", 0.33 ) 26 | end 27 | 28 | -- Evaluate the state of the game 29 | function CAddonTemplateGameMode:OnThink() 30 | 31 | if(counter==0) then 32 | Tutorial:AddBot( "npc_dota_hero_lone_druid", "mid", "hard", false ); 33 | counter = counter+1 34 | end 35 | if GameRules:State_Get() == DOTA_GAMERULES_STATE_GAME_IN_PROGRESS then 36 | print( "Game started" ) 37 | local myHero = Entities:FindByClassname(nil, "npc_dota_hero_lina") 38 | local heroEntity = EntIndexToHScript(myHero:GetEntityIndex()) 39 | self:BotUpdating(heroEntity) 40 | --local myhero= Entities:FindAllByClassname( "npc_dota_hero_lina") 41 | --print(myhero:GetName()) 42 | elseif GameRules:State_Get() == DOTA_GAMERULES_STATE_PRE_GAME then 43 | 44 | 45 | elseif GameRules:State_Get() >= DOTA_GAMERULES_STATE_POST_GAME then 46 | return nil 47 | end 48 | return 1 49 | end -------------------------------------------------------------------------------- /Dota2 DQN/pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | hmm 6 | dota2dqn 7 | 0.0.1-SNAPSHOT 8 | jar 9 | 10 | dota2dqn 11 | http://maven.apache.org 12 | 13 | 14 | UTF-8 15 | 16 | 17 | 18 | 19 | junit 20 | junit 21 | 3.8.1 22 | test 23 | 24 | 25 | org.deeplearning4j 26 | deeplearning4j-core 27 | 0.7.2 28 | 29 | 30 | 31 | org.nd4j 32 | nd4j-native-platform 33 | 0.7.2 34 | 35 | 36 | org.nd4j 37 | nd4j-api 38 | 0.7.2 39 | 40 | 41 | 42 | org.datavec 43 | datavec-api 44 | 0.7.2 45 | 46 | 47 | se.lu.lucs 48 | Dota2BotFramework 49 | 0.0.2-SNAPSHOT 50 | 51 | 52 | org.deeplearning4j 53 | rl4j-core 54 | 0.7.2 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /Dota2 DQN/src/main/java/hmm/dota2dqn/common/Configs.java: -------------------------------------------------------------------------------- 1 | package hmm.dota2dqn.common; 2 | 3 | public class Configs { 4 | private int miniBatchSize; 5 | private double heroRange; 6 | private double gameDimansion; 7 | private int oneDStateSize; 8 | private int stateSize; 9 | private int actionSize; 10 | private int[] hiddenLayers; 11 | private double learningRate; 12 | private double gamma; 13 | private String heroName; 14 | private double explorationRate; 15 | private double explorationDecay; 16 | public Configs(){ 17 | this.setMiniBatchSize(32); 18 | //The attacking range of Lina is 670 19 | this.setHeroRange(670); 20 | //The x and y dimansion of game's world is almost 7500 21 | this.setGameDimansion(7500); 22 | /* 23 | * Size of navigation's state (Please check the Readme file for more information) 24 | */ 25 | this.setOneDStateSize((int) Math.ceil(gameDimansion/heroRange)); 26 | this.setStateSize((int) Math.pow(getOneDStateSize(),2)); 27 | /* 28 | * Number of navigation direction in game world is 8 29 | */ 30 | this.setActionSize(8); 31 | 32 | this.setHiddenLayers(new int[]{70,35,15}); 33 | 34 | this.setLearningRate(0.01); 35 | 36 | this.setGamma(0.8); 37 | 38 | this.setHeroName("npc_dota_hero_lina"); 39 | this.setExplorationRate(0.3); 40 | this.setExplorationDecay(0.95); 41 | } 42 | 43 | public int getMiniBatchSize() { 44 | return miniBatchSize; 45 | } 46 | public void setMiniBatchSize(int miniBatchSize) { 47 | this.miniBatchSize = miniBatchSize; 48 | } 49 | public double getHeroRange() { 50 | return heroRange; 51 | } 52 | public void setHeroRange(double heroRange) { 53 | this.heroRange = heroRange; 54 | } 55 | public double getGameDimansion() { 56 | return gameDimansion; 57 | } 58 | public void setGameDimansion(double gameDimansion) { 59 | this.gameDimansion = gameDimansion; 60 | } 61 | public int getOneDStateSize() { 62 | return oneDStateSize; 63 | } 64 | public void setOneDStateSize(int oneDStateSize) { 65 | this.oneDStateSize = oneDStateSize; 66 | } 67 | 68 | public int getStateSize() { 69 | return stateSize; 70 | } 71 | 72 | public void setStateSize(int stateSize) { 73 | this.stateSize = stateSize; 74 | } 75 | 76 | public int getActionSize() { 77 | return actionSize; 78 | } 79 | 80 | public void setActionSize(int actionSize) { 81 | this.actionSize = actionSize; 82 | } 83 | 84 | public int[] getHiddenLayers() { 85 | return hiddenLayers; 86 | } 87 | 88 | public void setHiddenLayers(int[] hiddenLayers) { 89 | this.hiddenLayers = hiddenLayers; 90 | } 91 | 92 | public double getLearningRate() { 93 | return learningRate; 94 | } 95 | 96 | public void setLearningRate(double learningRate) { 97 | this.learningRate = learningRate; 98 | } 99 | 100 | public double getGamma() { 101 | return gamma; 102 | } 103 | 104 | public void setGamma(double gamma) { 105 | this.gamma = gamma; 106 | } 107 | 108 | public String getHeroName() { 109 | return heroName; 110 | } 111 | 112 | public void setHeroName(String heroName) { 113 | this.heroName = heroName; 114 | } 115 | 116 | public double getExplorationRate() { 117 | return explorationRate; 118 | } 119 | 120 | public void setExplorationRate(double explorationRate) { 121 | this.explorationRate = explorationRate; 122 | } 123 | 124 | public double getExplorationDecay() { 125 | return explorationDecay; 126 | } 127 | 128 | public void setExplorationDecay(double explorationDecay) { 129 | this.explorationDecay = explorationDecay; 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /Dota2 DQN/src/main/java/hmm/dota2dqn/nn/DeepQLearning.java: -------------------------------------------------------------------------------- 1 | package hmm.dota2dqn.nn; 2 | 3 | import java.io.IOException; 4 | import java.util.ArrayList; 5 | import java.util.List; 6 | import java.util.Random; 7 | 8 | 9 | import org.deeplearning4j.berkeley.Pair; 10 | import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator; 11 | import org.nd4j.linalg.api.ndarray.INDArray; 12 | import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; 13 | import org.nd4j.linalg.factory.Nd4j; 14 | import hmm.dota2dqn.common.Configs; 15 | import hmm.dota2dqn.common.Memory; 16 | import hmm.dota2dqn.common.MemoryEntry; 17 | import hmm.dota2dqn.common.State; 18 | 19 | public class DeepQLearning { 20 | private Random randomfunc = new Random(); 21 | private NeuralNetwork nnet1; 22 | private NeuralNetwork nnet2; 23 | Configs conf; 24 | public DeepQLearning(Configs confInput){ 25 | this.conf=confInput; 26 | this.nnet1= new NeuralNetwork(); 27 | this.nnet2=new NeuralNetwork(); 28 | } 29 | public void generateNNs(){ 30 | nnet1.createNN(conf); 31 | nnet2.createNN(conf); 32 | } 33 | public void loadNNs(String filename) throws IOException{ 34 | nnet1.loadModel(conf, filename+"_q1"); 35 | nnet2.loadModel(conf, filename+"_q2"); 36 | } 37 | public void saveNN(String filename) throws IOException{ 38 | nnet1.saveModel(filename+"_q1"); 39 | nnet2.saveModel(filename+"_q2"); 40 | } 41 | public int selectAction(State state, double explorationRate) { 42 | int action=0; 43 | INDArray qValues1 = nnet1.predictAction(state.getFlattenState()); 44 | INDArray qValues2 = nnet2.predictAction(state.getFlattenState()); 45 | INDArray qValues = qValues1.add(qValues2); 46 | double random = randomfunc.nextDouble(); 47 | if (random < explorationRate) { 48 | action = randomfunc.nextInt(conf.getActionSize()); 49 | } else { 50 | action=findMaxIndex(qValues); 51 | } 52 | return action; 53 | } 54 | private int findMaxIndex(INDArray inputqValues){ 55 | int returnq=0; 56 | double max = (double) inputqValues.maxNumber(); 57 | for(int i=0;i miniBatch = memory.getMiniBatch(); 68 | List> pairedList= new ArrayList>(); 69 | double rand = randomfunc.nextDouble(); 70 | for (MemoryEntry entry : miniBatch) { 71 | INDArray qValues = Nd4j.zeros(conf.getActionSize(),1); 72 | INDArray qValuesObserved = Nd4j.zeros(conf.getActionSize(),1); 73 | INDArray lastState= entry.getState().getFlattenState(); 74 | INDArray observedState =entry.getNewState().getFlattenState(); 75 | if (rand < 0.5) { 76 | qValues = nnet1.predictAction(lastState); 77 | qValuesObserved = nnet2.predictAction(observedState); 78 | } else { 79 | qValues = nnet2.predictAction(lastState); 80 | qValuesObserved = nnet1.predictAction(observedState); 81 | } 82 | double targetValue = calculateOptimalQ(qValuesObserved, entry.getReward(), entry.isFinal()); 83 | 84 | double oldvalue = qValues.getDouble(entry.getAction(), 1); 85 | double newValue=(1-conf.getLearningRate())*oldvalue+ targetValue; 86 | qValues.putScalar(entry.getAction(),1, newValue); 87 | Pair pairedInputOutput= new Pair(lastState, qValues); 88 | pairedList.add(pairedInputOutput); 89 | } 90 | DataSetIterator datasetIterator = new INDArrayDataSetIterator(pairedList, conf.getMiniBatchSize()); 91 | if (rand < 0.5) { 92 | nnet1.trainNetwork(datasetIterator); 93 | } else { 94 | nnet2.trainNetwork(datasetIterator); 95 | } 96 | } 97 | private double calculateOptimalQ(INDArray qValuesObservedState, double reward, boolean isDone) { 98 | double maxQ= (double) qValuesObservedState.maxNumber(); 99 | if (isDone) 100 | return reward; 101 | else 102 | return reward + conf.getGamma() * maxQ; 103 | } 104 | 105 | } 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Q-Network for Dota2 2 | I developed this project as a part of my PhD research so this is an academic research project and not affiliated with Valve. 3 | # What is the project? 4 | In 2015, DeepMind (Owned by Google) company developed a reinforcement learning agent, termed a deep Q-network (DQN), for Atari 2600 games (https://deepmind.com/research/dqn/). The AI gets game screen pixels then use its algorithm to find optimal action for each state of game. In 2016, the DeepMind claimed that they are developing an AI for a game, StarCraft, with high dimensional actions (https://deepmind.com/blog/deepmind-and-blizzard-release-starcraft-ii-ai-research-environment/). Unfortunately, they have not release their code yet (9 January 2017) so I developed my algorithm for "Dota2" which is a high dimensional game too. 5 | # Programming language 6 | The game uses a built in LUA engine for scripting however that engine has some limitation e.g. it is not possible to load file from disk. Therefore, I used a framework which is developed by Tobias Mahlmann (https://github.com/lightbringer/dota2ai) in order to transfer data between JAVA and the game. 7 | # Algorithm 8 | A simple DQN is useless in a game with high dimensional actions so using a hierarchical deep reinforcement learning is essential (https://arxiv.org/pdf/1604.06057.pdf). In the hierarchical deep Q-network (HDQN), a separated DQN for each sub goal is developed. For this project, a DQN should be developed for these goals: 9 | - Navigation 10 | - Attack 11 | - Buy and Sell 12 | - Retreat. 13 | 14 | 15 | At the moment, only navigation DQN has been developed so others will be added to the project as soon as possible. 16 | # Navigation's DQN 17 | A double DQN implemented by using [Deeplearning4j](https://deeplearning4j.org/) libraries with 3 fully connected hidden layers (70, 35, 15 neurons in each layer). The exploration rate is 0.3 ([Please check this](Dota2 DQN/src/main/java/hmm/dota2dqn/common/Configs.java)) in order to agent go fast to the lane at the beginning of the game so if you want the agent choose different lane you should increase it. 18 | # State 19 | A DQN usually uses some convolutional layers in order to convert game screen (state) into a matrix of numbers then connects the matrix to some fully connected hidden layers and output layer (actions). Although it is possible to convert Dota2 screen into the matrix by using convolutional layers, the matrix can be generated by game engine directly. The game world dimension is almost 7500*7500 so it is possible to create a 2D matrix[7500][7500] and put each object in a cell of this matrix based on the object location. However, this matrix's size (7550*7500) is very big and not applicable for using as a state. In Dota2, a hero can communicate with a target when the target is located in the hero attack range e.g. Lina can attack to enemies which are located in 670 distance of her. Therefore, the world matrix size can be decreased by dividing it to the hero attack range for example the world matrix size for Lina would be 12*12 (7500/670=11.19 and its ceiling is 12). It is worth noting that the minimum attack range of heroes is 140 (http://dota2.gamepedia.com/Attack_range) so final world matrix for all heroes has good size for using as a state. 20 | # Goal 21 | The agent achieve its goal when distance of it and an enemy's object (hero, tower, npc, building) become less than its attack target (670 for lina). 22 | # Actions 23 | There are 8 directions in Dota2 (four cardinal and four intercardinal) so the navigation DQN has 8 outputs. 24 | # Rewards 25 | The time between two update is 0.33 ([Please check this](Dota2 Addon/addon_game_mode.lua)) and the agent get -1 reward in each update unless it achieves to the goal. The reward of achieving to goal is 100. 26 | # How to run? 27 | - Open Dota 2 workshop 28 | - Create an empty addon and name it Dota2DQN 29 | - Copy all files of [this folder](Dota2 Addon/) to the addon vscripts folder 30 | - Launch the custom game 31 | - Open VConsole and write "dota_lunch_custom_game Dota2DQN dota" into the Command field. 32 | - Open Eclipse and import the project in "Dota2 DQN" folder as a Maven project 33 | - Run App.java 34 | 35 | 36 | # Use pre-trained network 37 | I trained the agent to start from the fountain and go to middle lane by my laptop and provided them in the [Pretrained Networks](Pretrained Networks/). It is worth noting that training the agent to navigate in a full game needs more powerful computer than my laptop. In order to use those pre trained networks just copy them into the root folder of your project in the Eclipse. 38 | 39 | -------------------------------------------------------------------------------- /Dota2 DQN/src/main/java/hmm/dota2dqn/nn/NeuralNetwork.java: -------------------------------------------------------------------------------- 1 | package hmm.dota2dqn.nn; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | 6 | 7 | import org.deeplearning4j.nn.api.OptimizationAlgorithm; 8 | import org.deeplearning4j.nn.conf.MultiLayerConfiguration; 9 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 10 | import org.deeplearning4j.nn.conf.Updater; 11 | import org.deeplearning4j.nn.conf.layers.DenseLayer; 12 | import org.deeplearning4j.nn.conf.layers.OutputLayer; 13 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; 14 | import org.deeplearning4j.nn.weights.WeightInit; 15 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener; 16 | import org.deeplearning4j.rl4j.util.Constants; 17 | import org.deeplearning4j.util.ModelSerializer; 18 | import org.nd4j.linalg.activations.Activation; 19 | import org.nd4j.linalg.api.ndarray.INDArray; 20 | import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; 21 | import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; 22 | 23 | 24 | import hmm.dota2dqn.common.Configs; 25 | 26 | public class NeuralNetwork { 27 | private MultiLayerNetwork mln; 28 | private int inputs, outputs; 29 | private int[] hiddenLayers; 30 | public NeuralNetwork(){}; 31 | public void createNN(Configs generalConf){ 32 | 33 | double learningRate=generalConf.getLearningRate(); 34 | this.inputs=generalConf.getStateSize(); 35 | this.outputs=generalConf.getActionSize(); 36 | this.hiddenLayers=generalConf.getHiddenLayers(); 37 | MultiLayerConfiguration conf = null; 38 | if (hiddenLayers.length == 0) { 39 | conf = new NeuralNetConfiguration.Builder() 40 | .iterations(1) 41 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 42 | .learningRate(learningRate) 43 | .biasLearningRate(learningRate) 44 | .biasInit(0) 45 | .updater(Updater.RMSPROP) 46 | .list() 47 | .layer(0, new OutputLayer.Builder(LossFunction.MSE) 48 | .weightInit(WeightInit.XAVIER) 49 | .activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER) 50 | .nIn(inputs).nOut(outputs).build()) 51 | .pretrain(false).backprop(true).build(); 52 | } else { 53 | NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() 54 | .iterations(1) 55 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 56 | .learningRate(learningRate) 57 | .biasLearningRate(learningRate) 58 | .biasInit(0) 59 | .updater(Updater.RMSPROP) 60 | .list() 61 | .layer(0, new DenseLayer.Builder().nIn(inputs).nOut(hiddenLayers[0]) 62 | .weightInit(WeightInit.XAVIER) 63 | .activation(Activation.RELU) 64 | .build()); 65 | for (int i = 0 ; i < hiddenLayers.length - 1 ; i++) { 66 | builder.layer(i+1, new DenseLayer.Builder().nIn(hiddenLayers[i]).nOut(hiddenLayers[i+1]) 67 | .weightInit(WeightInit.XAVIER) 68 | .activation(Activation.RELU) 69 | .build()); 70 | } 71 | conf = builder.layer(hiddenLayers.length, new OutputLayer.Builder(LossFunction.MSE) 72 | .weightInit(WeightInit.XAVIER) 73 | .activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER) 74 | .nIn(hiddenLayers[hiddenLayers.length - 1]).nOut(outputs).build()) 75 | .pretrain(false).backprop(true).build(); 76 | } 77 | MultiLayerNetwork model = new MultiLayerNetwork(conf); 78 | model.init(); 79 | 80 | model.setListeners(new ScoreIterationListener(Constants.NEURAL_NET_ITERATION_LISTENER)); 81 | this.mln = model; 82 | } 83 | 84 | public void loadModel(Configs conf,String fileName) throws IOException { 85 | this.mln = ModelSerializer.restoreMultiLayerNetwork(fileName); 86 | this.inputs=conf.getStateSize(); 87 | this.outputs=conf.getActionSize(); 88 | } 89 | public void saveModel(String path) throws IOException { 90 | //Save the model 91 | File locationToSave = new File(path); 92 | boolean saveUpdater = true; 93 | ModelSerializer.writeModel(this.mln, locationToSave, saveUpdater); 94 | } 95 | public synchronized INDArray predictAction(INDArray input) { 96 | INDArray resultVector = mln.output(input); 97 | return resultVector; 98 | } 99 | public void trainNetwork(DataSetIterator datasetIterator) { 100 | 101 | mln.fit(datasetIterator); 102 | } 103 | 104 | 105 | } 106 | -------------------------------------------------------------------------------- /Dota2 Addon/hero_tools.lua: -------------------------------------------------------------------------------- 1 | ---Hero Commands 2 | local function BitAND(a,b)--Bitwise and 3 | local p,c=1,0 4 | while a>0 and b>0 do 5 | local ra,rb=a%2,b%2 6 | if ra+rb>1 then c=c+p end 7 | a,b,p=(a-ra)/2,(b-rb)/2,p*2 8 | end 9 | return c 10 | end 11 | 12 | function CAddonTemplateGameMode:ParseHeroLevelUp(eHero, reply) 13 | local result = package.loaded['game/dkjson'].decode(reply) 14 | local abilityPoints = eHero:GetAbilityPoints() 15 | if abilityPoints <= 0 then 16 | Warning(eHero:GetName() .. " has no ability points. Why am I levelling up?") 17 | return 18 | end 19 | 20 | local abilityIndex = result.abilityIndex 21 | 22 | if abilityIndex > -1 then --a bot may send -1 to delay the level up 23 | local ability = eHero:GetAbilityByIndex(abilityIndex) 24 | if ability:GetLevel() == ability:GetMaxLevel() then 25 | Warning(eHero:GetName() .. ": " .. ability:GetName() .. " is maxed out") 26 | return 27 | end 28 | ability:UpgradeAbility(false) 29 | eHero:SetAbilityPoints(abilityPoints - 1) --UpgradeAbility doesn't decrease the ability points 30 | Say(nil, eHero:GetName() .. " levelled up ability " .. abilityIndex, false) 31 | end 32 | end 33 | -- Main entry function -- 34 | function CAddonTemplateGameMode:ParseHeroCommand(eHero, reply) 35 | local result = package.loaded['game/dkjson'].decode(reply) 36 | local command = result.command 37 | 38 | --TODO deal with abilities that the hero can interrupt 39 | local eAbility = eHero:GetCurrentActiveAbility() 40 | if eAbility then 41 | self:Noop(eHero, result) 42 | Warning("active not null") 43 | return 44 | end 45 | 46 | if command == "MOVE" then 47 | self:MoveTo(eHero, result) 48 | elseif command == "ATTACK" then 49 | self:Attack(eHero, result) 50 | elseif command == "CAST" then 51 | self:Cast(eHero, result) 52 | elseif command == "BUY" then 53 | self:Buy(eHero, result) 54 | elseif command == "SELL" then 55 | self:Sell(eHero, result) 56 | elseif command == "USE_ITEM" then 57 | self:UseItem(eHero, result) 58 | elseif command == "NOOP" then 59 | self:Noop(eHero, result) 60 | else 61 | --self._Error = true 62 | Warning(eHero:GetName() .. " sent invalid command " .. reply) 63 | end 64 | end 65 | -- Buying items -- 66 | function CAddonTemplateGameMode:Buy(eHero, result) 67 | local itemName = result.item 68 | local itemCost = GetItemCost(itemName) 69 | 70 | if eHero:GetGold() < itemCost then 71 | Warning(eHero:GetName() .. " tried to buy " .. itemName .. " but couldn't afford it") 72 | return 73 | end 74 | 75 | 76 | local targetSlot = -1 77 | for i=DOTA_ITEM_SLOT_1 ,DOTA_ITEM_SLOT_6 ,1 do 78 | if not eHero:GetItemInSlot(i) then 79 | targetSlot = i 80 | break 81 | end 82 | end 83 | 84 | --TODO item availability is missing 85 | 86 | if targetSlot < 0 then 87 | Warning(eHero:GetName() .. " tried to buy " .. itemName .. " but has no space left") 88 | --TODO buy into stash 89 | return 90 | end 91 | 92 | local eShop = Entities:FindByClassnameWithin(nil, "trigger_shop", eHero:GetOrigin(), 0.01) 93 | if eShop then 94 | local shopType = GetShopType(eShop:GetModelName()) 95 | 96 | if IsAvailableInShop(itemName, shopType) then 97 | 98 | local eItem = CreateItem(itemName, eHero, eHero) 99 | EmitSoundOn("General.Buy", eHero) 100 | eHero:AddItem(eItem) 101 | eHero:SpendGold(itemCost, DOTA_ModifyGold_PurchaseItem) --should the reason take DOTA_ModifyGold_PurchaseConsumable into account? 102 | else 103 | --TODO ping actually the right shop 104 | PingNearestShop(eHero) 105 | Warning("Shop is of type " .. shopType) 106 | return 107 | end 108 | else 109 | PingNearestShop(eHero) 110 | return 111 | end 112 | 113 | Say(nil, eHero:GetName() .." bought " .. itemName, false) 114 | end 115 | 116 | function CAddonTemplateGameMode:Sell(eHero, result) 117 | local slot = result.slot 118 | 119 | 120 | if eHero:CanSellItems() then 121 | local eItem = eHero:GetItemInSlot(slot) 122 | if eItem then 123 | --TODO GetCost does not return the value altered, i.e. halved 124 | EmitSoundOn("General.Sell", eHero) 125 | eHero:ModifyGold(eItem:GetCost(), true, DOTA_ModifyGold_SellItem ) 126 | eHero:RemoveItem(eItem) 127 | else 128 | Warning("No item in slot " .. slot) 129 | end 130 | else 131 | PingNearestShop(eHero) 132 | Warning("Bot tried to sell item outside shop") 133 | end 134 | end 135 | 136 | function CAddonTemplateGameMode:UseItem(eHero, result) 137 | local slot = result.slot 138 | local eItem = eHero:GetItemInSlot(slot) 139 | if eItem then 140 | self:UseAbility(eHero, eItem) 141 | else 142 | Warning("Bot tried to use item in empty slot") 143 | end 144 | end 145 | 146 | function CAddonTemplateGameMode:Noop(eHero, result) 147 | --Noop 148 | end 149 | 150 | function CAddonTemplateGameMode:MoveTo(eHero, result) 151 | eHero:MoveToPosition(Vector(result.x, result.y, result.z)) 152 | Say(nil, eHero:GetName() .. " moving to " .. result.x ..", " .. result.y ..", " .. result.z, false) 153 | end 154 | 155 | function CAddonTemplateGameMode:Attack(eHero, result) 156 | --Might want to check attack range 157 | --eHero:PerformAttack(EntIndexToHScript(result.target), true, true, false, true) 158 | eHero:MoveToTargetToAttack(EntIndexToHScript(result.target)) 159 | Say(nil, eHero:GetName() .. " attacking " .. result.target, false) 160 | end 161 | 162 | function CAddonTemplateGameMode:Cast(eHero, result) 163 | local eAbility = eHero:GetAbilityByIndex(result.ability) 164 | self:UseAbility(eHero, eAbility) 165 | end 166 | 167 | function CAddonTemplateGameMode:UseAbility(eHero, eAbility) 168 | local level = eAbility:GetLevel() 169 | local manaCost = eAbility:GetManaCost(level) 170 | local player = eHero:GetPlayerOwnerID() 171 | 172 | if eHero:GetMana() < manaCost then 173 | Warning("Bot tried to use ability without mana") 174 | elseif eAbility:GetCooldownTimeRemaining() > 0 then 175 | Warning("Bot tried to use ability still on cooldown") 176 | else 177 | eAbility:StartCooldown(eAbility:GetCooldown(level)) 178 | eAbility:PayManaCost() 179 | eAbility:OnSpellStart() 180 | local behaviour = eAbility:GetBehavior() 181 | --There is some logic missing here to check for range and make the hero face the right direction 182 | if (BitAND(behaviour, DOTA_ABILITY_BEHAVIOR_NO_TARGET)) then 183 | Say(nil ,eHero:GetName() .. " casting " .. eAbility:GetName(), false) 184 | eHero:CastAbilityNoTarget(eAbility, player) 185 | elseif(BitAND(behaviour, DOTA_ABILITY_BEHAVIOR_UNIT_TARGET )) then 186 | local target = EntIndexToHScript(result.target) 187 | if target:IsAlive() then 188 | Say(nil ,eHero:GetName() .. " casting " .. eAbility:GetName() .. " on unit " .. target:GetName(), false) 189 | eHero:CastAbilityOnTarget(target, eAbility, player) 190 | end 191 | elseif(BitAND(behaviour, DOTA_ABILITY_BEHAVIOR_POINT )) then 192 | Say(nil ,eHero:GetName() .. " casting " .. eAbility:GetName() .. " on " .. result.x .. ", " .. result.y .. ", " .. result.z, false) 193 | eHero:CastAbilityOnPosition(Vector(result.x, result.y, result.z), eAbility, player) 194 | else 195 | Warning(eHero:GetName() .. " sent invalid cast command " .. behaviour) 196 | 197 | end 198 | 199 | 200 | end 201 | end -------------------------------------------------------------------------------- /Dota2 Addon/json_helpers.lua: -------------------------------------------------------------------------------- 1 | function CAddonTemplateGameMode:JSONWorld(eHero) 2 | local world = {} 3 | world.entities = {} 4 | 5 | 6 | --TODO there are apparently around 2300 trees on the map. Sending those that are NOT standing might be more efficient 7 | local tree = Entities:FindByClassname(nil, "ent_dota_tree") 8 | while tree ~= nil do 9 | if eHero:CanEntityBeSeenByMyTeam(tree) and tree:IsStanding() then 10 | world.entities[tree:entindex()]=self:JSONtree(tree) 11 | end 12 | tree = Entities:FindByClassname(tree, "ent_dota_tree") 13 | end 14 | 15 | 16 | local allUnits = FindUnitsInRadius(eHero:GetTeamNumber(), 17 | eHero:GetOrigin(), 18 | nil, 19 | FIND_UNITS_EVERYWHERE, 20 | DOTA_UNIT_TARGET_TEAM_BOTH, 21 | DOTA_UNIT_TARGET_ALL, 22 | DOTA_UNIT_TARGET_FLAG_FOW_VISIBLE, 23 | FIND_ANY_ORDER, 24 | true) 25 | 26 | for _,unit in pairs(allUnits) do 27 | world.entities[unit:entindex()]=self:JSONunit(unit) 28 | end 29 | 30 | --so FindUnitsInRadius somehow ignores all the buildings 31 | local buildings = {} 32 | buildings[0] = Entities:FindByName(nil, "dota_goodguys_tower1_bot") 33 | buildings[1] = Entities:FindByName(nil, "dota_goodguys_tower2_bot") 34 | buildings[2] = Entities:FindByName(nil, "dota_goodguys_tower3_bot") 35 | 36 | buildings[3] = Entities:FindByName(nil, "dota_goodguys_tower1_mid") 37 | buildings[4] = Entities:FindByName(nil, "dota_goodguys_tower2_mid") 38 | buildings[5] = Entities:FindByName(nil, "dota_goodguys_tower3_mid") 39 | 40 | buildings[6] = Entities:FindByName(nil, "dota_goodguys_tower1_top") 41 | buildings[7] = Entities:FindByName(nil, "dota_goodguys_tower2_top") 42 | buildings[8] = Entities:FindByName(nil, "dota_goodguys_tower3_top") 43 | 44 | buildings[9] = Entities:FindByName(nil, "dota_goodguys_tower4_top") 45 | buildings[10] = Entities:FindByName(nil, "dota_goodguys_tower4_bot") 46 | 47 | buildings[11] = Entities:FindByName(nil, "good_rax_melee_bot") 48 | buildings[12] = Entities:FindByName(nil, "good_rax_range_bot") 49 | buildings[13] = Entities:FindByName(nil, "good_rax_melee_mid") 50 | buildings[14] = Entities:FindByName(nil, "good_rax_range_mid") 51 | buildings[15] = Entities:FindByName(nil, "good_rax_melee_top") 52 | buildings[16] = Entities:FindByName(nil, "good_rax_range_top") 53 | 54 | buildings[17] = Entities:FindByName(nil, "ent_dota_fountain_good") 55 | 56 | --dire 57 | buildings[18] = Entities:FindByName(nil, "dota_badguys_tower1_bot") 58 | buildings[19] = Entities:FindByName(nil, "dota_badguys_tower2_bot") 59 | buildings[20] = Entities:FindByName(nil, "dota_badguys_tower3_bot") 60 | 61 | buildings[21] = Entities:FindByName(nil, "dota_badguys_tower1_mid") 62 | buildings[22] = Entities:FindByName(nil, "dota_badguys_tower2_mid") 63 | buildings[23] = Entities:FindByName(nil, "dota_badguys_tower3_mid") 64 | 65 | buildings[24] = Entities:FindByName(nil, "dota_badguys_tower1_top") 66 | buildings[25] = Entities:FindByName(nil, "dota_badguys_tower2_top") 67 | buildings[26] = Entities:FindByName(nil, "dota_badguys_tower3_top") 68 | 69 | buildings[27] = Entities:FindByName(nil, "dota_badguys_tower4_top") 70 | buildings[28] = Entities:FindByName(nil, "dota_badguys_tower4_bot") 71 | 72 | buildings[29] = Entities:FindByName(nil, "bad_rax_melee_bot") 73 | buildings[30] = Entities:FindByName(nil, "bad_rax_range_bot") 74 | buildings[31] = Entities:FindByName(nil, "bad_rax_melee_mid") 75 | buildings[32] = Entities:FindByName(nil, "bad_rax_range_mid") 76 | buildings[33] = Entities:FindByName(nil, "bad_rax_melee_top") 77 | buildings[34] = Entities:FindByName(nil, "bad_rax_range_top") 78 | 79 | buildings[35] = Entities:FindByName(nil, "ent_dota_fountain_bad") 80 | 81 | 82 | for i,unit in ipairs(buildings) do 83 | world.entities[unit:entindex()]=self:JSONunit(unit) 84 | end 85 | 86 | return package.loaded['game/dkjson'].encode(world) 87 | end 88 | function CAddonTemplateGameMode:JSONtree(eTree) 89 | local tree = {} 90 | tree.origin = VectorToArray(eTree:GetOrigin()) 91 | tree.type = "Tree" 92 | return tree 93 | end 94 | 95 | function CAddonTemplateGameMode:JSONunit(eUnit) 96 | local unit = {} 97 | unit.level = eUnit:GetLevel() 98 | unit.origin = VectorToArray(eUnit:GetOrigin()) 99 | --unit.absOrigin = VectorToArray(eUnit:GetAbsOrigin()) 100 | --unit.center = VectorToArray(eUnit:GetCenter()) 101 | --unit.velocity = VectorToArray(eUnit:GetVelocity()) 102 | --unit.localVelocity = VectorToArray(eUnit:GetForwardVector()) 103 | unit.health = eUnit:GetHealth() 104 | unit.maxHealth = eUnit:GetMaxHealth() 105 | unit.mana = eUnit:GetMana() 106 | unit.maxMana = eUnit:GetMaxMana() 107 | unit.alive = eUnit:IsAlive() 108 | unit.blind = eUnit:IsBlind() 109 | unit.dominated = eUnit:IsDominated() 110 | unit.deniable = eUnit:IsDeniable() 111 | unit.disarmed = eUnit:IsDisarmed() 112 | unit.rooted = eUnit:IsRooted() 113 | unit.name = eUnit:GetName() 114 | unit.team = eUnit:GetTeamNumber() 115 | unit.attackRange = eUnit:GetAttackRange() 116 | 117 | if eUnit:IsHero() then 118 | unit.gold = eUnit:GetGold() 119 | unit.kills=eUnit:GetKills() 120 | unit.type = "Hero" 121 | -- Abilities are actually in CBaseNPC, but we'll just send them for Heros to avoid cluttering the JSON-- 122 | unit.abilities = {} 123 | local abilityCount = eUnit:GetAbilityCount() - 1 --minus 1 because lua for loops are upper boundary inclusive 124 | 125 | for index=0,abilityCount,1 do 126 | local eAbility = eUnit:GetAbilityByIndex(index) 127 | -- abilityCount returned 16 for me even though the hero had only 5 slots (maybe it's actually max slots?). We fix that by checking for null pointer 128 | if eAbility then 129 | unit.abilities[index] = {} 130 | unit.abilities[index].type = "Ability" 131 | unit.abilities[index].name = eAbility:GetAbilityName() 132 | unit.abilities[index].targetFlags = eAbility:GetAbilityTargetFlags() 133 | unit.abilities[index].targetTeam = eAbility:GetAbilityTargetTeam() 134 | unit.abilities[index].targetType = eAbility:GetAbilityTargetType() 135 | unit.abilities[index].abilityType = eAbility:GetAbilityType() 136 | unit.abilities[index].abilityIndex = eAbility:GetAbilityIndex() 137 | unit.abilities[index].level = eAbility:GetLevel() 138 | unit.abilities[index].maxLevel = eAbility:GetMaxLevel() 139 | unit.abilities[index].abilityDamage = eAbility:GetAbilityDamage() 140 | unit.abilities[index].abilityDamageType = eAbility:GetAbilityDamage() 141 | unit.abilities[index].cooldownTime = eAbility:GetCooldownTime() 142 | unit.abilities[index].cooldownTimeRemaining = eAbility:GetCooldownTimeRemaining() 143 | end 144 | end 145 | elseif eUnit:IsBuilding() then 146 | if eUnit:IsTower() then 147 | unit.type = "Tower" 148 | else 149 | unit.type = "Building" 150 | end 151 | else 152 | unit.type = "BaseNPC" 153 | end 154 | 155 | 156 | local attackTarget = eUnit:GetAttackTarget() 157 | if attackTarget then 158 | unit.attackTarget = attackTarget:entindex() 159 | end 160 | 161 | return unit 162 | end 163 | function VectorToArray(v) 164 | return {v.x, v.y, v.z} 165 | end 166 | -------------------------------------------------------------------------------- /Dota2 DQN/src/main/java/hmm/dota2dqn/bots/Lina.java: -------------------------------------------------------------------------------- 1 | package hmm.dota2dqn.bots; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | import java.util.Arrays; 6 | import java.util.Set; 7 | import java.util.stream.Collectors; 8 | 9 | import hmm.dota2dqn.common.Configs; 10 | import hmm.dota2dqn.common.Memory; 11 | import hmm.dota2dqn.common.State; 12 | import hmm.dota2dqn.nn.DeepQLearning; 13 | import se.lu.lucs.dota2.framework.bot.BaseBot; 14 | import se.lu.lucs.dota2.framework.bot.BotCommands.LevelUp; 15 | import se.lu.lucs.dota2.framework.bot.BotCommands.Select; 16 | import se.lu.lucs.dota2.framework.game.BaseEntity; 17 | import se.lu.lucs.dota2.framework.game.BaseNPC; 18 | import se.lu.lucs.dota2.framework.game.Hero; 19 | import se.lu.lucs.dota2.framework.game.Tower; 20 | import se.lu.lucs.dota2.framework.game.World; 21 | 22 | public class Lina extends BaseBot{ 23 | 24 | Configs conf; 25 | Memory memory; 26 | DeepQLearning dqn; 27 | boolean isStarting,isFinished; 28 | float[] goalPosition; 29 | float[] previousPosition; 30 | State lastState; 31 | World world; 32 | Hero lina; 33 | int steps,totalReward; 34 | double explorationRate,explorationDecay; 35 | public Lina(Configs inputConfig) throws IOException{ 36 | this.conf=inputConfig; 37 | this.steps=0; 38 | this.totalReward=0; 39 | this.goalPosition= new float[2]; 40 | this.previousPosition=new float[2]; 41 | this.isStarting=true; 42 | this.explorationDecay=conf.getExplorationDecay(); 43 | this.explorationRate=conf.getExplorationRate(); 44 | 45 | //Create memory 46 | memory = new Memory(); 47 | 48 | dqn = new DeepQLearning(conf); 49 | //Check whether the NNets are saved before 50 | File f1 = new File("Navigation_q1"); 51 | File f2 = new File("Navigation_q2"); 52 | if(f1.exists()&&f2.exists()) 53 | dqn.loadNNs("Navigation"); 54 | else 55 | dqn.generateNNs(); 56 | 57 | } 58 | @Override 59 | public LevelUp levelUp() { 60 | // TODO Auto-generated method stub 61 | return null; 62 | } 63 | 64 | @Override 65 | public Select select() { 66 | // TODO Auto-generated method stub 67 | return null; 68 | } 69 | 70 | @Override 71 | public Command update(World inputWorld) { 72 | // TODO Auto-generated method stub 73 | //Steps shows the number of update function called to achieve the goal 74 | this.world=inputWorld; 75 | steps+=1; 76 | final int myIndex = world.searchIndexByName(conf.getHeroName()); 77 | this.lina = (Hero) world.getEntities().get( myIndex ); 78 | if (myIndex < 0) { 79 | System.out.println( "Lina is dead!" ); 80 | return NOOP; 81 | } 82 | float[] position= new float[2]; 83 | position[0]=lina.getOrigin()[0]; 84 | position[1]=lina.getOrigin()[1]; 85 | 86 | /* 87 | * Check whether the game is just started. It might be better to do this part in addon_game_mode.lua file but for now it 88 | * works good enough. 89 | */ 90 | if(isStarting==true){ 91 | lastState= new State(conf.getOneDStateSize()); 92 | goalPosition=position.clone(); 93 | isStarting=false; 94 | } 95 | /* 96 | * The time between two update is 0.33 (Please check addon_game_mode.lua) and it is not enough for lina to complete an action. 97 | * Although it is possible to increase the time, it might not good for other activities like attack. Thus, checking whether lina 98 | * is complete its task before applying a new one is essential. In addition, sometimes lina get stuck behind world obstacles so 99 | * it is necessary to check that situation in order to avoiding any infinite loop (It might be better to give some big negative 100 | * reward to those action that make her getting stuck). 101 | */ 102 | if(Arrays.equals(goalPosition, position)||Arrays.equals(position, previousPosition)){ 103 | previousPosition= position.clone(); 104 | goalPosition=calcDestination(position).clone(); 105 | if(isFinished==true){ 106 | System.out.println(steps); 107 | return NOOP; 108 | } 109 | MOVE.setX(goalPosition[0]); 110 | MOVE.setY(goalPosition[1]); 111 | return MOVE; 112 | }else{ 113 | previousPosition= position.clone(); 114 | //Lina has not complete her last act so no need for new act and just do some training. 115 | if(memory.getMemorySize()>32) 116 | dqn.learnOnMemory(memory); 117 | return NOOP; 118 | } 119 | } 120 | private int choseAnAction() { 121 | int act=0; 122 | State observation= new State(conf.getOneDStateSize()); 123 | observation=generateNavigationState(); 124 | int range=(int) conf.getHeroRange(); 125 | int dimansion =(int) conf.getGameDimansion(); 126 | final Set e = findEntitiesInRange( world, lina,range).stream().filter( p -> p instanceof BaseNPC ) 127 | .filter( p -> ((BaseNPC) p).getTeam() == 3 ).collect( Collectors.toSet() ); 128 | //Check whether lina reach her goal 129 | if (!e.isEmpty()&&isFinished==false) { 130 | memory.addToMemory(observation,lastState,act,100,true); 131 | totalReward+=100; 132 | System.out.println(totalReward); 133 | try { 134 | dqn.saveNN("Navigation"); 135 | } catch (IOException e1) { 136 | // TODO Auto-generated catch block 137 | e1.printStackTrace(); 138 | } 139 | isFinished=true; 140 | }else{ 141 | if(Math.abs(goalPosition[0])==dimansion|| Math.abs(goalPosition[1])==dimansion){ 142 | //AI should not choose a coordination out of game world 143 | memory.addToMemory(observation,lastState,act,-1000,false); 144 | totalReward-=1000; 145 | }else{ 146 | memory.addToMemory(observation,lastState,act,-1,false); 147 | totalReward-=1; 148 | } 149 | lastState= new State(observation); 150 | } 151 | act=dqn.selectAction(observation,explorationRate); 152 | explorationRate *= explorationDecay; 153 | return act; 154 | } 155 | private float[] calcDestination(float[] position){ 156 | int act=choseAnAction(); 157 | float[] destination=new float[2]; 158 | int range= (int) conf.getHeroRange(); 159 | int euclideanDistance= (int) (range/Math.sqrt(2)); 160 | int dimansion = (int) conf.getGameDimansion(); 161 | switch (act) { 162 | case 0: 163 | destination[0]=(position[0]+range)>dimansion ? dimansion:position[0]+range; 164 | destination[1]=position[1]; 165 | break; 166 | case 1: 167 | destination[0]=position[0]; 168 | destination[1]=position[1]+range>dimansion ? dimansion:position[1]+range; 169 | break; 170 | case 2: 171 | destination[0]=position[0]-range<-dimansion ? -dimansion:position[0]-range; 172 | destination[1]=position[1]; 173 | break; 174 | case 3: 175 | destination[0]=position[0]; 176 | destination[1]=position[1]-range<-dimansion ? -dimansion:position[1]-range; 177 | break; 178 | case 4: 179 | destination[0]=position[0]+euclideanDistance>dimansion ? dimansion:position[0]+euclideanDistance; 180 | destination[1]=position[1]+euclideanDistance>dimansion ? dimansion:position[1]+euclideanDistance; 181 | break; 182 | case 5: 183 | destination[0]=position[0]+euclideanDistance>dimansion ? dimansion:position[0]+euclideanDistance; 184 | destination[1]=position[1]-euclideanDistance<-dimansion ? -dimansion:position[1]-euclideanDistance; 185 | break; 186 | case 6: 187 | destination[0]=position[0]-euclideanDistance<-dimansion ? -dimansion:position[0]-euclideanDistance; 188 | destination[1]=position[1]-euclideanDistance<-dimansion ? -dimansion:position[1]-euclideanDistance; 189 | break; 190 | case 7: 191 | destination[0]=position[0]-euclideanDistance<-dimansion ? -dimansion:position[0]-euclideanDistance; 192 | destination[1]=position[1]+euclideanDistance>dimansion ? dimansion:position[1]+euclideanDistance; 193 | break; 194 | default: 195 | break; 196 | } 197 | 198 | return destination; 199 | } 200 | private State generateNavigationState(){ 201 | State returnState = new State(conf.getOneDStateSize()); 202 | int[] xAndy=new int[2]; 203 | float[] position= new float[3]; 204 | position=lina.getOrigin(); 205 | xAndy=convertToMiniSize(position[0],position[1]); 206 | returnState.addToState(xAndy,1); 207 | final Set en = findEntitiesInRange( world, lina, Float.POSITIVE_INFINITY ).stream().filter( p -> p instanceof BaseNPC ) 208 | .filter( p -> ((BaseNPC) p).getTeam() == 3 ).collect( Collectors.toSet() ); 209 | for (BaseEntity targetenitity : en) { 210 | position=targetenitity.getOrigin(); 211 | xAndy=convertToMiniSize(position[0],position[1]); 212 | switch (targetenitity.getClass().getName()) { 213 | case "se.lu.lucs.dota2.framework.game.Hero": 214 | returnState.addToState(xAndy,2); 215 | break; 216 | case "se.lu.lucs.dota2.framework.game.BaseNPC": 217 | returnState.addToState(xAndy,3); 218 | break; 219 | case "se.lu.lucs.dota2.framework.game.Tower": 220 | returnState.addToState(xAndy,4); 221 | break; 222 | case "se.lu.lucs.dota2.framework.game.Building": 223 | returnState.addToState(xAndy,5); 224 | break; 225 | default: 226 | break; 227 | } 228 | } 229 | final Set altower = findEntitiesInRange( world, lina, Float.POSITIVE_INFINITY ).stream().filter( p -> p instanceof BaseNPC ) 230 | .filter( p -> ((BaseNPC) p).getTeam() == lina.getTeam() ).filter( p -> ((BaseNPC) p).getClass() == Tower.class ).collect( Collectors.toSet() ); 231 | for (BaseEntity targetenitity : altower) { 232 | position=targetenitity.getOrigin(); 233 | xAndy=convertToMiniSize(position[0],position[1]); 234 | returnState.addToState(xAndy,6); 235 | } 236 | return returnState; 237 | } 238 | private int[] convertToMiniSize(float x, float y){ 239 | int[] totalarray= new int[2]; 240 | double XandYofWorld= conf.getGameDimansion(); 241 | int minisize= conf.getOneDStateSize(); 242 | totalarray[0]=(int)(((x+XandYofWorld)*minisize)/(XandYofWorld*2)); 243 | totalarray[1]=(int)(((XandYofWorld-y)*minisize)/(XandYofWorld*2)); 244 | return totalarray; 245 | } 246 | private static Set findEntitiesInRange( World world, BaseEntity center, float range ) { 247 | final Set result = world.getEntities().values().stream().filter( e -> distance( center, e ) < range ).collect( Collectors.toSet() ); 248 | result.remove( center ); 249 | return result; 250 | } 251 | private static float distance( BaseEntity a, BaseEntity b ) { 252 | final float[] posA = a.getOrigin(); 253 | final float[] posB = b.getOrigin(); 254 | return distance( posA, posB ); 255 | } 256 | private static float distance( float[] posA, float[] posB ) { 257 | return (float) Math.hypot( posB[0] - posA[0], posB[1] - posA[1] ); 258 | } 259 | } 260 | --------------------------------------------------------------------------------