├── .gitignore ├── LICENSE ├── README.md ├── images ├── DQN.png ├── DQN2013.png ├── DQN2015.png ├── DQN2015_.png ├── FrozenLake.png ├── MCOnPolicy.png ├── QLearning.png ├── RL CrossEntroy Demo.png ├── Reinforcement.png └── Sarsa.png ├── resources ├── Mindmap.emmx ├── iris-test.txt └── iris-train.txt └── src ├── DeepSharp.Core ├── DeepSharp.Core.csproj └── Trainer.cs ├── DeepSharp.Dataset ├── DataLoaders │ ├── DataLoader.cs │ ├── DataLoaderConfig.cs │ └── InfiniteDataLoader.cs ├── Datasets │ ├── DataView.cs │ ├── DataViewPair.cs │ ├── Dataset.cs │ ├── StreamHeaderAttribute.cs │ └── StreamHeaderRange.cs ├── DeepSharp.Dataset.csproj ├── DeepSharp.Dataset.csproj.DotSettings └── Usings.cs ├── DeepSharp.ObjectDetection ├── DeepSharp.ObjectDetection.csproj └── Yolo.cs ├── DeepSharp.RL ├── ActionSelectors │ ├── ActionSelector.cs │ ├── ArgmaxActionSelector.cs │ ├── EpsilonActionSelector.cs │ └── ProbActionSelector.cs ├── Agents │ ├── Agent.cs │ ├── LearnOutCome.cs │ ├── Models │ │ ├── QTable.cs │ │ ├── RewardKey.cs │ │ ├── TrasitKey.cs │ │ └── VTable.cs │ ├── Nets │ │ ├── DQNNet.cs │ │ ├── Net.cs │ │ └── PGN.cs │ ├── Others │ │ ├── CrossEntropy.cs │ │ └── CrossEntropyExt.cs │ ├── PolicyAgengt.cs │ ├── PolicyBased │ │ ├── A2C.cs │ │ ├── A3C.cs │ │ ├── ActorCritic.cs │ │ ├── Reinforce.cs │ │ └── ReinforceOriginal.cs │ ├── ValueAgent.cs │ └── ValueBased │ │ ├── DeepQN │ │ ├── CategoricalDQN.cs │ │ ├── DQN.cs │ │ ├── DoubleDQN.cs │ │ ├── DuelingDQN.cs │ │ ├── NDQN.cs │ │ └── NoisyDQN.cs │ │ ├── DynamicPlan │ │ ├── PIDiscountR.cs │ │ ├── PITStepR.cs │ │ ├── PolicyIteration.cs │ │ ├── VIDiscountR.cs │ │ ├── VITStepR.cs │ │ └── ValueIterate.cs │ │ ├── MonteCarlo │ │ ├── MonteCarloOffPolicy.cs │ │ └── MonteCarloOnPolicy.cs │ │ ├── TemporalDifference │ │ ├── QLearning.cs │ │ └── SARSA.cs │ │ └── ValueIteration.cs ├── DeepSharp.RL.csproj ├── DeepSharp.RL.csproj.DotSettings ├── Enumerates │ └── PlayMode.cs ├── Environs │ ├── Act.cs │ ├── Environ.cs │ ├── Episode.cs │ ├── FrozenLake │ │ ├── FrozenLake.cs │ │ ├── LakeRole.cs │ │ └── LakeUnit.cs │ ├── KArmedBandit │ │ ├── Bandit.cs │ │ └── KArmedBandit.cs │ ├── Observation.cs │ ├── Reward.cs │ ├── Space.cs │ ├── Spaces │ │ ├── Binary.cs │ │ ├── Box.cs │ │ ├── DigitalSpace.cs │ │ ├── Disperse.cs │ │ ├── MultiBinary.cs │ │ └── MultiDisperse.cs │ ├── Step.cs │ └── Wappers │ │ ├── EnvironWarpper.cs │ │ └── MaxAndSkipEnv.cs ├── ExpReplays │ ├── EpisodeExpReplay.cs │ ├── ExpReplay.cs │ ├── ExperienceCase.cs │ ├── PrioritizedExpReplay.cs │ └── UniformExpReplay.cs ├── Trainers │ ├── RLTrainOption.cs │ ├── RLTrainer.cs │ └── TrainerCallBack.cs └── Usings.cs ├── DeepSharp.Utility ├── Converters │ └── Convert.cs ├── DeepSharp.Utility.csproj ├── Operations │ └── OpMat.cs ├── TensorEqualityCompare.cs └── Usings.cs ├── DeepSharp.sln ├── DeepSharp.sln.DotSettings ├── RLConsole ├── Program.cs ├── RLConsole.csproj └── Utility.cs ├── Reinforcement Learning.md └── TorchSharpTest ├── AbstractTest.cs ├── DemoTest ├── DemoNet.cs └── IrisTest.cs ├── LossTest └── LossTest.cs ├── RLTest ├── AgentTest.cs ├── EnvironTest │ ├── FrozenLakeTest.cs │ └── KArmedBanditTest.cs ├── ModelTest │ ├── ActionSelectorTest.cs │ ├── BasicTest.cs │ ├── PolicyNetTest.cs │ ├── QTableTest.cs │ ├── SpaceTest.cs │ └── VTableTest.cs ├── PolicyBasedTest │ ├── ActorCriticTest.cs │ └── ReinforceTest.cs ├── TrainerTest │ └── RLTrainTest.cs └── ValueBasedTest │ ├── DQNTest.cs │ ├── MonteCarloTest.cs │ ├── QLearningTest.cs │ └── SARSATest.cs ├── SampleDataset ├── Iris.cs └── IrisOneHot.cs ├── TorchSharpTest.csproj ├── TorchSharpTest.csproj.DotSettings ├── TorchTests ├── DataSetTest.cs ├── ModuleTest.cs ├── SaveLoadTest.cs └── TensorTest.cs └── Usings.cs /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Xin.Pu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /images/DQN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/DQN.png -------------------------------------------------------------------------------- /images/DQN2013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/DQN2013.png -------------------------------------------------------------------------------- /images/DQN2015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/DQN2015.png -------------------------------------------------------------------------------- /images/DQN2015_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/DQN2015_.png -------------------------------------------------------------------------------- /images/FrozenLake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/FrozenLake.png -------------------------------------------------------------------------------- /images/MCOnPolicy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/MCOnPolicy.png -------------------------------------------------------------------------------- /images/QLearning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/QLearning.png -------------------------------------------------------------------------------- /images/RL CrossEntroy Demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/RL CrossEntroy Demo.png -------------------------------------------------------------------------------- /images/Reinforcement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/Reinforcement.png -------------------------------------------------------------------------------- /images/Sarsa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/images/Sarsa.png -------------------------------------------------------------------------------- /resources/Mindmap.emmx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/resources/Mindmap.emmx -------------------------------------------------------------------------------- /resources/iris-test.txt: -------------------------------------------------------------------------------- 1 | #Label Sepal length Sepal width Petal length Petal width 2 | 0 5.1 3.5 1.4 0.2 3 | 0 4.9 3.0 1.4 0.2 4 | 0 4.7 3.2 1.3 0.2 5 | 0 4.6 3.1 1.5 0.2 6 | 0 5.0 3.6 1.4 0.2 7 | 0 5.4 3.9 1.7 0.4 8 | 0 4.6 3.4 1.4 0.3 9 | 0 5.0 3.4 1.5 0.2 10 | 0 4.4 2.9 1.4 0.2 11 | 0 4.9 3.1 1.5 0.1 12 | 1 7.0 3.2 4.7 1.4 13 | 1 6.4 3.2 4.5 1.5 14 | 1 6.9 3.1 4.9 1.5 15 | 1 5.5 2.3 4.0 1.3 16 | 1 6.5 2.8 4.6 1.5 17 | 1 5.7 2.8 4.5 1.3 18 | 1 6.3 3.3 4.7 1.6 19 | 1 4.9 2.4 3.3 1.0 20 | 1 6.6 2.9 4.6 1.3 21 | 1 5.2 2.7 3.9 1.4 22 | 2 6.3 3.3 6.0 2.5 23 | 2 5.8 2.7 5.1 1.9 24 | 2 7.1 3.0 5.9 2.1 25 | 2 6.3 2.9 5.6 1.8 26 | 2 6.5 3.0 5.8 2.2 27 | 2 7.6 3.0 6.6 2.1 28 | 2 4.9 2.5 4.5 1.7 29 | 2 7.3 2.9 6.3 1.8 30 | 2 6.7 2.5 5.8 1.8 31 | 2 7.2 3.6 6.1 2.5 32 | -------------------------------------------------------------------------------- /resources/iris-train.txt: -------------------------------------------------------------------------------- 1 | #Label Sepal length Sepal width Petal length Petal width 2 | 0 5.4 3.7 1.5 0.2 3 | 0 4.8 3.4 1.6 0.2 4 | 0 4.8 3.0 1.4 0.1 5 | 0 4.3 3.0 1.1 0.1 6 | 0 5.8 4.0 1.2 0.2 7 | 0 5.7 4.4 1.5 0.4 8 | 0 5.4 3.9 1.3 0.4 9 | 0 5.1 3.5 1.4 0.3 10 | 0 5.7 3.8 1.7 0.3 11 | 0 5.1 3.8 1.5 0.3 12 | 0 5.4 3.4 1.7 0.2 13 | 0 5.1 3.7 1.5 0.4 14 | 0 4.6 3.6 1.0 0.2 15 | 0 5.1 3.3 1.7 0.5 16 | 0 4.8 3.4 1.9 0.2 17 | 0 5.0 3.0 1.6 0.2 18 | 0 5.0 3.4 1.6 0.4 19 | 0 5.2 3.5 1.5 0.2 20 | 0 5.2 3.4 1.4 0.2 21 | 0 4.7 3.2 1.6 0.2 22 | 0 4.8 3.1 1.6 0.2 23 | 0 5.4 3.4 1.5 0.4 24 | 0 5.2 4.1 1.5 0.1 25 | 0 5.5 4.2 1.4 0.2 26 | 0 4.9 3.1 1.5 0.1 27 | 0 5.0 3.2 1.2 0.2 28 | 0 5.5 3.5 1.3 0.2 29 | 0 4.9 3.1 1.5 0.1 30 | 0 4.4 3.0 1.3 0.2 31 | 0 5.1 3.4 1.5 0.2 32 | 0 5.0 3.5 1.3 0.3 33 | 0 4.5 2.3 1.3 0.3 34 | 0 4.4 3.2 1.3 0.2 35 | 0 5.0 3.5 1.6 0.6 36 | 0 5.1 3.8 1.9 0.4 37 | 0 4.8 3.0 1.4 0.3 38 | 0 5.1 3.8 1.6 0.2 39 | 0 4.6 3.2 1.4 0.2 40 | 0 5.3 3.7 1.5 0.2 41 | 0 5.0 3.3 1.4 0.2 42 | 1 5.0 2.0 3.5 1.0 43 | 1 5.9 3.0 4.2 1.5 44 | 1 6.0 2.2 4.0 1.0 45 | 1 6.1 2.9 4.7 1.4 46 | 1 5.6 2.9 3.6 1.3 47 | 1 6.7 3.1 4.4 1.4 48 | 1 5.6 3.0 4.5 1.5 49 | 1 5.8 2.7 4.1 1.0 50 | 1 6.2 2.2 4.5 1.5 51 | 1 5.6 2.5 3.9 1.1 52 | 1 5.9 3.2 4.8 1.8 53 | 1 6.1 2.8 4.0 1.3 54 | 1 6.3 2.5 4.9 1.5 55 | 1 6.1 2.8 4.7 1.2 56 | 1 6.4 2.9 4.3 1.3 57 | 1 6.6 3.0 4.4 1.4 58 | 1 6.8 2.8 4.8 1.4 59 | 1 6.7 3.0 5.0 1.7 60 | 1 6.0 2.9 4.5 1.5 61 | 1 5.7 2.6 3.5 1.0 62 | 1 5.5 2.4 3.8 1.1 63 | 1 5.5 2.4 3.7 1.0 64 | 1 5.8 2.7 3.9 1.2 65 | 1 6.0 2.7 5.1 1.6 66 | 1 5.4 3.0 4.5 1.5 67 | 1 6.0 3.4 4.5 1.6 68 | 1 6.7 3.1 4.7 1.5 69 | 1 6.3 2.3 4.4 1.3 70 | 1 5.6 3.0 4.1 1.3 71 | 1 5.5 2.5 4.0 1.3 72 | 1 5.5 2.6 4.4 1.2 73 | 1 6.1 3.0 4.6 1.4 74 | 1 5.8 2.6 4.0 1.2 75 | 1 5.0 2.3 3.3 1.0 76 | 1 5.6 2.7 4.2 1.3 77 | 1 5.7 3.0 4.2 1.2 78 | 1 5.7 2.9 4.2 1.3 79 | 1 6.2 2.9 4.3 1.3 80 | 1 5.1 2.5 3.0 1.1 81 | 1 5.7 2.8 4.1 1.3 82 | 2 6.5 3.2 5.1 2.0 83 | 2 6.4 2.7 5.3 1.9 84 | 2 6.8 3.0 5.5 2.1 85 | 2 5.7 2.5 5.0 2.0 86 | 2 5.8 2.8 5.1 2.4 87 | 2 6.4 3.2 5.3 2.3 88 | 2 6.5 3.0 5.5 1.8 89 | 2 7.7 3.8 6.7 2.2 90 | 2 7.7 2.6 6.9 2.3 91 | 2 6.0 2.2 5.0 1.5 92 | 2 6.9 3.2 5.7 2.3 93 | 2 5.6 2.8 4.9 2.0 94 | 2 7.7 2.8 6.7 2.0 95 | 2 6.3 2.7 4.9 1.8 96 | 2 6.7 3.3 5.7 2.1 97 | 2 7.2 3.2 6.0 1.8 98 | 2 6.2 2.8 4.8 1.8 99 | 2 6.1 3.0 4.9 1.8 100 | 2 6.4 2.8 5.6 2.1 101 | 2 7.2 3.0 5.8 1.6 102 | 2 7.4 2.8 6.1 1.9 103 | 2 7.9 3.8 6.4 2.0 104 | 2 6.4 2.8 5.6 2.2 105 | 2 6.3 2.8 5.1 1.5 106 | 2 6.1 2.6 5.6 1.4 107 | 2 7.7 3.0 6.1 2.3 108 | 2 6.3 3.4 5.6 2.4 109 | 2 6.4 3.1 5.5 1.8 110 | 2 6.0 3.0 4.8 1.8 111 | 2 6.9 3.1 5.4 2.1 112 | 2 6.7 3.1 5.6 2.4 113 | 2 6.9 3.1 5.1 2.3 114 | 2 5.8 2.7 5.1 1.9 115 | 2 6.8 3.2 5.9 2.3 116 | 2 6.7 3.3 5.7 2.5 117 | 2 6.7 3.0 5.2 2.3 118 | 2 6.3 2.5 5.0 1.9 119 | 2 6.5 3.0 5.2 2.0 120 | 2 6.2 3.4 5.4 2.3 121 | 2 5.9 3.0 5.1 1.8 122 | -------------------------------------------------------------------------------- /src/DeepSharp.Core/DeepSharp.Core.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | net7.0 5 | enable 6 | enable 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /src/DeepSharp.Core/Trainer.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.Core 2 | { 3 | /// 4 | /// Trainer for simplify deep learning 5 | /// 6 | public abstract class Trainer 7 | { 8 | } 9 | } -------------------------------------------------------------------------------- /src/DeepSharp.Dataset/DataLoaders/DataLoader.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.Dataset 2 | { 3 | /// 4 | /// Basic DataLoader inherit from torch.utils.data.DataLoader 5 | /// 6 | /// 7 | public class DataLoader : torch.utils.data.DataLoader 8 | where T : DataView 9 | { 10 | public DataLoader(Dataset dataset, DataLoaderConfig config) 11 | : base(dataset, config.BatchSize, CollateFunc, config.Shuffle, config.Device, config.Seed, config.NumWorker, 12 | config.DropLast) 13 | { 14 | } 15 | 16 | public static DataViewPair CollateFunc(IEnumerable dataViews, torch.Device device) 17 | { 18 | var views = dataViews.ToList(); 19 | var features = views.Select(a => a.GetFeatures()).ToList(); 20 | var labels = views.Select(a => a.GetLabels()).ToList(); 21 | var result = new DataViewPair(labels, features).To(device); 22 | return result; 23 | } 24 | 25 | public async IAsyncEnumerable GetBatchSample() 26 | { 27 | using var enumerator = GetEnumerator(); 28 | while (enumerator.MoveNext()) yield return enumerator.Current; 29 | await Task.CompletedTask; 30 | } 31 | } 32 | } -------------------------------------------------------------------------------- /src/DeepSharp.Dataset/DataLoaders/DataLoaderConfig.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.Dataset 2 | { 3 | public struct DataLoaderConfig 4 | { 5 | public DataLoaderConfig() 6 | { 7 | Seed = null; 8 | } 9 | 10 | public int BatchSize { set; get; } = 4; 11 | public bool Shuffle { set; get; } = true; 12 | public bool DropLast { set; get; } = true; 13 | public int NumWorker { set; get; } = 1; 14 | public int? Seed { set; get; } 15 | public torch.Device Device { set; get; } = new(DeviceType.CUDA); 16 | } 17 | } -------------------------------------------------------------------------------- /src/DeepSharp.Dataset/DataLoaders/InfiniteDataLoader.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.Dataset 2 | { 3 | /// 4 | /// Infinite DataLoader inherit from DataLoader 5 | /// 6 | /// 7 | public class InfiniteDataLoader : DataLoader 8 | where T : DataView 9 | 10 | { 11 | public InfiniteDataLoader(Dataset dataset, DataLoaderConfig config) : base(dataset, config) 12 | { 13 | IEnumerator = GetEnumerator(); 14 | } 15 | 16 | protected IEnumerator IEnumerator { set; get; } 17 | 18 | public async IAsyncEnumerable GetBatchSample(int sample) 19 | { 20 | var i = 0; 21 | while (i++ < sample) 22 | if (IEnumerator.MoveNext()) 23 | { 24 | yield return IEnumerator.Current; 25 | } 26 | else 27 | { 28 | IEnumerator.Reset(); 29 | IEnumerator.MoveNext(); 30 | yield return IEnumerator.Current; 31 | } 32 | 33 | await Task.CompletedTask; 34 | } 35 | } 36 | } -------------------------------------------------------------------------------- /src/DeepSharp.Dataset/Datasets/DataView.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.Dataset 2 | { 3 | public abstract class DataView 4 | { 5 | /// 6 | /// Get the features with tensor format. 7 | /// 8 | /// 9 | public abstract torch.Tensor GetFeatures(); 10 | 11 | /// 12 | /// Get the labels with tensor format. 13 | /// 14 | /// 15 | public abstract torch.Tensor GetLabels(); 16 | 17 | 18 | /// 19 | /// convert batch DataView to single DataView Pair 20 | /// 21 | /// 22 | /// cpu or cuda 23 | /// 24 | public static DataViewPair FromDataViews(IEnumerable datasetViews, torch.Device device) 25 | { 26 | var views = datasetViews.ToList(); 27 | var features = views.Select(a => a.GetFeatures()).ToList(); 28 | var labels = views.Select(a => a.GetLabels()).ToList(); 29 | var result = new DataViewPair(labels, features).To(device); 30 | return result; 31 | } 32 | } 33 | } -------------------------------------------------------------------------------- /src/DeepSharp.Dataset/Datasets/DataViewPair.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.Dataset 2 | { 3 | public class DataViewPair 4 | { 5 | public DataViewPair(torch.Tensor labels, torch.Tensor features) 6 | { 7 | Labels = labels; 8 | Features = features; 9 | } 10 | 11 | public DataViewPair(IEnumerable labels, IEnumerable features) 12 | { 13 | var labelsArray = labels.ToArray(); 14 | var featuresArray = features.ToArray(); 15 | Labels = torch.vstack(labelsArray); 16 | Features = torch.vstack(featuresArray); 17 | } 18 | 19 | 20 | public torch.Tensor Labels { set; get; } 21 | public torch.Tensor Features { set; get; } 22 | 23 | 24 | internal DataViewPair To(torch.Device device) 25 | { 26 | var res = new DataViewPair(Labels.to(device), Features.to(device)); 27 | return res; 28 | } 29 | 30 | /// 31 | /// Send DataViewPair to CPU device 32 | /// 33 | /// 34 | public DataViewPair cpu() 35 | { 36 | return To(new torch.Device(DeviceType.CPU)); 37 | } 38 | 39 | /// 40 | /// Send DataViewPair to CPU device 41 | /// 42 | /// 43 | public DataViewPair cuda() 44 | { 45 | return To(new torch.Device(DeviceType.CUDA)); 46 | } 47 | 48 | public override string ToString() 49 | { 50 | var strbuild = new StringBuilder(); 51 | strbuild.AppendLine($"{Labels}"); 52 | strbuild.AppendLine($"{Features}"); 53 | return strbuild.ToString(); 54 | } 55 | } 56 | } -------------------------------------------------------------------------------- /src/DeepSharp.Dataset/Datasets/Dataset.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | 3 | namespace DeepSharp.Dataset 4 | { 5 | /// 6 | /// Load DataSet by object-oriented inherit from torch.utils.data.Dataset 7 | /// 8 | /// 9 | public class Dataset : torch.utils.data.Dataset 10 | where T : DataView 11 | { 12 | protected T[] AllData { set; get; } 13 | 14 | public Dataset(string path, char splitChar = '\t', bool hasHeader = true) 15 | { 16 | /// Step 1 Precheck 17 | File.Exists(path).Should().BeTrue($"File {path} should exist."); 18 | 19 | /// Step 2 Read Stream to DataTable or Array 20 | using var stream = new StreamReader(path); 21 | var allline = stream.ReadToEnd() 22 | .Split('\r', '\n') 23 | .Where(a => !string.IsNullOrEmpty(a)) 24 | .ToList(); 25 | if (hasHeader) 26 | allline.RemoveAt(0); 27 | var alldata = allline.Select(l => l.Split(splitChar).ToArray()).ToArray(); 28 | 29 | var fieldDict = GetFieldDict(typeof(T)); 30 | 31 | /// Step 3 According LoadColumnAttribute Change to Data 32 | AllData = alldata 33 | .Select(single => GetData(fieldDict, single)) 34 | .ToArray(); 35 | Count = AllData.Length; 36 | } 37 | 38 | public override long Count { get; } 39 | 40 | public override T GetTensor(long index) 41 | { 42 | return AllData[index]; 43 | } 44 | 45 | 46 | #region protect function 47 | 48 | protected static Dictionary GetFieldDict(Type type) 49 | { 50 | var fieldInfo = type.GetProperties() 51 | .Where(a => a.CustomAttributes 52 | .Any(attributeData => attributeData.AttributeType == typeof(StreamHeaderAttribute))) 53 | .ToList(); 54 | var dict = fieldInfo.ToDictionary( 55 | f => f, 56 | f => f.GetCustomAttribute()!.StreamHeaderRange); 57 | return dict; 58 | } 59 | 60 | protected static T GetData(Dictionary dict, string[] array) 61 | { 62 | var obj = (T) Activator.CreateInstance(typeof(T))!; 63 | dict.ToList().ForEach(p => 64 | { 65 | var fieldInfo = p.Key; 66 | var range = p.Value; 67 | var type = fieldInfo.PropertyType; 68 | 69 | if (range.Min == range.Max) 70 | { 71 | var field = Convert.ChangeType(array[range.Min], type); 72 | fieldInfo.SetValue(obj, field); 73 | } 74 | else if (type.IsArray && range.Max >= range.Min) 75 | { 76 | var len = range.Max - range.Min + 1; 77 | var arr = Activator.CreateInstance(type, len); 78 | Enumerable.Range(0, len).ToList().ForEach(i => 79 | { 80 | var field = Convert.ChangeType(array[range.Min + i], type.GetElementType()!); 81 | type.GetMethod("Set")?.Invoke(arr, new[] {i, field}); 82 | }); 83 | fieldInfo.SetValue(obj, arr); 84 | } 85 | }); 86 | 87 | return obj; 88 | } 89 | 90 | #endregion 91 | } 92 | } -------------------------------------------------------------------------------- /src/DeepSharp.Dataset/Datasets/StreamHeaderAttribute.cs: -------------------------------------------------------------------------------- 1 |  2 | 3 | namespace DeepSharp.Dataset 4 | { 5 | [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)] 6 | public class StreamHeaderAttribute : Attribute 7 | { 8 | public StreamHeaderRange StreamHeaderRange; 9 | 10 | /// Maps member to specific field in text file. 11 | /// The index of the field in the text file. 12 | public StreamHeaderAttribute(int fieldIndex) 13 | { 14 | StreamHeaderRange = new StreamHeaderRange(fieldIndex); 15 | } 16 | 17 | public StreamHeaderAttribute(int min, int max) 18 | { 19 | StreamHeaderRange = new StreamHeaderRange(min, max); 20 | } 21 | } 22 | } -------------------------------------------------------------------------------- /src/DeepSharp.Dataset/Datasets/StreamHeaderRange.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.Dataset 2 | { 3 | /// 4 | /// Specifies the range of indices of input columns that should be mapped to an output column. 5 | /// 6 | public sealed class StreamHeaderRange 7 | { 8 | /// 9 | /// Whether this range includes only other indices not specified. 10 | /// 11 | public bool AllOther; 12 | 13 | /// 14 | /// Whether this range extends to the end of the line, but should be a fixed number of items. 15 | /// If is specified, the fields and are ignored. 16 | /// 17 | public bool AutoEnd; 18 | 19 | /// 20 | /// Force scalar columns to be treated as vectors of length one. 21 | /// 22 | public bool ForceVector; 23 | 24 | /// 25 | /// The maximum index of the column, inclusive. If 26 | /// indicates that the should auto-detect the length 27 | /// of the lines, and read until the end. 28 | /// If max is specified, the fields and are ignored. 29 | /// 30 | public int Max; 31 | 32 | /// 33 | /// The minimum index of the column, inclusive. 34 | /// 35 | public int Min; 36 | 37 | /// 38 | /// Whether this range extends to the end of the line, which can vary from line to line. 39 | /// If is specified, the fields and are ignored. 40 | /// If is , then is ignored. 41 | /// 42 | public bool VariableEnd; 43 | 44 | public StreamHeaderRange() 45 | { 46 | } 47 | 48 | /// 49 | /// A range representing a single value. Will result in a scalar column. 50 | /// 51 | /// The index of the field of the text file to read. 52 | public StreamHeaderRange(int index) 53 | { 54 | index.Should().BeGreaterThanOrEqualTo(0, "Must be non-negative"); 55 | Min = index; 56 | Max = index; 57 | } 58 | 59 | /// 60 | /// A range representing a set of values. Will result in a vector column. 61 | /// 62 | /// The minimum inclusive index of the column. 63 | /// 64 | /// The maximum-inclusive index of the column. If null 65 | /// indicates that the should auto-detect the length 66 | /// of the lines, and read until the end. 67 | /// 68 | public StreamHeaderRange(int min, int max) 69 | { 70 | min.Should().BeGreaterThanOrEqualTo(0, "Must be non-negative"); 71 | max.Should().BeGreaterOrEqualTo(min, "If specified, must be greater than or equal to " + nameof(min)); 72 | 73 | Min = min; 74 | Max = max; 75 | // Note that without the following being set, in the case where there is a single range 76 | // where Min == Max, the result will not be a vector valued but a scalar column. 77 | ForceVector = true; 78 | } 79 | 80 | internal static StreamHeaderRange Parse(string str) 81 | { 82 | str.Should().NotBeNullOrEmpty(); 83 | var res = new StreamHeaderRange(); 84 | if (res.TryParse(str)) return res; 85 | 86 | return null; 87 | } 88 | 89 | private bool TryParse(string str) 90 | { 91 | str.Should().NotBeNullOrEmpty(); 92 | 93 | var ich = str.IndexOfAny(new[] {'-', '~'}); 94 | if (ich < 0) 95 | { 96 | // No "-" or "~". Single integer. 97 | if (!int.TryParse(str, out Min)) return false; 98 | 99 | Max = Min; 100 | return true; 101 | } 102 | 103 | AllOther = str[ich] == '~'; 104 | ForceVector = true; 105 | 106 | if (ich == 0) 107 | { 108 | if (!AllOther) return false; 109 | 110 | Min = 0; 111 | } 112 | else if (!int.TryParse(str.Substring(0, ich), out Min)) 113 | { 114 | return false; 115 | } 116 | 117 | var rest = str.Substring(ich + 1); 118 | if (string.IsNullOrEmpty(rest) || rest == "*") 119 | { 120 | AutoEnd = true; 121 | return true; 122 | } 123 | 124 | if (rest == "**") 125 | { 126 | VariableEnd = true; 127 | return true; 128 | } 129 | 130 | int tmp; 131 | if (!int.TryParse(rest, out tmp)) return false; 132 | 133 | Max = tmp; 134 | return true; 135 | } 136 | 137 | internal bool TryUnparse(StringBuilder sb) 138 | { 139 | sb.Should().NotBeNull(); 140 | var dash = AllOther ? '~' : '-'; 141 | if (Min < 0) return false; 142 | 143 | sb.Append(Min); 144 | if (Max != null) 145 | { 146 | if (Max != Min || ForceVector || AllOther) sb.Append(dash).Append(Max); 147 | } 148 | else if (AutoEnd) 149 | { 150 | sb.Append(dash).Append("*"); 151 | } 152 | else if (VariableEnd) 153 | { 154 | sb.Append(dash).Append("**"); 155 | } 156 | 157 | return true; 158 | } 159 | } 160 | } -------------------------------------------------------------------------------- /src/DeepSharp.Dataset/DeepSharp.Dataset.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | net7.0 5 | enable 6 | enable 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /src/DeepSharp.Dataset/DeepSharp.Dataset.csproj.DotSettings: -------------------------------------------------------------------------------- 1 |  2 | True 3 | True -------------------------------------------------------------------------------- /src/DeepSharp.Dataset/Usings.cs: -------------------------------------------------------------------------------- 1 | global using TorchSharp; 2 | global using OpenCvSharp; 3 | global using System.Text; 4 | global using FluentAssertions; -------------------------------------------------------------------------------- /src/DeepSharp.ObjectDetection/DeepSharp.ObjectDetection.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | net7.0 5 | enable 6 | enable 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /src/DeepSharp.ObjectDetection/Yolo.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.ObjectDetection 2 | { 3 | public class Yolo 4 | { 5 | } 6 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/ActionSelectors/ActionSelector.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.ActionSelectors 2 | { 3 | /// 4 | /// Action Selector help to change Pred-out from Net to Specific action objects 5 | /// 动作选择器,转换网络的输出到具体的动作选择器 6 | /// 7 | public abstract class ActionSelector 8 | { 9 | protected ActionSelector(bool keepDims = false) 10 | { 11 | KeepDims = keepDims; 12 | } 13 | 14 | public bool KeepDims { set; get; } 15 | 16 | public abstract torch.Tensor Select(torch.Tensor probs); 17 | } 18 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/ActionSelectors/ArgmaxActionSelector.cs: -------------------------------------------------------------------------------- 1 | using FluentAssertions; 2 | 3 | namespace DeepSharp.RL.ActionSelectors 4 | { 5 | /// 6 | /// 传入张量,对最后一维执行 Argmax 7 | /// 8 | public class ArgmaxActionSelector : ActionSelector 9 | { 10 | /// 11 | /// 12 | /// 13 | /// action of long format 14 | public override torch.Tensor Select(torch.Tensor probs) 15 | { 16 | probs.dim().Should().Be(2, "ArgmaxActionSelector Support tensor which dims is 2"); 17 | return torch.argmax(probs, -1, KeepDims); 18 | } 19 | } 20 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/ActionSelectors/EpsilonActionSelector.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.ActionSelectors 2 | { 3 | public class EpsilonActionSelector : ActionSelector 4 | { 5 | public EpsilonActionSelector(ActionSelector selector) 6 | { 7 | Selector = selector; 8 | } 9 | 10 | public ActionSelector Selector { protected set; get; } 11 | 12 | public override torch.Tensor Select(torch.Tensor probs) 13 | { 14 | throw new NotImplementedException(); 15 | } 16 | } 17 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/ActionSelectors/ProbActionSelector.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.ActionSelectors 2 | { 3 | /// 4 | /// 输入归一化概率,返回从分布中采样的结果 5 | /// 6 | public class ProbActionSelector : ActionSelector 7 | { 8 | /// 9 | /// 10 | /// such as [[0.8,0.1,0.1],[0.01,0.98,0.01]] 11 | /// 12 | public override torch.Tensor Select(torch.Tensor probs) 13 | { 14 | var dims = probs.dim(); 15 | 16 | return dims switch 17 | { 18 | 1 => GetActionByDim1(probs), 19 | 2 => GetActionByDim2(probs), 20 | _ => throw new NotSupportedException("Support Dim which 1 & 2") 21 | }; 22 | } 23 | 24 | private torch.Tensor GetActionByDim1(torch.Tensor probs) 25 | { 26 | return torch.multinomial(probs, 1); 27 | } 28 | 29 | private torch.Tensor GetActionByDim2(torch.Tensor probs) 30 | { 31 | var width = (int) probs.shape[0]; 32 | 33 | var arr = Enumerable.Range(0, width).Select(i => 34 | { 35 | var tensorIndices = new[] {torch.TensorIndex.Single(i)}; 36 | var prob = probs[tensorIndices]; 37 | 38 | return torch.multinomial(prob, 1); 39 | }).ToList(); 40 | var final = torch.vstack(arr); 41 | 42 | return KeepDims ? final : final.squeeze(-1); 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/Agent.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Enumerates; 2 | using DeepSharp.RL.Environs; 3 | using MathNet.Numerics.Random; 4 | 5 | namespace DeepSharp.RL.Agents 6 | { 7 | /// 8 | /// 智能体 9 | /// 10 | public abstract class Agent 11 | { 12 | protected Agent(Environ env, string name) 13 | { 14 | Name = name; 15 | Environ = env; 16 | Device = env.Device; 17 | } 18 | 19 | 20 | public string Name { protected set; get; } 21 | public torch.Device Device { protected set; get; } 22 | 23 | public long ObservationSize => Environ.ObservationSpace!.N; 24 | 25 | public long ActionSize => Environ.ActionSpace!.N; 26 | 27 | public Environ Environ { protected set; get; } 28 | public float Epsilon { set; get; } = 0.2f; 29 | 30 | 31 | public abstract LearnOutcome Learn(); 32 | 33 | public abstract void Save(string path); 34 | 35 | public abstract void Load(string path); 36 | 37 | 38 | /// 39 | /// Get a random Action 40 | /// 41 | /// 42 | public Act GetSampleAct() 43 | { 44 | return Environ.SampleAct(); 45 | } 46 | 47 | /// 48 | /// Get a action by Policy 49 | /// π(s) 50 | /// 51 | /// current state 52 | /// a action provide by agent's policy 53 | public abstract Act GetPolicyAct(torch.Tensor state); 54 | 55 | 56 | /// 57 | /// Get a action by ε-greedy method 58 | /// π^ε(s) 59 | /// 60 | /// current state 61 | /// 62 | /// 63 | public Act GetEpsilonAct(torch.Tensor state) 64 | { 65 | var d = new SystemRandomSource(); 66 | var v = d.NextDouble(); 67 | var act = v < Epsilon 68 | ? GetSampleAct() 69 | : GetPolicyAct(state); 70 | return act; 71 | } 72 | 73 | 74 | /// 75 | /// Get Episode by Agent 76 | /// 以策略为主,运行得到一个完整片段 77 | /// 78 | /// 奖励 79 | public virtual Episode RunEpisode( 80 | PlayMode playMode = PlayMode.Agent) 81 | { 82 | Environ.Reset(); 83 | var episode = new Episode(); 84 | var epoch = 0; 85 | while (Environ.IsComplete(epoch) == false) 86 | { 87 | epoch++; 88 | var act = playMode switch 89 | { 90 | PlayMode.Sample => GetSampleAct(), 91 | PlayMode.Agent => GetPolicyAct(Environ.Observation!.Value!), 92 | PlayMode.EpsilonGreedy => GetEpsilonAct(Environ.Observation!.Value!), 93 | _ => throw new ArgumentOutOfRangeException(nameof(playMode), playMode, null) 94 | }; 95 | var step = Environ.Step(act, epoch); 96 | episode.Steps.Add(step); 97 | Environ.CallBack?.Invoke(step); 98 | Environ.Observation = step.PostState; /// It's import for Update Observation 99 | } 100 | 101 | var sumReward = Environ.GetReturn(episode); 102 | episode.SumReward = new Reward(sumReward); 103 | return episode; 104 | } 105 | 106 | 107 | /// 108 | /// Get Episodes by Agent 109 | /// 以策略为主,运行得到多个完整片段 110 | /// 111 | /// 奖励 112 | public virtual Episode[] RunEpisodes(int count, 113 | PlayMode playMode = PlayMode.Agent) 114 | { 115 | var episodes = new List(); 116 | foreach (var _ in Enumerable.Repeat(1, count)) 117 | episodes.Add(RunEpisode(playMode)); 118 | 119 | return episodes.ToArray(); 120 | } 121 | 122 | /// 123 | /// 124 | /// test count 125 | /// Average Reward 126 | public float TestEpisodes(int testCount) 127 | { 128 | var episode = RunEpisodes(testCount); 129 | var averageReward = episode.Average(a => a.SumReward.Value); 130 | return averageReward; 131 | } 132 | 133 | 134 | public override string ToString() 135 | { 136 | return $"Agent[{Name}]"; 137 | } 138 | } 139 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/LearnOutCome.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | 3 | namespace DeepSharp.RL.Agents 4 | { 5 | public class LearnOutcome 6 | { 7 | public LearnOutcome() 8 | { 9 | Steps = new List(); 10 | Evaluate = 0; 11 | } 12 | 13 | public LearnOutcome(Step[] steps, float evaluate) 14 | { 15 | Steps = steps.ToList(); 16 | Evaluate = evaluate; 17 | } 18 | 19 | public LearnOutcome(Episode episode) 20 | { 21 | Steps = episode.Steps; 22 | Evaluate = episode.SumReward.Value; 23 | } 24 | 25 | public LearnOutcome(Episode[] episode) 26 | { 27 | Steps = episode.SelectMany(e => e.Steps).ToList(); 28 | Evaluate = episode.Average(a => a.SumReward.Value); 29 | } 30 | 31 | public LearnOutcome(Episode[] episode, float loss) 32 | { 33 | Steps = episode.SelectMany(e => e.Steps).ToList(); 34 | Evaluate = loss; 35 | } 36 | 37 | public List Steps { protected set; get; } 38 | public float Evaluate { set; get; } 39 | 40 | public void AppendStep(Step step) 41 | { 42 | Steps.Add(step); 43 | } 44 | 45 | public void AppendStep(IEnumerable steps) 46 | { 47 | Steps.AddRange(steps); 48 | } 49 | 50 | public void AppendStep(Episode episode) 51 | { 52 | Steps.AddRange(episode.Steps); 53 | } 54 | 55 | public void UpdateEvaluate(float evaluation) 56 | { 57 | Evaluate = evaluation; 58 | } 59 | 60 | public override string ToString() 61 | { 62 | var avrReward = Steps.Average(a => a.Reward.Value); 63 | var message = $"S:{Steps.Count}\tR:{avrReward:F4}\tE:{Evaluate:F4}"; 64 | return message; 65 | } 66 | } 67 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/Models/QTable.cs: -------------------------------------------------------------------------------- 1 | using System.Text; 2 | using DeepSharp.RL.Environs; 3 | 4 | namespace DeepSharp.RL.Agents 5 | { 6 | /// 7 | /// State-Action Value Table 8 | /// Q(s,a) = Cumulative reward 9 | /// 10 | public class QTable : IEquatable 11 | { 12 | public QTable() 13 | { 14 | Return = new Dictionary(); 15 | } 16 | 17 | public Dictionary Return { protected set; get; } 18 | protected List TrasitKeys => Return.Keys.ToList(); 19 | 20 | 21 | public float this[TransitKey transit] 22 | { 23 | get => GetValue(transit); 24 | set => SetValue(transit, value); 25 | } 26 | 27 | 28 | public float this[torch.Tensor state, torch.Tensor action] 29 | { 30 | get => GetValue(new TransitKey(state, action)); 31 | set => SetValue(new TransitKey(state, action), value); 32 | } 33 | 34 | 35 | public bool Equals(QTable? other) 36 | { 37 | if (other == null) return false; 38 | if (other.TrasitKeys.Count != TrasitKeys.Count) return false; 39 | var res = TrasitKeys.All(key => !(Math.Abs(this[key] - other[key]) > 1E-2)); 40 | return res; 41 | } 42 | 43 | 44 | private void SetValue(TransitKey transit, float value) 45 | { 46 | Return[transit] = value; 47 | } 48 | 49 | private float GetValue(TransitKey transit) 50 | { 51 | Return.TryAdd(transit, 0f); 52 | return Return[transit]; 53 | } 54 | 55 | 56 | /// 57 | /// argMax 58 | /// 59 | /// 60 | /// 61 | public Act? GetBestAct(torch.Tensor state) 62 | { 63 | var row = TrasitKeys 64 | .Where(a => a.State.Equals(state)); 65 | 66 | var stateActions = Return 67 | .Where(a => row.Contains(a.Key)).ToList(); 68 | 69 | if (!stateActions.Any()) 70 | return null; 71 | 72 | if (stateActions.All(a => a.Value == 0)) 73 | return null; 74 | 75 | var argMax = stateActions 76 | .MaxBy(a => a.Value); 77 | var act = argMax.Key.Act; 78 | return new Act(act); 79 | } 80 | 81 | /// 82 | /// argMax 83 | /// 84 | /// 85 | /// 86 | public float GetBestValue(torch.Tensor state) 87 | { 88 | var row = TrasitKeys 89 | .Where(a => a.State.Equals(state)); 90 | 91 | var stateActions = Return 92 | .Where(a => row.Contains(a.Key)).ToList(); 93 | 94 | if (!stateActions.Any()) 95 | return 0; 96 | 97 | var bestValue = stateActions 98 | .Max(a => a.Value); 99 | return bestValue; 100 | } 101 | 102 | public override string ToString() 103 | { 104 | var str = new StringBuilder(); 105 | foreach (var keyValuePair in Return.Where(a => a.Value > 0)) 106 | str.AppendLine($"{keyValuePair.Key}\t{keyValuePair.Value:F4}"); 107 | return str.ToString(); 108 | } 109 | } 110 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/Models/RewardKey.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | 3 | namespace DeepSharp.RL.Agents 4 | { 5 | /// 6 | /// 奖励表 复合键 7 | /// 8 | public struct RewardKey 9 | { 10 | public RewardKey(torch.Tensor state, torch.Tensor act, torch.Tensor newState) 11 | { 12 | State = state; 13 | Act = act; 14 | NewState = newState; 15 | } 16 | 17 | public RewardKey(Observation state, Act act, Observation newState) 18 | { 19 | State = state.Value!; 20 | Act = act.Value!; 21 | NewState = newState.Value!; 22 | } 23 | 24 | public torch.Tensor State { set; get; } 25 | public torch.Tensor Act { set; get; } 26 | public torch.Tensor NewState { set; get; } 27 | 28 | 29 | public override string ToString() 30 | { 31 | var state = State.ToString(torch.numpy); 32 | var action = Act.ToString(torch.numpy); 33 | var newState = NewState.ToString(torch.numpy); 34 | 35 | return $"{state} \r\n {action} \r\n {newState}"; 36 | } 37 | 38 | public static bool operator ==(RewardKey x, RewardKey y) 39 | { 40 | return x.Equals(y); 41 | } 42 | 43 | public static bool operator !=(RewardKey x, RewardKey y) 44 | { 45 | return !x.Equals(y); 46 | } 47 | 48 | public bool Equals(RewardKey other) 49 | { 50 | return State.Equals(other.State) && Act.Equals(other.Act) && NewState.Equals(other.NewState); 51 | } 52 | 53 | 54 | public override bool Equals(object? obj) 55 | { 56 | if (obj is RewardKey input) 57 | return Equals(input); 58 | return false; 59 | } 60 | 61 | public override int GetHashCode() 62 | { 63 | return -1; 64 | } 65 | } 66 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/Models/TrasitKey.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | 3 | namespace DeepSharp.RL.Agents 4 | { 5 | /// 6 | /// Key of Transit, which combine state and action. 7 | /// 8 | public class TransitKey 9 | { 10 | public TransitKey(Observation state, Act act) 11 | { 12 | State = state.Value!; 13 | Act = act.Value!; 14 | } 15 | 16 | public TransitKey(torch.Tensor state, torch.Tensor act) 17 | { 18 | State = state; 19 | Act = act; 20 | } 21 | 22 | public torch.Tensor State { protected set; get; } 23 | public torch.Tensor Act { protected set; get; } 24 | 25 | 26 | public static bool operator ==(TransitKey x, TransitKey y) 27 | { 28 | return x.Equals(y); 29 | } 30 | 31 | public static bool operator !=(TransitKey x, TransitKey y) 32 | { 33 | return !x.Equals(y); 34 | } 35 | 36 | public bool Equals(TransitKey other) 37 | { 38 | return State.Equals(other.State) && Act.Equals(other.Act); 39 | } 40 | 41 | public override bool Equals(object? obj) 42 | { 43 | if (obj is TransitKey input) 44 | return Equals(input); 45 | return false; 46 | } 47 | 48 | public override int GetHashCode() 49 | { 50 | return -1; 51 | } 52 | 53 | public override string ToString() 54 | { 55 | return $"{State.ToString(torch.numpy)},{Act.ToString(torch.numpy)}"; 56 | } 57 | } 58 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/Models/VTable.cs: -------------------------------------------------------------------------------- 1 | using System.Text; 2 | 3 | namespace DeepSharp.RL.Agents 4 | { 5 | /// 6 | /// State Value Function 7 | /// V(s) = Cumulative reward 8 | /// 9 | public class VTable : IEquatable 10 | { 11 | public VTable() 12 | { 13 | Return = new Dictionary(); 14 | } 15 | 16 | public Dictionary Return { protected set; get; } 17 | 18 | protected List StateKeys => Return.Keys.ToList(); 19 | 20 | 21 | public float this[torch.Tensor state] 22 | { 23 | get => GetValue(state); 24 | set => SetValue(state, value); 25 | } 26 | 27 | private void SetValue(torch.Tensor state, float value) 28 | { 29 | Return[state] = value; 30 | } 31 | 32 | private float GetValue(torch.Tensor transit) 33 | { 34 | Return.TryAdd(transit, 0f); 35 | return Return[transit]; 36 | } 37 | 38 | public static float operator -(VTable a, VTable b) 39 | { 40 | var keys = a.StateKeys; 41 | return keys.Select(k => Math.Abs(a[k] - b[k])).Max(); 42 | } 43 | 44 | public bool Equals(VTable? other) 45 | { 46 | if (other == null) return false; 47 | if (other.StateKeys.Count != StateKeys.Count) return false; 48 | var res = StateKeys.All(key => !(Math.Abs(this[key] - other[key]) > 1E-2)); 49 | return res; 50 | } 51 | 52 | 53 | public override string ToString() 54 | { 55 | var str = new StringBuilder(); 56 | foreach (var keyValuePair in Return.Where(a => a.Value > 0)) 57 | str.AppendLine($"{keyValuePair.Key.ToString(torch.numpy)}\t{keyValuePair.Value:F4}"); 58 | return str.ToString(); 59 | } 60 | } 61 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/Nets/DQNNet.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Agents 2 | { 3 | /// 4 | /// DQN Dep Model 5 | /// 6 | public sealed class DQNNet : Module 7 | { 8 | private readonly Module conv; 9 | private readonly Module fc; 10 | public torch.Device Device { get; } 11 | public torch.ScalarType ScalarType { get; } 12 | 13 | public DQNNet(long[] inputShape, int actions, 14 | torch.ScalarType scalar = torch.ScalarType.Float32, 15 | DeviceType deviceType = DeviceType.CUDA) : 16 | base("DQNN") 17 | { 18 | ScalarType = scalar; 19 | Device = new torch.Device(deviceType); 20 | var modules = new List<(string, Module)> 21 | { 22 | ("Conv2d1", Conv2d(inputShape[0], 32, 8, 4)), 23 | ("Relu1", ReLU()), 24 | ("Conv2d2", Conv2d(32, 64, 4, 2)), 25 | ("Relu2", ReLU()), 26 | ("Conv2d3", Conv2d(64, 64, 3)), 27 | ("Relu3", ReLU()) 28 | }; 29 | conv = Sequential(modules); 30 | conv.to(Device); 31 | 32 | var convOutSize = GetConvOut(inputShape); 33 | var modules2 = new List<(string, Module)> 34 | { 35 | ("Linear1", Linear(convOutSize, 512)), 36 | ("Relu4", ReLU()), 37 | ("Linear2", Linear(512, actions)) 38 | }; 39 | fc = Sequential(modules2); 40 | fc.to(Device); 41 | 42 | RegisterComponents(); 43 | } 44 | 45 | 46 | public override torch.Tensor forward(torch.Tensor input) 47 | { 48 | var convOut = conv.forward(input).view(input.size(0), -1); 49 | var fcOut = fc.forward(convOut); 50 | return fcOut; 51 | } 52 | 53 | public int GetConvOut(long[] inputShape) 54 | { 55 | var arr = new List {1}; 56 | arr.AddRange(inputShape); 57 | var input = torch.zeros(arr.ToArray(), ScalarType, Device); 58 | var o = conv.forward(input); 59 | var shapes = o.size(); 60 | var outSize = shapes.Aggregate((a, b) => a * b); 61 | return (int) outSize; 62 | } 63 | 64 | protected override void Dispose(bool disposing) 65 | { 66 | if (disposing) 67 | { 68 | conv.Dispose(); 69 | fc.Dispose(); 70 | ClearModules(); 71 | } 72 | 73 | base.Dispose(disposing); 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/Nets/Net.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Agents 2 | { 3 | /// 4 | /// This is demo net to guide how to create a new Module 5 | /// 6 | public sealed class Net : Module 7 | { 8 | private readonly Module layers; 9 | 10 | public Net(long obsSize, long hiddenSize, long actionNum, DeviceType deviceType = DeviceType.CUDA) : 11 | base("Net") 12 | { 13 | var modules = new List<(string, Module)> 14 | { 15 | ("line1", Linear(obsSize, hiddenSize)), 16 | ("relu", ReLU()), 17 | ("line2", Linear(hiddenSize, actionNum)) 18 | }; 19 | layers = Sequential(modules); 20 | layers.to(new torch.Device(deviceType)); 21 | RegisterComponents(); 22 | } 23 | 24 | public override torch.Tensor forward(torch.Tensor input) 25 | { 26 | return layers.forward(input.to_type(torch.ScalarType.Float32)); 27 | } 28 | 29 | protected override void Dispose(bool disposing) 30 | { 31 | if (disposing) 32 | { 33 | layers.Dispose(); 34 | ClearModules(); 35 | } 36 | 37 | base.Dispose(disposing); 38 | } 39 | } 40 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/Nets/PGN.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Agents 2 | { 3 | public sealed class PGN : Module 4 | { 5 | private readonly Module layers; 6 | 7 | public PGN(long obsSize, long hiddenSize, long actionNum, DeviceType deviceType = DeviceType.CUDA) : 8 | base("PolicyNet") 9 | { 10 | var modules = new List<(string, Module)> 11 | { 12 | ("line1", Linear(obsSize, hiddenSize)), 13 | ("relu", ReLU()), 14 | ("line2", Linear(hiddenSize, actionNum)), 15 | ("softmax", Softmax(-1)) 16 | }; 17 | layers = Sequential(modules); 18 | layers.to(new torch.Device(deviceType)); 19 | RegisterComponents(); 20 | } 21 | 22 | public override torch.Tensor forward(torch.Tensor input) 23 | { 24 | return layers.forward(input.to_type(torch.ScalarType.Float32)); 25 | } 26 | 27 | protected override void Dispose(bool disposing) 28 | { 29 | if (disposing) 30 | { 31 | layers.Dispose(); 32 | ClearModules(); 33 | } 34 | 35 | base.Dispose(disposing); 36 | } 37 | } 38 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/Others/CrossEntropy.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.ActionSelectors; 2 | using DeepSharp.RL.Enumerates; 3 | using DeepSharp.RL.Environs; 4 | using FluentAssertions; 5 | using static TorchSharp.torch.optim; 6 | 7 | namespace DeepSharp.RL.Agents 8 | { 9 | /// 10 | /// An Agent base on CrossEntropy Function 11 | /// Cross-Entropy Method 12 | /// http://people.smp.uq.edu.au/DirkKroese/ps/eormsCE.pdf 13 | /// 14 | public class CrossEntropy : Agent 15 | { 16 | public CrossEntropy(Environ environ, 17 | int t, 18 | float percentElite = 0.7f, 19 | int hiddenSize = 100) : base(environ, "CrossEntropy") 20 | { 21 | T = t; 22 | PercentElite = percentElite; 23 | AgentNet = new Net((int) environ.ObservationSpace!.N, hiddenSize, (int) environ.ActionSpace!.N, 24 | Device.type); 25 | Optimizer = Adam(AgentNet.parameters(), 0.01); 26 | Loss = CrossEntropyLoss(); 27 | } 28 | 29 | public int T { protected set; get; } 30 | 31 | public float PercentElite { protected set; get; } 32 | 33 | public Net AgentNet { protected set; get; } 34 | 35 | public Optimizer Optimizer { protected set; get; } 36 | 37 | public Loss Loss { protected set; get; } 38 | 39 | /// 40 | /// 智能体 根据观察 生成动作 概率 分布,并按分布生成下一个动作 41 | /// 42 | /// 43 | /// 44 | public override Act GetPolicyAct(torch.Tensor state) 45 | { 46 | var input = state.unsqueeze(0); 47 | var sm = Softmax(1); 48 | var actionProbs = sm.forward(AgentNet.forward(input)); 49 | var nextAction = new ProbActionSelector().Select(actionProbs); 50 | var action = new Act(nextAction); 51 | return action; 52 | } 53 | 54 | 55 | public override LearnOutcome Learn() 56 | { 57 | var episodes = RunEpisodes(T, PlayMode.Sample); 58 | var elite = GetElite(episodes); 59 | 60 | var oars = elite.SelectMany(a => a.Steps) 61 | .ToList(); 62 | 63 | var observations = oars 64 | .Select(a => a.PostState.Value) 65 | .ToList(); 66 | var actions = oars 67 | .Select(a => a.Action.Value) 68 | .ToList(); 69 | 70 | var observation = torch.vstack(observations!); 71 | var action = torch.vstack(actions!); 72 | 73 | var loss = Learn(observation, action); 74 | 75 | return new LearnOutcome(episodes, loss); 76 | } 77 | 78 | public override void Save(string path) 79 | { 80 | throw new NotImplementedException(); 81 | } 82 | 83 | public override void Load(string path) 84 | { 85 | throw new NotImplementedException(); 86 | } 87 | 88 | /// 89 | /// Replace default Optimizer 90 | /// 91 | /// 92 | public void UpdateOptimizer(Optimizer optimizer) 93 | { 94 | Optimizer = optimizer; 95 | } 96 | 97 | /// 98 | /// Get Elite 99 | /// 100 | /// 101 | /// 102 | /// 103 | public virtual Episode[] GetElite(Episode[] episodes) 104 | { 105 | var reward = episodes 106 | .Select(a => a.SumReward.Value) 107 | .ToArray(); 108 | var rewardP = reward.OrderByDescending(a => a) 109 | .Take((int) (reward.Length * PercentElite)) 110 | .Min(); 111 | 112 | var filterEpisodes = episodes 113 | .Where(e => e.SumReward.Value > rewardP) 114 | .ToArray(); 115 | 116 | return filterEpisodes; 117 | } 118 | 119 | /// 120 | /// core function to update net 121 | /// 122 | /// tensor from multi observations, size: [batch,observation size] 123 | /// tensor from multi actions, size: [batch,action size] 124 | /// loss 125 | internal float Learn(torch.Tensor observations, torch.Tensor actions) 126 | { 127 | observations.shape.Last().Should() 128 | .Be(ObservationSize, $"Agent observations tensor should be [B,{ObservationSize}]"); 129 | 130 | actions = actions.squeeze(-1); 131 | var pred = AgentNet.forward(observations); 132 | var output = Loss.call(pred, actions); 133 | 134 | Optimizer.zero_grad(); 135 | output.backward(); 136 | Optimizer.step(); 137 | 138 | var loss = output.item(); 139 | return loss; 140 | } 141 | } 142 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/Others/CrossEntropyExt.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using static TorchSharp.torch.optim; 3 | 4 | namespace DeepSharp.RL.Agents 5 | { 6 | /// 7 | /// An Agent base on CrossEntropy Function 8 | /// Cross-Entropy Method 9 | /// http://people.smp.uq.edu.au/DirkKroese/ps/eormsCE.pdf 10 | /// 11 | public class CrossEntropyExt : CrossEntropy 12 | 13 | { 14 | public int MemsEliteLength = 30; 15 | public List Start = new(); 16 | 17 | public CrossEntropyExt(Environ environ, 18 | int t, 19 | float percentElite = 0.7f, 20 | int hiddenSize = 100) 21 | : base(environ, t, percentElite, hiddenSize) 22 | { 23 | Optimizer = Adam(AgentNet.parameters(), 0.01); 24 | } 25 | 26 | /// 27 | /// 增加记忆功能,记录历史的精英片段 28 | /// 29 | internal List MemeSteps { set; get; } = new(); 30 | 31 | /// 32 | /// Get Elite 33 | /// 34 | /// 35 | /// 36 | /// 37 | public override Episode[] GetElite(Episode[] episodes) 38 | { 39 | var current = episodes.Select(a => a.DateTime).ToList(); 40 | Start.Add(current.Min()); 41 | if (Start.Count >= 10) 42 | Start.RemoveAt(0); 43 | 44 | var combine = episodes.Concat(MemeSteps).ToList(); 45 | var reward = combine 46 | .Select(a => a.SumReward.Value) 47 | .ToArray(); 48 | var rewardP = reward.OrderByDescending(a => a) 49 | .Take((int) (reward.Length * PercentElite)) 50 | .Min(); 51 | 52 | var filterEpisodes = combine 53 | .Where(e => e.SumReward.Value > rewardP) 54 | .ToArray(); 55 | 56 | MemeSteps = filterEpisodes.Where(a => a.DateTime > Start.Min()) 57 | .OrderByDescending(a => a.SumReward.Value) 58 | .Take(MemsEliteLength) 59 | .ToList(); 60 | 61 | return filterEpisodes; 62 | } 63 | } 64 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/PolicyAgengt.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | 3 | namespace DeepSharp.RL.Agents 4 | { 5 | public abstract class PolicyGradientAgengt : Agent 6 | { 7 | protected PolicyGradientAgengt(Environ env, string name) 8 | : base(env, name) 9 | { 10 | PolicyNet = new PGN(ObservationSize, 128, ActionSize, DeviceType.CPU); 11 | } 12 | 13 | public Module PolicyNet { protected set; get; } 14 | 15 | 16 | /// 17 | /// argmax(a') Q(state,a') 18 | /// 价值表中获取该状态State下最高价值的action' 19 | /// 20 | /// 21 | /// 22 | public override Act GetPolicyAct(torch.Tensor state) 23 | { 24 | var probs = PolicyNet.forward(state.unsqueeze(0)).squeeze(0); 25 | var actIndex = torch.multinomial(probs, 1, true).ToInt32(); 26 | return new Act(torch.from_array(new[] {actIndex})); 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/PolicyBased/A2C.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using DeepSharp.RL.ExpReplays; 3 | using TorchSharp.Modules; 4 | using static TorchSharp.torch.optim; 5 | 6 | namespace DeepSharp.RL.Agents 7 | { 8 | public class A2C : PolicyGradientAgengt 9 | { 10 | public A2C(Environ env, 11 | int batchsize, 12 | float alpha = 0.01f, 13 | float beta = 0.01f, 14 | float gamma = 0.99f) 15 | : base(env, "ActorCritic") 16 | { 17 | Batchsize = batchsize; 18 | Gamma = gamma; 19 | Alpha = alpha; 20 | Beta = beta; 21 | /// Out is V[batchsize,1] 22 | Q = new Net(ObservationSize, 128, 1, DeviceType.CPU); 23 | ExpReplays = new EpisodeExpReplay(batchsize, gamma); 24 | 25 | var parameters = new[] {Q, PolicyNet} 26 | .SelectMany(a => a.parameters()); 27 | Optimizer = Adam(parameters, Alpha); 28 | } 29 | 30 | /// 31 | /// Episodes send to train 32 | /// 33 | public int Batchsize { protected set; get; } 34 | 35 | public float Alpha { protected set; get; } 36 | public float Beta { protected set; get; } 37 | public float Gamma { protected set; get; } 38 | public Module Q { protected set; get; } 39 | public EpisodeExpReplay ExpReplays { protected set; get; } 40 | public Optimizer Optimizer { protected set; get; } 41 | 42 | /// 43 | /// QLearning for VNet 44 | /// 45 | /// 46 | public override LearnOutcome Learn() 47 | { 48 | var learnOutCome = new LearnOutcome(); 49 | 50 | var episodes = RunEpisodes(Batchsize); 51 | 52 | 53 | episodes.ToList().ForEach(e => 54 | { 55 | learnOutCome.AppendStep(e); 56 | ExpReplays.Enqueue(e); 57 | }); 58 | 59 | var experienceCase = ExpReplays.All(); 60 | var state = experienceCase.PreState; 61 | var action = experienceCase.Action; 62 | var valsRef = experienceCase.Reward; 63 | ExpReplays.Clear(); 64 | 65 | Optimizer.zero_grad(); 66 | 67 | var value = Q.forward(state); 68 | 69 | var lossValue = new MSELoss().forward(value, valsRef); 70 | lossValue.backward(); 71 | 72 | var logProbV = torch.log(PolicyNet.forward(state)).gather(1, action); 73 | var logProbActionV = (valsRef - value.detach()) * logProbV; 74 | var lossPolicy = -logProbActionV.mean(); 75 | 76 | 77 | lossPolicy.backward(); 78 | Optimizer.step(); 79 | 80 | 81 | learnOutCome.Evaluate = lossPolicy.item(); 82 | 83 | return learnOutCome; 84 | } 85 | 86 | 87 | public override void Save(string path) 88 | { 89 | if (File.Exists(path)) File.Delete(path); 90 | PolicyNet.save(path); 91 | } 92 | 93 | public override void Load(string path) 94 | { 95 | PolicyNet.load(path); 96 | } 97 | } 98 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/PolicyBased/A3C.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Agents 2 | { 3 | internal class A3C 4 | { 5 | } 6 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/PolicyBased/ActorCritic.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using DeepSharp.RL.ExpReplays; 3 | using OpenCvSharp.Dnn; 4 | using TorchSharp.Modules; 5 | using static TorchSharp.torch.optim; 6 | 7 | namespace DeepSharp.RL.Agents 8 | { 9 | public class ActorCritic : PolicyGradientAgengt 10 | { 11 | public ActorCritic(Environ env, 12 | int batchsize, 13 | float alpha = 0.01f, 14 | float beta = 0.01f, 15 | float gamma = 0.99f) 16 | : base(env, "ActorCritic") 17 | { 18 | Batchsize = batchsize; 19 | Gamma = gamma; 20 | Alpha = alpha; 21 | Beta = beta; 22 | Q = new Net(ObservationSize, 128, ActionSize, DeviceType.CPU); 23 | ExpReplays = new EpisodeExpReplay(batchsize, gamma); 24 | ExpReplaysForPolicy = new EpisodeExpReplay(batchsize, gamma); 25 | var parameters = new[] {Q, PolicyNet} 26 | .SelectMany(a => a.parameters()); 27 | Optimizer = Adam(parameters, Alpha); 28 | } 29 | 30 | /// 31 | /// Episodes send to train 32 | /// 33 | public int Batchsize { protected set; get; } 34 | 35 | public float Alpha { protected set; get; } 36 | public float Beta { protected set; get; } 37 | public float Gamma { protected set; get; } 38 | public Module Q { protected set; get; } 39 | public EpisodeExpReplay ExpReplays { protected set; get; } 40 | public EpisodeExpReplay ExpReplaysForPolicy { protected set; get; } 41 | public Optimizer Optimizer { protected set; get; } 42 | 43 | /// 44 | /// QLearning for VNet 45 | /// 46 | /// 47 | public override LearnOutcome Learn() 48 | { 49 | var learnOutCome = new LearnOutcome(); 50 | 51 | var episodes = RunEpisodes(Batchsize); 52 | 53 | 54 | episodes.ToList().ForEach(e => 55 | { 56 | learnOutCome.AppendStep(e); 57 | ExpReplays.Enqueue(e,false); 58 | ExpReplaysForPolicy.Enqueue(e); 59 | }); 60 | 61 | var experienceCase = ExpReplays.All(); 62 | var state = experienceCase.PreState; 63 | var action = experienceCase.Action; 64 | var reward = experienceCase.Reward; 65 | var poststate = experienceCase.PostState; 66 | ExpReplays.Clear(); 67 | 68 | Optimizer.zero_grad(); 69 | 70 | var stateActionValue = Q.forward(state).gather(1,action); 71 | var nextStateValue = Q.forward(poststate).max(1).values.detach(); 72 | var expectedStatedActionValuey = reward+ nextStateValue * Gamma ; 73 | var lossValue = new MSELoss().forward(stateActionValue, expectedStatedActionValuey); 74 | lossValue.backward(); 75 | 76 | var logProbV = torch.log(PolicyNet.forward(state)).gather(1, action); 77 | var logProbActionV = stateActionValue.detach() * logProbV; 78 | var lossPolicy = -logProbActionV.mean(); 79 | 80 | 81 | lossPolicy.backward(); 82 | Optimizer.step(); 83 | 84 | 85 | learnOutCome.Evaluate = lossPolicy.item(); 86 | 87 | return learnOutCome; 88 | } 89 | 90 | 91 | public override void Save(string path) 92 | { 93 | if (File.Exists(path)) File.Delete(path); 94 | PolicyNet.save(path); 95 | } 96 | 97 | public override void Load(string path) 98 | { 99 | PolicyNet.load(path); 100 | } 101 | } 102 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/PolicyBased/Reinforce.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using DeepSharp.RL.ExpReplays; 3 | using static TorchSharp.torch.optim; 4 | 5 | namespace DeepSharp.RL.Agents 6 | { 7 | /// 8 | /// ReinforceOriginal: Learn by a batch episodes 9 | /// 10 | public class Reinforce : PolicyGradientAgengt 11 | { 12 | public Reinforce(Environ env, 13 | int batchsize = 4, 14 | float gamma = 0.99f, 15 | float alpha = 0.01f) 16 | : base(env, "Reinforce") 17 | { 18 | Batchsize = batchsize; 19 | Gamma = gamma; 20 | Alpha = alpha; 21 | 22 | ExpReplays = new EpisodeExpReplay(batchsize, gamma); 23 | Optimizer = Adam(PolicyNet.parameters(), Alpha); 24 | } 25 | 26 | /// 27 | /// Episodes send to train 28 | /// 29 | public int Batchsize { protected set; get; } 30 | 31 | public float Gamma { protected set; get; } 32 | public float Alpha { protected set; get; } 33 | 34 | 35 | public Optimizer Optimizer { protected set; get; } 36 | 37 | public EpisodeExpReplay ExpReplays { protected set; get; } 38 | 39 | 40 | public override LearnOutcome Learn() 41 | { 42 | var learnOutCome = new LearnOutcome(); 43 | 44 | var episodes = RunEpisodes(Batchsize); 45 | 46 | Optimizer.zero_grad(); 47 | 48 | episodes.ToList().ForEach(e => 49 | { 50 | learnOutCome.AppendStep(e); 51 | ExpReplays.Enqueue(e); 52 | }); 53 | 54 | var experienceCase = ExpReplays.All(); 55 | var state = experienceCase.PreState; 56 | var action = experienceCase.Action; 57 | var qValues = experienceCase.Reward; 58 | ExpReplays.Clear(); 59 | 60 | var logProbV = torch.log(PolicyNet.forward(state)).gather(1, action); 61 | var logProbActionV = qValues * logProbV; 62 | var loss = -logProbActionV.mean(); 63 | 64 | 65 | loss.backward(); 66 | Optimizer.step(); 67 | 68 | 69 | learnOutCome.Evaluate = loss.item(); 70 | 71 | return learnOutCome; 72 | } 73 | 74 | 75 | public override void Save(string path) 76 | { 77 | if (File.Exists(path)) File.Delete(path); 78 | PolicyNet.save(path); 79 | } 80 | 81 | public override void Load(string path) 82 | { 83 | PolicyNet.load(path); 84 | } 85 | } 86 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/PolicyBased/ReinforceOriginal.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using static TorchSharp.torch.optim; 3 | 4 | namespace DeepSharp.RL.Agents 5 | { 6 | /// 7 | /// ReinforceOriginal: Learn after each episode 8 | /// 9 | public class ReinforceOriginal : PolicyGradientAgengt 10 | { 11 | public ReinforceOriginal(Environ env, 12 | float gamma = 0.99f, 13 | float alpha = 0.01f) 14 | : base(env, "ReinforceOriginal") 15 | { 16 | Gamma = gamma; 17 | Alpha = alpha; 18 | Optimizer = Adam(PolicyNet.parameters(), Alpha); 19 | } 20 | 21 | /// 22 | /// Episodes send to train 23 | /// 24 | public int Batchsize { protected set; get; } 25 | 26 | public float Gamma { protected set; get; } 27 | public float Alpha { protected set; get; } 28 | 29 | 30 | public Optimizer Optimizer { protected set; get; } 31 | 32 | 33 | public override LearnOutcome Learn() 34 | { 35 | var learnOutCome = new LearnOutcome(); 36 | 37 | var episode = RunEpisode(); 38 | var steps = episode.Steps; 39 | learnOutCome.AppendStep(episode.Steps); 40 | 41 | steps.Reverse(); 42 | Optimizer.zero_grad(); 43 | 44 | var g = 0f; 45 | 46 | foreach (var s in steps) 47 | { 48 | var reward = s.Reward.Value; 49 | var state = s.PreState.Value!.unsqueeze(0); 50 | var action = s.Action.Value!.view(-1, 1).to(torch.ScalarType.Int64); 51 | var logProb = torch.log(PolicyNet.forward(state)).gather(1, action); 52 | 53 | g = Gamma * g + reward; 54 | 55 | var loss = -logProb * g; 56 | loss.backward(); 57 | learnOutCome.Evaluate = loss.item(); 58 | } 59 | 60 | Optimizer.step(); 61 | 62 | return learnOutCome; 63 | } 64 | 65 | 66 | public override void Save(string path) 67 | { 68 | if (File.Exists(path)) File.Delete(path); 69 | PolicyNet.save(path); 70 | } 71 | 72 | public override void Load(string path) 73 | { 74 | PolicyNet.load(path); 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueAgent.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | 3 | namespace DeepSharp.RL.Agents 4 | { 5 | public abstract class ValueAgent : Agent 6 | { 7 | protected ValueAgent(Environ env, string name) 8 | : base(env, name) 9 | { 10 | QTable = new QTable(); 11 | } 12 | 13 | public QTable QTable { protected set; get; } 14 | 15 | 16 | /// 17 | /// argmax(a') Q(state,a') 18 | /// 价值表中获取该状态State下最高价值的action' 19 | /// 20 | /// 21 | /// 22 | public override Act GetPolicyAct(torch.Tensor state) 23 | { 24 | var action = QTable.GetBestAct(state); 25 | return action ?? GetSampleAct(); 26 | } 27 | 28 | public override void Save(string path) 29 | { 30 | /// Save Q Table 31 | } 32 | 33 | public override void Load(string path) 34 | { 35 | /// Save Q Table 36 | } 37 | } 38 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/DeepQN/CategoricalDQN.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Agents 2 | { 3 | internal class CategoricalDQN 4 | { 5 | } 6 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/DeepQN/DoubleDQN.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Agents 2 | { 3 | internal class DoubleDQN 4 | { 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/DeepQN/DuelingDQN.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Agents 2 | { 3 | internal class DuelingDQN 4 | { 5 | } 6 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/DeepQN/NDQN.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Agents 2 | { 3 | internal class NDQN 4 | { 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/DeepQN/NoisyDQN.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Agents 2 | { 3 | internal class NoisyDQN 4 | { 5 | } 6 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/PIDiscountR.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using DeepSharp.Utility; 3 | 4 | namespace DeepSharp.RL.Agents 5 | { 6 | public class PIDiscountR : PolicyIteration 7 | { 8 | public PIDiscountR(Environ env, Dictionary p, 9 | Dictionary r, int t, float gamma = 0.9f) 10 | : base(env, p, r, t) 11 | { 12 | Gamma = gamma; 13 | } 14 | 15 | public float Gamma { protected set; get; } 16 | 17 | protected override VTable GetVTable(int t) 18 | { 19 | var vNext = new VTable(); 20 | 21 | foreach (var unused in Enumerable.Range(0, t)) 22 | foreach (var x in X) 23 | vNext[x] = RewardKeys 24 | .Where(a => a.State.Equals(x)) 25 | .Sum(r => P[r] * (R[r] + vNext[r.NewState] * Gamma)); 26 | return vNext; 27 | } 28 | 29 | protected override QTable GetQTable(VTable v, int t) 30 | { 31 | var q = new QTable(); 32 | 33 | var states = P.Keys 34 | .Select(a => a.State) 35 | .Distinct(new TensorEqualityCompare()) 36 | .ToArray(); 37 | 38 | var actions = P.Keys 39 | .Select(a => a.Act) 40 | .Distinct(new TensorEqualityCompare()) 41 | .ToArray(); 42 | 43 | foreach (var state in states) 44 | foreach (var action in actions) 45 | { 46 | var value = RewardKeys.Where(a => a.State.Equals(state) && a.Act.Equals(action)) 47 | .Sum(a => P[a] * (R[a] + v[a.NewState] * Gamma)); 48 | q[state, action] = value; 49 | } 50 | 51 | return q; 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/PITStepR.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using DeepSharp.Utility; 3 | 4 | namespace DeepSharp.RL.Agents 5 | { 6 | public class PITStepR : ValueIterate 7 | { 8 | public PITStepR(Environ env, Dictionary p, 9 | Dictionary r, int t = 100) : base(env, p, r, t) 10 | { 11 | } 12 | 13 | protected override VTable GetVTable(int t) 14 | { 15 | var vNext = new VTable(); 16 | 17 | foreach (var i in Enumerable.Range(1, t)) 18 | foreach (var x in X) 19 | vNext[x] = RewardKeys 20 | .Where(a => a.State.Equals(x)) 21 | .Sum(r => P[r] * (R[r] / i + vNext[r.NewState] * (i - 1) / i)); 22 | return vNext; 23 | } 24 | 25 | protected override QTable GetQTable(VTable v, int t) 26 | { 27 | var q = new QTable(); 28 | 29 | var states = P.Keys 30 | .Select(a => a.State) 31 | .Distinct(new TensorEqualityCompare()) 32 | .ToArray(); 33 | 34 | var actions = P.Keys 35 | .Select(a => a.Act) 36 | .Distinct(new TensorEqualityCompare()) 37 | .ToArray(); 38 | 39 | foreach (var state in states) 40 | foreach (var action in actions) 41 | { 42 | var value = RewardKeys.Where(a => a.State.Equals(state) && a.Act.Equals(action)) 43 | .Sum(a => P[a] * (R[a] / t + v[a.NewState] * (t - 1) / t)); 44 | q[state, action] = value; 45 | } 46 | 47 | return q; 48 | } 49 | } 50 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/PolicyIteration.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using DeepSharp.Utility; 3 | 4 | namespace DeepSharp.RL.Agents 5 | { 6 | public abstract class PolicyIteration : ValueAgent 7 | { 8 | /// 9 | /// 10 | /// 11 | /// 12 | /// 13 | /// 14 | /// 15 | protected PolicyIteration(Environ env, Dictionary p, 16 | Dictionary r, int t = 100) 17 | : base(env, "ValueIterate") 18 | { 19 | T = t; 20 | VTable = new VTable(); 21 | P = p; 22 | R = r; 23 | RewardKeys = p.Keys.ToArray(); 24 | X = P.Keys.Select(a => a.State) 25 | .Distinct(new TensorEqualityCompare()) 26 | .ToArray(); 27 | } 28 | 29 | public int T { protected set; get; } 30 | 31 | 32 | public VTable VTable { protected set; get; } 33 | 34 | protected Dictionary P { set; get; } 35 | protected Dictionary R { set; get; } 36 | protected torch.Tensor[] X { set; get; } 37 | protected RewardKey[] RewardKeys { set; get; } 38 | 39 | 40 | public override LearnOutcome Learn() 41 | { 42 | while (true) 43 | { 44 | /// Update VTable 45 | foreach (var t in Enumerable.Range(1, T)) 46 | { 47 | var vNext = GetVTable(t); 48 | VTable = vNext; 49 | } 50 | 51 | /// Policy Iterate 52 | /// Get Policy (argmax Q=> Update QTable) by Value 53 | var qTable = GetQTable(VTable, T); 54 | if (qTable == QTable) break; 55 | QTable = qTable; 56 | } 57 | 58 | return new LearnOutcome(); 59 | } 60 | 61 | protected abstract VTable GetVTable(int t); 62 | 63 | protected abstract QTable GetQTable(VTable vTable, int t); 64 | } 65 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/VIDiscountR.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using DeepSharp.Utility; 3 | 4 | namespace DeepSharp.RL.Agents 5 | { 6 | public class VIDiscountR : ValueIterate 7 | { 8 | public VIDiscountR(Environ env, Dictionary p, 9 | Dictionary r, int t, float gamma = 0.9f, float threshold = 0.1f) 10 | : base(env, p, r, t, threshold) 11 | { 12 | Gamma = gamma; 13 | } 14 | 15 | public float Gamma { protected set; get; } 16 | 17 | protected override VTable GetVTable(int t) 18 | { 19 | var vNext = new VTable(); 20 | 21 | foreach (var unused in Enumerable.Range(0, t)) 22 | foreach (var x in X) 23 | vNext[x] = RewardKeys 24 | .Where(a => a.State.Equals(x)) 25 | .Sum(r => P[r] * (R[r] + vNext[r.NewState] * Gamma)); 26 | return vNext; 27 | } 28 | 29 | protected override QTable GetQTable(VTable v, int t) 30 | { 31 | var q = new QTable(); 32 | 33 | var states = P.Keys 34 | .Select(a => a.State) 35 | .Distinct(new TensorEqualityCompare()) 36 | .ToArray(); 37 | 38 | var actions = P.Keys 39 | .Select(a => a.Act) 40 | .Distinct(new TensorEqualityCompare()) 41 | .ToArray(); 42 | 43 | foreach (var state in states) 44 | foreach (var action in actions) 45 | { 46 | var value = RewardKeys.Where(a => a.State.Equals(state) && a.Act.Equals(action)) 47 | .Sum(a => P[a] * (R[a] + v[a.NewState] * Gamma)); 48 | q[state, action] = value; 49 | } 50 | 51 | return q; 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/VITStepR.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using DeepSharp.Utility; 3 | 4 | namespace DeepSharp.RL.Agents 5 | { 6 | public class VITStepR : ValueIterate 7 | { 8 | public VITStepR(Environ env, Dictionary p, 9 | Dictionary r, int t = 100, float threshold = 0.1f) : base(env, p, r, t, threshold) 10 | { 11 | } 12 | 13 | protected override VTable GetVTable(int t) 14 | { 15 | var vNext = new VTable(); 16 | 17 | foreach (var i in Enumerable.Range(1, t)) 18 | foreach (var x in X) 19 | vNext[x] = RewardKeys 20 | .Where(a => a.State.Equals(x)) 21 | .Sum(r => P[r] * (R[r] / i + vNext[r.NewState] * (i - 1) / i)); 22 | return vNext; 23 | } 24 | 25 | protected override QTable GetQTable(VTable v, int t) 26 | { 27 | var q = new QTable(); 28 | 29 | var states = P.Keys 30 | .Select(a => a.State) 31 | .Distinct(new TensorEqualityCompare()) 32 | .ToArray(); 33 | 34 | var actions = P.Keys 35 | .Select(a => a.Act) 36 | .Distinct(new TensorEqualityCompare()) 37 | .ToArray(); 38 | 39 | foreach (var state in states) 40 | foreach (var action in actions) 41 | { 42 | var value = RewardKeys.Where(a => a.State.Equals(state) && a.Act.Equals(action)) 43 | .Sum(a => P[a] * (R[a] / t + v[a.NewState] * (t - 1) / t)); 44 | q[state, action] = value; 45 | } 46 | 47 | return q; 48 | } 49 | } 50 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/DynamicPlan/ValueIterate.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using DeepSharp.Utility; 3 | 4 | namespace DeepSharp.RL.Agents 5 | { 6 | public abstract class ValueIterate : ValueAgent 7 | { 8 | /// 9 | /// 10 | /// 11 | /// 12 | /// 13 | /// 14 | /// 15 | protected ValueIterate(Environ env, Dictionary p, 16 | Dictionary r, int t = 100, float threshold = 0.1f) 17 | : base(env, "ValueIterate") 18 | { 19 | T = t; 20 | Threshold = threshold; 21 | VTable = new VTable(); 22 | P = p; 23 | R = r; 24 | RewardKeys = p.Keys.ToArray(); 25 | X = P.Keys.Select(a => a.State) 26 | .Distinct(new TensorEqualityCompare()) 27 | .ToArray(); 28 | } 29 | 30 | public int T { protected set; get; } 31 | 32 | /// 33 | /// Convergence Threshold 34 | /// 35 | public float Threshold { protected set; get; } 36 | 37 | public VTable VTable { protected set; get; } 38 | 39 | protected Dictionary P { set; get; } 40 | protected Dictionary R { set; get; } 41 | protected torch.Tensor[] X { set; get; } 42 | protected RewardKey[] RewardKeys { set; get; } 43 | 44 | 45 | public override LearnOutcome Learn() 46 | { 47 | /// Value Iterate 48 | foreach (var t in Enumerable.Range(1, T)) 49 | { 50 | var vNext = GetVTable(t); 51 | if (vNext - VTable < Threshold) 52 | break; 53 | VTable = vNext; 54 | } 55 | 56 | /// Get Policy (argmax Q=> Update QTable) by Value 57 | var qTable = GetQTable(VTable, T); 58 | QTable = qTable; 59 | return new LearnOutcome(); 60 | } 61 | 62 | protected abstract VTable GetVTable(int t); 63 | 64 | protected abstract QTable GetQTable(VTable vTable, int t); 65 | } 66 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/MonteCarlo/MonteCarloOffPolicy.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | 3 | namespace DeepSharp.RL.Agents 4 | { 5 | /// 6 | /// Todo has issue about GetTransitPer? 2023/7/3 7 | /// 8 | public class MonteCarloOffPolicy : ValueAgent 9 | { 10 | public MonteCarloOffPolicy(Environ env, float epsilon = 0.1f, int t = 10) 11 | : base(env, "MonteCarloOffPolicy") 12 | { 13 | Epsilon = epsilon; 14 | T = t; 15 | Count = new Dictionary(); 16 | } 17 | 18 | public int T { protected set; get; } 19 | 20 | public Dictionary Count { protected set; get; } 21 | 22 | 23 | public override LearnOutcome Learn() 24 | { 25 | Environ.Reset(); 26 | var episode = new Episode(); 27 | var epoch = 0; 28 | var act = GetEpsilonAct(Environ.Observation!.Value!); 29 | while (Environ.IsComplete(epoch) == false && epoch < T) 30 | { 31 | epoch++; 32 | var step = Environ.Step(act, epoch); 33 | 34 | episode.Steps.Add(step); 35 | Environ.CallBack?.Invoke(step); 36 | Environ.Observation = step.PostState; /// It's import for Update Observation 37 | } 38 | 39 | Update(episode); 40 | 41 | var sumReward = episode.Steps.Sum(a => a.Reward.Value); 42 | episode.SumReward = new Reward(sumReward); 43 | 44 | var learnOut = new LearnOutcome(episode); 45 | 46 | return learnOut; 47 | } 48 | 49 | 50 | public void Update(Episode episode) 51 | { 52 | var lenth = episode.Length; 53 | var steps = episode.Steps; 54 | foreach (var t in Enumerable.Range(0, lenth)) 55 | { 56 | var step = steps[t]; 57 | var key = new TransitKey(step.PreState, step.Action); 58 | var r = steps.Skip(t).Average(a => a.Reward.Value); 59 | var per = steps.Skip(t).Select(GetTransitPer).Aggregate(1f, (a, b) => a * b); ///Error Here 60 | var finalR = r * per; 61 | var count = GetCount(key); 62 | QTable[key] = (QTable[key] * count + finalR) / (count + 1); 63 | SetCount(key, count + 1); 64 | } 65 | } 66 | 67 | 68 | private float GetTransitPer(Step step) 69 | { 70 | var actPolicy = GetPolicyAct(step.PreState.Value!).Value!; 71 | var actStep = step.Action.Value!; 72 | var actionSpace = Environ.ActionSpace!.N; 73 | var e = 1f; 74 | var per = actPolicy.Equals(actStep) 75 | ? 1 - Epsilon + Epsilon / actionSpace 76 | : Epsilon / actionSpace; 77 | return e / per; 78 | } 79 | 80 | private int GetCount(TransitKey transitKey) 81 | { 82 | Count.TryAdd(transitKey, 0); 83 | return Count[transitKey]; 84 | } 85 | 86 | private void SetCount(TransitKey transitKey, int value) 87 | { 88 | Count[transitKey] = value; 89 | } 90 | } 91 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/MonteCarlo/MonteCarloOnPolicy.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | 3 | namespace DeepSharp.RL.Agents 4 | { 5 | /// 6 | /// Monte Carlo Method On Policy 7 | /// 8 | public class MonteCarloOnPolicy : ValueAgent 9 | { 10 | /// 11 | /// 12 | /// 13 | /// the leaning count of each epoch 14 | public MonteCarloOnPolicy(Environ env, float epsilon = 0.1f, int t = 10) 15 | : base(env, "MonteCarloOnPolicy") 16 | { 17 | Epsilon = epsilon; 18 | T = t; 19 | Count = new Dictionary(); 20 | } 21 | 22 | public int T { protected set; get; } 23 | 24 | public Dictionary Count { protected set; get; } 25 | 26 | 27 | public override LearnOutcome Learn() 28 | { 29 | Environ.Reset(); 30 | var episode = new Episode(); 31 | var epoch = 0; 32 | var act = GetEpsilonAct(Environ.Observation!.Value!); 33 | while (Environ.IsComplete(epoch) == false && epoch < T) 34 | { 35 | epoch++; 36 | var step = Environ.Step(act, epoch); 37 | 38 | episode.Steps.Add(step); 39 | 40 | Environ.CallBack?.Invoke(step); 41 | Environ.Observation = step.PostState; /// It's import for Update Observation 42 | } 43 | 44 | Update(episode); 45 | 46 | var sumReward = episode.Steps.Sum(a => a.Reward.Value); 47 | episode.SumReward = new Reward(sumReward); 48 | 49 | 50 | var learnOut = new LearnOutcome(episode); 51 | 52 | return learnOut; 53 | } 54 | 55 | 56 | public void Update(Episode episode) 57 | { 58 | var lenth = episode.Length; 59 | var steps = episode.Steps; 60 | foreach (var t in Enumerable.Range(0, lenth)) 61 | { 62 | var step = steps[t]; 63 | var key = new TransitKey(step.PreState, step.Action); 64 | var r = steps.Skip(t).Average(a => a.Reward.Value); 65 | var count = GetCount(key); 66 | QTable[key] = (QTable[key] * count + r) / (count + 1); 67 | SetCount(key, count + 1); 68 | } 69 | } 70 | 71 | private int GetCount(TransitKey transitKey) 72 | { 73 | Count.TryAdd(transitKey, 0); 74 | return Count[transitKey]; 75 | } 76 | 77 | private void SetCount(TransitKey transitKey, int value) 78 | { 79 | Count[transitKey] = value; 80 | } 81 | } 82 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/TemporalDifference/QLearning.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | 3 | namespace DeepSharp.RL.Agents 4 | { 5 | public class QLearning : ValueAgent 6 | { 7 | /// 8 | /// 9 | /// 10 | /// epsilon of ε-greedy Policy 11 | /// learning rate 12 | /// rate of discount 13 | public QLearning(Environ env, 14 | float epsilon = 0.1f, 15 | float alpha = 0.2f, 16 | float gamma = 0.9f) : 17 | base(env, "QLearning") 18 | { 19 | Epsilon = epsilon; 20 | Alpha = alpha; 21 | Gamma = gamma; 22 | } 23 | 24 | 25 | public float Alpha { protected set; get; } 26 | public float Gamma { protected set; get; } 27 | 28 | 29 | public override Act GetPolicyAct(torch.Tensor state) 30 | { 31 | var action = QTable.GetBestAct(state); 32 | return action ?? GetSampleAct(); 33 | } 34 | 35 | 36 | public override LearnOutcome Learn() 37 | { 38 | Environ.Reset(); 39 | var episode = new Episode(); 40 | var epoch = 0; 41 | while (Environ.IsComplete(epoch) == false) 42 | { 43 | epoch++; 44 | var epsilonAct = GetEpsilonAct(Environ.Observation!.Value!); 45 | var step = Environ.Step(epsilonAct, epoch); 46 | 47 | Update(step); 48 | 49 | episode.Steps.Add(step); 50 | Environ.CallBack?.Invoke(step); 51 | 52 | Environ.Observation = step.PostState; /// It's import for Update Observation 53 | } 54 | 55 | var sumReward = episode.Steps.Sum(a => a.Reward.Value); 56 | episode.SumReward = new Reward(sumReward); 57 | 58 | return new LearnOutcome(episode); 59 | } 60 | 61 | public void Update(Step step) 62 | { 63 | var s = step.PreState.Value!; 64 | var a = step.Action.Value!; 65 | var r = step.Reward.Value; 66 | var sNext = step.PostState.Value!; 67 | var q = QTable[s, a]; 68 | 69 | var aNext = GetPolicyAct(sNext); 70 | var qNext = QTable[sNext, aNext.Value!]; 71 | 72 | QTable[s, a] = q + Alpha * (r + Gamma * qNext - q); 73 | } 74 | } 75 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Agents/ValueBased/TemporalDifference/SARSA.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | 3 | namespace DeepSharp.RL.Agents 4 | { 5 | public class SARSA : ValueAgent 6 | { 7 | /// 8 | /// 9 | /// 10 | /// epsilon of ε-greedy Policy 11 | /// learning rate 12 | /// rate of discount 13 | public SARSA(Environ env, 14 | float epsilon = 0.1f, 15 | float alpha = 0.2f, 16 | float gamma = 0.9f) : base(env, "SARSA") 17 | { 18 | Epsilon = epsilon; 19 | Alpha = alpha; 20 | Gamma = gamma; 21 | } 22 | 23 | public float Alpha { protected set; get; } 24 | public float Gamma { protected set; get; } 25 | 26 | public override LearnOutcome Learn() 27 | { 28 | Environ.Reset(); 29 | var episode = new Episode(); 30 | var epoch = 0; 31 | var act = GetEpsilonAct(Environ.Observation!.Value!); 32 | while (Environ.IsComplete(epoch) == false) 33 | { 34 | epoch++; 35 | var step = Environ.Step(act, epoch); 36 | 37 | var actNext = Update(step); /// 38 | 39 | episode.Steps.Add(step); 40 | Environ.CallBack?.Invoke(step); 41 | 42 | Environ.Observation = step.PostState; /// It's import for Update Observation 43 | act = actNext; 44 | } 45 | 46 | var sumReward = episode.Steps.Sum(a => a.Reward.Value); 47 | episode.SumReward = new Reward(sumReward); 48 | 49 | return new LearnOutcome(episode); 50 | } 51 | 52 | 53 | public Act Update(Step step) 54 | { 55 | var s = step.PreState.Value!; 56 | var a = step.Action.Value!; 57 | var r = step.Reward.Value; 58 | var sNext = step.PostState.Value!; 59 | var q = QTable[s, a]; 60 | 61 | var aNext = GetEpsilonAct(sNext); /// a' by ε-greedy policy 62 | var qNext = QTable[sNext, aNext.Value!]; 63 | 64 | 65 | QTable[s, a] = q + Alpha * (r + Gamma * qNext - q); 66 | return aNext; 67 | } 68 | } 69 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/DeepSharp.RL.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | net7.0 5 | enable 6 | enable 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /src/DeepSharp.RL/DeepSharp.RL.csproj.DotSettings: -------------------------------------------------------------------------------- 1 |  2 | True 3 | True 4 | True 5 | True 6 | True 7 | True 8 | True 9 | True 10 | True 11 | True 12 | True 13 | True 14 | True -------------------------------------------------------------------------------- /src/DeepSharp.RL/Enumerates/PlayMode.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Enumerates 2 | { 3 | /// 4 | /// Play mode of Agent 5 | /// 6 | public enum PlayMode 7 | { 8 | /// 9 | /// 平均采样 10 | /// 11 | Sample, 12 | 13 | /// 14 | /// 根据智能体的策略 15 | /// 16 | Agent, 17 | 18 | /// 19 | /// Sample(ε) and Agent(1-ε) 20 | /// 21 | EpsilonGreedy 22 | } 23 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Act.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs 2 | { 3 | /// 4 | /// 动作 5 | /// 6 | public class Act : IEqualityComparer 7 | { 8 | public Act(torch.Tensor? action) 9 | { 10 | Value = action; 11 | TimeStamp = DateTime.Now; 12 | } 13 | 14 | /// 15 | /// 奖励的张量格式 16 | /// 17 | public torch.Tensor? Value { set; get; } 18 | 19 | /// 20 | /// 奖励产生的时间戳 21 | /// 22 | public DateTime TimeStamp { set; get; } 23 | 24 | 25 | public bool Equals(Act? x, Act? y) 26 | { 27 | if (ReferenceEquals(x, y)) return true; 28 | if (ReferenceEquals(x, null)) return false; 29 | if (ReferenceEquals(y, null)) return false; 30 | return x.GetType() == y.GetType() && x.Value!.Equals(y.Value!); 31 | } 32 | 33 | public int GetHashCode(Act obj) 34 | { 35 | return HashCode.Combine(obj.TimeStamp, obj.Value); 36 | } 37 | 38 | public Act To(torch.Device device) 39 | { 40 | return new Act(Value!.to(device)); 41 | } 42 | 43 | public override string ToString() 44 | { 45 | return $"{TimeStamp}\t{Value!.ToString(torch.numpy)}"; 46 | } 47 | } 48 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Environ.cs: -------------------------------------------------------------------------------- 1 | using System.Text; 2 | 3 | namespace DeepSharp.RL.Environs 4 | { 5 | /// 6 | /// 环境 7 | /// 提供观察 并给与奖励 8 | /// 9 | public abstract class Environ 10 | where T1 : Space 11 | where T2 : Space 12 | { 13 | public Action? CallBack; 14 | 15 | protected Environ(string name, DeviceType deviceType = DeviceType.CUDA) 16 | { 17 | Name = name; 18 | Device = new torch.Device(deviceType); 19 | Reward = new Reward(0); 20 | ObservationList = new List(); 21 | } 22 | 23 | protected Environ(string name) 24 | : this(name, DeviceType.CPU) 25 | { 26 | } 27 | 28 | 29 | public string Name { set; get; } 30 | 31 | public torch.Device Device { set; get; } 32 | public T1? ActionSpace { protected set; get; } 33 | public T2? ObservationSpace { protected set; get; } 34 | 35 | 36 | /// 37 | /// Observation Current 38 | /// 39 | public Observation? Observation { set; get; } 40 | 41 | /// 42 | /// Reward Current 43 | /// 44 | public Reward Reward { set; get; } 45 | 46 | /// 47 | /// Observation Temp List 48 | /// 49 | public List ObservationList { set; get; } 50 | 51 | public int Life => ObservationList.Count; 52 | 53 | 54 | /// 55 | /// 恢复初始 56 | /// 57 | public virtual Observation Reset() 58 | { 59 | Observation = new Observation(ObservationSpace!.Generate()); 60 | ObservationList = new List {Observation}; 61 | Reward = new Reward(0); 62 | return Observation; 63 | } 64 | 65 | /// 66 | /// 获取一个回合的最终奖励 67 | /// 68 | /// 69 | /// 70 | public abstract float GetReturn(Episode episode); 71 | 72 | public virtual Act SampleAct() 73 | { 74 | return new Act(ActionSpace!.Sample()); 75 | } 76 | 77 | /// 78 | /// Agent provide Act 79 | /// 80 | /// 81 | /// 82 | public virtual Step Step(Act act, int epoch) 83 | { 84 | var state = Observation!; 85 | var stateNew = Update(act); 86 | var reward = GetReward(stateNew); 87 | var complete = IsComplete(epoch); 88 | var step = new Step(state, act, stateNew, reward, complete); 89 | ObservationList.Add(stateNew); 90 | Observation = stateNew; 91 | return step; 92 | } 93 | 94 | 95 | /// 96 | /// Update Environ Observation according with one action from Agent 97 | /// 98 | /// Action from Policy 99 | /// new observation 100 | public abstract Observation Update(Act act); 101 | 102 | 103 | /// 104 | /// Cal Reward from Observation 105 | /// 从观察获取单步奖励的计算方法 106 | /// 107 | /// one observation 108 | /// one reward 109 | public abstract Reward GetReward(Observation observation); 110 | 111 | 112 | /// 113 | /// Check Environ is Complete 114 | /// 判断探索是否结束 115 | /// 116 | /// 117 | /// 118 | public abstract bool IsComplete(int epoch); 119 | 120 | 121 | public override string ToString() 122 | { 123 | var str = new StringBuilder(); 124 | str.AppendLine($"{Name}\tLife:{Life}"); 125 | str.AppendLine(new string('-', 30)); 126 | str.Append($"State:\t{Observation!.Value!.ToString(torch.numpy)}"); 127 | return str.ToString(); 128 | } 129 | } 130 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Episode.cs: -------------------------------------------------------------------------------- 1 | using System.Text; 2 | 3 | namespace DeepSharp.RL.Environs 4 | { 5 | /// 6 | /// 片段 7 | /// 8 | public class Episode 9 | { 10 | public Episode() 11 | { 12 | Steps = new List(); 13 | SumReward = new Reward(0); 14 | DateTime = DateTime.Now; 15 | Evaluate = 0; 16 | } 17 | 18 | public Episode(List steps) 19 | { 20 | Steps = steps; 21 | SumReward = new Reward(0); 22 | DateTime = DateTime.Now; 23 | Evaluate = 0; 24 | } 25 | 26 | public List Steps { set; get; } 27 | public Step this[int i] => Steps[i]; 28 | 29 | public Reward SumReward { set; get; } 30 | public DateTime DateTime { set; get; } 31 | 32 | public bool IsComplete { set; get; } 33 | public int Length => Steps.Count; 34 | 35 | public float Evaluate { set; get; } 36 | 37 | public void Enqueue(Step step) 38 | { 39 | Steps.Add(step); 40 | } 41 | 42 | public int[] GetAction() 43 | { 44 | var actions = Steps 45 | .Select(a => a.Action.Value!.ToInt32()) 46 | .ToArray(); 47 | return actions; 48 | } 49 | 50 | public override string ToString() 51 | { 52 | var str = new StringBuilder(); 53 | str.AppendLine($"Test By Agent: Get Reward {SumReward}"); 54 | Steps.ForEach(s => 55 | { 56 | var state = s.PreState.Value!.ToString(torch.numpy); 57 | var action = s.Action.Value!.ToInt32(); 58 | var reward = s.Reward.Value; 59 | var line = $"{state},{action},{reward}"; 60 | str.AppendLine(line); 61 | }); 62 | return str.ToString(); 63 | } 64 | 65 | 66 | /// 67 | /// Get a Episode which each step's reward estimate to QValue (discount by Gamma) 68 | /// 69 | /// 70 | /// 71 | public Episode GetReturnEpisode(float gamma = 0.9f) 72 | { 73 | var stepsWithReturn = new List(); 74 | 75 | 76 | var sumR = 0f; 77 | var steps = Steps; 78 | steps.Reverse(); 79 | foreach (var s in steps) 80 | { 81 | sumR *= gamma; 82 | sumR += s.Reward.Value; 83 | 84 | var sNew = (Step) s.Clone(); 85 | sNew.Reward = new Reward(sumR); 86 | stepsWithReturn.Add(sNew); 87 | } 88 | 89 | steps.Reverse(); 90 | stepsWithReturn.Reverse(); 91 | var res = new Episode(stepsWithReturn); 92 | return res; 93 | } 94 | } 95 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/FrozenLake/LakeRole.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs 2 | { 3 | public enum LakeRole 4 | { 5 | Ice, 6 | Hole, 7 | Start, 8 | End 9 | } 10 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/FrozenLake/LakeUnit.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs 2 | { 3 | public class LakeUnit 4 | { 5 | /// 6 | /// 7 | /// 8 | /// 9 | public LakeUnit(int row, int column, int index) 10 | { 11 | Column = column; 12 | Row = row; 13 | Index = index; 14 | Role = LakeRole.Ice; 15 | } 16 | 17 | public int Index { set; get; } 18 | public int Row { set; get; } 19 | public int Column { set; get; } 20 | public LakeRole Role { set; get; } 21 | 22 | public override string ToString() 23 | { 24 | switch (Role) 25 | { 26 | case LakeRole.Ice: 27 | return "I"; 28 | case LakeRole.Hole: 29 | return "H"; 30 | case LakeRole.Start: 31 | return "S"; 32 | case LakeRole.End: 33 | return "E"; 34 | default: 35 | throw new ArgumentOutOfRangeException(); 36 | } 37 | } 38 | } 39 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/KArmedBandit/Bandit.cs: -------------------------------------------------------------------------------- 1 | using MathNet.Numerics.Random; 2 | 3 | namespace DeepSharp.RL.Environs 4 | { 5 | /// 6 | /// 简化的赌博机,以Prob 概率吐出一枚硬币 7 | /// 8 | public class Bandit 9 | { 10 | public Bandit(string name, double prob = 0.7) 11 | { 12 | Name = name; 13 | Prob = prob; 14 | RandomSource = new SystemRandomSource(); 15 | } 16 | 17 | protected SystemRandomSource RandomSource { set; get; } 18 | 19 | public double Prob { set; get; } 20 | 21 | public string Name { get; set; } 22 | 23 | 24 | public float Step() 25 | { 26 | var pro = RandomSource.NextDouble(); 27 | return pro <= Prob ? 1 : 0; 28 | } 29 | 30 | public override string ToString() 31 | { 32 | return $"Bandit{Name}:\t{Prob:P}"; 33 | } 34 | } 35 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/KArmedBandit/KArmedBandit.cs: -------------------------------------------------------------------------------- 1 | using System.Text; 2 | using DeepSharp.RL.Environs.Spaces; 3 | using MathNet.Numerics.Random; 4 | 5 | namespace DeepSharp.RL.Environs 6 | { 7 | /// 8 | /// 多臂赌博机,每个赌博机以 0,0.1到1 的概率随机生成 9 | /// 10 | public sealed class KArmedBandit : Environ 11 | { 12 | public KArmedBandit(int k, DeviceType deviceType = DeviceType.CPU) 13 | : base("KArmedBandit", deviceType) 14 | { 15 | bandits = new Bandit[k]; 16 | ActionSpace = new Disperse(k, deviceType: deviceType); 17 | ObservationSpace = new Box(0, 1, new long[] {k}, deviceType); 18 | Create(k); 19 | Reset(); 20 | } 21 | 22 | 23 | public KArmedBandit(double[] probs, DeviceType deviceType = DeviceType.CPU) 24 | : base("KArmedBandit", deviceType) 25 | { 26 | var k = probs.Length; 27 | bandits = new Bandit[k]; 28 | ActionSpace = new Disperse(k, deviceType: deviceType); 29 | ObservationSpace = new Box(0, 1, new long[] {k}, deviceType); 30 | Create(probs); 31 | Reset(); 32 | } 33 | 34 | private Bandit[] bandits { get; } 35 | public Bandit this[int k] => bandits[k]; 36 | 37 | 38 | private void Create(int k) 39 | { 40 | var random = new SystemRandomSource(); 41 | foreach (var i in Enumerable.Range(0, k)) 42 | bandits[i] = new Bandit($"{i}", random.NextDouble()); 43 | } 44 | 45 | private void Create(double[] probs) 46 | { 47 | foreach (var i in Enumerable.Range(0, probs.Length)) 48 | bandits[i] = new Bandit($"{i}", probs[i]); 49 | } 50 | 51 | 52 | /// 53 | /// 该环境下 当次奖励为赌博机的获得金币数量,无需转换 54 | /// 55 | /// 56 | /// 57 | public override Reward GetReward(Observation observation) 58 | { 59 | var sum = observation.Value!.to_type(torch.ScalarType.Float32) 60 | .sum() 61 | .item(); 62 | var reward = new Reward(sum); 63 | return reward; 64 | } 65 | 66 | /// 67 | /// The cumulative reward received by the trajectory of an interaction process 68 | /// apply for evaluate 69 | /// 70 | /// 71 | /// 72 | public override float GetReturn(Episode episode) 73 | { 74 | return episode.Steps.Average(a => a.Reward.Value); 75 | } 76 | 77 | /// 78 | /// 79 | /// 动作,该环境下包含智能体选择的赌博机索引 80 | /// 返回选择的赌博机当次执行后获得的金币数量 0 或 1 81 | public override Observation Update(Act act) 82 | { 83 | var obs = new float[ObservationSpace!.N]; 84 | var index = act.Value!.ToInt64(); 85 | var bandit = bandits[index]; 86 | var value = bandit.Step(); 87 | obs[index] = value; 88 | 89 | var obsTensor = torch.from_array(obs, torch.ScalarType.Float32).to(Device); 90 | return new Observation(obsTensor); 91 | } 92 | 93 | 94 | /// 95 | /// 没满20次采样,环境关闭 96 | /// 97 | /// 98 | /// 99 | public override bool IsComplete(int epoch) 100 | { 101 | return epoch >= 20; 102 | } 103 | 104 | 105 | public override string ToString() 106 | { 107 | var str = new StringBuilder(); 108 | str.AppendLine(base.ToString()); 109 | str.Append(string.Join("\r\n", bandits.Select(a => $"\t{a}"))); 110 | return str.ToString(); 111 | } 112 | } 113 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Observation.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs 2 | { 3 | /// 4 | /// 观察 5 | /// 6 | public class Observation 7 | { 8 | public Observation(torch.Tensor? state) 9 | { 10 | Value = state; 11 | TimeStamp = DateTime.Now; 12 | } 13 | 14 | /// 15 | /// 观察的张量格式 16 | /// 17 | public torch.Tensor? Value { set; get; } 18 | 19 | /// 20 | /// 观察产生的时间戳 21 | /// 22 | public DateTime TimeStamp { set; get; } 23 | 24 | public Observation To(torch.Device device) 25 | { 26 | return new Observation(Value?.to(device)); 27 | } 28 | 29 | 30 | public object Clone() 31 | { 32 | return new Observation(Value) {TimeStamp = TimeStamp}; 33 | } 34 | 35 | public override string ToString() 36 | { 37 | return $"Observation\r\n{Value?.ToString(torch.numpy)}"; 38 | } 39 | } 40 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Reward.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs 2 | { 3 | /// 4 | /// Reward 5 | /// 6 | public class Reward 7 | { 8 | public Reward(float value) 9 | { 10 | Value = value; 11 | TimeStamp = DateTime.Now; 12 | } 13 | 14 | /// 15 | /// reward 16 | /// 17 | public float Value { set; get; } 18 | 19 | /// 20 | /// TimeStamp of get reward 21 | /// 22 | public DateTime TimeStamp { set; get; } 23 | 24 | public override string ToString() 25 | { 26 | return $"Reward:\t{Value}"; 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Space.cs: -------------------------------------------------------------------------------- 1 | using FluentAssertions; 2 | using MathNet.Numerics.Random; 3 | 4 | namespace DeepSharp.RL.Environs 5 | { 6 | /// 7 | /// 空间 动作空间和观察空间父类 8 | /// 9 | public abstract class Space : IDisposable 10 | { 11 | protected Space( 12 | long[] shape, 13 | torch.ScalarType type, 14 | DeviceType deviceType, 15 | long seed) 16 | { 17 | (Shape, Type, DeviceType) = (shape, type, deviceType); 18 | CheckInitParameter(shape, type); 19 | CheckType(); 20 | Generator = torch.random.manual_seed(new SystemRandomSource().NextInt64(0, 1000)); 21 | N = shape.Aggregate(1, (a, b) => (int) (a * b)); 22 | } 23 | 24 | public long N { get; init; } 25 | public long[] Shape { get; } 26 | public torch.ScalarType Type { get; } 27 | public DeviceType DeviceType { get; } 28 | internal torch.Generator Generator { get; } 29 | internal torch.Device Device => new(DeviceType); 30 | 31 | public void Dispose() 32 | { 33 | Generator.Dispose(); 34 | } 35 | 36 | /// 37 | /// Returns a sample from the space. 38 | /// 39 | /// 40 | public abstract torch.Tensor Sample(); 41 | 42 | public abstract void CheckType(); 43 | 44 | /// 45 | /// Generates a tensor whose shape and type are consistent with the space definition. 46 | /// 47 | /// 48 | public virtual torch.Tensor Generate() 49 | { 50 | return torch.zeros(Shape, Type, Device); 51 | } 52 | 53 | 54 | public override string ToString() 55 | { 56 | return $"Space Type: {GetType().Name}\nShape: {Shape}\ndType: {Type} \nN:{N}"; 57 | } 58 | 59 | private static void CheckInitParameter(long[] shape, torch.ScalarType type) 60 | { 61 | shape.Should().NotBeNull(); 62 | shape.Length.Should().BeGreaterThan(0); 63 | } 64 | } 65 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Spaces/Binary.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs.Spaces 2 | { 3 | /// 4 | /// It's Space only support [0,1] 5 | /// 6 | public class Binary : Disperse 7 | { 8 | public Binary(torch.ScalarType dtype = torch.ScalarType.Int32, 9 | DeviceType deviceType = DeviceType.CUDA, long seed = 1) 10 | : base(2, dtype, deviceType, seed) 11 | { 12 | N = 1; 13 | } 14 | } 15 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Spaces/Box.cs: -------------------------------------------------------------------------------- 1 | using FluentAssertions; 2 | 3 | namespace DeepSharp.RL.Environs.Spaces 4 | { 5 | /// 6 | /// A T-dimensional box that contains every point in the action space. 7 | /// 8 | public class Box : DigitalSpace 9 | { 10 | public Box(float low, float high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) : 11 | base(torch.full(shape, low), torch.full(shape, high), shape, torch.ScalarType.Float32, deviceType, seed) 12 | { 13 | CoculateBounded(); 14 | } 15 | 16 | public Box(double low, double high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) : 17 | base(torch.full(shape, low), torch.full(shape, high), shape, torch.ScalarType.Float64, deviceType, seed) 18 | { 19 | CoculateBounded(); 20 | } 21 | 22 | public Box(long low, long high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) : 23 | base(torch.full(shape, low), torch.full(shape, high), shape, torch.ScalarType.Int64, deviceType, seed) 24 | { 25 | CoculateBounded(); 26 | } 27 | 28 | public Box(int low, int high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) : 29 | base(torch.full(shape, low, torch.ScalarType.Int32), torch.full(shape, high, torch.ScalarType.Int32), shape, 30 | torch.ScalarType.Int32, deviceType, seed) 31 | { 32 | CoculateBounded(); 33 | } 34 | 35 | public Box(short low, short high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) : 36 | base(torch.full(shape, low), torch.full(shape, high), shape, torch.ScalarType.Int16, deviceType, seed) 37 | { 38 | CoculateBounded(); 39 | } 40 | 41 | public Box(byte low, byte high, long[] shape, DeviceType deviceType = DeviceType.CUDA, long seed = 1) : 42 | base(torch.full(shape, low), torch.full(shape, high), shape, torch.ScalarType.Byte, deviceType, seed) 43 | { 44 | CoculateBounded(); 45 | } 46 | 47 | public Box(torch.Tensor low, torch.Tensor high, long[] shape, torch.ScalarType type, 48 | DeviceType deviceType = DeviceType.CUDA, long seed = 1) : base(low, high, shape, type, deviceType, seed) 49 | { 50 | CoculateBounded(); 51 | } 52 | 53 | protected torch.Tensor BoundedBelow { get; private set; } = null!; 54 | protected torch.Tensor BoundedAbove { get; private set; } = null!; 55 | 56 | public override torch.Tensor Sample() 57 | { 58 | var unbounded = ~BoundedBelow & ~BoundedAbove; 59 | var uppBounded = ~BoundedBelow & BoundedAbove; 60 | var lowBounded = BoundedBelow & ~BoundedAbove; 61 | var bounded = BoundedBelow & BoundedAbove; 62 | 63 | 64 | var high = Type.ToString().StartsWith("F") ? High : High + 1; 65 | var sample = torch.empty(Shape, Type); 66 | 67 | sample[unbounded] = torch.distributions.Normal(torch.zeros(Shape, torch.ScalarType.Float32), 68 | torch.ones(Shape, torch.ScalarType.Float32)) 69 | .sample(1).reshape(Shape)[unbounded].to_type(Type); 70 | 71 | sample[lowBounded] = (Low + torch.distributions.Exponential(torch.ones(Shape, torch.ScalarType.Float32)) 72 | .sample(1) 73 | .reshape(Shape))[lowBounded].to_type(Type); 74 | 75 | sample[uppBounded] = 76 | (high - torch.distributions.Exponential(torch.ones(Shape, torch.ScalarType.Float32)).sample(1) 77 | .reshape(Shape))[uppBounded].to_type(Type); 78 | 79 | sample[bounded] = 80 | torch.distributions 81 | .Uniform(Low.to_type(torch.ScalarType.Float32), high.to_type(torch.ScalarType.Float32)).sample(1) 82 | .reshape(Shape)[bounded] 83 | .to_type(Type); 84 | 85 | return sample.to(Device); 86 | } 87 | 88 | public override void CheckType() 89 | { 90 | var acceptType = new[] 91 | { 92 | torch.ScalarType.Byte, 93 | torch.ScalarType.Int8, 94 | torch.ScalarType.Int16, 95 | torch.ScalarType.Int32, 96 | torch.ScalarType.Int64, 97 | torch.ScalarType.Float32, 98 | torch.ScalarType.Float64 99 | }; 100 | Type.Should().BeOneOf(acceptType, $"Disperse accept Type in {string.Join(",", acceptType)}"); 101 | } 102 | 103 | private void CoculateBounded() 104 | { 105 | Low.shape.Should().Equal(Shape); 106 | High.shape.Should().Equal(Shape); 107 | BoundedBelow = Low > torch.tensor(double.NegativeInfinity); 108 | BoundedAbove = High < torch.tensor(double.PositiveInfinity); 109 | } 110 | } 111 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Spaces/DigitalSpace.cs: -------------------------------------------------------------------------------- 1 | using FluentAssertions; 2 | 3 | namespace DeepSharp.RL.Environs.Spaces 4 | { 5 | public abstract class DigitalSpace : Space 6 | { 7 | protected DigitalSpace( 8 | torch.Tensor low, 9 | torch.Tensor high, 10 | long[] shape, 11 | torch.ScalarType type, 12 | DeviceType deviceType = DeviceType.CUDA, 13 | long seed = 471) : base(shape, type, deviceType, seed) 14 | { 15 | CheckParameters(low, high); 16 | Low = low; 17 | High = high; 18 | } 19 | 20 | protected DigitalSpace( 21 | long low, 22 | long high, 23 | long[] shape, 24 | torch.ScalarType type, 25 | DeviceType deviceType = DeviceType.CUDA, 26 | long seed = 471) : base(shape, type, deviceType, seed) 27 | { 28 | CheckParameters(low, high); 29 | Low = torch.full(shape, low, type); 30 | High = torch.full(shape, high, type); 31 | } 32 | 33 | public torch.Tensor Low { get; } 34 | public torch.Tensor High { get; } 35 | 36 | 37 | /// 38 | /// Generates a tensor whose shape and type are consistent with the space definition. 39 | /// 40 | /// 41 | public override torch.Tensor Generate() 42 | { 43 | return (torch.zeros(Shape, Type) + Low).to(Device); 44 | } 45 | 46 | private void CheckParameters(torch.Tensor low, torch.Tensor high) 47 | { 48 | low.Should().NotBeNull(); 49 | high.Should().NotBeNull(); 50 | torch.all(low < high).Equals(torch.tensor(true)).Should().Be(true); 51 | } 52 | 53 | 54 | public override void CheckType() 55 | { 56 | var acceptType = new[] 57 | { 58 | torch.ScalarType.Int8, 59 | torch.ScalarType.Int16, 60 | torch.ScalarType.Int32, 61 | torch.ScalarType.Int64 62 | }; 63 | Type.Should().BeOneOf(acceptType, $"Disperse accept Type in {string.Join(",", acceptType)}"); 64 | } 65 | } 66 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Spaces/Disperse.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs.Spaces 2 | { 3 | /// 4 | /// 一维 离散的动作空间, 采样为编码的动作序号 5 | /// A list of possible actions, where each timestep only one of the actions can be used. 6 | /// 7 | public class Disperse : DigitalSpace 8 | { 9 | /// 10 | /// Disperse Start with 0 11 | /// Discrete(2) will sample {0, 1} 12 | /// Discrete(3) will sample from {0, 1, 2} 13 | /// 14 | /// 15 | /// 16 | /// 17 | /// 18 | public Disperse(long length, torch.ScalarType dtype = torch.ScalarType.Int64, 19 | DeviceType deviceType = DeviceType.CUDA, long seed = 1) 20 | : base(0, 0 + length - 1, new long[] {1}, dtype, deviceType, seed) 21 | { 22 | N = length; 23 | } 24 | 25 | public Disperse(long length, long start, torch.ScalarType dtype = torch.ScalarType.Int64, 26 | DeviceType deviceType = DeviceType.CUDA, long seed = 1) 27 | : base(start, start + length - 1, new long[] {1}, dtype, deviceType, seed) 28 | { 29 | N = length; 30 | } 31 | 32 | 33 | public override torch.Tensor Sample() 34 | { 35 | var device = new torch.Device(DeviceType); 36 | var low = Low.to_type(torch.ScalarType.Int64).item(); 37 | var high = (High + 1).to_type(torch.ScalarType.Int64).item(); 38 | 39 | var sample = torch.randint(low, high, Shape, device: device).to_type(Type); 40 | 41 | return sample; 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Spaces/MultiBinary.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs.Spaces 2 | { 3 | /// 4 | /// A list of possible actions, where each timestep any of the actions can be used in any combination. 5 | /// 6 | public class MultiBinary : DigitalSpace 7 | { 8 | public MultiBinary(long[] shape, torch.ScalarType type = torch.ScalarType.Int32, 9 | DeviceType deviceType = DeviceType.CUDA, long seed = 471) : base(0, 1, shape, type, deviceType, seed) 10 | { 11 | } 12 | 13 | public MultiBinary(long shape, torch.ScalarType type = torch.ScalarType.Int32, 14 | DeviceType deviceType = DeviceType.CUDA, long seed = 471) : base(0, 1, new[] {shape}, type, deviceType, 15 | seed) 16 | { 17 | } 18 | 19 | 20 | public override torch.Tensor Sample() 21 | { 22 | var high = High + 1; 23 | 24 | var sample = torch.distributions.Uniform(Low.to_type(torch.ScalarType.Float32), 25 | high.to_type(torch.ScalarType.Float32), Generator) 26 | .sample(1) 27 | .reshape(Shape) 28 | .to_type(Type); 29 | 30 | return sample.to(Device); 31 | } 32 | } 33 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Spaces/MultiDisperse.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs.Spaces 2 | { 3 | /// 4 | /// A list of possible actions, where each timestep only one action of each discrete set can be used. 5 | /// 6 | public class MultiDisperse : DigitalSpace 7 | { 8 | public MultiDisperse(torch.Tensor low, torch.Tensor high, long[] shape, torch.ScalarType type, 9 | DeviceType deviceType = DeviceType.CUDA, long seed = 1) 10 | : base(low, high, shape, type, deviceType, seed) 11 | { 12 | } 13 | 14 | public MultiDisperse(long low, long high, long[] shape, torch.ScalarType type, 15 | DeviceType deviceType = DeviceType.CUDA, long seed = 1) 16 | : base(low, high, shape, type, deviceType, seed) 17 | { 18 | } 19 | 20 | public override torch.Tensor Sample() 21 | { 22 | var high = High + 1; 23 | 24 | var sample = torch.distributions.Uniform(Low.to_type(torch.ScalarType.Float32), 25 | high.to_type(torch.ScalarType.Float32), Generator) 26 | .sample(1) 27 | .reshape(Shape) 28 | .to_type(Type); 29 | 30 | return sample.to(Device); 31 | } 32 | } 33 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Step.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs 2 | { 3 | /// 4 | /// Step 5 | /// 6 | public class Step : ICloneable 7 | { 8 | public Step(Observation preState, 9 | Act action, 10 | Observation postState, 11 | Reward reward, 12 | bool isComplete = false, 13 | float priority = 1f) 14 | { 15 | PreState = preState; 16 | Action = action; 17 | Reward = reward; 18 | PostState = postState; 19 | IsComplete = isComplete; 20 | Priority = priority; 21 | } 22 | 23 | public Observation PreState { set; get; } 24 | 25 | /// 26 | /// 动作 27 | /// 28 | public Act Action { set; get; } 29 | 30 | /// 31 | /// 动作后的观察 32 | /// 33 | public Observation PostState { set; get; } 34 | 35 | /// 36 | /// 动作后的奖励 37 | /// 38 | public Reward Reward { set; get; } 39 | 40 | public bool IsComplete { set; get; } 41 | 42 | public float Priority { set; get; } 43 | 44 | 45 | public object Clone() 46 | { 47 | return new Step(PreState, Action, PostState, Reward, IsComplete); 48 | } 49 | } 50 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Wappers/EnvironWarpper.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs.Wappers 2 | { 3 | public abstract class EnvironWarpper 4 | { 5 | protected EnvironWarpper(Environ environ) 6 | { 7 | Environ = environ; 8 | } 9 | 10 | public Environ Environ { set; get; } 11 | 12 | public abstract Step Step(Act act, int epoch); 13 | 14 | public abstract Observation Reset(); 15 | } 16 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Environs/Wappers/MaxAndSkipEnv.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Environs.Wappers 2 | { 3 | public abstract class MaxAndSkipEnv : EnvironWarpper 4 | { 5 | public int Skip { protected set; get; } 6 | public Queue Observations { protected set; get; } 7 | 8 | protected MaxAndSkipEnv(Environ environ, int skip) 9 | : base(environ) 10 | { 11 | Skip = skip; 12 | Observations = new Queue(2); 13 | } 14 | 15 | 16 | public override Step Step(Act act, int epoch) 17 | { 18 | var totalReward = 0f; 19 | var isComplete = false; 20 | var oldobs = Environ.Observation!; 21 | foreach (var _ in Enumerable.Range(0, Skip)) 22 | { 23 | var step = Environ.Step(act, epoch); 24 | Observations.Enqueue(step.PostState); 25 | totalReward += step.Reward.Value; 26 | if (step.IsComplete) 27 | { 28 | isComplete = true; 29 | break; 30 | } 31 | } 32 | 33 | var obs = Observations.Select(a => a.Value!).ToList(); 34 | var max = new Observation(torch.max(torch.vstack(obs))); 35 | var reward = new Reward(totalReward); 36 | return new Step(oldobs, act, max, reward, isComplete); 37 | } 38 | 39 | public override Observation Reset() 40 | { 41 | Observations.Clear(); 42 | var obs = Environ.Reset(); 43 | Observations.Enqueue(obs); 44 | return obs; 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/DeepSharp.RL/ExpReplays/EpisodeExpReplay.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using DeepSharp.RL.ExperienceSources; 3 | 4 | namespace DeepSharp.RL.ExpReplays 5 | { 6 | /// 7 | /// Exp Relay apply for Store Episode 8 | /// 9 | public class EpisodeExpReplay 10 | { 11 | public EpisodeExpReplay(int capacity, float gamma) 12 | { 13 | Capacity = capacity; 14 | Gamma = gamma; 15 | Buffers = new Queue(capacity); 16 | } 17 | 18 | /// 19 | /// Capacity of Experience Replay Buffer 20 | /// 21 | public int Capacity { protected set; get; } 22 | 23 | /// 24 | /// Capacity of Experience Replay Buffer 25 | /// 26 | public float Gamma { protected set; get; } 27 | 28 | /// 29 | /// Cache 30 | /// 31 | public Queue Buffers { set; get; } 32 | 33 | public int Size => Buffers.Sum(a => a.Length); 34 | 35 | public void Enqueue(Episode episode, bool isU = true) 36 | { 37 | if (Buffers.Count == Capacity) Buffers.Dequeue(); 38 | if (isU) 39 | { 40 | var e = episode.GetReturnEpisode(Gamma); 41 | Buffers.Enqueue(e); 42 | } 43 | else 44 | { 45 | Buffers.Enqueue(episode); 46 | } 47 | } 48 | 49 | public virtual ExperienceCase All() 50 | { 51 | var episodes = Buffers; 52 | var batchStep = episodes.SelectMany(a => a.Steps).ToArray(); 53 | 54 | /// Get Array from Steps 55 | var stateArray = batchStep.Select(a => a.PreState.Value!.unsqueeze(0)).ToArray(); 56 | var actArray = batchStep.Select(a => a.Action.Value!.unsqueeze(0)).ToArray(); 57 | var rewardArray = batchStep.Select(a => a.Reward.Value).ToArray(); 58 | var stateNextArray = batchStep.Select(a => a.PostState.Value!.unsqueeze(0)).ToArray(); 59 | var doneArray = batchStep.Select(a => a.IsComplete).ToArray(); 60 | 61 | /// Convert to VStack 62 | var state = torch.vstack(stateArray); 63 | var actionV = torch.vstack(actArray).to(torch.ScalarType.Int64); 64 | var reward = torch.from_array(rewardArray).view(-1, 1); 65 | var stateNext = torch.vstack(stateNextArray); 66 | var done = torch.from_array(doneArray).reshape(Size); 67 | 68 | var excase = new ExperienceCase(state, actionV, reward, stateNext, done); 69 | return excase; 70 | } 71 | 72 | public void Clear() 73 | { 74 | Buffers.Clear(); 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/ExpReplays/ExpReplay.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using DeepSharp.RL.ExperienceSources; 3 | 4 | namespace DeepSharp.RL.ExpReplays 5 | { 6 | /// 7 | /// Exp Relay apply for Store steps 8 | /// 9 | public abstract class ExpReplay 10 | { 11 | protected ExpReplay(int capacity = 10000) 12 | { 13 | Capacity = capacity; 14 | Buffers = new Queue(capacity); 15 | } 16 | 17 | /// 18 | /// Capacity of Experience Replay Buffer 19 | /// 20 | public int Capacity { protected set; get; } 21 | 22 | /// 23 | /// Cache 24 | /// 25 | public Queue Buffers { set; get; } 26 | 27 | public int Size => Buffers.Count(); 28 | 29 | /// 30 | /// Record a step [State, Action, Reward, NewState] 31 | /// 32 | /// 33 | public virtual void Enqueue(Step step) 34 | { 35 | if (Buffers.Count == Capacity) Buffers.Dequeue(); 36 | Buffers.Enqueue(step); 37 | } 38 | 39 | /// 40 | /// Record steps {[State , Action, Reward, NewState],...,[State , Action, Reward, NewState]} 41 | /// 42 | public void Enqueue(IEnumerable steps) 43 | { 44 | steps.ToList().ForEach(Enqueue); 45 | } 46 | 47 | protected abstract Step[] SampleSteps(int batchsize); 48 | 49 | public virtual void Enqueue(Episode episode) 50 | { 51 | Enqueue(episode.Steps); 52 | } 53 | 54 | public virtual ExperienceCase Sample(int batchsize) 55 | { 56 | var batchStep = SampleSteps(batchsize); 57 | 58 | /// Get Array from Steps 59 | var stateArray = batchStep.Select(a => a.PreState.Value!.unsqueeze(0)).ToArray(); 60 | var actArray = batchStep.Select(a => a.Action.Value!.unsqueeze(0)).ToArray(); 61 | var rewardArray = batchStep.Select(a => a.Reward.Value).ToArray(); 62 | var stateNextArray = batchStep.Select(a => a.PostState.Value!.unsqueeze(0)).ToArray(); 63 | var doneArray = batchStep.Select(a => a.IsComplete).ToArray(); 64 | 65 | /// Convert to VStack 66 | var state = torch.vstack(stateArray); 67 | var actionV = torch.vstack(actArray).to(torch.ScalarType.Int64); 68 | var reward = torch.from_array(rewardArray).reshape(batchsize); 69 | var stateNext = torch.vstack(stateNextArray); 70 | var done = torch.from_array(doneArray).reshape(batchsize); 71 | 72 | var excase = new ExperienceCase(state, actionV, reward, stateNext, done); 73 | return excase; 74 | } 75 | 76 | 77 | public void Clear() 78 | { 79 | Buffers.Clear(); 80 | } 81 | } 82 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/ExpReplays/ExperienceCase.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.ExperienceSources 2 | { 3 | /// 4 | /// 5 | public struct ExperienceCase 6 | { 7 | public ExperienceCase(torch.Tensor preState, 8 | torch.Tensor action, 9 | torch.Tensor reward, 10 | torch.Tensor postState, 11 | torch.Tensor done) 12 | { 13 | PreState = preState; 14 | Action = action; 15 | Reward = reward; 16 | PostState = postState; 17 | Done = done; 18 | } 19 | 20 | /// 21 | /// State before action 22 | /// 23 | public torch.Tensor PreState { get; set; } 24 | 25 | /// 26 | /// Action 27 | /// 28 | public torch.Tensor Action { set; get; } 29 | 30 | /// 31 | /// Reward 32 | /// 33 | public torch.Tensor Reward { set; get; } 34 | 35 | /// 36 | /// State after action 37 | /// 38 | public torch.Tensor PostState { set; get; } 39 | 40 | /// 41 | /// Episode is complete? 42 | /// 43 | public torch.Tensor Done { set; get; } 44 | } 45 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/ExpReplays/PrioritizedExpReplay.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | 3 | namespace DeepSharp.RL.ExpReplays 4 | { 5 | /// 6 | /// Uniform sample from Experience Source Cache 7 | /// 8 | public class PrioritizedExpReplay : ExpReplay 9 | { 10 | /// 11 | /// 12 | /// Capacity of Experience Replay Buffer,recommend 10^5 ~ 10^6 13 | public PrioritizedExpReplay(int capacity = 10000) 14 | : base(capacity) 15 | { 16 | } 17 | 18 | 19 | /// 20 | /// Uniform sample batch size steps from Queue 21 | /// 22 | /// batch size 23 | /// 24 | protected override Step[] SampleSteps(int batchsize) 25 | { 26 | var probs = torch.from_array(Buffers.Select(a => a.Priority).ToArray()); 27 | var randomIndex = torch.multinomial(probs, batchsize).data().ToArray(); 28 | 29 | var steps = randomIndex 30 | .AsParallel() 31 | .Select(i => Buffers.ElementAt((int) i)) 32 | .ToArray(); 33 | 34 | return steps; 35 | } 36 | } 37 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/ExpReplays/UniformExpReplay.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | 3 | namespace DeepSharp.RL.ExpReplays 4 | { 5 | /// 6 | /// Uniform sample from Experience Source Cache 7 | /// 8 | public class UniformExpReplay : ExpReplay 9 | { 10 | /// 11 | /// 12 | /// Capacity of Experience Replay Buffer,recommend 10^5 ~ 10^6 13 | public UniformExpReplay(int capacity = 10000) 14 | : base(capacity) 15 | { 16 | } 17 | 18 | 19 | /// 20 | /// Uniform sample batch size steps from Queue 21 | /// 22 | /// batch size 23 | /// 24 | protected override Step[] SampleSteps(int batchsize) 25 | { 26 | var randomIndex = torch.randint(0, Size, new[] {batchsize}).data().ToArray(); 27 | 28 | var steps = randomIndex 29 | .Select(i => Buffers.ElementAt((int) i)) 30 | .ToArray(); 31 | 32 | return steps; 33 | } 34 | } 35 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Trainers/RLTrainOption.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.RL.Trainers 2 | { 3 | public struct RLTrainOption 4 | { 5 | public RLTrainOption() 6 | { 7 | TrainEpoch = 0; 8 | StopReward = 0; 9 | ValEpisode = 0; 10 | ValInterval = 0; 11 | SaveFolder = string.Empty; 12 | OutTimeSpan = TimeSpan.FromHours(1); 13 | AutoSave = false; 14 | } 15 | 16 | public float StopReward { set; get; } 17 | public int TrainEpoch { set; get; } 18 | public int ValInterval { set; get; } 19 | public int ValEpisode { set; get; } 20 | public string SaveFolder { set; get; } 21 | public TimeSpan OutTimeSpan { set; get; } 22 | public bool AutoSave { set; get; } 23 | } 24 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Trainers/RLTrainer.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using DeepSharp.RL.Environs; 3 | 4 | namespace DeepSharp.RL.Trainers 5 | { 6 | public class RLTrainer 7 | { 8 | private TrainerCallBack? callback; 9 | 10 | public RLTrainer(Agent agent) 11 | { 12 | Agent = agent; 13 | } 14 | 15 | public RLTrainer(Agent agent, Action print) 16 | { 17 | Agent = agent; 18 | Print = print; 19 | } 20 | 21 | public Agent Agent { set; get; } 22 | 23 | public TrainerCallBack? Callback 24 | { 25 | set 26 | { 27 | callback = value; 28 | if (callback != null) 29 | callback.RlTrainer = this; 30 | } 31 | get => callback; 32 | } 33 | 34 | public Action? Print { set; get; } 35 | 36 | 37 | public virtual void Train( 38 | float preReward, 39 | int trainEpoch, 40 | string saveFolder = "", 41 | int testEpisodes = -1, 42 | int testInterval = 5, 43 | bool autoSave = false) 44 | { 45 | OnTrainStart(); 46 | 47 | var valEpoch = 0; 48 | foreach (var epoch in Enumerable.Range(1, trainEpoch)) 49 | { 50 | OnLearnStart(epoch); 51 | var outcome = Agent.Learn(); 52 | OnLearnEnd(epoch, outcome); 53 | 54 | 55 | if (testEpisodes <= 0) 56 | continue; 57 | 58 | if (epoch % testInterval == 0) 59 | { 60 | valEpoch++; 61 | OnValStart(valEpoch); 62 | var episodes = Agent.RunEpisodes(testEpisodes); 63 | OnValStop(valEpoch, episodes); 64 | 65 | var valReward = episodes.Average(e => e.SumReward.Value); 66 | 67 | if (valReward < preReward) 68 | continue; 69 | 70 | /// val reward > pre reward 71 | /// save and break from training 72 | if (autoSave) 73 | { 74 | OnSaveStart(); 75 | Agent.Save(Path.Combine(saveFolder, $"[{Agent}]_{epoch}_{valReward:F2}.st")); 76 | OnSaveEnd(); 77 | } 78 | 79 | break; 80 | } 81 | } 82 | 83 | OnTrainEnd(); 84 | } 85 | 86 | 87 | public virtual void Train(RLTrainOption tp) 88 | { 89 | Train(tp.StopReward, tp.TrainEpoch, tp.SaveFolder, tp.ValEpisode, tp.ValInterval); 90 | } 91 | 92 | public virtual void Val(int valEpoch) 93 | { 94 | var episodes = Agent.RunEpisodes(valEpoch); 95 | var aveReward = episodes.Average(a => a.SumReward.Value); 96 | Print?.Invoke($"[Val] {valEpoch:D5}\tR:[{aveReward}]"); 97 | foreach (var episode in episodes) Print?.Invoke(episode); 98 | } 99 | 100 | 101 | 102 | 103 | #region MyRegion 104 | 105 | protected virtual void OnTrainStart() 106 | { 107 | Print?.Invoke($"[{Agent}] start training."); 108 | Callback?.OnTrainStart(); 109 | } 110 | 111 | protected virtual void OnTrainEnd() 112 | { 113 | Print?.Invoke($"[{Agent}] stop training."); 114 | Callback?.OnTrainEnd(); 115 | } 116 | 117 | 118 | protected virtual void OnLearnStart(int epoch) 119 | { 120 | Callback?.OnLearnStart(epoch); 121 | } 122 | 123 | 124 | protected virtual void OnLearnEnd(int epoch, LearnOutcome outcome) 125 | { 126 | Print?.Invoke($"[Tra]\t{epoch:D5}\t{outcome}"); 127 | Callback?.OnLearnEnd(epoch, outcome); 128 | } 129 | 130 | protected virtual void OnValStart(int epoch) 131 | { 132 | Callback?.OnValStart(epoch); 133 | } 134 | 135 | protected virtual void OnValStop(int epoch, Episode[] episodes) 136 | { 137 | var aveReward = episodes.Average(a => a.SumReward.Value); 138 | Print?.Invoke($"[Val]\t{epoch:D5}\tE:{episodes.Length}:\tR:{aveReward:F4}"); 139 | Callback?.OnValEnd(epoch, episodes); 140 | } 141 | 142 | protected virtual void OnSaveStart() 143 | { 144 | Callback?.OnSaveStart(); 145 | } 146 | 147 | protected virtual void OnSaveEnd() 148 | { 149 | Callback?.OnSaveEnd(); 150 | } 151 | 152 | #endregion 153 | } 154 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Trainers/TrainerCallBack.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using DeepSharp.RL.Environs; 3 | 4 | namespace DeepSharp.RL.Trainers 5 | { 6 | public abstract class TrainerCallBack 7 | { 8 | public RLTrainer RlTrainer { set; get; } = null!; 9 | 10 | public abstract void OnTrainStart(); 11 | public abstract void OnTrainEnd(); 12 | public abstract void OnLearnStart(int epoch); 13 | public abstract void OnLearnEnd(int epoch, LearnOutcome outcome); 14 | public abstract void OnValStart(int epoch); 15 | public abstract void OnValEnd(int epoch, Episode[] episodes); 16 | public abstract void OnSaveStart(); 17 | public abstract void OnSaveEnd(); 18 | } 19 | } -------------------------------------------------------------------------------- /src/DeepSharp.RL/Usings.cs: -------------------------------------------------------------------------------- 1 | global using TorchSharp; 2 | global using OpenCvSharp; 3 | global using static TorchSharp.torch.nn; -------------------------------------------------------------------------------- /src/DeepSharp.Utility/Converters/Convert.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.Utility.Operations; 2 | 3 | namespace DeepSharp.Utility.Converters 4 | { 5 | public class Convert 6 | { 7 | /// 8 | /// Convert Mat to Tensor 9 | /// 10 | /// 11 | /// 12 | public static torch.Tensor ToTensor(Mat mat) 13 | { 14 | var dims = OpMat.GetDims(mat); 15 | mat.GetArray(out float[] d); 16 | var original = torch.from_array(d); 17 | var final = original.reshape(dims); 18 | return final; 19 | } 20 | 21 | 22 | /// 23 | /// 转换为 Array 24 | /// 25 | /// 26 | /// 27 | /// 28 | public static Array ToArray(torch.Tensor tensor) 29 | { 30 | switch (tensor.dtype) 31 | { 32 | case torch.ScalarType.Byte: 33 | return tensor.data().ToNDArray(); 34 | 35 | case torch.ScalarType.Int8: 36 | return tensor.data().ToNDArray(); 37 | 38 | case torch.ScalarType.Int16: 39 | return tensor.data().ToNDArray(); 40 | 41 | case torch.ScalarType.Int32: 42 | return tensor.data().ToNDArray(); 43 | 44 | case torch.ScalarType.Int64: 45 | return tensor.data().ToNDArray(); 46 | 47 | case torch.ScalarType.Float16: 48 | case torch.ScalarType.BFloat16: 49 | case torch.ScalarType.Float32: 50 | return tensor.data().ToNDArray(); 51 | case torch.ScalarType.Float64: 52 | return tensor.data().ToNDArray(); 53 | 54 | case torch.ScalarType.Bool: 55 | return tensor.data().ToNDArray(); 56 | 57 | default: 58 | throw new ArgumentOutOfRangeException(); 59 | } 60 | } 61 | } 62 | } -------------------------------------------------------------------------------- /src/DeepSharp.Utility/DeepSharp.Utility.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | net7.0 5 | enable 6 | enable 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /src/DeepSharp.Utility/Operations/OpMat.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.Utility.Operations 2 | { 3 | public class OpMat 4 | { 5 | /// 6 | /// Get Dims os Mat 7 | /// 8 | /// 9 | /// 10 | public static long[] GetDims(Mat mat) 11 | { 12 | mat.GetArray(out float[] d); 13 | var dims = Enumerable.Range(0, mat.Dims) 14 | .Select(a => (long) mat.Size(a)) 15 | .ToArray(); 16 | return dims; 17 | } 18 | } 19 | } -------------------------------------------------------------------------------- /src/DeepSharp.Utility/TensorEqualityCompare.cs: -------------------------------------------------------------------------------- 1 | namespace DeepSharp.Utility 2 | { 3 | public class TensorEqualityCompare : IEqualityComparer 4 | { 5 | public bool Equals(torch.Tensor? x, torch.Tensor? y) 6 | { 7 | return x!.Equals(y!); 8 | } 9 | 10 | public int GetHashCode(torch.Tensor obj) 11 | { 12 | return -1; 13 | } 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/DeepSharp.Utility/Usings.cs: -------------------------------------------------------------------------------- 1 | global using TorchSharp; 2 | global using OpenCvSharp; 3 | global using static TorchSharp.torch.nn; -------------------------------------------------------------------------------- /src/DeepSharp.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 17 4 | VisualStudioVersion = 17.5.33530.505 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TorchSharpTest", "TorchSharpTest\TorchSharpTest.csproj", "{77C48BA9-1604-4EE0-BCBE-DC9B6DC238A8}" 7 | EndProject 8 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Documents", "Documents", "{8919C8FE-5A2A-477B-A1A3-832F4EAC36A5}" 9 | ProjectSection(SolutionItems) = preProject 10 | ..\.gitignore = ..\.gitignore 11 | ..\images\DQN2013.png = ..\images\DQN2013.png 12 | ..\images\DQN2015.png = ..\images\DQN2015.png 13 | ..\images\DQN2015_.png = ..\images\DQN2015_.png 14 | ..\images\QLearning.png = ..\images\QLearning.png 15 | ..\README.md = ..\README.md 16 | Reinforcement Learning.md = Reinforcement Learning.md 17 | ..\images\Reinforcement.png = ..\images\Reinforcement.png 18 | ..\images\RL CrossEntroy Demo.png = ..\images\RL CrossEntroy Demo.png 19 | ..\images\Sarsa.png = ..\images\Sarsa.png 20 | EndProjectSection 21 | EndProject 22 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DeepSharp.Utility", "DeepSharp.Utility\DeepSharp.Utility.csproj", "{93F876FF-1723-404F-9AA3-CD1D0873A3AB}" 23 | EndProject 24 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DeepSharp.RL", "DeepSharp.RL\DeepSharp.RL.csproj", "{8E1D7BEF-3F6F-4237-A48D-3DA57757EB46}" 25 | EndProject 26 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "RLConsole", "RLConsole\RLConsole.csproj", "{96E80CDB-2323-44BD-8D48-225313319E35}" 27 | EndProject 28 | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "DeepSharp.Dataset", "DeepSharp.Dataset\DeepSharp.Dataset.csproj", "{03E2BC70-5E53-4170-BC46-936C6A6BF964}" 29 | EndProject 30 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Resources", "Resources", "{44AFB17A-FF8C-4E9A-84EB-FE26F4D188A3}" 31 | ProjectSection(SolutionItems) = preProject 32 | ..\resources\iris-test.txt = ..\resources\iris-test.txt 33 | ..\resources\iris-train.txt = ..\resources\iris-train.txt 34 | EndProjectSection 35 | EndProject 36 | Global 37 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 38 | Debug|Any CPU = Debug|Any CPU 39 | Release|Any CPU = Release|Any CPU 40 | EndGlobalSection 41 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 42 | {77C48BA9-1604-4EE0-BCBE-DC9B6DC238A8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 43 | {77C48BA9-1604-4EE0-BCBE-DC9B6DC238A8}.Debug|Any CPU.Build.0 = Debug|Any CPU 44 | {77C48BA9-1604-4EE0-BCBE-DC9B6DC238A8}.Release|Any CPU.ActiveCfg = Release|Any CPU 45 | {77C48BA9-1604-4EE0-BCBE-DC9B6DC238A8}.Release|Any CPU.Build.0 = Release|Any CPU 46 | {93F876FF-1723-404F-9AA3-CD1D0873A3AB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 47 | {93F876FF-1723-404F-9AA3-CD1D0873A3AB}.Debug|Any CPU.Build.0 = Debug|Any CPU 48 | {93F876FF-1723-404F-9AA3-CD1D0873A3AB}.Release|Any CPU.ActiveCfg = Release|Any CPU 49 | {93F876FF-1723-404F-9AA3-CD1D0873A3AB}.Release|Any CPU.Build.0 = Release|Any CPU 50 | {8E1D7BEF-3F6F-4237-A48D-3DA57757EB46}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 51 | {8E1D7BEF-3F6F-4237-A48D-3DA57757EB46}.Debug|Any CPU.Build.0 = Debug|Any CPU 52 | {8E1D7BEF-3F6F-4237-A48D-3DA57757EB46}.Release|Any CPU.ActiveCfg = Release|Any CPU 53 | {8E1D7BEF-3F6F-4237-A48D-3DA57757EB46}.Release|Any CPU.Build.0 = Release|Any CPU 54 | {96E80CDB-2323-44BD-8D48-225313319E35}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 55 | {96E80CDB-2323-44BD-8D48-225313319E35}.Debug|Any CPU.Build.0 = Debug|Any CPU 56 | {96E80CDB-2323-44BD-8D48-225313319E35}.Release|Any CPU.ActiveCfg = Release|Any CPU 57 | {96E80CDB-2323-44BD-8D48-225313319E35}.Release|Any CPU.Build.0 = Release|Any CPU 58 | {03E2BC70-5E53-4170-BC46-936C6A6BF964}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 59 | {03E2BC70-5E53-4170-BC46-936C6A6BF964}.Debug|Any CPU.Build.0 = Debug|Any CPU 60 | {03E2BC70-5E53-4170-BC46-936C6A6BF964}.Release|Any CPU.ActiveCfg = Release|Any CPU 61 | {03E2BC70-5E53-4170-BC46-936C6A6BF964}.Release|Any CPU.Build.0 = Release|Any CPU 62 | EndGlobalSection 63 | GlobalSection(SolutionProperties) = preSolution 64 | HideSolutionNode = FALSE 65 | EndGlobalSection 66 | GlobalSection(ExtensibilityGlobals) = postSolution 67 | SolutionGuid = {09494C35-C628-45BE-B4FB-BE6B0D82A6D0} 68 | EndGlobalSection 69 | EndGlobal 70 | -------------------------------------------------------------------------------- /src/DeepSharp.sln.DotSettings: -------------------------------------------------------------------------------- 1 |  2 | True 3 | True 4 | True 5 | True -------------------------------------------------------------------------------- /src/RLConsole/Program.cs: -------------------------------------------------------------------------------- 1 | // See https://aka.ms/new-console-template for more information 2 | 3 | using DeepSharp.RL.Agents; 4 | using DeepSharp.RL.Environs; 5 | 6 | static void Print(object obj) 7 | { 8 | Console.WriteLine(obj.ToString()); 9 | } 10 | 11 | var frozenlake = new Frozenlake(new[] {0.8f, 0.1f, 0.1f}); 12 | var agent = new DQN(frozenlake, 100, 1000, 0.9f, batchSize: 16); 13 | Print(frozenlake); 14 | 15 | 16 | var i = 0; 17 | float reward; 18 | const int testEpisode = 20; 19 | const float predReward = 0.82f; 20 | do 21 | { 22 | i++; 23 | 24 | frozenlake.Reset(); 25 | agent.Learn(); 26 | //var e = 1f - (1 - 0.01f) / 1000 * i; 27 | //agent.Epsilon = e < 0.01f ? 0.01f : e; 28 | 29 | reward = agent.TestEpisodes(testEpisode); 30 | Print($"{i}:\t{reward}"); 31 | } while (reward < predReward); 32 | 33 | Print($"Stop after Learn {i}"); 34 | frozenlake.ChangeToRough(); 35 | var episode = agent.RunEpisode(); 36 | Print(episode); -------------------------------------------------------------------------------- /src/RLConsole/RLConsole.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Exe 5 | net7.0 6 | enable 7 | enable 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /src/RLConsole/Utility.cs: -------------------------------------------------------------------------------- 1 | namespace RLConsole 2 | { 3 | public class Utility 4 | { 5 | public static void Print(object obj) 6 | { 7 | Console.WriteLine(obj.ToString()); 8 | } 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /src/Reinforcement Learning.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xin-pu/DeepSharp/aff6662bd8bb3a30af0340efc3848aa30da8c65a/src/Reinforcement Learning.md -------------------------------------------------------------------------------- /src/TorchSharpTest/AbstractTest.cs: -------------------------------------------------------------------------------- 1 | namespace TorchSharpTest 2 | { 3 | public class AbstractTest 4 | { 5 | private readonly ITestOutputHelper _testOutputHelper; 6 | 7 | public AbstractTest(ITestOutputHelper testOutputHelper) 8 | { 9 | _testOutputHelper = testOutputHelper; 10 | } 11 | 12 | protected void writeLine(object? obj) 13 | { 14 | _testOutputHelper.WriteLine(obj?.ToString()); 15 | } 16 | 17 | internal void Print(string[] objs) 18 | { 19 | foreach (var o in objs) Print(o); 20 | } 21 | 22 | internal void Print(object obj) 23 | { 24 | writeLine(obj); 25 | } 26 | 27 | 28 | /// 29 | /// Todo optimize print Tensor 30 | /// 31 | /// 32 | internal void Print(torch.Tensor tensor) 33 | { 34 | writeLine(tensor.ToString(torch.numpy)); 35 | writeLine(tensor); 36 | writeLine(""); 37 | } 38 | 39 | /// 40 | /// Todo optimize print Tensor 41 | /// 42 | /// 43 | internal void Print(float tensor) 44 | { 45 | writeLine(tensor); 46 | } 47 | } 48 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/DemoTest/DemoNet.cs: -------------------------------------------------------------------------------- 1 | namespace TorchSharpTest.DemoTest 2 | { 3 | /// 4 | /// This is demo net to guide how to create a new Module 5 | /// 6 | public sealed class DemoNet : Module 7 | { 8 | private readonly Module layers; 9 | 10 | public DemoNet(int obsSize, int actionNum) : base("Net") 11 | { 12 | var modules = new List<(string, Module)> 13 | { 14 | ("line1", Linear(obsSize, 10)), 15 | ("line2", Linear(10, actionNum)) 16 | }; 17 | layers = Sequential(modules); 18 | RegisterComponents(); 19 | } 20 | 21 | 22 | public override torch.Tensor forward(torch.Tensor input) 23 | { 24 | return layers.forward(input); 25 | } 26 | 27 | protected override void Dispose(bool disposing) 28 | { 29 | if (disposing) 30 | { 31 | layers.Dispose(); 32 | ClearModules(); 33 | } 34 | 35 | base.Dispose(disposing); 36 | } 37 | } 38 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/DemoTest/IrisTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.Dataset; 2 | using TorchSharpTest.SampleDataset; 3 | 4 | namespace TorchSharpTest.DemoTest 5 | { 6 | public class IrisTest : AbstractTest 7 | { 8 | public IrisTest(ITestOutputHelper testOutputHelper) 9 | : base(testOutputHelper) 10 | { 11 | } 12 | 13 | 14 | private torch.Device device => new(DeviceType.CPU); 15 | public string SaveFile => "Iris.txt"; 16 | 17 | 18 | /// 19 | /// CrossEntropyLoss 20 | /// 21 | [Fact] 22 | public async void Train() 23 | { 24 | var dataset = new Dataset(@"..\..\..\..\..\Resources\iris-train.txt"); 25 | var dataConfig = new DataLoaderConfig { BatchSize = 8, Device = device }; 26 | var dataloader = new DataLoader(dataset, dataConfig); 27 | 28 | 29 | var net = new DemoNet(4, 3).to(device); 30 | 31 | var optimizer = torch.optim.Adam(net.parameters()); 32 | var crossEntropyLoss = CrossEntropyLoss(); 33 | foreach (var epoch in Enumerable.Range(0, 500)) 34 | { 35 | var lossEpoch = new List(); 36 | await foreach (var datapair in dataloader.GetBatchSample()) 37 | { 38 | var (x, y) = (datapair.Features, datapair.Labels); 39 | 40 | var eval = net.forward(x); 41 | var output = crossEntropyLoss.call(eval, y); 42 | 43 | optimizer.zero_grad(); 44 | output.backward(); 45 | optimizer.step(); 46 | 47 | var loss = output.item(); 48 | lossEpoch.Add(loss); 49 | } 50 | 51 | var t = lossEpoch.Average(); 52 | Print($"epoch:\t{epoch:D5}\tLoss:\t{t:F4}"); 53 | } 54 | 55 | if (File.Exists(SaveFile)) File.Delete(SaveFile); 56 | net.save(SaveFile); 57 | } 58 | 59 | 60 | [Fact] 61 | public void Predict() 62 | { 63 | using var net = new DemoNet(4, 3); 64 | { 65 | net.load(SaveFile); 66 | 67 | var testdata = new IrisOneHot 68 | { 69 | SepalLength = 5.0f, 70 | SepalWidth = 3.6f, 71 | PetalLength = 1.4f, 72 | PetalWidth = 0.2f 73 | }; 74 | var y = net.forward(testdata.GetFeatures().unsqueeze(0)); 75 | var yy = Softmax(1).call(y); 76 | var res = yy.argmax().item(); 77 | 78 | Print(string.Join(",", yy.data().ToArray())); 79 | Print(res.ToString()); 80 | } 81 | } 82 | } 83 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/LossTest/LossTest.cs: -------------------------------------------------------------------------------- 1 | namespace TorchSharpTest.LossTest 2 | { 3 | public class LossTest : AbstractTest 4 | { 5 | public LossTest(ITestOutputHelper testOutputHelper) 6 | : base(testOutputHelper) 7 | { 8 | } 9 | 10 | [Fact] 11 | public void Cross() 12 | { 13 | var input = torch.randn(3, 5, requires_grad: true); 14 | var target = torch.empty(3, torch.ScalarType.Int64).randint_like(0, 5); 15 | var loss = CrossEntropyLoss(); 16 | var c = loss.call(input, target); 17 | c.backward(); 18 | var array = c.data().ToArray(); 19 | Print(string.Join(",", array)); 20 | } 21 | } 22 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/AgentTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | 3 | namespace TorchSharpTest.RLTest 4 | { 5 | public class AgentTest 6 | { 7 | public AgentTest(Agent agent) 8 | { 9 | Agent = agent; 10 | } 11 | 12 | public Agent Agent { set; get; } 13 | 14 | 15 | public float TestEpisode(int testCount) 16 | { 17 | var episode = Agent.RunEpisodes(testCount); 18 | var averageReward = episode.Average(a => a.SumReward.Value); 19 | return averageReward; 20 | } 21 | } 22 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/EnvironTest/FrozenLakeTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs; 2 | using FluentAssertions; 3 | 4 | namespace TorchSharpTest.RLTest.EnvironTest 5 | { 6 | public class FrozenLakeTest : AbstractTest 7 | { 8 | public DeviceType DeviceType = DeviceType.CUDA; 9 | 10 | public FrozenLakeTest(ITestOutputHelper testOutputHelper) 11 | : base(testOutputHelper) 12 | { 13 | } 14 | 15 | [Fact] 16 | public void FrozenLakeCreateTest1() 17 | { 18 | var frozenlake = new Frozenlake(deviceType: DeviceType); 19 | Print(frozenlake); 20 | } 21 | 22 | [Fact] 23 | public void FrozenLakeCreate2Test() 24 | { 25 | var frozenlake = new Frozenlake(deviceType: DeviceType); 26 | frozenlake.SetPlayID(15); 27 | Print(frozenlake); 28 | frozenlake.IsComplete(1).Should().BeTrue(); 29 | } 30 | 31 | 32 | [Fact] 33 | public void FrozenLakeTestMove() 34 | { 35 | var frozenlake = new Frozenlake(); 36 | var testEpoch = 100; 37 | var count = 0; 38 | var countL = 0; 39 | var countR = 0; 40 | foreach (var i in Enumerable.Range(0, testEpoch)) 41 | { 42 | frozenlake.SetPlayID(1); 43 | frozenlake.Step(new Act(torch.from_array(new[] {1})), 1); 44 | if (frozenlake.PlayID == 5) count++; 45 | if (frozenlake.PlayID == 0) countL++; 46 | if (frozenlake.PlayID == 2) countR++; 47 | } 48 | 49 | var probTarget = count * 1f / testEpoch; 50 | var probLeft = countL * 1f / testEpoch; 51 | var probRight = countR * 1f / testEpoch; 52 | 53 | Print($"{probTarget:P2}\t{probLeft:P2}\t{probRight:P2}"); 54 | } 55 | } 56 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/EnvironTest/KArmedBanditTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using DeepSharp.RL.Enumerates; 3 | using DeepSharp.RL.Environs; 4 | 5 | namespace TorchSharpTest.RLTest.EnvironTest 6 | { 7 | public class KArmedBanditTest : AbstractTest 8 | { 9 | public KArmedBanditTest(ITestOutputHelper testOutputHelper) 10 | : base(testOutputHelper) 11 | { 12 | } 13 | 14 | [Fact] 15 | public void BanditCreateTest() 16 | { 17 | var count = 0; 18 | var bandit = new Bandit("A"); 19 | Print(bandit); 20 | var range = 100; 21 | foreach (var _ in Enumerable.Range(0, range)) 22 | { 23 | var res = bandit.Step(); 24 | if (res > 0) count++; 25 | } 26 | 27 | Print(count * 1f / range); 28 | } 29 | 30 | [Fact] 31 | public void KArmedBanditCreateTest() 32 | { 33 | var kArmedBandit = new KArmedBandit(5); 34 | Print(kArmedBandit); 35 | } 36 | 37 | [Fact] 38 | public void RandomPickup() 39 | { 40 | var probs = new[] {0.5f, 0.5f}; 41 | 42 | foreach (var _ in Enumerable.Repeat(0, 10)) 43 | { 44 | var res = torch.multinomial(torch.from_array(probs), 1, true); 45 | var index = res.item(); 46 | Print(index); 47 | } 48 | } 49 | 50 | [Fact] 51 | public void AgentCrossEntropy() 52 | { 53 | var epoch = 100; 54 | var episodesEachBatch = 20; 55 | 56 | /// Step 1 Create a 4-Armed Bandit 57 | var kArmedBandit = new KArmedBandit(2) 58 | { 59 | [0] = {Prob = 0.4}, 60 | [1] = {Prob = 0.75} 61 | }; 62 | Print(kArmedBandit); 63 | 64 | /// Step 2 Create AgentCrossEntropy with 0.7f percentElite as default 65 | var agent = new CrossEntropy(kArmedBandit, episodesEachBatch); 66 | 67 | /// Step 3 Learn and Optimize 68 | foreach (var i in Enumerable.Range(0, epoch)) 69 | { 70 | var loss = agent.Learn(); 71 | 72 | var test = agent.RunEpisodes(episodesEachBatch); 73 | 74 | var rewardMean = test.Select(a => a.SumReward.Value).Average(); 75 | 76 | Print($"Epoch:{i:D4}\tReward:{rewardMean:F4}\tLoss:{loss:F4}"); 77 | } 78 | } 79 | 80 | 81 | [Fact] 82 | public void QLearningRunRandom() 83 | { 84 | /// Step 1 Create a 4-Armed Bandit 85 | var kArmedBandit = new KArmedBandit(2); 86 | Print(kArmedBandit); 87 | 88 | /// Step 2 Create AgentCrossEntropy with 0.7f percentElite as default 89 | var agent = new ValueIteration(kArmedBandit, 20); 90 | agent.RunEpisodes(20, PlayMode.Sample); 91 | Print(kArmedBandit); 92 | } 93 | 94 | [Fact] 95 | public void QLearningMain() 96 | { 97 | /// Step 1 Create a 4-Armed Bandit 98 | var kArmedBandit = new KArmedBandit(4) 99 | { 100 | [0] = {Prob = 0.5}, 101 | [1] = {Prob = 0.2}, 102 | [2] = {Prob = 0.4}, 103 | [3] = {Prob = 0.8} 104 | }; 105 | /// Step 2 Create AgentCrossEntropy with 0.7f percentElite as default 106 | var agent = new ValueIteration(kArmedBandit, 100); 107 | Print(kArmedBandit); 108 | 109 | var i = 0; 110 | var bestReward = 0f; 111 | while (i < 100) 112 | { 113 | agent.Learn(); 114 | 115 | var episodes = agent.RunEpisodes(10); 116 | foreach (var episode in episodes) 117 | agent.Update(episode); 118 | 119 | bestReward = new[] {bestReward, episodes.Average(a => a.SumReward.Value)}.Max(); 120 | Print($"{agent} Play:{++i:D3}\t {bestReward}"); 121 | if (bestReward > 18) 122 | break; 123 | } 124 | } 125 | } 126 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/ModelTest/ActionSelectorTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.ActionSelectors; 2 | using FluentAssertions; 3 | 4 | namespace TorchSharpTest.RLTest.ModelTest 5 | { 6 | public class ActionSelectorTest : AbstractTest 7 | { 8 | public ActionSelectorTest(ITestOutputHelper testOutputHelper) 9 | : base(testOutputHelper) 10 | { 11 | } 12 | 13 | [Fact] 14 | public void ArgmaxActionSelectorTest() 15 | { 16 | var input = torch.from_array(new double[,] {{1, 2, 3}, {1, -1, 0}}); 17 | var res = new ArgmaxActionSelector().Select(input); 18 | res.Equals(torch.tensor(new long[] {2, 0})).Should().BeTrue(); 19 | } 20 | 21 | [Fact] 22 | public void ProbabilityActionSelectorTest() 23 | { 24 | var input = torch.from_array(new[,] {{1f, 0, 0}, {0, 1f, 0}}); 25 | var res = new ProbActionSelector().Select(input); 26 | res.Equals(torch.from_array(new long[] {0, 1})).Should().BeTrue(); 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/ModelTest/BasicTest.cs: -------------------------------------------------------------------------------- 1 | namespace TorchSharpTest.RLTest.ModelTest 2 | { 3 | public class BasicTest : AbstractTest 4 | { 5 | public BasicTest(ITestOutputHelper testOutputHelper) 6 | : base(testOutputHelper) 7 | { 8 | } 9 | } 10 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/ModelTest/PolicyNetTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | 3 | namespace TorchSharpTest.RLTest.ModelTest 4 | { 5 | public class PolicyNetTest : AbstractTest 6 | { 7 | public PolicyNetTest(ITestOutputHelper testOutputHelper) 8 | : base(testOutputHelper) 9 | { 10 | } 11 | 12 | [Fact] 13 | public void TestNet() 14 | { 15 | var net = new PGN(3, 128, 2, DeviceType.CPU); 16 | var res = net.forward(torch.from_array(new float[,] {{1, 2, 3}})); 17 | Print(res); 18 | } 19 | } 20 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/ModelTest/QTableTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using FluentAssertions; 3 | 4 | namespace TorchSharpTest.RLTest.ModelTest 5 | { 6 | public class QTableTest : AbstractTest 7 | { 8 | public QTableTest(ITestOutputHelper testOutputHelper) 9 | : base(testOutputHelper) 10 | { 11 | } 12 | 13 | [Fact] 14 | public void TestTransitKeyOperator() 15 | { 16 | var state1 = torch.tensor(new[] {0, 0, 1}); 17 | var action1 = torch.tensor(new[] {0, 1, 0}); 18 | var key1 = new TransitKey(state1, action1); 19 | 20 | 21 | var state2 = torch.tensor(new[] {0, 0, 1}); 22 | var action2 = torch.tensor(new[] {0, 1, 0}); 23 | var key2 = new TransitKey(state2, action2); 24 | 25 | (key2 == key1).Should().BeTrue(); 26 | } 27 | 28 | [Fact] 29 | public void TestTransitKeyDict() 30 | { 31 | var state = torch.tensor(new[] {0, 0, 1}); 32 | var action = torch.tensor(new[] {0, 1, 0}); 33 | var key = new TransitKey(state, action); 34 | var returnDict = new Dictionary {[key] = 2}; 35 | 36 | var stateTest = torch.tensor(new[] {0, 0, 1}); 37 | var actionTest = torch.tensor(new[] {0, 1, 0}); 38 | var keyTest = new TransitKey(stateTest, actionTest); 39 | var res = returnDict[keyTest]; 40 | res.Should().Be(2); 41 | } 42 | 43 | 44 | [Fact] 45 | public void CreateValueTableTest1() 46 | { 47 | var vt = new QTable(); 48 | var state = torch.tensor(new[] {0, 0, 1}); 49 | var action = torch.tensor(new[] {0, 1, 0}); 50 | var tr = new TransitKey(state, action); 51 | vt[tr] = 3f; 52 | Print(vt[tr]); 53 | var state2 = torch.tensor(new[] {0, 1, 1}); 54 | Print(vt[state2, action]); 55 | } 56 | 57 | [Fact] 58 | public void CreateValueTableTest2() 59 | { 60 | var vt = new QTable(); 61 | var state = torch.tensor(new[] {0, 0, 1}); 62 | var action = torch.tensor(new[] {0, 1, 0}); 63 | vt[state, action] = 3f; 64 | Print(vt[state, action]); 65 | var state2 = torch.tensor(new[] {0, 1, 1}); 66 | Print(vt[state2, action]); 67 | } 68 | } 69 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/ModelTest/SpaceTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Environs.Spaces; 2 | using FluentAssertions; 3 | 4 | namespace TorchSharpTest.RLTest.ModelTest 5 | { 6 | public class SpaceTest : AbstractTest 7 | { 8 | public SpaceTest(ITestOutputHelper testOutputHelper) 9 | : base(testOutputHelper) 10 | { 11 | } 12 | 13 | 14 | #region Disperse Test 15 | 16 | [Fact] 17 | public void DisperseCons() 18 | { 19 | var disperse = new Disperse(5); 20 | var s = disperse.Generate(); 21 | Print(s); 22 | 23 | var disperse2 = new Disperse(2); 24 | s = disperse2.Generate(); 25 | Print(s); 26 | } 27 | 28 | [Fact] 29 | public void DisperseGenerate() 30 | { 31 | var disperse = new Disperse(5, torch.ScalarType.Int8); 32 | var one = disperse.Generate(); 33 | Print(one); 34 | one.dtype.Should().Be(torch.ScalarType.Int8); 35 | 36 | disperse = new Disperse(5, torch.ScalarType.Int16); 37 | one = disperse.Generate(); 38 | Print(one); 39 | one.dtype.Should().Be(torch.ScalarType.Int16); 40 | 41 | disperse = new Disperse(5, torch.ScalarType.Int32); 42 | one = disperse.Generate(); 43 | Print(one); 44 | one.dtype.Should().Be(torch.ScalarType.Int32); 45 | 46 | disperse = new Disperse(5); 47 | one = disperse.Generate(); 48 | Print(one); 49 | one.dtype.Should().Be(torch.ScalarType.Int64); 50 | } 51 | 52 | [Fact] 53 | public void DisperseDevice() 54 | { 55 | var disperse = new Disperse(5, deviceType: DeviceType.CUDA); 56 | var one = disperse.Generate(); 57 | Print(one); 58 | one.device_type.Should().Be(DeviceType.CUDA); 59 | 60 | disperse = new Disperse(5, deviceType: DeviceType.CPU); 61 | one = disperse.Generate(); 62 | Print(one); 63 | one.device_type.Should().Be(DeviceType.CPU); 64 | } 65 | 66 | 67 | [Fact] 68 | public void DisperseSample() 69 | { 70 | var a = new Disperse(5); 71 | foreach (var _ in Enumerable.Repeat(0, 10)) 72 | { 73 | var data = a.Sample(); 74 | Print(data); 75 | data.ToInt64().Should().BeInRange(0, 4); 76 | } 77 | } 78 | 79 | #endregion 80 | 81 | 82 | #region Box 83 | 84 | [Fact] 85 | public void CreateFloatBox() 86 | { 87 | var box = new Box(0f, 1f, new long[] {2, 2}); 88 | var r = box.Sample(); 89 | Print(r); 90 | } 91 | 92 | [Fact] 93 | public void CreateDoubleBox() 94 | { 95 | var box = new Box(0d, 1d, new long[] {2, 2}); 96 | var r = box.Sample(); 97 | Print(r); 98 | } 99 | 100 | 101 | [Fact] 102 | public void CreateInt32Box() 103 | { 104 | var box = new Box(1, 5, new long[] {10}); 105 | var r = box.Sample(); 106 | Print(r); 107 | } 108 | 109 | [Fact] 110 | public void CreateInt64Box() 111 | { 112 | var box = new Box(1L, 5L, new long[] {10}); 113 | var r = box.Sample(); 114 | Print(r); 115 | } 116 | 117 | [Fact] 118 | public void CreateByteBox() 119 | { 120 | var box = new Box((byte) 0, (byte) 1, new long[] {10}); 121 | var r = box.Sample(); 122 | Print(r); 123 | } 124 | 125 | [Fact] 126 | public void CreateInt16Box() 127 | { 128 | var box = new Box((short) 1, (short) 5, new long[] {10}); 129 | var r = box.Sample(); 130 | Print(r); 131 | } 132 | 133 | #endregion 134 | 135 | 136 | #region Other Space 137 | 138 | [Fact] 139 | public void CreateMultiDisperse1() 140 | { 141 | var low = torch.tensor(new long[] {0, 0}); 142 | var high = torch.tensor(new long[] {3, 4}); 143 | var shape = new long[] {2}; 144 | var multiDisperse = new MultiDisperse(low, high, shape, torch.ScalarType.Int32); 145 | var r = multiDisperse.Sample(); 146 | Print(r); 147 | Print(multiDisperse); 148 | } 149 | 150 | [Fact] 151 | public void CreateMultiDisperse2() 152 | { 153 | var multiDisperse = new MultiDisperse(0, 1, new long[] {2}, torch.ScalarType.Int64); 154 | var r = multiDisperse.Sample(); 155 | Print(r); 156 | Print(multiDisperse); 157 | } 158 | 159 | 160 | [Fact] 161 | public void CreateBinary() 162 | { 163 | var binary = new Binary(torch.ScalarType.Int64); 164 | var r = binary.Sample(); 165 | Print(r); 166 | Print(binary); 167 | } 168 | 169 | 170 | [Fact] 171 | public void CreateMultiBinary1() 172 | { 173 | var binary = new MultiBinary(2L); 174 | var r = binary.Sample(); 175 | Print(r); 176 | Print(binary); 177 | } 178 | 179 | [Fact] 180 | public void CreateMultiBinary2() 181 | { 182 | var binary = new MultiBinary(new long[] {2, 2}, torch.ScalarType.Int64); 183 | var r = binary.Sample(); 184 | Print(r); 185 | Print(binary); 186 | } 187 | 188 | #endregion 189 | } 190 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/ModelTest/VTableTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using DeepSharp.Utility; 3 | 4 | namespace TorchSharpTest.RLTest.ModelTest 5 | { 6 | public class VTableTest : AbstractTest 7 | { 8 | public VTableTest(ITestOutputHelper testOutputHelper) 9 | : base(testOutputHelper) 10 | { 11 | } 12 | 13 | [Fact] 14 | public void CreateValueTableTest1() 15 | { 16 | var vt = new VTable(); 17 | var state = torch.tensor(new[] {0, 0, 1}); 18 | vt[state] = 3f; 19 | Print(vt[state]); 20 | var state2 = torch.tensor(new[] {0, 1, 1}); 21 | Print(vt[state2]); 22 | } 23 | 24 | [Fact] 25 | public void CreateValueTableTest2() 26 | { 27 | var state1 = torch.tensor(new[] {0, 0, 1}); 28 | var state2 = torch.tensor(new[] {0, 0, 1}); 29 | var arr = new[] {state1, state2}; 30 | var p = arr.Distinct(new TensorEqualityCompare()).ToList(); 31 | Print(p.Count); 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/PolicyBasedTest/ActorCriticTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using DeepSharp.RL.Environs; 3 | using DeepSharp.RL.Trainers; 4 | 5 | namespace TorchSharpTest.RLTest.PolicyBasedTest 6 | { 7 | public class ActorCriticTest : AbstractTest 8 | { 9 | public ActorCriticTest(ITestOutputHelper testOutputHelper) 10 | : base(testOutputHelper) 11 | { 12 | } 13 | 14 | [Fact] 15 | public void ACKMTest() 16 | { 17 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.7, 0.25}); 18 | var agent = new ActorCritic(kArmedBandit, 4, gamma: 0.99f); 19 | var trainer = new RLTrainer(agent, Print); 20 | trainer.Train(0.90f, 300, testEpisodes: 20, testInterval: 2, autoSave: false); 21 | agent.Save("ACKM.st"); 22 | } 23 | 24 | [Fact] 25 | public void ACKMCVal() 26 | { 27 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.7, 0.25}); 28 | var agent = new ActorCritic(kArmedBandit, 16); 29 | agent.Load("ACKM.st"); 30 | var trainer = new RLTrainer(agent, Print); 31 | trainer.Val(20); 32 | } 33 | 34 | [Fact] 35 | public void ACFLTest() 36 | { 37 | var frozenlake = new Frozenlake(new[] {0.8f, 0.1f, 0.1f}); 38 | var agent = new A2C(frozenlake, 16); 39 | var trainer = new RLTrainer(agent, Print); 40 | trainer.Train(0.90f, 600, testEpisodes: 20, testInterval: 2, autoSave: false); 41 | agent.Save("A2CFL.st"); 42 | } 43 | 44 | [Fact] 45 | public void ACFLCVal() 46 | { 47 | var frozenlake = new Frozenlake(new[] {0.8f, 0.1f, 0.1f}); 48 | var agent = new A2C(frozenlake, 16); 49 | agent.Load("A2CFL.st"); 50 | var episode = agent.RunEpisodes(10); 51 | episode.ToList().ForEach(Print); 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/PolicyBasedTest/ReinforceTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using DeepSharp.RL.Environs; 3 | using DeepSharp.RL.Trainers; 4 | 5 | namespace TorchSharpTest.RLTest.PolicyBasedTest 6 | { 7 | public class ReinforceTest:AbstractTest 8 | { 9 | public ReinforceTest(ITestOutputHelper testOutputHelper) 10 | : base(testOutputHelper) 11 | { 12 | } 13 | 14 | [Fact] 15 | public void ReinforceKMTest() 16 | { 17 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.7, 0.25}); 18 | var agent = new Reinforce(kArmedBandit, 16); 19 | var trainer = new RLTrainer(agent, Print); 20 | trainer.Train(0.90f, 500, testEpisodes: 20, testInterval: 2, autoSave: false); 21 | agent.Save("ReinKM.st"); 22 | } 23 | 24 | [Fact] 25 | public void ReinforceKMVal() 26 | { 27 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.7, 0.25}); 28 | var agent = new Reinforce(kArmedBandit); 29 | agent.Load("ReinKM.st"); 30 | var trainer = new RLTrainer(agent, Print); 31 | trainer.Val(20); 32 | } 33 | 34 | [Fact] 35 | public void ReinforceFLTest() 36 | { 37 | var frozenlake = new Frozenlake(new[] {0.8f, 0.1f, 0.1f}); 38 | var agent = new Reinforce(frozenlake, 16); 39 | var trainer = new RLTrainer(agent, Print); 40 | trainer.Train(0.95f, 500, testEpisodes: 20, testInterval: 2, autoSave: false); 41 | agent.Save("ReinFrozen.st"); 42 | } 43 | 44 | 45 | [Fact] 46 | public void ReinforceFLVal() 47 | { 48 | var frozenlake = new Frozenlake(new[] {0.8f, 0.1f, 0.1f}); 49 | var agent = new Reinforce(frozenlake); 50 | agent.Load("ReinFrozen.st"); 51 | frozenlake.ChangeToRough(); 52 | var episode = agent.RunEpisode(); 53 | Print(episode); 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/TrainerTest/RLTrainTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using DeepSharp.RL.Environs; 3 | using DeepSharp.RL.Trainers; 4 | 5 | namespace TorchSharpTest.RLTest.TrainerTest 6 | { 7 | public class RLTrainTest : AbstractTest 8 | { 9 | public RLTrainTest(ITestOutputHelper testOutputHelper) 10 | : base(testOutputHelper) 11 | { 12 | } 13 | 14 | [Fact] 15 | public void TrainCreateTest() 16 | { 17 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.75, 0.75}); 18 | var agent = new QLearning(kArmedBandit); 19 | var trainer = new RLTrainer(agent, Print); 20 | trainer.Train(0.9f, 1, "", 20, 2); 21 | } 22 | 23 | 24 | [Fact] 25 | public void TrainCallBackTest() 26 | { 27 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.75, 0.75}); 28 | var agent = new DQN(kArmedBandit, 100, 1000); 29 | var trainer = new RLTrainer(agent, Print) 30 | { 31 | Callback = new TestCallBack() 32 | }; 33 | trainer.Train(0.9f, 500, "", 20); 34 | } 35 | 36 | 37 | private class TestCallBack : TrainerCallBack 38 | { 39 | public override void OnTrainStart() 40 | { 41 | RlTrainer.Print?.Invoke("Hello, this info comes from callback"); 42 | } 43 | 44 | public override void OnTrainEnd() 45 | { 46 | } 47 | 48 | public override void OnLearnStart(int epoch) 49 | { 50 | } 51 | 52 | public override void OnLearnEnd(int epoch, LearnOutcome outcome) 53 | { 54 | } 55 | 56 | public override void OnValStart(int epoch) 57 | { 58 | } 59 | 60 | public override void OnValEnd(int epoch, Episode[] episodes) 61 | { 62 | } 63 | 64 | public override void OnSaveStart() 65 | { 66 | } 67 | 68 | public override void OnSaveEnd() 69 | { 70 | } 71 | } 72 | } 73 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/ValueBasedTest/DQNTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using DeepSharp.RL.Environs; 3 | 4 | namespace TorchSharpTest.RLTest.ValueBasedTest 5 | { 6 | public class DQNTest : AbstractTest 7 | { 8 | public DQNTest(ITestOutputHelper testOutputHelper) 9 | : base(testOutputHelper) 10 | { 11 | } 12 | 13 | 14 | [Fact] 15 | public void TestDQN() 16 | { 17 | var frozenLake = new Frozenlake(); 18 | var dqn = new DQN(frozenLake); 19 | var act = dqn.GetPolicyAct(frozenLake.Observation!.Value!); 20 | Print(act); 21 | } 22 | 23 | [Fact] 24 | public void KArmedBanditMain() 25 | { 26 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.75, 0.75, 0.88, 0.9, 0.75, 0.75}); 27 | var agent = new DQN(kArmedBandit, 100, 1000, batchSize: 16); 28 | Print(kArmedBandit); 29 | 30 | var i = 0; 31 | float reward; 32 | const int testEpisode = 20; 33 | const float predReward = 18f; 34 | do 35 | { 36 | i++; 37 | kArmedBandit.Reset(); 38 | agent.Learn(); 39 | 40 | reward = agent.TestEpisodes(testEpisode); 41 | Print($"{i:D5}:\t{reward}"); 42 | } while (reward <= predReward); 43 | 44 | var episode = agent.RunEpisode(); 45 | Print(episode); 46 | } 47 | 48 | 49 | [Fact] 50 | public void FrozenlakeMain() 51 | { 52 | var frozenlake = new Frozenlake(new[] {0.8f, 0.1f, 0.1f}); 53 | var agent = new DQN(frozenlake, 100, 1000, 0.9f, batchSize: 16); 54 | Print(frozenlake); 55 | 56 | 57 | var i = 0; 58 | float reward; 59 | const int testEpisode = 20; 60 | const float predReward = 0.8f; 61 | do 62 | { 63 | i++; 64 | frozenlake.Reset(); 65 | agent.Learn(); 66 | 67 | reward = agent.TestEpisodes(testEpisode); 68 | Print($"{i:D5}:\t{reward}"); 69 | } while (reward < predReward); 70 | 71 | Print($"Stop after Learn {i}"); 72 | frozenlake.ChangeToRough(); 73 | var episode = agent.RunEpisode(); 74 | Print(episode); 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/ValueBasedTest/MonteCarloTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using DeepSharp.RL.Environs; 3 | 4 | namespace TorchSharpTest.RLTest.ValueBasedTest 5 | { 6 | public class MonteCarloTest : AbstractTest 7 | { 8 | public MonteCarloTest(ITestOutputHelper testOutputHelper) 9 | : base(testOutputHelper) 10 | { 11 | } 12 | 13 | [Fact] 14 | public void KABOnPolicyTest() 15 | { 16 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.75, 0.75}); 17 | var agent = new MonteCarloOnPolicy(kArmedBandit); 18 | Print(kArmedBandit); 19 | 20 | var i = 0; 21 | float reward; 22 | const int testEpisode = 20; 23 | const float predReward = 17f; 24 | do 25 | { 26 | i++; 27 | kArmedBandit.Reset(); 28 | agent.Learn(); 29 | 30 | reward = agent.TestEpisodes(testEpisode); 31 | } while (reward < predReward); 32 | 33 | Print($"Stop after Learn {i}"); 34 | 35 | var episode = agent.RunEpisode(); 36 | Print(episode); 37 | } 38 | 39 | [Fact] 40 | public void KABOffPolicyTest() 41 | { 42 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.75, 0.75}); 43 | var agent = new MonteCarloOffPolicy(kArmedBandit); 44 | Print(kArmedBandit); 45 | 46 | var i = 0; 47 | float reward; 48 | const int testEpisode = 20; 49 | const float predReward = 17f; 50 | do 51 | { 52 | i++; 53 | kArmedBandit.Reset(); 54 | agent.Learn(); 55 | 56 | reward = agent.TestEpisodes(testEpisode); 57 | } while (reward < predReward); 58 | 59 | Print($"Stop after Learn {i}"); 60 | 61 | var episode = agent.RunEpisode(); 62 | Print(episode); 63 | } 64 | 65 | [Fact] 66 | public void FLOnPolicyTest() 67 | { 68 | var frozenlake = new Frozenlake(new[] {0.8f, 0.1f, 0.1f}); 69 | var agent = new MonteCarloOnPolicy(frozenlake, 0.1f, 50); 70 | Print(frozenlake); 71 | 72 | 73 | var i = 0; 74 | float reward = 0; 75 | const int testEpisode = 20; 76 | const float predReward = 0.7f; 77 | do 78 | { 79 | i++; 80 | frozenlake.Reset(); 81 | agent.Learn(); 82 | 83 | if (i % 100 == 0) 84 | { 85 | reward = agent.TestEpisodes(testEpisode); 86 | Print($"{i:D5}:\t{reward}"); 87 | } 88 | } while (reward < predReward); 89 | 90 | Print($"Stop after Learn {i}"); 91 | frozenlake.ChangeToRough(); 92 | var episode = agent.RunEpisode(); 93 | Print(episode); 94 | } 95 | 96 | [Fact] 97 | public void FLOffPolictTest() 98 | { 99 | var frozenlake = new Frozenlake(new[] {0.8f, 0f, 0f}); 100 | var agent = new MonteCarloOffPolicy(frozenlake, 0.1f, 50); 101 | Print(frozenlake); 102 | 103 | 104 | var i = 0; 105 | float reward = 0; 106 | const int testEpisode = 20; 107 | const float predReward = 0.7f; 108 | do 109 | { 110 | i++; 111 | frozenlake.Reset(); 112 | agent.Learn(); 113 | 114 | if (i % 100 == 0) 115 | { 116 | reward = agent.TestEpisodes(testEpisode); 117 | Print($"{i:D5}:\t{reward}"); 118 | } 119 | } while (reward < predReward); 120 | 121 | Print($"Stop after Learn {i}"); 122 | frozenlake.ChangeToRough(); 123 | var episode = agent.RunEpisode(); 124 | Print(episode); 125 | } 126 | } 127 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/ValueBasedTest/QLearningTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using DeepSharp.RL.Environs; 3 | 4 | namespace TorchSharpTest.RLTest.ValueBasedTest 5 | { 6 | public class QLearningTest : AbstractTest 7 | { 8 | public QLearningTest(ITestOutputHelper testOutputHelper) 9 | : base(testOutputHelper) 10 | { 11 | } 12 | 13 | 14 | [Fact] 15 | public void KArmedBanditMain() 16 | { 17 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.75, 0.75}); 18 | var agent = new QLearning(kArmedBandit); 19 | Print(kArmedBandit); 20 | 21 | var i = 0; 22 | float reward; 23 | const int testEpisode = 20; 24 | const float predReward = 17f; 25 | do 26 | { 27 | i++; 28 | kArmedBandit.Reset(); 29 | agent.Learn(); 30 | 31 | reward = agent.TestEpisodes(testEpisode); 32 | } while (reward < predReward); 33 | 34 | Print($"Stop after Learn {i}"); 35 | 36 | var episode = agent.RunEpisode(); 37 | Print(episode); 38 | } 39 | 40 | [Fact] 41 | public void KArmedBanditMainHighLevel() 42 | { 43 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.75, 0.75}); 44 | var agent = new QLearning(kArmedBandit); 45 | Print(kArmedBandit); 46 | 47 | var i = 0; 48 | float reward; 49 | const int testEpisode = 20; 50 | const float predReward = 17f; 51 | do 52 | { 53 | i++; 54 | kArmedBandit.Reset(); 55 | agent.Learn(); 56 | 57 | reward = agent.TestEpisodes(testEpisode); 58 | } while (reward < predReward); 59 | 60 | Print($"Stop after Learn {i}"); 61 | 62 | var episode = agent.RunEpisode(); 63 | Print(episode); 64 | } 65 | 66 | 67 | [Fact] 68 | public void FrozenlakeMain() 69 | { 70 | var frozenlake = new Frozenlake(new[] {0.8f, 0.1f, 0.1f}); 71 | var agent = new QLearning(frozenlake); 72 | Print(frozenlake); 73 | 74 | 75 | var i = 0; 76 | float reward; 77 | const int testEpisode = 20; 78 | const float predReward = 0.7f; 79 | do 80 | { 81 | i++; 82 | frozenlake.Reset(); 83 | agent.Learn(); 84 | 85 | reward = agent.TestEpisodes(testEpisode); 86 | } while (reward < predReward); 87 | 88 | Print($"Stop after Learn {i}"); 89 | frozenlake.ChangeToRough(); 90 | var episode = agent.RunEpisode(); 91 | Print(episode); 92 | } 93 | } 94 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/RLTest/ValueBasedTest/SARSATest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.RL.Agents; 2 | using DeepSharp.RL.Environs; 3 | 4 | namespace TorchSharpTest.RLTest.ValueBasedTest 5 | { 6 | public class SARSATest : AbstractTest 7 | { 8 | public SARSATest(ITestOutputHelper testOutputHelper) 9 | : base(testOutputHelper) 10 | { 11 | } 12 | 13 | 14 | [Fact] 15 | public void KArmedBanditMain() 16 | { 17 | var kArmedBandit = new KArmedBandit(new[] {0.4, 0.85, 0.75, 0.75}); 18 | var agent = new SARSA(kArmedBandit); 19 | Print(kArmedBandit); 20 | 21 | var i = 0; 22 | float reward; 23 | const int testEpisode = 20; 24 | const float predReward = 17f; 25 | do 26 | { 27 | i++; 28 | kArmedBandit.Reset(); 29 | agent.Learn(); 30 | 31 | reward = agent.TestEpisodes(testEpisode); 32 | } while (reward < predReward); 33 | 34 | Print($"Stop after Learn {i}"); 35 | 36 | var episode = agent.RunEpisode(); 37 | Print(episode); 38 | } 39 | 40 | 41 | [Fact] 42 | public void FrozenlakeMain() 43 | { 44 | var frozenlake = new Frozenlake(new[] {0.8f, 0.1f, 0.1f}); 45 | var agent = new SARSA(frozenlake); 46 | Print(frozenlake); 47 | 48 | 49 | var i = 0; 50 | float reward; 51 | const int testEpisode = 20; 52 | const float predReward = 0.7f; 53 | do 54 | { 55 | i++; 56 | frozenlake.Reset(); 57 | agent.Learn(); 58 | 59 | reward = agent.TestEpisodes(testEpisode); 60 | } while (reward < predReward); 61 | 62 | Print($"Stop after Learn {i}"); 63 | frozenlake.ChangeToRough(); 64 | var episode = agent.RunEpisode(); 65 | Print(episode); 66 | } 67 | } 68 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/SampleDataset/Iris.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.Dataset; 2 | 3 | namespace TorchSharpTest.SampleDataset 4 | { 5 | public class Iris : DataView 6 | { 7 | /// 8 | /// 9 | public Iris() 10 | { 11 | } 12 | 13 | [StreamHeader(0)] public long Label { set; get; } 14 | [StreamHeader(1)] public float SepalLength { set; get; } 15 | [StreamHeader(2)] public float SepalWidth { set; get; } 16 | [StreamHeader(3)] public float PetalLength { set; get; } 17 | [StreamHeader(4)] public float PetalWidth { set; get; } 18 | 19 | public override torch.Tensor GetFeatures() 20 | { 21 | return torch.tensor(new[] {SepalLength, SepalWidth, PetalLength, PetalWidth}); 22 | } 23 | 24 | public override torch.Tensor GetLabels() 25 | { 26 | return torch.tensor(new[] {Label}); 27 | } 28 | 29 | public override string ToString() 30 | { 31 | return $"Label:{Label}\t" + 32 | $"SepalLength:{SepalLength:F2}\tSepalWidth:{SepalWidth:F2}\t" + 33 | $"PetalLength:{PetalLength:F2}\tPetalWidth:{PetalWidth:F2}"; 34 | } 35 | 36 | /// 37 | /// return a random Iris 38 | /// 39 | /// 40 | public static Iris RandomIris() 41 | { 42 | var randomSource = new Random(); 43 | return new Iris 44 | { 45 | Label = randomSource.Next(0, 3), 46 | PetalLength = randomSource.NextSingle() * 4, 47 | PetalWidth = randomSource.NextSingle() * 4, 48 | SepalLength = randomSource.NextSingle() * 4, 49 | SepalWidth = randomSource.NextSingle() * 4 50 | }; 51 | } 52 | } 53 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/SampleDataset/IrisOneHot.cs: -------------------------------------------------------------------------------- 1 | namespace TorchSharpTest.SampleDataset 2 | { 3 | public class IrisOneHot : Iris 4 | { 5 | /// 6 | /// OneHot [0,0,1] 代表分类3 7 | /// 8 | /// 9 | public override torch.Tensor GetLabels() 10 | { 11 | var array = Enumerable.Repeat(0, 3).Select(a => (float) a).ToArray(); 12 | array[Label] = 1; 13 | return torch.tensor(array); 14 | } 15 | } 16 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/TorchSharpTest.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | net7.0 5 | enable 6 | enable 7 | 8 | false 9 | true 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | runtime; build; native; contentfiles; analyzers; buildtransitive 19 | all 20 | 21 | 22 | runtime; build; native; contentfiles; analyzers; buildtransitive 23 | all 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /src/TorchSharpTest/TorchSharpTest.csproj.DotSettings: -------------------------------------------------------------------------------- 1 |  2 | True -------------------------------------------------------------------------------- /src/TorchSharpTest/TorchTests/DataSetTest.cs: -------------------------------------------------------------------------------- 1 | using DeepSharp.Dataset; 2 | using TorchSharpTest.SampleDataset; 3 | 4 | namespace TorchSharpTest.TorchTests 5 | { 6 | public class DataSetTest : AbstractTest 7 | { 8 | public DataSetTest(ITestOutputHelper testOutputHelper) 9 | : base(testOutputHelper) 10 | { 11 | } 12 | 13 | [Fact] 14 | public void StreamDatasetTest() 15 | { 16 | var dataset = new Dataset(@"F:\Iris\iris-train.txt"); 17 | var res = dataset.GetTensor(0); 18 | Print(res); 19 | } 20 | 21 | [Fact] 22 | public void OriginalDataloaderTest() 23 | { 24 | var dataset = new Dataset(@"F:\Iris\iris-train.txt"); 25 | var device = new torch.Device(DeviceType.CUDA); 26 | var dataloader = 27 | new torch.utils.data.DataLoader(dataset, 4, 28 | DataView.FromDataViews, true, device); 29 | 30 | using var iterator = dataloader.GetEnumerator(); 31 | while (iterator.MoveNext()) 32 | { 33 | var current = iterator.Current; 34 | Print(current); 35 | } 36 | } 37 | 38 | [Fact] 39 | public void DataLoaderTest() 40 | { 41 | var dataset = new Dataset(@"F:\Iris\iris-train.txt"); 42 | var dataConfig = new DataLoaderConfig 43 | { 44 | Device = new torch.Device(DeviceType.CUDA) 45 | }; 46 | var dataloader = new DataLoader(dataset, dataConfig); 47 | 48 | using var iterator = dataloader.GetEnumerator(); 49 | while (iterator.MoveNext()) 50 | { 51 | var current = iterator.Current; 52 | Print(current); 53 | } 54 | } 55 | 56 | [Fact] 57 | public async void InfiniteDataLoaderTest() 58 | { 59 | var dataset = new Dataset(@"F:\Iris\iris-train.txt"); 60 | var dataConfig = new DataLoaderConfig(); 61 | var dataloader = new InfiniteDataLoader(dataset, dataConfig); 62 | 63 | await foreach (var a in dataloader.GetBatchSample(100)) 64 | { 65 | var array = a.Labels.data().ToArray(); 66 | Print($"{string.Join(";", array)}"); 67 | } 68 | } 69 | } 70 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/TorchTests/ModuleTest.cs: -------------------------------------------------------------------------------- 1 | using TorchSharpTest.DemoTest; 2 | using static TorchSharp.torch; 3 | 4 | namespace TorchSharpTest.TorchTests 5 | { 6 | public class ModuleTest : AbstractTest 7 | { 8 | public ModuleTest(ITestOutputHelper testOutputHelper) 9 | : base(testOutputHelper) 10 | { 11 | } 12 | 13 | private Device device => new(DeviceType.CUDA); 14 | private string savePath => "test.txt"; 15 | 16 | [Fact] 17 | public void LinearTest() 18 | { 19 | var linear = Linear(4, 5, device: device); 20 | var x = randn(3, 5, 4, device: device); 21 | var y = linear.forward(x); 22 | Print(y); 23 | } 24 | 25 | 26 | [Fact] 27 | public void NetTest() 28 | { 29 | var x = zeros(3, 4).to(device); 30 | 31 | var net = new DemoNet(4, 3).to(device); 32 | var y = net.forward(x); 33 | 34 | Print(y); 35 | } 36 | 37 | [Fact] 38 | public void NetSaveTest() 39 | { 40 | if (File.Exists(savePath)) File.Delete(savePath); 41 | var net = new DemoNet(4, 3); 42 | net.save(savePath); 43 | 44 | var a = from_array(new float[] {1, 2, 3, 4}); 45 | var c = net.forward(a); 46 | var str = string.Join(",", c.data().ToArray()); 47 | Print(str); 48 | 49 | c = net.forward(a); 50 | str = string.Join(",", c.data().ToArray()); 51 | Print(str); 52 | } 53 | 54 | [Fact] 55 | public void NetLoadTest() 56 | { 57 | var net = new DemoNet(4, 3); 58 | net.load(savePath); 59 | var a = from_array(new float[] {1, 2, 3, 4}); 60 | var c = net.forward(a); 61 | var str = string.Join(",", c.data().ToArray()); 62 | Print(str); 63 | } 64 | } 65 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/TorchTests/SaveLoadTest.cs: -------------------------------------------------------------------------------- 1 | namespace TorchSharpTest.TorchTests 2 | { 3 | public class SaveLoadTest : AbstractTest 4 | { 5 | public SaveLoadTest(ITestOutputHelper testOutputHelper) 6 | : base(testOutputHelper) 7 | { 8 | } 9 | 10 | 11 | public string Location => "Test.ts"; 12 | 13 | [Fact] 14 | public void TestSave() 15 | { 16 | if (File.Exists(Location)) File.Delete(Location); 17 | using var conv = Sequential( 18 | Conv2d(100, 10, 5), 19 | Linear(100, 10)); 20 | conv.save(Location); 21 | } 22 | 23 | [Fact] 24 | public void TestLoad() 25 | { 26 | using var loaded = Sequential( 27 | Conv2d(100, 10, 5), 28 | Linear(100, 10)); 29 | loaded.load(Location); 30 | } 31 | } 32 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/TorchTests/TensorTest.cs: -------------------------------------------------------------------------------- 1 | using FluentAssertions; 2 | 3 | namespace TorchSharpTest.TorchTests 4 | { 5 | public class TensorTest : AbstractTest 6 | { 7 | public TensorTest(ITestOutputHelper testOutputHelper) 8 | : base(testOutputHelper) 9 | { 10 | } 11 | 12 | [Fact] 13 | public void CreateRandTensor() 14 | { 15 | var device = new torch.Device(DeviceType.CUDA); 16 | var tensor = torch.randn(3, 5, 4, device: device); 17 | Print(tensor.ToString()); 18 | } 19 | 20 | [Fact] 21 | public void CreateArrayTensor() 22 | { 23 | var tensor = torch.from_array(new float[] {1, 2}).to(DeviceType.CUDA); 24 | Print(tensor); 25 | } 26 | 27 | [Fact] 28 | public void CreateOnesTensor() 29 | { 30 | var tensor = torch.ones(2, 3).to(DeviceType.CUDA); 31 | Print(tensor); 32 | } 33 | 34 | 35 | [Fact] 36 | public void TestAnyAndAll() 37 | { 38 | var a = torch.tensor(new long[] {1, 2, 3}); 39 | var b = torch.tensor(new long[] {0, 0, 0}); 40 | torch.all(b < a).Equals(torch.tensor(true)).Should().Be(true); 41 | } 42 | 43 | 44 | [Fact] 45 | public void TestM() 46 | { 47 | var probs = torch.tensor(new[] {1 / 3f, 1 / 3f, 1 / 3f}); 48 | var sample = torch.multinomial(probs, 1000, true); 49 | 50 | var arr = sample.data(); 51 | var a1 = arr.Count(a => a == 0); 52 | var a2 = arr.Count(a => a == 1); 53 | var a3 = arr.Count(a => a == 2); 54 | Print($"{a1},{a2},{a3}"); 55 | } 56 | 57 | 58 | [Fact] 59 | public void TestGreatAndLess() 60 | { 61 | var input = torch.tensor(new[] {1.1f, 3f, 3.3f, 5.2f}); 62 | var targetLower = 2f; 63 | var targetUpper = 4f; 64 | 65 | var mean = (targetUpper + targetLower) / 2; 66 | var half = (targetUpper - targetLower) / 2; 67 | var dis = 1 - torch.abs(input - mean) / half; 68 | Print(dis); 69 | 70 | var condition = dis.greater(0); 71 | Print(condition); 72 | 73 | var final = dis.where(condition, 0); 74 | Print(final); 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /src/TorchSharpTest/Usings.cs: -------------------------------------------------------------------------------- 1 | global using Xunit; 2 | global using Xunit.Abstractions; 3 | global using TorchSharp; 4 | global using static TorchSharp.torch.nn; --------------------------------------------------------------------------------