├── .gitignore
├── LICENSE
├── README.md
├── images
├── DQN.png
├── DQN2013.png
├── DQN2015.png
├── DQN2015_.png
├── FrozenLake.png
├── MCOnPolicy.png
├── QLearning.png
├── RL CrossEntroy Demo.png
├── Reinforcement.png
└── Sarsa.png
├── resources
├── Mindmap.emmx
├── iris-test.txt
└── iris-train.txt
└── src
├── DeepSharp.Core
├── DeepSharp.Core.csproj
└── Trainer.cs
├── DeepSharp.Dataset
├── DataLoaders
│ ├── DataLoader.cs
│ ├── DataLoaderConfig.cs
│ └── InfiniteDataLoader.cs
├── Datasets
│ ├── DataView.cs
│ ├── DataViewPair.cs
│ ├── Dataset.cs
│ ├── StreamHeaderAttribute.cs
│ └── StreamHeaderRange.cs
├── DeepSharp.Dataset.csproj
├── DeepSharp.Dataset.csproj.DotSettings
└── Usings.cs
├── DeepSharp.ObjectDetection
├── DeepSharp.ObjectDetection.csproj
└── Yolo.cs
├── DeepSharp.RL
├── ActionSelectors
│ ├── ActionSelector.cs
│ ├── ArgmaxActionSelector.cs
│ ├── EpsilonActionSelector.cs
│ └── ProbActionSelector.cs
├── Agents
│ ├── Agent.cs
│ ├── LearnOutCome.cs
│ ├── Models
│ │ ├── QTable.cs
│ │ ├── RewardKey.cs
│ │ ├── TrasitKey.cs
│ │ └── VTable.cs
│ ├── Nets
│ │ ├── DQNNet.cs
│ │ ├── Net.cs
│ │ └── PGN.cs
│ ├── Others
│ │ ├── CrossEntropy.cs
│ │ └── CrossEntropyExt.cs
│ ├── PolicyAgengt.cs
│ ├── PolicyBased
│ │ ├── A2C.cs
│ │ ├── A3C.cs
│ │ ├── ActorCritic.cs
│ │ ├── Reinforce.cs
│ │ └── ReinforceOriginal.cs
│ ├── ValueAgent.cs
│ └── ValueBased
│ │ ├── DeepQN
│ │ ├── CategoricalDQN.cs
│ │ ├── DQN.cs
│ │ ├── DoubleDQN.cs
│ │ ├── DuelingDQN.cs
│ │ ├── NDQN.cs
│ │ └── NoisyDQN.cs
│ │ ├── DynamicPlan
│ │ ├── PIDiscountR.cs
│ │ ├── PITStepR.cs
│ │ ├── PolicyIteration.cs
│ │ ├── VIDiscountR.cs
│ │ ├── VITStepR.cs
│ │ └── ValueIterate.cs
│ │ ├── MonteCarlo
│ │ ├── MonteCarloOffPolicy.cs
│ │ └── MonteCarloOnPolicy.cs
│ │ ├── TemporalDifference
│ │ ├── QLearning.cs
│ │ └── SARSA.cs
│ │ └── ValueIteration.cs
├── DeepSharp.RL.csproj
├── DeepSharp.RL.csproj.DotSettings
├── Enumerates
│ └── PlayMode.cs
├── Environs
│ ├── Act.cs
│ ├── Environ.cs
│ ├── Episode.cs
│ ├── FrozenLake
│ │ ├── FrozenLake.cs
│ │ ├── LakeRole.cs
│ │ └── LakeUnit.cs
│ ├── KArmedBandit
│ │ ├── Bandit.cs
│ │ └── KArmedBandit.cs
│ ├── Observation.cs
│ ├── Reward.cs
│ ├── Space.cs
│ ├── Spaces
│ │ ├── Binary.cs
│ │ ├── Box.cs
│ │ ├── DigitalSpace.cs
│ │ ├── Disperse.cs
│ │ ├── MultiBinary.cs
│ │ └── MultiDisperse.cs
│ ├── Step.cs
│ └── Wappers
│ │ ├── EnvironWarpper.cs
│ │ └── MaxAndSkipEnv.cs
├── ExpReplays
│ ├── EpisodeExpReplay.cs
│ ├── ExpReplay.cs
│ ├── ExperienceCase.cs
│ ├── PrioritizedExpReplay.cs
│ └── UniformExpReplay.cs
├── Trainers
│ ├── RLTrainOption.cs
│ ├── RLTrainer.cs
│ └── TrainerCallBack.cs
└── Usings.cs
├── DeepSharp.Utility
├── Converters
│ └── Convert.cs
├── DeepSharp.Utility.csproj
├── Operations
│ └── OpMat.cs
├── TensorEqualityCompare.cs
└── Usings.cs
├── DeepSharp.sln
├── DeepSharp.sln.DotSettings
├── RLConsole
├── Program.cs
├── RLConsole.csproj
└── Utility.cs
├── Reinforcement Learning.md
└── TorchSharpTest
├── AbstractTest.cs
├── DemoTest
├── DemoNet.cs
└── IrisTest.cs
├── LossTest
└── LossTest.cs
├── RLTest
├── AgentTest.cs
├── EnvironTest
│ ├── FrozenLakeTest.cs
│ └── KArmedBanditTest.cs
├── ModelTest
│ ├── ActionSelectorTest.cs
│ ├── BasicTest.cs
│ ├── PolicyNetTest.cs
│ ├── QTableTest.cs
│ ├── SpaceTest.cs
│ └── VTableTest.cs
├── PolicyBasedTest
│ ├── ActorCriticTest.cs
│ └── ReinforceTest.cs
├── TrainerTest
│ └── RLTrainTest.cs
└── ValueBasedTest
│ ├── DQNTest.cs
│ ├── MonteCarloTest.cs
│ ├── QLearningTest.cs
│ └── SARSATest.cs
├── SampleDataset
├── Iris.cs
└── IrisOneHot.cs
├── TorchSharpTest.csproj
├── TorchSharpTest.csproj.DotSettings
├── TorchTests
├── DataSetTest.cs
├── ModuleTest.cs
├── SaveLoadTest.cs
└── TensorTest.cs
└── Usings.cs
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Xin.Pu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/images/DQN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/DQN.png
--------------------------------------------------------------------------------
/images/DQN2013.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/DQN2013.png
--------------------------------------------------------------------------------
/images/DQN2015.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/DQN2015.png
--------------------------------------------------------------------------------
/images/DQN2015_.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/DQN2015_.png
--------------------------------------------------------------------------------
/images/FrozenLake.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/FrozenLake.png
--------------------------------------------------------------------------------
/images/MCOnPolicy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/MCOnPolicy.png
--------------------------------------------------------------------------------
/images/QLearning.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/QLearning.png
--------------------------------------------------------------------------------
/images/RL CrossEntroy Demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/RL CrossEntroy Demo.png
--------------------------------------------------------------------------------
/images/Reinforcement.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/Reinforcement.png
--------------------------------------------------------------------------------
/images/Sarsa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/Sarsa.png
--------------------------------------------------------------------------------
/resources/Mindmap.emmx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/resources/Mindmap.emmx
--------------------------------------------------------------------------------
/resources/iris-test.txt:
--------------------------------------------------------------------------------
1 | #Label Sepal length Sepal width Petal length Petal width
2 | 0 5.1 3.5 1.4 0.2
3 | 0 4.9 3.0 1.4 0.2
4 | 0 4.7 3.2 1.3 0.2
5 | 0 4.6 3.1 1.5 0.2
6 | 0 5.0 3.6 1.4 0.2
7 | 0 5.4 3.9 1.7 0.4
8 | 0 4.6 3.4 1.4 0.3
9 | 0 5.0 3.4 1.5 0.2
10 | 0 4.4 2.9 1.4 0.2
11 | 0 4.9 3.1 1.5 0.1
12 | 1 7.0 3.2 4.7 1.4
13 | 1 6.4 3.2 4.5 1.5
14 | 1 6.9 3.1 4.9 1.5
15 | 1 5.5 2.3 4.0 1.3
16 | 1 6.5 2.8 4.6 1.5
17 | 1 5.7 2.8 4.5 1.3
18 | 1 6.3 3.3 4.7 1.6
19 | 1 4.9 2.4 3.3 1.0
20 | 1 6.6 2.9 4.6 1.3
21 | 1 5.2 2.7 3.9 1.4
22 | 2 6.3 3.3 6.0 2.5
23 | 2 5.8 2.7 5.1 1.9
24 | 2 7.1 3.0 5.9 2.1
25 | 2 6.3 2.9 5.6 1.8
26 | 2 6.5 3.0 5.8 2.2
27 | 2 7.6 3.0 6.6 2.1
28 | 2 4.9 2.5 4.5 1.7
29 | 2 7.3 2.9 6.3 1.8
30 | 2 6.7 2.5 5.8 1.8
31 | 2 7.2 3.6 6.1 2.5
32 |
--------------------------------------------------------------------------------
/resources/iris-train.txt:
--------------------------------------------------------------------------------
1 | #Label Sepal length Sepal width Petal length Petal width
2 | 0 5.4 3.7 1.5 0.2
3 | 0 4.8 3.4 1.6 0.2
4 | 0 4.8 3.0 1.4 0.1
5 | 0 4.3 3.0 1.1 0.1
6 | 0 5.8 4.0 1.2 0.2
7 | 0 5.7 4.4 1.5 0.4
8 | 0 5.4 3.9 1.3 0.4
9 | 0 5.1 3.5 1.4 0.3
10 | 0 5.7 3.8 1.7 0.3
11 | 0 5.1 3.8 1.5 0.3
12 | 0 5.4 3.4 1.7 0.2
13 | 0 5.1 3.7 1.5 0.4
14 | 0 4.6 3.6 1.0 0.2
15 | 0 5.1 3.3 1.7 0.5
16 | 0 4.8 3.4 1.9 0.2
17 | 0 5.0 3.0 1.6 0.2
18 | 0 5.0 3.4 1.6 0.4
19 | 0 5.2 3.5 1.5 0.2
20 | 0 5.2 3.4 1.4 0.2
21 | 0 4.7 3.2 1.6 0.2
22 | 0 4.8 3.1 1.6 0.2
23 | 0 5.4 3.4 1.5 0.4
24 | 0 5.2 4.1 1.5 0.1
25 | 0 5.5 4.2 1.4 0.2
26 | 0 4.9 3.1 1.5 0.1
27 | 0 5.0 3.2 1.2 0.2
28 | 0 5.5 3.5 1.3 0.2
29 | 0 4.9 3.1 1.5 0.1
30 | 0 4.4 3.0 1.3 0.2
31 | 0 5.1 3.4 1.5 0.2
32 | 0 5.0 3.5 1.3 0.3
33 | 0 4.5 2.3 1.3 0.3
34 | 0 4.4 3.2 1.3 0.2
35 | 0 5.0 3.5 1.6 0.6
36 | 0 5.1 3.8 1.9 0.4
37 | 0 4.8 3.0 1.4 0.3
38 | 0 5.1 3.8 1.6 0.2
39 | 0 4.6 3.2 1.4 0.2
40 | 0 5.3 3.7 1.5 0.2
41 | 0 5.0 3.3 1.4 0.2
42 | 1 5.0 2.0 3.5 1.0
43 | 1 5.9 3.0 4.2 1.5
44 | 1 6.0 2.2 4.0 1.0
45 | 1 6.1 2.9 4.7 1.4
46 | 1 5.6 2.9 3.6 1.3
47 | 1 6.7 3.1 4.4 1.4
48 | 1 5.6 3.0 4.5 1.5
49 | 1 5.8 2.7 4.1 1.0
50 | 1 6.2 2.2 4.5 1.5
51 | 1 5.6 2.5 3.9 1.1
52 | 1 5.9 3.2 4.8 1.8
53 | 1 6.1 2.8 4.0 1.3
54 | 1 6.3 2.5 4.9 1.5
55 | 1 6.1 2.8 4.7 1.2
56 | 1 6.4 2.9 4.3 1.3
57 | 1 6.6 3.0 4.4 1.4
58 | 1 6.8 2.8 4.8 1.4
59 | 1 6.7 3.0 5.0 1.7
60 | 1 6.0 2.9 4.5 1.5
61 | 1 5.7 2.6 3.5 1.0
62 | 1 5.5 2.4 3.8 1.1
63 | 1 5.5 2.4 3.7 1.0
64 | 1 5.8 2.7 3.9 1.2
65 | 1 6.0 2.7 5.1 1.6
66 | 1 5.4 3.0 4.5 1.5
67 | 1 6.0 3.4 4.5 1.6
68 | 1 6.7 3.1 4.7 1.5
69 | 1 6.3 2.3 4.4 1.3
70 | 1 5.6 3.0 4.1 1.3
71 | 1 5.5 2.5 4.0 1.3
72 | 1 5.5 2.6 4.4 1.2
73 | 1 6.1 3.0 4.6 1.4
74 | 1 5.8 2.6 4.0 1.2
75 | 1 5.0 2.3 3.3 1.0
76 | 1 5.6 2.7 4.2 1.3
77 | 1 5.7 3.0 4.2 1.2
78 | 1 5.7 2.9 4.2 1.3
79 | 1 6.2 2.9 4.3 1.3
80 | 1 5.1 2.5 3.0 1.1
81 | 1 5.7 2.8 4.1 1.3
82 | 2 6.5 3.2 5.1 2.0
83 | 2 6.4 2.7 5.3 1.9
84 | 2 6.8 3.0 5.5 2.1
85 | 2 5.7 2.5 5.0 2.0
86 | 2 5.8 2.8 5.1 2.4
87 | 2 6.4 3.2 5.3 2.3
88 | 2 6.5 3.0 5.5 1.8
89 | 2 7.7 3.8 6.7 2.2
90 | 2 7.7 2.6 6.9 2.3
91 | 2 6.0 2.2 5.0 1.5
92 | 2 6.9 3.2 5.7 2.3
93 | 2 5.6 2.8 4.9 2.0
94 | 2 7.7 2.8 6.7 2.0
95 | 2 6.3 2.7 4.9 1.8
96 | 2 6.7 3.3 5.7 2.1
97 | 2 7.2 3.2 6.0 1.8
98 | 2 6.2 2.8 4.8 1.8
99 | 2 6.1 3.0 4.9 1.8
100 | 2 6.4 2.8 5.6 2.1
101 | 2 7.2 3.0 5.8 1.6
102 | 2 7.4 2.8 6.1 1.9
103 | 2 7.9 3.8 6.4 2.0
104 | 2 6.4 2.8 5.6 2.2
105 | 2 6.3 2.8 5.1 1.5
106 | 2 6.1 2.6 5.6 1.4
107 | 2 7.7 3.0 6.1 2.3
108 | 2 6.3 3.4 5.6 2.4
109 | 2 6.4 3.1 5.5 1.8
110 | 2 6.0 3.0 4.8 1.8
111 | 2 6.9 3.1 5.4 2.1
112 | 2 6.7 3.1 5.6 2.4
113 | 2 6.9 3.1 5.1 2.3
114 | 2 5.8 2.7 5.1 1.9
115 | 2 6.8 3.2 5.9 2.3
116 | 2 6.7 3.3 5.7 2.5
117 | 2 6.7 3.0 5.2 2.3
118 | 2 6.3 2.5 5.0 1.9
119 | 2 6.5 3.0 5.2 2.0
120 | 2 6.2 3.4 5.4 2.3
121 | 2 5.9 3.0 5.1 1.8
122 |
--------------------------------------------------------------------------------
/src/DeepSharp.Core/DeepSharp.Core.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | net7.0
5 | enable
6 | enable
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/src/DeepSharp.Core/Trainer.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.Core
2 | {
3 | ///
4 | /// Trainer for simplify deep learning
5 | ///
6 | public abstract class Trainer
7 | {
8 | }
9 | }
--------------------------------------------------------------------------------
/src/DeepSharp.Dataset/DataLoaders/DataLoader.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.Dataset
2 | {
3 | ///
4 | /// Basic DataLoader inherit from torch.utils.data.DataLoader
5 | ///
6 | ///
7 | public class DataLoader : torch.utils.data.DataLoader
8 | where T : DataView
9 | {
10 | public DataLoader(Dataset dataset, DataLoaderConfig config)
11 | : base(dataset, config.BatchSize, CollateFunc, config.Shuffle, config.Device, config.Seed, config.NumWorker,
12 | config.DropLast)
13 | {
14 | }
15 |
16 | public static DataViewPair CollateFunc(IEnumerable dataViews, torch.Device device)
17 | {
18 | var views = dataViews.ToList();
19 | var features = views.Select(a => a.GetFeatures()).ToList();
20 | var labels = views.Select(a => a.GetLabels()).ToList();
21 | var result = new DataViewPair(labels, features).To(device);
22 | return result;
23 | }
24 |
25 | public async IAsyncEnumerable GetBatchSample()
26 | {
27 | using var enumerator = GetEnumerator();
28 | while (enumerator.MoveNext()) yield return enumerator.Current;
29 | await Task.CompletedTask;
30 | }
31 | }
32 | }
--------------------------------------------------------------------------------
/src/DeepSharp.Dataset/DataLoaders/DataLoaderConfig.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.Dataset
2 | {
3 | public struct DataLoaderConfig
4 | {
5 | public DataLoaderConfig()
6 | {
7 | Seed = null;
8 | }
9 |
10 | public int BatchSize { set; get; } = 4;
11 | public bool Shuffle { set; get; } = true;
12 | public bool DropLast { set; get; } = true;
13 | public int NumWorker { set; get; } = 1;
14 | public int? Seed { set; get; }
15 | public torch.Device Device { set; get; } = new(DeviceType.CUDA);
16 | }
17 | }
--------------------------------------------------------------------------------
/src/DeepSharp.Dataset/DataLoaders/InfiniteDataLoader.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.Dataset
2 | {
3 | ///
4 | /// Infinite DataLoader inherit from DataLoader
5 | ///
6 | ///
7 | public class InfiniteDataLoader : DataLoader
8 | where T : DataView
9 |
10 | {
11 | public InfiniteDataLoader(Dataset dataset, DataLoaderConfig config) : base(dataset, config)
12 | {
13 | IEnumerator = GetEnumerator();
14 | }
15 |
16 | protected IEnumerator IEnumerator { set; get; }
17 |
18 | public async IAsyncEnumerable GetBatchSample(int sample)
19 | {
20 | var i = 0;
21 | while (i++ < sample)
22 | if (IEnumerator.MoveNext())
23 | {
24 | yield return IEnumerator.Current;
25 | }
26 | else
27 | {
28 | IEnumerator.Reset();
29 | IEnumerator.MoveNext();
30 | yield return IEnumerator.Current;
31 | }
32 |
33 | await Task.CompletedTask;
34 | }
35 | }
36 | }
--------------------------------------------------------------------------------
/src/DeepSharp.Dataset/Datasets/DataView.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.Dataset
2 | {
3 | public abstract class DataView
4 | {
5 | ///
6 | /// Get the features with tensor format.
7 | ///
8 | ///
9 | public abstract torch.Tensor GetFeatures();
10 |
11 | ///
12 | /// Get the labels with tensor format.
13 | ///
14 | ///
15 | public abstract torch.Tensor GetLabels();
16 |
17 |
18 | ///
19 | /// convert batch DataView to single DataView Pair
20 | ///
21 | ///
22 | /// cpu or cuda
23 | ///
24 | public static DataViewPair FromDataViews(IEnumerable datasetViews, torch.Device device)
25 | {
26 | var views = datasetViews.ToList();
27 | var features = views.Select(a => a.GetFeatures()).ToList();
28 | var labels = views.Select(a => a.GetLabels()).ToList();
29 | var result = new DataViewPair(labels, features).To(device);
30 | return result;
31 | }
32 | }
33 | }
--------------------------------------------------------------------------------
/src/DeepSharp.Dataset/Datasets/DataViewPair.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.Dataset
2 | {
3 | public class DataViewPair
4 | {
5 | public DataViewPair(torch.Tensor labels, torch.Tensor features)
6 | {
7 | Labels = labels;
8 | Features = features;
9 | }
10 |
11 | public DataViewPair(IEnumerable labels, IEnumerable features)
12 | {
13 | var labelsArray = labels.ToArray();
14 | var featuresArray = features.ToArray();
15 | Labels = torch.vstack(labelsArray);
16 | Features = torch.vstack(featuresArray);
17 | }
18 |
19 |
20 | public torch.Tensor Labels { set; get; }
21 | public torch.Tensor Features { set; get; }
22 |
23 |
24 | internal DataViewPair To(torch.Device device)
25 | {
26 | var res = new DataViewPair(Labels.to(device), Features.to(device));
27 | return res;
28 | }
29 |
30 | ///
31 | /// Send DataViewPair to CPU device
32 | ///
33 | ///
34 | public DataViewPair cpu()
35 | {
36 | return To(new torch.Device(DeviceType.CPU));
37 | }
38 |
39 | ///
40 | /// Send DataViewPair to CPU device
41 | ///
42 | ///
43 | public DataViewPair cuda()
44 | {
45 | return To(new torch.Device(DeviceType.CUDA));
46 | }
47 |
48 | public override string ToString()
49 | {
50 | var strbuild = new StringBuilder();
51 | strbuild.AppendLine($"{Labels}");
52 | strbuild.AppendLine($"{Features}");
53 | return strbuild.ToString();
54 | }
55 | }
56 | }
--------------------------------------------------------------------------------
/src/DeepSharp.Dataset/Datasets/Dataset.cs:
--------------------------------------------------------------------------------
1 | using System.Reflection;
2 |
3 | namespace DeepSharp.Dataset
4 | {
5 | ///
6 | /// Load DataSet by object-oriented inherit from torch.utils.data.Dataset
7 | ///
8 | ///
9 | public class Dataset : torch.utils.data.Dataset
10 | where T : DataView
11 | {
12 | protected T[] AllData { set; get; }
13 |
14 | public Dataset(string path, char splitChar = '\t', bool hasHeader = true)
15 | {
16 | /// Step 1 Precheck
17 | File.Exists(path).Should().BeTrue($"File {path} should exist.");
18 |
19 | /// Step 2 Read Stream to DataTable or Array
20 | using var stream = new StreamReader(path);
21 | var allline = stream.ReadToEnd()
22 | .Split('\r', '\n')
23 | .Where(a => !string.IsNullOrEmpty(a))
24 | .ToList();
25 | if (hasHeader)
26 | allline.RemoveAt(0);
27 | var alldata = allline.Select(l => l.Split(splitChar).ToArray()).ToArray();
28 |
29 | var fieldDict = GetFieldDict(typeof(T));
30 |
31 | /// Step 3 According LoadColumnAttribute Change to Data
32 | AllData = alldata
33 | .Select(single => GetData(fieldDict, single))
34 | .ToArray();
35 | Count = AllData.Length;
36 | }
37 |
38 | public override long Count { get; }
39 |
40 | public override T GetTensor(long index)
41 | {
42 | return AllData[index];
43 | }
44 |
45 |
46 | #region protect function
47 |
48 | protected static Dictionary GetFieldDict(Type type)
49 | {
50 | var fieldInfo = type.GetProperties()
51 | .Where(a => a.CustomAttributes
52 | .Any(attributeData => attributeData.AttributeType == typeof(StreamHeaderAttribute)))
53 | .ToList();
54 | var dict = fieldInfo.ToDictionary(
55 | f => f,
56 | f => f.GetCustomAttribute()!.StreamHeaderRange);
57 | return dict;
58 | }
59 |
60 | protected static T GetData(Dictionary dict, string[] array)
61 | {
62 | var obj = (T) Activator.CreateInstance(typeof(T))!;
63 | dict.ToList().ForEach(p =>
64 | {
65 | var fieldInfo = p.Key;
66 | var range = p.Value;
67 | var type = fieldInfo.PropertyType;
68 |
69 | if (range.Min == range.Max)
70 | {
71 | var field = Convert.ChangeType(array[range.Min], type);
72 | fieldInfo.SetValue(obj, field);
73 | }
74 | else if (type.IsArray && range.Max >= range.Min)
75 | {
76 | var len = range.Max - range.Min + 1;
77 | var arr = Activator.CreateInstance(type, len);
78 | Enumerable.Range(0, len).ToList().ForEach(i =>
79 | {
80 | var field = Convert.ChangeType(array[range.Min + i], type.GetElementType()!);
81 | type.GetMethod("Set")?.Invoke(arr, new[] {i, field});
82 | });
83 | fieldInfo.SetValue(obj, arr);
84 | }
85 | });
86 |
87 | return obj;
88 | }
89 |
90 | #endregion
91 | }
92 | }
--------------------------------------------------------------------------------
/src/DeepSharp.Dataset/Datasets/StreamHeaderAttribute.cs:
--------------------------------------------------------------------------------
1 |
2 |
3 | namespace DeepSharp.Dataset
4 | {
5 | [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)]
6 | public class StreamHeaderAttribute : Attribute
7 | {
8 | public StreamHeaderRange StreamHeaderRange;
9 |
10 | /// Maps member to specific field in text file.
11 | /// The index of the field in the text file.
12 | public StreamHeaderAttribute(int fieldIndex)
13 | {
14 | StreamHeaderRange = new StreamHeaderRange(fieldIndex);
15 | }
16 |
17 | public StreamHeaderAttribute(int min, int max)
18 | {
19 | StreamHeaderRange = new StreamHeaderRange(min, max);
20 | }
21 | }
22 | }
--------------------------------------------------------------------------------
/src/DeepSharp.Dataset/Datasets/StreamHeaderRange.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.Dataset
2 | {
3 | ///
4 | /// Specifies the range of indices of input columns that should be mapped to an output column.
5 | ///
6 | public sealed class StreamHeaderRange
7 | {
8 | ///
9 | /// Whether this range includes only other indices not specified.
10 | ///
11 | public bool AllOther;
12 |
13 | ///
14 | /// Whether this range extends to the end of the line, but should be a fixed number of items.
15 | /// If is specified, the fields and are ignored.
16 | ///
17 | public bool AutoEnd;
18 |
19 | ///
20 | /// Force scalar columns to be treated as vectors of length one.
21 | ///
22 | public bool ForceVector;
23 |
24 | ///
25 | /// The maximum index of the column, inclusive. If
26 | /// indicates that the should auto-detect the length
27 | /// of the lines, and read until the end.
28 | /// If max is specified, the fields and are ignored.
29 | ///
30 | public int Max;
31 |
32 | ///
33 | /// The minimum index of the column, inclusive.
34 | ///
35 | public int Min;
36 |
37 | ///
38 | /// Whether this range extends to the end of the line, which can vary from line to line.
39 | /// If is specified, the fields and are ignored.
40 | /// If is , then is ignored.
41 | ///
42 | public bool VariableEnd;
43 |
44 | public StreamHeaderRange()
45 | {
46 | }
47 |
48 | ///
49 | /// A range representing a single value. Will result in a scalar column.
50 | ///
51 | /// The index of the field of the text file to read.
52 | public StreamHeaderRange(int index)
53 | {
54 | index.Should().BeGreaterThanOrEqualTo(0, "Must be non-negative");
55 | Min = index;
56 | Max = index;
57 | }
58 |
59 | ///
60 | /// A range representing a set of values. Will result in a vector column.
61 | ///
62 | /// The minimum inclusive index of the column.
63 | ///
64 | /// The maximum-inclusive index of the column. If null
65 | /// indicates that the should auto-detect the length
66 | /// of the lines, and read until the end.
67 | ///
68 | public StreamHeaderRange(int min, int max)
69 | {
70 | min.Should().BeGreaterThanOrEqualTo(0, "Must be non-negative");
71 | max.Should().BeGreaterOrEqualTo(min, "If specified, must be greater than or equal to " + nameof(min));
72 |
73 | Min = min;
74 | Max = max;
75 | // Note that without the following being set, in the case where there is a single range
76 | // where Min == Max, the result will not be a vector valued but a scalar column.
77 | ForceVector = true;
78 | }
79 |
80 | internal static StreamHeaderRange Parse(string str)
81 | {
82 | str.Should().NotBeNullOrEmpty();
83 | var res = new StreamHeaderRange();
84 | if (res.TryParse(str)) return res;
85 |
86 | return null;
87 | }
88 |
89 | private bool TryParse(string str)
90 | {
91 | str.Should().NotBeNullOrEmpty();
92 |
93 | var ich = str.IndexOfAny(new[] {'-', '~'});
94 | if (ich < 0)
95 | {
96 | // No "-" or "~". Single integer.
97 | if (!int.TryParse(str, out Min)) return false;
98 |
99 | Max = Min;
100 | return true;
101 | }
102 |
103 | AllOther = str[ich] == '~';
104 | ForceVector = true;
105 |
106 | if (ich == 0)
107 | {
108 | if (!AllOther) return false;
109 |
110 | Min = 0;
111 | }
112 | else if (!int.TryParse(str.Substring(0, ich), out Min))
113 | {
114 | return false;
115 | }
116 |
117 | var rest = str.Substring(ich + 1);
118 | if (string.IsNullOrEmpty(rest) || rest == "*")
119 | {
120 | AutoEnd = true;
121 | return true;
122 | }
123 |
124 | if (rest == "**")
125 | {
126 | VariableEnd = true;
127 | return true;
128 | }
129 |
130 | int tmp;
131 | if (!int.TryParse(rest, out tmp)) return false;
132 |
133 | Max = tmp;
134 | return true;
135 | }
136 |
137 | internal bool TryUnparse(StringBuilder sb)
138 | {
139 | sb.Should().NotBeNull();
140 | var dash = AllOther ? '~' : '-';
141 | if (Min < 0) return false;
142 |
143 | sb.Append(Min);
144 | if (Max != null)
145 | {
146 | if (Max != Min || ForceVector || AllOther) sb.Append(dash).Append(Max);
147 | }
148 | else if (AutoEnd)
149 | {
150 | sb.Append(dash).Append("*");
151 | }
152 | else if (VariableEnd)
153 | {
154 | sb.Append(dash).Append("**");
155 | }
156 |
157 | return true;
158 | }
159 | }
160 | }
--------------------------------------------------------------------------------
/src/DeepSharp.Dataset/DeepSharp.Dataset.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | net7.0
5 | enable
6 | enable
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/src/DeepSharp.Dataset/DeepSharp.Dataset.csproj.DotSettings:
--------------------------------------------------------------------------------
1 |
2 | True
3 | True
--------------------------------------------------------------------------------
/src/DeepSharp.Dataset/Usings.cs:
--------------------------------------------------------------------------------
1 | global using TorchSharp;
2 | global using OpenCvSharp;
3 | global using System.Text;
4 | global using FluentAssertions;
--------------------------------------------------------------------------------
/src/DeepSharp.ObjectDetection/DeepSharp.ObjectDetection.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | net7.0
5 | enable
6 | enable
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/src/DeepSharp.ObjectDetection/Yolo.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.ObjectDetection
2 | {
3 | public class Yolo
4 | {
5 | }
6 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/ActionSelectors/ActionSelector.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.ActionSelectors
2 | {
3 | ///
4 | /// Action Selector help to change Pred-out from Net to Specific action objects
5 | /// 动作选择器,转换网络的输出到具体的动作选择器
6 | ///
7 | public abstract class ActionSelector
8 | {
9 | protected ActionSelector(bool keepDims = false)
10 | {
11 | KeepDims = keepDims;
12 | }
13 |
14 | public bool KeepDims { set; get; }
15 |
16 | public abstract torch.Tensor Select(torch.Tensor probs);
17 | }
18 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/ActionSelectors/ArgmaxActionSelector.cs:
--------------------------------------------------------------------------------
1 | using FluentAssertions;
2 |
3 | namespace DeepSharp.RL.ActionSelectors
4 | {
5 | ///
6 | /// 传入张量,对最后一维执行 Argmax
7 | ///
8 | public class ArgmaxActionSelector : ActionSelector
9 | {
10 | ///
11 | ///
12 | ///
13 | /// action of long format
14 | public override torch.Tensor Select(torch.Tensor probs)
15 | {
16 | probs.dim().Should().Be(2, "ArgmaxActionSelector Support tensor which dims is 2");
17 | return torch.argmax(probs, -1, KeepDims);
18 | }
19 | }
20 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/ActionSelectors/EpsilonActionSelector.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.ActionSelectors
2 | {
3 | public class EpsilonActionSelector : ActionSelector
4 | {
5 | public EpsilonActionSelector(ActionSelector selector)
6 | {
7 | Selector = selector;
8 | }
9 |
10 | public ActionSelector Selector { protected set; get; }
11 |
12 | public override torch.Tensor Select(torch.Tensor probs)
13 | {
14 | throw new NotImplementedException();
15 | }
16 | }
17 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/ActionSelectors/ProbActionSelector.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.ActionSelectors
2 | {
3 | ///
4 | /// 输入归一化概率,返回从分布中采样的结果
5 | ///
6 | public class ProbActionSelector : ActionSelector
7 | {
8 | ///
9 | ///
10 | /// such as [[0.8,0.1,0.1],[0.01,0.98,0.01]]
11 | ///
12 | public override torch.Tensor Select(torch.Tensor probs)
13 | {
14 | var dims = probs.dim();
15 |
16 | return dims switch
17 | {
18 | 1 => GetActionByDim1(probs),
19 | 2 => GetActionByDim2(probs),
20 | _ => throw new NotSupportedException("Support Dim which 1 & 2")
21 | };
22 | }
23 |
24 | private torch.Tensor GetActionByDim1(torch.Tensor probs)
25 | {
26 | return torch.multinomial(probs, 1);
27 | }
28 |
29 | private torch.Tensor GetActionByDim2(torch.Tensor probs)
30 | {
31 | var width = (int) probs.shape[0];
32 |
33 | var arr = Enumerable.Range(0, width).Select(i =>
34 | {
35 | var tensorIndices = new[] {torch.TensorIndex.Single(i)};
36 | var prob = probs[tensorIndices];
37 |
38 | return torch.multinomial(prob, 1);
39 | }).ToList();
40 | var final = torch.vstack(arr);
41 |
42 | return KeepDims ? final : final.squeeze(-1);
43 | }
44 | }
45 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/Agent.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Enumerates;
2 | using DeepSharp.RL.Environs;
3 | using MathNet.Numerics.Random;
4 |
5 | namespace DeepSharp.RL.Agents
6 | {
7 | ///
8 | /// 智能体
9 | ///
10 | public abstract class Agent
11 | {
12 | protected Agent(Environ env, string name)
13 | {
14 | Name = name;
15 | Environ = env;
16 | Device = env.Device;
17 | }
18 |
19 |
20 | public string Name { protected set; get; }
21 | public torch.Device Device { protected set; get; }
22 |
23 | public long ObservationSize => Environ.ObservationSpace!.N;
24 |
25 | public long ActionSize => Environ.ActionSpace!.N;
26 |
27 | public Environ Environ { protected set; get; }
28 | public float Epsilon { set; get; } = 0.2f;
29 |
30 |
31 | public abstract LearnOutcome Learn();
32 |
33 | public abstract void Save(string path);
34 |
35 | public abstract void Load(string path);
36 |
37 |
38 | ///
39 | /// Get a random Action
40 | ///
41 | ///
42 | public Act GetSampleAct()
43 | {
44 | return Environ.SampleAct();
45 | }
46 |
47 | ///
48 | /// Get a action by Policy
49 | /// π(s)
50 | ///
51 | /// current state
52 | /// a action provide by agent's policy
53 | public abstract Act GetPolicyAct(torch.Tensor state);
54 |
55 |
56 | ///
57 | /// Get a action by ε-greedy method
58 | /// π^ε(s)
59 | ///
60 | /// current state
61 | ///
62 | ///
63 | public Act GetEpsilonAct(torch.Tensor state)
64 | {
65 | var d = new SystemRandomSource();
66 | var v = d.NextDouble();
67 | var act = v < Epsilon
68 | ? GetSampleAct()
69 | : GetPolicyAct(state);
70 | return act;
71 | }
72 |
73 |
74 | ///
75 | /// Get Episode by Agent
76 | /// 以策略为主,运行得到一个完整片段
77 | ///
78 | /// 奖励
79 | public virtual Episode RunEpisode(
80 | PlayMode playMode = PlayMode.Agent)
81 | {
82 | Environ.Reset();
83 | var episode = new Episode();
84 | var epoch = 0;
85 | while (Environ.IsComplete(epoch) == false)
86 | {
87 | epoch++;
88 | var act = playMode switch
89 | {
90 | PlayMode.Sample => GetSampleAct(),
91 | PlayMode.Agent => GetPolicyAct(Environ.Observation!.Value!),
92 | PlayMode.EpsilonGreedy => GetEpsilonAct(Environ.Observation!.Value!),
93 | _ => throw new ArgumentOutOfRangeException(nameof(playMode), playMode, null)
94 | };
95 | var step = Environ.Step(act, epoch);
96 | episode.Steps.Add(step);
97 | Environ.CallBack?.Invoke(step);
98 | Environ.Observation = step.PostState; /// It's import for Update Observation
99 | }
100 |
101 | var sumReward = Environ.GetReturn(episode);
102 | episode.SumReward = new Reward(sumReward);
103 | return episode;
104 | }
105 |
106 |
107 | ///
108 | /// Get Episodes by Agent
109 | /// 以策略为主,运行得到多个完整片段
110 | ///
111 | /// 奖励
112 | public virtual Episode[] RunEpisodes(int count,
113 | PlayMode playMode = PlayMode.Agent)
114 | {
115 | var episodes = new List();
116 | foreach (var _ in Enumerable.Repeat(1, count))
117 | episodes.Add(RunEpisode(playMode));
118 |
119 | return episodes.ToArray();
120 | }
121 |
122 | ///
123 | ///
124 | /// test count
125 | /// Average Reward
126 | public float TestEpisodes(int testCount)
127 | {
128 | var episode = RunEpisodes(testCount);
129 | var averageReward = episode.Average(a => a.SumReward.Value);
130 | return averageReward;
131 | }
132 |
133 |
134 | public override string ToString()
135 | {
136 | return $"Agent[{Name}]";
137 | }
138 | }
139 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/LearnOutCome.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 |
3 | namespace DeepSharp.RL.Agents
4 | {
5 | public class LearnOutcome
6 | {
7 | public LearnOutcome()
8 | {
9 | Steps = new List();
10 | Evaluate = 0;
11 | }
12 |
13 | public LearnOutcome(Step[] steps, float evaluate)
14 | {
15 | Steps = steps.ToList();
16 | Evaluate = evaluate;
17 | }
18 |
19 | public LearnOutcome(Episode episode)
20 | {
21 | Steps = episode.Steps;
22 | Evaluate = episode.SumReward.Value;
23 | }
24 |
25 | public LearnOutcome(Episode[] episode)
26 | {
27 | Steps = episode.SelectMany(e => e.Steps).ToList();
28 | Evaluate = episode.Average(a => a.SumReward.Value);
29 | }
30 |
31 | public LearnOutcome(Episode[] episode, float loss)
32 | {
33 | Steps = episode.SelectMany(e => e.Steps).ToList();
34 | Evaluate = loss;
35 | }
36 |
37 | public List Steps { protected set; get; }
38 | public float Evaluate { set; get; }
39 |
40 | public void AppendStep(Step step)
41 | {
42 | Steps.Add(step);
43 | }
44 |
45 | public void AppendStep(IEnumerable steps)
46 | {
47 | Steps.AddRange(steps);
48 | }
49 |
50 | public void AppendStep(Episode episode)
51 | {
52 | Steps.AddRange(episode.Steps);
53 | }
54 |
55 | public void UpdateEvaluate(float evaluation)
56 | {
57 | Evaluate = evaluation;
58 | }
59 |
60 | public override string ToString()
61 | {
62 | var avrReward = Steps.Average(a => a.Reward.Value);
63 | var message = $"S:{Steps.Count}\tR:{avrReward:F4}\tE:{Evaluate:F4}";
64 | return message;
65 | }
66 | }
67 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/Models/QTable.cs:
--------------------------------------------------------------------------------
1 | using System.Text;
2 | using DeepSharp.RL.Environs;
3 |
4 | namespace DeepSharp.RL.Agents
5 | {
6 | ///
7 | /// State-Action Value Table
8 | /// Q(s,a) = Cumulative reward
9 | ///
10 | public class QTable : IEquatable
11 | {
12 | public QTable()
13 | {
14 | Return = new Dictionary();
15 | }
16 |
17 | public Dictionary Return { protected set; get; }
18 | protected List TrasitKeys => Return.Keys.ToList();
19 |
20 |
21 | public float this[TransitKey transit]
22 | {
23 | get => GetValue(transit);
24 | set => SetValue(transit, value);
25 | }
26 |
27 |
28 | public float this[torch.Tensor state, torch.Tensor action]
29 | {
30 | get => GetValue(new TransitKey(state, action));
31 | set => SetValue(new TransitKey(state, action), value);
32 | }
33 |
34 |
35 | public bool Equals(QTable? other)
36 | {
37 | if (other == null) return false;
38 | if (other.TrasitKeys.Count != TrasitKeys.Count) return false;
39 | var res = TrasitKeys.All(key => !(Math.Abs(this[key] - other[key]) > 1E-2));
40 | return res;
41 | }
42 |
43 |
44 | private void SetValue(TransitKey transit, float value)
45 | {
46 | Return[transit] = value;
47 | }
48 |
49 | private float GetValue(TransitKey transit)
50 | {
51 | Return.TryAdd(transit, 0f);
52 | return Return[transit];
53 | }
54 |
55 |
56 | ///
57 | /// argMax
58 | ///
59 | ///
60 | ///
61 | public Act? GetBestAct(torch.Tensor state)
62 | {
63 | var row = TrasitKeys
64 | .Where(a => a.State.Equals(state));
65 |
66 | var stateActions = Return
67 | .Where(a => row.Contains(a.Key)).ToList();
68 |
69 | if (!stateActions.Any())
70 | return null;
71 |
72 | if (stateActions.All(a => a.Value == 0))
73 | return null;
74 |
75 | var argMax = stateActions
76 | .MaxBy(a => a.Value);
77 | var act = argMax.Key.Act;
78 | return new Act(act);
79 | }
80 |
81 | ///
82 | /// argMax
83 | ///
84 | ///
85 | ///
86 | public float GetBestValue(torch.Tensor state)
87 | {
88 | var row = TrasitKeys
89 | .Where(a => a.State.Equals(state));
90 |
91 | var stateActions = Return
92 | .Where(a => row.Contains(a.Key)).ToList();
93 |
94 | if (!stateActions.Any())
95 | return 0;
96 |
97 | var bestValue = stateActions
98 | .Max(a => a.Value);
99 | return bestValue;
100 | }
101 |
102 | public override string ToString()
103 | {
104 | var str = new StringBuilder();
105 | foreach (var keyValuePair in Return.Where(a => a.Value > 0))
106 | str.AppendLine($"{keyValuePair.Key}\t{keyValuePair.Value:F4}");
107 | return str.ToString();
108 | }
109 | }
110 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/Models/RewardKey.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 |
3 | namespace DeepSharp.RL.Agents
4 | {
5 | ///
6 | /// 奖励表 复合键
7 | ///
8 | public struct RewardKey
9 | {
10 | public RewardKey(torch.Tensor state, torch.Tensor act, torch.Tensor newState)
11 | {
12 | State = state;
13 | Act = act;
14 | NewState = newState;
15 | }
16 |
17 | public RewardKey(Observation state, Act act, Observation newState)
18 | {
19 | State = state.Value!;
20 | Act = act.Value!;
21 | NewState = newState.Value!;
22 | }
23 |
24 | public torch.Tensor State { set; get; }
25 | public torch.Tensor Act { set; get; }
26 | public torch.Tensor NewState { set; get; }
27 |
28 |
29 | public override string ToString()
30 | {
31 | var state = State.ToString(torch.numpy);
32 | var action = Act.ToString(torch.numpy);
33 | var newState = NewState.ToString(torch.numpy);
34 |
35 | return $"{state} \r\n {action} \r\n {newState}";
36 | }
37 |
38 | public static bool operator ==(RewardKey x, RewardKey y)
39 | {
40 | return x.Equals(y);
41 | }
42 |
43 | public static bool operator !=(RewardKey x, RewardKey y)
44 | {
45 | return !x.Equals(y);
46 | }
47 |
48 | public bool Equals(RewardKey other)
49 | {
50 | return State.Equals(other.State) && Act.Equals(other.Act) && NewState.Equals(other.NewState);
51 | }
52 |
53 |
54 | public override bool Equals(object? obj)
55 | {
56 | if (obj is RewardKey input)
57 | return Equals(input);
58 | return false;
59 | }
60 |
61 | public override int GetHashCode()
62 | {
63 | return -1;
64 | }
65 | }
66 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/Models/TrasitKey.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 |
3 | namespace DeepSharp.RL.Agents
4 | {
5 | ///
6 | /// Key of Transit, which combine state and action.
7 | ///
8 | public class TransitKey
9 | {
10 | public TransitKey(Observation state, Act act)
11 | {
12 | State = state.Value!;
13 | Act = act.Value!;
14 | }
15 |
16 | public TransitKey(torch.Tensor state, torch.Tensor act)
17 | {
18 | State = state;
19 | Act = act;
20 | }
21 |
22 | public torch.Tensor State { protected set; get; }
23 | public torch.Tensor Act { protected set; get; }
24 |
25 |
26 | public static bool operator ==(TransitKey x, TransitKey y)
27 | {
28 | return x.Equals(y);
29 | }
30 |
31 | public static bool operator !=(TransitKey x, TransitKey y)
32 | {
33 | return !x.Equals(y);
34 | }
35 |
36 | public bool Equals(TransitKey other)
37 | {
38 | return State.Equals(other.State) && Act.Equals(other.Act);
39 | }
40 |
41 | public override bool Equals(object? obj)
42 | {
43 | if (obj is TransitKey input)
44 | return Equals(input);
45 | return false;
46 | }
47 |
48 | public override int GetHashCode()
49 | {
50 | return -1;
51 | }
52 |
53 | public override string ToString()
54 | {
55 | return $"{State.ToString(torch.numpy)},{Act.ToString(torch.numpy)}";
56 | }
57 | }
58 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/Models/VTable.cs:
--------------------------------------------------------------------------------
1 | using System.Text;
2 |
3 | namespace DeepSharp.RL.Agents
4 | {
5 | ///
6 | /// State Value Function
7 | /// V(s) = Cumulative reward
8 | ///
9 | public class VTable : IEquatable
10 | {
11 | public VTable()
12 | {
13 | Return = new Dictionary();
14 | }
15 |
16 | public Dictionary Return { protected set; get; }
17 |
18 | protected List StateKeys => Return.Keys.ToList();
19 |
20 |
21 | public float this[torch.Tensor state]
22 | {
23 | get => GetValue(state);
24 | set => SetValue(state, value);
25 | }
26 |
27 | private void SetValue(torch.Tensor state, float value)
28 | {
29 | Return[state] = value;
30 | }
31 |
32 | private float GetValue(torch.Tensor transit)
33 | {
34 | Return.TryAdd(transit, 0f);
35 | return Return[transit];
36 | }
37 |
38 | public static float operator -(VTable a, VTable b)
39 | {
40 | var keys = a.StateKeys;
41 | return keys.Select(k => Math.Abs(a[k] - b[k])).Max();
42 | }
43 |
44 | public bool Equals(VTable? other)
45 | {
46 | if (other == null) return false;
47 | if (other.StateKeys.Count != StateKeys.Count) return false;
48 | var res = StateKeys.All(key => !(Math.Abs(this[key] - other[key]) > 1E-2));
49 | return res;
50 | }
51 |
52 |
53 | public override string ToString()
54 | {
55 | var str = new StringBuilder();
56 | foreach (var keyValuePair in Return.Where(a => a.Value > 0))
57 | str.AppendLine($"{keyValuePair.Key.ToString(torch.numpy)}\t{keyValuePair.Value:F4}");
58 | return str.ToString();
59 | }
60 | }
61 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/Nets/DQNNet.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Agents
2 | {
3 | ///
4 | /// DQN Dep Model
5 | ///
6 | public sealed class DQNNet : Module
7 | {
8 | private readonly Module conv;
9 | private readonly Module fc;
10 | public torch.Device Device { get; }
11 | public torch.ScalarType ScalarType { get; }
12 |
13 | public DQNNet(long[] inputShape, int actions,
14 | torch.ScalarType scalar = torch.ScalarType.Float32,
15 | DeviceType deviceType = DeviceType.CUDA) :
16 | base("DQNN")
17 | {
18 | ScalarType = scalar;
19 | Device = new torch.Device(deviceType);
20 | var modules = new List<(string, Module)>
21 | {
22 | ("Conv2d1", Conv2d(inputShape[0], 32, 8, 4)),
23 | ("Relu1", ReLU()),
24 | ("Conv2d2", Conv2d(32, 64, 4, 2)),
25 | ("Relu2", ReLU()),
26 | ("Conv2d3", Conv2d(64, 64, 3)),
27 | ("Relu3", ReLU())
28 | };
29 | conv = Sequential(modules);
30 | conv.to(Device);
31 |
32 | var convOutSize = GetConvOut(inputShape);
33 | var modules2 = new List<(string, Module)>
34 | {
35 | ("Linear1", Linear(convOutSize, 512)),
36 | ("Relu4", ReLU()),
37 | ("Linear2", Linear(512, actions))
38 | };
39 | fc = Sequential(modules2);
40 | fc.to(Device);
41 |
42 | RegisterComponents();
43 | }
44 |
45 |
46 | public override torch.Tensor forward(torch.Tensor input)
47 | {
48 | var convOut = conv.forward(input).view(input.size(0), -1);
49 | var fcOut = fc.forward(convOut);
50 | return fcOut;
51 | }
52 |
53 | public int GetConvOut(long[] inputShape)
54 | {
55 | var arr = new List {1};
56 | arr.AddRange(inputShape);
57 | var input = torch.zeros(arr.ToArray(), ScalarType, Device);
58 | var o = conv.forward(input);
59 | var shapes = o.size();
60 | var outSize = shapes.Aggregate((a, b) => a * b);
61 | return (int) outSize;
62 | }
63 |
64 | protected override void Dispose(bool disposing)
65 | {
66 | if (disposing)
67 | {
68 | conv.Dispose();
69 | fc.Dispose();
70 | ClearModules();
71 | }
72 |
73 | base.Dispose(disposing);
74 | }
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/Nets/Net.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Agents
2 | {
3 | ///
4 | /// This is demo net to guide how to create a new Module
5 | ///
6 | public sealed class Net : Module
7 | {
8 | private readonly Module layers;
9 |
10 | public Net(long obsSize, long hiddenSize, long actionNum, DeviceType deviceType = DeviceType.CUDA) :
11 | base("Net")
12 | {
13 | var modules = new List<(string, Module)>
14 | {
15 | ("line1", Linear(obsSize, hiddenSize)),
16 | ("relu", ReLU()),
17 | ("line2", Linear(hiddenSize, actionNum))
18 | };
19 | layers = Sequential(modules);
20 | layers.to(new torch.Device(deviceType));
21 | RegisterComponents();
22 | }
23 |
24 | public override torch.Tensor forward(torch.Tensor input)
25 | {
26 | return layers.forward(input.to_type(torch.ScalarType.Float32));
27 | }
28 |
29 | protected override void Dispose(bool disposing)
30 | {
31 | if (disposing)
32 | {
33 | layers.Dispose();
34 | ClearModules();
35 | }
36 |
37 | base.Dispose(disposing);
38 | }
39 | }
40 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/Nets/PGN.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Agents
2 | {
3 | public sealed class PGN : Module
4 | {
5 | private readonly Module layers;
6 |
7 | public PGN(long obsSize, long hiddenSize, long actionNum, DeviceType deviceType = DeviceType.CUDA) :
8 | base("PolicyNet")
9 | {
10 | var modules = new List<(string, Module)>
11 | {
12 | ("line1", Linear(obsSize, hiddenSize)),
13 | ("relu", ReLU()),
14 | ("line2", Linear(hiddenSize, actionNum)),
15 | ("softmax", Softmax(-1))
16 | };
17 | layers = Sequential(modules);
18 | layers.to(new torch.Device(deviceType));
19 | RegisterComponents();
20 | }
21 |
22 | public override torch.Tensor forward(torch.Tensor input)
23 | {
24 | return layers.forward(input.to_type(torch.ScalarType.Float32));
25 | }
26 |
27 | protected override void Dispose(bool disposing)
28 | {
29 | if (disposing)
30 | {
31 | layers.Dispose();
32 | ClearModules();
33 | }
34 |
35 | base.Dispose(disposing);
36 | }
37 | }
38 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/Others/CrossEntropy.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.ActionSelectors;
2 | using DeepSharp.RL.Enumerates;
3 | using DeepSharp.RL.Environs;
4 | using FluentAssertions;
5 | using static TorchSharp.torch.optim;
6 |
7 | namespace DeepSharp.RL.Agents
8 | {
9 | ///
10 | /// An Agent base on CrossEntropy Function
11 | /// Cross-Entropy Method
12 | /// http://people.smp.uq.edu.au/DirkKroese/ps/eormsCE.pdf
13 | ///
14 | public class CrossEntropy : Agent
15 | {
16 | public CrossEntropy(Environ environ,
17 | int t,
18 | float percentElite = 0.7f,
19 | int hiddenSize = 100) : base(environ, "CrossEntropy")
20 | {
21 | T = t;
22 | PercentElite = percentElite;
23 | AgentNet = new Net((int) environ.ObservationSpace!.N, hiddenSize, (int) environ.ActionSpace!.N,
24 | Device.type);
25 | Optimizer = Adam(AgentNet.parameters(), 0.01);
26 | Loss = CrossEntropyLoss();
27 | }
28 |
29 | public int T { protected set; get; }
30 |
31 | public float PercentElite { protected set; get; }
32 |
33 | public Net AgentNet { protected set; get; }
34 |
35 | public Optimizer Optimizer { protected set; get; }
36 |
37 | public Loss Loss { protected set; get; }
38 |
39 | ///
40 | /// 智能体 根据观察 生成动作 概率 分布,并按分布生成下一个动作
41 | ///
42 | ///
43 | ///
44 | public override Act GetPolicyAct(torch.Tensor state)
45 | {
46 | var input = state.unsqueeze(0);
47 | var sm = Softmax(1);
48 | var actionProbs = sm.forward(AgentNet.forward(input));
49 | var nextAction = new ProbActionSelector().Select(actionProbs);
50 | var action = new Act(nextAction);
51 | return action;
52 | }
53 |
54 |
55 | public override LearnOutcome Learn()
56 | {
57 | var episodes = RunEpisodes(T, PlayMode.Sample);
58 | var elite = GetElite(episodes);
59 |
60 | var oars = elite.SelectMany(a => a.Steps)
61 | .ToList();
62 |
63 | var observations = oars
64 | .Select(a => a.PostState.Value)
65 | .ToList();
66 | var actions = oars
67 | .Select(a => a.Action.Value)
68 | .ToList();
69 |
70 | var observation = torch.vstack(observations!);
71 | var action = torch.vstack(actions!);
72 |
73 | var loss = Learn(observation, action);
74 |
75 | return new LearnOutcome(episodes, loss);
76 | }
77 |
78 | public override void Save(string path)
79 | {
80 | throw new NotImplementedException();
81 | }
82 |
83 | public override void Load(string path)
84 | {
85 | throw new NotImplementedException();
86 | }
87 |
88 | ///
89 | /// Replace default Optimizer
90 | ///
91 | ///
92 | public void UpdateOptimizer(Optimizer optimizer)
93 | {
94 | Optimizer = optimizer;
95 | }
96 |
97 | ///
98 | /// Get Elite
99 | ///
100 | ///
101 | ///
102 | ///
103 | public virtual Episode[] GetElite(Episode[] episodes)
104 | {
105 | var reward = episodes
106 | .Select(a => a.SumReward.Value)
107 | .ToArray();
108 | var rewardP = reward.OrderByDescending(a => a)
109 | .Take((int) (reward.Length * PercentElite))
110 | .Min();
111 |
112 | var filterEpisodes = episodes
113 | .Where(e => e.SumReward.Value > rewardP)
114 | .ToArray();
115 |
116 | return filterEpisodes;
117 | }
118 |
119 | ///
120 | /// core function to update net
121 | ///
122 | /// tensor from multi observations, size: [batch,observation size]
123 | /// tensor from multi actions, size: [batch,action size]
124 | /// loss
125 | internal float Learn(torch.Tensor observations, torch.Tensor actions)
126 | {
127 | observations.shape.Last().Should()
128 | .Be(ObservationSize, $"Agent observations tensor should be [B,{ObservationSize}]");
129 |
130 | actions = actions.squeeze(-1);
131 | var pred = AgentNet.forward(observations);
132 | var output = Loss.call(pred, actions);
133 |
134 | Optimizer.zero_grad();
135 | output.backward();
136 | Optimizer.step();
137 |
138 | var loss = output.item();
139 | return loss;
140 | }
141 | }
142 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/Others/CrossEntropyExt.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using static TorchSharp.torch.optim;
3 |
4 | namespace DeepSharp.RL.Agents
5 | {
6 | ///
7 | /// An Agent base on CrossEntropy Function
8 | /// Cross-Entropy Method
9 | /// http://people.smp.uq.edu.au/DirkKroese/ps/eormsCE.pdf
10 | ///
11 | public class CrossEntropyExt : CrossEntropy
12 |
13 | {
14 | public int MemsEliteLength = 30;
15 | public List Start = new();
16 |
17 | public CrossEntropyExt(Environ environ,
18 | int t,
19 | float percentElite = 0.7f,
20 | int hiddenSize = 100)
21 | : base(environ, t, percentElite, hiddenSize)
22 | {
23 | Optimizer = Adam(AgentNet.parameters(), 0.01);
24 | }
25 |
26 | ///
27 | /// 增加记忆功能,记录历史的精英片段
28 | ///
29 | internal List MemeSteps { set; get; } = new();
30 |
31 | ///
32 | /// Get Elite
33 | ///
34 | ///
35 | ///
36 | ///
37 | public override Episode[] GetElite(Episode[] episodes)
38 | {
39 | var current = episodes.Select(a => a.DateTime).ToList();
40 | Start.Add(current.Min());
41 | if (Start.Count >= 10)
42 | Start.RemoveAt(0);
43 |
44 | var combine = episodes.Concat(MemeSteps).ToList();
45 | var reward = combine
46 | .Select(a => a.SumReward.Value)
47 | .ToArray();
48 | var rewardP = reward.OrderByDescending(a => a)
49 | .Take((int) (reward.Length * PercentElite))
50 | .Min();
51 |
52 | var filterEpisodes = combine
53 | .Where(e => e.SumReward.Value > rewardP)
54 | .ToArray();
55 |
56 | MemeSteps = filterEpisodes.Where(a => a.DateTime > Start.Min())
57 | .OrderByDescending(a => a.SumReward.Value)
58 | .Take(MemsEliteLength)
59 | .ToList();
60 |
61 | return filterEpisodes;
62 | }
63 | }
64 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/PolicyAgengt.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 |
3 | namespace DeepSharp.RL.Agents
4 | {
5 | public abstract class PolicyGradientAgengt : Agent
6 | {
7 | protected PolicyGradientAgengt(Environ env, string name)
8 | : base(env, name)
9 | {
10 | PolicyNet = new PGN(ObservationSize, 128, ActionSize, DeviceType.CPU);
11 | }
12 |
13 | public Module PolicyNet { protected set; get; }
14 |
15 |
16 | ///
17 | /// argmax(a') Q(state,a')
18 | /// 价值表中获取该状态State下最高价值的action'
19 | ///
20 | ///
21 | ///
22 | public override Act GetPolicyAct(torch.Tensor state)
23 | {
24 | var probs = PolicyNet.forward(state.unsqueeze(0)).squeeze(0);
25 | var actIndex = torch.multinomial(probs, 1, true).ToInt32();
26 | return new Act(torch.from_array(new[] {actIndex}));
27 | }
28 | }
29 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/PolicyBased/A2C.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using DeepSharp.RL.ExpReplays;
3 | using TorchSharp.Modules;
4 | using static TorchSharp.torch.optim;
5 |
6 | namespace DeepSharp.RL.Agents
7 | {
8 | public class A2C : PolicyGradientAgengt
9 | {
10 | public A2C(Environ env,
11 | int batchsize,
12 | float alpha = 0.01f,
13 | float beta = 0.01f,
14 | float gamma = 0.99f)
15 | : base(env, "ActorCritic")
16 | {
17 | Batchsize = batchsize;
18 | Gamma = gamma;
19 | Alpha = alpha;
20 | Beta = beta;
21 | /// Out is V[batchsize,1]
22 | Q = new Net(ObservationSize, 128, 1, DeviceType.CPU);
23 | ExpReplays = new EpisodeExpReplay(batchsize, gamma);
24 |
25 | var parameters = new[] {Q, PolicyNet}
26 | .SelectMany(a => a.parameters());
27 | Optimizer = Adam(parameters, Alpha);
28 | }
29 |
30 | ///
31 | /// Episodes send to train
32 | ///
33 | public int Batchsize { protected set; get; }
34 |
35 | public float Alpha { protected set; get; }
36 | public float Beta { protected set; get; }
37 | public float Gamma { protected set; get; }
38 | public Module Q { protected set; get; }
39 | public EpisodeExpReplay ExpReplays { protected set; get; }
40 | public Optimizer Optimizer { protected set; get; }
41 |
42 | ///
43 | /// QLearning for VNet
44 | ///
45 | ///
46 | public override LearnOutcome Learn()
47 | {
48 | var learnOutCome = new LearnOutcome();
49 |
50 | var episodes = RunEpisodes(Batchsize);
51 |
52 |
53 | episodes.ToList().ForEach(e =>
54 | {
55 | learnOutCome.AppendStep(e);
56 | ExpReplays.Enqueue(e);
57 | });
58 |
59 | var experienceCase = ExpReplays.All();
60 | var state = experienceCase.PreState;
61 | var action = experienceCase.Action;
62 | var valsRef = experienceCase.Reward;
63 | ExpReplays.Clear();
64 |
65 | Optimizer.zero_grad();
66 |
67 | var value = Q.forward(state);
68 |
69 | var lossValue = new MSELoss().forward(value, valsRef);
70 | lossValue.backward();
71 |
72 | var logProbV = torch.log(PolicyNet.forward(state)).gather(1, action);
73 | var logProbActionV = (valsRef - value.detach()) * logProbV;
74 | var lossPolicy = -logProbActionV.mean();
75 |
76 |
77 | lossPolicy.backward();
78 | Optimizer.step();
79 |
80 |
81 | learnOutCome.Evaluate = lossPolicy.item();
82 |
83 | return learnOutCome;
84 | }
85 |
86 |
87 | public override void Save(string path)
88 | {
89 | if (File.Exists(path)) File.Delete(path);
90 | PolicyNet.save(path);
91 | }
92 |
93 | public override void Load(string path)
94 | {
95 | PolicyNet.load(path);
96 | }
97 | }
98 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/PolicyBased/A3C.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Agents
2 | {
3 | internal class A3C
4 | {
5 | }
6 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/PolicyBased/ActorCritic.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using DeepSharp.RL.ExpReplays;
3 | using OpenCvSharp.Dnn;
4 | using TorchSharp.Modules;
5 | using static TorchSharp.torch.optim;
6 |
7 | namespace DeepSharp.RL.Agents
8 | {
9 | public class ActorCritic : PolicyGradientAgengt
10 | {
11 | public ActorCritic(Environ env,
12 | int batchsize,
13 | float alpha = 0.01f,
14 | float beta = 0.01f,
15 | float gamma = 0.99f)
16 | : base(env, "ActorCritic")
17 | {
18 | Batchsize = batchsize;
19 | Gamma = gamma;
20 | Alpha = alpha;
21 | Beta = beta;
22 | Q = new Net(ObservationSize, 128, ActionSize, DeviceType.CPU);
23 | ExpReplays = new EpisodeExpReplay(batchsize, gamma);
24 | ExpReplaysForPolicy = new EpisodeExpReplay(batchsize, gamma);
25 | var parameters = new[] {Q, PolicyNet}
26 | .SelectMany(a => a.parameters());
27 | Optimizer = Adam(parameters, Alpha);
28 | }
29 |
30 | ///
31 | /// Episodes send to train
32 | ///
33 | public int Batchsize { protected set; get; }
34 |
35 | public float Alpha { protected set; get; }
36 | public float Beta { protected set; get; }
37 | public float Gamma { protected set; get; }
38 | public Module Q { protected set; get; }
39 | public EpisodeExpReplay ExpReplays { protected set; get; }
40 | public EpisodeExpReplay ExpReplaysForPolicy { protected set; get; }
41 | public Optimizer Optimizer { protected set; get; }
42 |
43 | ///
44 | /// QLearning for VNet
45 | ///
46 | ///
47 | public override LearnOutcome Learn()
48 | {
49 | var learnOutCome = new LearnOutcome();
50 |
51 | var episodes = RunEpisodes(Batchsize);
52 |
53 |
54 | episodes.ToList().ForEach(e =>
55 | {
56 | learnOutCome.AppendStep(e);
57 | ExpReplays.Enqueue(e,false);
58 | ExpReplaysForPolicy.Enqueue(e);
59 | });
60 |
61 | var experienceCase = ExpReplays.All();
62 | var state = experienceCase.PreState;
63 | var action = experienceCase.Action;
64 | var reward = experienceCase.Reward;
65 | var poststate = experienceCase.PostState;
66 | ExpReplays.Clear();
67 |
68 | Optimizer.zero_grad();
69 |
70 | var stateActionValue = Q.forward(state).gather(1,action);
71 | var nextStateValue = Q.forward(poststate).max(1).values.detach();
72 | var expectedStatedActionValuey = reward+ nextStateValue * Gamma ;
73 | var lossValue = new MSELoss().forward(stateActionValue, expectedStatedActionValuey);
74 | lossValue.backward();
75 |
76 | var logProbV = torch.log(PolicyNet.forward(state)).gather(1, action);
77 | var logProbActionV = stateActionValue.detach() * logProbV;
78 | var lossPolicy = -logProbActionV.mean();
79 |
80 |
81 | lossPolicy.backward();
82 | Optimizer.step();
83 |
84 |
85 | learnOutCome.Evaluate = lossPolicy.item();
86 |
87 | return learnOutCome;
88 | }
89 |
90 |
91 | public override void Save(string path)
92 | {
93 | if (File.Exists(path)) File.Delete(path);
94 | PolicyNet.save(path);
95 | }
96 |
97 | public override void Load(string path)
98 | {
99 | PolicyNet.load(path);
100 | }
101 | }
102 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/PolicyBased/Reinforce.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using DeepSharp.RL.ExpReplays;
3 | using static TorchSharp.torch.optim;
4 |
5 | namespace DeepSharp.RL.Agents
6 | {
7 | ///
8 | /// ReinforceOriginal: Learn by a batch episodes
9 | ///
10 | public class Reinforce : PolicyGradientAgengt
11 | {
12 | public Reinforce(Environ env,
13 | int batchsize = 4,
14 | float gamma = 0.99f,
15 | float alpha = 0.01f)
16 | : base(env, "Reinforce")
17 | {
18 | Batchsize = batchsize;
19 | Gamma = gamma;
20 | Alpha = alpha;
21 |
22 | ExpReplays = new EpisodeExpReplay(batchsize, gamma);
23 | Optimizer = Adam(PolicyNet.parameters(), Alpha);
24 | }
25 |
26 | ///
27 | /// Episodes send to train
28 | ///
29 | public int Batchsize { protected set; get; }
30 |
31 | public float Gamma { protected set; get; }
32 | public float Alpha { protected set; get; }
33 |
34 |
35 | public Optimizer Optimizer { protected set; get; }
36 |
37 | public EpisodeExpReplay ExpReplays { protected set; get; }
38 |
39 |
40 | public override LearnOutcome Learn()
41 | {
42 | var learnOutCome = new LearnOutcome();
43 |
44 | var episodes = RunEpisodes(Batchsize);
45 |
46 | Optimizer.zero_grad();
47 |
48 | episodes.ToList().ForEach(e =>
49 | {
50 | learnOutCome.AppendStep(e);
51 | ExpReplays.Enqueue(e);
52 | });
53 |
54 | var experienceCase = ExpReplays.All();
55 | var state = experienceCase.PreState;
56 | var action = experienceCase.Action;
57 | var qValues = experienceCase.Reward;
58 | ExpReplays.Clear();
59 |
60 | var logProbV = torch.log(PolicyNet.forward(state)).gather(1, action);
61 | var logProbActionV = qValues * logProbV;
62 | var loss = -logProbActionV.mean();
63 |
64 |
65 | loss.backward();
66 | Optimizer.step();
67 |
68 |
69 | learnOutCome.Evaluate = loss.item();
70 |
71 | return learnOutCome;
72 | }
73 |
74 |
75 | public override void Save(string path)
76 | {
77 | if (File.Exists(path)) File.Delete(path);
78 | PolicyNet.save(path);
79 | }
80 |
81 | public override void Load(string path)
82 | {
83 | PolicyNet.load(path);
84 | }
85 | }
86 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/PolicyBased/ReinforceOriginal.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using static TorchSharp.torch.optim;
3 |
4 | namespace DeepSharp.RL.Agents
5 | {
6 | ///
7 | /// ReinforceOriginal: Learn after each episode
8 | ///
9 | public class ReinforceOriginal : PolicyGradientAgengt
10 | {
11 | public ReinforceOriginal(Environ env,
12 | float gamma = 0.99f,
13 | float alpha = 0.01f)
14 | : base(env, "ReinforceOriginal")
15 | {
16 | Gamma = gamma;
17 | Alpha = alpha;
18 | Optimizer = Adam(PolicyNet.parameters(), Alpha);
19 | }
20 |
21 | ///
22 | /// Episodes send to train
23 | ///
24 | public int Batchsize { protected set; get; }
25 |
26 | public float Gamma { protected set; get; }
27 | public float Alpha { protected set; get; }
28 |
29 |
30 | public Optimizer Optimizer { protected set; get; }
31 |
32 |
33 | public override LearnOutcome Learn()
34 | {
35 | var learnOutCome = new LearnOutcome();
36 |
37 | var episode = RunEpisode();
38 | var steps = episode.Steps;
39 | learnOutCome.AppendStep(episode.Steps);
40 |
41 | steps.Reverse();
42 | Optimizer.zero_grad();
43 |
44 | var g = 0f;
45 |
46 | foreach (var s in steps)
47 | {
48 | var reward = s.Reward.Value;
49 | var state = s.PreState.Value!.unsqueeze(0);
50 | var action = s.Action.Value!.view(-1, 1).to(torch.ScalarType.Int64);
51 | var logProb = torch.log(PolicyNet.forward(state)).gather(1, action);
52 |
53 | g = Gamma * g + reward;
54 |
55 | var loss = -logProb * g;
56 | loss.backward();
57 | learnOutCome.Evaluate = loss.item();
58 | }
59 |
60 | Optimizer.step();
61 |
62 | return learnOutCome;
63 | }
64 |
65 |
66 | public override void Save(string path)
67 | {
68 | if (File.Exists(path)) File.Delete(path);
69 | PolicyNet.save(path);
70 | }
71 |
72 | public override void Load(string path)
73 | {
74 | PolicyNet.load(path);
75 | }
76 | }
77 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueAgent.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 |
3 | namespace DeepSharp.RL.Agents
4 | {
5 | public abstract class ValueAgent : Agent
6 | {
7 | protected ValueAgent(Environ env, string name)
8 | : base(env, name)
9 | {
10 | QTable = new QTable();
11 | }
12 |
13 | public QTable QTable { protected set; get; }
14 |
15 |
16 | ///
17 | /// argmax(a') Q(state,a')
18 | /// 价值表中获取该状态State下最高价值的action'
19 | ///
20 | ///
21 | ///
22 | public override Act GetPolicyAct(torch.Tensor state)
23 | {
24 | var action = QTable.GetBestAct(state);
25 | return action ?? GetSampleAct();
26 | }
27 |
28 | public override void Save(string path)
29 | {
30 | /// Save Q Table
31 | }
32 |
33 | public override void Load(string path)
34 | {
35 | /// Save Q Table
36 | }
37 | }
38 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/DeepQN/CategoricalDQN.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Agents
2 | {
3 | internal class CategoricalDQN
4 | {
5 | }
6 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/DeepQN/DoubleDQN.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Agents
2 | {
3 | internal class DoubleDQN
4 | {
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/DeepQN/DuelingDQN.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Agents
2 | {
3 | internal class DuelingDQN
4 | {
5 | }
6 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/DeepQN/NDQN.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Agents
2 | {
3 | internal class NDQN
4 | {
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/DeepQN/NoisyDQN.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Agents
2 | {
3 | internal class NoisyDQN
4 | {
5 | }
6 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/PIDiscountR.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using DeepSharp.Utility;
3 |
4 | namespace DeepSharp.RL.Agents
5 | {
6 | public class PIDiscountR : PolicyIteration
7 | {
8 | public PIDiscountR(Environ env, Dictionary p,
9 | Dictionary r, int t, float gamma = 0.9f)
10 | : base(env, p, r, t)
11 | {
12 | Gamma = gamma;
13 | }
14 |
15 | public float Gamma { protected set; get; }
16 |
17 | protected override VTable GetVTable(int t)
18 | {
19 | var vNext = new VTable();
20 |
21 | foreach (var unused in Enumerable.Range(0, t))
22 | foreach (var x in X)
23 | vNext[x] = RewardKeys
24 | .Where(a => a.State.Equals(x))
25 | .Sum(r => P[r] * (R[r] + vNext[r.NewState] * Gamma));
26 | return vNext;
27 | }
28 |
29 | protected override QTable GetQTable(VTable v, int t)
30 | {
31 | var q = new QTable();
32 |
33 | var states = P.Keys
34 | .Select(a => a.State)
35 | .Distinct(new TensorEqualityCompare())
36 | .ToArray();
37 |
38 | var actions = P.Keys
39 | .Select(a => a.Act)
40 | .Distinct(new TensorEqualityCompare())
41 | .ToArray();
42 |
43 | foreach (var state in states)
44 | foreach (var action in actions)
45 | {
46 | var value = RewardKeys.Where(a => a.State.Equals(state) && a.Act.Equals(action))
47 | .Sum(a => P[a] * (R[a] + v[a.NewState] * Gamma));
48 | q[state, action] = value;
49 | }
50 |
51 | return q;
52 | }
53 | }
54 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/PITStepR.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using DeepSharp.Utility;
3 |
4 | namespace DeepSharp.RL.Agents
5 | {
6 | public class PITStepR : ValueIterate
7 | {
8 | public PITStepR(Environ env, Dictionary p,
9 | Dictionary r, int t = 100) : base(env, p, r, t)
10 | {
11 | }
12 |
13 | protected override VTable GetVTable(int t)
14 | {
15 | var vNext = new VTable();
16 |
17 | foreach (var i in Enumerable.Range(1, t))
18 | foreach (var x in X)
19 | vNext[x] = RewardKeys
20 | .Where(a => a.State.Equals(x))
21 | .Sum(r => P[r] * (R[r] / i + vNext[r.NewState] * (i - 1) / i));
22 | return vNext;
23 | }
24 |
25 | protected override QTable GetQTable(VTable v, int t)
26 | {
27 | var q = new QTable();
28 |
29 | var states = P.Keys
30 | .Select(a => a.State)
31 | .Distinct(new TensorEqualityCompare())
32 | .ToArray();
33 |
34 | var actions = P.Keys
35 | .Select(a => a.Act)
36 | .Distinct(new TensorEqualityCompare())
37 | .ToArray();
38 |
39 | foreach (var state in states)
40 | foreach (var action in actions)
41 | {
42 | var value = RewardKeys.Where(a => a.State.Equals(state) && a.Act.Equals(action))
43 | .Sum(a => P[a] * (R[a] / t + v[a.NewState] * (t - 1) / t));
44 | q[state, action] = value;
45 | }
46 |
47 | return q;
48 | }
49 | }
50 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/PolicyIteration.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using DeepSharp.Utility;
3 |
4 | namespace DeepSharp.RL.Agents
5 | {
6 | public abstract class PolicyIteration : ValueAgent
7 | {
8 | ///
9 | ///
10 | ///
11 | ///
12 | ///
13 | ///
14 | ///
15 | protected PolicyIteration(Environ env, Dictionary p,
16 | Dictionary r, int t = 100)
17 | : base(env, "ValueIterate")
18 | {
19 | T = t;
20 | VTable = new VTable();
21 | P = p;
22 | R = r;
23 | RewardKeys = p.Keys.ToArray();
24 | X = P.Keys.Select(a => a.State)
25 | .Distinct(new TensorEqualityCompare())
26 | .ToArray();
27 | }
28 |
29 | public int T { protected set; get; }
30 |
31 |
32 | public VTable VTable { protected set; get; }
33 |
34 | protected Dictionary P { set; get; }
35 | protected Dictionary R { set; get; }
36 | protected torch.Tensor[] X { set; get; }
37 | protected RewardKey[] RewardKeys { set; get; }
38 |
39 |
40 | public override LearnOutcome Learn()
41 | {
42 | while (true)
43 | {
44 | /// Update VTable
45 | foreach (var t in Enumerable.Range(1, T))
46 | {
47 | var vNext = GetVTable(t);
48 | VTable = vNext;
49 | }
50 |
51 | /// Policy Iterate
52 | /// Get Policy (argmax Q=> Update QTable) by Value
53 | var qTable = GetQTable(VTable, T);
54 | if (qTable == QTable) break;
55 | QTable = qTable;
56 | }
57 |
58 | return new LearnOutcome();
59 | }
60 |
61 | protected abstract VTable GetVTable(int t);
62 |
63 | protected abstract QTable GetQTable(VTable vTable, int t);
64 | }
65 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/VIDiscountR.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using DeepSharp.Utility;
3 |
4 | namespace DeepSharp.RL.Agents
5 | {
6 | public class VIDiscountR : ValueIterate
7 | {
8 | public VIDiscountR(Environ env, Dictionary p,
9 | Dictionary r, int t, float gamma = 0.9f, float threshold = 0.1f)
10 | : base(env, p, r, t, threshold)
11 | {
12 | Gamma = gamma;
13 | }
14 |
15 | public float Gamma { protected set; get; }
16 |
17 | protected override VTable GetVTable(int t)
18 | {
19 | var vNext = new VTable();
20 |
21 | foreach (var unused in Enumerable.Range(0, t))
22 | foreach (var x in X)
23 | vNext[x] = RewardKeys
24 | .Where(a => a.State.Equals(x))
25 | .Sum(r => P[r] * (R[r] + vNext[r.NewState] * Gamma));
26 | return vNext;
27 | }
28 |
29 | protected override QTable GetQTable(VTable v, int t)
30 | {
31 | var q = new QTable();
32 |
33 | var states = P.Keys
34 | .Select(a => a.State)
35 | .Distinct(new TensorEqualityCompare())
36 | .ToArray();
37 |
38 | var actions = P.Keys
39 | .Select(a => a.Act)
40 | .Distinct(new TensorEqualityCompare())
41 | .ToArray();
42 |
43 | foreach (var state in states)
44 | foreach (var action in actions)
45 | {
46 | var value = RewardKeys.Where(a => a.State.Equals(state) && a.Act.Equals(action))
47 | .Sum(a => P[a] * (R[a] + v[a.NewState] * Gamma));
48 | q[state, action] = value;
49 | }
50 |
51 | return q;
52 | }
53 | }
54 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/VITStepR.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using DeepSharp.Utility;
3 |
4 | namespace DeepSharp.RL.Agents
5 | {
6 | public class VITStepR : ValueIterate
7 | {
8 | public VITStepR(Environ env, Dictionary p,
9 | Dictionary r, int t = 100, float threshold = 0.1f) : base(env, p, r, t, threshold)
10 | {
11 | }
12 |
13 | protected override VTable GetVTable(int t)
14 | {
15 | var vNext = new VTable();
16 |
17 | foreach (var i in Enumerable.Range(1, t))
18 | foreach (var x in X)
19 | vNext[x] = RewardKeys
20 | .Where(a => a.State.Equals(x))
21 | .Sum(r => P[r] * (R[r] / i + vNext[r.NewState] * (i - 1) / i));
22 | return vNext;
23 | }
24 |
25 | protected override QTable GetQTable(VTable v, int t)
26 | {
27 | var q = new QTable();
28 |
29 | var states = P.Keys
30 | .Select(a => a.State)
31 | .Distinct(new TensorEqualityCompare())
32 | .ToArray();
33 |
34 | var actions = P.Keys
35 | .Select(a => a.Act)
36 | .Distinct(new TensorEqualityCompare())
37 | .ToArray();
38 |
39 | foreach (var state in states)
40 | foreach (var action in actions)
41 | {
42 | var value = RewardKeys.Where(a => a.State.Equals(state) && a.Act.Equals(action))
43 | .Sum(a => P[a] * (R[a] / t + v[a.NewState] * (t - 1) / t));
44 | q[state, action] = value;
45 | }
46 |
47 | return q;
48 | }
49 | }
50 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/ValueIterate.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using DeepSharp.Utility;
3 |
4 | namespace DeepSharp.RL.Agents
5 | {
6 | public abstract class ValueIterate : ValueAgent
7 | {
8 | ///
9 | ///
10 | ///
11 | ///
12 | ///
13 | ///
14 | ///
15 | protected ValueIterate(Environ env, Dictionary p,
16 | Dictionary r, int t = 100, float threshold = 0.1f)
17 | : base(env, "ValueIterate")
18 | {
19 | T = t;
20 | Threshold = threshold;
21 | VTable = new VTable();
22 | P = p;
23 | R = r;
24 | RewardKeys = p.Keys.ToArray();
25 | X = P.Keys.Select(a => a.State)
26 | .Distinct(new TensorEqualityCompare())
27 | .ToArray();
28 | }
29 |
30 | public int T { protected set; get; }
31 |
32 | ///
33 | /// Convergence Threshold
34 | ///
35 | public float Threshold { protected set; get; }
36 |
37 | public VTable VTable { protected set; get; }
38 |
39 | protected Dictionary P { set; get; }
40 | protected Dictionary R { set; get; }
41 | protected torch.Tensor[] X { set; get; }
42 | protected RewardKey[] RewardKeys { set; get; }
43 |
44 |
45 | public override LearnOutcome Learn()
46 | {
47 | /// Value Iterate
48 | foreach (var t in Enumerable.Range(1, T))
49 | {
50 | var vNext = GetVTable(t);
51 | if (vNext - VTable < Threshold)
52 | break;
53 | VTable = vNext;
54 | }
55 |
56 | /// Get Policy (argmax Q=> Update QTable) by Value
57 | var qTable = GetQTable(VTable, T);
58 | QTable = qTable;
59 | return new LearnOutcome();
60 | }
61 |
62 | protected abstract VTable GetVTable(int t);
63 |
64 | protected abstract QTable GetQTable(VTable vTable, int t);
65 | }
66 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/MonteCarlo/MonteCarloOffPolicy.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 |
3 | namespace DeepSharp.RL.Agents
4 | {
5 | ///
6 | /// Todo has issue about GetTransitPer? 2023/7/3
7 | ///
8 | public class MonteCarloOffPolicy : ValueAgent
9 | {
10 | public MonteCarloOffPolicy(Environ env, float epsilon = 0.1f, int t = 10)
11 | : base(env, "MonteCarloOffPolicy")
12 | {
13 | Epsilon = epsilon;
14 | T = t;
15 | Count = new Dictionary();
16 | }
17 |
18 | public int T { protected set; get; }
19 |
20 | public Dictionary Count { protected set; get; }
21 |
22 |
23 | public override LearnOutcome Learn()
24 | {
25 | Environ.Reset();
26 | var episode = new Episode();
27 | var epoch = 0;
28 | var act = GetEpsilonAct(Environ.Observation!.Value!);
29 | while (Environ.IsComplete(epoch) == false && epoch < T)
30 | {
31 | epoch++;
32 | var step = Environ.Step(act, epoch);
33 |
34 | episode.Steps.Add(step);
35 | Environ.CallBack?.Invoke(step);
36 | Environ.Observation = step.PostState; /// It's import for Update Observation
37 | }
38 |
39 | Update(episode);
40 |
41 | var sumReward = episode.Steps.Sum(a => a.Reward.Value);
42 | episode.SumReward = new Reward(sumReward);
43 |
44 | var learnOut = new LearnOutcome(episode);
45 |
46 | return learnOut;
47 | }
48 |
49 |
50 | public void Update(Episode episode)
51 | {
52 | var lenth = episode.Length;
53 | var steps = episode.Steps;
54 | foreach (var t in Enumerable.Range(0, lenth))
55 | {
56 | var step = steps[t];
57 | var key = new TransitKey(step.PreState, step.Action);
58 | var r = steps.Skip(t).Average(a => a.Reward.Value);
59 | var per = steps.Skip(t).Select(GetTransitPer).Aggregate(1f, (a, b) => a * b); ///Error Here
60 | var finalR = r * per;
61 | var count = GetCount(key);
62 | QTable[key] = (QTable[key] * count + finalR) / (count + 1);
63 | SetCount(key, count + 1);
64 | }
65 | }
66 |
67 |
68 | private float GetTransitPer(Step step)
69 | {
70 | var actPolicy = GetPolicyAct(step.PreState.Value!).Value!;
71 | var actStep = step.Action.Value!;
72 | var actionSpace = Environ.ActionSpace!.N;
73 | var e = 1f;
74 | var per = actPolicy.Equals(actStep)
75 | ? 1 - Epsilon + Epsilon / actionSpace
76 | : Epsilon / actionSpace;
77 | return e / per;
78 | }
79 |
80 | private int GetCount(TransitKey transitKey)
81 | {
82 | Count.TryAdd(transitKey, 0);
83 | return Count[transitKey];
84 | }
85 |
86 | private void SetCount(TransitKey transitKey, int value)
87 | {
88 | Count[transitKey] = value;
89 | }
90 | }
91 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/MonteCarlo/MonteCarloOnPolicy.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 |
3 | namespace DeepSharp.RL.Agents
4 | {
5 | ///
6 | /// Monte Carlo Method On Policy
7 | ///
8 | public class MonteCarloOnPolicy : ValueAgent
9 | {
10 | ///
11 | ///
12 | ///
13 | /// the leaning count of each epoch
14 | public MonteCarloOnPolicy(Environ env, float epsilon = 0.1f, int t = 10)
15 | : base(env, "MonteCarloOnPolicy")
16 | {
17 | Epsilon = epsilon;
18 | T = t;
19 | Count = new Dictionary();
20 | }
21 |
22 | public int T { protected set; get; }
23 |
24 | public Dictionary Count { protected set; get; }
25 |
26 |
27 | public override LearnOutcome Learn()
28 | {
29 | Environ.Reset();
30 | var episode = new Episode();
31 | var epoch = 0;
32 | var act = GetEpsilonAct(Environ.Observation!.Value!);
33 | while (Environ.IsComplete(epoch) == false && epoch < T)
34 | {
35 | epoch++;
36 | var step = Environ.Step(act, epoch);
37 |
38 | episode.Steps.Add(step);
39 |
40 | Environ.CallBack?.Invoke(step);
41 | Environ.Observation = step.PostState; /// It's import for Update Observation
42 | }
43 |
44 | Update(episode);
45 |
46 | var sumReward = episode.Steps.Sum(a => a.Reward.Value);
47 | episode.SumReward = new Reward(sumReward);
48 |
49 |
50 | var learnOut = new LearnOutcome(episode);
51 |
52 | return learnOut;
53 | }
54 |
55 |
56 | public void Update(Episode episode)
57 | {
58 | var lenth = episode.Length;
59 | var steps = episode.Steps;
60 | foreach (var t in Enumerable.Range(0, lenth))
61 | {
62 | var step = steps[t];
63 | var key = new TransitKey(step.PreState, step.Action);
64 | var r = steps.Skip(t).Average(a => a.Reward.Value);
65 | var count = GetCount(key);
66 | QTable[key] = (QTable[key] * count + r) / (count + 1);
67 | SetCount(key, count + 1);
68 | }
69 | }
70 |
71 | private int GetCount(TransitKey transitKey)
72 | {
73 | Count.TryAdd(transitKey, 0);
74 | return Count[transitKey];
75 | }
76 |
77 | private void SetCount(TransitKey transitKey, int value)
78 | {
79 | Count[transitKey] = value;
80 | }
81 | }
82 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/TemporalDifference/QLearning.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 |
3 | namespace DeepSharp.RL.Agents
4 | {
5 | public class QLearning : ValueAgent
6 | {
7 | ///
8 | ///
9 | ///
10 | /// epsilon of ε-greedy Policy
11 | /// learning rate
12 | /// rate of discount
13 | public QLearning(Environ env,
14 | float epsilon = 0.1f,
15 | float alpha = 0.2f,
16 | float gamma = 0.9f) :
17 | base(env, "QLearning")
18 | {
19 | Epsilon = epsilon;
20 | Alpha = alpha;
21 | Gamma = gamma;
22 | }
23 |
24 |
25 | public float Alpha { protected set; get; }
26 | public float Gamma { protected set; get; }
27 |
28 |
29 | public override Act GetPolicyAct(torch.Tensor state)
30 | {
31 | var action = QTable.GetBestAct(state);
32 | return action ?? GetSampleAct();
33 | }
34 |
35 |
36 | public override LearnOutcome Learn()
37 | {
38 | Environ.Reset();
39 | var episode = new Episode();
40 | var epoch = 0;
41 | while (Environ.IsComplete(epoch) == false)
42 | {
43 | epoch++;
44 | var epsilonAct = GetEpsilonAct(Environ.Observation!.Value!);
45 | var step = Environ.Step(epsilonAct, epoch);
46 |
47 | Update(step);
48 |
49 | episode.Steps.Add(step);
50 | Environ.CallBack?.Invoke(step);
51 |
52 | Environ.Observation = step.PostState; /// It's import for Update Observation
53 | }
54 |
55 | var sumReward = episode.Steps.Sum(a => a.Reward.Value);
56 | episode.SumReward = new Reward(sumReward);
57 |
58 | return new LearnOutcome(episode);
59 | }
60 |
61 | public void Update(Step step)
62 | {
63 | var s = step.PreState.Value!;
64 | var a = step.Action.Value!;
65 | var r = step.Reward.Value;
66 | var sNext = step.PostState.Value!;
67 | var q = QTable[s, a];
68 |
69 | var aNext = GetPolicyAct(sNext);
70 | var qNext = QTable[sNext, aNext.Value!];
71 |
72 | QTable[s, a] = q + Alpha * (r + Gamma * qNext - q);
73 | }
74 | }
75 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Agents/ValueBased/TemporalDifference/SARSA.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 |
3 | namespace DeepSharp.RL.Agents
4 | {
5 | public class SARSA : ValueAgent
6 | {
7 | ///
8 | ///
9 | ///
10 | /// epsilon of ε-greedy Policy
11 | /// learning rate
12 | /// rate of discount
13 | public SARSA(Environ env,
14 | float epsilon = 0.1f,
15 | float alpha = 0.2f,
16 | float gamma = 0.9f) : base(env, "SARSA")
17 | {
18 | Epsilon = epsilon;
19 | Alpha = alpha;
20 | Gamma = gamma;
21 | }
22 |
23 | public float Alpha { protected set; get; }
24 | public float Gamma { protected set; get; }
25 |
26 | public override LearnOutcome Learn()
27 | {
28 | Environ.Reset();
29 | var episode = new Episode();
30 | var epoch = 0;
31 | var act = GetEpsilonAct(Environ.Observation!.Value!);
32 | while (Environ.IsComplete(epoch) == false)
33 | {
34 | epoch++;
35 | var step = Environ.Step(act, epoch);
36 |
37 | var actNext = Update(step); ///
38 |
39 | episode.Steps.Add(step);
40 | Environ.CallBack?.Invoke(step);
41 |
42 | Environ.Observation = step.PostState; /// It's import for Update Observation
43 | act = actNext;
44 | }
45 |
46 | var sumReward = episode.Steps.Sum(a => a.Reward.Value);
47 | episode.SumReward = new Reward(sumReward);
48 |
49 | return new LearnOutcome(episode);
50 | }
51 |
52 |
53 | public Act Update(Step step)
54 | {
55 | var s = step.PreState.Value!;
56 | var a = step.Action.Value!;
57 | var r = step.Reward.Value;
58 | var sNext = step.PostState.Value!;
59 | var q = QTable[s, a];
60 |
61 | var aNext = GetEpsilonAct(sNext); /// a' by ε-greedy policy
62 | var qNext = QTable[sNext, aNext.Value!];
63 |
64 |
65 | QTable[s, a] = q + Alpha * (r + Gamma * qNext - q);
66 | return aNext;
67 | }
68 | }
69 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/DeepSharp.RL.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | net7.0
5 | enable
6 | enable
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/src/DeepSharp.RL/DeepSharp.RL.csproj.DotSettings:
--------------------------------------------------------------------------------
1 |
2 | True
3 | True
4 | True
5 | True
6 | True
7 | True
8 | True
9 | True
10 | True
11 | True
12 | True
13 | True
14 | True
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Enumerates/PlayMode.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Enumerates
2 | {
3 | ///
4 | /// Play mode of Agent
5 | ///
6 | public enum PlayMode
7 | {
8 | ///
9 | /// 平均采样
10 | ///
11 | Sample,
12 |
13 | ///
14 | /// 根据智能体的策略
15 | ///
16 | Agent,
17 |
18 | ///
19 | /// Sample(ε) and Agent(1-ε)
20 | ///
21 | EpsilonGreedy
22 | }
23 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Act.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs
2 | {
3 | ///
4 | /// 动作
5 | ///
6 | public class Act : IEqualityComparer
7 | {
8 | public Act(torch.Tensor? action)
9 | {
10 | Value = action;
11 | TimeStamp = DateTime.Now;
12 | }
13 |
14 | ///
15 | /// 奖励的张量格式
16 | ///
17 | public torch.Tensor? Value { set; get; }
18 |
19 | ///
20 | /// 奖励产生的时间戳
21 | ///
22 | public DateTime TimeStamp { set; get; }
23 |
24 |
25 | public bool Equals(Act? x, Act? y)
26 | {
27 | if (ReferenceEquals(x, y)) return true;
28 | if (ReferenceEquals(x, null)) return false;
29 | if (ReferenceEquals(y, null)) return false;
30 | return x.GetType() == y.GetType() && x.Value!.Equals(y.Value!);
31 | }
32 |
33 | public int GetHashCode(Act obj)
34 | {
35 | return HashCode.Combine(obj.TimeStamp, obj.Value);
36 | }
37 |
38 | public Act To(torch.Device device)
39 | {
40 | return new Act(Value!.to(device));
41 | }
42 |
43 | public override string ToString()
44 | {
45 | return $"{TimeStamp}\t{Value!.ToString(torch.numpy)}";
46 | }
47 | }
48 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Environ.cs:
--------------------------------------------------------------------------------
1 | using System.Text;
2 |
3 | namespace DeepSharp.RL.Environs
4 | {
5 | ///
6 | /// 环境
7 | /// 提供观察 并给与奖励
8 | ///
9 | public abstract class Environ
10 | where T1 : Space
11 | where T2 : Space
12 | {
13 | public Action? CallBack;
14 |
15 | protected Environ(string name, DeviceType deviceType = DeviceType.CUDA)
16 | {
17 | Name = name;
18 | Device = new torch.Device(deviceType);
19 | Reward = new Reward(0);
20 | ObservationList = new List();
21 | }
22 |
23 | protected Environ(string name)
24 | : this(name, DeviceType.CPU)
25 | {
26 | }
27 |
28 |
29 | public string Name { set; get; }
30 |
31 | public torch.Device Device { set; get; }
32 | public T1? ActionSpace { protected set; get; }
33 | public T2? ObservationSpace { protected set; get; }
34 |
35 |
36 | ///
37 | /// Observation Current
38 | ///
39 | public Observation? Observation { set; get; }
40 |
41 | ///
42 | /// Reward Current
43 | ///
44 | public Reward Reward { set; get; }
45 |
46 | ///
47 | /// Observation Temp List
48 | ///
49 | public List ObservationList { set; get; }
50 |
51 | public int Life => ObservationList.Count;
52 |
53 |
54 | ///
55 | /// 恢复初始
56 | ///
57 | public virtual Observation Reset()
58 | {
59 | Observation = new Observation(ObservationSpace!.Generate());
60 | ObservationList = new List {Observation};
61 | Reward = new Reward(0);
62 | return Observation;
63 | }
64 |
65 | ///
66 | /// 获取一个回合的最终奖励
67 | ///
68 | ///
69 | ///
70 | public abstract float GetReturn(Episode episode);
71 |
72 | public virtual Act SampleAct()
73 | {
74 | return new Act(ActionSpace!.Sample());
75 | }
76 |
77 | ///
78 | /// Agent provide Act
79 | ///
80 | ///
81 | ///
82 | public virtual Step Step(Act act, int epoch)
83 | {
84 | var state = Observation!;
85 | var stateNew = Update(act);
86 | var reward = GetReward(stateNew);
87 | var complete = IsComplete(epoch);
88 | var step = new Step(state, act, stateNew, reward, complete);
89 | ObservationList.Add(stateNew);
90 | Observation = stateNew;
91 | return step;
92 | }
93 |
94 |
95 | ///
96 | /// Update Environ Observation according with one action from Agent
97 | ///
98 | /// Action from Policy
99 | /// new observation
100 | public abstract Observation Update(Act act);
101 |
102 |
103 | ///
104 | /// Cal Reward from Observation
105 | /// 从观察获取单步奖励的计算方法
106 | ///
107 | /// one observation
108 | /// one reward
109 | public abstract Reward GetReward(Observation observation);
110 |
111 |
112 | ///
113 | /// Check Environ is Complete
114 | /// 判断探索是否结束
115 | ///
116 | ///
117 | ///
118 | public abstract bool IsComplete(int epoch);
119 |
120 |
121 | public override string ToString()
122 | {
123 | var str = new StringBuilder();
124 | str.AppendLine($"{Name}\tLife:{Life}");
125 | str.AppendLine(new string('-', 30));
126 | str.Append($"State:\t{Observation!.Value!.ToString(torch.numpy)}");
127 | return str.ToString();
128 | }
129 | }
130 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Episode.cs:
--------------------------------------------------------------------------------
1 | using System.Text;
2 |
3 | namespace DeepSharp.RL.Environs
4 | {
5 | ///
6 | /// 片段
7 | ///
8 | public class Episode
9 | {
10 | public Episode()
11 | {
12 | Steps = new List();
13 | SumReward = new Reward(0);
14 | DateTime = DateTime.Now;
15 | Evaluate = 0;
16 | }
17 |
18 | public Episode(List steps)
19 | {
20 | Steps = steps;
21 | SumReward = new Reward(0);
22 | DateTime = DateTime.Now;
23 | Evaluate = 0;
24 | }
25 |
26 | public List Steps { set; get; }
27 | public Step this[int i] => Steps[i];
28 |
29 | public Reward SumReward { set; get; }
30 | public DateTime DateTime { set; get; }
31 |
32 | public bool IsComplete { set; get; }
33 | public int Length => Steps.Count;
34 |
35 | public float Evaluate { set; get; }
36 |
37 | public void Enqueue(Step step)
38 | {
39 | Steps.Add(step);
40 | }
41 |
42 | public int[] GetAction()
43 | {
44 | var actions = Steps
45 | .Select(a => a.Action.Value!.ToInt32())
46 | .ToArray();
47 | return actions;
48 | }
49 |
50 | public override string ToString()
51 | {
52 | var str = new StringBuilder();
53 | str.AppendLine($"Test By Agent: Get Reward {SumReward}");
54 | Steps.ForEach(s =>
55 | {
56 | var state = s.PreState.Value!.ToString(torch.numpy);
57 | var action = s.Action.Value!.ToInt32();
58 | var reward = s.Reward.Value;
59 | var line = $"{state},{action},{reward}";
60 | str.AppendLine(line);
61 | });
62 | return str.ToString();
63 | }
64 |
65 |
66 | ///
67 | /// Get a Episode which each step's reward estimate to QValue (discount by Gamma)
68 | ///
69 | ///
70 | ///
71 | public Episode GetReturnEpisode(float gamma = 0.9f)
72 | {
73 | var stepsWithReturn = new List();
74 |
75 |
76 | var sumR = 0f;
77 | var steps = Steps;
78 | steps.Reverse();
79 | foreach (var s in steps)
80 | {
81 | sumR *= gamma;
82 | sumR += s.Reward.Value;
83 |
84 | var sNew = (Step) s.Clone();
85 | sNew.Reward = new Reward(sumR);
86 | stepsWithReturn.Add(sNew);
87 | }
88 |
89 | steps.Reverse();
90 | stepsWithReturn.Reverse();
91 | var res = new Episode(stepsWithReturn);
92 | return res;
93 | }
94 | }
95 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/FrozenLake/LakeRole.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs
2 | {
3 | public enum LakeRole
4 | {
5 | Ice,
6 | Hole,
7 | Start,
8 | End
9 | }
10 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/FrozenLake/LakeUnit.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs
2 | {
3 | public class LakeUnit
4 | {
5 | ///
6 | ///
7 | ///
8 | ///
9 | public LakeUnit(int row, int column, int index)
10 | {
11 | Column = column;
12 | Row = row;
13 | Index = index;
14 | Role = LakeRole.Ice;
15 | }
16 |
17 | public int Index { set; get; }
18 | public int Row { set; get; }
19 | public int Column { set; get; }
20 | public LakeRole Role { set; get; }
21 |
22 | public override string ToString()
23 | {
24 | switch (Role)
25 | {
26 | case LakeRole.Ice:
27 | return "I";
28 | case LakeRole.Hole:
29 | return "H";
30 | case LakeRole.Start:
31 | return "S";
32 | case LakeRole.End:
33 | return "E";
34 | default:
35 | throw new ArgumentOutOfRangeException();
36 | }
37 | }
38 | }
39 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/KArmedBandit/Bandit.cs:
--------------------------------------------------------------------------------
1 | using MathNet.Numerics.Random;
2 |
3 | namespace DeepSharp.RL.Environs
4 | {
5 | ///
6 | /// 简化的赌博机,以Prob 概率吐出一枚硬币
7 | ///
8 | public class Bandit
9 | {
10 | public Bandit(string name, double prob = 0.7)
11 | {
12 | Name = name;
13 | Prob = prob;
14 | RandomSource = new SystemRandomSource();
15 | }
16 |
17 | protected SystemRandomSource RandomSource { set; get; }
18 |
19 | public double Prob { set; get; }
20 |
21 | public string Name { get; set; }
22 |
23 |
24 | public float Step()
25 | {
26 | var pro = RandomSource.NextDouble();
27 | return pro <= Prob ? 1 : 0;
28 | }
29 |
30 | public override string ToString()
31 | {
32 | return $"Bandit{Name}:\t{Prob:P}";
33 | }
34 | }
35 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/KArmedBandit/KArmedBandit.cs:
--------------------------------------------------------------------------------
1 | using System.Text;
2 | using DeepSharp.RL.Environs.Spaces;
3 | using MathNet.Numerics.Random;
4 |
5 | namespace DeepSharp.RL.Environs
6 | {
7 | ///
8 | /// 多臂赌博机,每个赌博机以 0,0.1到1 的概率随机生成
9 | ///
10 | public sealed class KArmedBandit : Environ
11 | {
12 | public KArmedBandit(int k, DeviceType deviceType = DeviceType.CPU)
13 | : base("KArmedBandit", deviceType)
14 | {
15 | bandits = new Bandit[k];
16 | ActionSpace = new Disperse(k, deviceType: deviceType);
17 | ObservationSpace = new Box(0, 1, new long[] {k}, deviceType);
18 | Create(k);
19 | Reset();
20 | }
21 |
22 |
23 | public KArmedBandit(double[] probs, DeviceType deviceType = DeviceType.CPU)
24 | : base("KArmedBandit", deviceType)
25 | {
26 | var k = probs.Length;
27 | bandits = new Bandit[k];
28 | ActionSpace = new Disperse(k, deviceType: deviceType);
29 | ObservationSpace = new Box(0, 1, new long[] {k}, deviceType);
30 | Create(probs);
31 | Reset();
32 | }
33 |
34 | private Bandit[] bandits { get; }
35 | public Bandit this[int k] => bandits[k];
36 |
37 |
38 | private void Create(int k)
39 | {
40 | var random = new SystemRandomSource();
41 | foreach (var i in Enumerable.Range(0, k))
42 | bandits[i] = new Bandit($"{i}", random.NextDouble());
43 | }
44 |
45 | private void Create(double[] probs)
46 | {
47 | foreach (var i in Enumerable.Range(0, probs.Length))
48 | bandits[i] = new Bandit($"{i}", probs[i]);
49 | }
50 |
51 |
52 | ///
53 | /// 该环境下 当次奖励为赌博机的获得金币数量,无需转换
54 | ///
55 | ///
56 | ///
57 | public override Reward GetReward(Observation observation)
58 | {
59 | var sum = observation.Value!.to_type(torch.ScalarType.Float32)
60 | .sum()
61 | .item();
62 | var reward = new Reward(sum);
63 | return reward;
64 | }
65 |
66 | ///
67 | /// The cumulative reward received by the trajectory of an interaction process
68 | /// apply for evaluate
69 | ///
70 | ///
71 | ///
72 | public override float GetReturn(Episode episode)
73 | {
74 | return episode.Steps.Average(a => a.Reward.Value);
75 | }
76 |
77 | ///
78 | ///
79 | /// 动作,该环境下包含智能体选择的赌博机索引
80 | /// 返回选择的赌博机当次执行后获得的金币数量 0 或 1
81 | public override Observation Update(Act act)
82 | {
83 | var obs = new float[ObservationSpace!.N];
84 | var index = act.Value!.ToInt64();
85 | var bandit = bandits[index];
86 | var value = bandit.Step();
87 | obs[index] = value;
88 |
89 | var obsTensor = torch.from_array(obs, torch.ScalarType.Float32).to(Device);
90 | return new Observation(obsTensor);
91 | }
92 |
93 |
94 | ///
95 | /// 没满20次采样,环境关闭
96 | ///
97 | ///
98 | ///
99 | public override bool IsComplete(int epoch)
100 | {
101 | return epoch >= 20;
102 | }
103 |
104 |
105 | public override string ToString()
106 | {
107 | var str = new StringBuilder();
108 | str.AppendLine(base.ToString());
109 | str.Append(string.Join("\r\n", bandits.Select(a => $"\t{a}")));
110 | return str.ToString();
111 | }
112 | }
113 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Observation.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs
2 | {
3 | ///
4 | /// 观察
5 | ///
6 | public class Observation
7 | {
8 | public Observation(torch.Tensor? state)
9 | {
10 | Value = state;
11 | TimeStamp = DateTime.Now;
12 | }
13 |
14 | ///
15 | /// 观察的张量格式
16 | ///
17 | public torch.Tensor? Value { set; get; }
18 |
19 | ///
20 | /// 观察产生的时间戳
21 | ///
22 | public DateTime TimeStamp { set; get; }
23 |
24 | public Observation To(torch.Device device)
25 | {
26 | return new Observation(Value?.to(device));
27 | }
28 |
29 |
30 | public object Clone()
31 | {
32 | return new Observation(Value) {TimeStamp = TimeStamp};
33 | }
34 |
35 | public override string ToString()
36 | {
37 | return $"Observation\r\n{Value?.ToString(torch.numpy)}";
38 | }
39 | }
40 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Reward.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs
2 | {
3 | ///
4 | /// Reward
5 | ///
6 | public class Reward
7 | {
8 | public Reward(float value)
9 | {
10 | Value = value;
11 | TimeStamp = DateTime.Now;
12 | }
13 |
14 | ///
15 | /// reward
16 | ///
17 | public float Value { set; get; }
18 |
19 | ///
20 | /// TimeStamp of get reward
21 | ///
22 | public DateTime TimeStamp { set; get; }
23 |
24 | public override string ToString()
25 | {
26 | return $"Reward:\t{Value}";
27 | }
28 | }
29 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Space.cs:
--------------------------------------------------------------------------------
1 | using FluentAssertions;
2 | using MathNet.Numerics.Random;
3 |
4 | namespace DeepSharp.RL.Environs
5 | {
6 | ///
7 | /// 空间 动作空间和观察空间父类
8 | ///
9 | public abstract class Space : IDisposable
10 | {
11 | protected Space(
12 | long[] shape,
13 | torch.ScalarType type,
14 | DeviceType deviceType,
15 | long seed)
16 | {
17 | (Shape, Type, DeviceType) = (shape, type, deviceType);
18 | CheckInitParameter(shape, type);
19 | CheckType();
20 | Generator = torch.random.manual_seed(new SystemRandomSource().NextInt64(0, 1000));
21 | N = shape.Aggregate(1, (a, b) => (int) (a * b));
22 | }
23 |
24 | public long N { get; init; }
25 | public long[] Shape { get; }
26 | public torch.ScalarType Type { get; }
27 | public DeviceType DeviceType { get; }
28 | internal torch.Generator Generator { get; }
29 | internal torch.Device Device => new(DeviceType);
30 |
31 | public void Dispose()
32 | {
33 | Generator.Dispose();
34 | }
35 |
36 | ///
37 | /// Returns a sample from the space.
38 | ///
39 | ///
40 | public abstract torch.Tensor Sample();
41 |
42 | public abstract void CheckType();
43 |
44 | ///
45 | /// Generates a tensor whose shape and type are consistent with the space definition.
46 | ///
47 | ///
48 | public virtual torch.Tensor Generate()
49 | {
50 | return torch.zeros(Shape, Type, Device);
51 | }
52 |
53 |
54 | public override string ToString()
55 | {
56 | return $"Space Type: {GetType().Name}\nShape: {Shape}\ndType: {Type} \nN:{N}";
57 | }
58 |
59 | private static void CheckInitParameter(long[] shape, torch.ScalarType type)
60 | {
61 | shape.Should().NotBeNull();
62 | shape.Length.Should().BeGreaterThan(0);
63 | }
64 | }
65 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Spaces/Binary.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs.Spaces
2 | {
3 | ///
4 | /// It's Space only support [0,1]
5 | ///
6 | public class Binary : Disperse
7 | {
8 | public Binary(torch.ScalarType dtype = torch.ScalarType.Int32,
9 | DeviceType deviceType = DeviceType.CUDA, long seed = 1)
10 | : base(2, dtype, deviceType, seed)
11 | {
12 | N = 1;
13 | }
14 | }
15 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Spaces/Box.cs:
--------------------------------------------------------------------------------
1 | using FluentAssertions;
2 |
3 | namespace DeepSharp.RL.Environs.Spaces
4 | {
5 | ///
6 | /// A T-dimensional box that contains every point in the action space.
7 | ///
8 | public class Box : DigitalSpace
9 | {
10 | public Box(float low, float high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) :
11 | base(torch.full(shape, low), torch.full(shape, high), shape, torch.ScalarType.Float32, deviceType, seed)
12 | {
13 | CoculateBounded();
14 | }
15 |
16 | public Box(double low, double high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) :
17 | base(torch.full(shape, low), torch.full(shape, high), shape, torch.ScalarType.Float64, deviceType, seed)
18 | {
19 | CoculateBounded();
20 | }
21 |
22 | public Box(long low, long high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) :
23 | base(torch.full(shape, low), torch.full(shape, high), shape, torch.ScalarType.Int64, deviceType, seed)
24 | {
25 | CoculateBounded();
26 | }
27 |
28 | public Box(int low, int high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) :
29 | base(torch.full(shape, low, torch.ScalarType.Int32), torch.full(shape, high, torch.ScalarType.Int32), shape,
30 | torch.ScalarType.Int32, deviceType, seed)
31 | {
32 | CoculateBounded();
33 | }
34 |
35 | public Box(short low, short high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) :
36 | base(torch.full(shape, low), torch.full(shape, high), shape, torch.ScalarType.Int16, deviceType, seed)
37 | {
38 | CoculateBounded();
39 | }
40 |
41 | public Box(byte low, byte high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) :
42 | base(torch.full(shape, low), torch.full(shape, high), shape, torch.ScalarType.Byte, deviceType, seed)
43 | {
44 | CoculateBounded();
45 | }
46 |
47 | public Box(torch.Tensor low, torch.Tensor high, long[] shape, torch.ScalarType type,
48 | DeviceType deviceType = DeviceType.CUDA, long seed = 1) : base(low, high, shape, type, deviceType, seed)
49 | {
50 | CoculateBounded();
51 | }
52 |
53 | protected torch.Tensor BoundedBelow { get; private set; } = null!;
54 | protected torch.Tensor BoundedAbove { get; private set; } = null!;
55 |
56 | public override torch.Tensor Sample()
57 | {
58 | var unbounded = ~BoundedBelow & ~BoundedAbove;
59 | var uppBounded = ~BoundedBelow & BoundedAbove;
60 | var lowBounded = BoundedBelow & ~BoundedAbove;
61 | var bounded = BoundedBelow & BoundedAbove;
62 |
63 |
64 | var high = Type.ToString().StartsWith("F") ? High : High + 1;
65 | var sample = torch.empty(Shape, Type);
66 |
67 | sample[unbounded] = torch.distributions.Normal(torch.zeros(Shape, torch.ScalarType.Float32),
68 | torch.ones(Shape, torch.ScalarType.Float32))
69 | .sample(1).reshape(Shape)[unbounded].to_type(Type);
70 |
71 | sample[lowBounded] = (Low + torch.distributions.Exponential(torch.ones(Shape, torch.ScalarType.Float32))
72 | .sample(1)
73 | .reshape(Shape))[lowBounded].to_type(Type);
74 |
75 | sample[uppBounded] =
76 | (high - torch.distributions.Exponential(torch.ones(Shape, torch.ScalarType.Float32)).sample(1)
77 | .reshape(Shape))[uppBounded].to_type(Type);
78 |
79 | sample[bounded] =
80 | torch.distributions
81 | .Uniform(Low.to_type(torch.ScalarType.Float32), high.to_type(torch.ScalarType.Float32)).sample(1)
82 | .reshape(Shape)[bounded]
83 | .to_type(Type);
84 |
85 | return sample.to(Device);
86 | }
87 |
88 | public override void CheckType()
89 | {
90 | var acceptType = new[]
91 | {
92 | torch.ScalarType.Byte,
93 | torch.ScalarType.Int8,
94 | torch.ScalarType.Int16,
95 | torch.ScalarType.Int32,
96 | torch.ScalarType.Int64,
97 | torch.ScalarType.Float32,
98 | torch.ScalarType.Float64
99 | };
100 | Type.Should().BeOneOf(acceptType, $"Disperse accept Type in {string.Join(",", acceptType)}");
101 | }
102 |
103 | private void CoculateBounded()
104 | {
105 | Low.shape.Should().Equal(Shape);
106 | High.shape.Should().Equal(Shape);
107 | BoundedBelow = Low > torch.tensor(double.NegativeInfinity);
108 | BoundedAbove = High < torch.tensor(double.PositiveInfinity);
109 | }
110 | }
111 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Spaces/DigitalSpace.cs:
--------------------------------------------------------------------------------
1 | using FluentAssertions;
2 |
3 | namespace DeepSharp.RL.Environs.Spaces
4 | {
5 | public abstract class DigitalSpace : Space
6 | {
7 | protected DigitalSpace(
8 | torch.Tensor low,
9 | torch.Tensor high,
10 | long[] shape,
11 | torch.ScalarType type,
12 | DeviceType deviceType = DeviceType.CUDA,
13 | long seed = 471) : base(shape, type, deviceType, seed)
14 | {
15 | CheckParameters(low, high);
16 | Low = low;
17 | High = high;
18 | }
19 |
20 | protected DigitalSpace(
21 | long low,
22 | long high,
23 | long[] shape,
24 | torch.ScalarType type,
25 | DeviceType deviceType = DeviceType.CUDA,
26 | long seed = 471) : base(shape, type, deviceType, seed)
27 | {
28 | CheckParameters(low, high);
29 | Low = torch.full(shape, low, type);
30 | High = torch.full(shape, high, type);
31 | }
32 |
33 | public torch.Tensor Low { get; }
34 | public torch.Tensor High { get; }
35 |
36 |
37 | ///
38 | /// Generates a tensor whose shape and type are consistent with the space definition.
39 | ///
40 | ///
41 | public override torch.Tensor Generate()
42 | {
43 | return (torch.zeros(Shape, Type) + Low).to(Device);
44 | }
45 |
46 | private void CheckParameters(torch.Tensor low, torch.Tensor high)
47 | {
48 | low.Should().NotBeNull();
49 | high.Should().NotBeNull();
50 | torch.all(low < high).Equals(torch.tensor(true)).Should().Be(true);
51 | }
52 |
53 |
54 | public override void CheckType()
55 | {
56 | var acceptType = new[]
57 | {
58 | torch.ScalarType.Int8,
59 | torch.ScalarType.Int16,
60 | torch.ScalarType.Int32,
61 | torch.ScalarType.Int64
62 | };
63 | Type.Should().BeOneOf(acceptType, $"Disperse accept Type in {string.Join(",", acceptType)}");
64 | }
65 | }
66 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Spaces/Disperse.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs.Spaces
2 | {
3 | ///
4 | /// 一维 离散的动作空间, 采样为编码的动作序号
5 | /// A list of possible actions, where each timestep only one of the actions can be used.
6 | ///
7 | public class Disperse : DigitalSpace
8 | {
9 | ///
10 | /// Disperse Start with 0
11 | /// Discrete(2) will sample {0, 1}
12 | /// Discrete(3) will sample from {0, 1, 2}
13 | ///
14 | ///
15 | ///
16 | ///
17 | ///
18 | public Disperse(long length, torch.ScalarType dtype = torch.ScalarType.Int64,
19 | DeviceType deviceType = DeviceType.CUDA, long seed = 1)
20 | : base(0, 0 + length - 1, new long[] {1}, dtype, deviceType, seed)
21 | {
22 | N = length;
23 | }
24 |
25 | public Disperse(long length, long start, torch.ScalarType dtype = torch.ScalarType.Int64,
26 | DeviceType deviceType = DeviceType.CUDA, long seed = 1)
27 | : base(start, start + length - 1, new long[] {1}, dtype, deviceType, seed)
28 | {
29 | N = length;
30 | }
31 |
32 |
33 | public override torch.Tensor Sample()
34 | {
35 | var device = new torch.Device(DeviceType);
36 | var low = Low.to_type(torch.ScalarType.Int64).item();
37 | var high = (High + 1).to_type(torch.ScalarType.Int64).item();
38 |
39 | var sample = torch.randint(low, high, Shape, device: device).to_type(Type);
40 |
41 | return sample;
42 | }
43 | }
44 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Spaces/MultiBinary.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs.Spaces
2 | {
3 | ///
4 | /// A list of possible actions, where each timestep any of the actions can be used in any combination.
5 | ///
6 | public class MultiBinary : DigitalSpace
7 | {
8 | public MultiBinary(long[] shape, torch.ScalarType type = torch.ScalarType.Int32,
9 | DeviceType deviceType = DeviceType.CUDA, long seed = 471) : base(0, 1, shape, type, deviceType, seed)
10 | {
11 | }
12 |
13 | public MultiBinary(long shape, torch.ScalarType type = torch.ScalarType.Int32,
14 | DeviceType deviceType = DeviceType.CUDA, long seed = 471) : base(0, 1, new[] {shape}, type, deviceType,
15 | seed)
16 | {
17 | }
18 |
19 |
20 | public override torch.Tensor Sample()
21 | {
22 | var high = High + 1;
23 |
24 | var sample = torch.distributions.Uniform(Low.to_type(torch.ScalarType.Float32),
25 | high.to_type(torch.ScalarType.Float32), Generator)
26 | .sample(1)
27 | .reshape(Shape)
28 | .to_type(Type);
29 |
30 | return sample.to(Device);
31 | }
32 | }
33 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Spaces/MultiDisperse.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs.Spaces
2 | {
3 | ///
4 | /// A list of possible actions, where each timestep only one action of each discrete set can be used.
5 | ///
6 | public class MultiDisperse : DigitalSpace
7 | {
8 | public MultiDisperse(torch.Tensor low, torch.Tensor high, long[] shape, torch.ScalarType type,
9 | DeviceType deviceType = DeviceType.CUDA, long seed = 1)
10 | : base(low, high, shape, type, deviceType, seed)
11 | {
12 | }
13 |
14 | public MultiDisperse(long low, long high, long[] shape, torch.ScalarType type,
15 | DeviceType deviceType = DeviceType.CUDA, long seed = 1)
16 | : base(low, high, shape, type, deviceType, seed)
17 | {
18 | }
19 |
20 | public override torch.Tensor Sample()
21 | {
22 | var high = High + 1;
23 |
24 | var sample = torch.distributions.Uniform(Low.to_type(torch.ScalarType.Float32),
25 | high.to_type(torch.ScalarType.Float32), Generator)
26 | .sample(1)
27 | .reshape(Shape)
28 | .to_type(Type);
29 |
30 | return sample.to(Device);
31 | }
32 | }
33 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Step.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs
2 | {
3 | ///
4 | /// Step
5 | ///
6 | public class Step : ICloneable
7 | {
8 | public Step(Observation preState,
9 | Act action,
10 | Observation postState,
11 | Reward reward,
12 | bool isComplete = false,
13 | float priority = 1f)
14 | {
15 | PreState = preState;
16 | Action = action;
17 | Reward = reward;
18 | PostState = postState;
19 | IsComplete = isComplete;
20 | Priority = priority;
21 | }
22 |
23 | public Observation PreState { set; get; }
24 |
25 | ///
26 | /// 动作
27 | ///
28 | public Act Action { set; get; }
29 |
30 | ///
31 | /// 动作后的观察
32 | ///
33 | public Observation PostState { set; get; }
34 |
35 | ///
36 | /// 动作后的奖励
37 | ///
38 | public Reward Reward { set; get; }
39 |
40 | public bool IsComplete { set; get; }
41 |
42 | public float Priority { set; get; }
43 |
44 |
45 | public object Clone()
46 | {
47 | return new Step(PreState, Action, PostState, Reward, IsComplete);
48 | }
49 | }
50 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Wappers/EnvironWarpper.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs.Wappers
2 | {
3 | public abstract class EnvironWarpper
4 | {
5 | protected EnvironWarpper(Environ environ)
6 | {
7 | Environ = environ;
8 | }
9 |
10 | public Environ Environ { set; get; }
11 |
12 | public abstract Step Step(Act act, int epoch);
13 |
14 | public abstract Observation Reset();
15 | }
16 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Environs/Wappers/MaxAndSkipEnv.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Environs.Wappers
2 | {
3 | public abstract class MaxAndSkipEnv : EnvironWarpper
4 | {
5 | public int Skip { protected set; get; }
6 | public Queue Observations { protected set; get; }
7 |
8 | protected MaxAndSkipEnv(Environ environ, int skip)
9 | : base(environ)
10 | {
11 | Skip = skip;
12 | Observations = new Queue(2);
13 | }
14 |
15 |
16 | public override Step Step(Act act, int epoch)
17 | {
18 | var totalReward = 0f;
19 | var isComplete = false;
20 | var oldobs = Environ.Observation!;
21 | foreach (var _ in Enumerable.Range(0, Skip))
22 | {
23 | var step = Environ.Step(act, epoch);
24 | Observations.Enqueue(step.PostState);
25 | totalReward += step.Reward.Value;
26 | if (step.IsComplete)
27 | {
28 | isComplete = true;
29 | break;
30 | }
31 | }
32 |
33 | var obs = Observations.Select(a => a.Value!).ToList();
34 | var max = new Observation(torch.max(torch.vstack(obs)));
35 | var reward = new Reward(totalReward);
36 | return new Step(oldobs, act, max, reward, isComplete);
37 | }
38 |
39 | public override Observation Reset()
40 | {
41 | Observations.Clear();
42 | var obs = Environ.Reset();
43 | Observations.Enqueue(obs);
44 | return obs;
45 | }
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/src/DeepSharp.RL/ExpReplays/EpisodeExpReplay.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using DeepSharp.RL.ExperienceSources;
3 |
4 | namespace DeepSharp.RL.ExpReplays
5 | {
6 | ///
7 | /// Exp Relay apply for Store Episode
8 | ///
9 | public class EpisodeExpReplay
10 | {
11 | public EpisodeExpReplay(int capacity, float gamma)
12 | {
13 | Capacity = capacity;
14 | Gamma = gamma;
15 | Buffers = new Queue(capacity);
16 | }
17 |
18 | ///
19 | /// Capacity of Experience Replay Buffer
20 | ///
21 | public int Capacity { protected set; get; }
22 |
23 | ///
24 | /// Capacity of Experience Replay Buffer
25 | ///
26 | public float Gamma { protected set; get; }
27 |
28 | ///
29 | /// Cache
30 | ///
31 | public Queue Buffers { set; get; }
32 |
33 | public int Size => Buffers.Sum(a => a.Length);
34 |
35 | public void Enqueue(Episode episode, bool isU = true)
36 | {
37 | if (Buffers.Count == Capacity) Buffers.Dequeue();
38 | if (isU)
39 | {
40 | var e = episode.GetReturnEpisode(Gamma);
41 | Buffers.Enqueue(e);
42 | }
43 | else
44 | {
45 | Buffers.Enqueue(episode);
46 | }
47 | }
48 |
49 | public virtual ExperienceCase All()
50 | {
51 | var episodes = Buffers;
52 | var batchStep = episodes.SelectMany(a => a.Steps).ToArray();
53 |
54 | /// Get Array from Steps
55 | var stateArray = batchStep.Select(a => a.PreState.Value!.unsqueeze(0)).ToArray();
56 | var actArray = batchStep.Select(a => a.Action.Value!.unsqueeze(0)).ToArray();
57 | var rewardArray = batchStep.Select(a => a.Reward.Value).ToArray();
58 | var stateNextArray = batchStep.Select(a => a.PostState.Value!.unsqueeze(0)).ToArray();
59 | var doneArray = batchStep.Select(a => a.IsComplete).ToArray();
60 |
61 | /// Convert to VStack
62 | var state = torch.vstack(stateArray);
63 | var actionV = torch.vstack(actArray).to(torch.ScalarType.Int64);
64 | var reward = torch.from_array(rewardArray).view(-1, 1);
65 | var stateNext = torch.vstack(stateNextArray);
66 | var done = torch.from_array(doneArray).reshape(Size);
67 |
68 | var excase = new ExperienceCase(state, actionV, reward, stateNext, done);
69 | return excase;
70 | }
71 |
72 | public void Clear()
73 | {
74 | Buffers.Clear();
75 | }
76 | }
77 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/ExpReplays/ExpReplay.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 | using DeepSharp.RL.ExperienceSources;
3 |
4 | namespace DeepSharp.RL.ExpReplays
5 | {
6 | ///
7 | /// Exp Relay apply for Store steps
8 | ///
9 | public abstract class ExpReplay
10 | {
11 | protected ExpReplay(int capacity = 10000)
12 | {
13 | Capacity = capacity;
14 | Buffers = new Queue(capacity);
15 | }
16 |
17 | ///
18 | /// Capacity of Experience Replay Buffer
19 | ///
20 | public int Capacity { protected set; get; }
21 |
22 | ///
23 | /// Cache
24 | ///
25 | public Queue Buffers { set; get; }
26 |
27 | public int Size => Buffers.Count();
28 |
29 | ///
30 | /// Record a step [State, Action, Reward, NewState]
31 | ///
32 | ///
33 | public virtual void Enqueue(Step step)
34 | {
35 | if (Buffers.Count == Capacity) Buffers.Dequeue();
36 | Buffers.Enqueue(step);
37 | }
38 |
39 | ///
40 | /// Record steps {[State , Action, Reward, NewState],...,[State , Action, Reward, NewState]}
41 | ///
42 | public void Enqueue(IEnumerable steps)
43 | {
44 | steps.ToList().ForEach(Enqueue);
45 | }
46 |
47 | protected abstract Step[] SampleSteps(int batchsize);
48 |
49 | public virtual void Enqueue(Episode episode)
50 | {
51 | Enqueue(episode.Steps);
52 | }
53 |
54 | public virtual ExperienceCase Sample(int batchsize)
55 | {
56 | var batchStep = SampleSteps(batchsize);
57 |
58 | /// Get Array from Steps
59 | var stateArray = batchStep.Select(a => a.PreState.Value!.unsqueeze(0)).ToArray();
60 | var actArray = batchStep.Select(a => a.Action.Value!.unsqueeze(0)).ToArray();
61 | var rewardArray = batchStep.Select(a => a.Reward.Value).ToArray();
62 | var stateNextArray = batchStep.Select(a => a.PostState.Value!.unsqueeze(0)).ToArray();
63 | var doneArray = batchStep.Select(a => a.IsComplete).ToArray();
64 |
65 | /// Convert to VStack
66 | var state = torch.vstack(stateArray);
67 | var actionV = torch.vstack(actArray).to(torch.ScalarType.Int64);
68 | var reward = torch.from_array(rewardArray).reshape(batchsize);
69 | var stateNext = torch.vstack(stateNextArray);
70 | var done = torch.from_array(doneArray).reshape(batchsize);
71 |
72 | var excase = new ExperienceCase(state, actionV, reward, stateNext, done);
73 | return excase;
74 | }
75 |
76 |
77 | public void Clear()
78 | {
79 | Buffers.Clear();
80 | }
81 | }
82 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/ExpReplays/ExperienceCase.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.ExperienceSources
2 | {
3 | ///
4 | ///
5 | public struct ExperienceCase
6 | {
7 | public ExperienceCase(torch.Tensor preState,
8 | torch.Tensor action,
9 | torch.Tensor reward,
10 | torch.Tensor postState,
11 | torch.Tensor done)
12 | {
13 | PreState = preState;
14 | Action = action;
15 | Reward = reward;
16 | PostState = postState;
17 | Done = done;
18 | }
19 |
20 | ///
21 | /// State before action
22 | ///
23 | public torch.Tensor PreState { get; set; }
24 |
25 | ///
26 | /// Action
27 | ///
28 | public torch.Tensor Action { set; get; }
29 |
30 | ///
31 | /// Reward
32 | ///
33 | public torch.Tensor Reward { set; get; }
34 |
35 | ///
36 | /// State after action
37 | ///
38 | public torch.Tensor PostState { set; get; }
39 |
40 | ///
41 | /// Episode is complete?
42 | ///
43 | public torch.Tensor Done { set; get; }
44 | }
45 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/ExpReplays/PrioritizedExpReplay.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 |
3 | namespace DeepSharp.RL.ExpReplays
4 | {
5 | ///
6 | /// Uniform sample from Experience Source Cache
7 | ///
8 | public class PrioritizedExpReplay : ExpReplay
9 | {
10 | ///
11 | ///
12 | /// Capacity of Experience Replay Buffer,recommend 10^5 ~ 10^6
13 | public PrioritizedExpReplay(int capacity = 10000)
14 | : base(capacity)
15 | {
16 | }
17 |
18 |
19 | ///
20 | /// Uniform sample batch size steps from Queue
21 | ///
22 | /// batch size
23 | ///
24 | protected override Step[] SampleSteps(int batchsize)
25 | {
26 | var probs = torch.from_array(Buffers.Select(a => a.Priority).ToArray());
27 | var randomIndex = torch.multinomial(probs, batchsize).data().ToArray();
28 |
29 | var steps = randomIndex
30 | .AsParallel()
31 | .Select(i => Buffers.ElementAt((int) i))
32 | .ToArray();
33 |
34 | return steps;
35 | }
36 | }
37 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/ExpReplays/UniformExpReplay.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Environs;
2 |
3 | namespace DeepSharp.RL.ExpReplays
4 | {
5 | ///
6 | /// Uniform sample from Experience Source Cache
7 | ///
8 | public class UniformExpReplay : ExpReplay
9 | {
10 | ///
11 | ///
12 | /// Capacity of Experience Replay Buffer,recommend 10^5 ~ 10^6
13 | public UniformExpReplay(int capacity = 10000)
14 | : base(capacity)
15 | {
16 | }
17 |
18 |
19 | ///
20 | /// Uniform sample batch size steps from Queue
21 | ///
22 | /// batch size
23 | ///
24 | protected override Step[] SampleSteps(int batchsize)
25 | {
26 | var randomIndex = torch.randint(0, Size, new[] {batchsize}).data().ToArray();
27 |
28 | var steps = randomIndex
29 | .Select(i => Buffers.ElementAt((int) i))
30 | .ToArray();
31 |
32 | return steps;
33 | }
34 | }
35 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Trainers/RLTrainOption.cs:
--------------------------------------------------------------------------------
1 | namespace DeepSharp.RL.Trainers
2 | {
3 | public struct RLTrainOption
4 | {
5 | public RLTrainOption()
6 | {
7 | TrainEpoch = 0;
8 | StopReward = 0;
9 | ValEpisode = 0;
10 | ValInterval = 0;
11 | SaveFolder = string.Empty;
12 | OutTimeSpan = TimeSpan.FromHours(1);
13 | AutoSave = false;
14 | }
15 |
16 | public float StopReward { set; get; }
17 | public int TrainEpoch { set; get; }
18 | public int ValInterval { set; get; }
19 | public int ValEpisode { set; get; }
20 | public string SaveFolder { set; get; }
21 | public TimeSpan OutTimeSpan { set; get; }
22 | public bool AutoSave { set; get; }
23 | }
24 | }
--------------------------------------------------------------------------------
/src/DeepSharp.RL/Trainers/RLTrainer.cs:
--------------------------------------------------------------------------------
1 | using DeepSharp.RL.Agents;
2 | using DeepSharp.RL.Environs;
3 |
4 | namespace DeepSharp.RL.Trainers
5 | {
6 | public class RLTrainer
7 | {
8 | private TrainerCallBack? callback;
9 |
10 | public RLTrainer(Agent agent)
11 | {
12 | Agent = agent;
13 | }
14 |
15 | public RLTrainer(Agent agent, Action