├── _config.yml ├── src ├── reimprove │ ├── misc │ │ ├── data_source.ts │ │ ├── typed_window.ts │ │ └── learning_data_logger.ts │ ├── algorithms │ │ ├── q │ │ │ ├── qaction.ts │ │ │ ├── qtransition.ts │ │ │ ├── qstate.ts │ │ │ ├── qagent.ts │ │ │ └── qmatrix.ts │ │ ├── agent_config.ts │ │ ├── abstract_agent.ts │ │ └── deepq │ │ │ └── dqagent.ts │ ├── memory.ts │ ├── networks.ts │ ├── teacher.ts │ ├── model.ts │ └── academy.ts └── reimprove.ts ├── .gitignore ├── .npmignore ├── .travis.yml ├── docs ├── classes │ ├── datasource.md │ ├── qaction.md │ ├── result.md │ ├── typedwindow.md │ ├── learningdatalogger.md │ ├── memory.md │ ├── neuralnetwork.md │ ├── abstractagent.md │ ├── qtransition.md │ ├── qstate.md │ ├── model.md │ ├── convolutionalneuralnetwork.md │ └── teacher.md ├── interfaces │ ├── qstatedata.md │ ├── qactiondata.md │ ├── totflayerconfig.md │ ├── memoryconfig.md │ ├── agentconfig.md │ ├── mementotensor.md │ ├── academystepinput.md │ ├── graphnode.md │ ├── buildagentconfig.md │ ├── agenttrackinginformation.md │ ├── memento.md │ ├── dqagentconfig.md │ ├── academyconfig.md │ ├── layer.md │ ├── layerconfig.md │ ├── graphedge.md │ ├── convolutionalnetworklayer.md │ ├── teachertrackinginformation.md │ ├── neuralnetworklayer.md │ ├── denselayer.md │ ├── teachingconfig.md │ ├── maxpooling2dlayer.md │ ├── qagentconfig.md │ ├── dropoutlayer.md │ ├── convolutionallayer.md │ └── flattenlayer.md ├── enums │ ├── layertype.md │ └── teachingstate.md └── README.md ├── scripts └── build-npm.sh ├── test ├── academy.spec.ts ├── agent.spec.ts ├── network.spec.ts ├── model.spec.ts ├── teacher.spec.ts ├── reimprove.spec.ts └── q.spec.ts ├── LICENSE ├── tsconfig.json ├── CHANGELOG.md ├── package.json └── README.md /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /src/reimprove/misc/data_source.ts: -------------------------------------------------------------------------------- 1 | export class DataSource { 2 | 3 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | node_modules/ 3 | .cache/ 4 | dist/ 5 | reimprovejs-*.tgz -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | tsconfig.json 2 | .npmignore 3 | reimprovejs-*.tgz 4 | docs/ 5 | .travis.yml 6 | _config.yml -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: nodejs 2 | 3 | before_script: 4 | - npm run setup 5 | 6 | script: 7 | - npm run test 8 | 9 | after_success: 10 | - npm run build -------------------------------------------------------------------------------- /docs/classes/datasource.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [DataSource](datasource.md) / 4 | 5 | # Class: DataSource 6 | 7 | ## Hierarchy 8 | 9 | * **DataSource** -------------------------------------------------------------------------------- /scripts/build-npm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | rm -rf dist/ 6 | tsc -p . 7 | browserify --standalone ReImprove src/reimprove.ts -p [tsify] > dist/reimprove.js #| uglifyjs > dist/reimprove.js 8 | echo "Prepared bundle" 9 | npm pack -------------------------------------------------------------------------------- /docs/interfaces/qstatedata.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [QStateData](qstatedata.md) / 4 | 5 | # Interface: QStateData 6 | 7 | ## Hierarchy 8 | 9 | * **QStateData** 10 | 11 | ## Indexable 12 | 13 | ● \[■` key`: *string*\]: any -------------------------------------------------------------------------------- /docs/interfaces/qactiondata.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [QActionData](qactiondata.md) / 4 | 5 | # Interface: QActionData 6 | 7 | ## Hierarchy 8 | 9 | * **QActionData** 10 | 11 | ## Indexable 12 | 13 | ● \[■` key`: *string*\]: any -------------------------------------------------------------------------------- /docs/interfaces/totflayerconfig.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [ToTfLayerConfig](totflayerconfig.md) / 4 | 5 | # Interface: ToTfLayerConfig 6 | 7 | ## Hierarchy 8 | 9 | * **ToTfLayerConfig** 10 | 11 | ## Indexable 12 | 13 | ● \[■` key`: *string*\]: any -------------------------------------------------------------------------------- /src/reimprove/algorithms/q/qaction.ts: -------------------------------------------------------------------------------- 1 | export interface QActionData { 2 | [key: string]: any; 3 | } 4 | 5 | export class QAction { 6 | private data: QActionData; 7 | 8 | constructor(private name: string, data?: QActionData) { 9 | this.data = data; 10 | } 11 | 12 | get Data() { 13 | return this.data; 14 | } 15 | 16 | set Data(data: QActionData) { 17 | this.data = data; 18 | } 19 | 20 | get Name() { 21 | return this.name; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /docs/interfaces/memoryconfig.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [MemoryConfig](memoryconfig.md) / 4 | 5 | # Interface: MemoryConfig 6 | 7 | ## Hierarchy 8 | 9 | * **MemoryConfig** 10 | 11 | ### Index 12 | 13 | #### Properties 14 | 15 | * [size](memoryconfig.md#size) 16 | 17 | ## Properties 18 | 19 | ### size 20 | 21 | ● **size**: *number* 22 | 23 | *Defined in [reimprove/memory.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L5)* 24 | 25 | ___ -------------------------------------------------------------------------------- /docs/interfaces/agentconfig.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [AgentConfig](agentconfig.md) / 4 | 5 | # Interface: AgentConfig 6 | 7 | ## Hierarchy 8 | 9 | * **AgentConfig** 10 | 11 | * [DQAgentConfig](dqagentconfig.md) 12 | 13 | * [QAgentConfig](qagentconfig.md) 14 | 15 | ### Index 16 | 17 | #### Properties 18 | 19 | * [memorySize](agentconfig.md#optional-memorysize) 20 | 21 | ## Properties 22 | 23 | ### `Optional` memorySize 24 | 25 | ● **memorySize**? : *number* 26 | 27 | *Defined in [reimprove/algorithms/agent_config.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L5)* 28 | 29 | ___ -------------------------------------------------------------------------------- /src/reimprove/algorithms/agent_config.ts: -------------------------------------------------------------------------------- 1 | import {QAction} from "./q/qaction"; 2 | import {QStateData} from "./q/qstate"; 3 | 4 | export interface AgentConfig { 5 | memorySize?: number; 6 | } 7 | 8 | export interface DQAgentConfig extends AgentConfig { 9 | batchSize?: number; 10 | temporalWindow?: number; 11 | } 12 | 13 | export interface QAgentConfig extends AgentConfig { 14 | createMatrixDynamically?: boolean; 15 | actions?: Array; 16 | startingData?: QStateData; 17 | dataHash: (data:QStateData) => string; 18 | initialState?: QStateData; 19 | gamma?: number; 20 | } 21 | 22 | export interface AgentTrackingInformation { 23 | averageLoss: number; 24 | averageReward: number; 25 | name: string; 26 | } -------------------------------------------------------------------------------- /src/reimprove/misc/typed_window.ts: -------------------------------------------------------------------------------- 1 | import {mean} from 'lodash'; 2 | 3 | export class TypedWindow { 4 | private window: Array; 5 | 6 | constructor(private size: number, private minSize: number, private nullValue: T) { 7 | this.window = []; 8 | } 9 | 10 | add(value: T): void { 11 | if(value == this.nullValue) return; 12 | this.window.push(value); 13 | if(this.window.length > this.size) 14 | this.window.shift(); 15 | } 16 | 17 | mean(): number { 18 | if(this.window.length < this.minSize) { 19 | return -1; 20 | } else { 21 | return mean(this.window); 22 | } 23 | } 24 | 25 | get Window() { 26 | return this.window; 27 | } 28 | } -------------------------------------------------------------------------------- /docs/interfaces/mementotensor.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [MementoTensor](mementotensor.md) / 4 | 5 | # Interface: MementoTensor 6 | 7 | ## Hierarchy 8 | 9 | * **MementoTensor** 10 | 11 | ### Index 12 | 13 | #### Properties 14 | 15 | * [references](mementotensor.md#references) 16 | * [tensor](mementotensor.md#tensor) 17 | 18 | ## Properties 19 | 20 | ### references 21 | 22 | ● **references**: *number* 23 | 24 | *Defined in [reimprove/memory.ts:17](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L17)* 25 | 26 | ___ 27 | 28 | ### tensor 29 | 30 | ● **tensor**: *`Tensor`* 31 | 32 | *Defined in [reimprove/memory.ts:16](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L16)* 33 | 34 | ___ -------------------------------------------------------------------------------- /test/academy.spec.ts: -------------------------------------------------------------------------------- 1 | import {expect} from "chai"; 2 | import {Academy} from "../src/reimprove/academy"; 3 | 4 | const academy = new Academy(); 5 | 6 | describe('Academy', () => { 7 | beforeEach(() => { 8 | academy.reset(); 9 | }); 10 | 11 | it("should generate new agent name", () => { 12 | expect(academy.addAgent({model: null})).to.not.be.null; 13 | }); 14 | 15 | it("should generate new teacher name", () => { 16 | expect(academy.addTeacher()).to.not.be.null; 17 | }); 18 | 19 | it("should let the actual agent name", () => { 20 | expect(academy.addAgent({model: null}, "test")).to.be.equal("test"); 21 | }); 22 | 23 | it("should let the actual teacher name", () => { 24 | expect(academy.addTeacher(null, "test")).to.be.equal("test"); 25 | }) 26 | }); -------------------------------------------------------------------------------- /src/reimprove/algorithms/q/qtransition.ts: -------------------------------------------------------------------------------- 1 | import {QAction} from "./qaction"; 2 | import {QState} from "./qstate"; 3 | 4 | 5 | export class QTransition { 6 | private QValue: number; 7 | private readonly id: number; 8 | 9 | private static transitionId: number = 0; 10 | 11 | constructor(private from: QState, private to: QState, private action: QAction) { 12 | this.QValue = 0; 13 | this.id = QTransition.transitionId++; 14 | } 15 | 16 | get Q() { return this.QValue; } 17 | set Q(qvalue: number) { this.QValue = qvalue; } 18 | 19 | get From() { return this.from; } 20 | get To() { return this.to; } 21 | get Action() { return this.action; } 22 | 23 | set To(state: QState) { this.to = state; } 24 | set From(state: QState) { this.from = state; } 25 | 26 | get Id(): number { return this.id; } 27 | } -------------------------------------------------------------------------------- /src/reimprove.ts: -------------------------------------------------------------------------------- 1 | export {Model, LayerType} from "./reimprove/model"; 2 | export {Academy, BuildAgentConfig} from "./reimprove/academy"; 3 | export {AgentConfig} from "./reimprove/algorithms/agent_config"; 4 | export {TeachingConfig} from "./reimprove/teacher"; 5 | 6 | export { 7 | NeuralNetwork, 8 | ConvolutionalNeuralNetwork, 9 | ConvolutionalLayer, 10 | MaxPooling2DLayer, 11 | FlattenLayer, 12 | DenseLayer, 13 | DropoutLayer 14 | } from './reimprove/networks'; 15 | 16 | export {QAgent} from "./reimprove/algorithms/q/qagent"; 17 | export {QState} from "./reimprove/algorithms/q/qstate"; 18 | export {QAction} from "./reimprove/algorithms/q/qaction"; 19 | export {QMatrix} from "./reimprove/algorithms/q/qmatrix"; 20 | export {QTransition} from "./reimprove/algorithms/q/qtransition"; 21 | 22 | export {setBackend} from "@tensorflow/tfjs"; 23 | -------------------------------------------------------------------------------- /docs/enums/layertype.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [LayerType](layertype.md) / 4 | 5 | # Enumeration: LayerType 6 | 7 | ### Index 8 | 9 | #### Enumeration members 10 | 11 | * [CONV2D](layertype.md#conv2d) 12 | * [DENSE](layertype.md#dense) 13 | * [FLATTEN](layertype.md#flatten) 14 | 15 | ## Enumeration members 16 | 17 | ### CONV2D 18 | 19 | ● **CONV2D**: = "CONV2D" 20 | 21 | *Defined in [reimprove/model.ts:14](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L14)* 22 | 23 | ___ 24 | 25 | ### DENSE 26 | 27 | ● **DENSE**: = "DENSE" 28 | 29 | *Defined in [reimprove/model.ts:13](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L13)* 30 | 31 | ___ 32 | 33 | ### FLATTEN 34 | 35 | ● **FLATTEN**: = "FLATTEN" 36 | 37 | *Defined in [reimprove/model.ts:15](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L15)* 38 | 39 | ___ -------------------------------------------------------------------------------- /docs/interfaces/academystepinput.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [AcademyStepInput](academystepinput.md) / 4 | 5 | # Interface: AcademyStepInput 6 | 7 | Input to give at each step of the Academy, where you specify the target teacher and its inputs. 8 | 9 | ## Hierarchy 10 | 11 | * **AcademyStepInput** 12 | 13 | ### Index 14 | 15 | #### Properties 16 | 17 | * [agentsInput](academystepinput.md#agentsinput) 18 | * [teacherName](academystepinput.md#teachername) 19 | 20 | ## Properties 21 | 22 | ### agentsInput 23 | 24 | ● **agentsInput**: *number[]* 25 | 26 | *Defined in [reimprove/academy.ts:32](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/academy.ts#L32)* 27 | 28 | ___ 29 | 30 | ### teacherName 31 | 32 | ● **teacherName**: *string* 33 | 34 | *Defined in [reimprove/academy.ts:31](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/academy.ts#L31)* 35 | 36 | ___ -------------------------------------------------------------------------------- /docs/interfaces/graphnode.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [GraphNode](graphnode.md) / 4 | 5 | # Interface: GraphNode 6 | 7 | ## Hierarchy 8 | 9 | * **GraphNode** 10 | 11 | ### Index 12 | 13 | #### Properties 14 | 15 | * [color](graphnode.md#optional-color) 16 | * [id](graphnode.md#id) 17 | * [label](graphnode.md#label) 18 | 19 | ## Properties 20 | 21 | ### `Optional` color 22 | 23 | ● **color**? : *string* 24 | 25 | *Defined in [reimprove/algorithms/q/qmatrix.ts:8](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qmatrix.ts#L8)* 26 | 27 | ___ 28 | 29 | ### id 30 | 31 | ● **id**: *number* 32 | 33 | *Defined in [reimprove/algorithms/q/qmatrix.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qmatrix.ts#L6)* 34 | 35 | ___ 36 | 37 | ### label 38 | 39 | ● **label**: *string* 40 | 41 | *Defined in [reimprove/algorithms/q/qmatrix.ts:7](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qmatrix.ts#L7)* 42 | 43 | ___ -------------------------------------------------------------------------------- /docs/interfaces/buildagentconfig.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [BuildAgentConfig](buildagentconfig.md) / 4 | 5 | # Interface: BuildAgentConfig 6 | 7 | Configuration to build an agent 8 | 9 | ## Hierarchy 10 | 11 | * **BuildAgentConfig** 12 | 13 | ### Index 14 | 15 | #### Properties 16 | 17 | * [agentConfig](buildagentconfig.md#optional-agentconfig) 18 | * [model](buildagentconfig.md#model) 19 | 20 | ## Properties 21 | 22 | ### `Optional` agentConfig 23 | 24 | ● **agentConfig**? : *[DQAgentConfig](dqagentconfig.md)* 25 | 26 | *Defined in [reimprove/academy.ts:40](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/academy.ts#L40)* 27 | 28 | The agent configuration, defaulted if not present 29 | 30 | ___ 31 | 32 | ### model 33 | 34 | ● **model**: *[Model](../classes/model.md)* 35 | 36 | *Defined in [reimprove/academy.ts:38](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/academy.ts#L38)* 37 | 38 | The agent cannot have no model. But multiple agents can share the same one 39 | 40 | ___ -------------------------------------------------------------------------------- /src/reimprove/algorithms/abstract_agent.ts: -------------------------------------------------------------------------------- 1 | import {AgentConfig, AgentTrackingInformation} from "./agent_config"; 2 | import {QAction} from "./q/qaction"; 3 | import {QTransition} from "./q/qtransition"; 4 | 5 | const DEFAULT_AGENT_CONFIG: AgentConfig = { 6 | memorySize: 30000 7 | }; 8 | 9 | export abstract class AbstractAgent { 10 | protected agentConfig: AgentConfig; 11 | 12 | protected constructor(agentConfig?: AgentConfig, private name?: string) { 13 | this.agentConfig = {...DEFAULT_AGENT_CONFIG, ...agentConfig}; 14 | } 15 | 16 | abstract get AgentConfig(): AgentConfig; 17 | protected setAgentConfig(config: AgentConfig) { this.agentConfig = config; } 18 | 19 | get Name() { return this.name; } 20 | set Name(name: string) { this.name = name; } 21 | 22 | abstract getTrackingInformation(): AgentTrackingInformation; 23 | abstract reset(): void; 24 | 25 | // abstract learn(gamma?: number, alpha?: number, data?: QStateData): void; 26 | abstract infer(input: number[] | number[][] | QAction, epsilon?: number, keepTensors?: boolean): number | QTransition; 27 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Paul Breton 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/interfaces/agenttrackinginformation.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [AgentTrackingInformation](agenttrackinginformation.md) / 4 | 5 | # Interface: AgentTrackingInformation 6 | 7 | ## Hierarchy 8 | 9 | * **AgentTrackingInformation** 10 | 11 | ### Index 12 | 13 | #### Properties 14 | 15 | * [averageLoss](agenttrackinginformation.md#averageloss) 16 | * [averageReward](agenttrackinginformation.md#averagereward) 17 | * [name](agenttrackinginformation.md#name) 18 | 19 | ## Properties 20 | 21 | ### averageLoss 22 | 23 | ● **averageLoss**: *number* 24 | 25 | *Defined in [reimprove/algorithms/agent_config.ts:23](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L23)* 26 | 27 | ___ 28 | 29 | ### averageReward 30 | 31 | ● **averageReward**: *number* 32 | 33 | *Defined in [reimprove/algorithms/agent_config.ts:24](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L24)* 34 | 35 | ___ 36 | 37 | ### name 38 | 39 | ● **name**: *string* 40 | 41 | *Defined in [reimprove/algorithms/agent_config.ts:25](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L25)* 42 | 43 | ___ -------------------------------------------------------------------------------- /docs/interfaces/memento.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [Memento](memento.md) / 4 | 5 | # Interface: Memento 6 | 7 | ## Hierarchy 8 | 9 | * **Memento** 10 | 11 | ### Index 12 | 13 | #### Properties 14 | 15 | * [action](memento.md#action) 16 | * [nextState](memento.md#nextstate) 17 | * [reward](memento.md#reward) 18 | * [state](memento.md#state) 19 | 20 | ## Properties 21 | 22 | ### action 23 | 24 | ● **action**: *number* 25 | 26 | *Defined in [reimprove/memory.ts:10](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L10)* 27 | 28 | ___ 29 | 30 | ### nextState 31 | 32 | ● **nextState**: *[MementoTensor](mementotensor.md)* 33 | 34 | *Defined in [reimprove/memory.ts:12](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L12)* 35 | 36 | ___ 37 | 38 | ### reward 39 | 40 | ● **reward**: *number* 41 | 42 | *Defined in [reimprove/memory.ts:11](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L11)* 43 | 44 | ___ 45 | 46 | ### state 47 | 48 | ● **state**: *[MementoTensor](mementotensor.md)* 49 | 50 | *Defined in [reimprove/memory.ts:9](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L9)* 51 | 52 | ___ -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compileOnSave": true, 3 | "compilerOptions": { 4 | "allowJs": false, 5 | "allowUnreachableCode": false, 6 | "allowUnusedLabels": false, 7 | "alwaysStrict": true, 8 | "checkJs": false, 9 | "declaration": true, 10 | "diagnostics": true, 11 | "esModuleInterop": true, 12 | "forceConsistentCasingInFileNames": true, 13 | "lib": [ 14 | "es2015", 15 | "dom", 16 | "dom.iterable" 17 | ], 18 | "module": "commonjs", 19 | "moduleResolution": "node", 20 | "newLine": "LF", 21 | "noEmitOnError": true, 22 | "noFallthroughCasesInSwitch": true, 23 | "noImplicitAny": true, 24 | "noImplicitReturns": true, 25 | "noImplicitThis": true, 26 | "noUnusedLocals": true, 27 | "noUnusedParameters": false, 28 | "outDir": "dist", 29 | "pretty": true, 30 | "removeComments": true, 31 | "skipLibCheck": true, 32 | "strictFunctionTypes": true, 33 | "strictNullChecks": false, 34 | "strictPropertyInitialization": false, 35 | "target": "es5" 36 | }, 37 | "directive-selector": [ 38 | true, 39 | "attribute", 40 | "app", 41 | "camelCase" 42 | ], 43 | "include": [ 44 | "src/**/*.ts" 45 | ], 46 | "exclude": [ 47 | "node_modules" 48 | ] 49 | } 50 | -------------------------------------------------------------------------------- /docs/interfaces/dqagentconfig.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [DQAgentConfig](dqagentconfig.md) / 4 | 5 | # Interface: DQAgentConfig 6 | 7 | ## Hierarchy 8 | 9 | * [AgentConfig](agentconfig.md) 10 | 11 | * **DQAgentConfig** 12 | 13 | ### Index 14 | 15 | #### Properties 16 | 17 | * [batchSize](dqagentconfig.md#optional-batchsize) 18 | * [memorySize](dqagentconfig.md#optional-memorysize) 19 | * [temporalWindow](dqagentconfig.md#optional-temporalwindow) 20 | 21 | ## Properties 22 | 23 | ### `Optional` batchSize 24 | 25 | ● **batchSize**? : *number* 26 | 27 | *Defined in [reimprove/algorithms/agent_config.ts:9](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L9)* 28 | 29 | ___ 30 | 31 | ### `Optional` memorySize 32 | 33 | ● **memorySize**? : *number* 34 | 35 | *Inherited from [AgentConfig](agentconfig.md).[memorySize](agentconfig.md#optional-memorysize)* 36 | 37 | *Defined in [reimprove/algorithms/agent_config.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L5)* 38 | 39 | ___ 40 | 41 | ### `Optional` temporalWindow 42 | 43 | ● **temporalWindow**? : *number* 44 | 45 | *Defined in [reimprove/algorithms/agent_config.ts:10](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L10)* 46 | 47 | ___ -------------------------------------------------------------------------------- /docs/interfaces/academyconfig.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [AcademyConfig](academyconfig.md) / 4 | 5 | # Interface: AcademyConfig 6 | 7 | Academy configuration, used for logs. You need to say if you want to log agents and memory, and 8 | give the parent
element 9 | 10 | ## Hierarchy 11 | 12 | * **AcademyConfig** 13 | 14 | ### Index 15 | 16 | #### Properties 17 | 18 | * [agentsLogs](academyconfig.md#agentslogs) 19 | * [memoryLogs](academyconfig.md#memorylogs) 20 | * [parentLogsElement](academyconfig.md#parentlogselement) 21 | 22 | ## Properties 23 | 24 | ### agentsLogs 25 | 26 | ● **agentsLogs**: *boolean* 27 | 28 | *Defined in [reimprove/academy.ts:22](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/academy.ts#L22)* 29 | 30 | If agents logs should be displayed, default to `false` 31 | 32 | ___ 33 | 34 | ### memoryLogs 35 | 36 | ● **memoryLogs**: *boolean* 37 | 38 | *Defined in [reimprove/academy.ts:24](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/academy.ts#L24)* 39 | 40 | If memory logs should be displayed, default to `false` 41 | 42 | ___ 43 | 44 | ### parentLogsElement 45 | 46 | ● **parentLogsElement**: *`HTMLElement`* 47 | 48 | *Defined in [reimprove/academy.ts:20](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/academy.ts#L20)* 49 | 50 | Parent
element, default to `null` 51 | 52 | ___ -------------------------------------------------------------------------------- /docs/enums/teachingstate.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [TeachingState](teachingstate.md) / 4 | 5 | # Enumeration: TeachingState 6 | 7 | ### Index 8 | 9 | #### Enumeration members 10 | 11 | * [EXPERIENCING](teachingstate.md#experiencing) 12 | * [LEARNING](teachingstate.md#learning) 13 | * [NONE](teachingstate.md#none) 14 | * [STOPPED](teachingstate.md#stopped) 15 | * [TESTING](teachingstate.md#testing) 16 | 17 | ## Enumeration members 18 | 19 | ### EXPERIENCING 20 | 21 | ● **EXPERIENCING**: = 0 22 | 23 | *Defined in [reimprove/teacher.ts:28](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L28)* 24 | 25 | ___ 26 | 27 | ### LEARNING 28 | 29 | ● **LEARNING**: = 1 30 | 31 | *Defined in [reimprove/teacher.ts:29](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L29)* 32 | 33 | ___ 34 | 35 | ### NONE 36 | 37 | ● **NONE**: = -1 38 | 39 | *Defined in [reimprove/teacher.ts:31](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L31)* 40 | 41 | ___ 42 | 43 | ### STOPPED 44 | 45 | ● **STOPPED**: = -2 46 | 47 | *Defined in [reimprove/teacher.ts:32](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L32)* 48 | 49 | ___ 50 | 51 | ### TESTING 52 | 53 | ● **TESTING**: = 2 54 | 55 | *Defined in [reimprove/teacher.ts:30](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L30)* 56 | 57 | ___ -------------------------------------------------------------------------------- /src/reimprove/algorithms/q/qstate.ts: -------------------------------------------------------------------------------- 1 | import {QTransition} from "./qtransition"; 2 | import {QAction} from "./qaction"; 3 | 4 | export interface QStateData { 5 | [key: string]: any; 6 | } 7 | 8 | export class QState { 9 | private transitions: Map; 10 | private final: boolean; 11 | private readonly id: number; 12 | 13 | private static stateId: number = 0; 14 | 15 | constructor(private readonly data: QStateData, private reward: number) { 16 | this.transitions = new Map(); 17 | this.final = false; 18 | this.id = QState.stateId++; 19 | } 20 | 21 | setTransition(action: QAction, transition: QTransition): QTransition { 22 | if(!this.transitions.has(action) || this.transitions.get(action) === null) 23 | return this.transitions.set(action, transition).get(action); 24 | return null; 25 | } 26 | 27 | takeAction(action: QAction): QTransition { 28 | return this.transitions.get(action); 29 | } 30 | 31 | get Data(): QStateData { return this.data; } 32 | get Reward(): number { return this.reward; } 33 | set Reward(reward: number) { this.reward = reward; } 34 | get Transitions(): QTransition[] { return Array.from(this.transitions.values()); } 35 | setFinal(): QState { this.final = true; return this; } 36 | set Final(final: boolean) { this.final = final; } 37 | get Final() { return this.final; } 38 | get Id(): number { return this.id; } 39 | } -------------------------------------------------------------------------------- /docs/interfaces/layer.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [Layer](layer.md) / 4 | 5 | # Interface: Layer 6 | 7 | ## Hierarchy 8 | 9 | * **Layer** 10 | 11 | * [ConvolutionalNetworkLayer](convolutionalnetworklayer.md) 12 | 13 | * [NeuralNetworkLayer](neuralnetworklayer.md) 14 | 15 | * [FlattenLayer](flattenlayer.md) 16 | 17 | ### Index 18 | 19 | #### Properties 20 | 21 | * [activation](layer.md#optional-activation) 22 | * [inputShape](layer.md#optional-inputshape) 23 | * [name](layer.md#optional-name) 24 | * [useBias](layer.md#optional-usebias) 25 | 26 | ## Properties 27 | 28 | ### `Optional` activation 29 | 30 | ● **activation**? : *string | any* 31 | 32 | *Defined in [reimprove/networks.ts:4](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L4)* 33 | 34 | ___ 35 | 36 | ### `Optional` inputShape 37 | 38 | ● **inputShape**? : *number[]* 39 | 40 | *Defined in [reimprove/networks.ts:8](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L8)* 41 | 42 | Do not use this field 43 | 44 | ___ 45 | 46 | ### `Optional` name 47 | 48 | ● **name**? : *string* 49 | 50 | *Defined in [reimprove/networks.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L6)* 51 | 52 | ___ 53 | 54 | ### `Optional` useBias 55 | 56 | ● **useBias**? : *boolean* 57 | 58 | *Defined in [reimprove/networks.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L5)* 59 | 60 | ___ -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ReImproveJS Changelog 2 | ======================== 3 | 4 | Version 0.0.2 5 | ---------------------- 6 | Improved the way models are created, by implementing a solution between TensorFlowJS and the user, giving 7 | interfaces to allow easier creation. Also, added support for Convolutional Neural Networks and for model loading 8 | from Keras or TensorFlow models. 9 | * Added Neural Networks, a class which permits to create a neural network. This class corresponds to 10 | the pre-creation of the model, only giving each layer its configuration. 11 | * Add layers to your neural network (dense, dropout, flatten) 12 | * Added Convolutional Neural Network, an extension of the Neural Networks which permits to add 13 | convolutional layers (conv2d, maxpool2d). 14 | * In the Convolutional Neural Network, the structure is managed automatically (conv layers => dense layers), you 15 | just have to set the content of each part. 16 | * __Convolutional Networks are not ready for use for now__ 17 | 18 | 19 | Version 0.0.1 20 | ---------------------- 21 | This version corresponds to the minimal to have a working library. 22 | * Create an academy, agents, teachers 23 | * Associate teachers and agents 24 | * Give an agent a neural network model 25 | * Neural Network is managed by TensorFlowJS 26 | * Implemented Q-Learning algorithm and here DQN 27 | * Reward your agents 28 | * Possibility to dynamically manage the learning sequence 29 | * Parameters to create to learning sessions 30 | * Possibility to visualize training sequence parameters by providing a debug output -------------------------------------------------------------------------------- /docs/interfaces/layerconfig.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [LayerConfig](layerconfig.md) / 4 | 5 | # Interface: LayerConfig 6 | 7 | Simplified layer configuration where you only give your layer, your activation function and the number of units. 8 | 9 | ## Hierarchy 10 | 11 | * **LayerConfig** 12 | 13 | ### Index 14 | 15 | #### Properties 16 | 17 | * [activation](layerconfig.md#activation) 18 | * [inputShape](layerconfig.md#optional-inputshape) 19 | * [units](layerconfig.md#units) 20 | * [useBias](layerconfig.md#optional-usebias) 21 | 22 | ## Properties 23 | 24 | ### activation 25 | 26 | ● **activation**: *string* 27 | 28 | *Defined in [reimprove/model.ts:27](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L27)* 29 | 30 | The activation function ('relu', 'sigmoid', ...) 31 | 32 | ___ 33 | 34 | ### `Optional` inputShape 35 | 36 | ● **inputShape**? : *`Array`* 37 | 38 | *Defined in [reimprove/model.ts:25](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L25)* 39 | 40 | If it is an input layer, the size of the input 41 | 42 | ___ 43 | 44 | ### units 45 | 46 | ● **units**: *number* 47 | 48 | *Defined in [reimprove/model.ts:23](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L23)* 49 | 50 | Number of neurons of this layer 51 | 52 | ___ 53 | 54 | ### `Optional` useBias 55 | 56 | ● **useBias**? : *boolean* 57 | 58 | *Defined in [reimprove/model.ts:28](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L28)* 59 | 60 | ___ -------------------------------------------------------------------------------- /docs/interfaces/graphedge.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [GraphEdge](graphedge.md) / 4 | 5 | # Interface: GraphEdge 6 | 7 | ## Hierarchy 8 | 9 | * **GraphEdge** 10 | 11 | ### Index 12 | 13 | #### Properties 14 | 15 | * [arrows](graphedge.md#arrows) 16 | * [color](graphedge.md#color) 17 | * [from](graphedge.md#from) 18 | * [id](graphedge.md#id) 19 | * [label](graphedge.md#label) 20 | * [to](graphedge.md#to) 21 | 22 | ## Properties 23 | 24 | ### arrows 25 | 26 | ● **arrows**: *string* 27 | 28 | *Defined in [reimprove/algorithms/q/qmatrix.ts:16](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qmatrix.ts#L16)* 29 | 30 | ___ 31 | 32 | ### color 33 | 34 | ● **color**: *string* 35 | 36 | *Defined in [reimprove/algorithms/q/qmatrix.ts:17](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qmatrix.ts#L17)* 37 | 38 | ___ 39 | 40 | ### from 41 | 42 | ● **from**: *number* 43 | 44 | *Defined in [reimprove/algorithms/q/qmatrix.ts:12](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qmatrix.ts#L12)* 45 | 46 | ___ 47 | 48 | ### id 49 | 50 | ● **id**: *number* 51 | 52 | *Defined in [reimprove/algorithms/q/qmatrix.ts:14](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qmatrix.ts#L14)* 53 | 54 | ___ 55 | 56 | ### label 57 | 58 | ● **label**: *string* 59 | 60 | *Defined in [reimprove/algorithms/q/qmatrix.ts:15](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qmatrix.ts#L15)* 61 | 62 | ___ 63 | 64 | ### to 65 | 66 | ● **to**: *number* 67 | 68 | *Defined in [reimprove/algorithms/q/qmatrix.ts:13](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qmatrix.ts#L13)* 69 | 70 | ___ -------------------------------------------------------------------------------- /test/agent.spec.ts: -------------------------------------------------------------------------------- 1 | import {expect} from "chai"; 2 | import {DQAgent} from "../src/reimprove/algorithms/deepq/dqagent"; 3 | import {Model} from "../src/reimprove"; 4 | import {ConvolutionalNeuralNetwork} from "../src/reimprove/networks"; 5 | import * as tf from "@tensorflow/tfjs"; 6 | 7 | const batchSize = 18; 8 | const agent = new DQAgent(null, {batchSize: batchSize}); 9 | 10 | describe('Dqagent', () => { 11 | beforeEach(() => agent.reset()); 12 | 13 | it('Should have the right configuration', () => { 14 | expect(agent.AgentConfig).to.deep.equal({ 15 | memorySize: 30000, 16 | batchSize: batchSize, 17 | temporalWindow: 1 18 | }); 19 | }); 20 | 21 | it('Should be trainable with conv', async () => { 22 | const network = new ConvolutionalNeuralNetwork(); 23 | network.InputShape = [40, 40, 4]; 24 | network.addConvolutionalLayer(32); 25 | network.addMaxPooling2DLayer(); 26 | network.addConvolutionalLayer(64); 27 | network.addMaxPooling2DLayer(); 28 | network.addNeuralNetworkLayers([128, {type: 'dense', activation:'softmax', units:2}]); 29 | const nmodel = Model.FromNetwork(network, {stepsPerEpoch:10, epochs:1}); 30 | nmodel.compile({loss: tf.losses.softmaxCrossEntropy, optimizer: 'adam'}); 31 | 32 | const convagent = new DQAgent(nmodel, {temporalWindow:0}); 33 | 34 | const input = new Array(40); 35 | for(let i =0;i < 40; ++i) { 36 | input[i] = new Array(40); 37 | for(let j = 0; j < 40; ++j) { 38 | input[i][j] = [100, 100, 100, 100]; 39 | } 40 | } 41 | 42 | expect(() => convagent.listen([input], 0)).to.not.throw(); 43 | expect(() => convagent.listen([input], 0)).to.not.throw(); 44 | await expect(async () => await convagent.learn(0.9, 1)).to.not.throw(); 45 | }); 46 | }); -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "reimprovejs", 3 | "version": "0.0.3", 4 | "private": false, 5 | "description": "A library using TensorFlow.js for Deep Reinforcement Learning", 6 | "main": "./dist/index.js", 7 | "types": "./dist/index.d.ts", 8 | "repository": { 9 | "type": "git", 10 | "url": "git+https://github.com/Pravez/ReImproveJS.git" 11 | }, 12 | "author": "Paul Breton { 6 | it('should create layers configurations with correct default args', () => { 7 | const network = new ConvolutionalNeuralNetwork(); 8 | network.InputShape = [5, 5, 1]; 9 | network.addConvolutionalLayers([32, 64]); 10 | network.addMaxPooling2DLayer({type: "maxpooling", strides: [5, 5]}); 11 | network.addNeuralNetworkLayers([{type: 'dense', units: 256, name: 'test'}, 128, 2]); 12 | 13 | const layers = network.getLayers(); 14 | 15 | expect(layers[0]).to.be.deep.equal({ 16 | type: 'convolutional', 17 | filters: 32, 18 | activation: 'relu', 19 | kernelSize: 3 20 | }); 21 | 22 | expect(layers[2]).to.be.deep.equal({ 23 | type: 'maxpooling', 24 | poolSize: 2, 25 | strides: [5, 5] 26 | }); 27 | 28 | expect(layers[3]).to.be.deep.equal({ 29 | type: 'flatten' 30 | }); 31 | 32 | expect(layers[4]).to.be.deep.equal({ 33 | type: 'dense', 34 | units: 256, 35 | activation: 'relu', 36 | name:'test' 37 | }); 38 | 39 | }); 40 | 41 | it('should create correct layers', () => { 42 | const network = new ConvolutionalNeuralNetwork(); 43 | network.InputShape = [5, 5, 1]; 44 | network.addConvolutionalLayers([32, 64]); 45 | network.addMaxPooling2DLayer(); 46 | network.addNeuralNetworkLayers([128, 128, 2]); 47 | 48 | const layers = network.createLayers(); 49 | expect(layers.length).to.be.equal(7); 50 | for (let i = 0; i < 6; ++i) { 51 | if (i < 2) 52 | expect(layers[i].name).to.contain('conv'); 53 | else if (i == 2) 54 | expect(layers[i].name).to.contain('pool'); 55 | else if (i == 3) 56 | expect(layers[i].name).to.contain('flatten'); 57 | else 58 | expect(layers[i].name).to.contain('dense'); 59 | } 60 | }); 61 | }); -------------------------------------------------------------------------------- /docs/interfaces/convolutionalnetworklayer.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [ConvolutionalNetworkLayer](convolutionalnetworklayer.md) / 4 | 5 | # Interface: ConvolutionalNetworkLayer 6 | 7 | ## Hierarchy 8 | 9 | * [Layer](layer.md) 10 | 11 | * **ConvolutionalNetworkLayer** 12 | 13 | * [ConvolutionalLayer](convolutionallayer.md) 14 | 15 | * [MaxPooling2DLayer](maxpooling2dlayer.md) 16 | 17 | ### Index 18 | 19 | #### Properties 20 | 21 | * [activation](convolutionalnetworklayer.md#optional-activation) 22 | * [inputShape](convolutionalnetworklayer.md#optional-inputshape) 23 | * [name](convolutionalnetworklayer.md#optional-name) 24 | * [type](convolutionalnetworklayer.md#type) 25 | * [useBias](convolutionalnetworklayer.md#optional-usebias) 26 | 27 | ## Properties 28 | 29 | ### `Optional` activation 30 | 31 | ● **activation**? : *string | any* 32 | 33 | *Inherited from [Layer](layer.md).[activation](layer.md#optional-activation)* 34 | 35 | *Defined in [reimprove/networks.ts:4](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L4)* 36 | 37 | ___ 38 | 39 | ### `Optional` inputShape 40 | 41 | ● **inputShape**? : *number[]* 42 | 43 | *Inherited from [Layer](layer.md).[inputShape](layer.md#optional-inputshape)* 44 | 45 | *Defined in [reimprove/networks.ts:8](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L8)* 46 | 47 | Do not use this field 48 | 49 | ___ 50 | 51 | ### `Optional` name 52 | 53 | ● **name**? : *string* 54 | 55 | *Inherited from [Layer](layer.md).[name](layer.md#optional-name)* 56 | 57 | *Defined in [reimprove/networks.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L6)* 58 | 59 | ___ 60 | 61 | ### type 62 | 63 | ● **type**: *"convolutional" | "maxpooling"* 64 | 65 | *Defined in [reimprove/networks.ts:12](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L12)* 66 | 67 | ___ 68 | 69 | ### `Optional` useBias 70 | 71 | ● **useBias**? : *boolean* 72 | 73 | *Inherited from [Layer](layer.md).[useBias](layer.md#optional-usebias)* 74 | 75 | *Defined in [reimprove/networks.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L5)* 76 | 77 | ___ -------------------------------------------------------------------------------- /docs/interfaces/teachertrackinginformation.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [TeacherTrackingInformation](teachertrackinginformation.md) / 4 | 5 | # Interface: TeacherTrackingInformation 6 | 7 | ## Hierarchy 8 | 9 | * **TeacherTrackingInformation** 10 | 11 | ### Index 12 | 13 | #### Properties 14 | 15 | * [currentLessonLength](teachertrackinginformation.md#currentlessonlength) 16 | * [epsilon](teachertrackinginformation.md#epsilon) 17 | * [gamma](teachertrackinginformation.md#gamma) 18 | * [lessonNumber](teachertrackinginformation.md#lessonnumber) 19 | * [maxLessons](teachertrackinginformation.md#maxlessons) 20 | * [name](teachertrackinginformation.md#name) 21 | * [students](teachertrackinginformation.md#students) 22 | 23 | ## Properties 24 | 25 | ### currentLessonLength 26 | 27 | ● **currentLessonLength**: *number* 28 | 29 | *Defined in [reimprove/teacher.ts:39](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L39)* 30 | 31 | ___ 32 | 33 | ### epsilon 34 | 35 | ● **epsilon**: *number* 36 | 37 | *Defined in [reimprove/teacher.ts:38](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L38)* 38 | 39 | ___ 40 | 41 | ### gamma 42 | 43 | ● **gamma**: *number* 44 | 45 | *Defined in [reimprove/teacher.ts:37](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L37)* 46 | 47 | ___ 48 | 49 | ### lessonNumber 50 | 51 | ● **lessonNumber**: *number* 52 | 53 | *Defined in [reimprove/teacher.ts:40](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L40)* 54 | 55 | ___ 56 | 57 | ### maxLessons 58 | 59 | ● **maxLessons**: *number* 60 | 61 | *Defined in [reimprove/teacher.ts:41](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L41)* 62 | 63 | ___ 64 | 65 | ### name 66 | 67 | ● **name**: *string* 68 | 69 | *Defined in [reimprove/teacher.ts:36](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L36)* 70 | 71 | ___ 72 | 73 | ### students 74 | 75 | ● **students**: *[AgentTrackingInformation](agenttrackinginformation.md)[]* 76 | 77 | *Defined in [reimprove/teacher.ts:42](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L42)* 78 | 79 | ___ -------------------------------------------------------------------------------- /test/model.spec.ts: -------------------------------------------------------------------------------- 1 | import * as tf from "@tensorflow/tfjs"; 2 | import {expect} from "chai"; 3 | import {LayerType, Model} from "../src/reimprove"; 4 | import {ConvolutionalNeuralNetwork} from "../src/reimprove/networks"; 5 | 6 | const screenInputSize = 20 * 20; 7 | const numActions = 3; 8 | const inputSize = screenInputSize * 1 + numActions * 1 + screenInputSize; 9 | const model = new Model(null, {stepsPerEpoch: 1, epochs: 1}); 10 | model.addLayer(LayerType.DENSE, {units: 128, activation: 'relu', inputShape: [inputSize]}); 11 | model.addLayer(LayerType.DENSE, {units: 128, activation: 'relu'}); 12 | model.addLayer(LayerType.DENSE, {units: numActions, activation: 'relu'}); 13 | model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); 14 | 15 | describe('Old model', () => { 16 | it('should give a 0 result', () => { 17 | expect(model.predict(tf.randomNormal([1, inputSize])).getHighestValue()).to.be.within(0, numActions); 18 | }); 19 | 20 | it('should give a random according to the size of the output tensor', () => { 21 | for (let i = 0; i < 10; ++i) 22 | expect(model.randomOutput()).to.be.within(0, numActions); 23 | }); 24 | }); 25 | 26 | const network = new ConvolutionalNeuralNetwork(); 27 | network.InputShape = [40, 40, 3]; 28 | network.addConvolutionalLayer(32); 29 | network.addMaxPooling2DLayer(); 30 | network.addConvolutionalLayer(64); 31 | network.addMaxPooling2DLayer(); 32 | network.addNeuralNetworkLayers([128, {type: 'dense', activation:'softmax', units:2}]); 33 | const nmodel = Model.FromNetwork(network, {stepsPerEpoch:10, epochs:1}); 34 | nmodel.compile({loss: tf.losses.softmaxCrossEntropy, optimizer: 'adam'}); 35 | 36 | describe('New model', () => { 37 | it('should have the right output size', () => { 38 | for (let i = 0; i < 10; ++i) 39 | expect(nmodel.randomOutput()).to.be.within(0, numActions); 40 | }); 41 | 42 | it('can be trained', async () => { 43 | const x = tf.randomNormal([1, 40, 40, 3]); 44 | const y = tf.tensor([[0, 1]]); 45 | 46 | for(let i = 0;i < 5; ++i) { 47 | await nmodel.fit(x, y); 48 | } 49 | 50 | let results = []; 51 | for(let i = 0;i < 10; ++i) 52 | results.push(nmodel.predict(x).getAction()); 53 | 54 | expect(results.reduce((p, c) => p + c)).to.be.greaterThan(7); 55 | }); 56 | }); 57 | 58 | -------------------------------------------------------------------------------- /docs/interfaces/neuralnetworklayer.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [NeuralNetworkLayer](neuralnetworklayer.md) / 4 | 5 | # Interface: NeuralNetworkLayer 6 | 7 | ## Hierarchy 8 | 9 | * [Layer](layer.md) 10 | 11 | * **NeuralNetworkLayer** 12 | 13 | * [DenseLayer](denselayer.md) 14 | 15 | * [DropoutLayer](dropoutlayer.md) 16 | 17 | ### Index 18 | 19 | #### Properties 20 | 21 | * [activation](neuralnetworklayer.md#optional-activation) 22 | * [inputShape](neuralnetworklayer.md#optional-inputshape) 23 | * [name](neuralnetworklayer.md#optional-name) 24 | * [type](neuralnetworklayer.md#type) 25 | * [units](neuralnetworklayer.md#units) 26 | * [useBias](neuralnetworklayer.md#optional-usebias) 27 | 28 | ## Properties 29 | 30 | ### `Optional` activation 31 | 32 | ● **activation**? : *string | any* 33 | 34 | *Inherited from [Layer](layer.md).[activation](layer.md#optional-activation)* 35 | 36 | *Defined in [reimprove/networks.ts:4](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L4)* 37 | 38 | ___ 39 | 40 | ### `Optional` inputShape 41 | 42 | ● **inputShape**? : *number[]* 43 | 44 | *Inherited from [Layer](layer.md).[inputShape](layer.md#optional-inputshape)* 45 | 46 | *Defined in [reimprove/networks.ts:8](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L8)* 47 | 48 | Do not use this field 49 | 50 | ___ 51 | 52 | ### `Optional` name 53 | 54 | ● **name**? : *string* 55 | 56 | *Inherited from [Layer](layer.md).[name](layer.md#optional-name)* 57 | 58 | *Defined in [reimprove/networks.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L6)* 59 | 60 | ___ 61 | 62 | ### type 63 | 64 | ● **type**: *"dense" | "dropout" | "flatten"* 65 | 66 | *Defined in [reimprove/networks.ts:30](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L30)* 67 | 68 | ___ 69 | 70 | ### units 71 | 72 | ● **units**: *number* 73 | 74 | *Defined in [reimprove/networks.ts:29](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L29)* 75 | 76 | ___ 77 | 78 | ### `Optional` useBias 79 | 80 | ● **useBias**? : *boolean* 81 | 82 | *Inherited from [Layer](layer.md).[useBias](layer.md#optional-usebias)* 83 | 84 | *Defined in [reimprove/networks.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L5)* 85 | 86 | ___ -------------------------------------------------------------------------------- /docs/interfaces/denselayer.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [DenseLayer](denselayer.md) / 4 | 5 | # Interface: DenseLayer 6 | 7 | ## Hierarchy 8 | 9 | * [NeuralNetworkLayer](neuralnetworklayer.md) 10 | 11 | * **DenseLayer** 12 | 13 | ### Index 14 | 15 | #### Properties 16 | 17 | * [activation](denselayer.md#optional-activation) 18 | * [inputShape](denselayer.md#optional-inputshape) 19 | * [name](denselayer.md#optional-name) 20 | * [type](denselayer.md#type) 21 | * [units](denselayer.md#units) 22 | * [useBias](denselayer.md#optional-usebias) 23 | 24 | ## Properties 25 | 26 | ### `Optional` activation 27 | 28 | ● **activation**? : *string | any* 29 | 30 | *Inherited from [Layer](layer.md).[activation](layer.md#optional-activation)* 31 | 32 | *Defined in [reimprove/networks.ts:4](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L4)* 33 | 34 | ___ 35 | 36 | ### `Optional` inputShape 37 | 38 | ● **inputShape**? : *number[]* 39 | 40 | *Inherited from [Layer](layer.md).[inputShape](layer.md#optional-inputshape)* 41 | 42 | *Defined in [reimprove/networks.ts:8](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L8)* 43 | 44 | Do not use this field 45 | 46 | ___ 47 | 48 | ### `Optional` name 49 | 50 | ● **name**? : *string* 51 | 52 | *Inherited from [Layer](layer.md).[name](layer.md#optional-name)* 53 | 54 | *Defined in [reimprove/networks.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L6)* 55 | 56 | ___ 57 | 58 | ### type 59 | 60 | ● **type**: *"dense"* 61 | 62 | *Overrides [NeuralNetworkLayer](neuralnetworklayer.md).[type](neuralnetworklayer.md#type)* 63 | 64 | *Defined in [reimprove/networks.ts:34](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L34)* 65 | 66 | ___ 67 | 68 | ### units 69 | 70 | ● **units**: *number* 71 | 72 | *Inherited from [NeuralNetworkLayer](neuralnetworklayer.md).[units](neuralnetworklayer.md#units)* 73 | 74 | *Defined in [reimprove/networks.ts:29](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L29)* 75 | 76 | ___ 77 | 78 | ### `Optional` useBias 79 | 80 | ● **useBias**? : *boolean* 81 | 82 | *Inherited from [Layer](layer.md).[useBias](layer.md#optional-usebias)* 83 | 84 | *Defined in [reimprove/networks.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L5)* 85 | 86 | ___ -------------------------------------------------------------------------------- /docs/interfaces/teachingconfig.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [TeachingConfig](teachingconfig.md) / 4 | 5 | # Interface: TeachingConfig 6 | 7 | ## Hierarchy 8 | 9 | * **TeachingConfig** 10 | 11 | ### Index 12 | 13 | #### Properties 14 | 15 | * [alpha](teachingconfig.md#optional-alpha) 16 | * [epsilon](teachingconfig.md#optional-epsilon) 17 | * [epsilonDecay](teachingconfig.md#optional-epsilondecay) 18 | * [epsilonMin](teachingconfig.md#optional-epsilonmin) 19 | * [gamma](teachingconfig.md#optional-gamma) 20 | * [lessonLength](teachingconfig.md#optional-lessonlength) 21 | * [lessonsQuantity](teachingconfig.md#optional-lessonsquantity) 22 | * [lessonsWithRandom](teachingconfig.md#optional-lessonswithrandom) 23 | 24 | ## Properties 25 | 26 | ### `Optional` alpha 27 | 28 | ● **alpha**? : *number* 29 | 30 | *Defined in [reimprove/teacher.ts:24](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L24)* 31 | 32 | ___ 33 | 34 | ### `Optional` epsilon 35 | 36 | ● **epsilon**? : *number* 37 | 38 | *Defined in [reimprove/teacher.ts:21](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L21)* 39 | 40 | ___ 41 | 42 | ### `Optional` epsilonDecay 43 | 44 | ● **epsilonDecay**? : *number* 45 | 46 | *Defined in [reimprove/teacher.ts:22](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L22)* 47 | 48 | ___ 49 | 50 | ### `Optional` epsilonMin 51 | 52 | ● **epsilonMin**? : *number* 53 | 54 | *Defined in [reimprove/teacher.ts:23](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L23)* 55 | 56 | ___ 57 | 58 | ### `Optional` gamma 59 | 60 | ● **gamma**? : *number* 61 | 62 | *Defined in [reimprove/teacher.ts:20](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L20)* 63 | 64 | ___ 65 | 66 | ### `Optional` lessonLength 67 | 68 | ● **lessonLength**? : *number* 69 | 70 | *Defined in [reimprove/teacher.ts:17](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L17)* 71 | 72 | ___ 73 | 74 | ### `Optional` lessonsQuantity 75 | 76 | ● **lessonsQuantity**? : *number* 77 | 78 | *Defined in [reimprove/teacher.ts:18](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L18)* 79 | 80 | ___ 81 | 82 | ### `Optional` lessonsWithRandom 83 | 84 | ● **lessonsWithRandom**? : *number* 85 | 86 | *Defined in [reimprove/teacher.ts:19](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L19)* 87 | 88 | ___ -------------------------------------------------------------------------------- /src/reimprove/memory.ts: -------------------------------------------------------------------------------- 1 | import {Tensor} from "@tensorflow/tfjs-core"; 2 | import {sampleSize, random, range, sample} from "lodash"; 3 | 4 | export interface MemoryConfig { 5 | size: number; 6 | } 7 | 8 | export interface Memento { 9 | state: MementoTensor; 10 | action: number; 11 | reward: number; 12 | nextState: MementoTensor; 13 | } 14 | 15 | export interface MementoTensor { 16 | tensor: Tensor; 17 | references: number; 18 | } 19 | 20 | export class Memory { 21 | config: MemoryConfig; 22 | 23 | memory: Array; 24 | currentSize: number; 25 | 26 | constructor(config: MemoryConfig) { 27 | this.config = config; 28 | 29 | this.memory = new Array(this.config.size); 30 | this.currentSize = 0; 31 | } 32 | 33 | remember(memento: Memento, replaceIfFull: boolean = true) { 34 | memento.state.references += 1; 35 | memento.nextState.references += 1; 36 | 37 | if (this.currentSize < this.config.size) { 38 | this.memory[this.currentSize++] = memento; 39 | } else if (replaceIfFull) { 40 | let randPos = random(0, this.memory.length - 1); 41 | Memory.freeMemento(this.memory[randPos]); 42 | this.memory[randPos] = memento; 43 | } 44 | } 45 | 46 | sample(batchSize: number, unique = true) { 47 | let memslice = this.memory.slice(0, this.currentSize); 48 | if (unique) 49 | return sampleSize(memslice, batchSize); 50 | else 51 | return range(batchSize).map(() => sample(memslice)); 52 | } 53 | 54 | get CurrentSize() { 55 | return this.currentSize; 56 | } 57 | 58 | get Size() { 59 | return this.memory.length; 60 | } 61 | 62 | private static freeMemento(memento: Memento) { 63 | memento.nextState.references -= 1; 64 | memento.state.references -= 1; 65 | if (memento.nextState.references <= 0) 66 | memento.nextState.tensor.dispose(); 67 | if (memento.state.references <= 0) 68 | memento.state.tensor.dispose(); 69 | } 70 | 71 | reset(): void { 72 | this.memory.forEach(memento => { 73 | memento.state.tensor.dispose(); 74 | memento.nextState.tensor.dispose(); 75 | }); 76 | this.memory = new Array(this.config.size); 77 | this.currentSize = 0; 78 | } 79 | 80 | merge(other: Memory): void { 81 | other.memory.forEach(memento => this.remember(memento)); 82 | } 83 | } -------------------------------------------------------------------------------- /docs/classes/qaction.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [QAction](qaction.md) / 4 | 5 | # Class: QAction 6 | 7 | ## Hierarchy 8 | 9 | * **QAction** 10 | 11 | ### Index 12 | 13 | #### Constructors 14 | 15 | * [constructor](qaction.md#constructor) 16 | 17 | #### Properties 18 | 19 | * [data](qaction.md#private-data) 20 | * [name](qaction.md#private-name) 21 | 22 | #### Accessors 23 | 24 | * [Data](qaction.md#data) 25 | * [Name](qaction.md#name) 26 | 27 | ## Constructors 28 | 29 | ### constructor 30 | 31 | \+ **new QAction**(`name`: string, `data?`: [QActionData](../interfaces/qactiondata.md)): *[QAction](qaction.md)* 32 | 33 | *Defined in [reimprove/algorithms/q/qaction.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qaction.ts#L6)* 34 | 35 | **Parameters:** 36 | 37 | Name | Type | 38 | ------ | ------ | 39 | `name` | string | 40 | `data?` | [QActionData](../interfaces/qactiondata.md) | 41 | 42 | **Returns:** *[QAction](qaction.md)* 43 | 44 | ___ 45 | 46 | ## Properties 47 | 48 | ### `Private` data 49 | 50 | ● **data**: *[QActionData](../interfaces/qactiondata.md)* 51 | 52 | *Defined in [reimprove/algorithms/q/qaction.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qaction.ts#L6)* 53 | 54 | ___ 55 | 56 | ### `Private` name 57 | 58 | ● **name**: *string* 59 | 60 | *Defined in [reimprove/algorithms/q/qaction.ts:8](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qaction.ts#L8)* 61 | 62 | ___ 63 | 64 | ## Accessors 65 | 66 | ### Data 67 | 68 | ● **get Data**(): *[QActionData](../interfaces/qactiondata.md)* 69 | 70 | *Defined in [reimprove/algorithms/q/qaction.ts:12](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qaction.ts#L12)* 71 | 72 | **Returns:** *[QActionData](../interfaces/qactiondata.md)* 73 | 74 | ● **set Data**(`data`: [QActionData](../interfaces/qactiondata.md)): *void* 75 | 76 | *Defined in [reimprove/algorithms/q/qaction.ts:16](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qaction.ts#L16)* 77 | 78 | **Parameters:** 79 | 80 | Name | Type | 81 | ------ | ------ | 82 | `data` | [QActionData](../interfaces/qactiondata.md) | 83 | 84 | **Returns:** *void* 85 | 86 | ___ 87 | 88 | ### Name 89 | 90 | ● **get Name**(): *string* 91 | 92 | *Defined in [reimprove/algorithms/q/qaction.ts:20](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qaction.ts#L20)* 93 | 94 | **Returns:** *string* 95 | 96 | ___ -------------------------------------------------------------------------------- /docs/interfaces/maxpooling2dlayer.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [MaxPooling2DLayer](maxpooling2dlayer.md) / 4 | 5 | # Interface: MaxPooling2DLayer 6 | 7 | ## Hierarchy 8 | 9 | * [ConvolutionalNetworkLayer](convolutionalnetworklayer.md) 10 | 11 | * **MaxPooling2DLayer** 12 | 13 | ### Index 14 | 15 | #### Properties 16 | 17 | * [activation](maxpooling2dlayer.md#optional-activation) 18 | * [inputShape](maxpooling2dlayer.md#optional-inputshape) 19 | * [name](maxpooling2dlayer.md#optional-name) 20 | * [poolSize](maxpooling2dlayer.md#optional-poolsize) 21 | * [strides](maxpooling2dlayer.md#optional-strides) 22 | * [type](maxpooling2dlayer.md#type) 23 | * [useBias](maxpooling2dlayer.md#optional-usebias) 24 | 25 | ## Properties 26 | 27 | ### `Optional` activation 28 | 29 | ● **activation**? : *string | any* 30 | 31 | *Inherited from [Layer](layer.md).[activation](layer.md#optional-activation)* 32 | 33 | *Defined in [reimprove/networks.ts:4](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L4)* 34 | 35 | ___ 36 | 37 | ### `Optional` inputShape 38 | 39 | ● **inputShape**? : *number[]* 40 | 41 | *Inherited from [Layer](layer.md).[inputShape](layer.md#optional-inputshape)* 42 | 43 | *Defined in [reimprove/networks.ts:8](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L8)* 44 | 45 | Do not use this field 46 | 47 | ___ 48 | 49 | ### `Optional` name 50 | 51 | ● **name**? : *string* 52 | 53 | *Inherited from [Layer](layer.md).[name](layer.md#optional-name)* 54 | 55 | *Defined in [reimprove/networks.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L6)* 56 | 57 | ___ 58 | 59 | ### `Optional` poolSize 60 | 61 | ● **poolSize**? : *number* 62 | 63 | *Defined in [reimprove/networks.ts:23](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L23)* 64 | 65 | ___ 66 | 67 | ### `Optional` strides 68 | 69 | ● **strides**? : *[number, number]* 70 | 71 | *Defined in [reimprove/networks.ts:24](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L24)* 72 | 73 | ___ 74 | 75 | ### type 76 | 77 | ● **type**: *"maxpooling"* 78 | 79 | *Overrides [ConvolutionalNetworkLayer](convolutionalnetworklayer.md).[type](convolutionalnetworklayer.md#type)* 80 | 81 | *Defined in [reimprove/networks.ts:25](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L25)* 82 | 83 | ___ 84 | 85 | ### `Optional` useBias 86 | 87 | ● **useBias**? : *boolean* 88 | 89 | *Inherited from [Layer](layer.md).[useBias](layer.md#optional-usebias)* 90 | 91 | *Defined in [reimprove/networks.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L5)* 92 | 93 | ___ -------------------------------------------------------------------------------- /docs/interfaces/qagentconfig.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [QAgentConfig](qagentconfig.md) / 4 | 5 | # Interface: QAgentConfig 6 | 7 | ## Hierarchy 8 | 9 | * [AgentConfig](agentconfig.md) 10 | 11 | * **QAgentConfig** 12 | 13 | ### Index 14 | 15 | #### Properties 16 | 17 | * [actions](qagentconfig.md#optional-actions) 18 | * [createMatrixDynamically](qagentconfig.md#optional-creatematrixdynamically) 19 | * [dataHash](qagentconfig.md#datahash) 20 | * [gamma](qagentconfig.md#optional-gamma) 21 | * [initialState](qagentconfig.md#optional-initialstate) 22 | * [memorySize](qagentconfig.md#optional-memorysize) 23 | * [startingData](qagentconfig.md#optional-startingdata) 24 | 25 | ## Properties 26 | 27 | ### `Optional` actions 28 | 29 | ● **actions**? : *`Array`* 30 | 31 | *Defined in [reimprove/algorithms/agent_config.ts:15](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L15)* 32 | 33 | ___ 34 | 35 | ### `Optional` createMatrixDynamically 36 | 37 | ● **createMatrixDynamically**? : *boolean* 38 | 39 | *Defined in [reimprove/algorithms/agent_config.ts:14](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L14)* 40 | 41 | ___ 42 | 43 | ### dataHash 44 | 45 | ● **dataHash**: *function* 46 | 47 | *Defined in [reimprove/algorithms/agent_config.ts:17](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L17)* 48 | 49 | #### Type declaration: 50 | 51 | ▸ (`data`: [QStateData](qstatedata.md)): *string* 52 | 53 | **Parameters:** 54 | 55 | Name | Type | 56 | ------ | ------ | 57 | `data` | [QStateData](qstatedata.md) | 58 | 59 | ___ 60 | 61 | ### `Optional` gamma 62 | 63 | ● **gamma**? : *number* 64 | 65 | *Defined in [reimprove/algorithms/agent_config.ts:19](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L19)* 66 | 67 | ___ 68 | 69 | ### `Optional` initialState 70 | 71 | ● **initialState**? : *[QStateData](qstatedata.md)* 72 | 73 | *Defined in [reimprove/algorithms/agent_config.ts:18](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L18)* 74 | 75 | ___ 76 | 77 | ### `Optional` memorySize 78 | 79 | ● **memorySize**? : *number* 80 | 81 | *Inherited from [AgentConfig](agentconfig.md).[memorySize](agentconfig.md#optional-memorysize)* 82 | 83 | *Defined in [reimprove/algorithms/agent_config.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L5)* 84 | 85 | ___ 86 | 87 | ### `Optional` startingData 88 | 89 | ● **startingData**? : *[QStateData](qstatedata.md)* 90 | 91 | *Defined in [reimprove/algorithms/agent_config.ts:16](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/agent_config.ts#L16)* 92 | 93 | ___ -------------------------------------------------------------------------------- /docs/interfaces/dropoutlayer.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [DropoutLayer](dropoutlayer.md) / 4 | 5 | # Interface: DropoutLayer 6 | 7 | ## Hierarchy 8 | 9 | * [NeuralNetworkLayer](neuralnetworklayer.md) 10 | 11 | * **DropoutLayer** 12 | 13 | ### Index 14 | 15 | #### Properties 16 | 17 | * [activation](dropoutlayer.md#optional-activation) 18 | * [inputShape](dropoutlayer.md#optional-inputshape) 19 | * [name](dropoutlayer.md#optional-name) 20 | * [rate](dropoutlayer.md#rate) 21 | * [seed](dropoutlayer.md#optional-seed) 22 | * [type](dropoutlayer.md#type) 23 | * [units](dropoutlayer.md#units) 24 | * [useBias](dropoutlayer.md#optional-usebias) 25 | 26 | ## Properties 27 | 28 | ### `Optional` activation 29 | 30 | ● **activation**? : *string | any* 31 | 32 | *Inherited from [Layer](layer.md).[activation](layer.md#optional-activation)* 33 | 34 | *Defined in [reimprove/networks.ts:4](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L4)* 35 | 36 | ___ 37 | 38 | ### `Optional` inputShape 39 | 40 | ● **inputShape**? : *number[]* 41 | 42 | *Inherited from [Layer](layer.md).[inputShape](layer.md#optional-inputshape)* 43 | 44 | *Defined in [reimprove/networks.ts:8](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L8)* 45 | 46 | Do not use this field 47 | 48 | ___ 49 | 50 | ### `Optional` name 51 | 52 | ● **name**? : *string* 53 | 54 | *Inherited from [Layer](layer.md).[name](layer.md#optional-name)* 55 | 56 | *Defined in [reimprove/networks.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L6)* 57 | 58 | ___ 59 | 60 | ### rate 61 | 62 | ● **rate**: *number* 63 | 64 | *Defined in [reimprove/networks.ts:39](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L39)* 65 | 66 | ___ 67 | 68 | ### `Optional` seed 69 | 70 | ● **seed**? : *number* 71 | 72 | *Defined in [reimprove/networks.ts:40](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L40)* 73 | 74 | ___ 75 | 76 | ### type 77 | 78 | ● **type**: *"dropout"* 79 | 80 | *Overrides [NeuralNetworkLayer](neuralnetworklayer.md).[type](neuralnetworklayer.md#type)* 81 | 82 | *Defined in [reimprove/networks.ts:38](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L38)* 83 | 84 | ___ 85 | 86 | ### units 87 | 88 | ● **units**: *number* 89 | 90 | *Inherited from [NeuralNetworkLayer](neuralnetworklayer.md).[units](neuralnetworklayer.md#units)* 91 | 92 | *Defined in [reimprove/networks.ts:29](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L29)* 93 | 94 | ___ 95 | 96 | ### `Optional` useBias 97 | 98 | ● **useBias**? : *boolean* 99 | 100 | *Inherited from [Layer](layer.md).[useBias](layer.md#optional-usebias)* 101 | 102 | *Defined in [reimprove/networks.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L5)* 103 | 104 | ___ -------------------------------------------------------------------------------- /docs/classes/result.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [Result](result.md) / 4 | 5 | # Class: Result 6 | 7 | Just a little wrapper around the result of a request to TensorflowJS. Because every single result is made through WebGL, 8 | we need to create async tasks. So we remove the async side by using the dataSync() method to get at the moment the result, 9 | instead of returning a Promise. 10 | 11 | ## Hierarchy 12 | 13 | * **Result** 14 | 15 | ### Index 16 | 17 | #### Constructors 18 | 19 | * [constructor](result.md#constructor) 20 | 21 | #### Properties 22 | 23 | * [result](result.md#private-result) 24 | 25 | #### Methods 26 | 27 | * [getAction](result.md#getaction) 28 | * [getHighestValue](result.md#gethighestvalue) 29 | * [getResultAndDispose](result.md#private-getresultanddispose) 30 | * [getValue](result.md#getvalue) 31 | 32 | ## Constructors 33 | 34 | ### constructor 35 | 36 | \+ **new Result**(`result`: `Tensor`): *[Result](result.md)* 37 | 38 | *Defined in [reimprove/model.ts:174](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L174)* 39 | 40 | **Parameters:** 41 | 42 | Name | Type | 43 | ------ | ------ | 44 | `result` | `Tensor` | 45 | 46 | **Returns:** *[Result](result.md)* 47 | 48 | ___ 49 | 50 | ## Properties 51 | 52 | ### `Private` result 53 | 54 | ● **result**: *`Tensor`* 55 | 56 | *Defined in [reimprove/model.ts:176](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L176)* 57 | 58 | ___ 59 | 60 | ## Methods 61 | 62 | ### getAction 63 | 64 | ▸ **getAction**(): *number* 65 | 66 | *Defined in [reimprove/model.ts:196](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L196)* 67 | 68 | Returns the index of the highest value of an 1D tensor 69 | 70 | **Returns:** *number* 71 | 72 | ___ 73 | 74 | ### getHighestValue 75 | 76 | ▸ **getHighestValue**(): *number* 77 | 78 | *Defined in [reimprove/model.ts:188](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L188)* 79 | 80 | Returns the highest value of an 1D tensor 81 | 82 | **Returns:** *number* 83 | 84 | ___ 85 | 86 | ### `Private` getResultAndDispose 87 | 88 | ▸ **getResultAndDispose**(`t`: `Tensor`): *`Float32Array` | `Int32Array` | `Uint8Array`* 89 | 90 | *Defined in [reimprove/model.ts:179](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L179)* 91 | 92 | **Parameters:** 93 | 94 | Name | Type | 95 | ------ | ------ | 96 | `t` | `Tensor` | 97 | 98 | **Returns:** *`Float32Array` | `Int32Array` | `Uint8Array`* 99 | 100 | ___ 101 | 102 | ### getValue 103 | 104 | ▸ **getValue**(): *`Int32Array` | `Float32Array` | `Uint8Array`* 105 | 106 | *Defined in [reimprove/model.ts:204](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L204)* 107 | 108 | Returns an array reflecting the initial result tensor 109 | 110 | **Returns:** *`Int32Array` | `Float32Array` | `Uint8Array`* 111 | 112 | ___ -------------------------------------------------------------------------------- /docs/interfaces/convolutionallayer.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [ConvolutionalLayer](convolutionallayer.md) / 4 | 5 | # Interface: ConvolutionalLayer 6 | 7 | ## Hierarchy 8 | 9 | * [ConvolutionalNetworkLayer](convolutionalnetworklayer.md) 10 | 11 | * **ConvolutionalLayer** 12 | 13 | ### Index 14 | 15 | #### Properties 16 | 17 | * [activation](convolutionallayer.md#optional-activation) 18 | * [filters](convolutionallayer.md#filters) 19 | * [inputShape](convolutionallayer.md#optional-inputshape) 20 | * [kernelSize](convolutionallayer.md#kernelsize) 21 | * [name](convolutionallayer.md#optional-name) 22 | * [strides](convolutionallayer.md#optional-strides) 23 | * [type](convolutionallayer.md#type) 24 | * [useBias](convolutionallayer.md#optional-usebias) 25 | 26 | ## Properties 27 | 28 | ### `Optional` activation 29 | 30 | ● **activation**? : *string | any* 31 | 32 | *Inherited from [Layer](layer.md).[activation](layer.md#optional-activation)* 33 | 34 | *Defined in [reimprove/networks.ts:4](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L4)* 35 | 36 | ___ 37 | 38 | ### filters 39 | 40 | ● **filters**: *number* 41 | 42 | *Defined in [reimprove/networks.ts:17](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L17)* 43 | 44 | ___ 45 | 46 | ### `Optional` inputShape 47 | 48 | ● **inputShape**? : *number[]* 49 | 50 | *Inherited from [Layer](layer.md).[inputShape](layer.md#optional-inputshape)* 51 | 52 | *Defined in [reimprove/networks.ts:8](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L8)* 53 | 54 | Do not use this field 55 | 56 | ___ 57 | 58 | ### kernelSize 59 | 60 | ● **kernelSize**: *number* 61 | 62 | *Defined in [reimprove/networks.ts:19](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L19)* 63 | 64 | ___ 65 | 66 | ### `Optional` name 67 | 68 | ● **name**? : *string* 69 | 70 | *Inherited from [Layer](layer.md).[name](layer.md#optional-name)* 71 | 72 | *Defined in [reimprove/networks.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L6)* 73 | 74 | ___ 75 | 76 | ### `Optional` strides 77 | 78 | ● **strides**? : *number | number[]* 79 | 80 | *Defined in [reimprove/networks.ts:18](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L18)* 81 | 82 | ___ 83 | 84 | ### type 85 | 86 | ● **type**: *"convolutional"* 87 | 88 | *Overrides [ConvolutionalNetworkLayer](convolutionalnetworklayer.md).[type](convolutionalnetworklayer.md#type)* 89 | 90 | *Defined in [reimprove/networks.ts:16](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L16)* 91 | 92 | ___ 93 | 94 | ### `Optional` useBias 95 | 96 | ● **useBias**? : *boolean* 97 | 98 | *Inherited from [Layer](layer.md).[useBias](layer.md#optional-usebias)* 99 | 100 | *Defined in [reimprove/networks.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L5)* 101 | 102 | ___ -------------------------------------------------------------------------------- /docs/classes/typedwindow.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [TypedWindow](typedwindow.md) / 4 | 5 | # Class: TypedWindow <**T**> 6 | 7 | ## Type parameters 8 | 9 | ■` T` 10 | 11 | ## Hierarchy 12 | 13 | * **TypedWindow** 14 | 15 | ### Index 16 | 17 | #### Constructors 18 | 19 | * [constructor](typedwindow.md#constructor) 20 | 21 | #### Properties 22 | 23 | * [minSize](typedwindow.md#private-minsize) 24 | * [nullValue](typedwindow.md#private-nullvalue) 25 | * [size](typedwindow.md#private-size) 26 | * [window](typedwindow.md#private-window) 27 | 28 | #### Accessors 29 | 30 | * [Window](typedwindow.md#window) 31 | 32 | #### Methods 33 | 34 | * [add](typedwindow.md#add) 35 | * [mean](typedwindow.md#mean) 36 | 37 | ## Constructors 38 | 39 | ### constructor 40 | 41 | \+ **new TypedWindow**(`size`: number, `minSize`: number, `nullValue`: `T`): *[TypedWindow](typedwindow.md)* 42 | 43 | *Defined in [reimprove/misc/typed_window.ts:4](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/typed_window.ts#L4)* 44 | 45 | **Parameters:** 46 | 47 | Name | Type | 48 | ------ | ------ | 49 | `size` | number | 50 | `minSize` | number | 51 | `nullValue` | `T` | 52 | 53 | **Returns:** *[TypedWindow](typedwindow.md)* 54 | 55 | ___ 56 | 57 | ## Properties 58 | 59 | ### `Private` minSize 60 | 61 | ● **minSize**: *number* 62 | 63 | *Defined in [reimprove/misc/typed_window.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/typed_window.ts#L6)* 64 | 65 | ___ 66 | 67 | ### `Private` nullValue 68 | 69 | ● **nullValue**: *`T`* 70 | 71 | *Defined in [reimprove/misc/typed_window.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/typed_window.ts#L6)* 72 | 73 | ___ 74 | 75 | ### `Private` size 76 | 77 | ● **size**: *number* 78 | 79 | *Defined in [reimprove/misc/typed_window.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/typed_window.ts#L6)* 80 | 81 | ___ 82 | 83 | ### `Private` window 84 | 85 | ● **window**: *`Array`* 86 | 87 | *Defined in [reimprove/misc/typed_window.ts:4](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/typed_window.ts#L4)* 88 | 89 | ___ 90 | 91 | ## Accessors 92 | 93 | ### Window 94 | 95 | ● **get Window**(): *`T`[]* 96 | 97 | *Defined in [reimprove/misc/typed_window.ts:25](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/typed_window.ts#L25)* 98 | 99 | **Returns:** *`T`[]* 100 | 101 | ___ 102 | 103 | ## Methods 104 | 105 | ### add 106 | 107 | ▸ **add**(`value`: `T`): *void* 108 | 109 | *Defined in [reimprove/misc/typed_window.ts:10](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/typed_window.ts#L10)* 110 | 111 | **Parameters:** 112 | 113 | Name | Type | 114 | ------ | ------ | 115 | `value` | `T` | 116 | 117 | **Returns:** *void* 118 | 119 | ___ 120 | 121 | ### mean 122 | 123 | ▸ **mean**(): *number* 124 | 125 | *Defined in [reimprove/misc/typed_window.ts:17](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/typed_window.ts#L17)* 126 | 127 | **Returns:** *number* 128 | 129 | ___ -------------------------------------------------------------------------------- /test/teacher.spec.ts: -------------------------------------------------------------------------------- 1 | import {spy} from "sinon"; 2 | import {expect, use} from "chai"; 3 | import {range} from 'lodash'; 4 | 5 | import generated from "sinon-chai"; 6 | import {LayerType, Model} from "../src/reimprove"; 7 | import {Teacher, TeachingState} from "../src/reimprove/teacher"; 8 | import {DQAgent} from "../src/reimprove/algorithms/deepq/dqagent"; 9 | 10 | use(generated); 11 | 12 | const lessonLength = 50; 13 | const lessons = 5; 14 | const screenInputSize = 20 * 20; 15 | const numActions = 3; 16 | const inputSize = screenInputSize * 1 + numActions * 1 + screenInputSize; 17 | const model = new Model(null, {stepsPerEpoch: 1, epochs: 1}); 18 | model.addLayer(LayerType.DENSE, {units: 128, activation: 'relu', inputShape: [inputSize]}); 19 | model.addLayer(LayerType.DENSE, {units: 128, activation: 'relu'}); 20 | model.addLayer(LayerType.DENSE, {units: numActions, activation: 'relu'}); 21 | model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); 22 | 23 | const teacher = new Teacher({lessonsQuantity: lessons, lessonLength: lessonLength, gamma: 0.8, epsilonDecay: 0.990}); 24 | const agent = new DQAgent(model); 25 | 26 | 27 | 28 | 29 | describe('Teacher', () => { 30 | before(() => { 31 | teacher.affectStudent(agent); 32 | }); 33 | 34 | beforeEach(() => { 35 | teacher.reset(); 36 | agent.reset(); 37 | }); 38 | 39 | it('should have the right configuration', () => { 40 | expect(teacher.config).to.be.deep.equal({ 41 | lessonLength: lessonLength, 42 | lessonsQuantity: lessons, 43 | lessonsWithRandom: 2, 44 | gamma: 0.8, 45 | epsilon: 1, 46 | epsilonDecay: 0.990, 47 | epsilonMin: 0.05, 48 | alpha: 1 49 | }); 50 | }); 51 | 52 | it('should have started learning', () => { 53 | teacher.teach(range(screenInputSize)); 54 | 55 | expect(teacher.State).to.be.equal(TeachingState.EXPERIENCING); 56 | expect(teacher.currentLessonLength).to.be.equal(1); 57 | }); 58 | 59 | it('Should fire events', async () => { 60 | let lessonEnded = spy(); 61 | let lessonLearningEnded = spy(); 62 | let teachingEnded = spy(); 63 | teacher.OnLessonEnded = lessonEnded; 64 | teacher.OnLearningLessonEnded = lessonLearningEnded; 65 | teacher.OnTeachingEnded = teachingEnded; 66 | 67 | for(let i = 0;i < lessonLength*lessons+1; ++i) { 68 | await teacher.teach(range(screenInputSize)); 69 | } 70 | 71 | expect(lessonEnded).to.have.been.calledWith(teacher.Name); 72 | expect(lessonLearningEnded).to.have.been.calledWith(teacher.Name); 73 | expect(teachingEnded).to.have.been.calledWith(teacher.Name); 74 | 75 | expect(lessonEnded).to.be.calledBefore(lessonLearningEnded); 76 | expect(lessonLearningEnded).to.be.calledBefore(teachingEnded); 77 | }); 78 | 79 | it('Should end in the testing state', async () => { 80 | expect(teacher.state).to.be.equal(TeachingState.NONE); 81 | 82 | for(let i = 0;i < lessonLength*lessons+1; ++i) { 83 | await teacher.teach(range(screenInputSize)); 84 | } 85 | 86 | expect(teacher.state).to.be.equal(TeachingState.TESTING); 87 | }); 88 | 89 | it('should have decreasing epsilon', async () => { 90 | for(let i = 0;i < lessonLength*lessons+1; ++i) { 91 | await teacher.teach(range(screenInputSize)); 92 | } 93 | }) 94 | }); -------------------------------------------------------------------------------- /docs/interfaces/flattenlayer.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [FlattenLayer](flattenlayer.md) / 4 | 5 | # Interface: FlattenLayer 6 | 7 | ## Hierarchy 8 | 9 | * [Layer](layer.md) 10 | 11 | * **FlattenLayer** 12 | 13 | ### Index 14 | 15 | #### Properties 16 | 17 | * [activation](flattenlayer.md#optional-activation) 18 | * [batchInputShape](flattenlayer.md#optional-batchinputshape) 19 | * [batchSize](flattenlayer.md#optional-batchsize) 20 | * [inputShape](flattenlayer.md#optional-inputshape) 21 | * [name](flattenlayer.md#optional-name) 22 | * [trainable](flattenlayer.md#optional-trainable) 23 | * [type](flattenlayer.md#type) 24 | * [updatable](flattenlayer.md#optional-updatable) 25 | * [useBias](flattenlayer.md#optional-usebias) 26 | * [weights](flattenlayer.md#optional-weights) 27 | 28 | ## Properties 29 | 30 | ### `Optional` activation 31 | 32 | ● **activation**? : *string | any* 33 | 34 | *Inherited from [Layer](layer.md).[activation](layer.md#optional-activation)* 35 | 36 | *Defined in [reimprove/networks.ts:4](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L4)* 37 | 38 | ___ 39 | 40 | ### `Optional` batchInputShape 41 | 42 | ● **batchInputShape**? : *number[]* 43 | 44 | *Defined in [reimprove/networks.ts:44](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L44)* 45 | 46 | ___ 47 | 48 | ### `Optional` batchSize 49 | 50 | ● **batchSize**? : *number* 51 | 52 | *Defined in [reimprove/networks.ts:45](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L45)* 53 | 54 | ___ 55 | 56 | ### `Optional` inputShape 57 | 58 | ● **inputShape**? : *number[]* 59 | 60 | *Inherited from [Layer](layer.md).[inputShape](layer.md#optional-inputshape)* 61 | 62 | *Defined in [reimprove/networks.ts:8](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L8)* 63 | 64 | Do not use this field 65 | 66 | ___ 67 | 68 | ### `Optional` name 69 | 70 | ● **name**? : *string* 71 | 72 | *Inherited from [Layer](layer.md).[name](layer.md#optional-name)* 73 | 74 | *Defined in [reimprove/networks.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L6)* 75 | 76 | ___ 77 | 78 | ### `Optional` trainable 79 | 80 | ● **trainable**? : *boolean* 81 | 82 | *Defined in [reimprove/networks.ts:46](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L46)* 83 | 84 | ___ 85 | 86 | ### type 87 | 88 | ● **type**: *"flatten"* 89 | 90 | *Defined in [reimprove/networks.ts:49](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L49)* 91 | 92 | ___ 93 | 94 | ### `Optional` updatable 95 | 96 | ● **updatable**? : *boolean* 97 | 98 | *Defined in [reimprove/networks.ts:47](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L47)* 99 | 100 | ___ 101 | 102 | ### `Optional` useBias 103 | 104 | ● **useBias**? : *boolean* 105 | 106 | *Inherited from [Layer](layer.md).[useBias](layer.md#optional-usebias)* 107 | 108 | *Defined in [reimprove/networks.ts:5](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L5)* 109 | 110 | ___ 111 | 112 | ### `Optional` weights 113 | 114 | ● **weights**? : *`Tensor`[]* 115 | 116 | *Defined in [reimprove/networks.ts:48](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L48)* 117 | 118 | ___ -------------------------------------------------------------------------------- /test/reimprove.spec.ts: -------------------------------------------------------------------------------- 1 | import {expect, use} from 'chai'; 2 | import {range, shuffle} from 'lodash'; 3 | import {memory} from '@tensorflow/tfjs-core'; 4 | 5 | import generated from "sinon-chai"; 6 | import {Academy, LayerType, Model} from "../src/reimprove"; 7 | 8 | use(generated); 9 | 10 | const initialInputSize = 100; 11 | const numActions = 2; 12 | const inputSize = initialInputSize + numActions + initialInputSize; 13 | 14 | const model = new Model(null, {stepsPerEpoch: 1, epochs: 1}); 15 | model.addLayer(LayerType.DENSE, {units: 128, activation: 'relu', inputShape: [inputSize]}); 16 | model.addLayer(LayerType.DENSE, {units: 128, activation: 'relu'}); 17 | model.addLayer(LayerType.DENSE, {units: numActions, activation: 'relu'}); 18 | model.compile({loss: 'meanSquaredError', optimizer: 'adam'}); 19 | 20 | const lessonLength = 10; 21 | const lessons = 10; 22 | const randomSteps = 0; 23 | const batchSize = 32; 24 | const memorySize = 100; 25 | 26 | const academy = new Academy(); 27 | const agent = academy.addAgent({model: model, agentConfig: {batchSize: batchSize, memorySize: memorySize}}); 28 | const teacher = academy.addTeacher({ 29 | lessonLength: lessonLength, 30 | lessonsQuantity: lessons, 31 | lessonsWithRandom: randomSteps 32 | }); 33 | academy.assignTeacherToAgent(agent, teacher); 34 | 35 | describe("ReImprove - Real", () => { 36 | beforeEach(() => { 37 | academy.resetTeachersAndAgents(); 38 | }); 39 | 40 | it('should have no tensor memory overflow', async () => { 41 | let input = shuffle(range(0, initialInputSize)).map(v => v / initialInputSize); 42 | 43 | let results; 44 | for (let i = 0; i < lessonLength * lessons; ++i) { 45 | results = await academy.step([ 46 | { 47 | teacherName: teacher, 48 | agentsInput: input 49 | } 50 | ]); 51 | academy.addRewardToAgent(agent, results.get(agent) == 1 ? 1.0 : -1.0); 52 | } 53 | 54 | expect(memory().numTensors).to.be.approximately(memorySize*2, memorySize*0.5); 55 | 56 | for (let i = 0; i < lessonLength * lessons; ++i) { 57 | results = await academy.step([ 58 | { 59 | teacherName: teacher, 60 | agentsInput: input 61 | } 62 | ]); 63 | academy.addRewardToAgent(agent, results.get(agent) == 1 ? 1.0 : -1.0); 64 | } 65 | 66 | expect(memory().numTensors).to.be.approximately(memorySize*2, memorySize*0.5); 67 | }); 68 | 69 | /*it('should have decreasing loss', async () => { 70 | let input = shuffle(range(0, initialInputSize)).map(v => v / initialInputSize); 71 | let losses: number[] = []; 72 | let rewards: number[] = []; 73 | 74 | academy.OnLearningLessonEnded(teacher, (t) => { 75 | losses.push(t.getData().students[0].averageLoss); 76 | rewards.push(t.getData().students[0].averageReward); 77 | }); 78 | 79 | let results; 80 | for (let i = 0; i < lessonLength * lessons; ++i) { 81 | results = await academy.step([ 82 | { 83 | teacherName: teacher, 84 | agentsInput: input 85 | } 86 | ]); 87 | academy.addRewardToAgent(agent, results.get(agent) == 1 ? 1.0 : -1.0); 88 | } 89 | 90 | expect(losses[0]).to.be.greaterThan(losses[losses.length - 1]); 91 | expect(rewards[0]).to.be.lessThan(rewards[rewards.length - 1]); 92 | });*/ 93 | }); -------------------------------------------------------------------------------- /docs/classes/learningdatalogger.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [LearningDataLogger](learningdatalogger.md) / 4 | 5 | # Class: LearningDataLogger 6 | 7 | ## Hierarchy 8 | 9 | * **LearningDataLogger** 10 | 11 | ### Index 12 | 13 | #### Constructors 14 | 15 | * [constructor](learningdatalogger.md#constructor) 16 | 17 | #### Properties 18 | 19 | * [academy](learningdatalogger.md#private-academy) 20 | * [memory](learningdatalogger.md#private-memory) 21 | * [parent](learningdatalogger.md#private-parent) 22 | * [tables](learningdatalogger.md#private-tables) 23 | 24 | #### Methods 25 | 26 | * [createMemoryTable](learningdatalogger.md#creatememorytable) 27 | * [createTeacherTable](learningdatalogger.md#createteachertable) 28 | * [dispose](learningdatalogger.md#dispose) 29 | * [updateTables](learningdatalogger.md#updatetables) 30 | * [tableStyle](learningdatalogger.md#static-tablestyle) 31 | 32 | ## Constructors 33 | 34 | ### constructor 35 | 36 | \+ **new LearningDataLogger**(`element`: string | `HTMLElement`, `academy`: [Academy](academy.md)): *[LearningDataLogger](learningdatalogger.md)* 37 | 38 | *Defined in [reimprove/misc/learning_data_logger.ts:9](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/learning_data_logger.ts#L9)* 39 | 40 | **Parameters:** 41 | 42 | Name | Type | 43 | ------ | ------ | 44 | `element` | string \| `HTMLElement` | 45 | `academy` | [Academy](academy.md) | 46 | 47 | **Returns:** *[LearningDataLogger](learningdatalogger.md)* 48 | 49 | ___ 50 | 51 | ## Properties 52 | 53 | ### `Private` academy 54 | 55 | ● **academy**: *[Academy](academy.md)* 56 | 57 | *Defined in [reimprove/misc/learning_data_logger.ts:11](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/learning_data_logger.ts#L11)* 58 | 59 | ___ 60 | 61 | ### `Private` memory 62 | 63 | ● **memory**: *`HTMLTableElement`* 64 | 65 | *Defined in [reimprove/misc/learning_data_logger.ts:9](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/learning_data_logger.ts#L9)* 66 | 67 | ___ 68 | 69 | ### `Private` parent 70 | 71 | ● **parent**: *`HTMLElement`* 72 | 73 | *Defined in [reimprove/misc/learning_data_logger.ts:7](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/learning_data_logger.ts#L7)* 74 | 75 | ___ 76 | 77 | ### `Private` tables 78 | 79 | ● **tables**: *object[]* 80 | 81 | *Defined in [reimprove/misc/learning_data_logger.ts:8](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/learning_data_logger.ts#L8)* 82 | 83 | ___ 84 | 85 | ## Methods 86 | 87 | ### createMemoryTable 88 | 89 | ▸ **createMemoryTable**(): *void* 90 | 91 | *Defined in [reimprove/misc/learning_data_logger.ts:25](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/learning_data_logger.ts#L25)* 92 | 93 | **Returns:** *void* 94 | 95 | ___ 96 | 97 | ### createTeacherTable 98 | 99 | ▸ **createTeacherTable**(`teacherName`: string): *void* 100 | 101 | *Defined in [reimprove/misc/learning_data_logger.ts:45](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/learning_data_logger.ts#L45)* 102 | 103 | **Parameters:** 104 | 105 | Name | Type | 106 | ------ | ------ | 107 | `teacherName` | string | 108 | 109 | **Returns:** *void* 110 | 111 | ___ 112 | 113 | ### dispose 114 | 115 | ▸ **dispose**(): *void* 116 | 117 | *Defined in [reimprove/misc/learning_data_logger.ts:94](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/learning_data_logger.ts#L94)* 118 | 119 | **Returns:** *void* 120 | 121 | ___ 122 | 123 | ### updateTables 124 | 125 | ▸ **updateTables**(`showMemory`: boolean): *void* 126 | 127 | *Defined in [reimprove/misc/learning_data_logger.ts:73](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/learning_data_logger.ts#L73)* 128 | 129 | **Parameters:** 130 | 131 | Name | Type | Default | 132 | ------ | ------ | ------ | 133 | `showMemory` | boolean | false | 134 | 135 | **Returns:** *void* 136 | 137 | ___ 138 | 139 | ### `Static` tableStyle 140 | 141 | ▸ **tableStyle**(`table`: `HTMLTableElement`): *void* 142 | 143 | *Defined in [reimprove/misc/learning_data_logger.ts:100](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/misc/learning_data_logger.ts#L100)* 144 | 145 | **Parameters:** 146 | 147 | Name | Type | 148 | ------ | ------ | 149 | `table` | `HTMLTableElement` | 150 | 151 | **Returns:** *void* 152 | 153 | ___ -------------------------------------------------------------------------------- /docs/classes/memory.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [Memory](memory.md) / 4 | 5 | # Class: Memory 6 | 7 | ## Hierarchy 8 | 9 | * **Memory** 10 | 11 | ### Index 12 | 13 | #### Constructors 14 | 15 | * [constructor](memory.md#constructor) 16 | 17 | #### Properties 18 | 19 | * [config](memory.md#config) 20 | * [currentSize](memory.md#currentsize) 21 | * [memory](memory.md#memory) 22 | 23 | #### Accessors 24 | 25 | * [CurrentSize](memory.md#currentsize) 26 | * [Size](memory.md#size) 27 | 28 | #### Methods 29 | 30 | * [merge](memory.md#merge) 31 | * [remember](memory.md#remember) 32 | * [reset](memory.md#reset) 33 | * [sample](memory.md#sample) 34 | * [freeMemento](memory.md#static-private-freememento) 35 | 36 | ## Constructors 37 | 38 | ### constructor 39 | 40 | \+ **new Memory**(`config`: [MemoryConfig](../interfaces/memoryconfig.md)): *[Memory](memory.md)* 41 | 42 | *Defined in [reimprove/memory.ts:24](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L24)* 43 | 44 | **Parameters:** 45 | 46 | Name | Type | 47 | ------ | ------ | 48 | `config` | [MemoryConfig](../interfaces/memoryconfig.md) | 49 | 50 | **Returns:** *[Memory](memory.md)* 51 | 52 | ___ 53 | 54 | ## Properties 55 | 56 | ### config 57 | 58 | ● **config**: *[MemoryConfig](../interfaces/memoryconfig.md)* 59 | 60 | *Defined in [reimprove/memory.ts:21](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L21)* 61 | 62 | ___ 63 | 64 | ### currentSize 65 | 66 | ● **currentSize**: *number* 67 | 68 | *Defined in [reimprove/memory.ts:24](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L24)* 69 | 70 | ___ 71 | 72 | ### memory 73 | 74 | ● **memory**: *`Array`* 75 | 76 | *Defined in [reimprove/memory.ts:23](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L23)* 77 | 78 | ___ 79 | 80 | ## Accessors 81 | 82 | ### CurrentSize 83 | 84 | ● **get CurrentSize**(): *number* 85 | 86 | *Defined in [reimprove/memory.ts:54](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L54)* 87 | 88 | **Returns:** *number* 89 | 90 | ___ 91 | 92 | ### Size 93 | 94 | ● **get Size**(): *number* 95 | 96 | *Defined in [reimprove/memory.ts:58](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L58)* 97 | 98 | **Returns:** *number* 99 | 100 | ___ 101 | 102 | ## Methods 103 | 104 | ### merge 105 | 106 | ▸ **merge**(`other`: [Memory](memory.md)): *void* 107 | 108 | *Defined in [reimprove/memory.ts:80](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L80)* 109 | 110 | **Parameters:** 111 | 112 | Name | Type | 113 | ------ | ------ | 114 | `other` | [Memory](memory.md) | 115 | 116 | **Returns:** *void* 117 | 118 | ___ 119 | 120 | ### remember 121 | 122 | ▸ **remember**(`memento`: [Memento](../interfaces/memento.md), `replaceIfFull`: boolean): *void* 123 | 124 | *Defined in [reimprove/memory.ts:33](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L33)* 125 | 126 | **Parameters:** 127 | 128 | Name | Type | Default | 129 | ------ | ------ | ------ | 130 | `memento` | [Memento](../interfaces/memento.md) | - | 131 | `replaceIfFull` | boolean | true | 132 | 133 | **Returns:** *void* 134 | 135 | ___ 136 | 137 | ### reset 138 | 139 | ▸ **reset**(): *void* 140 | 141 | *Defined in [reimprove/memory.ts:71](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L71)* 142 | 143 | **Returns:** *void* 144 | 145 | ___ 146 | 147 | ### sample 148 | 149 | ▸ **sample**(`batchSize`: number, `unique`: boolean): *[Memento](../interfaces/memento.md)[]* 150 | 151 | *Defined in [reimprove/memory.ts:46](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L46)* 152 | 153 | **Parameters:** 154 | 155 | Name | Type | Default | 156 | ------ | ------ | ------ | 157 | `batchSize` | number | - | 158 | `unique` | boolean | true | 159 | 160 | **Returns:** *[Memento](../interfaces/memento.md)[]* 161 | 162 | ___ 163 | 164 | ### `Static` `Private` freeMemento 165 | 166 | ▸ **freeMemento**(`memento`: [Memento](../interfaces/memento.md)): *void* 167 | 168 | *Defined in [reimprove/memory.ts:62](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/memory.ts#L62)* 169 | 170 | **Parameters:** 171 | 172 | Name | Type | 173 | ------ | ------ | 174 | `memento` | [Memento](../interfaces/memento.md) | 175 | 176 | **Returns:** *void* 177 | 178 | ___ -------------------------------------------------------------------------------- /docs/classes/neuralnetwork.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [NeuralNetwork](neuralnetwork.md) / 4 | 5 | # Class: NeuralNetwork 6 | 7 | ## Hierarchy 8 | 9 | * **NeuralNetwork** 10 | 11 | * [ConvolutionalNeuralNetwork](convolutionalneuralnetwork.md) 12 | 13 | ### Index 14 | 15 | #### Constructors 16 | 17 | * [constructor](neuralnetwork.md#constructor) 18 | 19 | #### Properties 20 | 21 | * [inputShape](neuralnetwork.md#protected-inputshape) 22 | * [neuralNetworkLayers](neuralnetwork.md#private-neuralnetworklayers) 23 | 24 | #### Accessors 25 | 26 | * [InputShape](neuralnetwork.md#inputshape) 27 | 28 | #### Methods 29 | 30 | * [addNeuralNetworkLayer](neuralnetwork.md#addneuralnetworklayer) 31 | * [addNeuralNetworkLayers](neuralnetwork.md#addneuralnetworklayers) 32 | * [createLayers](neuralnetwork.md#createlayers) 33 | * [getLayers](neuralnetwork.md#getlayers) 34 | 35 | #### Object literals 36 | 37 | * [DEFAULT_LAYER](neuralnetwork.md#static-private-default_layer) 38 | 39 | ## Constructors 40 | 41 | ### constructor 42 | 43 | \+ **new NeuralNetwork**(): *[NeuralNetwork](neuralnetwork.md)* 44 | 45 | *Defined in [reimprove/networks.ts:62](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L62)* 46 | 47 | **Returns:** *[NeuralNetwork](neuralnetwork.md)* 48 | 49 | ___ 50 | 51 | ## Properties 52 | 53 | ### `Protected` inputShape 54 | 55 | ● **inputShape**: *number[]* 56 | 57 | *Defined in [reimprove/networks.ts:55](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L55)* 58 | 59 | ___ 60 | 61 | ### `Private` neuralNetworkLayers 62 | 63 | ● **neuralNetworkLayers**: *[NeuralNetworkLayer](../interfaces/neuralnetworklayer.md)[]* 64 | 65 | *Defined in [reimprove/networks.ts:56](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L56)* 66 | 67 | ___ 68 | 69 | ## Accessors 70 | 71 | ### InputShape 72 | 73 | ● **set InputShape**(`shape`: number[]): *void* 74 | 75 | *Defined in [reimprove/networks.ts:85](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L85)* 76 | 77 | **Parameters:** 78 | 79 | Name | Type | 80 | ------ | ------ | 81 | `shape` | number[] | 82 | 83 | **Returns:** *void* 84 | 85 | ___ 86 | 87 | ## Methods 88 | 89 | ### addNeuralNetworkLayer 90 | 91 | ▸ **addNeuralNetworkLayer**(`layer`: number | [NeuralNetworkLayer](../interfaces/neuralnetworklayer.md)): *void* 92 | 93 | *Defined in [reimprove/networks.ts:69](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L69)* 94 | 95 | **Parameters:** 96 | 97 | Name | Type | 98 | ------ | ------ | 99 | `layer` | number \| [NeuralNetworkLayer](../interfaces/neuralnetworklayer.md) | 100 | 101 | **Returns:** *void* 102 | 103 | ___ 104 | 105 | ### addNeuralNetworkLayers 106 | 107 | ▸ **addNeuralNetworkLayers**(`layers`: `Array`): *void* 108 | 109 | *Defined in [reimprove/networks.ts:81](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L81)* 110 | 111 | **Parameters:** 112 | 113 | Name | Type | 114 | ------ | ------ | 115 | `layers` | `Array` | 116 | 117 | **Returns:** *void* 118 | 119 | ___ 120 | 121 | ### createLayers 122 | 123 | ▸ **createLayers**(`includeInputShape`: boolean): *`Array`* 124 | 125 | *Defined in [reimprove/networks.ts:89](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L89)* 126 | 127 | **Parameters:** 128 | 129 | Name | Type | Default | 130 | ------ | ------ | ------ | 131 | `includeInputShape` | boolean | true | 132 | 133 | **Returns:** *`Array`* 134 | 135 | ___ 136 | 137 | ### getLayers 138 | 139 | ▸ **getLayers**(): *[Layer](../interfaces/layer.md)[]* 140 | 141 | *Defined in [reimprove/networks.ts:99](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L99)* 142 | 143 | **Returns:** *[Layer](../interfaces/layer.md)[]* 144 | 145 | ___ 146 | 147 | ## Object literals 148 | 149 | ### `Static` `Private` DEFAULT_LAYER 150 | 151 | ### ■ **DEFAULT_LAYER**: *object* 152 | 153 | *Defined in [reimprove/networks.ts:58](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L58)* 154 | 155 | ### activation 156 | 157 | ● **activation**: *string* = "relu" 158 | 159 | *Defined in [reimprove/networks.ts:60](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L60)* 160 | 161 | ### type 162 | 163 | ● **type**: *"dense"* = "dense" 164 | 165 | *Defined in [reimprove/networks.ts:61](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L61)* 166 | 167 | ### units 168 | 169 | ● **units**: *number* = 32 170 | 171 | *Defined in [reimprove/networks.ts:59](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L59)* 172 | 173 | ___ -------------------------------------------------------------------------------- /docs/classes/abstractagent.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [AbstractAgent](abstractagent.md) / 4 | 5 | # Class: AbstractAgent 6 | 7 | ## Hierarchy 8 | 9 | * **AbstractAgent** 10 | 11 | * [DQAgent](dqagent.md) 12 | 13 | * [QAgent](qagent.md) 14 | 15 | ### Index 16 | 17 | #### Constructors 18 | 19 | * [constructor](abstractagent.md#protected-constructor) 20 | 21 | #### Properties 22 | 23 | * [agentConfig](abstractagent.md#protected-agentconfig) 24 | * [name](abstractagent.md#private-optional-name) 25 | 26 | #### Accessors 27 | 28 | * [AgentConfig](abstractagent.md#agentconfig) 29 | * [Name](abstractagent.md#name) 30 | 31 | #### Methods 32 | 33 | * [getTrackingInformation](abstractagent.md#abstract-gettrackinginformation) 34 | * [infer](abstractagent.md#abstract-infer) 35 | * [reset](abstractagent.md#abstract-reset) 36 | * [setAgentConfig](abstractagent.md#protected-setagentconfig) 37 | 38 | ## Constructors 39 | 40 | ### `Protected` constructor 41 | 42 | \+ **new AbstractAgent**(`agentConfig?`: [AgentConfig](../interfaces/agentconfig.md), `name?`: string): *[AbstractAgent](abstractagent.md)* 43 | 44 | *Defined in [reimprove/algorithms/abstract_agent.ts:10](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/abstract_agent.ts#L10)* 45 | 46 | **Parameters:** 47 | 48 | Name | Type | 49 | ------ | ------ | 50 | `agentConfig?` | [AgentConfig](../interfaces/agentconfig.md) | 51 | `name?` | string | 52 | 53 | **Returns:** *[AbstractAgent](abstractagent.md)* 54 | 55 | ___ 56 | 57 | ## Properties 58 | 59 | ### `Protected` agentConfig 60 | 61 | ● **agentConfig**: *[AgentConfig](../interfaces/agentconfig.md)* 62 | 63 | *Defined in [reimprove/algorithms/abstract_agent.ts:10](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/abstract_agent.ts#L10)* 64 | 65 | ___ 66 | 67 | ### `Private` `Optional` name 68 | 69 | ● **name**? : *string* 70 | 71 | *Defined in [reimprove/algorithms/abstract_agent.ts:12](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/abstract_agent.ts#L12)* 72 | 73 | ___ 74 | 75 | ## Accessors 76 | 77 | ### AgentConfig 78 | 79 | ● **get AgentConfig**(): *[AgentConfig](../interfaces/agentconfig.md)* 80 | 81 | *Defined in [reimprove/algorithms/abstract_agent.ts:16](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/abstract_agent.ts#L16)* 82 | 83 | **Returns:** *[AgentConfig](../interfaces/agentconfig.md)* 84 | 85 | ___ 86 | 87 | ### Name 88 | 89 | ● **get Name**(): *string* 90 | 91 | *Defined in [reimprove/algorithms/abstract_agent.ts:19](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/abstract_agent.ts#L19)* 92 | 93 | **Returns:** *string* 94 | 95 | ● **set Name**(`name`: string): *void* 96 | 97 | *Defined in [reimprove/algorithms/abstract_agent.ts:20](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/abstract_agent.ts#L20)* 98 | 99 | **Parameters:** 100 | 101 | Name | Type | 102 | ------ | ------ | 103 | `name` | string | 104 | 105 | **Returns:** *void* 106 | 107 | ___ 108 | 109 | ## Methods 110 | 111 | ### `Abstract` getTrackingInformation 112 | 113 | ▸ **getTrackingInformation**(): *[AgentTrackingInformation](../interfaces/agenttrackinginformation.md)* 114 | 115 | *Defined in [reimprove/algorithms/abstract_agent.ts:22](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/abstract_agent.ts#L22)* 116 | 117 | **Returns:** *[AgentTrackingInformation](../interfaces/agenttrackinginformation.md)* 118 | 119 | ___ 120 | 121 | ### `Abstract` infer 122 | 123 | ▸ **infer**(`input`: number[] | number[][] | [QAction](qaction.md), `epsilon?`: number, `keepTensors?`: boolean): *number | [QTransition](qtransition.md)* 124 | 125 | *Defined in [reimprove/algorithms/abstract_agent.ts:26](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/abstract_agent.ts#L26)* 126 | 127 | **Parameters:** 128 | 129 | Name | Type | 130 | ------ | ------ | 131 | `input` | number[] \| number[][] \| [QAction](qaction.md) | 132 | `epsilon?` | number | 133 | `keepTensors?` | boolean | 134 | 135 | **Returns:** *number | [QTransition](qtransition.md)* 136 | 137 | ___ 138 | 139 | ### `Abstract` reset 140 | 141 | ▸ **reset**(): *void* 142 | 143 | *Defined in [reimprove/algorithms/abstract_agent.ts:23](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/abstract_agent.ts#L23)* 144 | 145 | **Returns:** *void* 146 | 147 | ___ 148 | 149 | ### `Protected` setAgentConfig 150 | 151 | ▸ **setAgentConfig**(`config`: [AgentConfig](../interfaces/agentconfig.md)): *void* 152 | 153 | *Defined in [reimprove/algorithms/abstract_agent.ts:17](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/abstract_agent.ts#L17)* 154 | 155 | **Parameters:** 156 | 157 | Name | Type | 158 | ------ | ------ | 159 | `config` | [AgentConfig](../interfaces/agentconfig.md) | 160 | 161 | **Returns:** *void* 162 | 163 | ___ -------------------------------------------------------------------------------- /docs/classes/qtransition.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [QTransition](qtransition.md) / 4 | 5 | # Class: QTransition 6 | 7 | ## Hierarchy 8 | 9 | * **QTransition** 10 | 11 | ### Index 12 | 13 | #### Constructors 14 | 15 | * [constructor](qtransition.md#constructor) 16 | 17 | #### Properties 18 | 19 | * [QValue](qtransition.md#private-qvalue) 20 | * [action](qtransition.md#private-action) 21 | * [from](qtransition.md#private-from) 22 | * [id](qtransition.md#private-id) 23 | * [to](qtransition.md#private-to) 24 | * [transitionId](qtransition.md#static-private-transitionid) 25 | 26 | #### Accessors 27 | 28 | * [Action](qtransition.md#action) 29 | * [From](qtransition.md#from) 30 | * [Id](qtransition.md#id) 31 | * [Q](qtransition.md#q) 32 | * [To](qtransition.md#to) 33 | 34 | ## Constructors 35 | 36 | ### constructor 37 | 38 | \+ **new QTransition**(`from`: [QState](qstate.md), `to`: [QState](qstate.md), `action`: [QAction](qaction.md)): *[QTransition](qtransition.md)* 39 | 40 | *Defined in [reimprove/algorithms/q/qtransition.ts:9](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L9)* 41 | 42 | **Parameters:** 43 | 44 | Name | Type | 45 | ------ | ------ | 46 | `from` | [QState](qstate.md) | 47 | `to` | [QState](qstate.md) | 48 | `action` | [QAction](qaction.md) | 49 | 50 | **Returns:** *[QTransition](qtransition.md)* 51 | 52 | ___ 53 | 54 | ## Properties 55 | 56 | ### `Private` QValue 57 | 58 | ● **QValue**: *number* 59 | 60 | *Defined in [reimprove/algorithms/q/qtransition.ts:6](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L6)* 61 | 62 | ___ 63 | 64 | ### `Private` action 65 | 66 | ● **action**: *[QAction](qaction.md)* 67 | 68 | *Defined in [reimprove/algorithms/q/qtransition.ts:11](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L11)* 69 | 70 | ___ 71 | 72 | ### `Private` from 73 | 74 | ● **from**: *[QState](qstate.md)* 75 | 76 | *Defined in [reimprove/algorithms/q/qtransition.ts:11](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L11)* 77 | 78 | ___ 79 | 80 | ### `Private` id 81 | 82 | ● **id**: *number* 83 | 84 | *Defined in [reimprove/algorithms/q/qtransition.ts:7](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L7)* 85 | 86 | ___ 87 | 88 | ### `Private` to 89 | 90 | ● **to**: *[QState](qstate.md)* 91 | 92 | *Defined in [reimprove/algorithms/q/qtransition.ts:11](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L11)* 93 | 94 | ___ 95 | 96 | ### `Static` `Private` transitionId 97 | 98 | ■ **transitionId**: *number* = 0 99 | 100 | *Defined in [reimprove/algorithms/q/qtransition.ts:9](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L9)* 101 | 102 | ___ 103 | 104 | ## Accessors 105 | 106 | ### Action 107 | 108 | ● **get Action**(): *[QAction](qaction.md)* 109 | 110 | *Defined in [reimprove/algorithms/q/qtransition.ts:21](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L21)* 111 | 112 | **Returns:** *[QAction](qaction.md)* 113 | 114 | ___ 115 | 116 | ### From 117 | 118 | ● **get From**(): *[QState](qstate.md)* 119 | 120 | *Defined in [reimprove/algorithms/q/qtransition.ts:19](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L19)* 121 | 122 | **Returns:** *[QState](qstate.md)* 123 | 124 | ● **set From**(`state`: [QState](qstate.md)): *void* 125 | 126 | *Defined in [reimprove/algorithms/q/qtransition.ts:24](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L24)* 127 | 128 | **Parameters:** 129 | 130 | Name | Type | 131 | ------ | ------ | 132 | `state` | [QState](qstate.md) | 133 | 134 | **Returns:** *void* 135 | 136 | ___ 137 | 138 | ### Id 139 | 140 | ● **get Id**(): *number* 141 | 142 | *Defined in [reimprove/algorithms/q/qtransition.ts:26](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L26)* 143 | 144 | **Returns:** *number* 145 | 146 | ___ 147 | 148 | ### Q 149 | 150 | ● **get Q**(): *number* 151 | 152 | *Defined in [reimprove/algorithms/q/qtransition.ts:16](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L16)* 153 | 154 | **Returns:** *number* 155 | 156 | ● **set Q**(`qvalue`: number): *void* 157 | 158 | *Defined in [reimprove/algorithms/q/qtransition.ts:17](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L17)* 159 | 160 | **Parameters:** 161 | 162 | Name | Type | 163 | ------ | ------ | 164 | `qvalue` | number | 165 | 166 | **Returns:** *void* 167 | 168 | ___ 169 | 170 | ### To 171 | 172 | ● **get To**(): *[QState](qstate.md)* 173 | 174 | *Defined in [reimprove/algorithms/q/qtransition.ts:20](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L20)* 175 | 176 | **Returns:** *[QState](qstate.md)* 177 | 178 | ● **set To**(`state`: [QState](qstate.md)): *void* 179 | 180 | *Defined in [reimprove/algorithms/q/qtransition.ts:23](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qtransition.ts#L23)* 181 | 182 | **Parameters:** 183 | 184 | Name | Type | 185 | ------ | ------ | 186 | `state` | [QState](qstate.md) | 187 | 188 | **Returns:** *void* 189 | 190 | ___ -------------------------------------------------------------------------------- /src/reimprove/misc/learning_data_logger.ts: -------------------------------------------------------------------------------- 1 | import {Academy} from "../academy"; 2 | import {memory} from "@tensorflow/tfjs-core"; 3 | 4 | 5 | 6 | export class LearningDataLogger { 7 | private parent: HTMLElement; 8 | private tables: {teacherName: string, table: HTMLTableElement}[]; 9 | private memory: HTMLTableElement; 10 | 11 | constructor(element: string | HTMLElement, private academy: Academy) { 12 | if(typeof element == "string") { 13 | this.parent = document.getElementById(element); 14 | } else { 15 | this.parent = element; 16 | } 17 | 18 | this.tables = []; 19 | this.academy.Teachers.forEach(name => this.createTeacherTable(name)); 20 | this.tables.forEach(val => this.parent.appendChild(val.table)); 21 | this.createMemoryTable(); 22 | this.parent.appendChild(this.memory); 23 | } 24 | 25 | createMemoryTable(): void { 26 | this.memory = document.createElement('table'); 27 | const thead = this.memory.createTHead(); 28 | const tbody = this.memory.createTBody(); 29 | 30 | const hrow = thead.insertRow(0); 31 | hrow.insertCell(0).innerHTML = "Bytes allocated (undisposed)"; 32 | hrow.insertCell(1).innerHTML = "Unique tensors allocated"; 33 | hrow.insertCell(2).innerHTML = "Data buffers allocated"; 34 | hrow.insertCell(3).innerHTML = "Unreliable"; 35 | 36 | const brow = tbody.insertRow(0); 37 | brow.insertCell(0).innerHTML = ""; 38 | brow.insertCell(1).innerHTML = ""; 39 | brow.insertCell(2).innerHTML = ""; 40 | brow.insertCell(3).innerHTML = ""; 41 | 42 | LearningDataLogger.tableStyle(this.memory); 43 | } 44 | 45 | createTeacherTable(teacherName: string): void { 46 | const table = document.createElement('table'); 47 | const thead = table.createTHead(); 48 | const tbody = table.createTBody(); 49 | 50 | const hrow = thead.insertRow(0); 51 | hrow.insertCell(0).innerHTML = "Name"; 52 | hrow.insertCell(1).innerHTML = "Q loss average"; 53 | hrow.insertCell(2).innerHTML = "Average reward"; 54 | hrow.insertCell(3).innerHTML = "Epsilon"; 55 | hrow.insertCell(4).innerHTML = "Lesson number"; 56 | 57 | 58 | let studentsQuantity = this.academy.getTeacherData(teacherName).students.length; 59 | for(let i = 0;i < studentsQuantity; ++i) { 60 | const brow = tbody.insertRow(i); 61 | brow.insertCell(0).innerHTML = ""; 62 | brow.insertCell(1).innerHTML = ""; 63 | brow.insertCell(2).innerHTML = ""; 64 | brow.insertCell(3).innerHTML = ""; 65 | brow.insertCell(4).innerHTML = ""; 66 | } 67 | 68 | 69 | LearningDataLogger.tableStyle(table); 70 | this.tables.push({teacherName: teacherName, table: table}); 71 | } 72 | 73 | updateTables(showMemory: boolean = false): void { 74 | this.tables.forEach(table => { 75 | const tData = this.academy.getTeacherData(table.teacherName); 76 | tData.students.forEach((data, index) => { 77 | table.table.tBodies.item(0).rows.item(index).cells.item(0).innerHTML = data.name; 78 | table.table.tBodies.item(0).rows.item(index).cells.item(1).innerHTML = data.averageLoss.toString().substr(0, 5); 79 | table.table.tBodies.item(0).rows.item(index).cells.item(2).innerHTML = data.averageReward.toString().substr(0, 5); 80 | table.table.tBodies.item(0).rows.item(index).cells.item(3).innerHTML = tData.epsilon.toString().substr(0, 5); 81 | table.table.tBodies.item(0).rows.item(index).cells.item(4).innerHTML = tData.lessonNumber.toString(); 82 | }); 83 | }); 84 | 85 | if(showMemory) { 86 | const tfMemory = memory(); 87 | this.memory.tBodies.item(0).rows.item(0).cells.item(0).innerHTML = tfMemory.numBytes.toString(); 88 | this.memory.tBodies.item(0).rows.item(0).cells.item(1).innerHTML = tfMemory.numTensors.toString(); 89 | this.memory.tBodies.item(0).rows.item(0).cells.item(2).innerHTML = tfMemory.numDataBuffers.toString(); 90 | this.memory.tBodies.item(0).rows.item(0).cells.item(3).innerHTML = tfMemory.unreliable.toString(); 91 | } 92 | } 93 | 94 | dispose(): void { 95 | this.tables.forEach(table => { 96 | this.parent.removeChild(table.table); 97 | }) 98 | } 99 | 100 | static tableStyle(table: HTMLTableElement) { 101 | table.style.border = "medium solid #6495ed"; 102 | table.style.borderCollapse = "collapse"; 103 | 104 | table.tHead.style.fontFamily = "monospace"; 105 | table.tHead.style.border = "thin solid #6495ed"; 106 | table.tHead.style.padding = "5px"; 107 | table.tHead.style.backgroundColor = "#d0e3fa"; 108 | table.tHead.style.textAlign = "center"; 109 | table.tHead.style.margin = "8px"; 110 | 111 | for(let i = 0;i < table.tBodies.length; ++i) { 112 | const item = table.tBodies.item(i); 113 | item.style.fontFamily = "sans-serif"; 114 | item.style.border = "thin solid #6495ed"; 115 | item.style.padding = "5px"; 116 | item.style.textAlign = "center"; 117 | item.style.backgroundColor = "#ffffff"; 118 | item.style.margin = "3px"; 119 | } 120 | } 121 | } -------------------------------------------------------------------------------- /test/q.spec.ts: -------------------------------------------------------------------------------- 1 | import {expect} from "chai"; 2 | import {QAction, QAgent, QMatrix} from "../src/reimprove"; 3 | import {QStateData} from "../src/reimprove/algorithms/q/qstate"; 4 | 5 | let qagent: QAgent; 6 | let qmatrix: QMatrix; 7 | 8 | const hash = (data: QStateData) => { 9 | const val = (data.y + ((data.x + 1) / 2)); 10 | return `${data.x + val * val}`; 11 | }; 12 | 13 | qmatrix = new QMatrix(hash); 14 | qmatrix.registerAction(new QAction("LEFT")); 15 | qmatrix.registerAction(new QAction("RIGHT")); 16 | const middle = qmatrix.registerState({x: 1, y: 0}); 17 | const final = qmatrix.registerState({x: 2, y: 0}, 1.0).setFinal(); 18 | qmatrix.registerTransition("RIGHT", 19 | qmatrix.registerState({x: 0, y: 0}), 20 | middle, 21 | "LEFT" 22 | ); 23 | qmatrix.registerTransition("RIGHT", middle, final, "LEFT"); 24 | qmatrix.setStateAsInitial({x: 0, y: 0}); 25 | qmatrix.setStateAsFinal({x: 2, y: 0}); 26 | 27 | 28 | describe.skip("QLearning", () => { 29 | beforeEach(() => { 30 | qagent = new QAgent({dataHash: hash, gamma: 0.9}, qmatrix); 31 | }); 32 | 33 | it("should do the right transition and end in the right state", () => { 34 | let trans = qagent.infer(); 35 | 36 | expect(trans.Action.Name).to.be.eq("RIGHT"); 37 | qagent.learn(); 38 | expect(trans.Q).to.be.eq(0.); 39 | expect(qagent.CurrentState.Data).to.be.deep.eq({x: 1, y: 0}); 40 | }); 41 | 42 | it("Should have the right history", () => { 43 | for (let i = 0; i < 10; ++i) { 44 | qagent.infer(); 45 | qagent.learn(); 46 | if (!qagent.isPerforming()) 47 | qagent.restart(); 48 | } 49 | 50 | qagent.restart(); 51 | while (qagent.isPerforming()) { 52 | qagent.infer(); 53 | } 54 | 55 | expect(qagent.History[0].Q).to.be.greaterThan(0.); 56 | expect(qagent.History[1].Q).to.be.greaterThan(0.); 57 | expect(qagent.History[1].Q).to.be.greaterThan(qagent.History[0].Q); 58 | }); 59 | 60 | it("Should create the right path", () => { 61 | qmatrix = new QMatrix(hash); 62 | qmatrix.registerAction(new QAction("LEFT")); 63 | qmatrix.registerAction(new QAction("RIGHT")); 64 | qmatrix.registerAction(new QAction("JUMPRIGHT")); 65 | let middle = qmatrix.registerState({x: 1, y: 0}); 66 | let s3 = qmatrix.registerState({x: 2, y: 0}, 0.); 67 | let final = qmatrix.registerState({x: 3, y: 0}, -1.0).setFinal(); 68 | let s4 = qmatrix.registerState({x: 2, y: 1}, 0); 69 | let finalgood = qmatrix.registerState({x: 2, y: 2}, 1.0).setFinal(); 70 | qmatrix.registerTransition("RIGHT", qmatrix.registerState({x: 0, y: 0}), middle, "LEFT"); 71 | qmatrix.registerTransition("RIGHT", middle, s3, "LEFT"); 72 | qmatrix.registerTransition("RIGHT", s3, final, "LEFT"); 73 | qmatrix.registerTransition("JUMPRIGHT", s3, s4, "LEFT"); 74 | qmatrix.registerTransition("RIGHT", s4, finalgood, "LEFT"); 75 | qmatrix.setStateAsInitial({x: 0, y: 0}); 76 | 77 | qagent = new QAgent({dataHash: hash, gamma: 0.9}, qmatrix); 78 | qagent.CurrentState = qmatrix.InitialState; 79 | 80 | for (let i = 0; i < 5; ++i) { 81 | while (qagent.isPerforming()) { 82 | qagent.infer(); 83 | qagent.learn(); 84 | } 85 | 86 | qagent.restart(); 87 | } 88 | 89 | while (qagent.isPerforming()) { 90 | qagent.infer(); 91 | } 92 | 93 | expect(qagent.History.length).to.be.eq(4); 94 | expect(qagent.History.map(h => h.Action.Name)).to.be.deep.eq(["RIGHT", "RIGHT", "JUMPRIGHT", "RIGHT"]); 95 | expect(s3.takeAction(qmatrix.action("RIGHT")).Q).to.be.at.most(0.); 96 | expect(qagent.CurrentState).to.be.deep.eq(finalgood); 97 | }); 98 | 99 | it("should create dynamically the qmatrix", () => { 100 | const data = {x: 0, y: 0}; 101 | const gamma = 0.9; 102 | 103 | qagent = new QAgent({ 104 | dataHash: hash, 105 | initialState: data, 106 | gamma: gamma, 107 | createMatrixDynamically: true, 108 | actions: ["LEFT", "RIGHT"] 109 | }); 110 | 111 | while (qagent.isPerforming()) { 112 | switch (qagent.infer().Action.Name) { 113 | case "LEFT": 114 | data.x -= data.x > 0 ? 1 : 0; 115 | break; 116 | case "RIGHT": 117 | data.x += data.x < 4 ? 1 : 0; 118 | break; 119 | } 120 | 121 | 122 | qagent.learn(data); 123 | console.log(`State : ${qagent.CurrentState.Data.x}`); 124 | 125 | if (data.x === 3 && data.y === 0) 126 | qagent.finalState(1.0); 127 | } 128 | 129 | expect(data).to.be.deep.equal({x: 3, y: 0}); 130 | }); 131 | 132 | it("should produce a good graph output", () => { 133 | const data = {x: 0, y: 0}; 134 | const gamma = 0.9; 135 | 136 | qagent = new QAgent({ 137 | dataHash: hash, 138 | initialState: data, 139 | gamma: gamma, 140 | createMatrixDynamically: true, 141 | actions: ["LEFT", "RIGHT"] 142 | }); 143 | 144 | let graph = qagent.getStatesGraph(); 145 | 146 | expect(graph.nodes.length).to.be.equal(1); 147 | expect(graph.edges.length).to.be.equal(0); 148 | 149 | qagent.infer(); 150 | qagent.learn({x:1, y:0}); 151 | 152 | graph = qagent.getStatesGraph(); 153 | expect(graph.nodes.length).to.be.equal(2); 154 | expect(graph.edges.length).to.be.equal(1); 155 | }); 156 | }); -------------------------------------------------------------------------------- /src/reimprove/networks.ts: -------------------------------------------------------------------------------- 1 | import {layers, Tensor} from '@tensorflow/tfjs'; 2 | 3 | export interface Layer { 4 | activation?: string | any; 5 | useBias?: boolean; 6 | name?: string; 7 | /** Do not use this field */ 8 | inputShape?: number[]; 9 | } 10 | 11 | export interface ConvolutionalNetworkLayer extends Layer { 12 | type: 'convolutional' | 'maxpooling'; 13 | } 14 | 15 | export interface ConvolutionalLayer extends ConvolutionalNetworkLayer { 16 | type: 'convolutional'; 17 | filters: number; 18 | strides?: number | number[] 19 | kernelSize: number; 20 | } 21 | 22 | export interface MaxPooling2DLayer extends ConvolutionalNetworkLayer { 23 | poolSize?: number; 24 | strides?: [number, number]; 25 | type: 'maxpooling' 26 | } 27 | 28 | export interface NeuralNetworkLayer extends Layer { 29 | units: number; 30 | type: 'dense' | 'dropout' | 'flatten' 31 | } 32 | 33 | export interface DenseLayer extends NeuralNetworkLayer { 34 | type: 'dense'; 35 | } 36 | 37 | export interface DropoutLayer extends NeuralNetworkLayer { 38 | type: 'dropout'; 39 | rate: number; 40 | seed?: number; 41 | } 42 | 43 | export interface FlattenLayer extends Layer { 44 | batchInputShape?: number[]; 45 | batchSize?: number; 46 | trainable?: boolean; 47 | updatable?: boolean; 48 | weights?: Tensor[]; 49 | type: "flatten"; 50 | } 51 | 52 | 53 | 54 | export class NeuralNetwork { 55 | protected inputShape: number[]; 56 | private readonly neuralNetworkLayers: NeuralNetworkLayer[]; 57 | 58 | private static DEFAULT_LAYER: NeuralNetworkLayer = { 59 | units: 32, 60 | activation: "relu", 61 | type: 'dense' 62 | }; 63 | 64 | constructor() { 65 | this.neuralNetworkLayers = []; 66 | this.inputShape = [0]; 67 | } 68 | 69 | addNeuralNetworkLayer(layer: number | NeuralNetworkLayer): void { 70 | if (typeof layer == 'number') { 71 | this.neuralNetworkLayers.push({ 72 | units: layer, 73 | activation: NeuralNetwork.DEFAULT_LAYER.activation, 74 | type: 'dense' 75 | }) 76 | } else { 77 | this.neuralNetworkLayers.push({...NeuralNetwork.DEFAULT_LAYER, ...layer}) 78 | } 79 | } 80 | 81 | addNeuralNetworkLayers(layers: Array): void { 82 | layers.forEach(l => this.addNeuralNetworkLayer(l)); 83 | } 84 | 85 | set InputShape(shape: number[]) { 86 | this.inputShape = shape; 87 | } 88 | 89 | createLayers(includeInputShape: boolean = true): Array { 90 | const genLayers = []; 91 | if (includeInputShape) 92 | this.neuralNetworkLayers[0].inputShape = this.inputShape; 93 | for (let layer of this.neuralNetworkLayers) { 94 | genLayers.push(layer.type == "dense" ? layers.dense(layer) : layers.dropout(layer)) 95 | } 96 | return genLayers; 97 | } 98 | 99 | getLayers(): Layer[] { return this.neuralNetworkLayers; } 100 | } 101 | 102 | /** 103 | * @deprecated Do not use convolutional networks with ReImproveJS for now, they are not fully implemented and tested in 104 | * the library. 105 | */ 106 | export class ConvolutionalNeuralNetwork extends NeuralNetwork { 107 | private readonly convolutionalLayers: ConvolutionalNetworkLayer[]; 108 | private flattenLayer: FlattenLayer; 109 | 110 | private static DEFAULT_CONV_LAYER: ConvolutionalLayer = { 111 | filters: 32, 112 | kernelSize: 3, 113 | activation: 'relu', 114 | type: 'convolutional' 115 | }; 116 | 117 | private static DEFAULT_POOLING_LAYER: MaxPooling2DLayer = { 118 | poolSize: 2, 119 | strides: null, 120 | type: "maxpooling" 121 | }; 122 | 123 | constructor() { 124 | super(); 125 | this.convolutionalLayers = []; 126 | this.flattenLayer = {type: 'flatten'}; 127 | } 128 | 129 | addMaxPooling2DLayer(layer?: MaxPooling2DLayer): void { 130 | this.convolutionalLayers.push({...ConvolutionalNeuralNetwork.DEFAULT_POOLING_LAYER, ...layer}); 131 | } 132 | 133 | addConvolutionalLayer(layer: number | ConvolutionalNetworkLayer): void { 134 | if (typeof layer == 'number') { 135 | this.convolutionalLayers.push({ 136 | filters: layer, 137 | activation: ConvolutionalNeuralNetwork.DEFAULT_CONV_LAYER.activation, 138 | type: 'convolutional', 139 | kernelSize: ConvolutionalNeuralNetwork.DEFAULT_CONV_LAYER.kernelSize 140 | }) 141 | } else { 142 | this.convolutionalLayers.push({...ConvolutionalNeuralNetwork.DEFAULT_CONV_LAYER, ...layer}); 143 | } 144 | } 145 | 146 | addConvolutionalLayers(layers: Array): void { 147 | layers.forEach(l => this.addConvolutionalLayer(l)); 148 | } 149 | 150 | createLayers(includeInputShape: boolean = true): Array { 151 | const genLayers = []; 152 | this.convolutionalLayers[0].inputShape = this.inputShape; 153 | for (let layer of this.convolutionalLayers) { 154 | genLayers.push(layer.type == "convolutional" ? layers.conv2d(layer) : layers.maxPooling2d(layer)); 155 | } 156 | 157 | genLayers.push(layers.flatten(this.flattenLayer)); 158 | 159 | return genLayers.concat(super.createLayers(false)); 160 | } 161 | 162 | set FlattenLayer(layer: FlattenLayer) { 163 | this.flattenLayer = layer; 164 | } 165 | 166 | getLayers(): Layer[] { return (>this.convolutionalLayers).concat(this.flattenLayer, super.getLayers()); } 167 | } -------------------------------------------------------------------------------- /src/reimprove/algorithms/q/qagent.ts: -------------------------------------------------------------------------------- 1 | import {AbstractAgent} from "../abstract_agent"; 2 | import {AgentTrackingInformation, QAgentConfig} from "../agent_config"; 3 | import {QTransition} from "./qtransition"; 4 | import {QState, QStateData} from "./qstate"; 5 | import {GraphEdge, GraphNode, QMatrix} from "./qmatrix"; 6 | 7 | const DEFAULT_QAGENT_CONFIG: QAgentConfig = { 8 | createMatrixDynamically: false, 9 | dataHash: null, 10 | gamma: 0.9 11 | }; 12 | 13 | export class QAgent extends AbstractAgent { 14 | 15 | private history: Array; 16 | private previousTransition: QTransition; 17 | private currentState: QState; 18 | 19 | private qmatrix: QMatrix; 20 | private lossOnAlreadyVisited: boolean; 21 | 22 | constructor(config: QAgentConfig, qmatrix?: QMatrix, name?: string) { 23 | if ((!qmatrix || (qmatrix && !qmatrix.HashFunction)) && !config.dataHash) 24 | throw new Error("A hash function MUST be provided in the config parameter in order to hash correctly QStateData."); 25 | 26 | if (qmatrix && !qmatrix.InitialState) 27 | throw new Error("Please provide an initial state for your QMatrix (qmatrix.setInitialState())"); 28 | 29 | if (!qmatrix && !config.initialState) 30 | throw new Error("Please provide initial state data for your agent !"); 31 | 32 | super(config, name); 33 | this.AgentConfig = {...this.AgentConfig, ...{...DEFAULT_QAGENT_CONFIG, ...config}}; 34 | this.history = []; 35 | 36 | this.previousTransition = null; 37 | this.qmatrix = qmatrix ? qmatrix : new QMatrix(config.dataHash); 38 | if (!this.qmatrix.HashFunction) { 39 | this.qmatrix.HashFunction = config.dataHash; 40 | } 41 | this.currentState = qmatrix ? qmatrix.InitialState : this.qmatrix.registerState(config.initialState); 42 | this.lossOnAlreadyVisited = false; 43 | 44 | if (!qmatrix) { 45 | if (!config.createMatrixDynamically || !config.actions) 46 | throw new Error("Need actions and flag to create matrix dynamically. You should provide them or provide precreated QMatrix."); 47 | 48 | config.actions.forEach(a => this.qmatrix.registerAction(a)); 49 | this.qmatrix.Actions.forEach(a => this.qmatrix.registerTransition(a.Name, this.currentState, null)); 50 | this.qmatrix.setStateAsInitial(this.currentState); 51 | } 52 | } 53 | 54 | getTrackingInformation(): AgentTrackingInformation { 55 | return undefined; 56 | } 57 | 58 | restart(): void { 59 | this.history = []; 60 | this.currentState = this.qmatrix.InitialState; 61 | } 62 | 63 | infer(): QTransition { 64 | const action = QAgent.bestAction(...this.currentState.Transitions); 65 | this.previousTransition = this.currentState.takeAction(action.Action); 66 | 67 | this.history.push(this.previousTransition); 68 | 69 | return this.previousTransition; 70 | } 71 | 72 | isPerforming(): boolean { 73 | return !this.currentState.Final; 74 | } 75 | 76 | learn(data?: QStateData): void { 77 | if (this.previousTransition) { 78 | this.updateMatrix(data); 79 | const reward = this.previousTransition.To.Reward - (this.lossOnAlreadyVisited && this.history.indexOf(this.previousTransition) !== this.history.length - 1 ? 1 : 0); 80 | this.previousTransition.Q = reward + this.AgentConfig.gamma * QAgent.bestAction(...this.previousTransition.To.Transitions).Q; 81 | } 82 | } 83 | 84 | updateMatrix(data: QStateData) { 85 | if(!this.previousTransition.To) { 86 | let state: QState; 87 | if(this.qmatrix.exists(data)) { 88 | state = this.qmatrix.getStateFromData(data); 89 | } else { 90 | state = this.qmatrix.registerState(data); 91 | this.qmatrix.Actions.forEach(a => this.qmatrix.registerTransition(a.Name, state, null)); 92 | } 93 | 94 | this.qmatrix.updateTransition(this.previousTransition.Id, state); 95 | this.currentState = state; 96 | } else { 97 | this.currentState = this.previousTransition.To; 98 | } 99 | } 100 | 101 | finalState(reward: number, state?: QState): void { 102 | this.qmatrix.setStateAsFinal(state ? state : this.currentState); 103 | this.currentState.Reward = reward; 104 | } 105 | 106 | private static bestAction(...values: QTransition[]): QTransition { 107 | let bests = [values[0]]; 108 | for (let i = 1; i < values.length; ++i) { 109 | if (values[i].Q > bests[0].Q) 110 | bests = [values[i]]; 111 | if (values[i].Q === bests[0].Q) 112 | bests.push(values[i]); 113 | } 114 | 115 | return bests[Math.floor(Math.random() * Math.floor(bests.length))]; 116 | } 117 | 118 | get QMatrix() { 119 | return this.qmatrix; 120 | } 121 | 122 | set QMatrix(qmatrix: QMatrix) { 123 | this.qmatrix = qmatrix; 124 | } 125 | 126 | get History() { 127 | return this.history; 128 | } 129 | 130 | set CurrentState(state: QState) { 131 | this.currentState = state; 132 | } 133 | 134 | get CurrentState(): QState { 135 | return this.currentState; 136 | } 137 | 138 | get AgentConfig(): QAgentConfig { 139 | return this.agentConfig as QAgentConfig; 140 | } 141 | 142 | set AgentConfig(config: QAgentConfig) { 143 | this.setAgentConfig(config); 144 | } 145 | 146 | getStatesGraph(): { nodes: GraphNode[]; edges: GraphEdge[] } { 147 | return this.qmatrix.getGraphData(); 148 | } 149 | 150 | reset(): void { 151 | 152 | } 153 | 154 | setLossOnAlreadyVisitedState(toggle: boolean): void { 155 | this.lossOnAlreadyVisited = toggle; 156 | } 157 | } -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](README.md) 2 | 3 | [Globals](globals.md) / 4 | 5 | # ReImproveJS 6 | 7 | > A framework using TensorFlow.js for Deep Reinforcement Learning 8 | 9 | [Documentation](docs/README.md) | [NPM](https://www.npmjs.com/package/reimprovejs) | [Wiki](https://github.com/Pravez/ReImproveJS/wiki) | [Changelog](CHANGELOG.md) 10 | 11 | [![npm version](https://badge.fury.io/js/reimprovejs.svg)](https://badge.fury.io/js/reimprovejs) 12 | [![Build Status](https://travis-ci.org/Pravez/ReImproveJS.svg?branch=master)](https://travis-ci.org/Pravez/ReImproveJS) 13 | 14 | `ReImproveJS` is a little library to create Reinforcement Learning environments with Javascript. 15 | It currently implements DQN algorithm, but aims to allow users to change easily algorithms, like for instance A3C or Sarsa. 16 | 17 | The library is using TensorFlow.js as a computing background, enabling the use of WebGL to empower computations. 18 | 19 | Getting started 20 | ------------------ 21 | 22 | Installation 23 | ------------ 24 | 25 | ReImproveJS is available as a standalone or as a NPM package. 26 | As usual, you can use the CDN 27 | ```html 28 | 29 | ``` 30 | 31 | or if you have your local version 32 | 33 | ```html 34 | 35 | ``` 36 | You can also install it through NPM. 37 | 38 | ```bash 39 | $ npm install reimprovejs 40 | ``` 41 | 42 | Usage 43 | ----------- 44 | 45 | With ReImproveJS, you have an environment organized as if your agents were part of a "school". The idea is that you are managing 46 | an `Academy`, possessing `Teachers` and `Agents` (Students). You add `Teachers` and assign `Agents` to them. At each step of 47 | your world, you just need to give the `Academy` each `Teacher`'s input, which will handle everything concerning learning. 48 | 49 | Because you are in Reinforcement Learning, you need a Neural Network model in order for your agents to learn. TFJS's `Model` is 50 | embedded into a wrapper, and you just need to precise what type of layers you need, and that's all ! 51 | For instance : 52 | 53 | ```javascript 54 | 55 | const modelFitConfig = { // Exactly the same idea here by using tfjs's model's 56 | epochs: 1, // fit config. 57 | stepsPerEpoch: 16 58 | }; 59 | 60 | const numActions = 2; // The number of actions your agent can choose to do 61 | const inputSize = 100; // Inputs size (10x10 image for instance) 62 | const temporalWindow = 1; // The window of data which will be sent yo your agent 63 | // For instance the x previous inputs, and what actions the agent took 64 | 65 | const totalInputSize = inputSize * temporalWindow + numActions * temporalWindow + inputSize; 66 | 67 | const network = new ReImprove.NeuralNetwork(); 68 | network.InputShape = [totalInputSize]; 69 | network.addNeuralNetworkLayers([ 70 | {type: 'dense', units: 32, activation: 'relu'}, 71 | {type: 'dense', units: numActions, activation: 'softmax'} 72 | ]); 73 | // Now we initialize our model, and start adding layers 74 | const model = new ReImprove.Model.FromNetwork(network, modelFitConfig); 75 | 76 | // Finally compile the model, we also exactly use tfjs's optimizers and loss functions 77 | // (So feel free to choose one among tfjs's) 78 | model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}) 79 | 80 | ``` 81 | 82 | Now that our model is ready, let's create an agent... 83 | 84 | ```javascript 85 | 86 | // Every single field here is optionnal, and has a default value. Be careful, it may not 87 | // fit your needs ... 88 | 89 | const teacherConfig = { 90 | lessonsQuantity: 10, // Number of training lessons before only testing agent 91 | lessonsLength: 100, // The length of each lesson (in quantity of updates) 92 | lessonsWithRandom: 2, // How many random lessons before updating epsilon's value 93 | epsilon: 1, // Q-Learning values and so on ... 94 | epsilonDecay: 0.995, // (Random factor epsilon, decaying over time) 95 | epsilonMin: 0.05, 96 | gamma: 0.8 // (Gamma = 1 : agent cares really much about future rewards) 97 | }; 98 | 99 | const agentConfig = { 100 | model: model, // Our model corresponding to the agent 101 | agentConfig: { 102 | memorySize: 5000, // The size of the agent's memory (Q-Learning) 103 | batchSize: 128, // How many tensors will be given to the network when fit 104 | temporalWindow: temporalWindow // The temporal window giving previous inputs & actions 105 | } 106 | }; 107 | 108 | const academy = new ReImprove.Academy(); // First we need an academy to host everything 109 | const teacher = academy.addTeacher(teacherConfig); 110 | const agent = academy.addAgent(agentConfig); 111 | 112 | academy.assignTeacherToAgent(agent, teacher); 113 | 114 | ``` 115 | 116 | And that's it ! Now you just need to update during your world emulation if the agent gets rewards, and 117 | feed inputs to it. 118 | 119 | ```javascript 120 | // Nice event occuring during world emulation 121 | function OnSpecialGoodEvent() { 122 | academy.addRewardToAgent(agent, 1.0) // Give a nice reward if the agent did something nice ! 123 | } 124 | 125 | // Bad event 126 | function OnSpecialBadEvent() { 127 | academy.addRewardToAgent(agent, -1.0) // Give a bad reward to the agent if he did something wrong 128 | } 129 | 130 | // Animation loop, update loop, whatever loop you want 131 | async function step(time) { 132 | 133 | let inputs = getInputs(); // Need to give a number[] of your inputs for one teacher. 134 | await academy.step([ // Let the magic operate ... 135 | {teacherName: teacher, agentsInput: inputs} 136 | ]); 137 | 138 | } 139 | 140 | // Start your loop (/!\ for your environment, not specific to ReImproveJS). 141 | requestAnimationFrame(step); 142 | ``` 143 | 144 | Rewards are reset to 0 at each new step. 145 | 146 | __Please be careful__ : Convolutional networks are implemented and operational as models, but currently not 147 | fully implemented and tested in the Reinforcement Learning, so please __do not use them__ for now. 148 | 149 | Exemples 150 | ----------------- 151 | 152 | Here an exemple made by [@RGBKnights](https://github.com/RGBKnights) : https://gist.github.com/RGBKnights/756b5f51465cc22d0ca39205979ad2a1 -------------------------------------------------------------------------------- /docs/classes/qstate.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [QState](qstate.md) / 4 | 5 | # Class: QState 6 | 7 | ## Hierarchy 8 | 9 | * **QState** 10 | 11 | ### Index 12 | 13 | #### Constructors 14 | 15 | * [constructor](qstate.md#constructor) 16 | 17 | #### Properties 18 | 19 | * [data](qstate.md#private-data) 20 | * [final](qstate.md#private-final) 21 | * [id](qstate.md#private-id) 22 | * [reward](qstate.md#private-reward) 23 | * [transitions](qstate.md#private-transitions) 24 | * [stateId](qstate.md#static-private-stateid) 25 | 26 | #### Accessors 27 | 28 | * [Data](qstate.md#data) 29 | * [Final](qstate.md#final) 30 | * [Id](qstate.md#id) 31 | * [Reward](qstate.md#reward) 32 | * [Transitions](qstate.md#transitions) 33 | 34 | #### Methods 35 | 36 | * [setFinal](qstate.md#setfinal) 37 | * [setTransition](qstate.md#settransition) 38 | * [takeAction](qstate.md#takeaction) 39 | 40 | ## Constructors 41 | 42 | ### constructor 43 | 44 | \+ **new QState**(`data`: [QStateData](../interfaces/qstatedata.md), `reward`: number): *[QState](qstate.md)* 45 | 46 | *Defined in [reimprove/algorithms/q/qstate.ts:13](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L13)* 47 | 48 | **Parameters:** 49 | 50 | Name | Type | 51 | ------ | ------ | 52 | `data` | [QStateData](../interfaces/qstatedata.md) | 53 | `reward` | number | 54 | 55 | **Returns:** *[QState](qstate.md)* 56 | 57 | ___ 58 | 59 | ## Properties 60 | 61 | ### `Private` data 62 | 63 | ● **data**: *[QStateData](../interfaces/qstatedata.md)* 64 | 65 | *Defined in [reimprove/algorithms/q/qstate.ts:15](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L15)* 66 | 67 | ___ 68 | 69 | ### `Private` final 70 | 71 | ● **final**: *boolean* 72 | 73 | *Defined in [reimprove/algorithms/q/qstate.ts:10](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L10)* 74 | 75 | ___ 76 | 77 | ### `Private` id 78 | 79 | ● **id**: *number* 80 | 81 | *Defined in [reimprove/algorithms/q/qstate.ts:11](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L11)* 82 | 83 | ___ 84 | 85 | ### `Private` reward 86 | 87 | ● **reward**: *number* 88 | 89 | *Defined in [reimprove/algorithms/q/qstate.ts:15](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L15)* 90 | 91 | ___ 92 | 93 | ### `Private` transitions 94 | 95 | ● **transitions**: *`Map`* 96 | 97 | *Defined in [reimprove/algorithms/q/qstate.ts:9](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L9)* 98 | 99 | ___ 100 | 101 | ### `Static` `Private` stateId 102 | 103 | ■ **stateId**: *number* = 0 104 | 105 | *Defined in [reimprove/algorithms/q/qstate.ts:13](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L13)* 106 | 107 | ___ 108 | 109 | ## Accessors 110 | 111 | ### Data 112 | 113 | ● **get Data**(): *[QStateData](../interfaces/qstatedata.md)* 114 | 115 | *Defined in [reimprove/algorithms/q/qstate.ts:31](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L31)* 116 | 117 | **Returns:** *[QStateData](../interfaces/qstatedata.md)* 118 | 119 | ___ 120 | 121 | ### Final 122 | 123 | ● **get Final**(): *boolean* 124 | 125 | *Defined in [reimprove/algorithms/q/qstate.ts:37](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L37)* 126 | 127 | **Returns:** *boolean* 128 | 129 | ● **set Final**(`final`: boolean): *void* 130 | 131 | *Defined in [reimprove/algorithms/q/qstate.ts:36](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L36)* 132 | 133 | **Parameters:** 134 | 135 | Name | Type | 136 | ------ | ------ | 137 | `final` | boolean | 138 | 139 | **Returns:** *void* 140 | 141 | ___ 142 | 143 | ### Id 144 | 145 | ● **get Id**(): *number* 146 | 147 | *Defined in [reimprove/algorithms/q/qstate.ts:38](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L38)* 148 | 149 | **Returns:** *number* 150 | 151 | ___ 152 | 153 | ### Reward 154 | 155 | ● **get Reward**(): *number* 156 | 157 | *Defined in [reimprove/algorithms/q/qstate.ts:32](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L32)* 158 | 159 | **Returns:** *number* 160 | 161 | ● **set Reward**(`reward`: number): *void* 162 | 163 | *Defined in [reimprove/algorithms/q/qstate.ts:33](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L33)* 164 | 165 | **Parameters:** 166 | 167 | Name | Type | 168 | ------ | ------ | 169 | `reward` | number | 170 | 171 | **Returns:** *void* 172 | 173 | ___ 174 | 175 | ### Transitions 176 | 177 | ● **get Transitions**(): *[QTransition](qtransition.md)[]* 178 | 179 | *Defined in [reimprove/algorithms/q/qstate.ts:34](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L34)* 180 | 181 | **Returns:** *[QTransition](qtransition.md)[]* 182 | 183 | ___ 184 | 185 | ## Methods 186 | 187 | ### setFinal 188 | 189 | ▸ **setFinal**(): *[QState](qstate.md)* 190 | 191 | *Defined in [reimprove/algorithms/q/qstate.ts:35](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L35)* 192 | 193 | **Returns:** *[QState](qstate.md)* 194 | 195 | ___ 196 | 197 | ### setTransition 198 | 199 | ▸ **setTransition**(`action`: [QAction](qaction.md), `transition`: [QTransition](qtransition.md)): *[QTransition](qtransition.md)* 200 | 201 | *Defined in [reimprove/algorithms/q/qstate.ts:21](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L21)* 202 | 203 | **Parameters:** 204 | 205 | Name | Type | 206 | ------ | ------ | 207 | `action` | [QAction](qaction.md) | 208 | `transition` | [QTransition](qtransition.md) | 209 | 210 | **Returns:** *[QTransition](qtransition.md)* 211 | 212 | ___ 213 | 214 | ### takeAction 215 | 216 | ▸ **takeAction**(`action`: [QAction](qaction.md)): *[QTransition](qtransition.md)* 217 | 218 | *Defined in [reimprove/algorithms/q/qstate.ts:27](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/algorithms/q/qstate.ts#L27)* 219 | 220 | **Parameters:** 221 | 222 | Name | Type | 223 | ------ | ------ | 224 | `action` | [QAction](qaction.md) | 225 | 226 | **Returns:** *[QTransition](qtransition.md)* 227 | 228 | ___ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # This repository is deprecated. I cannot maintain it anymore, and cannot update issues. 2 | 3 | # ReImproveJS 4 | 5 | > A framework using TensorFlow.js for Deep Reinforcement Learning 6 | 7 | [Documentation](docs/README.md) | [NPM](https://www.npmjs.com/package/reimprovejs) | [Wiki](https://github.com/Pravez/ReImproveJS/wiki) | [Changelog](CHANGELOG.md) 8 | 9 | [![npm version](https://badge.fury.io/js/reimprovejs.svg)](https://badge.fury.io/js/reimprovejs) 10 | [![Build Status](https://travis-ci.org/Pravez/ReImproveJS.svg?branch=master)](https://travis-ci.org/Pravez/ReImproveJS) 11 | 12 | `ReImproveJS` is a little library to create Reinforcement Learning environments with Javascript. 13 | It currently implements DQN algorithm, but aims to allow users to change easily algorithms, like for instance A3C or Sarsa. 14 | 15 | The library is using TensorFlow.js as a computing background, enabling the use of WebGL to empower computations. 16 | 17 | Getting started 18 | ------------------ 19 | 20 | Installation 21 | ------------ 22 | 23 | ReImproveJS is available as a standalone or as a NPM package. 24 | As usual, you can use the CDN 25 | ```html 26 | 27 | ``` 28 | 29 | or if you have your local version 30 | 31 | ```html 32 | 33 | ``` 34 | You can also install it through NPM. 35 | 36 | ```bash 37 | $ npm install reimprovejs 38 | ``` 39 | 40 | Usage 41 | ----------- 42 | 43 | With ReImproveJS, you have an environment organized as if your agents were part of a "school". The idea is that you are managing 44 | an `Academy`, possessing `Teachers` and `Agents` (Students). You add `Teachers` and assign `Agents` to them. At each step of 45 | your world, you just need to give the `Academy` each `Teacher`'s input, which will handle everything concerning learning. 46 | 47 | Because you are in Reinforcement Learning, you need a Neural Network model in order for your agents to learn. TFJS's `Model` is 48 | embedded into a wrapper, and you just need to precise what type of layers you need, and that's all ! 49 | For instance : 50 | 51 | ```javascript 52 | 53 | const modelFitConfig = { // Exactly the same idea here by using tfjs's model's 54 | epochs: 1, // fit config. 55 | stepsPerEpoch: 16 56 | }; 57 | 58 | const numActions = 2; // The number of actions your agent can choose to do 59 | const inputSize = 100; // Inputs size (10x10 image for instance) 60 | const temporalWindow = 1; // The window of data which will be sent yo your agent 61 | // For instance the x previous inputs, and what actions the agent took 62 | 63 | const totalInputSize = inputSize * temporalWindow + numActions * temporalWindow + inputSize; 64 | 65 | const network = new ReImprove.NeuralNetwork(); 66 | network.InputShape = [totalInputSize]; 67 | network.addNeuralNetworkLayers([ 68 | {type: 'dense', units: 32, activation: 'relu'}, 69 | {type: 'dense', units: numActions, activation: 'softmax'} 70 | ]); 71 | // Now we initialize our model, and start adding layers 72 | const model = new ReImprove.Model.FromNetwork(network, modelFitConfig); 73 | 74 | // Finally compile the model, we also exactly use tfjs's optimizers and loss functions 75 | // (So feel free to choose one among tfjs's) 76 | model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}) 77 | 78 | ``` 79 | 80 | Now that our model is ready, let's create an agent... 81 | 82 | ```javascript 83 | 84 | // Every single field here is optionnal, and has a default value. Be careful, it may not 85 | // fit your needs ... 86 | 87 | const teacherConfig = { 88 | lessonsQuantity: 10, // Number of training lessons before only testing agent 89 | lessonsLength: 100, // The length of each lesson (in quantity of updates) 90 | lessonsWithRandom: 2, // How many random lessons before updating epsilon's value 91 | epsilon: 1, // Q-Learning values and so on ... 92 | epsilonDecay: 0.995, // (Random factor epsilon, decaying over time) 93 | epsilonMin: 0.05, 94 | gamma: 0.8 // (Gamma = 1 : agent cares really much about future rewards) 95 | }; 96 | 97 | const agentConfig = { 98 | model: model, // Our model corresponding to the agent 99 | agentConfig: { 100 | memorySize: 5000, // The size of the agent's memory (Q-Learning) 101 | batchSize: 128, // How many tensors will be given to the network when fit 102 | temporalWindow: temporalWindow // The temporal window giving previous inputs & actions 103 | } 104 | }; 105 | 106 | const academy = new ReImprove.Academy(); // First we need an academy to host everything 107 | const teacher = academy.addTeacher(teacherConfig); 108 | const agent = academy.addAgent(agentConfig); 109 | 110 | academy.assignTeacherToAgent(agent, teacher); 111 | 112 | ``` 113 | 114 | And that's it ! Now you just need to update during your world emulation if the agent gets rewards, and 115 | feed inputs to it. 116 | 117 | ```javascript 118 | // Nice event occuring during world emulation 119 | function OnSpecialGoodEvent() { 120 | academy.addRewardToAgent(agent, 1.0) // Give a nice reward if the agent did something nice ! 121 | } 122 | 123 | // Bad event 124 | function OnSpecialBadEvent() { 125 | academy.addRewardToAgent(agent, -1.0) // Give a bad reward to the agent if he did something wrong 126 | } 127 | 128 | // Animation loop, update loop, whatever loop you want 129 | async function step(time) { 130 | 131 | let inputs = getInputs(); // Need to give a number[] of your inputs for one teacher. 132 | await academy.step([ // Let the magic operate ... 133 | {teacherName: teacher, agentsInput: inputs} 134 | ]); 135 | 136 | } 137 | 138 | // Start your loop (/!\ for your environment, not specific to ReImproveJS). 139 | requestAnimationFrame(step); 140 | ``` 141 | 142 | Rewards are reset to 0 at each new step. 143 | 144 | __Please be careful__ : Convolutional networks are implemented and operational as models, but currently not 145 | fully implemented and tested in the Reinforcement Learning, so please __do not use them__ for now. 146 | 147 | Exemples 148 | ----------------- 149 | 150 | Here an exemple made by [@RGBKnights](https://github.com/RGBKnights) : https://gist.github.com/RGBKnights/756b5f51465cc22d0ca39205979ad2a1 151 | 152 | 153 | -------------------------------------------------------------------------------- /src/reimprove/teacher.ts: -------------------------------------------------------------------------------- 1 | import {DQAgent} from "./algorithms/deepq/dqagent"; 2 | import {AgentTrackingInformation} from "./algorithms/agent_config"; 3 | 4 | 5 | const DEFAULT_TEACHING_CONFIG: TeachingConfig = { 6 | lessonLength: 1000, 7 | lessonsQuantity: 30, 8 | lessonsWithRandom: 2, 9 | epsilon: 1, 10 | epsilonMin: 0.05, 11 | epsilonDecay: 0.95, 12 | gamma: 0.9, 13 | alpha: 1 14 | }; 15 | 16 | export interface TeachingConfig { 17 | lessonLength?: number; 18 | lessonsQuantity?: number; 19 | lessonsWithRandom?: number; 20 | gamma?: number; 21 | epsilon?: number; 22 | epsilonDecay?: number; 23 | epsilonMin?: number; 24 | alpha?: number; 25 | } 26 | 27 | export enum TeachingState { 28 | EXPERIENCING = 0, 29 | LEARNING = 1, 30 | TESTING = 2, 31 | NONE = -1, 32 | STOPPED = -2 33 | } 34 | 35 | export interface TeacherTrackingInformation { 36 | name: string; 37 | gamma: number; 38 | epsilon: number; 39 | currentLessonLength: number; 40 | lessonNumber: number; 41 | maxLessons: number; 42 | students: AgentTrackingInformation[]; 43 | } 44 | 45 | export class Teacher { 46 | 47 | name: string; 48 | config: TeachingConfig; 49 | state: TeachingState; 50 | 51 | agents: Set; 52 | 53 | currentLessonLength: number; 54 | lessonsTaught: number; 55 | 56 | onLearningLessonEnded: (teacher: string) => void; 57 | onLessonEnded: (teacher: string, lessonNumber: number) => void; 58 | onTeachingEnded: (teacher: string) => void; 59 | 60 | currentEpsilon: number; 61 | 62 | constructor(config?: TeachingConfig, name?: string) { 63 | this.config = {...DEFAULT_TEACHING_CONFIG, ...config}; 64 | this.agents = new Set(); 65 | this.currentLessonLength = 0; 66 | this.lessonsTaught = 0; 67 | this.state = TeachingState.NONE; 68 | this.currentEpsilon = this.config.epsilon; 69 | 70 | this.onLessonEnded = null; 71 | this.onLearningLessonEnded = null; 72 | this.onTeachingEnded = null; 73 | this.name = name; 74 | } 75 | 76 | affectStudent(agent: DQAgent) { 77 | this.agents.add(agent); 78 | } 79 | 80 | removeStudent(agent: DQAgent): boolean { 81 | return this.agents.delete(agent); 82 | } 83 | 84 | start() { 85 | this.lessonsTaught = 0; 86 | this.currentLessonLength = 0; 87 | this.state = TeachingState.EXPERIENCING; 88 | } 89 | 90 | async teach(inputs: number[]): Promise> { 91 | if (this.state == TeachingState.STOPPED) return null; 92 | 93 | if (this.state == TeachingState.NONE) { 94 | this.start(); 95 | } 96 | 97 | let actions = new Map(); 98 | // If learning is ended, we only test : we only do infer prop through network 99 | if (this.state == TeachingState.TESTING) { 100 | this.agents.forEach(a => actions.set(a.Name, a.infer(inputs, this.config.epsilonMin, false))); 101 | } else { 102 | 103 | //Update lesson 104 | this.currentLessonLength += 1; 105 | 106 | if (this.currentLessonLength >= this.config.lessonLength) 107 | this.state = TeachingState.LEARNING; 108 | 109 | 110 | if (this.state == TeachingState.EXPERIENCING) { 111 | this.agents.forEach(a => actions.set(a.Name, a.listen(inputs, this.currentEpsilon))); 112 | } else if (this.state == TeachingState.LEARNING) { 113 | if (this.onLessonEnded) 114 | this.onLessonEnded(this.name, this.lessonsTaught); 115 | 116 | for (let agent of Array.from(this.agents.keys())) { 117 | await agent.learn(this.config.gamma, this.config.alpha); 118 | } 119 | 120 | this.updateParameters(); 121 | 122 | this.lessonsTaught += 1; 123 | this.currentLessonLength = 0; 124 | 125 | if (this.lessonsTaught >= this.config.lessonsQuantity) { 126 | this.state = TeachingState.TESTING; 127 | if (this.onTeachingEnded) 128 | this.onTeachingEnded(this.name); 129 | } else { 130 | this.state = TeachingState.EXPERIENCING; 131 | } 132 | 133 | this.agents.forEach(a => actions.set(a.Name, a.listen(inputs, this.currentEpsilon))); 134 | 135 | if (this.onLearningLessonEnded) 136 | this.onLearningLessonEnded(this.name); 137 | 138 | } 139 | } 140 | 141 | // reset reward for everyone 142 | this.agents.forEach(a => a.setReward(0.)); 143 | 144 | return actions; 145 | } 146 | 147 | stopTeaching() { 148 | this.state = TeachingState.TESTING; 149 | } 150 | 151 | startTeaching() { 152 | if(this.lessonsTaught < this.config.lessonsQuantity) 153 | this.state = TeachingState.EXPERIENCING; 154 | } 155 | 156 | updateParameters() { 157 | if (this.lessonsTaught > this.config.lessonsWithRandom && this.currentEpsilon > this.config.epsilonMin) { 158 | this.currentEpsilon *= this.config.epsilonDecay; 159 | 160 | if(this.currentEpsilon < this.config.epsilonMin) 161 | this.currentEpsilon = this.config.epsilonMin; 162 | } 163 | } 164 | 165 | getData(): TeacherTrackingInformation { 166 | let data: AgentTrackingInformation[] = []; 167 | this.agents.forEach(agent => data.push(agent.getTrackingInformation())); 168 | return { 169 | epsilon: this.currentEpsilon, 170 | gamma: this.config.gamma, 171 | lessonNumber: this.lessonsTaught, 172 | currentLessonLength: this.currentLessonLength, 173 | maxLessons: this.config.lessonsQuantity, 174 | name: this.name, 175 | students: data 176 | }; 177 | } 178 | 179 | resetLesson() { 180 | this.currentLessonLength = 0; 181 | this.state = TeachingState.EXPERIENCING; 182 | } 183 | 184 | reset() { 185 | this.lessonsTaught = 0; 186 | this.currentLessonLength = 0; 187 | this.state = TeachingState.NONE; 188 | } 189 | 190 | stop() { 191 | this.state = TeachingState.STOPPED; 192 | } 193 | 194 | set OnLearningLessonEnded(callback: (teacher: string) => void) { 195 | this.onLearningLessonEnded = callback; 196 | } 197 | 198 | set OnLessonEnded(callback: (teacher: string, lessonNumber: number) => void) { 199 | this.onLessonEnded = callback; 200 | } 201 | 202 | set OnTeachingEnded(callback: (teacher: string) => void) { 203 | this.onTeachingEnded = callback; 204 | } 205 | 206 | set Name(name: string) { 207 | this.name = name; 208 | } 209 | 210 | get Name() { 211 | return this.name; 212 | } 213 | 214 | get State() { 215 | return this.state; 216 | } 217 | } -------------------------------------------------------------------------------- /docs/classes/model.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [Model](model.md) / 4 | 5 | # Class: Model 6 | 7 | The Model class is handling everything concerning the neural network 8 | 9 | ## Hierarchy 10 | 11 | * **Model** 12 | 13 | ### Index 14 | 15 | #### Constructors 16 | 17 | * [constructor](model.md#constructor) 18 | 19 | #### Properties 20 | 21 | * [fitConfig](model.md#fitconfig) 22 | * [model](model.md#model) 23 | 24 | #### Accessors 25 | 26 | * [FitConfig](model.md#fitconfig) 27 | * [InputSize](model.md#inputsize) 28 | * [OutputSize](model.md#outputsize) 29 | 30 | #### Methods 31 | 32 | * [addLayer](model.md#addlayer) 33 | * [compile](model.md#compile) 34 | * [export](model.md#export) 35 | * [fit](model.md#fit) 36 | * [predict](model.md#predict) 37 | * [randomOutput](model.md#randomoutput) 38 | * [FromNetwork](model.md#static-fromnetwork) 39 | * [loadFromFile](model.md#static-loadfromfile) 40 | 41 | ## Constructors 42 | 43 | ### constructor 44 | 45 | \+ **new Model**(`config?`: `SequentialArgs`, `fitConfig?`: `ModelFitArgs`): *[Model](model.md)* 46 | 47 | *Defined in [reimprove/model.ts:48](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L48)* 48 | 49 | The sequential config is truly optional and is to use only if you want to provide a complete tf.layers implementation 50 | of your model. Currently only dense layers are supported but convolutions etc will be implemented quickly. The [[ModelFitConfig]] 51 | is concerning the steps, steps per epoch etc ... which is how is the model going to train itself, which is handled by TensorFlowJS. 52 | 53 | **Parameters:** 54 | 55 | Name | Type | Description | 56 | ------ | ------ | ------ | 57 | `config?` | `SequentialArgs` | - | 58 | `fitConfig?` | `ModelFitArgs` | | 59 | 60 | **Returns:** *[Model](model.md)* 61 | 62 | ___ 63 | 64 | ## Properties 65 | 66 | ### fitConfig 67 | 68 | ● **fitConfig**: *`ModelFitArgs`* 69 | 70 | *Defined in [reimprove/model.ts:48](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L48)* 71 | 72 | ___ 73 | 74 | ### model 75 | 76 | ● **model**: *`LayersModel`* 77 | 78 | *Defined in [reimprove/model.ts:47](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L47)* 79 | 80 | ___ 81 | 82 | ## Accessors 83 | 84 | ### FitConfig 85 | 86 | ● **set FitConfig**(`fitConfig`: `ModelFitArgs`): *void* 87 | 88 | *Defined in [reimprove/model.ts:147](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L147)* 89 | 90 | **Parameters:** 91 | 92 | Name | Type | 93 | ------ | ------ | 94 | `fitConfig` | `ModelFitArgs` | 95 | 96 | **Returns:** *void* 97 | 98 | ___ 99 | 100 | ### InputSize 101 | 102 | ● **get InputSize**(): *number* 103 | 104 | *Defined in [reimprove/model.ts:143](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L143)* 105 | 106 | **Returns:** *number* 107 | 108 | ___ 109 | 110 | ### OutputSize 111 | 112 | ● **get OutputSize**(): *number* 113 | 114 | *Defined in [reimprove/model.ts:139](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L139)* 115 | 116 | **Returns:** *number* 117 | 118 | ___ 119 | 120 | ## Methods 121 | 122 | ### addLayer 123 | 124 | ▸ **addLayer**(`type`: [LayerType](../enums/layertype.md), `config`: [LayerConfig](../interfaces/layerconfig.md)): *void* 125 | 126 | *Defined in [reimprove/model.ts:87](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L87)* 127 | 128 | Method to just add a layer to the model, concatenating it with the previous ones. 129 | 130 | **`deprecated`** Please now use [NeuralNetwork](neuralnetwork.md) 131 | 132 | **Parameters:** 133 | 134 | Name | Type | Description | 135 | ------ | ------ | ------ | 136 | `type` | [LayerType](../enums/layertype.md) | a type among DENSE, FLATTEN or CONV2D | 137 | `config` | [LayerConfig](../interfaces/layerconfig.md) | - | 138 | 139 | **Returns:** *void* 140 | 141 | ___ 142 | 143 | ### compile 144 | 145 | ▸ **compile**(`config`: `ModelCompileArgs`): *[Model](model.md)* 146 | 147 | *Defined in [reimprove/model.ts:121](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L121)* 148 | 149 | To compile the model, refer to [[ModelCompileConfig]] to know exactly what to use, but essentially, give the optimizer ('sgd', 'crossEntropy' , ...) 150 | and the loss function ('meanSquaredError', ...), see TFJS's documentation for the exhaustive list. 151 | 152 | **Parameters:** 153 | 154 | Name | Type | 155 | ------ | ------ | 156 | `config` | `ModelCompileArgs` | 157 | 158 | **Returns:** *[Model](model.md)* 159 | 160 | ___ 161 | 162 | ### export 163 | 164 | ▸ **export**(`destination`: string, `place`: string): *`Promise`* 165 | 166 | *Defined in [reimprove/model.ts:77](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L77)* 167 | 168 | Export model to as destination. 169 | 170 | **Parameters:** 171 | 172 | Name | Type | Default | Description | 173 | ------ | ------ | ------ | ------ | 174 | `destination` | string | - | Can be one of 'downloads' (triggers browser download) [default], 'localstorage', 'indexeddb' or in http request 'http', 'https'. | 175 | `place` | string | "downloads" | - | 176 | 177 | **Returns:** *`Promise`* 178 | 179 | ___ 180 | 181 | ### fit 182 | 183 | ▸ **fit**(`x`: `Tensor`, `y`: `Tensor`): *`Promise`* 184 | 185 | *Defined in [reimprove/model.ts:130](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L130)* 186 | 187 | **Parameters:** 188 | 189 | Name | Type | 190 | ------ | ------ | 191 | `x` | `Tensor` | 192 | `y` | `Tensor` | 193 | 194 | **Returns:** *`Promise`* 195 | 196 | ___ 197 | 198 | ### predict 199 | 200 | ▸ **predict**(`x`: `Tensor`, `config?`: `ModelPredictConfig`): *[Result](result.md)* 201 | 202 | *Defined in [reimprove/model.ts:126](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L126)* 203 | 204 | **Parameters:** 205 | 206 | Name | Type | 207 | ------ | ------ | 208 | `x` | `Tensor` | 209 | `config?` | `ModelPredictConfig` | 210 | 211 | **Returns:** *[Result](result.md)* 212 | 213 | ___ 214 | 215 | ### randomOutput 216 | 217 | ▸ **randomOutput**(): *number* 218 | 219 | *Defined in [reimprove/model.ts:134](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L134)* 220 | 221 | **Returns:** *number* 222 | 223 | ___ 224 | 225 | ### `Static` FromNetwork 226 | 227 | ▸ **FromNetwork**(`network`: [NeuralNetwork](neuralnetwork.md), `fitConfig?`: `ModelFitArgs`, `name`: string): *[Model](model.md)* 228 | 229 | *Defined in [reimprove/model.ts:160](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L160)* 230 | 231 | Static method to create a [Model](model.md) from a [NeuralNetwork](neuralnetwork.md). The fit config is optional as well as the name. It 232 | returns a prepared model, but not compiled. 233 | 234 | **`constructor`** 235 | 236 | **Parameters:** 237 | 238 | Name | Type | Default | 239 | ------ | ------ | ------ | 240 | `network` | [NeuralNetwork](neuralnetwork.md) | - | 241 | `fitConfig?` | `ModelFitArgs` | - | 242 | `name` | string | v4() | 243 | 244 | **Returns:** *[Model](model.md)* 245 | 246 | ___ 247 | 248 | ### `Static` loadFromFile 249 | 250 | ▸ **loadFromFile**(`file`: string | object): *`Promise`* 251 | 252 | *Defined in [reimprove/model.ts:62](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/model.ts#L62)* 253 | 254 | **Parameters:** 255 | 256 | Name | Type | 257 | ------ | ------ | 258 | `file` | string \| object | 259 | 260 | **Returns:** *`Promise`* 261 | 262 | ___ -------------------------------------------------------------------------------- /src/reimprove/algorithms/q/qmatrix.ts: -------------------------------------------------------------------------------- 1 | import {QAction, QActionData} from "./qaction"; 2 | import {QState, QStateData} from "./qstate"; 3 | import {QTransition} from "./qtransition"; 4 | 5 | export interface GraphNode { 6 | id: number; 7 | label: string; 8 | color?: string; 9 | } 10 | 11 | export interface GraphEdge { 12 | from: number; 13 | to: number; 14 | id: number; 15 | label: string; 16 | arrows: string; 17 | color: string; 18 | } 19 | 20 | export class QMatrix { 21 | private actions: Map; 22 | private states: Map; 23 | private transitions: Array; 24 | 25 | 26 | private initialState: QState; 27 | 28 | constructor(private hashFunction?: (data: QStateData) => string) { 29 | this.actions = new Map(); 30 | this.states = new Map(); 31 | this.transitions = []; 32 | } 33 | 34 | registerAction(action: QAction | string, data?: QActionData) { 35 | this.actions.set(typeof action === "string" ? action : action.Name, typeof action === "string" ? new QAction(action, data) : action); 36 | } 37 | 38 | registerState(data: QStateData, reward: number = 0.): QState { 39 | if (!this.hashFunction) 40 | throw new Error("Unable to register a state without a hash function."); 41 | if (this.states.has(this.hash(data))) 42 | return this.states.get(this.hash(data)); 43 | const state = new QState(data, reward); 44 | this.states.set(this.hash(state.Data), state); 45 | return state; 46 | } 47 | 48 | registerTransition(action: string, from: QState, to: QState, oppositeActionName?: string): QTransition { 49 | const qaction = this.action(action); 50 | 51 | let transition = new QTransition(from, to, qaction); 52 | from.setTransition(qaction, transition); 53 | this.transitions.push(transition); 54 | 55 | if (oppositeActionName) { 56 | transition = new QTransition(to, from, qaction); 57 | to.setTransition(this.action(oppositeActionName), transition); 58 | this.transitions.push(transition); 59 | } 60 | 61 | return transition; 62 | } 63 | 64 | updateTransition(id: number, to: QState): QTransition | undefined { 65 | const trans = this.transitions.find(t => t.Id === id); 66 | if (trans) { 67 | trans.To = to; 68 | return trans; 69 | } 70 | return undefined; 71 | } 72 | 73 | action(name: string): QAction { 74 | return this.actions.get(name); 75 | } 76 | 77 | hash(data: QStateData): string { 78 | try { 79 | return this.hashFunction(data); 80 | } catch (exception) { 81 | console.error("Unable to hash the object, are you sure you gave a defined QStateData ? : " + exception); 82 | console.error("Falling on default hash func ... [ PLEASE PROVIDE A HASH FUNCTION ]"); 83 | return QMatrix.defaultHash(data); 84 | } 85 | } 86 | 87 | static defaultHash(data: QStateData): string { 88 | return JSON.stringify(data); 89 | } 90 | 91 | getStateFromData(data: QStateData): QState | undefined { 92 | return this.states.get(this.hash(data)); 93 | } 94 | 95 | exists(data: QStateData): boolean { 96 | return this.states.has(this.hash(data)); 97 | } 98 | 99 | private checkAndGetState(state: QState | QStateData | string): QState | undefined { 100 | if (typeof state === "string") { 101 | state = this.states.get(state); 102 | } else if (!(state instanceof QState)) { 103 | state = this.states.get(this.hash(state)); 104 | } 105 | 106 | return state as QState; 107 | } 108 | 109 | /** 110 | * Sets a state as initial state. Be careful there can be only one ! 111 | * @param {QState | string | QStateData} state 112 | * @returns {boolean} 113 | */ 114 | setStateAsInitial(state: QState | QStateData | string): boolean { 115 | this.initialState = this.checkAndGetState(state); 116 | return this.initialState !== undefined; 117 | } 118 | 119 | /** 120 | * Sets a state as final, which means that stops the emulation. There can be many. 121 | * Can be also done through QState.Final = true 122 | * @param {QState | string | QStateData} state 123 | * @returns {boolean} True if the state was successfully modified, false if it does not exists or wasn't modified. 124 | */ 125 | setStateAsFinal(state: QState | QStateData | string): boolean { 126 | const final = this.checkAndGetState(state); 127 | if (final !== undefined && !final.Final) { 128 | final.Final = true; 129 | return true; 130 | } 131 | 132 | return false; 133 | } 134 | 135 | /** 136 | * Remove the final flag from a state. Can be also done through QState.Final = false 137 | * @param {QState | string | QStateData} state 138 | * @returns {boolean} True if the state exists and was successfully modified, else false. 139 | */ 140 | removeStateFromFinals(state: QState | string | QStateData): boolean { 141 | const temps = this.checkAndGetState(state); 142 | if (temps !== undefined && temps.Final) { 143 | temps.Final = false; 144 | return true; 145 | } else { 146 | return false; 147 | } 148 | } 149 | 150 | reset() { 151 | this.transitions = []; 152 | this.states.clear(); 153 | this.actions.clear(); 154 | } 155 | 156 | resetTransitions() { 157 | this.transitions.forEach(t => t.Q = 0.0); 158 | } 159 | 160 | get InitialState() { 161 | return this.initialState; 162 | } 163 | 164 | get FinalStates() { 165 | return Array.from(this.states.values()).filter(state => state.Final) 166 | } 167 | 168 | get HashFunction() { 169 | return this.hashFunction; 170 | } 171 | 172 | set HashFunction(func: (data: QStateData) => string) { 173 | this.hashFunction = func; 174 | } 175 | 176 | get States(): Array { 177 | return Array.from(this.states.values()); 178 | } 179 | 180 | get Actions(): Array { 181 | return Array.from(this.actions.values()); 182 | } 183 | 184 | getGraphData(): { nodes: GraphNode[], edges: GraphEdge[] } { 185 | const nodes: GraphNode[] = this.States.map(s => ({ 186 | id: s.Id, 187 | label: JSON.stringify(s.Data), 188 | color: getColor(s.Reward) 189 | })); 190 | const edges: GraphEdge[] = this.transitions 191 | .filter(t => t.To && t.From) 192 | .map(t => ({ 193 | id: t.Id, 194 | to: t.To.Id, 195 | from: t.From.Id, 196 | label: `${t.Q}-${t.Action.Name}`, 197 | color: getColor(t.Q), 198 | arrows: 'to' 199 | })); 200 | 201 | return {nodes: nodes, edges: edges}; 202 | } 203 | } 204 | 205 | function getColor(value: number) { 206 | //value from 0 to 1 207 | const hue = parseInt(((1 - value) * 120).toString(10)); 208 | const h = hue; 209 | const s = 1; 210 | const l = 0.5; 211 | 212 | const c = (1 - Math.abs(2 * l - 1)) * s; 213 | const x = c * (1 - Math.abs(h / 60 % 2 - 1)); 214 | const m = l - c / 2; 215 | 216 | const values = hue < 60 ? [c, x, 0] : [x, c, 0]; 217 | const rgb = [(values[0] + m) * 255, (values[1] + m) * 255, (values[2] + m) * 255]; 218 | return `rgb(${rgb[0]},${rgb[1]},${rgb[2]})`; 219 | } -------------------------------------------------------------------------------- /src/reimprove/model.ts: -------------------------------------------------------------------------------- 1 | import * as tflayers from '@tensorflow/tfjs-layers'; 2 | import {Tensor, tidy, io, ModelPredictConfig} from '@tensorflow/tfjs-core'; 3 | import {random} from 'lodash'; 4 | import {NeuralNetwork} from "./networks"; 5 | import v4 from 'uuid/v4'; 6 | 7 | const DEFAULT_MODEL_FIT_CONFIG: tflayers.ModelFitArgs = { 8 | epochs: 10, 9 | stepsPerEpoch: 200 10 | }; 11 | 12 | export enum LayerType { 13 | DENSE = "DENSE", 14 | CONV2D = "CONV2D", 15 | FLATTEN = "FLATTEN" 16 | } 17 | 18 | /** 19 | * Simplified layer configuration where you only give your layer, your activation function and the number of units. 20 | */ 21 | export interface LayerConfig { 22 | /** Number of neurons of this layer */ 23 | units: number; 24 | /** If it is an input layer, the size of the input */ 25 | inputShape?: Array; 26 | /** The activation function ('relu', 'sigmoid', ...) */ 27 | activation: string; 28 | useBias?: boolean; 29 | } 30 | 31 | const DEFAULT_LAYER_CONFIG: LayerConfig = { 32 | units: 32, 33 | activation: 'relu', 34 | useBias: false 35 | }; 36 | 37 | 38 | interface ToTfLayerConfig { 39 | [key: string]: any; 40 | } 41 | 42 | /** 43 | * The Model class is handling everything concerning the neural network 44 | */ 45 | export class Model { 46 | 47 | model: tflayers.LayersModel; 48 | fitConfig: tflayers.ModelFitArgs; 49 | 50 | /** 51 | * The sequential config is truly optional and is to use only if you want to provide a complete tf.layers implementation 52 | * of your model. Currently only dense layers are supported but convolutions etc will be implemented quickly. The [[ModelFitConfig]] 53 | * is concerning the steps, steps per epoch etc ... which is how is the model going to train itself, which is handled by TensorFlowJS. 54 | * @param {SequentialArgs} config 55 | * @param {ModelFitArgs} fitConfig 56 | */ 57 | constructor(config?: tflayers.SequentialArgs, fitConfig?: tflayers.ModelFitArgs) { 58 | this.model = new tflayers.Sequential(config); 59 | this.fitConfig = {...DEFAULT_MODEL_FIT_CONFIG, ...fitConfig}; 60 | } 61 | 62 | static async loadFromFile(file: string | {json: File, weights: File}): Promise { 63 | let model = new Model(); 64 | if(typeof file === "string") 65 | model.model = await tflayers.loadLayersModel(file); 66 | else 67 | model.model = await tflayers.loadLayersModel(io.browserFiles([file.json, file.weights])); 68 | return model; 69 | } 70 | 71 | /** 72 | * Export model to as destination. 73 | * @param {string} destination Can be one of 'downloads' (triggers browser download) [default], 'localstorage', 'indexeddb' or in http request 'http', 'https'. 74 | * @param {string} place 75 | * @returns {Promise} 76 | */ 77 | async export(destination: string, place = 'downloads') { 78 | await this.model.save(`${place}://${destination}`); 79 | } 80 | 81 | /** 82 | * Method to just add a layer to the model, concatenating it with the previous ones. 83 | * @param type a type among DENSE, FLATTEN or CONV2D 84 | * @param {LayerConfig} config 85 | * @deprecated Please now use [[NeuralNetwork]] 86 | */ 87 | addLayer(type: LayerType, config: LayerConfig) { 88 | if(this.model instanceof tflayers.Sequential) { 89 | let conf: ToTfLayerConfig = DEFAULT_LAYER_CONFIG; 90 | if (config.inputShape) 91 | conf.inputShape = config.inputShape; 92 | 93 | switch (type) { 94 | case LayerType.DENSE: 95 | conf.units = config.units; 96 | conf.activation = config.activation; 97 | this.model.add(tflayers.layers.dense(conf)); 98 | break; 99 | case LayerType.CONV2D: 100 | conf.filters = config.units; 101 | conf.activation = config.activation; 102 | conf.useBias = config.useBias; 103 | this.model.add(tflayers.layers.conv2d(conf)); 104 | break; 105 | case LayerType.FLATTEN: 106 | conf = {}; 107 | this.model.add(tflayers.layers.flatten( conf)); 108 | break; 109 | } 110 | } else { 111 | throw new Error("Unable to add a layer to an already created model managed by tensorflowjs"); 112 | } 113 | } 114 | 115 | /** 116 | * To compile the model, refer to [[ModelCompileConfig]] to know exactly what to use, but essentially, give the optimizer ('sgd', 'crossEntropy' , ...) 117 | * and the loss function ('meanSquaredError', ...), see TFJS's documentation for the exhaustive list. 118 | * @param {ModelCompileArgs} config 119 | * @returns {Model} 120 | */ 121 | compile(config: tflayers.ModelCompileArgs): Model { 122 | this.model.compile(config); 123 | return this; 124 | } 125 | 126 | predict(x: Tensor, config?: ModelPredictConfig): Result { 127 | return new Result( this.model.predict(x, config)); 128 | } 129 | 130 | fit(x: Tensor, y: Tensor): Promise { 131 | return this.model.fit(x, y, this.fitConfig); 132 | } 133 | 134 | randomOutput(): number { 135 | // TODO create a distribution of all taken actions, in order later to choose in what way we want the random to behave 136 | return random(0, this.OutputSize); 137 | } 138 | 139 | get OutputSize(): number { 140 | return (this.model.getOutputAt(0)).shape[1]; 141 | } 142 | 143 | get InputSize(): number { 144 | return this.model.layers[0].batchInputShape[1]; 145 | } 146 | 147 | set FitConfig(fitConfig: tflayers.ModelFitArgs) { 148 | this.fitConfig = {...DEFAULT_MODEL_FIT_CONFIG, ...fitConfig}; 149 | } 150 | 151 | /** 152 | * Static method to create a [[Model]] from a [[NeuralNetwork]]. The fit config is optional as well as the name. It 153 | * returns a prepared model, but not compiled. 154 | * @param {NeuralNetwork} network 155 | * @param {ModelFitConfig} fitConfig 156 | * @param {string} name 157 | * @returns {Model} 158 | * @constructor 159 | */ 160 | static FromNetwork(network: NeuralNetwork, fitConfig?: tflayers.ModelFitArgs, name: string = v4()): Model { 161 | return new Model({ 162 | name: name, 163 | layers: network.createLayers() 164 | }, fitConfig); 165 | 166 | } 167 | } 168 | 169 | /** 170 | * Just a little wrapper around the result of a request to TensorflowJS. Because every single result is made through WebGL, 171 | * we need to create async tasks. So we remove the async side by using the dataSync() method to get at the moment the result, 172 | * instead of returning a Promise. 173 | */ 174 | export class Result { 175 | 176 | constructor(private result: Tensor) { 177 | } 178 | 179 | private getResultAndDispose(t: Tensor): Float32Array | Int32Array | Uint8Array { 180 | this.result.dispose(); 181 | return t.dataSync(); 182 | } 183 | 184 | /** 185 | * Returns the highest value of an 1D tensor 186 | * @returns {number} 187 | */ 188 | getHighestValue(): number { 189 | return tidy(() => this.getResultAndDispose(this.result.as1D().max())[0]); 190 | } 191 | 192 | /** 193 | * Returns the index of the highest value of an 1D tensor 194 | * @returns {number} 195 | */ 196 | getAction(): number { 197 | return tidy(() => this.getResultAndDispose(this.result.as1D().argMax())[0]); 198 | } 199 | 200 | /** 201 | * Returns an array reflecting the initial result tensor 202 | * @returns {Int32Array | Float32Array | Uint8Array} 203 | */ 204 | getValue(): Int32Array | Float32Array | Uint8Array { 205 | const resTensor = this.result.as1D(); 206 | const result = resTensor.dataSync(); 207 | 208 | resTensor.dispose(); 209 | this.result.dispose(); 210 | 211 | return result; 212 | } 213 | } 214 | -------------------------------------------------------------------------------- /src/reimprove/algorithms/deepq/dqagent.ts: -------------------------------------------------------------------------------- 1 | import {Memento, MementoTensor, Memory} from "../../memory"; 2 | import {Model} from "../../model"; 3 | import {Tensor, tensor, tensor2d, tidy} from "@tensorflow/tfjs-core"; 4 | import {range, random} from "lodash"; 5 | import {TypedWindow} from "../../misc/typed_window"; 6 | import {DQAgentConfig, AgentTrackingInformation} from "../agent_config"; 7 | import {AbstractAgent} from "../abstract_agent"; 8 | 9 | const MEM_WINDOW_MIN_SIZE = 2; 10 | const HIST_WINDOW_SIZE = 100; 11 | const HIST_WINDOW_MIN_SIZE = 0; 12 | 13 | const DEFAULT_AGENT_CONFIG: DQAgentConfig = { 14 | memorySize: 30000, 15 | batchSize: 32, 16 | temporalWindow: 1 17 | }; 18 | 19 | export class DQAgent extends AbstractAgent { 20 | private done: boolean; 21 | private currentReward: number; 22 | 23 | private readonly actionsBuffer: Array; 24 | private readonly statesBuffer: Array; 25 | private readonly inputsBuffer: Array; 26 | 27 | private lossesHistory: TypedWindow; 28 | private rewardsHistory: TypedWindow; 29 | private readonly netInputWindowSize: number; 30 | 31 | private memory: Memory; 32 | 33 | private forwardPasses: number; 34 | 35 | 36 | constructor(private model: Model, agentConfig?: DQAgentConfig, name?: string) { 37 | super(agentConfig, name); 38 | this.AgentConfig = {...DEFAULT_AGENT_CONFIG, ...agentConfig} as DQAgentConfig; 39 | this.done = false; 40 | this.currentReward = 0; 41 | 42 | this.lossesHistory = new TypedWindow(HIST_WINDOW_SIZE, HIST_WINDOW_MIN_SIZE, -1); 43 | this.rewardsHistory = new TypedWindow(HIST_WINDOW_SIZE, HIST_WINDOW_MIN_SIZE, null); 44 | 45 | this.memory = new Memory({size: this.AgentConfig.memorySize}); 46 | 47 | this.netInputWindowSize = Math.max(this.AgentConfig.temporalWindow, MEM_WINDOW_MIN_SIZE); 48 | this.actionsBuffer = new Array(this.netInputWindowSize); 49 | this.inputsBuffer = new Array(this.netInputWindowSize); 50 | this.statesBuffer = new Array(this.netInputWindowSize); 51 | 52 | this.forwardPasses = 0; 53 | } 54 | 55 | private createNeuralNetInput(input: Tensor): Tensor { 56 | return tidy(() => { 57 | let finalInput = input.clone(); 58 | 59 | for (let i = 0; i < this.AgentConfig.temporalWindow; ++i) { 60 | // Here we concatenate input with previous input 61 | finalInput = finalInput.concat(this.statesBuffer[this.netInputWindowSize - 1 - i], 1); 62 | 63 | // And we add to previous input previous action 64 | // (range from 0 to actions, and give a 1 or a 0 if we took this action or not) 65 | let ten = tensor([ 66 | range(0, this.model.OutputSize) 67 | .map((val) => val == this.actionsBuffer[this.netInputWindowSize - 1 - i] ? 1.0 : 0.0) 68 | ]); 69 | finalInput = finalInput.concat(ten, 1); 70 | } 71 | 72 | return finalInput; 73 | }); 74 | } 75 | 76 | private policy(input: Tensor): number { 77 | return this.model.predict(input).getAction(); 78 | } 79 | 80 | infer(input: number[] | number[][], epsilon: number, keepTensors: boolean = true): number { 81 | this.forwardPasses += 1; 82 | 83 | let action; 84 | let netInput; 85 | let tensorInput; 86 | if(Array.isArray(input) && Array.isArray(input[0])) 87 | tensorInput = tensor(input); 88 | else if (Array.isArray(input)) 89 | tensorInput = tensor2d([input as number[]], [1, input.length]); 90 | else 91 | throw new Error("Unable to create convenient tensor for training."); 92 | 93 | if (this.forwardPasses > this.AgentConfig.temporalWindow) { 94 | netInput = this.createNeuralNetInput(tensorInput); 95 | 96 | if (random(0, 1, true) < epsilon) { 97 | // Select a random action according to epsilon probability 98 | action = this.model.randomOutput(); 99 | } else { 100 | // Or just use our policy 101 | action = this.policy(netInput); 102 | } 103 | } else { 104 | // Case in the beginnings 105 | action = this.model.randomOutput(); 106 | netInput = tensor([]); 107 | } 108 | 109 | const stateShifted = this.statesBuffer.shift(); 110 | if (stateShifted) 111 | stateShifted.dispose(); 112 | this.statesBuffer.push(tensorInput); 113 | 114 | if (keepTensors) { 115 | this.actionsBuffer.shift(); 116 | this.inputsBuffer.shift(); 117 | 118 | this.actionsBuffer.push(action); 119 | this.inputsBuffer.push({tensor: netInput, references: 0}); 120 | } else { 121 | netInput.dispose(); 122 | } 123 | 124 | return action; 125 | } 126 | 127 | memorize(): void { 128 | 129 | this.rewardsHistory.add(this.currentReward); 130 | 131 | if (this.forwardPasses <= this.AgentConfig.temporalWindow + 1) return; 132 | 133 | // Save experience 134 | this.memory.remember({ 135 | action: this.actionsBuffer[this.netInputWindowSize - MEM_WINDOW_MIN_SIZE], 136 | reward: this.currentReward, 137 | state: this.inputsBuffer[this.netInputWindowSize - MEM_WINDOW_MIN_SIZE], 138 | nextState: this.inputsBuffer[this.netInputWindowSize - 1] 139 | }); 140 | } 141 | 142 | createTrainingDataFromMemento(memento: Memento, gamma: number, alpha: number): { x: Tensor, y: Tensor } { 143 | return tidy(() => { 144 | let target = memento.reward; 145 | if (!this.done) { 146 | target = alpha * (memento.reward + gamma * (this.model.predict(memento.nextState.tensor).getHighestValue())); 147 | } 148 | 149 | let future_target = this.model.predict(memento.state.tensor).getValue(); 150 | future_target[memento.action] += target; 151 | return {x: memento.state.tensor.clone(), y: tensor2d(future_target, [1, this.model.OutputSize])}; 152 | }); 153 | } 154 | 155 | listen(input: number[] | number[][], epsilon: number): number { 156 | let action = this.infer(input, epsilon, true); 157 | this.memorize(); 158 | 159 | return action; 160 | } 161 | 162 | async learn(gamma: number, alpha: number) { 163 | const trainData = this.memory.sample(this.AgentConfig.batchSize) 164 | .map(memento => this.createTrainingDataFromMemento(memento, gamma, alpha)) 165 | .reduce((previousValue, currentValue) => { 166 | const res = { 167 | x: previousValue.x.concat(currentValue.x), 168 | y: previousValue.y.concat(currentValue.y) 169 | }; 170 | 171 | previousValue.x.dispose(); 172 | previousValue.y.dispose(); 173 | currentValue.x.dispose(); 174 | currentValue.y.dispose(); 175 | 176 | return res; 177 | } 178 | ); 179 | 180 | 181 | const history = await this.model.fit(trainData.x, trainData.y); 182 | const loss = history.history.loss[0]; 183 | this.lossesHistory.add(loss); 184 | 185 | trainData.x.dispose(); 186 | trainData.y.dispose(); 187 | } 188 | 189 | addReward(value: number): void { 190 | this.currentReward += value; 191 | } 192 | 193 | setReward(value: number): void { 194 | this.currentReward = value; 195 | } 196 | 197 | reset(): void { 198 | this.memory.reset(); 199 | this.inputsBuffer.forEach(i => i.tensor.dispose()); 200 | this.statesBuffer.forEach(s => s.dispose()); 201 | this.forwardPasses = 0; 202 | } 203 | 204 | get AgentConfig(): DQAgentConfig { 205 | return this.agentConfig; 206 | } 207 | 208 | set AgentConfig(config: DQAgentConfig) { 209 | this.setAgentConfig(config); 210 | } 211 | 212 | getTrackingInformation(): AgentTrackingInformation { 213 | return { 214 | averageReward: this.rewardsHistory.mean(), 215 | averageLoss: this.lossesHistory.mean(), 216 | name: this.Name 217 | } 218 | } 219 | } -------------------------------------------------------------------------------- /docs/classes/convolutionalneuralnetwork.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [ConvolutionalNeuralNetwork](convolutionalneuralnetwork.md) / 4 | 5 | # Class: ConvolutionalNeuralNetwork 6 | 7 | **`deprecated`** Do not use convolutional networks with ReImproveJS for now, they are not fully implemented and tested in 8 | the library. 9 | 10 | ## Hierarchy 11 | 12 | * [NeuralNetwork](neuralnetwork.md) 13 | 14 | * **ConvolutionalNeuralNetwork** 15 | 16 | ### Index 17 | 18 | #### Constructors 19 | 20 | * [constructor](convolutionalneuralnetwork.md#constructor) 21 | 22 | #### Properties 23 | 24 | * [convolutionalLayers](convolutionalneuralnetwork.md#private-convolutionallayers) 25 | * [flattenLayer](convolutionalneuralnetwork.md#private-flattenlayer) 26 | * [inputShape](convolutionalneuralnetwork.md#protected-inputshape) 27 | 28 | #### Accessors 29 | 30 | * [FlattenLayer](convolutionalneuralnetwork.md#flattenlayer) 31 | * [InputShape](convolutionalneuralnetwork.md#inputshape) 32 | 33 | #### Methods 34 | 35 | * [addConvolutionalLayer](convolutionalneuralnetwork.md#addconvolutionallayer) 36 | * [addConvolutionalLayers](convolutionalneuralnetwork.md#addconvolutionallayers) 37 | * [addMaxPooling2DLayer](convolutionalneuralnetwork.md#addmaxpooling2dlayer) 38 | * [addNeuralNetworkLayer](convolutionalneuralnetwork.md#addneuralnetworklayer) 39 | * [addNeuralNetworkLayers](convolutionalneuralnetwork.md#addneuralnetworklayers) 40 | * [createLayers](convolutionalneuralnetwork.md#createlayers) 41 | * [getLayers](convolutionalneuralnetwork.md#getlayers) 42 | 43 | #### Object literals 44 | 45 | * [DEFAULT_CONV_LAYER](convolutionalneuralnetwork.md#static-private-default_conv_layer) 46 | * [DEFAULT_POOLING_LAYER](convolutionalneuralnetwork.md#static-private-default_pooling_layer) 47 | 48 | ## Constructors 49 | 50 | ### constructor 51 | 52 | \+ **new ConvolutionalNeuralNetwork**(): *[ConvolutionalNeuralNetwork](convolutionalneuralnetwork.md)* 53 | 54 | *Overrides [NeuralNetwork](neuralnetwork.md).[constructor](neuralnetwork.md#constructor)* 55 | 56 | *Defined in [reimprove/networks.ts:121](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L121)* 57 | 58 | **Returns:** *[ConvolutionalNeuralNetwork](convolutionalneuralnetwork.md)* 59 | 60 | ___ 61 | 62 | ## Properties 63 | 64 | ### `Private` convolutionalLayers 65 | 66 | ● **convolutionalLayers**: *[ConvolutionalNetworkLayer](../interfaces/convolutionalnetworklayer.md)[]* 67 | 68 | *Defined in [reimprove/networks.ts:107](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L107)* 69 | 70 | ___ 71 | 72 | ### `Private` flattenLayer 73 | 74 | ● **flattenLayer**: *[FlattenLayer](../interfaces/flattenlayer.md)* 75 | 76 | *Defined in [reimprove/networks.ts:108](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L108)* 77 | 78 | ___ 79 | 80 | ### `Protected` inputShape 81 | 82 | ● **inputShape**: *number[]* 83 | 84 | *Inherited from [NeuralNetwork](neuralnetwork.md).[inputShape](neuralnetwork.md#protected-inputshape)* 85 | 86 | *Defined in [reimprove/networks.ts:55](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L55)* 87 | 88 | ___ 89 | 90 | ## Accessors 91 | 92 | ### FlattenLayer 93 | 94 | ● **set FlattenLayer**(`layer`: [FlattenLayer](../interfaces/flattenlayer.md)): *void* 95 | 96 | *Defined in [reimprove/networks.ts:162](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L162)* 97 | 98 | **Parameters:** 99 | 100 | Name | Type | 101 | ------ | ------ | 102 | `layer` | [FlattenLayer](../interfaces/flattenlayer.md) | 103 | 104 | **Returns:** *void* 105 | 106 | ___ 107 | 108 | ### InputShape 109 | 110 | ● **set InputShape**(`shape`: number[]): *void* 111 | 112 | *Inherited from [NeuralNetwork](neuralnetwork.md).[InputShape](neuralnetwork.md#inputshape)* 113 | 114 | *Defined in [reimprove/networks.ts:85](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L85)* 115 | 116 | **Parameters:** 117 | 118 | Name | Type | 119 | ------ | ------ | 120 | `shape` | number[] | 121 | 122 | **Returns:** *void* 123 | 124 | ___ 125 | 126 | ## Methods 127 | 128 | ### addConvolutionalLayer 129 | 130 | ▸ **addConvolutionalLayer**(`layer`: number | [ConvolutionalNetworkLayer](../interfaces/convolutionalnetworklayer.md)): *void* 131 | 132 | *Defined in [reimprove/networks.ts:133](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L133)* 133 | 134 | **Parameters:** 135 | 136 | Name | Type | 137 | ------ | ------ | 138 | `layer` | number \| [ConvolutionalNetworkLayer](../interfaces/convolutionalnetworklayer.md) | 139 | 140 | **Returns:** *void* 141 | 142 | ___ 143 | 144 | ### addConvolutionalLayers 145 | 146 | ▸ **addConvolutionalLayers**(`layers`: `Array`): *void* 147 | 148 | *Defined in [reimprove/networks.ts:146](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L146)* 149 | 150 | **Parameters:** 151 | 152 | Name | Type | 153 | ------ | ------ | 154 | `layers` | `Array` | 155 | 156 | **Returns:** *void* 157 | 158 | ___ 159 | 160 | ### addMaxPooling2DLayer 161 | 162 | ▸ **addMaxPooling2DLayer**(`layer?`: [MaxPooling2DLayer](../interfaces/maxpooling2dlayer.md)): *void* 163 | 164 | *Defined in [reimprove/networks.ts:129](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L129)* 165 | 166 | **Parameters:** 167 | 168 | Name | Type | 169 | ------ | ------ | 170 | `layer?` | [MaxPooling2DLayer](../interfaces/maxpooling2dlayer.md) | 171 | 172 | **Returns:** *void* 173 | 174 | ___ 175 | 176 | ### addNeuralNetworkLayer 177 | 178 | ▸ **addNeuralNetworkLayer**(`layer`: number | [NeuralNetworkLayer](../interfaces/neuralnetworklayer.md)): *void* 179 | 180 | *Inherited from [NeuralNetwork](neuralnetwork.md).[addNeuralNetworkLayer](neuralnetwork.md#addneuralnetworklayer)* 181 | 182 | *Defined in [reimprove/networks.ts:69](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L69)* 183 | 184 | **Parameters:** 185 | 186 | Name | Type | 187 | ------ | ------ | 188 | `layer` | number \| [NeuralNetworkLayer](../interfaces/neuralnetworklayer.md) | 189 | 190 | **Returns:** *void* 191 | 192 | ___ 193 | 194 | ### addNeuralNetworkLayers 195 | 196 | ▸ **addNeuralNetworkLayers**(`layers`: `Array`): *void* 197 | 198 | *Inherited from [NeuralNetwork](neuralnetwork.md).[addNeuralNetworkLayers](neuralnetwork.md#addneuralnetworklayers)* 199 | 200 | *Defined in [reimprove/networks.ts:81](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L81)* 201 | 202 | **Parameters:** 203 | 204 | Name | Type | 205 | ------ | ------ | 206 | `layers` | `Array` | 207 | 208 | **Returns:** *void* 209 | 210 | ___ 211 | 212 | ### createLayers 213 | 214 | ▸ **createLayers**(`includeInputShape`: boolean): *`Array`* 215 | 216 | *Overrides [NeuralNetwork](neuralnetwork.md).[createLayers](neuralnetwork.md#createlayers)* 217 | 218 | *Defined in [reimprove/networks.ts:150](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L150)* 219 | 220 | **Parameters:** 221 | 222 | Name | Type | Default | 223 | ------ | ------ | ------ | 224 | `includeInputShape` | boolean | true | 225 | 226 | **Returns:** *`Array`* 227 | 228 | ___ 229 | 230 | ### getLayers 231 | 232 | ▸ **getLayers**(): *[Layer](../interfaces/layer.md)[]* 233 | 234 | *Overrides [NeuralNetwork](neuralnetwork.md).[getLayers](neuralnetwork.md#getlayers)* 235 | 236 | *Defined in [reimprove/networks.ts:166](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L166)* 237 | 238 | **Returns:** *[Layer](../interfaces/layer.md)[]* 239 | 240 | ___ 241 | 242 | ## Object literals 243 | 244 | ### `Static` `Private` DEFAULT_CONV_LAYER 245 | 246 | ### ■ **DEFAULT_CONV_LAYER**: *object* 247 | 248 | *Defined in [reimprove/networks.ts:110](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L110)* 249 | 250 | ### activation 251 | 252 | ● **activation**: *string* = "relu" 253 | 254 | *Defined in [reimprove/networks.ts:113](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L113)* 255 | 256 | ### filters 257 | 258 | ● **filters**: *number* = 32 259 | 260 | *Defined in [reimprove/networks.ts:111](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L111)* 261 | 262 | ### kernelSize 263 | 264 | ● **kernelSize**: *number* = 3 265 | 266 | *Defined in [reimprove/networks.ts:112](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L112)* 267 | 268 | ### type 269 | 270 | ● **type**: *"convolutional"* = "convolutional" 271 | 272 | *Defined in [reimprove/networks.ts:114](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L114)* 273 | 274 | ___ 275 | 276 | ### `Static` `Private` DEFAULT_POOLING_LAYER 277 | 278 | ### ■ **DEFAULT_POOLING_LAYER**: *object* 279 | 280 | *Defined in [reimprove/networks.ts:117](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L117)* 281 | 282 | ### poolSize 283 | 284 | ● **poolSize**: *number* = 2 285 | 286 | *Defined in [reimprove/networks.ts:118](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L118)* 287 | 288 | ### strides 289 | 290 | ● **strides**: *null* = null 291 | 292 | *Defined in [reimprove/networks.ts:119](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L119)* 293 | 294 | ### type 295 | 296 | ● **type**: *"maxpooling"* = "maxpooling" 297 | 298 | *Defined in [reimprove/networks.ts:120](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/networks.ts#L120)* 299 | 300 | ___ -------------------------------------------------------------------------------- /src/reimprove/academy.ts: -------------------------------------------------------------------------------- 1 | import {DQAgent} from "./algorithms/deepq/dqagent"; 2 | import {Teacher, TeacherTrackingInformation, TeachingConfig} from "./teacher"; 3 | import {v4} from 'uuid'; 4 | import {Model} from "./model"; 5 | import {LearningDataLogger} from "./misc/learning_data_logger"; 6 | import {DQAgentConfig} from "./algorithms/agent_config"; 7 | 8 | const DEFAULT_ACADEMY_CONFIG: AcademyConfig = { 9 | parentLogsElement: null, 10 | agentsLogs: false, 11 | memoryLogs: false 12 | }; 13 | 14 | /** 15 | * Academy configuration, used for logs. You need to say if you want to log agents and memory, and 16 | * give the parent
element 17 | */ 18 | export interface AcademyConfig { 19 | /** Parent
element, default to `null` */ 20 | parentLogsElement: HTMLElement; 21 | /** If agents logs should be displayed, default to `false` */ 22 | agentsLogs: boolean; 23 | /** If memory logs should be displayed, default to `false` */ 24 | memoryLogs: boolean; 25 | } 26 | 27 | /** 28 | * Input to give at each step of the Academy, where you specify the target teacher and its inputs. 29 | */ 30 | export interface AcademyStepInput { 31 | teacherName: string; 32 | agentsInput: number[]; 33 | } 34 | 35 | /** Configuration to build an agent */ 36 | export interface BuildAgentConfig { 37 | /** The agent cannot have no model. But multiple agents can share the same one */ 38 | model: Model; 39 | /** The agent configuration, defaulted if not present */ 40 | agentConfig?: DQAgentConfig; 41 | } 42 | 43 | /** 44 | * Class to interact with when creating the environment and updating it. 45 | */ 46 | export class Academy { 47 | 48 | private agents: Map; 49 | private teachers: Map; 50 | private assigments: Map; 51 | 52 | private logger: LearningDataLogger; 53 | private config: AcademyConfig; 54 | 55 | constructor(config?: AcademyConfig) { 56 | this.config = {...DEFAULT_ACADEMY_CONFIG, ...config}; 57 | this.agents = new Map(); 58 | this.teachers = new Map(); 59 | this.assigments = new Map(); 60 | 61 | if(this.config.parentLogsElement) { 62 | this.createLogger(this.config.parentLogsElement); 63 | } 64 | } 65 | 66 | addAgent(config: BuildAgentConfig, name?: string): string { 67 | let agent = new DQAgent(config.model, config.agentConfig, name); 68 | if (!agent.Name) 69 | agent.Name = v4(); 70 | 71 | this.agents.set(agent.Name, agent); 72 | 73 | return agent.Name; 74 | } 75 | 76 | addTeacher(config?: TeachingConfig, name?: string): string { 77 | let teacher = new Teacher(config, name); 78 | if (!teacher.Name) 79 | teacher.Name = v4(); 80 | 81 | this.teachers.set(teacher.Name, teacher); 82 | 83 | return teacher.Name; 84 | } 85 | 86 | assignTeacherToAgent(agentName: string, teacherName: string) { 87 | if (!this.agents.has(agentName)) 88 | throw new Error("No such agent has been registered"); 89 | if (!this.teachers.has(teacherName)) 90 | throw new Error("No such teacher has been registered"); 91 | 92 | this.assigments.set(agentName, teacherName); 93 | this.teachers.get(teacherName).affectStudent(this.agents.get(agentName)); 94 | } 95 | 96 | /** 97 | * A step in the academy, giving the teachers their inputs, and propagating it to agents. Returns a [[Map]] where you 98 | * just have to pick for each agent's name its decision. At each step all the rewards are reset to 0. 99 | * @param {AcademyStepInput[] | AcademyStepInput} inputs You can give only one input as well as an array of inputs. 100 | * @returns {Promise>} 101 | */ 102 | async step(inputs: AcademyStepInput[] | AcademyStepInput): Promise> { 103 | let actions = new Map(); 104 | let finalInput = inputs instanceof Array ? inputs : [inputs]; 105 | for(let input of finalInput) { 106 | if (!this.teachers.has(input.teacherName)) { 107 | throw new Error("No teacher has name " + input.teacherName); 108 | } 109 | 110 | const agentsActions = await this.teachers.get(input.teacherName).teach(input.agentsInput); 111 | agentsActions.forEach((value, key) => { 112 | if (actions.has(key)) 113 | throw new Error("Dqagent " + key + " has already registered an action."); 114 | 115 | actions.set(key, value); 116 | }); 117 | } 118 | 119 | if (this.logger) 120 | this.logger.updateTables(true); 121 | 122 | return actions; 123 | } 124 | 125 | /** 126 | * Add a reward to an agent, given its name. Be careful to give a reward normalized between -1.0 and 1.0 for an optimal 127 | * learn. 128 | * @param {string} name 129 | * @param {number} reward 130 | */ 131 | addRewardToAgent(name: string, reward: number) { 132 | if (this.agents.has(name)) 133 | this.agents.get(name).addReward(reward); 134 | } 135 | 136 | /** 137 | * In case where you just want to clearly set the agent's current reward for this step. 138 | * @param {string} name 139 | * @param {number} reward 140 | */ 141 | setRewardOfAgent(name: string, reward: number) { 142 | if (this.agents.has(name)) 143 | this.agents.get(name).setReward(reward); 144 | } 145 | 146 | /** 147 | * Callback which will be called each time the model's fit ends after the end of the lesson. 148 | * @param {string} teacherName The target teacher which will call the callback 149 | * @param {(teacher: string) => void} callback The callback, giving the teacher name 150 | * @constructor 151 | */ 152 | OnLearningLessonEnded(teacherName: string, callback: (teacher: string) => void) { 153 | if (this.teachers.has(teacherName)) 154 | this.teachers.get(teacherName).onLearningLessonEnded = callback; 155 | } 156 | 157 | /** 158 | * Callback called when a lesson is ended 159 | * @param {string} teacherName The target teacher which will call the callback 160 | * @param {(teacher: string, lessonNumber: number) => void} callback The callback, giving the teacher name and the index of the just ended lesson. 161 | * @constructor 162 | */ 163 | OnLessonEnded(teacherName: string, callback: (teacher: string, lessonNumber: number) => void) { 164 | if (this.teachers.has(teacherName)) 165 | this.teachers.get(teacherName).onLessonEnded = callback; 166 | } 167 | 168 | /** 169 | * Callback called when a lesson is ended 170 | * @param {string} teacherName The target teacher which will call the callback 171 | * @param {(teacher: string, lessonNumber: number) => void} callback The callback, giving the teacher name 172 | * @constructor 173 | */ 174 | OnTeachingEnded(teacherName: string, callback: (teacher: string) => void) { 175 | if (this.teachers.has(teacherName)) 176 | this.teachers.get(teacherName).onTeachingEnded = callback; 177 | } 178 | 179 | /** 180 | * Function to reset everything from teachers and agents (resetting parameters of teachers, and resetting memory and parameters of agents). 181 | */ 182 | resetTeachersAndAgents() { 183 | this.teachers.forEach(t => t.reset()); 184 | this.agents.forEach(a => a.reset()); 185 | } 186 | 187 | /** 188 | * Function resetting everything in the academy, calling first [[resetTeachersAndAgents]], then cleaning everything concerning teachers and agents. 189 | */ 190 | reset() { 191 | this.resetTeachersAndAgents(); 192 | this.teachers.clear(); 193 | this.agents.clear(); 194 | } 195 | 196 | /** 197 | * Resets to 0 the current state of the lesson. It cannot forget 198 | * @param {string} teacherName 199 | */ 200 | resetTeacherLesson(teacherName: string) { 201 | this.teachers.get(teacherName).resetLesson(); 202 | } 203 | 204 | /** 205 | * Gives the list of teachers 206 | * @returns {string[]} 207 | * @constructor 208 | */ 209 | get Teachers() { 210 | return Array.from(this.teachers.keys()); 211 | } 212 | 213 | /** 214 | * Used for logs, returning the tracking informations of a teacher, see [[TeacherTrackingInformation]] 215 | * @param {string} name 216 | * @returns {TeacherTrackingInformation} 217 | */ 218 | getTeacherData(name: string): TeacherTrackingInformation { 219 | return this.teachers.get(name).getData(); 220 | } 221 | 222 | /** 223 | * If not given in the configuration options in the constructor, you can choose to create the logger here 224 | * @param {HTMLElement} parent 225 | */ 226 | createLogger(parent: HTMLElement): void { 227 | if (this.logger) this.logger.dispose(); 228 | this.config.parentLogsElement = parent; 229 | this.logger = new LearningDataLogger(parent, this); 230 | } 231 | 232 | /** 233 | * Method to toggle logs, taking an argument to toggle memory logs. 234 | * @param {boolean} memory 235 | */ 236 | toggleLogs(memory = false): void { 237 | const status = this.config.agentsLogs; 238 | this.config.agentsLogs = !status; 239 | if(status) 240 | this.config.memoryLogs = memory; 241 | } 242 | 243 | toggleTeaching(teacher: string, toggle: boolean): void { 244 | if(toggle === true) 245 | this.teachers.get(teacher).startTeaching(); 246 | else 247 | this.teachers.get(teacher).stopTeaching(); 248 | } 249 | } -------------------------------------------------------------------------------- /docs/classes/teacher.md: -------------------------------------------------------------------------------- 1 | > ## [ReImproveJS](../README.md) 2 | 3 | [Globals](../globals.md) / [Teacher](teacher.md) / 4 | 5 | # Class: Teacher 6 | 7 | ## Hierarchy 8 | 9 | * **Teacher** 10 | 11 | ### Index 12 | 13 | #### Constructors 14 | 15 | * [constructor](teacher.md#constructor) 16 | 17 | #### Properties 18 | 19 | * [agents](teacher.md#agents) 20 | * [config](teacher.md#config) 21 | * [currentEpsilon](teacher.md#currentepsilon) 22 | * [currentLessonLength](teacher.md#currentlessonlength) 23 | * [lessonsTaught](teacher.md#lessonstaught) 24 | * [name](teacher.md#name) 25 | * [onLearningLessonEnded](teacher.md#onlearninglessonended) 26 | * [onLessonEnded](teacher.md#onlessonended) 27 | * [onTeachingEnded](teacher.md#onteachingended) 28 | * [state](teacher.md#state) 29 | 30 | #### Accessors 31 | 32 | * [Name](teacher.md#name) 33 | * [OnLearningLessonEnded](teacher.md#onlearninglessonended) 34 | * [OnLessonEnded](teacher.md#onlessonended) 35 | * [OnTeachingEnded](teacher.md#onteachingended) 36 | * [State](teacher.md#state) 37 | 38 | #### Methods 39 | 40 | * [affectStudent](teacher.md#affectstudent) 41 | * [getData](teacher.md#getdata) 42 | * [removeStudent](teacher.md#removestudent) 43 | * [reset](teacher.md#reset) 44 | * [resetLesson](teacher.md#resetlesson) 45 | * [start](teacher.md#start) 46 | * [startTeaching](teacher.md#startteaching) 47 | * [stop](teacher.md#stop) 48 | * [stopTeaching](teacher.md#stopteaching) 49 | * [teach](teacher.md#teach) 50 | * [updateParameters](teacher.md#updateparameters) 51 | 52 | ## Constructors 53 | 54 | ### constructor 55 | 56 | \+ **new Teacher**(`config?`: [TeachingConfig](../interfaces/teachingconfig.md), `name?`: string): *[Teacher](teacher.md)* 57 | 58 | *Defined in [reimprove/teacher.ts:60](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L60)* 59 | 60 | **Parameters:** 61 | 62 | Name | Type | 63 | ------ | ------ | 64 | `config?` | [TeachingConfig](../interfaces/teachingconfig.md) | 65 | `name?` | string | 66 | 67 | **Returns:** *[Teacher](teacher.md)* 68 | 69 | ___ 70 | 71 | ## Properties 72 | 73 | ### agents 74 | 75 | ● **agents**: *`Set`* 76 | 77 | *Defined in [reimprove/teacher.ts:51](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L51)* 78 | 79 | ___ 80 | 81 | ### config 82 | 83 | ● **config**: *[TeachingConfig](../interfaces/teachingconfig.md)* 84 | 85 | *Defined in [reimprove/teacher.ts:48](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L48)* 86 | 87 | ___ 88 | 89 | ### currentEpsilon 90 | 91 | ● **currentEpsilon**: *number* 92 | 93 | *Defined in [reimprove/teacher.ts:60](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L60)* 94 | 95 | ___ 96 | 97 | ### currentLessonLength 98 | 99 | ● **currentLessonLength**: *number* 100 | 101 | *Defined in [reimprove/teacher.ts:53](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L53)* 102 | 103 | ___ 104 | 105 | ### lessonsTaught 106 | 107 | ● **lessonsTaught**: *number* 108 | 109 | *Defined in [reimprove/teacher.ts:54](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L54)* 110 | 111 | ___ 112 | 113 | ### name 114 | 115 | ● **name**: *string* 116 | 117 | *Defined in [reimprove/teacher.ts:47](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L47)* 118 | 119 | ___ 120 | 121 | ### onLearningLessonEnded 122 | 123 | ● **onLearningLessonEnded**: *function* 124 | 125 | *Defined in [reimprove/teacher.ts:56](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L56)* 126 | 127 | #### Type declaration: 128 | 129 | ▸ (`teacher`: string): *void* 130 | 131 | **Parameters:** 132 | 133 | Name | Type | 134 | ------ | ------ | 135 | `teacher` | string | 136 | 137 | ___ 138 | 139 | ### onLessonEnded 140 | 141 | ● **onLessonEnded**: *function* 142 | 143 | *Defined in [reimprove/teacher.ts:57](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L57)* 144 | 145 | #### Type declaration: 146 | 147 | ▸ (`teacher`: string, `lessonNumber`: number): *void* 148 | 149 | **Parameters:** 150 | 151 | Name | Type | 152 | ------ | ------ | 153 | `teacher` | string | 154 | `lessonNumber` | number | 155 | 156 | ___ 157 | 158 | ### onTeachingEnded 159 | 160 | ● **onTeachingEnded**: *function* 161 | 162 | *Defined in [reimprove/teacher.ts:58](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L58)* 163 | 164 | #### Type declaration: 165 | 166 | ▸ (`teacher`: string): *void* 167 | 168 | **Parameters:** 169 | 170 | Name | Type | 171 | ------ | ------ | 172 | `teacher` | string | 173 | 174 | ___ 175 | 176 | ### state 177 | 178 | ● **state**: *[TeachingState](../enums/teachingstate.md)* 179 | 180 | *Defined in [reimprove/teacher.ts:49](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L49)* 181 | 182 | ___ 183 | 184 | ## Accessors 185 | 186 | ### Name 187 | 188 | ● **get Name**(): *string* 189 | 190 | *Defined in [reimprove/teacher.ts:210](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L210)* 191 | 192 | **Returns:** *string* 193 | 194 | ● **set Name**(`name`: string): *void* 195 | 196 | *Defined in [reimprove/teacher.ts:206](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L206)* 197 | 198 | **Parameters:** 199 | 200 | Name | Type | 201 | ------ | ------ | 202 | `name` | string | 203 | 204 | **Returns:** *void* 205 | 206 | ___ 207 | 208 | ### OnLearningLessonEnded 209 | 210 | ● **set OnLearningLessonEnded**(`callback`: function): *void* 211 | 212 | *Defined in [reimprove/teacher.ts:194](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L194)* 213 | 214 | **Parameters:** 215 | 216 | ■` callback`: *function* 217 | 218 | ▸ (`teacher`: string): *void* 219 | 220 | **Parameters:** 221 | 222 | Name | Type | 223 | ------ | ------ | 224 | `teacher` | string | 225 | 226 | **Returns:** *void* 227 | 228 | ___ 229 | 230 | ### OnLessonEnded 231 | 232 | ● **set OnLessonEnded**(`callback`: function): *void* 233 | 234 | *Defined in [reimprove/teacher.ts:198](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L198)* 235 | 236 | **Parameters:** 237 | 238 | ■` callback`: *function* 239 | 240 | ▸ (`teacher`: string, `lessonNumber`: number): *void* 241 | 242 | **Parameters:** 243 | 244 | Name | Type | 245 | ------ | ------ | 246 | `teacher` | string | 247 | `lessonNumber` | number | 248 | 249 | **Returns:** *void* 250 | 251 | ___ 252 | 253 | ### OnTeachingEnded 254 | 255 | ● **set OnTeachingEnded**(`callback`: function): *void* 256 | 257 | *Defined in [reimprove/teacher.ts:202](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L202)* 258 | 259 | **Parameters:** 260 | 261 | ■` callback`: *function* 262 | 263 | ▸ (`teacher`: string): *void* 264 | 265 | **Parameters:** 266 | 267 | Name | Type | 268 | ------ | ------ | 269 | `teacher` | string | 270 | 271 | **Returns:** *void* 272 | 273 | ___ 274 | 275 | ### State 276 | 277 | ● **get State**(): *[TeachingState](../enums/teachingstate.md)* 278 | 279 | *Defined in [reimprove/teacher.ts:214](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L214)* 280 | 281 | **Returns:** *[TeachingState](../enums/teachingstate.md)* 282 | 283 | ___ 284 | 285 | ## Methods 286 | 287 | ### affectStudent 288 | 289 | ▸ **affectStudent**(`agent`: [DQAgent](dqagent.md)): *void* 290 | 291 | *Defined in [reimprove/teacher.ts:76](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L76)* 292 | 293 | **Parameters:** 294 | 295 | Name | Type | 296 | ------ | ------ | 297 | `agent` | [DQAgent](dqagent.md) | 298 | 299 | **Returns:** *void* 300 | 301 | ___ 302 | 303 | ### getData 304 | 305 | ▸ **getData**(): *[TeacherTrackingInformation](../interfaces/teachertrackinginformation.md)* 306 | 307 | *Defined in [reimprove/teacher.ts:165](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L165)* 308 | 309 | **Returns:** *[TeacherTrackingInformation](../interfaces/teachertrackinginformation.md)* 310 | 311 | ___ 312 | 313 | ### removeStudent 314 | 315 | ▸ **removeStudent**(`agent`: [DQAgent](dqagent.md)): *boolean* 316 | 317 | *Defined in [reimprove/teacher.ts:80](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L80)* 318 | 319 | **Parameters:** 320 | 321 | Name | Type | 322 | ------ | ------ | 323 | `agent` | [DQAgent](dqagent.md) | 324 | 325 | **Returns:** *boolean* 326 | 327 | ___ 328 | 329 | ### reset 330 | 331 | ▸ **reset**(): *void* 332 | 333 | *Defined in [reimprove/teacher.ts:184](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L184)* 334 | 335 | **Returns:** *void* 336 | 337 | ___ 338 | 339 | ### resetLesson 340 | 341 | ▸ **resetLesson**(): *void* 342 | 343 | *Defined in [reimprove/teacher.ts:179](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L179)* 344 | 345 | **Returns:** *void* 346 | 347 | ___ 348 | 349 | ### start 350 | 351 | ▸ **start**(): *void* 352 | 353 | *Defined in [reimprove/teacher.ts:84](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L84)* 354 | 355 | **Returns:** *void* 356 | 357 | ___ 358 | 359 | ### startTeaching 360 | 361 | ▸ **startTeaching**(): *void* 362 | 363 | *Defined in [reimprove/teacher.ts:151](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L151)* 364 | 365 | **Returns:** *void* 366 | 367 | ___ 368 | 369 | ### stop 370 | 371 | ▸ **stop**(): *void* 372 | 373 | *Defined in [reimprove/teacher.ts:190](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L190)* 374 | 375 | **Returns:** *void* 376 | 377 | ___ 378 | 379 | ### stopTeaching 380 | 381 | ▸ **stopTeaching**(): *void* 382 | 383 | *Defined in [reimprove/teacher.ts:147](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L147)* 384 | 385 | **Returns:** *void* 386 | 387 | ___ 388 | 389 | ### teach 390 | 391 | ▸ **teach**(`inputs`: number[]): *`Promise>`* 392 | 393 | *Defined in [reimprove/teacher.ts:90](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L90)* 394 | 395 | **Parameters:** 396 | 397 | Name | Type | 398 | ------ | ------ | 399 | `inputs` | number[] | 400 | 401 | **Returns:** *`Promise>`* 402 | 403 | ___ 404 | 405 | ### updateParameters 406 | 407 | ▸ **updateParameters**(): *void* 408 | 409 | *Defined in [reimprove/teacher.ts:156](https://github.com/DevSide/ReImproveJS/blob/2368b25/src/reimprove/teacher.ts#L156)* 410 | 411 | **Returns:** *void* 412 | 413 | ___ --------------------------------------------------------------------------------