├── .github └── FUNDING.yml ├── .gitignore ├── Awesome_Deep_Reinforcement_Learning_List.md ├── LICENSE ├── README.md ├── __init__.py ├── docs ├── .readthedocs.yaml ├── Makefile ├── build │ └── init.py ├── make.bat └── source │ ├── RLSolver │ ├── baseline │ ├── build_models │ ├── helloworld │ ├── overview.rst │ └── problems │ ├── about │ ├── cloud.rst │ ├── ensemble.rst │ ├── init.py │ ├── overview.rst │ └── parallel.rst │ ├── algorithms │ ├── a2c.rst │ ├── ddpg.rst │ ├── double_dqn.rst │ ├── dqn.rst │ ├── init.py │ ├── maddpg.rst │ ├── mappo.rst │ ├── matd3.rst │ ├── ppo.rst │ ├── qmix.rst │ ├── redq.rst │ ├── rode.rst │ ├── sac.rst │ ├── td3.rst │ └── vdn.rst │ ├── api │ ├── .DS_Store │ ├── config.rst │ ├── evaluator.rst │ ├── learner.rst │ ├── replay.rst │ ├── run.rst │ ├── utils.rst │ └── worker.rst │ ├── conf.py │ ├── faq-en.rst │ ├── faq-zh.rst │ ├── helloworld │ ├── agent.rst │ ├── env.rst │ ├── intro.rst │ ├── net.rst │ ├── quickstart.rst │ └── run.rst │ ├── images │ ├── BipedalWalker-v3_1.gif │ ├── BipedalWalker-v3_2.gif │ ├── File_structure.png │ ├── H-term.png │ ├── LunarLander.gif │ ├── LunarLanderTwinDelay3.gif │ ├── bellman.png │ ├── efficiency.png │ ├── envs.png │ ├── fin.png │ ├── framework.png │ ├── framework2.png │ ├── init.py │ ├── isaacgym.gif │ ├── learning_curve.png │ ├── logo.jpg │ ├── logo.png │ ├── overview.jpg │ ├── overview_1.png │ ├── overview_2.png │ ├── overview_3.png │ ├── overview_4.png │ ├── parallelism.png │ ├── performance1.png │ ├── performance2.png │ ├── pseudo.png │ ├── reacher_v2_1.gif │ ├── recursive.png │ ├── samples.png │ ├── tab.png │ ├── test1.png │ ├── test2.png │ └── time.png │ ├── index.rst │ ├── other │ └── faq.rst │ └── tutorial │ ├── BipedalWalker-v3.rst │ ├── Creating_VecEnv.rst │ ├── LunarLanderContinuous-v2.rst │ ├── elegantrl-podracer.rst │ ├── finrl-podracer.rst │ ├── hterm.rst │ ├── isaacgym.rst │ └── redq.rst ├── elegantrl ├── __init__.py ├── agents │ ├── AgentBase.py │ ├── AgentDQN.py │ ├── AgentEmbedDQN.py │ ├── AgentPPO.py │ ├── AgentSAC.py │ ├── AgentTD3.py │ ├── MAgentMADDPG.py │ ├── MAgentMAPPO.py │ ├── MAgentQMix.py │ ├── MAgentVDN.py │ └── __init__.py ├── envs │ ├── CustomGymEnv.py │ ├── PlanIsaacGymEnv.py │ ├── PointChasingEnv.py │ ├── StockTradingEnv.py │ └── __init__.py └── train │ ├── __init__.py │ ├── config.py │ ├── evaluator.py │ ├── replay_buffer.py │ └── run.py ├── examples ├── __init__.py ├── demo_A2C_PPO.py ├── demo_A2C_PPO_discrete.py ├── demo_DDPG_TD3_SAC.py ├── demo_DDPG_TD3_SAC_with_PER.py ├── demo_DQN_variants.py ├── demo_DQN_variants_embed.py ├── demo_FinRL_ElegantRL_China_A_shares.py ├── list_gym_envs.py ├── plan_BipedalWalker-v3.py ├── plan_DDPG_H.py ├── plan_Hopper-v2_H.py ├── plan_Isaac_Gym.py ├── plan_PPO_H.py ├── plan_PaperTradingEnv_PPO.py ├── plan_mujoco_draw_obj_h.py ├── plan_mujoco_render.py ├── tutorial_Hopper-v3.py └── tutorial_LunarLanderContinous-v2.py ├── figs ├── BipdealWalkerHardCore_313score.png ├── BipedalWalkerHardcore-v2-total-668kb.gif ├── ElegantRL.png ├── File_structure.png ├── LunarLanderTwinDelay3.gif ├── RL_survey_2020.pdf ├── RL_survey_2020.png ├── SB3_vs_ElegantRL.png ├── envs.png ├── icon.jpg ├── original.gif ├── performance.png ├── performance1.png └── performance2.png ├── helloworld ├── README.md ├── StockTradingVmapEnv.py ├── erl_agent.py ├── erl_config.py ├── erl_env.py ├── erl_run.py ├── erl_tutorial_DDPG.py ├── erl_tutorial_DQN.py ├── erl_tutorial_PPO.py ├── helloworld_DQN_single_file.py ├── helloworld_PPO_single_file.py ├── helloworld_SAC_TD3_DDPG_single_file.py ├── singlfile.rst └── unit_tests │ ├── check_agent.py │ ├── check_config.py │ ├── check_env.py │ ├── check_net.py │ └── check_run.py ├── requirements.txt ├── rlsolver ├── LICENSE ├── README.md ├── __init__.py ├── data │ ├── __init__.py │ ├── gset │ │ └── gset_14.txt │ ├── solomon-instances │ │ └── c101.txt │ ├── syn_BA │ │ └── BA_100_ID0.txt │ ├── syn_ER │ │ └── ER_100_ID0.txt │ ├── syn_PL │ │ └── PL_100_ID0.txt │ └── tsplib │ │ ├── a5.tsp │ │ └── berlin52.tsp ├── docs │ ├── .readthedocs.yaml │ ├── Makefile │ ├── __init__.py │ ├── build │ │ ├── __init__.py │ │ └── init.py │ ├── make.bat │ └── source │ │ ├── RLSolver │ │ ├── baseline │ │ ├── build_models │ │ ├── helloworld │ │ ├── overview.rst │ │ └── problems │ │ ├── __init__.py │ │ ├── about │ │ └── overview.rst │ │ ├── algorithms │ │ └── REINFORCE.rst │ │ ├── api │ │ ├── config.rst │ │ └── utils.rst │ │ ├── conf.py │ │ ├── helloworld │ │ └── quickstart.rst │ │ ├── images │ │ ├── bellman.png │ │ └── envs.png │ │ ├── index.rst │ │ ├── other │ │ └── faq.rst │ │ └── tutorial │ │ └── maxcut.rst ├── envs │ ├── __init__.py │ └── env_mcpg_maxcut.py ├── fig │ ├── RLSolver_framework.png │ ├── RLSolver_structure.png │ ├── objectives_epochs.png │ ├── parallel_sims_maxcut.png │ ├── parallel_sims_pattern.png │ ├── sampling_efficiency_maxcut.png │ ├── speed_up_maxcut1.png │ ├── speed_up_maxcut2.png │ └── work_flow.png ├── main.py ├── methods │ ├── VRPTW_algs │ │ ├── Customer.py │ │ ├── ESPPRC1.py │ │ ├── ESPPRC2.py │ │ ├── ESPPRC_demo.py │ │ ├── ESPPRC_demo2.py │ │ ├── ESPPRC_demo3.py │ │ ├── Label.py │ │ ├── Vehicle.py │ │ ├── column_generation.py │ │ ├── config.py │ │ ├── impact_heuristic.py │ │ ├── main.py │ │ └── util.py │ ├── __init__.py │ ├── config.py │ ├── genetic_algorithm.py │ ├── greedy.py │ ├── gurobi.py │ ├── quantum.py │ ├── random_walk.py │ ├── scip.py │ ├── sdp.py │ ├── simulated_annealing.py │ ├── tsp_algs │ │ ├── __init__.py │ │ ├── christofides.py │ │ ├── config.py │ │ ├── ga.py │ │ ├── gksp.py │ │ ├── ins_c.py │ │ ├── ins_f.py │ │ ├── ins_n.py │ │ ├── main.py │ │ ├── nn.py │ │ ├── opt_2.py │ │ ├── opt_3.py │ │ ├── s_tabu.py │ │ ├── sa.py │ │ └── util.py │ ├── util.py │ ├── util_evaluator.py │ ├── util_generate.py │ ├── util_generate_tsp.py │ ├── util_obj.py │ ├── util_read_data.py │ ├── util_result.py │ └── util_statistics.py ├── requirements.txt └── result │ ├── __init__.py │ ├── automatic_statistical_results.py │ └── c101-10-customers.txt ├── setup.py ├── tutorial_BipedalWalker_v3.ipynb ├── tutorial_Creating_ChasingVecEnv.ipynb ├── tutorial_LunarLanderContinuous_v2.ipynb ├── tutorial_Pendulum_v1.ipynb ├── tutorial_helloworld_DQN_DDPG_PPO.ipynb └── unit_tests ├── __init__.py ├── agents ├── test_agents.py └── test_net.py ├── envs ├── test_env.py ├── test_isaac_env.py └── test_isaac_environments.py └── train ├── test_config.py └── test_evaluator.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [BruceYanghy] 4 | open_collective: # Replace with a single Open Collective username 5 | ko_fi: # Replace with a single Ko-fi username 6 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 7 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 8 | liberapay: # Replace with a single Liberapay username 9 | issuehunt: # Replace with a single IssueHunt username 10 | otechie: # Replace with a single Otechie username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | custom: ['paypal.me/Hongyang'] 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python-created cache files 2 | **/__pycache__ 3 | 4 | # Saved actor/critic networks from training 5 | *.pth 6 | 7 | # Plots, recorders, replay buffers 8 | *plot_*.jpg 9 | *recorder.npy 10 | *replay*.npz 11 | 12 | # Runs created by Isaac Gym 13 | runs/ 14 | 15 | # VS Code 16 | **/.vscode 17 | # JetBrains folder 18 | .idea/ 19 | .env-erl/ 20 | -------------------------------------------------------------------------------- /Awesome_Deep_Reinforcement_Learning_List.md: -------------------------------------------------------------------------------- 1 | ## 2 | 3 | ## Distributed Frameworks 4 | 5 | [1] Massively Parallel Methods for Deep Reinforcement Learning (SGD, first distributed architecture, Gorilla DQN). 6 | 7 | [2] Asynchronous Methods for Deep Reinforcement Learning (SGD, A3C). 8 | 9 | [3] Reinforcement Learning through Asynchronous Advantage Actor-Critic on a GPU (A3C on GPU). 10 | 11 | [4] Efficient Parallel Methods for Deep Reinforcement Learning (Batched A2C, GPU). 12 | 13 | [5] Evolution Strategies as a Scalable Alternative to Reinforcement Learning (ES). 14 | 15 | [6] Deep Neuroevolution: Genetic Algorithms Are a Competitive Alternative for Training Deep Neural Networks for 16 | Reinforcement Learning (ES). 17 | 18 | [7] RLlib: Abstractions for Distributed Reinforcement Learning (Library) 19 | 20 | [8] Distributed Deep Reinforcement Learning: Learn how to play Atari games in 21 minutes (Batched A3C). 21 | 22 | [9] Distributed Prioritized Experience Replay (Ape-X, distributed replay buffer). 23 | 24 | [10] IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures (CPU+GPU). 25 | 26 | [11] Accelerated Methods for Deep Reinforcement Learning (Simulation Acceleration). 27 | 28 | [12] GPU-Accelerated Robotic Simulation for Distributed Reinforcement Learning (Simulation Acceleration). 29 | 30 | ## 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 AI4Finance Foundation Inc. 2 | All rights reserved. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/__init__.py -------------------------------------------------------------------------------- /docs/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | python: 2 | setup_py_install: true 3 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/build/init.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/RLSolver/baseline: -------------------------------------------------------------------------------- 1 | gurobi 2 | random_walk 3 | greedy 4 | simulated_annealing 5 | -------------------------------------------------------------------------------- /docs/source/RLSolver/build_models: -------------------------------------------------------------------------------- 1 | build_models 2 | 3 | 4 | **Graph max-cut** 5 | 6 | 7 | 8 | 9 | 10 | ## MIMO 11 | 12 | 13 | 14 | ## Compressive sensing 15 | 16 | -------------------------------------------------------------------------------- /docs/source/RLSolver/helloworld: -------------------------------------------------------------------------------- 1 | helloworld 2 | -------------------------------------------------------------------------------- /docs/source/RLSolver/overview.rst: -------------------------------------------------------------------------------- 1 | Overview 2 | ============= 3 | 4 | One sentence summary: RLSolver is a high-performance RL Solver. 5 | 6 | We aim to find high-quality optimum, or even (nearly) global optimum, for nonconvex/nonlinear optimizations (continuous variables) and combinatorial optimizations (discrete variables). We provide pretrained neural networks to perform real-time inference for nonconvex optimization problems, including combinatorial optimization problems. 7 | 8 | 9 | The following two key technologies are under active development: 10 | - Massively parallel simuations of gym-environments on GPU, using thousands of CUDA cores and tensor cores. 11 | - Podracer scheduling on a GPU cloud, e.g., DGX-2 SuperPod. 12 | 13 | Key references: 14 | - Mazyavkina, Nina, et al. "Reinforcement learning for combinatorial optimization: A survey." Computers & Operations Research 134 (2021): 105400. 15 | 16 | - Bengio, Yoshua, Andrea Lodi, and Antoine Prouvost. "Machine learning for combinatorial optimization: a methodological tour d’horizon." European Journal of Operational Research 290.2 (2021): 405-421. 17 | 18 | - Makoviychuk, Viktor, et al. "Isaac Gym: High performance GPU based physics simulation for robot learning." Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track (Round 2). 2021. 19 | 20 | - Nair, Vinod, et al. "Solving mixed integer programs using neural networks." arXiv preprint arXiv:2012.13349 (2020). 21 | 22 | MCMC: 23 | - Maxcut 24 | - MIMO Beamforming in 5G/6G. 25 | - Classical NP-Hard problems. 26 | - Classical Simulation of Quantum Circuits. 27 | - Compressive Sensing. 28 | - Portfolio Management. 29 | - OR-Gym. 30 | 31 | File Structure: 32 | ``` 33 | -RLSolver 34 | -├── opt_methods 35 | -| ├──branch-and-bound.py 36 | -| └──cutting_plane.py 37 | -├── helloworld 38 | -| ├──maxcut 39 | -| ├──data 40 | -| ├──result 41 | -| ├──mcmc.py 42 | -| ├──l2a.py 43 | -└── rlsolver (main folder) 44 | - ├── mcmc 45 | - | ├── _base 46 | - | └── maxcut 47 | - | └── tsp 48 | - | ├── portfolio_management 49 | - |── rlsolver_learn2opt 50 | - | ├── mimo 51 | - | ├── tensor_train 52 | - └── utils 53 | - └── maxcut.py 54 | - └── maxcut_gurobi.py 55 | - └── tsp.py 56 | - └── tsp_gurobi.py 57 | ``` 58 | 59 | 60 | **RLSolver features high-performance and stability:** 61 | 62 | **High-performance**: it can find high-quality optimum, or even (nearly) global optimum. 63 | 64 | **Stable**: it leverages computing resource to implement the Hamiltonian-term as an add-on regularization to DRL algorithms. Such an add-on H-term utilizes computing power (can be computed in parallel on GPU) to search for the "minimum-energy state", corresponding to the stable state of a system. 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /docs/source/RLSolver/problems: -------------------------------------------------------------------------------- 1 | Classical nonconvex/nonlinear optimizations (continuous variables) and combinatorial optimizations (discrete variables) are listed here. 2 | - Maxcut 3 | - TSP 4 | - MILP 5 | - MIMO 6 | - Compressive sensing 7 | - TNCO 8 | 9 | -------------------------------------------------------------------------------- /docs/source/about/cloud.rst: -------------------------------------------------------------------------------- 1 | Cloud-native Paradigm 2 | ================================= 3 | To the best of our knowledge, ElegantRL is the first open-source cloud-native framework that supports millions of GPU cores to carry out massively parallel DRL training at multiple levels. 4 | 5 | In this article, we will discuss our motivation and cloud-native designs. 6 | 7 | Why cloud-native? 8 | --------------------------------- 9 | 10 | When you need more computing power and storage for your task, running on a cloud may be a more preferable choice than buying racks of machines. Due to its accessible and automated nature, the cloud has been a disruptive force in many deep learning tasks, such as natural langauge processing, image recognition, video synthesis, etc. 11 | 12 | Therefore, we embrace the cloud computing platforms to: 13 | 14 | - build a serverless application framework that performs the entire life-cycle (simulate-learn-deploy) of DRL applications on low-cost cloud computing power. 15 | - support for single-click training for sophisticated DRL problems (compute-intensive and time-consuming) with automatic hyper-parameter tuning. 16 | - provide off-the-shelf APIs to free users from full-stack development and machine learning implementations, e.g., DRL algorithms, ensemble methods, performance analysis. 17 | 18 | Our goal is to allow for wider DRL applications and faster development life cycles that can be created by smaller teams. One simple example of this is the following workflow. 19 | 20 | A user wants to train a trading agent using minute-level NASDAQ 100 constituent stock dataset, a compute-intensive task as the dimensions of the dataset increase, e.g., the number of stocks, the length of period, the number of features. Once the user finishes constructing the environment/simulator, she can directly submit the job to our framework. Say the user has no idea which DRL algorithms she should use and how to setup the hyper-parameters, the framework can automatically initialize agents with different algorithms and hyper-parameter to search the best combination. All data is stored in the cloud storage and the computing is parallized on cloud clusters. 21 | 22 | A cloud-native solution 23 | ----------------------------------------------------------------------- 24 | 25 | ElegantRL follows the cloud-native paradigm in the form of microservice, containerization, and orchestration. 26 | 27 | **Microservices**: ElegantRL organizes a DRL agent as a collection of microservices, including orchestrator, worker, learner, evaluator, etc. Each microservice has specialized functionality and connects to other microservices through clear-cut APIs. The microservice structure makes ElegantRL a highly modularized framework and allows practitioners to use and customize without understanding its every detail. 28 | 29 | **Containerization**: An agent is encapsulated into a pod (the basic deployable object in Kubernetes (K8s)), while each microservice within the agent is mapped to a container (a lightweight and portable package of software). On the cloud, microservice and containerization together offer significant benefits in asynchronous parallelism, fault isolation, and security. 30 | 31 | **Orchestration**: ElegantRL employs K8s to orchestrate pods and containers, which automates the deployment and management of the DRL application on the cloud. Our goal is to free developers and practitioners from sophisticated distributed machine learning. 32 | 33 | We provide two different scheduling mechanism on the cloud, namely generational evolution and tournament-based evolution. 34 | 35 | A tutorial on generational evolution is available `here `_. 36 | 37 | A tutorial on tournament-based evolution is available `here `_. 38 | -------------------------------------------------------------------------------- /docs/source/about/ensemble.rst: -------------------------------------------------------------------------------- 1 | Ensemble Methods 2 | =============================== 3 | 4 | -------------------------------------------------------------------------------- /docs/source/about/init.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/source/about/overview.rst: -------------------------------------------------------------------------------- 1 | Key Concepts and Features 2 | ============= 3 | 4 | One sentence summary: in deep reinforcement learning (DRL), an agent learns by continuously interacting with an unknown environment, in a trial-and-error manner, making sequential decisions under uncertainty and achieving a balance between exploration (of uncharted territory) and exploitation (of current knowledge). 5 | 6 | The lifecycle of a DRL application consists of three stages: *simulation*, *learning*, and *deployment*. Our goal is to leverage massive computing power to address three major challenges existed in these three stages: 7 | - simulation speed bottleneck; 8 | - sensitivity to hyper-parameters; 9 | - unstable generalization ability. 10 | 11 | ElegantRL is a massively parallel framework for cloud-native DRL applications implemented in PyTorch: 12 | - We embrace the accessibility of cloud computing platforms and follow a cloud-native paradigm in the form of containerization, microservices, and orchestration, to ensure fast and robust execution on a cloud. 13 | - We fully exploit the parallelism of DRL algorithms at multiple levels, namely the worker/learner parallelism within a container, the pipeline parallelism (asynchronous execution) over multiple microservices, and the inherent parallelism of the scheduling task at an orchestrator. 14 | - We take advantage of recent technology breakthroughs in massively parallel simulation, population-based training that implicitly searches for optimal hyperparameters, and ensemble methods for variance reduction. 15 | 16 | 17 | **ElegantRL features strong scalability, elasticity and stability and allows practitioners to conduct efficient training from one GPU to hundreds of GPUs on a cloud:** 18 | 19 | **Scalable**: the multi-level parallelism results in high scalability. One can train a population with hundreds of agents, where each agent employs thousands of workers and tens of learners. Therefore, ElegantRL can easily scale out to a cloud with hundreds or thousands of nodes. 20 | 21 | **Elastic**: ElegantRL features strong elasticity on the cloud. The resource allocation can be made according to the numbers of workers, learners, and agents and the unit resource assigned to each of them. We allow a flexible adaptation to meet the dynamic resource availability on the cloud or the demands of practitioners. 22 | 23 | **Stable**: With the massively computing power of a cloud, ensemble methods and population-based training will greatly improve the stability of DRL algorithms. Furthermore, ElegantRL leverages computing resource to implement the Hamiltonian-term as an add-on regularization to model-free DRL algorithms. Such an add-on H-term utilizes computing power (can be computed in parallel on GPU) to search for the "minimum-energy state", corresponding to the stable state of a system. Altogether, ElegantRL demonstrates a much more stable performance compared to Stable-Baseline3, a popular DRL library devote to stability. 24 | 25 | **Accessible**: ElegantRL is a highly modularized framework and maintains ElegantRL-HelloWorld for beginners to get started. We also help users overcome the learning curve by providing API documentations, Colab tutorials, frequently asked questions (FAQs), and demos, e.g., on OpenAI Gym, MuJoCo, Isaac Gym. 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /docs/source/about/parallel.rst: -------------------------------------------------------------------------------- 1 | Muti-level Parallelism 2 | ============================================== 3 | 4 | ElegantRL is a massively parallel framework for DRL algorithms. In this article, we will explain how we map the multi-level parallelism of DRL algorithms to a cloud, namely the worker/learner parallelism within a container, the pipeline parallelism (asynchronous execution) over multiple microservices, and the inherent parallelism of the scheduling task at an orchestrator. 5 | 6 | Here, we follow a *bottom-up* approach to describe the parallelism at multiple levels. 7 | 8 | .. image:: ../images/parallelism.png 9 | :width: 80% 10 | :align: center 11 | 12 | An overview of the multi-level parallelism supported by ElegantRL. ElegantRL decomposes an agent into worker (a) and learner (b) and pipes their executions through the pipeline parallelism (c). Besides, ElegantRL emphasizes three types of inherent parallelism in DRL algorithms, including population-based training (PBT) (d1), ensemble methods (d2), and multi-agent DRL (d3). 13 | 14 | Worker/Learner parallelism 15 | ----------------------------------------------------------- 16 | 17 | ElegantRL adopts a worker-learner decomposition of a single agent, decoupling the data sampling process and model learning process. We exploit both the worker parallelism and learner parallelism. 18 | 19 | **Worker parallelism**: a worker generates transitions from interactions of an actor with an environment. As shown in the figure a, ElegantRL supports the recent breakthrough technology, *massively parallel simulation*, with a simulation speedup of 2 ~ 3 orders of magnitude. One GPU can simulate the interactions of one actor with thousands of environments, while existing libraries achieve parallel simulation on hundreds of CPUs. 20 | 21 | Advantage of massively parallel simulation: 22 | - Running thousands of parallel simulations, since the manycore GPU architecture is natually suited for parallel simulations. 23 | - Speeding up the matrix computations of each simulation using GPU tensor cores. 24 | - Reducing the communication overhead by bypassing the bottleneck between CPUs and GPUs. 25 | - Maximizing GPU utilization. 26 | 27 | To achieve massively parallel simulation, ElegantRL supports both user-customized and imported simulator, namely Issac Gym from NVIDIA. 28 | 29 | A tutorial on how to create a GPU-accelerated VecEnv is available `here `_. 30 | 31 | A tutorial on how to utilize Isaac Gym as an imported massively parallel simulator is available `here `_. 32 | 33 | .. note:: 34 | Besides massively parallel simulation on GPUs, we allow users to conduct worker parallelism on classic environments through multiprocessing, e.g., OpenAI Gym and MuJoCo. 35 | 36 | **Learner parallelism**: a learner fetches a batch of transitions to train neural networks, e.g., a critic net and an actor net in the figure b. Multiple critic nets and actor nets of an ensemble method can be trained simultaneously on one GPU. It is different from other libraries that achieve parallel training on multiple CPUs via distributed SGD. 37 | 38 | 39 | Pipeline parallelism 40 | ----------------------------------------------------------- 41 | 42 | We view the worker-learner interaction as a *producer-consumer* model: a worker produces transitions and a learner consumes. As shown in figure c, ElegantRL pipelines the execution of workers and learners, allowing them to run on one GPU asynchronously. We exploit pipeline parallelism in our implementations of off-policy model-free algorithms, including DDPG, TD3, SAC, etc. 43 | 44 | 45 | Inherent parallelism 46 | ----------------------------------------------------------- 47 | ElegantRL supports three types of inherent parallelism in DRL algorithms, including *population-based training*, *ensemble methods*, and *multi-agent DRL*. Each features strong independence and requires little or no communication. 48 | 49 | - Population-based training (PBT): it trains hundreds of agents and obtains a powerful agent, e.g., generational evolution and tournament-based evolution. As shown in figure d1, an agent is encapsulated into a pod on the cloud, whose training is orchestrated by the evaluator and selector of a PBT controller. Population-based training implicitly achieves massively parallel hyper-parameter tuning. 50 | - Ensemble methods: it combines the predictions of multiple models and obtains a better result than each individual result, as shown in figure d2. ElegantRL implements various ensemble methods that perform remarkably well in the following scenarios: 51 | 52 | 1. take an average of multiple critic nets to reduce the variance in the estimation of Q-value; 53 | 2. perform a minimization over multiple critic nets to reduce over-estimation bias; 54 | 3. optimize hyper-parameters by initializing agents in a population with different hyper-parameters. 55 | 56 | - Multi-agent DRL: in the cooperative, competitive, or mixed settings of MARL, multiple parallel agents interact with the same environment. During the training process, there is little communication among those parallel agents. 57 | -------------------------------------------------------------------------------- /docs/source/algorithms/a2c.rst: -------------------------------------------------------------------------------- 1 | .. _a2c: 2 | 3 | 4 | A2C 5 | ========== 6 | 7 | `Advantage Actor-Critic (A2C) `_ is a synchronous and deterministic version of Asynchronous Advantage Actor-Critic (A3C). It combines value optimization and policy optimization approaches. This implementation of the A2C algorithm is built on PPO algorithm for simplicity, and it supports the following extensions: 8 | 9 | - Target network: ✔️ 10 | - Gradient clipping: ✔️ 11 | - Reward clipping: ❌ 12 | - Generalized Advantage Estimation (GAE): ✔️ 13 | - Discrete version: ✔️ 14 | 15 | .. warning:: 16 | The implementation of A2C serves as a pedagogical goal. For practitioners, we recommend using the PPO algorithm for training agents. Without the trust-region and clipped ratio, hyper-parameters in A2C, e.g., ``repeat_times``, need to be fine-tuned to avoid performance collapse. 17 | 18 | 19 | Code Snippet 20 | ------------ 21 | 22 | .. code-block:: python 23 | 24 | import torch 25 | from elegantrl.run import train_and_evaluate 26 | from elegantrl.config import Arguments 27 | from elegantrl.train.config import build_env 28 | from elegantrl.agents.AgentA2C import AgentA2C 29 | 30 | # train and save 31 | args = Arguments(env=build_env('Pendulum-v0'), agent=AgentA2C()) 32 | args.cwd = 'demo_Pendulum_A2C' 33 | args.env.target_return = -200 34 | args.reward_scale = 2 ** -2 35 | train_and_evaluate(args) 36 | 37 | # test 38 | agent = AgentA2C() 39 | agent.init(args.net_dim, args.state_dim, args.action_dim) 40 | agent.save_or_load_agent(cwd=args.cwd, if_save=False) 41 | 42 | env = build_env('Pendulum-v0') 43 | state = env.reset() 44 | episode_reward = 0 45 | for i in range(2 ** 10): 46 | action = agent.select_action(state) 47 | next_state, reward, done, _ = env.step(action) 48 | 49 | episode_reward += reward 50 | if done: 51 | print(f'Step {i:>6}, Episode return {episode_reward:8.3f}') 52 | break 53 | else: 54 | state = next_state 55 | env.render() 56 | 57 | 58 | 59 | Parameters 60 | --------------------- 61 | 62 | .. autoclass:: elegantrl.agents.AgentA2C.AgentA2C 63 | :members: 64 | 65 | .. autoclass:: elegantrl.agents.AgentA2C.AgentDiscreteA2C 66 | :members: 67 | 68 | .. _a2c_networks: 69 | 70 | Networks 71 | ------------- 72 | 73 | .. autoclass:: elegantrl.agents.net.ActorPPO 74 | :members: 75 | 76 | .. autoclass:: elegantrl.agents.net.ActorDiscretePPO 77 | :members: 78 | 79 | .. autoclass:: elegantrl.agents.net.CriticPPO 80 | :members: 81 | -------------------------------------------------------------------------------- /docs/source/algorithms/ddpg.rst: -------------------------------------------------------------------------------- 1 | .. _ddpg: 2 | 3 | 4 | DDPG 5 | ========== 6 | 7 | `Deep Deterministic Policy Gradient (DDPG) `_ is an off-policy Actor-Critic algorithm for continuous action space. Since computing the maximum over actions in the target is a challenge in continuous action space, DDPG deals with this using a policy network to compute an action. This implementation provides DDPG and supports the following extensions: 8 | 9 | - Experience replay: ✔️ 10 | - Target network: ✔️ 11 | - Gradient clipping: ✔️ 12 | - Reward clipping: ❌ 13 | - Prioritized Experience Replay (PER): ✔️ 14 | - Ornstein–Uhlenbeck noise: ✔️ 15 | 16 | 17 | .. warning:: 18 | In the DDPG paper, the authors use time-correlated Ornstein-Uhlenbeck Process to add noise to the action output. However, as shown in the later works, the Ornstein-Uhlenbeck Process is an overcomplication that does not have a noticeable effect on performance when compared to uncorrelated Gaussian noise. 19 | 20 | Code Snippet 21 | ------------ 22 | 23 | .. code-block:: python 24 | 25 | import torch 26 | from elegantrl.run import train_and_evaluate 27 | from elegantrl.config import Arguments 28 | from elegantrl.train.config import build_env 29 | from elegantrl.agents.AgentDDPG import AgentDDPG 30 | 31 | # train and save 32 | args = Arguments(env=build_env('Pendulum-v0'), agent=AgentDDPG()) 33 | args.cwd = 'demo_Pendulum_DDPG' 34 | args.env.target_return = -200 35 | args.reward_scale = 2 ** -2 36 | train_and_evaluate(args) 37 | 38 | # test 39 | agent = AgentDDPG() 40 | agent.init(args.net_dim, args.state_dim, args.action_dim) 41 | agent.save_or_load_agent(cwd=args.cwd, if_save=False) 42 | 43 | env = build_env('Pendulum-v0') 44 | state = env.reset() 45 | episode_reward = 0 46 | for i in range(2 ** 10): 47 | action = agent.select_action(state) 48 | next_state, reward, done, _ = env.step(action) 49 | 50 | episode_reward += reward 51 | if done: 52 | print(f'Step {i:>6}, Episode return {episode_reward:8.3f}') 53 | break 54 | else: 55 | state = next_state 56 | env.render() 57 | 58 | 59 | 60 | Parameters 61 | --------------------- 62 | 63 | .. autoclass:: elegantrl.agents.AgentDDPG.AgentDDPG 64 | :members: 65 | 66 | .. _ddpg_networks: 67 | 68 | Networks 69 | ------------- 70 | 71 | .. autoclass:: elegantrl.agents.net.Actor 72 | :members: 73 | 74 | .. autoclass:: elegantrl.agents.net.Critic 75 | :members: 76 | -------------------------------------------------------------------------------- /docs/source/algorithms/double_dqn.rst: -------------------------------------------------------------------------------- 1 | .. _dqn: 2 | 3 | 4 | Double DQN 5 | ========== 6 | 7 | `Double Deep Q-Network (Double DQN) `_ is one of the most important extensions of vanilla DQN. It resolves the issue of overestimation via a simple trick: decoupling the max operation in the target into **action selection** and **action evaluation**. 8 | 9 | Without having to introduce additional networks, we use a Q-network to select the best among the available next actions and use the target network to evaluate its Q-value. This implementation supports the following extensions: 10 | 11 | - Experience replay: ✔️ 12 | - Target network: ✔️ 13 | - Gradient clipping: ✔️ 14 | - Reward clipping: ❌ 15 | - Prioritized Experience Replay (PER): ✔️ 16 | - Dueling network architecture: ✔️ 17 | 18 | 19 | Code Snippet 20 | ------------ 21 | 22 | .. code-block:: python 23 | 24 | import torch 25 | from elegantrl.run import train_and_evaluate 26 | from elegantrl.config import Arguments 27 | from elegantrl.train.config import build_env 28 | from elegantrl.agents.AgentDoubleDQN import AgentDoubleDQN 29 | 30 | # train and save 31 | args = Arguments(env=build_env('CartPole-v0'), agent=AgentDoubleDQN()) 32 | args.cwd = 'demo_CartPole_DoubleDQN' 33 | args.target_return = 195 34 | train_and_evaluate(args) 35 | 36 | # test 37 | agent = AgentDoubleDQN() 38 | agent.init(args.net_dim, args.state_dim, args.action_dim) 39 | agent.save_or_load_agent(cwd=args.cwd, if_save=False) 40 | 41 | env = build_env('CartPole-v0') 42 | state = env.reset() 43 | episode_reward = 0 44 | for i in range(2 ** 10): 45 | action = agent.select_action(state) 46 | next_state, reward, done, _ = env.step(action) 47 | 48 | episode_reward += reward 49 | if done: 50 | print(f'Step {i:>6}, Episode return {episode_reward:8.3f}') 51 | break 52 | else: 53 | state = next_state 54 | env.render() 55 | 56 | 57 | 58 | Parameters 59 | --------------------- 60 | 61 | .. autoclass:: elegantrl.agents.AgentDoubleDQN.AgentDoubleDQN 62 | :members: 63 | 64 | .. _dqn_networks: 65 | 66 | Networks 67 | ------------- 68 | 69 | .. autoclass:: elegantrl.agents.net.QNetTwin 70 | :members: 71 | 72 | .. autoclass:: elegantrl.agents.net.QNetTwinDuel 73 | :members: 74 | -------------------------------------------------------------------------------- /docs/source/algorithms/dqn.rst: -------------------------------------------------------------------------------- 1 | .. _dqn: 2 | 3 | 4 | DQN 5 | ========== 6 | 7 | `Deep Q-Network (DQN) `_ is an off-policy value-based algorithm for discrete action space. It uses a deep neural network to approximate a Q function defined on state-action pairs. This implementation starts from a vanilla Deep Q-Learning and supports the following extensions: 8 | 9 | - Experience replay: ✔️ 10 | - Target network (soft update): ✔️ 11 | - Gradient clipping: ✔️ 12 | - Reward clipping: ❌ 13 | - Prioritized Experience Replay (PER): ✔️ 14 | - Dueling network architecture: ✔️ 15 | 16 | .. note:: 17 | This implementation has no support for reward clipping because we introduce the hyper-paramter ``reward_scale`` for reward scaling as an alternative. We believe that the clipping function may omit information since it cannot map the clipped reward back to the original reward; however, the reward scaling function is able to manipulate the reward back and forth. 18 | 19 | 20 | .. warning:: 21 | PER leads to a faster learning speed and is also critical for environments with sparse rewards. However, a replay buffer with small size may hurt the performance of PER. 22 | 23 | 24 | Code Snippet 25 | ------------ 26 | 27 | .. code-block:: python 28 | 29 | import torch 30 | from elegantrl.run import train_and_evaluate 31 | from elegantrl.config import Arguments 32 | from elegantrl.train.config import build_env 33 | from elegantrl.agents.AgentDQN import AgentDQN 34 | 35 | # train and save 36 | args = Arguments(env=build_env('CartPole-v0'), agent=AgentDQN()) 37 | args.cwd = 'demo_CartPole_DQN' 38 | args.target_return = 195 39 | args.agent.if_use_dueling = True 40 | train_and_evaluate(args) 41 | 42 | # test 43 | agent = AgentDQN() 44 | agent.init(args.net_dim, args.state_dim, args.action_dim) 45 | agent.save_or_load_agent(cwd=args.cwd, if_save=False) 46 | 47 | env = build_env('CartPole-v0') 48 | state = env.reset() 49 | episode_reward = 0 50 | for i in range(2 ** 10): 51 | action = agent.select_action(state) 52 | next_state, reward, done, _ = env.step(action) 53 | 54 | episode_reward += reward 55 | if done: 56 | print(f'Step {i:>6}, Episode return {episode_reward:8.3f}') 57 | break 58 | else: 59 | state = next_state 60 | env.render() 61 | 62 | 63 | 64 | Parameters 65 | --------------------- 66 | 67 | .. autoclass:: elegantrl.agents.AgentDQN.AgentDQN 68 | :members: 69 | 70 | .. _dqn_networks: 71 | 72 | Networks 73 | ------------- 74 | 75 | .. autoclass:: elegantrl.agents.net.QNet 76 | :members: 77 | 78 | .. autoclass:: elegantrl.agents.net.QNetDuel 79 | :members: 80 | -------------------------------------------------------------------------------- /docs/source/algorithms/init.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/source/algorithms/maddpg.rst: -------------------------------------------------------------------------------- 1 | .. _maddpg: 2 | 3 | 4 | MADDPG 5 | ========== 6 | 7 | `Multi-Agent Deep Deterministic Policy Gradient (MADDPG) `_ is a multi-agent reinforcement learning algorithm for continuous action space: 8 | 9 | - Implementation is based on DDPG ✔️ 10 | - Initialize n DDPG agents in MADDPG ✔️ 11 | 12 | Code Snippet 13 | ------------ 14 | 15 | .. code-block:: python 16 | 17 | def update_net(self, buffer, batch_size, repeat_times, soft_update_tau): 18 | buffer.update_now_len() 19 | self.batch_size = batch_size 20 | self.update_tau = soft_update_tau 21 | rewards, dones, actions, observations, next_obs = buffer.sample_batch(self.batch_size) 22 | for index in range(self.n_agents): 23 | self.update_agent(rewards, dones, actions, observations, next_obs, index) 24 | 25 | for agent in self.agents: 26 | self.soft_update(agent.cri_target, agent.cri, self.update_tau) 27 | self.soft_update(agent.act_target, agent.act, self.update_tau) 28 | 29 | return 30 | 31 | Parameters 32 | --------------------- 33 | 34 | .. autoclass:: elegantrl.agents.AgentMADDPG.AgentMADDPG 35 | :members: 36 | .. _maddpg_networks: 37 | 38 | Networks 39 | ------------- 40 | 41 | .. autoclass:: elegantrl.agents.net.Actor 42 | :members: 43 | 44 | .. autoclass:: elegantrl.agents.net.Critic 45 | :members: 46 | 47 | -------------------------------------------------------------------------------- /docs/source/algorithms/mappo.rst: -------------------------------------------------------------------------------- 1 | .. _mappo: 2 | 3 | 4 | MAPPO 5 | ========== 6 | 7 | `Multi-Agent Proximal Policy Optimization (MAPPO) `_ is a variant of PPO which is specialized for multi-agent settings. MAPPO achieves surprisingly strong performance in two popular multi-agent testbeds: the particle-world environments and the Starcraft multi-agent challenge. 8 | 9 | - Shared network parameter for all agents ✔️ 10 | 11 | 12 | MAPPO achieves strong results while exhibiting comparable sample efficiency. 13 | 14 | Code Snippet 15 | ------------ 16 | 17 | .. code-block:: python 18 | 19 | def ppo_update(self, sample, update_actor=True): 20 | 21 | share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \ 22 | value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \ 23 | adv_targ, available_actions_batch = sample 24 | 25 | old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv) 26 | adv_targ = check(adv_targ).to(**self.tpdv) 27 | value_preds_batch = check(value_preds_batch).to(**self.tpdv) 28 | return_batch = check(return_batch).to(**self.tpdv) 29 | active_masks_batch = check(active_masks_batch).to(**self.tpdv) 30 | 31 | # Reshape to do in a single forward pass for all steps 32 | values, action_log_probs, dist_entropy = self.policy.evaluate_actions(share_obs_batch, 33 | obs_batch, 34 | rnn_states_batch, 35 | rnn_states_critic_batch, 36 | actions_batch, 37 | masks_batch, 38 | available_actions_batch, 39 | active_masks_batch) 40 | # actor update 41 | imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch) 42 | 43 | surr1 = imp_weights * adv_targ 44 | surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ 45 | 46 | 47 | Parameters 48 | --------------------- 49 | 50 | .. autoclass:: elegantrl.agents.AgentMAPPO.AgentMAPPO 51 | :members: 52 | .. _mappo_networks: 53 | 54 | Networks 55 | ------------- 56 | 57 | .. autoclass:: elegantrl.agents.net.ActorMAPPO 58 | :members: 59 | 60 | .. autoclass:: elegantrl.agents.net.CriticMAPPO 61 | :members: 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /docs/source/algorithms/matd3.rst: -------------------------------------------------------------------------------- 1 | .. _matd3: 2 | 3 | 4 | MATD3 5 | ========== 6 | 7 | `Multi-Agent TD3 (MATD3) `_ uses double centralized critics to reduce overestimation bias in multi-agent environments. 8 | It combines the improvements of TD3 with MADDPG. 9 | 10 | Code Snippet 11 | ------------ 12 | .. code-block:: python 13 | 14 | def update_net(self, buffer, batch_size, repeat_times, soft_update_tau): 15 | """ 16 | Update the neural networks by sampling batch data from ``ReplayBuffer``. 17 | 18 | :param buffer: the ReplayBuffer instance that stores the trajectories. 19 | :param batch_size: the size of batch data for Stochastic Gradient Descent (SGD). 20 | :param repeat_times: the re-using times of each trajectory. 21 | :param soft_update_tau: the soft update parameter. 22 | :return Nonetype 23 | """ 24 | buffer.update_now_len() 25 | self.batch_size = batch_size 26 | self.update_tau = soft_update_tau 27 | rewards, dones, actions, observations, next_obs = buffer.sample_batch(self.batch_size) 28 | for index in range(self.n_agents): 29 | self.update_agent(rewards, dones, actions, observations, next_obs, index) 30 | 31 | for agent in self.agents: 32 | self.soft_update(agent.cri_target, agent.cri, self.update_tau) 33 | self.soft_update(agent.act_target, agent.act, self.update_tau) 34 | 35 | return 36 | 37 | Parameters 38 | --------------------- 39 | 40 | .. autoclass:: elegantrl.agents.AgentMATD3.AgentMATD3 41 | :members: 42 | .. _matd3_networks: 43 | 44 | Networks 45 | ------------- 46 | 47 | .. autoclass:: elegantrl.agents.net.Actor 48 | :members: 49 | 50 | .. autoclass:: elegantrl.agents.net.CriticTwin 51 | :members: 52 | 53 | -------------------------------------------------------------------------------- /docs/source/algorithms/ppo.rst: -------------------------------------------------------------------------------- 1 | .. _ppo: 2 | 3 | 4 | PPO 5 | ========== 6 | 7 | `Proximal Policy Optimization (PPO) `_ is an on-policy Actor-Critic algorithm for both discrete and continuous action spaces. It has two primary variants: **PPO-Penalty** and **PPO-Clip**, where both utilize surrogate objectives to avoid the new policy changing too far from the old policy. This implementation provides PPO-Clip and supports the following extensions: 8 | 9 | - Target network: ✔️ 10 | - Gradient clipping: ✔️ 11 | - Reward clipping: ❌ 12 | - Generalized Advantage Estimation (GAE): ✔️ 13 | - Discrete version: ✔️ 14 | 15 | .. note:: 16 | The surrogate objective is the key feature of PPO since it both regularizes the policy update and enables the reuse of training data. 17 | 18 | A clear explanation of PPO algorithm and implementation in ElegantRL is available `here `_. 19 | 20 | Code Snippet 21 | ------------ 22 | 23 | .. code-block:: python 24 | 25 | import torch 26 | from elegantrl.run import train_and_evaluate 27 | from elegantrl.config import Arguments 28 | from elegantrl.train.config import build_env 29 | from elegantrl.agents.AgentPPO import AgentPPO 30 | 31 | # train and save 32 | args = Arguments(env=build_env('BipedalWalker-v3'), agent=AgentPPO()) 33 | args.cwd = 'demo_BipedalWalker_PPO' 34 | args.env.target_return = 300 35 | args.reward_scale = 2 ** -2 36 | train_and_evaluate(args) 37 | 38 | # test 39 | agent = AgentPPO() 40 | agent.init(args.net_dim, args.state_dim, args.action_dim) 41 | agent.save_or_load_agent(cwd=args.cwd, if_save=False) 42 | 43 | env = build_env('BipedalWalker-v3') 44 | state = env.reset() 45 | episode_reward = 0 46 | for i in range(2 ** 10): 47 | action = agent.select_action(state) 48 | next_state, reward, done, _ = env.step(action) 49 | 50 | episode_reward += reward 51 | if done: 52 | print(f'Step {i:>6}, Episode return {episode_reward:8.3f}') 53 | break 54 | else: 55 | state = next_state 56 | env.render() 57 | 58 | 59 | 60 | Parameters 61 | --------------------- 62 | 63 | .. autoclass:: elegantrl.agents.AgentPPO.AgentPPO 64 | :members: 65 | 66 | .. autoclass:: elegantrl.agents.AgentPPO.AgentDiscretePPO 67 | :members: 68 | 69 | .. _ppo_networks: 70 | 71 | Networks 72 | ------------- 73 | 74 | .. autoclass:: elegantrl.agents.net.ActorPPO 75 | :members: 76 | 77 | .. autoclass:: elegantrl.agents.net.ActorDiscretePPO 78 | :members: 79 | 80 | .. autoclass:: elegantrl.agents.net.CriticPPO 81 | :members: 82 | -------------------------------------------------------------------------------- /docs/source/algorithms/qmix.rst: -------------------------------------------------------------------------------- 1 | .. _qmix: 2 | 3 | 4 | QMix 5 | ========== 6 | 7 | `QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning `_ is a value-based method that can train decentralized policies in a centralized end-to-end fashion. QMIX employs a network that estimates joint action-values as a complex non-linear combination of per-agent values that condition only on local observations. 8 | 9 | - Experience replay: ✔️ 10 | - Target network: ✔️ 11 | - Gradient clipping: ❌ 12 | - Reward clipping: ❌ 13 | - Prioritized Experience Replay (PER): ✔️ 14 | - Ornstein–Uhlenbeck noise: ❌ 15 | 16 | 17 | 18 | Code Snippet 19 | ------------ 20 | 21 | .. code-block:: python 22 | 23 | def train(self, batch, t_env: int, episode_num: int, per_weight=None): 24 | rewards = batch["reward"][:, :-1] 25 | actions = batch["actions"][:, :-1] 26 | terminated = batch["terminated"][:, :-1].float() 27 | mask = batch["filled"][:, :-1].float() 28 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 29 | avail_actions = batch["avail_actions"] 30 | 31 | self.mac.agent.train() 32 | mac_out = [] 33 | self.mac.init_hidden(batch.batch_size) 34 | for t in range(batch.max_seq_length): 35 | agent_outs = self.mac.forward(batch, t=t) 36 | mac_out.append(agent_outs) 37 | mac_out = th.stack(mac_out, dim=1) 38 | 39 | chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim 40 | chosen_action_qvals_ = chosen_action_qvals 41 | 42 | with th.no_grad(): 43 | self.target_mac.agent.train() 44 | target_mac_out = [] 45 | self.target_mac.init_hidden(batch.batch_size) 46 | for t in range(batch.max_seq_length): 47 | target_agent_outs = self.target_mac.forward(batch, t=t) 48 | target_mac_out.append(target_agent_outs) 49 | 50 | target_mac_out = th.stack(target_mac_out, dim=1) # Concat across time 51 | 52 | mac_out_detach = mac_out.clone().detach() 53 | mac_out_detach[avail_actions == 0] = -9999999 54 | cur_max_actions = mac_out_detach.max(dim=3, keepdim=True)[1] 55 | target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) 56 | 57 | target_max_qvals = self.target_mixer(target_max_qvals, batch["state"]) 58 | 59 | if getattr(self.args, 'q_lambda', False): 60 | qvals = th.gather(target_mac_out, 3, batch["actions"]).squeeze(3) 61 | qvals = self.target_mixer(qvals, batch["state"]) 62 | 63 | targets = build_q_lambda_targets(rewards, terminated, mask, target_max_qvals, qvals, 64 | self.args.gamma, self.args.td_lambda) 65 | else: 66 | targets = build_td_lambda_targets(rewards, terminated, mask, target_max_qvals, 67 | self.args.n_agents, self.args.gamma, self.args.td_lambda) 68 | 69 | chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) 70 | 71 | td_error = (chosen_action_qvals - targets.detach()) 72 | td_error2 = 0.5 * td_error.pow(2) 73 | 74 | mask = mask.expand_as(td_error2) 75 | masked_td_error = td_error2 * mask 76 | 77 | 78 | if self.use_per: 79 | per_weight = th.from_numpy(per_weight).unsqueeze(-1).to(device=self.device) 80 | masked_td_error = masked_td_error.sum(1) * per_weight 81 | 82 | loss = L_td = masked_td_error.sum() / mask.sum() 83 | 84 | 85 | self.optimiser.zero_grad() 86 | loss.backward() 87 | grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) 88 | self.optimiser.step() 89 | 90 | 91 | Parameters 92 | --------------------- 93 | 94 | .. autoclass:: elegantrl.agents.AgentQMix.AgentQMix 95 | :members: 96 | 97 | .. _qmix_networks: 98 | 99 | Networks 100 | ------------- 101 | 102 | .. autoclass:: elegantrl.agents.net.QMix 103 | :members: 104 | 105 | .. autoclass:: elegantrl.agents.net.Critic 106 | :members: 107 | -------------------------------------------------------------------------------- /docs/source/algorithms/redq.rst: -------------------------------------------------------------------------------- 1 | .. _redq: 2 | 3 | 4 | REDQ 5 | ========== 6 | 7 | `Randomized Ensembled Double Q-Learning: Learning Fast Without a Model (REDQ) `_ has 8 | three carefully integrated ingredients to achieve its high performance: 9 | 10 | - update-to-data (UTD) ratio >> 1. 11 | - an ensemble of Q functions. 12 | - in-target minimization across a random subset of Q functions. 13 | 14 | This implementation is based on SAC. 15 | 16 | 17 | Code Snippet 18 | ------------ 19 | 20 | .. code-block:: python 21 | 22 | import torch 23 | from elegantrl.run import train_and_evaluate 24 | from elegantrl.config import Arguments 25 | from elegantrl.train.config import build_env 26 | from elegantrl.agents.AgentREDQ import AgentREDQ 27 | 28 | # train and save 29 | args = Arguments(env=build_env('Hopper-v2'), agent=AgentREDQ()) 30 | args.cwd = 'demo_Hopper_REDQ' 31 | train_and_evaluate(args) 32 | 33 | # test 34 | agent = AgentREDQ() 35 | agent.init(args.net_dim, args.state_dim, args.action_dim) 36 | agent.save_or_load_agent(cwd=args.cwd, if_save=False) 37 | 38 | env = build_env('Pendulum-v0') 39 | state = env.reset() 40 | episode_reward = 0 41 | for i in range(125000): 42 | action = agent.select_action(state) 43 | next_state, reward, done, _ = env.step(action) 44 | 45 | episode_reward += reward 46 | if done: 47 | print(f'Step {i:>6}, Episode return {episode_reward:8.3f}') 48 | break 49 | else: 50 | state = next_state 51 | env.render() 52 | 53 | Parameters 54 | --------------------- 55 | 56 | .. autoclass:: elegantrl.agents.AgentREDQ.AgentREDQ 57 | :members: 58 | 59 | .. _redq_networks: 60 | 61 | Networks 62 | ------------- 63 | 64 | .. autoclass:: elegantrl.agents.net.ActorSAC 65 | :members: 66 | 67 | .. autoclass:: elegantrl.agents.net.Critic 68 | :members: 69 | 70 | -------------------------------------------------------------------------------- /docs/source/algorithms/rode.rst: -------------------------------------------------------------------------------- 1 | .. _mappo: 2 | 3 | 4 | MAPPO 5 | ========== 6 | 7 | Multi-Agent Proximal Policy Optimization (MAPPO), a variant of PPO, is specialized for multi-agent settings. Using a 1-GPU desktop, we show that MAPPO achieves surprisingly strong performance in two popular multi-agent testbeds: the particle-world environments, and the Starcraft multi-agent challenge. 8 | 9 | - Shared network parameter for all agents ✔️ 10 | - This class is under test, we temporarily add all utils in AgentMAPPO ✔️ 11 | 12 | MAPPO achieves strong results while exhibiting comparable sample efficiency. 13 | 14 | 15 | Parameters 16 | --------------------- 17 | 18 | .. autoclass:: elegantrl.agents.AgentRODE.AgentREDQ 19 | :members: 20 | 21 | .. _redq_networks: 22 | 23 | Networks 24 | ------------- 25 | 26 | .. autoclass:: elegantrl.agents.net.ActorSAC 27 | :members: 28 | 29 | .. autoclass:: elegantrl.agents.net.Critic 30 | :members: 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/source/algorithms/sac.rst: -------------------------------------------------------------------------------- 1 | .. _sac: 2 | 3 | 4 | SAC 5 | ========== 6 | 7 | `Soft Actor-Critic (SAC) `_ is an off-policy Actor-Critic algorithm for continuous action space. In SAC, it introduces an entropy regularization to the loss function, which has a close connection with the trade-off of the exploration and exploitation. In our implementation, we employ a **learnable entropy regularization coefficienct** to dynamic control the scale of the entropy, which makes it consistent with a pre-defined target entropy. SAC also utilizes **Clipped Double-Q Learning** (mentioned in TD3) to overcome the overestimation of Q-values. This implementation provides SAC and supports the following extensions: 8 | 9 | - Experience replay: ✔️ 10 | - Target network: ✔️ 11 | - Gradient clipping: ✔️ 12 | - Reward clipping: ❌ 13 | - Prioritized Experience Replay (PER): ✔️ 14 | - Leanable entropy regularization coefficient: ✔️ 15 | 16 | .. note:: 17 | Inspired by the delayed policy update from TD3, we implement a modified version of SAC ``AgentModSAC`` with a dynamic adjustment of the frequency of the policy update. The adjustment is based on the loss of critic networks: a small loss leads to a high update frequency and vise versa. 18 | 19 | Code Snippet 20 | ------------ 21 | 22 | .. code-block:: python 23 | 24 | import torch 25 | from elegantrl.run import train_and_evaluate 26 | from elegantrl.config import Arguments 27 | from elegantrl.train.config import build_env 28 | from elegantrl.agents.AgentSAC import AgentSAC 29 | 30 | # train and save 31 | args = Arguments(env=build_env('Pendulum-v0'), agent=AgentSAC()) 32 | args.cwd = 'demo_Pendulum_SAC' 33 | args.env.target_return = -200 34 | args.reward_scale = 2 ** -2 35 | train_and_evaluate(args) 36 | 37 | # test 38 | agent = AgentSAC() 39 | agent.init(args.net_dim, args.state_dim, args.action_dim) 40 | agent.save_or_load_agent(cwd=args.cwd, if_save=False) 41 | 42 | env = build_env('Pendulum-v0') 43 | state = env.reset() 44 | episode_reward = 0 45 | for i in range(2 ** 10): 46 | action = agent.select_action(state) 47 | next_state, reward, done, _ = env.step(action) 48 | 49 | episode_reward += reward 50 | if done: 51 | print(f'Step {i:>6}, Episode return {episode_reward:8.3f}') 52 | break 53 | else: 54 | state = next_state 55 | env.render() 56 | 57 | 58 | 59 | Parameters 60 | --------------------- 61 | 62 | .. autoclass:: elegantrl.agents.AgentSAC.AgentSAC 63 | :members: 64 | 65 | .. autoclass:: elegantrl.agents.AgentSAC.AgentModSAC 66 | :members: 67 | 68 | .. _sac_networks: 69 | 70 | Networks 71 | ------------- 72 | 73 | .. autoclass:: elegantrl.agents.net.ActorSAC 74 | :members: 75 | 76 | .. autoclass:: elegantrl.agents.net.CriticTwin 77 | :members: 78 | -------------------------------------------------------------------------------- /docs/source/algorithms/td3.rst: -------------------------------------------------------------------------------- 1 | .. _td3: 2 | 3 | 4 | TD3 5 | ========== 6 | 7 | `Twin Delayed DDPG (TD3) `_ is a successor of DDPG algorithm with the usage of three additional tricks. In TD3, the usage of **Clipped Double-Q Learning**, **Delayed Policy Updates**, and **Target Policy Smoothing** overcomes the overestimation of Q-values and smooths out Q-values along with changes in action, which shows improved performance over baseline DDPG. This implementation provides TD3 and supports the following extensions: 8 | 9 | - Experience replay: ✔️ 10 | - Target network: ✔️ 11 | - Gradient clipping: ✔️ 12 | - Reward clipping: ❌ 13 | - Prioritized Experience Replay (PER): ✔️ 14 | 15 | .. note:: 16 | With respect to the clipped Double-Q learning, we use two Q-networks with shared parameters under a single Class ``CriticTwin``. Such an implementation allows a lower computational and training time cost. 17 | 18 | .. warning:: 19 | In the TD3 implementation, it contains a number of highly sensitive hyper-parameters, which requires the user to carefully tune these hyper-parameters to obtain a satisfied result. 20 | 21 | Code Snippet 22 | ------------ 23 | 24 | .. code-block:: python 25 | 26 | import torch 27 | from elegantrl.run import train_and_evaluate 28 | from elegantrl.config import Arguments 29 | from elegantrl.train.config import build_env 30 | from elegantrl.agents.AgentTD3 import AgentTD3 31 | 32 | # train and save 33 | args = Arguments(env=build_env('Pendulum-v0'), agent=AgentTD3()) 34 | args.cwd = 'demo_Pendulum_TD3' 35 | args.env.target_return = -200 36 | args.reward_scale = 2 ** -2 37 | train_and_evaluate(args) 38 | 39 | # test 40 | agent = AgentTD3() 41 | agent.init(args.net_dim, args.state_dim, args.action_dim) 42 | agent.save_or_load_agent(cwd=args.cwd, if_save=False) 43 | 44 | env = build_env('Pendulum-v0') 45 | state = env.reset() 46 | episode_reward = 0 47 | for i in range(2 ** 10): 48 | action = agent.select_action(state) 49 | next_state, reward, done, _ = env.step(action) 50 | 51 | episode_reward += reward 52 | if done: 53 | print(f'Step {i:>6}, Episode return {episode_reward:8.3f}') 54 | break 55 | else: 56 | state = next_state 57 | env.render() 58 | 59 | 60 | 61 | Parameters 62 | --------------------- 63 | 64 | .. autoclass:: elegantrl.agents.AgentTD3.AgentTD3 65 | :members: 66 | 67 | .. _td3_networks: 68 | 69 | Networks 70 | ------------- 71 | 72 | .. autoclass:: elegantrl.agents.net.Actor 73 | :members: 74 | 75 | .. autoclass:: elegantrl.agents.net.CriticTwin 76 | :members: 77 | -------------------------------------------------------------------------------- /docs/source/algorithms/vdn.rst: -------------------------------------------------------------------------------- 1 | .. _vdn: 2 | 3 | 4 | VDN 5 | ========== 6 | 7 | `Value Decomposition Networks (VDN) `_ trains individual agents with a novel value decomposition network architecture, which learns to decompose the team value function into agent-wise value functions. 8 | 9 | Code Snippet 10 | ------------ 11 | .. code-block:: python 12 | 13 | def train(self, batch, t_env: int, episode_num: int): 14 | 15 | # Get the relevant quantities 16 | rewards = batch["reward"][:, :-1] 17 | actions = batch["actions"][:, :-1] 18 | terminated = batch["terminated"][:, :-1].float() 19 | mask = batch["filled"][:, :-1].float() 20 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 21 | avail_actions = batch["avail_actions"] 22 | 23 | # Calculate estimated Q-Values 24 | mac_out = [] 25 | self.mac.init_hidden(batch.batch_size) 26 | for t in range(batch.max_seq_length): 27 | agent_outs = self.mac.forward(batch, t=t) 28 | mac_out.append(agent_outs) 29 | mac_out = th.stack(mac_out, dim=1) # Concat over time 30 | 31 | # Pick the Q-Values for the actions taken by each agent 32 | chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim 33 | 34 | # Calculate the Q-Values necessary for the target 35 | target_mac_out = [] 36 | self.target_mac.init_hidden(batch.batch_size) 37 | for t in range(batch.max_seq_length): 38 | target_agent_outs = self.target_mac.forward(batch, t=t) 39 | target_mac_out.append(target_agent_outs) 40 | 41 | 42 | 43 | Parameters 44 | --------------------- 45 | 46 | .. autoclass:: elegantrl.agents.AgentVDN.AgentVDN 47 | :members: 48 | 49 | .. _vdn_networks: 50 | Networks 51 | ------------- 52 | 53 | .. autoclass:: elegantrl.agents.net.VDN 54 | :members: 55 | -------------------------------------------------------------------------------- /docs/source/api/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/api/.DS_Store -------------------------------------------------------------------------------- /docs/source/api/config.rst: -------------------------------------------------------------------------------- 1 | Configuration: *config.py* 2 | ========================== 3 | 4 | 5 | ``Arguments`` 6 | --------------------- 7 | 8 | The ``Arguments`` class contains all parameters of the training process, including environment setup, model training, model evaluation, and resource allocation. It provides users an unified interface to customize the training process. 9 | 10 | The class should be initialized at the start of the training process. For example, 11 | 12 | .. code-block:: python 13 | 14 | from elegantrl.train.config import Arguments 15 | from elegantrl.agents.AgentPPO import AgentPPO 16 | from elegantrl.train.config import build_env 17 | import gym 18 | 19 | args = Arguments(build_env('Pendulum-v1'), AgentPPO()) 20 | 21 | The full list of parameters in ``Arguments``: 22 | 23 | .. autoclass:: elegantrl.train.config.Arguments 24 | :members: 25 | 26 | 27 | Environment registration 28 | --------------------- 29 | 30 | .. autofunction:: elegantrl.train.config.build_env 31 | 32 | .. autofunction:: elegantrl.train.config.check_env 33 | 34 | 35 | Utils 36 | --------------------- 37 | 38 | .. autofunction:: elegantrl.train.config.kwargs_filter 39 | -------------------------------------------------------------------------------- /docs/source/api/evaluator.rst: -------------------------------------------------------------------------------- 1 | Evaluator: *evaluator.py* 2 | =============================== 3 | 4 | In the course of training, ElegantRL provide an ``evaluator`` to periodically evaluate agent's performance and save models. 5 | 6 | For agent evaluation, the evaluator runs agent's actor (policy) network on the testing environment and outputs corresponding scores. Commonly used performance metrics are mean and variance of episodic rewards. The score is useful in following two cases: 7 | - Case 1: the score serves as a goal signal. When the score reaches the target score, it means that the goal of the task is achieved. 8 | - Case 2: the score serves as a criterion to determine overfitting of models. When the score continuously drops, we can terminate the training process early to mitigate the performance collapse and the waste of computing power brought by overfitting. 9 | 10 | .. note:: 11 | ElegantRL supports a tournament-based ensemble training scheme to empower the population-based training (PBT). We maintain a leaderboard to keep track of agents with high scores and then perform a tournament-based evolution among these agents. In this case, the score from the evaluator serves as a metric for leaderboard. 12 | 13 | For model saving, the evaluator saves following three types of files: 14 | - actor.pth: actor (policy) network of the agent. 15 | - plot_learning_curve.jpg: learning curve of the agent. 16 | - recorder.npy: log file, including total training steps, reward average, reward standard deviation, reward exp, actor loss, and critic loss. 17 | 18 | We implement the ``evaluator`` as a microservice, which can be ran as an independent process. When an evaluator is running, it can automatically monitors parallel agents, and provide evaluation when any agent needs, and communicate agent information with the leaderboard. 19 | 20 | Implementations 21 | --------------------- 22 | 23 | .. autoclass:: elegantrl.train.evaluator.Evaluator 24 | :members: 25 | 26 | 27 | Utils 28 | --------------------- 29 | 30 | .. autofunction:: elegantrl.train.evaluator.get_episode_return_and_step 31 | 32 | .. autofunction:: elegantrl.train.evaluator.save_learning_curve 33 | 34 | -------------------------------------------------------------------------------- /docs/source/api/learner.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/api/learner.rst -------------------------------------------------------------------------------- /docs/source/api/replay.rst: -------------------------------------------------------------------------------- 1 | Replay Buffer: *replay_buffer.py* 2 | ================================= 3 | 4 | ElegantRL provides ``ReplayBuffer`` to store sampled transitions. 5 | 6 | In ElegantRL, we utilize ``Worker`` for exploration (data sampling) and ``Learner`` for exploitation (model learning), and we view such a relationship as a "producer-consumer" model, where a worker produces transitions and a learner consumes, and a learner updates the actor net at worker to produce new transitions. In this case, the ``ReplayBuffer`` is the storage buffer that connects the worker and learner. 7 | 8 | Each transition is in a format (state, (reward, done, action)). 9 | 10 | .. note:: 11 | We allocate the ``ReplayBuffer`` on continuous RAM for high performance training. Since the collected transitions are packed in sequence, the addressing speed increases dramatically when a learner randomly samples a batch of transitions. 12 | 13 | Implementations 14 | --------------------- 15 | 16 | .. autoclass:: elegantrl.train.replay_buffer.ReplayBuffer 17 | :members: 18 | 19 | Multiprocessing 20 | --------------------- 21 | 22 | .. autoclass:: elegantrl.train.replay_buffer.ReplayBufferMP 23 | :members: 24 | 25 | Initialization 26 | --------------------- 27 | 28 | .. autofunction:: elegantrl.train.replay_buffer.init_replay_buffer 29 | 30 | Utils 31 | --------------------- 32 | 33 | .. autoclass:: elegantrl.train.replay_buffer.BinarySearchTree 34 | -------------------------------------------------------------------------------- /docs/source/api/run.rst: -------------------------------------------------------------------------------- 1 | Run: *run.py* 2 | ================================= 3 | 4 | In *run.py*, we provide functions to wrap the training (and evaluation) process. 5 | 6 | In ElegantRL, users follow a **two-step procedure** to train an agent in a lightweight and automatic way. 7 | 8 | 1. Initializing the agent and environment, and setting hyper-parameters up in ``Arguments``. 9 | 2. Passing the ``Arguments`` to functions for the training process, e.g., ``train_and_evaluate`` for single-process training and ``train_and_evaluate_mp`` for multi-process training. 10 | 11 | Let's look at a demo for the simple two-step procedure. 12 | 13 | .. code-block:: python 14 | 15 | from elegantrl.train.config import Arguments 16 | from elegantrl.train.run import train_and_evaluate, train_and_evaluate_mp 17 | from elegantrl.envs.Chasing import ChasingEnv 18 | from elegantrl.agents.AgentPPO import AgentPPO 19 | 20 | # Step 1 21 | args = Arguments(agent=AgentPPO(), env_func=ChasingEnv) 22 | 23 | # Step 2 24 | train_and_evaluate_mp(args) 25 | 26 | Single-process 27 | --------------------- 28 | 29 | .. autofunction:: elegantrl.train.run.train_and_evaluate 30 | 31 | Multi-process 32 | --------------------- 33 | 34 | .. autofunction:: elegantrl.train.run.train_and_evaluate_mp 35 | 36 | Utils 37 | --------------------- 38 | 39 | .. autoclass:: elegantrl.train.run.safely_terminate_process 40 | 41 | .. autoclass:: elegantrl.train.run.check_subprocess 42 | -------------------------------------------------------------------------------- /docs/source/api/utils.rst: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/source/api/worker.rst: -------------------------------------------------------------------------------- 1 | Worker: *worker.py* 2 | ================================= 3 | 4 | Deep reinforcement learning (DRL) employs a trial-and-error manner to collect training data (transitions) from agent-environment interactions, along with the learning procedure. ElegantRL utilizes ``Worker`` to generate transitions and achieves worker parallelism, thus greatly speeding up the data collection. 5 | 6 | Implementations 7 | --------------------- 8 | 9 | .. autoclass:: elegantrl.train.worker.PipeWorker 10 | :members: 11 | 12 | -------------------------------------------------------------------------------- /docs/source/faq-zh.rst: -------------------------------------------------------------------------------- 1 | FAQ 2 | 3 | 4 | ^^^^^^^^ 5 | 问题 1:在强化学习代码中,对log值裁剪到 -20 到 +2 之间是在进行什么操作?为什么要裁剪到这两个值之间? 6 | ^^^^^^^^ 7 | 8 | 在强化学习中,我们举两类对log值进行裁剪的例子: 9 | 10 | - 对随机策略的动作的高斯分布的方差的log值 `action_std_log` 进行裁剪 11 | 12 | - 对正态分布中对应的概率的log值 (log probability) `logprob` 进行裁剪 13 | 14 | 简单说,就是相对于正态分布 N~(0, 1) 来说,一个高斯分布的方差的log值如果超过 (-20, +2) 这个区间,那么: 15 | 16 | - 如果log值小于 -20,那么这个高斯分布的方差特别小,相当于没有方差,接近于一个确定的数值。 17 | 18 | - 如果log值大于 +2,那么这个高斯分布的方差特别大,相当于在接近均值附近是均匀分布。 19 | 20 | 有空我就展开讲一讲。 21 | 22 | ----------------- 23 | 对随机策略的动作的高斯分布的方差的log值 `action_std_log` 进行裁剪 24 | ----------------- 25 | 对应代码是 `action_std = self.net_action_std(t_tmp).clip(-20, 2).exp()`, 可以在 `elegantrl/net.py` 里找到。 26 | 27 | SAC算法的 `alpha_log` 也能进行类似的裁剪 28 | 29 | 还可以讲一讲 强化学习里,把权重处理成 log 形式再进行梯度优化。 30 | 31 | 有空我就展开讲一讲。或者你们来补充(2022-06-08 18:01:54) 32 | 33 | ----------------- 34 | 对正态分布中对应的概率的log值 (log probability) `logprob` 进行裁剪 35 | ----------------- 36 | 对应代码是 `logprob = logprob.clip(-20, 2)`, 有可能在 `elegantrl/agent/` 里的随机策略梯度算法里找到,因为随机策略梯度算法会用到 `logprob`。 37 | 38 | 有空我就展开讲一讲。或者你们来补充(2022-06-08 18:01:54) 39 | 40 | 41 | ^^^^^^^^ 42 | 问题:On-policy 和 off-policy 的区别是什么? 43 | ^^^^^^^^ 44 | 若行为策略和目标策略相同,则是on-policy,若不同则为off-policy 45 | 46 | 有空我就展开讲一讲。 47 | 48 | 49 | ^^^^^^^^ 50 | 问题: elegantrl RLlib SB3 对比 51 | ^^^^^^^^ 52 | 53 | RLlib 的优势: 54 | 55 | - 他们有ray 可以调度多卡之间的传输,多卡的时候选择 生产者-消费者 模式,保证把计算资源用满 56 | 57 | - RLlib的复杂代码把RL过程抽象出来了,他们可以选择 TensorFlow 或者 PyTorch 作为深度学习的后端 58 | 59 | RLlib的劣势: 60 | 61 | - 让 生产者worker 和 消费者 learner 异步的方案,数据不够新,虽然计算资源用尽了,但是计算效率降低了 62 | 63 | - 现在大家都 PyTorch ,RLlib 的代码太复杂了,用起来有门槛,反而不容易用 64 | 65 | ELegantRL 在单卡上和 RLlib比较: 66 | 67 | - 论文写了,我们让 一张GPU运行完worker,就让 learner直接用 worker收集到的data,数据不用挪动,因此快。 68 | 69 | - 我们的代码从 worker 到 learner都支持了 vectorized env,(我不清楚现在RLlib 的worker 是否支持 vectorized env ,但他们的 learner 支持不了) 70 | 71 | - 我们还开发了 vwap 的 vectorized env,而不只是 stable baselines3 或者 天授的EnvPool 那种 subprocessing vectorized env 72 | 73 | ELegantRL 在多卡上和RLlib比较: 74 | 只能说是各有优劣,不能说谁的方案更适合某个 DRL算法或者某个 任务。 75 | 我们在金融任务上,使用了 PPO+Podracer,而不是 RLlib 的 Producter&Comsumer 的模式,让PPO算法的数据利用效率更高,而且我们还套了一层 遗传算法在外面方便跳出局部最优,达到更好的次优。 76 | 77 | 78 | 比SB3更稳定,是因为,ELegantRL 和 sb3在以下两点差别明显: 79 | 80 | - 我们还开发了 vmap 的 vectorized env,而不只是 stable baselines3 或者 天授的EnvPool 那种 subprocessing vectorized env,用GPU做仿真环境的并行(StockTradingVecEnv),采集数据量多了2个数量级以上,数据多,所以训练稳定 81 | 82 | - 我们用了 H term,这个真的有用,可以让训练变稳定:在根据符合贝尔曼公式Q值的优化方向的基础上,再使用一个 H term 找出另一个优化方向,两个优化方向同时使用,更不容易掉进局部最优,所以稳定(可惜当前ELegantRL库 只有 半年前的代码支持了 H term,还需要人手把 Hterm 的代码升级到 2023年2月份的版本) 83 | 84 | 比SB3快: 85 | 86 | - 我们的 ReplayBuffer优化过,按顺序储存 state,所以不需要重复保存 state_t 和 state_t+1,再加上我们的ReplayBuffer都是 PyTorch tensor 格式 + 指针,抽取数据没有用PyTorch自带的 dataLoader,而是自己写的,因此快 87 | 88 | - 我们的 worker 和 learner 都有 针对 vectorized env 的优化,sb3没有 89 | 90 | - 我们给FinRL任务以及 RLSolver任务 开发了 GPU并行仿真环境,sb3 没有 91 | 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /docs/source/helloworld/agent.rst: -------------------------------------------------------------------------------- 1 | Agents: *agent.py* 2 | ================== 3 | 4 | In this HelloWorld, we focus on DQN, SAC, and PPO, which are the most representative and commonly used DRL algorithms. 5 | 6 | Agents 7 | ------------------------------ 8 | 9 | .. autoclass:: elegantrl_helloworld.agent.AgentBase 10 | :members: 11 | 12 | .. autoclass:: elegantrl_helloworld.agent.AgentDQN 13 | :members: 14 | 15 | .. autoclass:: elegantrl_helloworld.agent.AgentSAC 16 | :members: 17 | 18 | .. autoclass:: elegantrl_helloworld.agent.AgentPPO 19 | :members: 20 | 21 | .. autoclass:: elegantrl_helloworld.agent.AgentDiscretePPO 22 | :members: 23 | 24 | 25 | Replay Buffer 26 | --------------------------------------------- 27 | 28 | .. autoclass:: elegantrl_helloworld.agent.ReplayBuffer 29 | :members: 30 | 31 | .. autoclass:: elegantrl_helloworld.agent.ReplayBufferList 32 | :members: 33 | -------------------------------------------------------------------------------- /docs/source/helloworld/env.rst: -------------------------------------------------------------------------------- 1 | Environment: *env.py* 2 | ===================== 3 | 4 | .. autofunction:: elegantrl_helloworld.env.get_gym_env_args 5 | 6 | .. autofunction:: elegantrl_helloworld.env.kwargs_filter 7 | 8 | .. autofunction:: elegantrl_helloworld.env.build_env 9 | -------------------------------------------------------------------------------- /docs/source/helloworld/net.rst: -------------------------------------------------------------------------------- 1 | Networks: *net.py* 2 | ================== 3 | 4 | In ElegantRL, there are three basic network classes: Q-net, Actor, and Critic. Here, we list several examples, which are the networks used by DQN, SAC, and PPO algorithms. 5 | 6 | The full list of networks are available `here `_ 7 | 8 | Q Net 9 | ----- 10 | 11 | .. autoclass:: elegantrl_helloworld.net.QNet 12 | :members: 13 | 14 | Actor Network 15 | ------------- 16 | 17 | .. autoclass:: elegantrl_helloworld.net.ActorSAC 18 | :members: 19 | 20 | .. autoclass:: elegantrl_helloworld.net.ActorPPO 21 | :members: 22 | 23 | .. autoclass:: elegantrl_helloworld.net.ActorDiscretePPO 24 | :members: 25 | 26 | Critic Network 27 | -------------- 28 | 29 | .. autoclass:: elegantrl_helloworld.net.CriticTwin 30 | :members: 31 | 32 | .. autoclass:: elegantrl_helloworld.net.CriticPPO 33 | :members: 34 | -------------------------------------------------------------------------------- /docs/source/helloworld/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ============= 3 | 4 | As a quickstart, we select the Pendulum task from the demo.py to show how to train a DRL agent in ElegantRL. 5 | 6 | Step 1: Import packages 7 | ------------------------------- 8 | 9 | .. code-block:: python 10 | 11 | from elegantrl_helloworld.demo import * 12 | 13 | gym.logger.set_level(40) # Block warning 14 | 15 | Step 2: Specify Agent and Environment 16 | -------------------------------------- 17 | 18 | .. code-block:: python 19 | 20 | env = PendulumEnv('Pendulum-v0', target_return=-500) 21 | args = Arguments(AgentSAC, env) 22 | 23 | Part 3: Specify Hyper-parameters 24 | -------------------------------------- 25 | 26 | .. code-block:: python 27 | 28 | args.reward_scale = 2 ** -1 # RewardRange: -1800 < -200 < -50 < 0 29 | args.gamma = 0.97 30 | args.target_step = args.max_step * 2 31 | args.eval_times = 2 ** 3 32 | 33 | Step 4: Train and Evaluate the Agent 34 | -------------------------------------- 35 | 36 | .. code-block:: python 37 | 38 | train_and_evaluate(args) 39 | 40 | Try by yourself through this `Colab `_! 41 | 42 | .. tip:: 43 | - By default, it will train a stable-SAC agent in the Pendulum-v0 environment for 400 seconds. 44 | 45 | - It will choose to utilize CPUs or GPUs automatically. Don't worry, we never use ``.cuda()``. 46 | 47 | - It will save the log and model parameters file in ``'./{Environment}_{Agent}_{GPU_ID}'``. 48 | 49 | - It will print the total reward while training. (Maybe we should use TensorBoardX?) 50 | 51 | - The code is heavily commented. We believe these comments can answer some of your questions. 52 | -------------------------------------------------------------------------------- /docs/source/helloworld/run.rst: -------------------------------------------------------------------------------- 1 | Main: *run.py* 2 | ============== 3 | 4 | Hyper-parameters 5 | ------------------------------------------------------------ 6 | 7 | .. autoclass:: elegantrl_helloworld.run.Arguments 8 | :members: 9 | 10 | 11 | Train and Evaluate 12 | ----------------------------------------------------------- 13 | 14 | .. autofunction:: elegantrl_helloworld.run.train_and_evaluate 15 | 16 | .. autofunction:: elegantrl_helloworld.run.init_agent 17 | 18 | .. autofunction:: elegantrl_helloworld.run.init_evaluator 19 | 20 | .. autofunction:: elegantrl_helloworld.run.init_buffer 21 | 22 | 23 | Evaluator 24 | --------------------------------- 25 | 26 | .. autoclass:: elegantrl_helloworld.run.Evaluator 27 | :members: 28 | 29 | .. autofunction:: elegantrl_helloworld.run.get_episode_return_and_step -------------------------------------------------------------------------------- /docs/source/images/BipedalWalker-v3_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/BipedalWalker-v3_1.gif -------------------------------------------------------------------------------- /docs/source/images/BipedalWalker-v3_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/BipedalWalker-v3_2.gif -------------------------------------------------------------------------------- /docs/source/images/File_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/File_structure.png -------------------------------------------------------------------------------- /docs/source/images/H-term.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/H-term.png -------------------------------------------------------------------------------- /docs/source/images/LunarLander.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/LunarLander.gif -------------------------------------------------------------------------------- /docs/source/images/LunarLanderTwinDelay3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/LunarLanderTwinDelay3.gif -------------------------------------------------------------------------------- /docs/source/images/bellman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/bellman.png -------------------------------------------------------------------------------- /docs/source/images/efficiency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/efficiency.png -------------------------------------------------------------------------------- /docs/source/images/envs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/envs.png -------------------------------------------------------------------------------- /docs/source/images/fin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/fin.png -------------------------------------------------------------------------------- /docs/source/images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/framework.png -------------------------------------------------------------------------------- /docs/source/images/framework2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/framework2.png -------------------------------------------------------------------------------- /docs/source/images/init.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/source/images/isaacgym.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/isaacgym.gif -------------------------------------------------------------------------------- /docs/source/images/learning_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/learning_curve.png -------------------------------------------------------------------------------- /docs/source/images/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/logo.jpg -------------------------------------------------------------------------------- /docs/source/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/logo.png -------------------------------------------------------------------------------- /docs/source/images/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/overview.jpg -------------------------------------------------------------------------------- /docs/source/images/overview_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/overview_1.png -------------------------------------------------------------------------------- /docs/source/images/overview_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/overview_2.png -------------------------------------------------------------------------------- /docs/source/images/overview_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/overview_3.png -------------------------------------------------------------------------------- /docs/source/images/overview_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/overview_4.png -------------------------------------------------------------------------------- /docs/source/images/parallelism.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/parallelism.png -------------------------------------------------------------------------------- /docs/source/images/performance1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/performance1.png -------------------------------------------------------------------------------- /docs/source/images/performance2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/performance2.png -------------------------------------------------------------------------------- /docs/source/images/pseudo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/pseudo.png -------------------------------------------------------------------------------- /docs/source/images/reacher_v2_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/reacher_v2_1.gif -------------------------------------------------------------------------------- /docs/source/images/recursive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/recursive.png -------------------------------------------------------------------------------- /docs/source/images/samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/samples.png -------------------------------------------------------------------------------- /docs/source/images/tab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/tab.png -------------------------------------------------------------------------------- /docs/source/images/test1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/test1.png -------------------------------------------------------------------------------- /docs/source/images/test2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/test2.png -------------------------------------------------------------------------------- /docs/source/images/time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/docs/source/images/time.png -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. ElegantRL documentation master file, created by 2 | 3 | Welcome to ElegantRL! 4 | ===================================== 5 | 6 | .. image:: ./images/logo.png 7 | :width: 50% 8 | :align: center 9 | :target: https://github.com/AI4Finance-Foundation/ElegantRL 10 | 11 | 12 | 13 | `ElegantRL `_ is an open-source massively parallel library for deep reinforcement learning (DRL) algorithms, implemented in PyTorch. We aim to provide a *next-generation* framework that leverage recent techniques, e.g., massively parallel simulations, ensemble methods, population-based training, and showcase exciting scientific discoveries. 14 | 15 | ElegantRL features strong **scalability**, **elasticity** and **lightweightness**, and allows users to conduct **efficient** training on either one GPU or hundreds of GPUs: 16 | 17 | - **Scalability**: ElegantRL fully exploits the parallelism of DRL algorithms at multiple levels, making it easily scale out to hundreds or thousands of computing nodes on a cloud platform, say, a SuperPOD platform with thousands of GPUs. 18 | 19 | - **Elasticity**: ElegantRL can elastically allocate computing resources on the cloud, which helps adapt to available resources and prevents over/under-provisioning/under-provisioning. 20 | 21 | - **Lightweightness**: The core codes <1,000 lines (check `elegantrl_helloworld `_). 22 | 23 | - **Efficient**: in many testing cases, it is more efficient than `Ray RLlib `_. 24 | 25 | ElegantRL implements the following DRL algorithms: 26 | 27 | - **DDPG, TD3, SAC, A2C, PPO, REDQ for continuous actions** 28 | 29 | - **DQN, DoubleDQN, D3QN, PPO-Discrete for discrete actions** 30 | 31 | - **QMIX, VDN; MADDPG, MAPPO, MATD3 for multi-agent RL** 32 | 33 | 34 | For beginners, we maintain `ElegantRL-HelloWorld `_ as a tutorial. It is a lightweight version of ElegantRL with <1,000 lines of core codes. More details are available `here `_. 35 | 36 | Installation 37 | --------------------------------------- 38 | 39 | ElegantRL generally requires: 40 | 41 | - Python>=3.6 42 | 43 | - PyTorch>=1.0.2 44 | 45 | - gym, matplotlib, numpy, pybullet, torch, opencv-python, box2d-py. 46 | 47 | You can simply install ElegantRL from PyPI with the following command: 48 | 49 | .. code-block:: bash 50 | :linenos: 51 | 52 | pip3 install erl --upgrade 53 | 54 | Or install with the newest version through GitHub: 55 | 56 | .. code-block:: bash 57 | :linenos: 58 | 59 | git clone https://github.com/AI4Finance-Foundation/ElegantRL.git 60 | cd ElegantRL 61 | pip3 install . 62 | 63 | 64 | .. toctree:: 65 | :maxdepth: 1 66 | :hidden: 67 | 68 | Home 69 | 70 | .. toctree:: 71 | :maxdepth: 1 72 | :caption: HelloWorld 73 | 74 | helloworld/intro 75 | helloworld/net 76 | helloworld/agent 77 | helloworld/env 78 | helloworld/run 79 | helloworld/quickstart 80 | 81 | .. toctree:: 82 | :maxdepth: 1 83 | :caption: Overview 84 | 85 | about/overview 86 | about/cloud 87 | about/parallel 88 | 89 | 90 | .. toctree:: 91 | :maxdepth: 1 92 | :caption: Tutorials 93 | 94 | tutorial/LunarLanderContinuous-v2 95 | tutorial/BipedalWalker-v3 96 | tutorial/Creating_VecEnv 97 | tutorial/isaacgym 98 | tutorial/redq 99 | tutorial/hterm 100 | tutorial/finrl-podracer 101 | tutorial/elegantrl-podracer 102 | 103 | .. toctree:: 104 | :maxdepth: 1 105 | :caption: Algorithms 106 | 107 | algorithms/dqn 108 | algorithms/double_dqn 109 | algorithms/ddpg 110 | algorithms/td3 111 | algorithms/sac 112 | algorithms/a2c 113 | algorithms/ppo 114 | algorithms/redq 115 | algorithms/maddpg 116 | algorithms/matd3 117 | algorithms/qmix 118 | algorithms/vdn 119 | algorithms/mappo 120 | 121 | 122 | 123 | .. toctree:: 124 | :maxdepth: 1 125 | :caption: RLSolver 126 | 127 | RLSolver/overview 128 | RLSolver/helloworld 129 | RLSolver/datasets 130 | RLSolver/environments 131 | RLSolver/benchmarks 132 | 133 | 134 | .. toctree:: 135 | :maxdepth: 1 136 | :caption: API Reference 137 | 138 | api/config 139 | api/run 140 | api/worker 141 | api/learner 142 | api/replay 143 | api/evaluator 144 | api/utils 145 | 146 | 147 | .. toctree:: 148 | :maxdepth: 1 149 | :caption: Other 150 | 151 | other/faq 152 | 153 | 154 | Indices and tables 155 | ================== 156 | 157 | * :ref:`genindex` 158 | * :ref:`modindex` 159 | * :ref:`search` 160 | -------------------------------------------------------------------------------- /docs/source/tutorial/BipedalWalker-v3.rst: -------------------------------------------------------------------------------- 1 | Example 2: BipedalWalker-v3 2 | =============================== 3 | 4 | BipedalWalker-v3 is a classic task in robotics that performs a fundamental skill: moving forward as fast as possible. The goal is to get a 2D biped walker to walk through rough terrain. BipedalWalker is considered to be a difficult task in the continuous action space, and there are only a few RL implementations that can reach the target reward. Our Python code is available `here `_. 5 | 6 | When a biped walker takes random actions: 7 | 8 | .. image:: ../images/BipedalWalker-v3_1.gif 9 | :width: 80% 10 | :align: center 11 | 12 | 13 | Step 1: Install ElegantRL 14 | ------------------------------ 15 | 16 | .. code-block:: python 17 | 18 | pip install git+https://github.com/AI4Finance-LLC/ElegantRL.git 19 | 20 | Step 2: Import packages 21 | ------------------------------- 22 | 23 | - ElegantRL 24 | 25 | - OpenAI Gym: a toolkit for developing and comparing reinforcement learning algorithms (collections of environments). 26 | 27 | .. code-block:: python 28 | 29 | from elegantrl.run import * 30 | 31 | gym.logger.set_level(40) # Block warning 32 | 33 | Step 3: Get environment information 34 | -------------------------------------------------- 35 | 36 | .. code-block:: python 37 | 38 | get_gym_env_args(gym.make('BipedalWalker-v3'), if_print=False) 39 | 40 | 41 | Output: 42 | 43 | .. code-block:: python 44 | 45 | env_args = { 46 | 'env_num': 1, 47 | 'env_name': 'BipedalWalker-v3', 48 | 'max_step': 1600, 49 | 'state_dim': 24, 50 | 'action_dim': 4, 51 | 'if_discrete': False, 52 | 'target_return': 300, 53 | } 54 | 55 | 56 | Step 4: Initialize agent and environment 57 | --------------------------------------------- 58 | 59 | - agent: chooses a agent (DRL algorithm) from a set of agents in the `directory `_. 60 | 61 | - env_func: the function to create an environment, in this case, we use ``gym.make`` to create BipedalWalker-v3. 62 | 63 | - env_args: the environment information. 64 | 65 | .. code-block:: python 66 | 67 | env_func = gym.make 68 | env_args = { 69 | 'env_num': 1, 70 | 'env_name': 'BipedalWalker-v3', 71 | 'max_step': 1600, 72 | 'state_dim': 24, 73 | 'action_dim': 4, 74 | 'if_discrete': False, 75 | 'target_return': 300, 76 | 'id': 'BipedalWalker-v3', 77 | } 78 | 79 | args = Arguments(AgentPPO, env_func=env_func, env_args=env_args) 80 | 81 | Step 5: Specify hyper-parameters 82 | ---------------------------------------- 83 | 84 | A list of hyper-parameters is available `here `_. 85 | 86 | .. code-block:: python 87 | 88 | args.target_step = args.max_step * 4 89 | args.gamma = 0.98 90 | args.eval_times = 2 ** 4 91 | 92 | 93 | Step 6: Train your agent 94 | ---------------------------------------- 95 | 96 | In this tutorial, we provide four different modes to train an agent: 97 | 98 | - **Single-process**: utilize one GPU for a single-process training. No parallelism. 99 | 100 | - **Multi-process**: utilize one GPU for a multi-process training. Support worker and learner parallelism. 101 | 102 | - **Multi-GPU**: utilize multi-GPUs to train an agent through model fusion. Specify the GPU ids you want to use. 103 | 104 | - **Tournament-based ensemble training**: utilize multi-GPUs to run tournament-based ensemble training. 105 | 106 | 107 | .. code-block:: python 108 | 109 | flag = 'SingleProcess' 110 | 111 | if flag == 'SingleProcess': 112 | args.learner_gpus = 0 113 | train_and_evaluate(args) 114 | 115 | elif flag == 'MultiProcess': 116 | args.learner_gpus = 0 117 | train_and_evaluate_mp(args) 118 | 119 | elif flag == 'MultiGPU': 120 | args.learner_gpus = [0, 1, 2, 3] 121 | train_and_evaluate_mp(args) 122 | 123 | elif flag == 'Tournament-based': 124 | args.learner_gpus = [[i, ] for i in range(4)] # [[0,], [1, ], [2, ]] or [[0, 1], [2, 3]] 125 | python_path = '.../bin/python3' 126 | train_and_evaluate_mp(args, python_path) 127 | 128 | else: 129 | raise ValueError(f"Unknown flag: {flag}") 130 | 131 | 132 | Try by yourself through this `Colab `_! 133 | 134 | Performance of a trained agent: 135 | 136 | .. image:: ../images/BipedalWalker-v3_2.gif 137 | :width: 80% 138 | :align: center 139 | 140 | 141 | Check out our **video** on bilibili: `Crack the BipedalWalkerHardcore-v2 with total reward 310 using IntelAC `_. 142 | -------------------------------------------------------------------------------- /docs/source/tutorial/LunarLanderContinuous-v2.rst: -------------------------------------------------------------------------------- 1 | Example 1: LunarLanderContinuous-v2 2 | ======================================== 3 | 4 | LunarLanderContinuous-v2 is a robotic control task. The goal is to get a Lander to rest on the landing pad. If lander moves away from landing pad it loses reward back. Episode finishes if the lander crashes or comes to rest, receiving additional -100 or +100 points. Detailed description of the task can be found at `OpenAI Gym `_. Our Python code is available `here `_. 5 | 6 | 7 | When a Lander takes random actions: 8 | 9 | .. image:: ../images/LunarLander.gif 10 | :width: 80% 11 | :align: center 12 | 13 | 14 | Step 1: Install ElegantRL 15 | ------------------------------ 16 | 17 | .. code-block:: python 18 | 19 | pip install git+https://github.com/AI4Finance-LLC/ElegantRL.git 20 | 21 | Step 2: Import packages 22 | ------------------------------- 23 | 24 | - ElegantRL 25 | 26 | - OpenAI Gym: a toolkit for developing and comparing reinforcement learning algorithms (collections of environments). 27 | 28 | .. code-block:: python 29 | 30 | from elegantrl.run import * 31 | 32 | gym.logger.set_level(40) # Block warning 33 | 34 | Step 3: Get environment information 35 | -------------------------------------------------- 36 | 37 | .. code-block:: python 38 | 39 | get_gym_env_args(gym.make('LunarLanderContinuous-v2'), if_print=True) 40 | 41 | 42 | Output: 43 | 44 | .. code-block:: python 45 | 46 | env_args = { 47 | 'env_num': 1, 48 | 'env_name': 'LunarLanderContinuous-v2', 49 | 'max_step': 1000, 50 | 'state_dim': 8, 51 | 'action_dim': 4, 52 | 'if_discrete': True, 53 | 'target_return': 200, 54 | 'id': 'LunarLanderContinuous-v2' 55 | } 56 | 57 | 58 | Step 4: Initialize agent and environment 59 | --------------------------------------------- 60 | 61 | - agent: chooses a agent (DRL algorithm) from a set of agents in the `directory `_. 62 | 63 | - env_func: the function to create an environment, in this case, we use ``gym.make`` to create LunarLanderContinuous-v2. 64 | 65 | - env_args: the environment information. 66 | 67 | .. code-block:: python 68 | 69 | env_func = gym.make 70 | env_args = { 71 | 'env_num': 1, 72 | 'env_name': 'LunarLanderContinuous-v2', 73 | 'max_step': 1000, 74 | 'state_dim': 8, 75 | 'action_dim': 4, 76 | 'if_discrete': True, 77 | 'target_return': 200, 78 | 'id': 'LunarLanderContinuous-v2' 79 | } 80 | 81 | args = Arguments(AgentModSAC, env_func=env_func, env_args=env_args) 82 | 83 | Step 5: Specify hyper-parameters 84 | ---------------------------------------- 85 | 86 | A list of hyper-parameters is available `here `_. 87 | 88 | .. code-block:: python 89 | 90 | args.target_step = args.max_step 91 | args.gamma = 0.99 92 | args.eval_times = 2 ** 5 93 | 94 | 95 | Step 6: Train your agent 96 | ---------------------------------------- 97 | 98 | In this tutorial, we provide a single-process demo to train an agent: 99 | 100 | .. code-block:: python 101 | 102 | train_and_evaluate(args) 103 | 104 | 105 | Try by yourself through this `Colab `_! 106 | 107 | 108 | Performance of a trained agent: 109 | 110 | .. image:: ../images/LunarLanderTwinDelay3.gif 111 | :width: 80% 112 | :align: center 113 | -------------------------------------------------------------------------------- /docs/source/tutorial/hterm.rst: -------------------------------------------------------------------------------- 1 | How to learn stably: H-term 2 | ====================================================== 3 | 4 | Stability plays a key role in productizing DRL applications to real-world problems, making it a central concern of DRL researchers and practitioners. Recently, a lot of algorithms and open-source software have been developed to address this challenge. A popular open-source library `Stable-Baselines3 `_ offers a set of reliable implementations of DRL algorithms that match prior results. 5 | 6 | In this article, we introduce a **Hamiltonian-term (H-term)**, a generic add-on in ElegantRL that can be applied to existing model-free DRL algorithms. The H-term essentially trades computing power for stability. 7 | 8 | Basic Idea 9 | ----------------------------------------------- 10 | In a standard RL problem, a decision-making process can be modeled as a Markov Decision Process (MDP). The Bellman equation gives the optimality condition for MDP problems: 11 | 12 | .. image:: ../images/bellman.png 13 | :width: 80% 14 | :align: center 15 | 16 | The above equation is inherently recursive, so we expand it as follows: 17 | 18 | .. image:: ../images/recursive.png 19 | :width: 80% 20 | :align: center 21 | 22 | In practice, we aim to find a policy that maximizes the Q-value. By taking a variational approach, we can rewrite the Bellman equation into a Hamiltonian equation. Our goal then is transformed to find a policy that minimizes the energy of a system. (Check our `paper `_ for details). 23 | 24 | .. image:: ../images/H-term.png 25 | :width: 80% 26 | :align: center 27 | 28 | A Simple Add-on 29 | ----------------------------------------------- 30 | The derivations and physical interpretations may be a little bit scary, however, the actual implementation of the H-term is super simple. Here, we present the pseudocode and make a comparison (marked in red) to the Actor-Critic algorithms: 31 | 32 | .. image:: ../images/pseudo.png 33 | :width: 80% 34 | :align: center 35 | 36 | As marked out in lines 19–20, we include an additional update of the policy network, in order to minimize the H-term. Different from most algorithms that optimize on a single step (batch of transitions), we emphasize the importance of the sequential information from a trajectory (batch of trajectories). 37 | 38 | It is a fact that optimizing the H-term is compute-intensive, controlled by the hyper-parameter L (the number of selected trajectories) and K (the length of each trajectory). Fortunately, ElegantRL fully supports parallel computing from a single GPU to hundreds of GPUs, which provides the opportunity to trade computing power for stability. 39 | 40 | Example: Hopper-v2 41 | ----------------------------------------------- 42 | Currently, we have implemented the H-term into several widely-used DRL algorithms, PPO, SAC, TD3, and DDPG. Here, we present the performance on a benchmark problem `Hopper-v2 `_ using PPO algorithm. 43 | 44 | The implementations of PPO+H in `here `_ 45 | 46 | .. image:: ../images/samples.png 47 | :width: 80% 48 | :align: center 49 | 50 | .. image:: ../images/time.png 51 | :width: 80% 52 | :align: center 53 | 54 | In terms of variance, it is obvious that ElegantRL substantially outperforms Stable-Baseline3. The variance over 8 runs is much smaller. Also, the PPO+H in ElegantRL completed the training process of 5M samples in about 6x faster than Stable-Baseline3. 55 | -------------------------------------------------------------------------------- /docs/source/tutorial/redq.rst: -------------------------------------------------------------------------------- 1 | How to run learner parallelism: REDQ 2 | ================================================== -------------------------------------------------------------------------------- /elegantrl/__init__.py: -------------------------------------------------------------------------------- 1 | from .train.run import train_agent 2 | from .train.run import train_agent_single_process 3 | from .train.run import train_agent_multiprocessing 4 | from .train.run import train_agent_multiprocessing_multi_gpu 5 | 6 | from .train.config import Config 7 | from .train.config import get_gym_env_args 8 | -------------------------------------------------------------------------------- /elegantrl/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .AgentBase import AgentBase 2 | 3 | # DQN (off-policy) 4 | from .AgentDQN import AgentDQN, AgentDuelingDQN 5 | from .AgentDQN import AgentDoubleDQN, AgentD3QN 6 | from .AgentEmbedDQN import AgentEmbedDQN, AgentEnsembleDQN 7 | 8 | # off-policy 9 | from .AgentTD3 import AgentTD3, AgentDDPG 10 | from .AgentSAC import AgentSAC, AgentModSAC 11 | 12 | # on-policy 13 | from .AgentPPO import AgentPPO, AgentDiscretePPO 14 | from .AgentPPO import AgentA2C, AgentDiscreteA2C 15 | -------------------------------------------------------------------------------- /elegantrl/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/elegantrl/envs/__init__.py -------------------------------------------------------------------------------- /elegantrl/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .run import train_agent 2 | from .run import train_agent_single_process 3 | from .run import train_agent_multiprocessing 4 | from .run import train_agent_multiprocessing_multi_gpu 5 | 6 | from .config import build_env, get_gym_env_args 7 | from .config import Config 8 | from .evaluator import Evaluator 9 | from .replay_buffer import ReplayBuffer 10 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/examples/__init__.py -------------------------------------------------------------------------------- /examples/list_gym_envs.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script lists out all OpenAI gym environments that can be tested on. Some of them 3 | (Ant, Hopper, etc.) require additional external dependencies (mujoco_py, etc.). 4 | """ 5 | 6 | from gym import envs 7 | 8 | all_envs = envs.registry.all() 9 | env_ids = [env_spec.id for env_spec in all_envs] 10 | 11 | for env_id in env_ids: 12 | print(env_id) 13 | -------------------------------------------------------------------------------- /examples/plan_BipedalWalker-v3.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from elegantrl.agents import AgentPPO 3 | from elegantrl.train.config import get_gym_env_args, Arguments 4 | from elegantrl.train.run import * 5 | 6 | gym.logger.set_level(40) # Block warning 7 | 8 | get_gym_env_args(gym.make("BipedalWalker-v3"), if_print=True) 9 | 10 | env_func = gym.make 11 | env_args = { 12 | "num_envs": 1, 13 | "env_name": "BipedalWalker-v3", 14 | "max_step": 1600, 15 | "state_dim": 24, 16 | "action_dim": 4, 17 | "if_discrete": False, 18 | "target_return": 300, 19 | "id": "BipedalWalker-v3", 20 | } 21 | args = Arguments(AgentPPO, env_func=env_func, env_args=env_args) 22 | 23 | args.target_step = args.max_step * 4 24 | args.gamma = 0.98 25 | args.eval_times = 2 ** 4 26 | 27 | if __name__ == '__main__': 28 | flag = "SingleProcess" 29 | 30 | if flag == "SingleProcess": 31 | args.learner_gpu_ids = 0 32 | train_and_evaluate(args) 33 | elif flag == "MultiProcess": 34 | args.learner_gpu_ids = 0 35 | train_and_evaluate_mp(args) 36 | elif flag == "MultiGPU": 37 | args.learner_gpu_ids = [0, 1, 2, 3] 38 | train_and_evaluate_mp(args) 39 | elif flag == "Tournament-based": 40 | args.learner_gpu_ids = [ 41 | [i, ] for i in range(4) 42 | ] # [[0,], [1, ], [2, ]] or [[0, 1], [2, 3]] 43 | python_path = "../bin/python3" 44 | train_and_evaluate_mp(args, python_path) # multiple processing 45 | else: 46 | raise ValueError(f"Unknown flag: {flag}") 47 | -------------------------------------------------------------------------------- /examples/plan_Hopper-v2_H.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from elegantrl.train.demo import * 4 | 5 | 6 | def demo_continuous_action_on_policy(): 7 | gpu_id = ( 8 | int(sys.argv[1]) if len(sys.argv) > 1 else 0 9 | ) # >=0 means GPU ID, -1 means CPU 10 | drl_id = 1 # int(sys.argv[2]) 11 | env_id = 4 # int(sys.argv[3]) 12 | 13 | env_name = "Hopper-v2" 14 | agent = AgentPPO_H 15 | 16 | print("agent", agent.__name__) 17 | print("gpu_id", gpu_id) 18 | print("env_name", env_name) 19 | 20 | env_func = gym.make 21 | env_args = { 22 | "num_envs": 1, 23 | "env_name": "Hopper-v2", 24 | "max_step": 1000, 25 | "state_dim": 11, 26 | "action_dim": 3, 27 | "if_discrete": False, 28 | "target_return": 3800.0, 29 | } 30 | args = Arguments(agent, env_func=env_func, env_args=env_args) 31 | args.eval_times = 2**1 32 | args.reward_scale = 2**-4 33 | 34 | args.target_step = args.max_step * 4 # 6 35 | args.worker_num = 2 36 | 37 | args.net_dim = 2**7 38 | args.layer_num = 3 39 | args.batch_size = int(args.net_dim * 2) 40 | args.repeat_times = 2**4 41 | args.ratio_clip = 0.25 42 | args.gamma = 0.993 43 | args.lambda_entropy = 0.02 44 | args.lambda_h_term = 2**-5 45 | 46 | args.if_allow_break = False 47 | args.break_step = int(8e6) 48 | 49 | args.learner_gpu_ids = gpu_id 50 | args.random_seed += gpu_id 51 | 52 | if_check = 0 53 | if if_check: 54 | train_and_evaluate(args) 55 | else: 56 | train_and_evaluate_mp(args) 57 | 58 | 59 | if __name__ == "__main__": 60 | demo_continuous_action_on_policy() 61 | -------------------------------------------------------------------------------- /examples/plan_Isaac_Gym.py: -------------------------------------------------------------------------------- 1 | import isaacgym 2 | import torch 3 | import sys 4 | import wandb 5 | 6 | from elegantrl.train.run import train_and_evaluate 7 | from elegantrl.train.config import Arguments, build_env 8 | from elegantrl.agents.AgentPPO import AgentPPO 9 | from elegantrl.envs.IsaacGym import IsaacVecEnv, IsaacOneEnv 10 | 11 | 12 | def demo(seed, config): 13 | agent_class = AgentPPO 14 | env_func = IsaacVecEnv 15 | gpu_id = 0 16 | 17 | env_args = { 18 | 'num_envs': config['num_envs'], 19 | 'env_name': config['env_name'], 20 | 'max_step': config['max_step'], 21 | 'state_dim': config['state_dim'], 22 | 'action_dim': config['action_dim'], 23 | 'if_discrete': False, 24 | 'target_return': 10000., 25 | 'sim_device_id': gpu_id, 26 | 'rl_device_id': gpu_id, 27 | } 28 | env = build_env(env_func=env_func, env_args=env_args) 29 | args = Arguments(agent_class, env=env) 30 | args.if_Isaac = True 31 | args.if_use_old_traj = True 32 | args.if_use_gae = True 33 | args.obs_norm = True 34 | args.value_norm = False 35 | 36 | args.reward_scale = config['reward_scale'] 37 | args.horizon_len = config['horizon_len'] 38 | args.batch_size = config['batch_size'] 39 | args.repeat_times = 5 40 | args.gamma = 0.99 41 | args.lambda_gae_adv = 0.95 42 | args.learning_rate = 5e-4 43 | args.lambda_entropy = 0.0 44 | 45 | args.eval_gap = 1e6 46 | args.learner_gpu_ids = gpu_id 47 | args.random_seed = seed 48 | args.cwd = f'./result/{args.env_name}_{args.agent_class.__name__[5:]}_{args.num_envs}envs/{args.random_seed}' 49 | 50 | train_and_evaluate(args) 51 | 52 | 53 | if __name__ == '__main__': 54 | seed = int(sys.argv[1]) if len(sys.argv) > 1 else 0 55 | config = { 56 | 'env_name': 'Ant', 57 | 'num_envs': 2048, 58 | 'state_dim': 60, 59 | 'action_dim': 8, 60 | 'max_step': 1000, 61 | 'reward_scale': 0.01, 62 | 'horizon_len': 32, 63 | 'batch_size': 16384, 64 | } 65 | # config = { 66 | # 'env_name': 'Humanoid', 67 | # 'num_envs': 2048, 68 | # 'state_dim': 108, 69 | # 'action_dim': 21, 70 | # 'max_step': 1000, 71 | # 'reward_scale': 0.01, 72 | # 'horizon_len': 32, 73 | # 'batch_size': 16384, 74 | # } 75 | # config = { 76 | # 'env_name': 'ShadowHand', 77 | # 'num_envs': 16384, 78 | # 'state_dim': 211, 79 | # 'action_dim': 20, 80 | # 'max_step': 600, 81 | # 'reward_scale': 0.01, 82 | # 'horizon_len': 8, 83 | # 'batch_size': 32768, 84 | # } 85 | # config = { 86 | # 'env_name': 'Anymal', 87 | # 'num_envs': 4096, 88 | # 'state_dim': 48, 89 | # 'action_dim': 12, 90 | # 'max_step': 2500, 91 | # 'reward_scale': 1, 92 | # 'horizon_len': 32, 93 | # 'batch_size': 16384, 94 | # } 95 | # config = { 96 | # 'env_name': 'Ingenuity', 97 | # 'num_envs': 4096, 98 | # 'state_dim': 13, 99 | # 'action_dim': 6, 100 | # 'max_step': 2000, 101 | # 'reward_scale': 1, 102 | # 'horizon_len': 16, 103 | # 'batch_size': 16384, 104 | # } 105 | cwd = config['env_name'] + '_PPO_' + str(seed) 106 | wandb.init( 107 | project=config['env_name'] + '_PPO_' + str(config['num_envs']), 108 | entity=None, 109 | sync_tensorboard=True, 110 | config=config, 111 | name=cwd, 112 | monitor_gym=True, 113 | save_code=True, 114 | ) 115 | config = wandb.config 116 | demo(seed, config) 117 | -------------------------------------------------------------------------------- /examples/plan_PaperTradingEnv_PPO.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/AI4Finance-Foundation/FinRL-Meta/blob/master/examples/FinRL_PaperTrading_Demo.ipynb 3 | """ 4 | 5 | """Part I""" 6 | 7 | API_KEY = "PKAVSDVA8AIK4YBOOL3S" 8 | API_SECRET = "U6TKEjt9C77Dw21ca8zVGUhsZxTUohaLYdmOrO3L" 9 | API_BASE_URL = 'https://paper-api.alpaca.markets' 10 | data_url = 'wss://data.alpaca.markets' 11 | 12 | from finrl.config_tickers import DOW_30_TICKER 13 | from finrl.config import INDICATORS 14 | from finrl.meta.env_stock_trading.env_stocktrading_np import StockTradingEnv 15 | from finrl.meta.env_stock_trading.env_stock_papertrading import AlpacaPaperTrading 16 | from finrl.meta.data_processor import DataProcessor 17 | from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline 18 | 19 | import numpy as np 20 | import pandas as pd 21 | 22 | 23 | def train( 24 | start_date, 25 | end_date, 26 | ticker_list, 27 | data_source, 28 | time_interval, 29 | technical_indicator_list, 30 | drl_lib, 31 | env, 32 | model_name, 33 | if_vix=True, 34 | **kwargs, 35 | ): 36 | # download data 37 | dp = DataProcessor(data_source, **kwargs) 38 | data = dp.download_data(ticker_list, start_date, end_date, time_interval) 39 | data = dp.clean_data(data) 40 | data = dp.add_technical_indicator(data, technical_indicator_list) 41 | if if_vix: 42 | data = dp.add_vix(data) 43 | else: 44 | data = dp.add_turbulence(data) 45 | price_array, tech_array, turbulence_array = dp.df_to_array(data, if_vix) 46 | 47 | np.save('price_array.npy', price_array) 48 | np.save('tech_array.npy', tech_array) 49 | np.save('turbulence_array.npy', turbulence_array) 50 | print("| save in '.'") 51 | 52 | 53 | def run(): 54 | ticker_list = DOW_30_TICKER 55 | env = StockTradingEnv 56 | erl_params = {"learning_rate": 3e-6, "batch_size": 2048, "gamma": 0.985, 57 | "seed": 312, "net_dimension": [128, 64], "target_step": 5000, "eval_gap": 30, 58 | "eval_times": 1} 59 | 60 | train(start_date='2022-08-25', 61 | end_date='2022-08-31', 62 | ticker_list=ticker_list, 63 | data_source='alpaca', 64 | time_interval='1Min', 65 | technical_indicator_list=INDICATORS, 66 | drl_lib='elegantrl', 67 | env=env, 68 | model_name='ppo', 69 | if_vix=True, 70 | API_KEY=API_KEY, 71 | API_SECRET=API_SECRET, 72 | API_BASE_URL=API_BASE_URL, 73 | erl_params=erl_params, 74 | cwd='./papertrading_erl', # current_working_dir 75 | break_step=1e5) 76 | 77 | 78 | if __name__ == '__main__': 79 | run() 80 | 81 | 82 | """ 83 | (base) develop@rlsmartagent-dev:~/workspace/ElegantRL0101/examples$ pip install git+https://github.com/AI4Finance-Foundation/FinRL.git 84 | Defaulting to user installation because normal site-packages is not writeable 85 | Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple 86 | Collecting git+https://github.com/AI4Finance-Foundation/FinRL.git 87 | Cloning https://github.com/AI4Finance-Foundation/FinRL.git to /tmp/pip-req-build-er0zbi_n 88 | Running command git clone -q https://github.com/AI4Finance-Foundation/FinRL.git /tmp/pip-req-build-er0zbi_n 89 | 90 | 91 | 92 | fatal: unable to access 'https://github.com/AI4Finance-Foundation/FinRL.git/': GnuTLS recv error (-110): The TLS connection was non-properly terminated. 93 | WARNING: Discarding git+https://github.com/AI4Finance-Foundation/FinRL.git. Command errored out with exit status 128: git clone -q https://github.com/AI4Finance-Foundation/FinRL.git /tmp/pip-req-build-er0zbi_n Check the logs for full command output. 94 | ERROR: Command errored out with exit status 128: git clone -q https://github.com/AI4Finance-Foundation/FinRL.git /tmp/pip-req-build-er0zbi_n Check the logs for full command output. 95 | """ -------------------------------------------------------------------------------- /examples/tutorial_Hopper-v3.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from elegantrl.agents import AgentPPO 3 | from elegantrl.train.config import get_gym_env_args, Arguments 4 | from elegantrl.train.run import * 5 | 6 | # set environment name here (e.g. 'Hopper-v3', 'LunarLanderContinuous-v2', 7 | # 'BipedalWalker-v3') 8 | env_name = "Hopper-v3" 9 | 10 | # retrieve appropriate training arguments for this environment 11 | env_args = get_gym_env_args(gym.make(env_name), if_print=False) 12 | args = Arguments(AgentPPO, env_func=gym.make, env_args=env_args) 13 | 14 | # set/modify any arguments you'd like to here 15 | args.batch_size = 2**16 16 | args.eval_times = 2**4 17 | args.max_memo = 2**16 18 | args.target_step = 2**16 19 | 20 | # print out arguments in an easy-to-read format to show you what you're about to 21 | # train... 22 | args.print() 23 | 24 | # ...and go! 25 | train_and_evaluate(args) 26 | -------------------------------------------------------------------------------- /examples/tutorial_LunarLanderContinous-v2.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from elegantrl.agents import AgentModSAC 3 | from elegantrl.train.config import get_gym_env_args, Arguments 4 | from elegantrl.train.run import * 5 | 6 | gym.logger.set_level(40) # Block warning 7 | 8 | get_gym_env_args(gym.make("LunarLanderContinuous-v2"), if_print=False) 9 | 10 | env_func = gym.make 11 | env_args = { 12 | "num_envs": 1, 13 | "env_name": "LunarLanderContinuous-v2", 14 | "max_step": 1000, 15 | "state_dim": 8, 16 | "action_dim": 2, 17 | "if_discrete": False, 18 | "target_return": 200, 19 | "id": "LunarLanderContinuous-v2", 20 | } 21 | args = Arguments(AgentModSAC, env_func=env_func, env_args=env_args) 22 | 23 | args.target_step = args.max_step 24 | args.gamma = 0.99 25 | args.eval_times = 2**5 26 | args.random_seed = 2022 27 | 28 | train_and_evaluate(args) 29 | -------------------------------------------------------------------------------- /figs/BipdealWalkerHardCore_313score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/BipdealWalkerHardCore_313score.png -------------------------------------------------------------------------------- /figs/BipedalWalkerHardcore-v2-total-668kb.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/BipedalWalkerHardcore-v2-total-668kb.gif -------------------------------------------------------------------------------- /figs/ElegantRL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/ElegantRL.png -------------------------------------------------------------------------------- /figs/File_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/File_structure.png -------------------------------------------------------------------------------- /figs/LunarLanderTwinDelay3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/LunarLanderTwinDelay3.gif -------------------------------------------------------------------------------- /figs/RL_survey_2020.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/RL_survey_2020.pdf -------------------------------------------------------------------------------- /figs/RL_survey_2020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/RL_survey_2020.png -------------------------------------------------------------------------------- /figs/SB3_vs_ElegantRL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/SB3_vs_ElegantRL.png -------------------------------------------------------------------------------- /figs/envs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/envs.png -------------------------------------------------------------------------------- /figs/icon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/icon.jpg -------------------------------------------------------------------------------- /figs/original.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/original.gif -------------------------------------------------------------------------------- /figs/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/performance.png -------------------------------------------------------------------------------- /figs/performance1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/performance1.png -------------------------------------------------------------------------------- /figs/performance2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/figs/performance2.png -------------------------------------------------------------------------------- /helloworld/erl_env.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import gymnasium as gym 5 | 6 | ARY = np.ndarray 7 | 8 | 9 | class PendulumEnv(gym.Wrapper): # a demo of custom env 10 | def __init__(self): 11 | gym_env_name = 'Pendulum-v1' 12 | super().__init__(env=gym.make(gym_env_name)) 13 | 14 | '''the necessary env information when you design a custom env''' 15 | self.env_name = gym_env_name # the name of this env. 16 | self.state_dim = self.observation_space.shape[0] # feature number of state 17 | self.action_dim = self.action_space.shape[0] # feature number of action 18 | self.if_discrete = False # discrete action or continuous action 19 | 20 | def reset(self, **kwargs) -> Tuple[ARY, dict]: # reset the agent in env 21 | state, info_dict = self.env.reset() 22 | return state, info_dict 23 | 24 | def step(self, action: ARY) -> Tuple[ARY, float, bool, bool, dict]: # agent interacts in env 25 | # OpenAI Pendulum env set its action space as (-2, +2). It is bad. 26 | # We suggest that adjust action space to (-1, +1) when designing a custom env. 27 | state, reward, terminated, truncated, info_dict = self.env.step(action * 2) 28 | state = state.reshape(self.state_dim) 29 | return state, float(reward) * 0.5, terminated, truncated, info_dict 30 | -------------------------------------------------------------------------------- /helloworld/erl_tutorial_DDPG.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from erl_config import Config, get_gym_env_args 5 | from erl_agent import AgentDDPG 6 | from erl_run import train_agent, valid_agent 7 | from erl_env import PendulumEnv 8 | 9 | 10 | def train_ddpg_for_pendulum(gpu_id=0): 11 | agent_class = AgentDDPG # DRL algorithm 12 | env_class = PendulumEnv # run a custom env: PendulumEnv, which based on OpenAI pendulum 13 | env_args = { 14 | 'env_name': 'Pendulum', # Apply torque on the free end to swing a pendulum into an upright position 15 | # Reward: r = -(theta + 0.1 * theta_dt + 0.001 * torque) 16 | 17 | 'state_dim': 3, # the x-y coordinates of the pendulum's free end and its angular velocity. 18 | 'action_dim': 1, # the torque applied to free end of the pendulum 19 | 'if_discrete': False # continuous action space, symbols → direction, value → force 20 | } 21 | get_gym_env_args(env=PendulumEnv(), if_print=True) # return env_args 22 | 23 | args = Config(agent_class, env_class, env_args) # see `erl_config.py Arguments()` for hyperparameter explanation 24 | args.break_step = int(1e5) # break training if 'total_step > break_step' 25 | args.net_dims = [64, 32] # the middle layer dimension of MultiLayer Perceptron 26 | args.gamma = 0.97 # discount factor of future rewards 27 | 28 | args.gpu_id = gpu_id # the ID of single GPU, -1 means CPU 29 | train_agent(args) 30 | if input("| Press 'y' to load actor.pth and render:") == 'y': 31 | actor_name = sorted([s for s in os.listdir(args.cwd) if s[-4:] == '.pth'])[-1] 32 | actor_path = f"{args.cwd}/{actor_name}" 33 | valid_agent(env_class, env_args, args.net_dims, agent_class, actor_path) 34 | 35 | 36 | if __name__ == "__main__": 37 | GPU_ID = int(sys.argv[1]) if len(sys.argv) > 1 else 0 38 | train_ddpg_for_pendulum(gpu_id=GPU_ID) 39 | -------------------------------------------------------------------------------- /helloworld/erl_tutorial_DQN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import gymnasium as gym 4 | 5 | from erl_config import Config, get_gym_env_args 6 | from erl_agent import AgentDQN 7 | from erl_run import train_agent, valid_agent 8 | 9 | gym.logger.set_level(40) # Block warning 10 | 11 | 12 | def train_dqn_for_cartpole(gpu_id=0): 13 | agent_class = AgentDQN # DRL algorithm 14 | env_class = gym.make 15 | env_args = { 16 | 'env_name': 'CartPole-v0', # A pole is attached by an un-actuated joint to a cart. 17 | # Reward: keep the pole upright, a reward of `+1` for every step taken 18 | 19 | 'state_dim': 4, # (CartPosition, CartVelocity, PoleAngle, PoleAngleVelocity) 20 | 'action_dim': 2, # (Push cart to the left, Push cart to the right) 21 | 'if_discrete': True, # discrete action space 22 | } 23 | get_gym_env_args(env=gym.make('CartPole-v0'), if_print=True) # return env_args 24 | 25 | args = Config(agent_class, env_class, env_args) # see `erl_config.py Arguments()` for hyperparameter explanation 26 | args.break_step = int(1e5) # break training if 'total_step > break_step' 27 | args.net_dims = [64, 32] # the middle layer dimension of MultiLayer Perceptron 28 | args.gamma = 0.95 # discount factor of future rewards 29 | 30 | args.gpu_id = gpu_id # the ID of single GPU, -1 means CPU 31 | train_agent(args) 32 | if input("| Press 'y' to load actor.pth and render:") == 'y': 33 | actor_name = sorted([s for s in os.listdir(args.cwd) if s[-4:] == '.pth'])[-1] 34 | actor_path = f"{args.cwd}/{actor_name}" 35 | valid_agent(env_class, env_args, args.net_dims, agent_class, actor_path) 36 | 37 | 38 | def train_dqn_for_lunar_lander(gpu_id=0): 39 | agent_class = AgentDQN # DRL algorithm 40 | env_class = gym.make 41 | env_args = { 42 | 'env_name': 'LunarLander-v2', # A lander learns to land on a landing pad and using as little fuel as possible 43 | # Reward: Lander moves to the landing pad and come rest +100; lander crashes -100. 44 | # Reward: Lander moves to landing pad get positive reward, move away gets negative reward. 45 | # Reward: Firing the main engine -0.3, side engine -0.03 each frame. 46 | 47 | 'state_dim': 8, # coordinates xy, linear velocities xy, angle, angular velocity, two booleans 48 | 'action_dim': 4, # do nothing, fire left engine, fire main engine, fire right engine. 49 | 'if_discrete': True # discrete action space 50 | } 51 | get_gym_env_args(env=gym.make('LunarLander-v2'), if_print=True) # return env_args 52 | 53 | args = Config(agent_class, env_class, env_args) # see `erl_config.py Arguments()` for hyperparameter explanation 54 | args.break_step = int(4e5) # break training if 'total_step > break_step' 55 | args.explore_rate = 0.1 # the probability of choosing action randomly in epsilon-greedy 56 | args.net_dims = [128, 64] # the middle layer dimension of Fully Connected Network 57 | 58 | args.gpu_id = gpu_id # the ID of single GPU, -1 means CPU 59 | train_agent(args) 60 | if input("| Press 'y' to load actor.pth and render:") == 'y': 61 | actor_name = sorted([s for s in os.listdir(args.cwd) if s[-4:] == '.pth'])[-1] 62 | actor_path = f"{args.cwd}/{actor_name}" 63 | valid_agent(env_class, env_args, args.net_dims, agent_class, actor_path) 64 | 65 | 66 | if __name__ == "__main__": 67 | GPU_ID = int(sys.argv[1]) if len(sys.argv) > 1 else 0 68 | train_dqn_for_cartpole(gpu_id=GPU_ID) 69 | """ 70 | | Arguments Remove cwd: ./CartPole-v1_DQN_0 71 | | Evaluator: 72 | | `step`: Number of samples, or total training steps, or running times of `env.step()`. 73 | | `time`: Time spent from the start of training to this moment. 74 | | `avgR`: Average value of cumulative rewards, which is the sum of rewards in an episode. 75 | | `stdR`: Standard dev of cumulative rewards, which is the sum of rewards in an episode. 76 | | `avgS`: Average of steps in an episode. 77 | | `objC`: Objective of Critic network. Or call it loss function of critic network. 78 | | `objA`: Objective of Actor network. It is the average Q value of the critic network. 79 | | step time | avgR stdR avgS | objC objA 80 | | 1.02e+04 19 | 17.31 2.11 17 | 0.92 19.77 81 | | 2.05e+04 39 | 9.47 0.71 9 | 0.93 23.96 82 | | 3.07e+04 66 | 191.25 18.10 191 | 1.38 31.52 83 | | 4.10e+04 98 | 212.41 16.34 212 | 0.65 21.52 84 | | 5.12e+04 141 | 183.41 10.96 183 | 0.47 21.10 85 | | 6.14e+04 184 | 171.94 8.44 172 | 0.34 20.48 86 | | 7.17e+04 233 | 173.00 8.85 173 | 0.27 19.88 87 | | 8.19e+04 290 | 115.84 3.61 116 | 0.24 19.95 88 | | 9.22e+04 349 | 128.44 5.99 128 | 0.19 19.80 89 | """ 90 | 91 | train_dqn_for_lunar_lander(gpu_id=GPU_ID) 92 | -------------------------------------------------------------------------------- /helloworld/erl_tutorial_PPO.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import gymnasium as gym 4 | 5 | from erl_config import Config, get_gym_env_args 6 | from erl_agent import AgentPPO 7 | from erl_run import train_agent, valid_agent 8 | from erl_env import PendulumEnv 9 | 10 | 11 | def train_ppo_for_pendulum(gpu_id=0): 12 | agent_class = AgentPPO # DRL algorithm name 13 | env_class = PendulumEnv # run a custom env: PendulumEnv, which based on OpenAI pendulum 14 | env_args = { 15 | 'env_name': 'Pendulum', # Apply torque on the free end to swing a pendulum into an upright position 16 | # Reward: r = -(theta + 0.1 * theta_dt + 0.001 * torque) 17 | 18 | 'state_dim': 3, # the x-y coordinates of the pendulum's free end and its angular velocity. 19 | 'action_dim': 1, # the torque applied to free end of the pendulum 20 | 'if_discrete': False # continuous action space, symbols → direction, value → force 21 | } 22 | get_gym_env_args(env=PendulumEnv(), if_print=True) # return env_args 23 | 24 | args = Config(agent_class, env_class, env_args) # see `erl_config.py Arguments()` for hyperparameter explanation 25 | args.break_step = int(2e5) # break training if 'total_step > break_step' 26 | args.net_dims = [64, 32] # the middle layer dimension of MultiLayer Perceptron 27 | args.gamma = 0.97 # discount factor of future rewards 28 | args.repeat_times = 16 # repeatedly update network using ReplayBuffer to keep critic's loss small 29 | 30 | args.gpu_id = gpu_id # the ID of single GPU, -1 means CPU 31 | train_agent(args) 32 | if input("| Press 'y' to load actor.pth and render:") == 'y': 33 | actor_name = sorted([s for s in os.listdir(args.cwd) if s[-4:] == '.pth'])[-1] 34 | actor_path = f"{args.cwd}/{actor_name}" 35 | valid_agent(env_class, env_args, args.net_dims, agent_class, actor_path) 36 | 37 | 38 | def train_ppo_for_lunar_lander(gpu_id=0): 39 | agent_class = AgentPPO # DRL algorithm name 40 | env_class = gym.make 41 | env_args = { 42 | 'env_name': 'LunarLanderContinuous-v2', # A lander learns to land on a landing pad 43 | # Reward: Lander moves to the landing pad and come rest +100; lander crashes -100. 44 | # Reward: Lander moves to landing pad get positive reward, move away gets negative reward. 45 | # Reward: Firing the main engine -0.3, side engine -0.03 each frame. 46 | 47 | 'state_dim': 8, # coordinates xy, linear velocities xy, angle, angular velocity, two booleans 48 | 'action_dim': 2, # fire main engine or side engine. 49 | 'if_discrete': False # continuous action space, symbols → direction, value → force 50 | } 51 | get_gym_env_args(env=gym.make('LunarLanderContinuous-v2'), if_print=True) # return env_args 52 | 53 | args = Config(agent_class, env_class, env_args) # see `erl_config.py Arguments()` for hyperparameter explanation 54 | args.break_step = int(4e5) # break training if 'total_step > break_step' 55 | args.net_dims = [64, 32] # the middle layer dimension of MultiLayer Perceptron 56 | args.repeat_times = 32 # repeatedly update network using ReplayBuffer to keep critic's loss small 57 | args.lambda_entropy = 0.04 # the lambda of the policy entropy term in PPO 58 | args.gamma = 0.98 59 | 60 | args.gpu_id = gpu_id # the ID of single GPU, -1 means CPU 61 | train_agent(args) 62 | if input("| Press 'y' to load actor.pth and render:") == 'y': 63 | actor_name = sorted([s for s in os.listdir(args.cwd) if s[-4:] == '.pth'])[-1] 64 | actor_path = f"{args.cwd}/{actor_name}" 65 | valid_agent(env_class, env_args, args.net_dims, agent_class, actor_path) 66 | 67 | 68 | if __name__ == "__main__": 69 | GPU_ID = int(sys.argv[1]) if len(sys.argv) > 1 else 0 70 | train_ppo_for_pendulum(gpu_id=GPU_ID) 71 | train_ppo_for_lunar_lander(gpu_id=GPU_ID) 72 | -------------------------------------------------------------------------------- /helloworld/unit_tests/check_env.py: -------------------------------------------------------------------------------- 1 | from env import * 2 | 3 | 4 | def check_pendulum_env(): 5 | env = PendulumEnv() 6 | assert isinstance(env.env_name, str) 7 | assert isinstance(env.state_dim, int) 8 | assert isinstance(env.action_dim, int) 9 | assert isinstance(env.if_discrete, bool) 10 | 11 | state = env.reset() 12 | assert state.shape == (env.state_dim,) 13 | 14 | action = np.random.uniform(-1, +1, size=env.action_dim) 15 | state, reward, done, info_dict = env.step(action) 16 | assert isinstance(state, np.ndarray) 17 | assert state.shape == (env.state_dim,) 18 | assert isinstance(state, np.ndarray) 19 | assert isinstance(reward, float) 20 | assert isinstance(done, bool) 21 | assert isinstance(info_dict, dict) or (info_dict is None) 22 | 23 | 24 | if __name__ == '__main__': 25 | check_pendulum_env() 26 | print('| Finish checking.') 27 | -------------------------------------------------------------------------------- /helloworld/unit_tests/check_run.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import numpy as np 4 | 5 | from run import * 6 | 7 | 8 | def check_get_rewards_and_steps(net_dims=(64, 32)): 9 | pass 10 | 11 | """discrete env""" 12 | from env import gym 13 | env_args = {'env_name': 'CartPole-v1', 'state_dim': 4, 'action_dim': 2, 'if_discrete': True} 14 | env_class = gym.make 15 | env = build_env(env_class=env_class, env_args=env_args) 16 | 17 | '''discrete env, on-policy''' 18 | from net import QNet 19 | actor = QNet(dims=net_dims, state_dim=env.state_dim, action_dim=env.action_dim) 20 | cumulative_returns, episode_steps = get_rewards_and_steps(env=env, actor=actor) 21 | assert isinstance(cumulative_returns, float) 22 | assert isinstance(episode_steps, int) 23 | assert episode_steps >= 1 24 | 25 | """continuous env""" 26 | from env import PendulumEnv 27 | env_args = {'env_name': 'Pendulum-v1', 'state_dim': 3, 'action_dim': 1, 'if_discrete': False} 28 | env_class = PendulumEnv 29 | env = build_env(env_class=env_class, env_args=env_args) 30 | 31 | '''continuous env, off-policy''' 32 | from net import Actor 33 | actor = Actor(dims=net_dims, state_dim=env.state_dim, action_dim=env.action_dim) 34 | cumulative_returns, episode_steps = get_rewards_and_steps(env=env, actor=actor) 35 | assert isinstance(cumulative_returns, float) 36 | assert isinstance(episode_steps, int) 37 | assert episode_steps >= 1 38 | 39 | '''continuous env, on-policy''' 40 | from net import ActorPPO 41 | actor = ActorPPO(dims=net_dims, state_dim=env.state_dim, action_dim=env.action_dim) 42 | cumulative_returns, episode_steps = get_rewards_and_steps(env=env, actor=actor) 43 | assert isinstance(cumulative_returns, float) 44 | assert isinstance(episode_steps, int) 45 | assert episode_steps >= 1 46 | 47 | 48 | def check_draw_learning_curve_using_recorder(cwd='./temp'): 49 | os.makedirs(cwd, exist_ok=True) 50 | recorder_path = f"{cwd}/recorder.npy" 51 | recorder_len = 8 52 | 53 | recorder = np.zeros((recorder_len, 3), dtype=np.float32) 54 | recorder[:, 0] = np.linspace(1, 100, num=recorder_len) # total_step 55 | recorder[:, 1] = np.linspace(1, 200, num=recorder_len) # used_time 56 | recorder[:, 2] = np.linspace(1, 300, num=recorder_len) # average of cumulative rewards 57 | np.save(recorder_path, recorder) 58 | draw_learning_curve_using_recorder(cwd) 59 | assert os.path.exists(f"{cwd}/LearningCurve.jpg") 60 | shutil.rmtree(cwd) 61 | 62 | 63 | def check_evaluator(net_dims=(64, 32), horizon_len=1024, eval_per_step=16, eval_times=2, cwd='./temp'): 64 | from env import PendulumEnv 65 | env_args = {'env_name': 'Pendulum-v1', 'state_dim': 3, 'action_dim': 1, 'if_discrete': False} 66 | env_class = PendulumEnv 67 | env = build_env(env_class, env_args) 68 | from net import Actor 69 | actor = Actor(dims=net_dims, state_dim=env.state_dim, action_dim=env.action_dim) 70 | 71 | os.makedirs(cwd, exist_ok=True) 72 | evaluator = Evaluator(eval_env=env, eval_per_step=eval_per_step, eval_times=eval_times, cwd=cwd) 73 | evaluator.evaluate_and_save(actor=actor, horizon_len=horizon_len, logging_tuple=(0.1, 0.2)) 74 | evaluator.evaluate_and_save(actor=actor, horizon_len=horizon_len, logging_tuple=(0.3, 0.4)) 75 | evaluator.close() 76 | assert os.path.exists(f"{evaluator.cwd}/recorder.npy") 77 | assert os.path.exists(f"{evaluator.cwd}/LearningCurve.jpg") 78 | shutil.rmtree(cwd) 79 | 80 | 81 | if __name__ == '__main__': 82 | check_draw_learning_curve_using_recorder() 83 | check_get_rewards_and_steps() 84 | check_evaluator() 85 | print('| Finish checking.') 86 | 87 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # ML framework 2 | torch 3 | 4 | # data handling 5 | numpy 6 | 7 | # plot/simulation 8 | matplotlib 9 | gymnasium 10 | 11 | # profiling (no necessary) 12 | wandb 13 | -------------------------------------------------------------------------------- /rlsolver/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AI4Finance Foundation Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rlsolver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/__init__.py -------------------------------------------------------------------------------- /rlsolver/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/data/__init__.py -------------------------------------------------------------------------------- /rlsolver/data/syn_BA/BA_100_ID0.txt: -------------------------------------------------------------------------------- 1 | 100 384 2 | 1 2 1 3 | 1 3 1 4 | 1 4 1 5 | 1 5 1 6 | 1 6 1 7 | 1 7 1 8 | 1 8 1 9 | 1 9 1 10 | 1 10 1 11 | 1 12 1 12 | 1 14 1 13 | 1 15 1 14 | 1 17 1 15 | 1 20 1 16 | 1 21 1 17 | 1 22 1 18 | 1 24 1 19 | 1 35 1 20 | 1 37 1 21 | 1 39 1 22 | 1 40 1 23 | 1 42 1 24 | 1 52 1 25 | 1 54 1 26 | 1 64 1 27 | 1 70 1 28 | 1 84 1 29 | 1 89 1 30 | 1 93 1 31 | 2 6 1 32 | 2 8 1 33 | 2 20 1 34 | 2 32 1 35 | 2 50 1 36 | 2 76 1 37 | 2 86 1 38 | 3 6 1 39 | 3 7 1 40 | 3 8 1 41 | 3 12 1 42 | 3 14 1 43 | 3 18 1 44 | 3 25 1 45 | 3 27 1 46 | 3 34 1 47 | 3 43 1 48 | 3 44 1 49 | 3 56 1 50 | 3 61 1 51 | 3 65 1 52 | 3 67 1 53 | 3 82 1 54 | 3 85 1 55 | 3 88 1 56 | 4 6 1 57 | 4 7 1 58 | 4 9 1 59 | 4 10 1 60 | 4 17 1 61 | 4 30 1 62 | 4 31 1 63 | 4 36 1 64 | 4 47 1 65 | 4 51 1 66 | 4 56 1 67 | 4 58 1 68 | 4 72 1 69 | 4 74 1 70 | 5 10 1 71 | 5 12 1 72 | 5 16 1 73 | 5 18 1 74 | 5 21 1 75 | 5 22 1 76 | 5 33 1 77 | 5 38 1 78 | 5 42 1 79 | 5 95 1 80 | 5 98 1 81 | 6 7 1 82 | 6 9 1 83 | 6 11 1 84 | 6 13 1 85 | 6 14 1 86 | 6 15 1 87 | 6 16 1 88 | 6 18 1 89 | 6 24 1 90 | 6 26 1 91 | 6 40 1 92 | 6 41 1 93 | 6 45 1 94 | 6 53 1 95 | 6 58 1 96 | 6 67 1 97 | 7 8 1 98 | 7 10 1 99 | 7 11 1 100 | 7 12 1 101 | 7 13 1 102 | 7 14 1 103 | 7 15 1 104 | 7 16 1 105 | 7 17 1 106 | 7 19 1 107 | 7 23 1 108 | 7 24 1 109 | 7 28 1 110 | 7 29 1 111 | 7 33 1 112 | 7 34 1 113 | 7 40 1 114 | 7 42 1 115 | 7 44 1 116 | 7 48 1 117 | 7 51 1 118 | 7 61 1 119 | 7 64 1 120 | 7 68 1 121 | 7 80 1 122 | 7 85 1 123 | 7 99 1 124 | 7 100 1 125 | 8 9 1 126 | 8 19 1 127 | 8 22 1 128 | 8 29 1 129 | 8 54 1 130 | 8 56 1 131 | 8 77 1 132 | 8 89 1 133 | 8 94 1 134 | 8 99 1 135 | 9 11 1 136 | 9 19 1 137 | 9 21 1 138 | 9 23 1 139 | 9 33 1 140 | 9 36 1 141 | 9 38 1 142 | 9 43 1 143 | 9 46 1 144 | 9 54 1 145 | 9 60 1 146 | 9 69 1 147 | 9 71 1 148 | 9 72 1 149 | 9 75 1 150 | 9 83 1 151 | 9 84 1 152 | 9 91 1 153 | 9 92 1 154 | 9 95 1 155 | 10 11 1 156 | 10 13 1 157 | 10 22 1 158 | 10 23 1 159 | 10 27 1 160 | 10 48 1 161 | 10 58 1 162 | 10 59 1 163 | 10 60 1 164 | 10 80 1 165 | 11 26 1 166 | 11 30 1 167 | 11 49 1 168 | 11 53 1 169 | 11 67 1 170 | 11 70 1 171 | 11 94 1 172 | 12 13 1 173 | 12 15 1 174 | 12 25 1 175 | 12 30 1 176 | 12 33 1 177 | 12 40 1 178 | 12 52 1 179 | 12 53 1 180 | 12 63 1 181 | 12 64 1 182 | 12 67 1 183 | 12 73 1 184 | 12 81 1 185 | 12 87 1 186 | 13 16 1 187 | 13 17 1 188 | 13 21 1 189 | 13 23 1 190 | 13 25 1 191 | 13 28 1 192 | 13 41 1 193 | 13 47 1 194 | 13 48 1 195 | 13 74 1 196 | 14 20 1 197 | 14 29 1 198 | 14 55 1 199 | 14 92 1 200 | 15 18 1 201 | 15 20 1 202 | 15 26 1 203 | 15 30 1 204 | 15 37 1 205 | 15 38 1 206 | 15 57 1 207 | 15 66 1 208 | 15 85 1 209 | 15 87 1 210 | 15 91 1 211 | 15 97 1 212 | 16 31 1 213 | 16 59 1 214 | 17 19 1 215 | 17 26 1 216 | 17 27 1 217 | 17 28 1 218 | 17 34 1 219 | 17 35 1 220 | 17 48 1 221 | 17 49 1 222 | 17 66 1 223 | 17 68 1 224 | 17 79 1 225 | 17 86 1 226 | 17 99 1 227 | 18 45 1 228 | 18 49 1 229 | 18 63 1 230 | 18 93 1 231 | 19 32 1 232 | 19 44 1 233 | 19 86 1 234 | 19 90 1 235 | 20 32 1 236 | 20 39 1 237 | 20 77 1 238 | 21 24 1 239 | 21 27 1 240 | 21 43 1 241 | 21 65 1 242 | 21 82 1 243 | 21 96 1 244 | 22 25 1 245 | 22 32 1 246 | 22 36 1 247 | 22 38 1 248 | 22 41 1 249 | 22 47 1 250 | 22 50 1 251 | 22 62 1 252 | 22 71 1 253 | 22 81 1 254 | 22 86 1 255 | 22 87 1 256 | 23 37 1 257 | 23 49 1 258 | 23 56 1 259 | 23 77 1 260 | 23 79 1 261 | 24 31 1 262 | 24 35 1 263 | 24 66 1 264 | 25 59 1 265 | 25 83 1 266 | 26 28 1 267 | 26 29 1 268 | 26 39 1 269 | 26 50 1 270 | 26 65 1 271 | 26 71 1 272 | 26 80 1 273 | 26 88 1 274 | 26 96 1 275 | 27 31 1 276 | 27 34 1 277 | 27 37 1 278 | 27 44 1 279 | 27 45 1 280 | 27 51 1 281 | 27 55 1 282 | 27 57 1 283 | 27 58 1 284 | 27 70 1 285 | 27 73 1 286 | 27 78 1 287 | 27 82 1 288 | 27 88 1 289 | 27 91 1 290 | 27 93 1 291 | 27 97 1 292 | 28 42 1 293 | 28 47 1 294 | 28 78 1 295 | 28 89 1 296 | 29 35 1 297 | 29 55 1 298 | 29 64 1 299 | 29 73 1 300 | 30 39 1 301 | 31 70 1 302 | 32 46 1 303 | 32 57 1 304 | 33 62 1 305 | 34 36 1 306 | 34 41 1 307 | 34 45 1 308 | 34 46 1 309 | 34 59 1 310 | 34 79 1 311 | 36 61 1 312 | 36 83 1 313 | 37 43 1 314 | 37 50 1 315 | 37 75 1 316 | 37 78 1 317 | 37 96 1 318 | 38 73 1 319 | 39 57 1 320 | 39 76 1 321 | 40 51 1 322 | 40 52 1 323 | 40 65 1 324 | 40 75 1 325 | 40 78 1 326 | 41 55 1 327 | 41 60 1 328 | 41 81 1 329 | 42 46 1 330 | 42 61 1 331 | 42 62 1 332 | 42 72 1 333 | 42 84 1 334 | 42 90 1 335 | 44 60 1 336 | 44 72 1 337 | 45 54 1 338 | 46 62 1 339 | 46 84 1 340 | 47 63 1 341 | 48 53 1 342 | 48 87 1 343 | 49 52 1 344 | 49 98 1 345 | 50 68 1 346 | 50 69 1 347 | 51 83 1 348 | 51 85 1 349 | 52 97 1 350 | 53 66 1 351 | 53 77 1 352 | 53 80 1 353 | 53 94 1 354 | 54 76 1 355 | 56 93 1 356 | 58 69 1 357 | 58 74 1 358 | 58 94 1 359 | 58 96 1 360 | 58 100 1 361 | 59 82 1 362 | 60 92 1 363 | 61 63 1 364 | 61 75 1 365 | 62 68 1 366 | 62 90 1 367 | 63 100 1 368 | 65 79 1 369 | 66 74 1 370 | 67 69 1 371 | 68 71 1 372 | 72 76 1 373 | 74 81 1 374 | 74 98 1 375 | 75 100 1 376 | 77 90 1 377 | 80 88 1 378 | 80 92 1 379 | 81 98 1 380 | 85 89 1 381 | 87 91 1 382 | 87 95 1 383 | 89 95 1 384 | 92 97 1 385 | 94 99 1 386 | -------------------------------------------------------------------------------- /rlsolver/data/syn_PL/PL_100_ID0.txt: -------------------------------------------------------------------------------- 1 | 100 384 2 | 1 5 1 3 | 1 6 1 4 | 1 7 1 5 | 1 8 1 6 | 1 20 1 7 | 1 23 1 8 | 1 49 1 9 | 1 54 1 10 | 1 58 1 11 | 1 64 1 12 | 1 99 1 13 | 2 5 1 14 | 2 11 1 15 | 2 22 1 16 | 2 29 1 17 | 2 37 1 18 | 2 43 1 19 | 2 58 1 20 | 2 62 1 21 | 2 66 1 22 | 2 75 1 23 | 2 76 1 24 | 2 82 1 25 | 2 92 1 26 | 2 98 1 27 | 3 5 1 28 | 3 6 1 29 | 3 7 1 30 | 3 8 1 31 | 3 9 1 32 | 3 10 1 33 | 3 11 1 34 | 3 12 1 35 | 3 13 1 36 | 3 15 1 37 | 3 16 1 38 | 3 21 1 39 | 3 24 1 40 | 3 25 1 41 | 3 27 1 42 | 3 28 1 43 | 3 30 1 44 | 3 32 1 45 | 3 33 1 46 | 3 39 1 47 | 3 42 1 48 | 3 47 1 49 | 3 49 1 50 | 3 50 1 51 | 3 66 1 52 | 3 74 1 53 | 3 83 1 54 | 3 84 1 55 | 3 90 1 56 | 4 5 1 57 | 4 6 1 58 | 4 7 1 59 | 4 8 1 60 | 4 12 1 61 | 4 15 1 62 | 4 17 1 63 | 4 24 1 64 | 4 27 1 65 | 4 28 1 66 | 4 29 1 67 | 4 30 1 68 | 4 35 1 69 | 4 43 1 70 | 4 48 1 71 | 4 70 1 72 | 4 81 1 73 | 4 83 1 74 | 5 6 1 75 | 5 7 1 76 | 5 10 1 77 | 5 12 1 78 | 5 13 1 79 | 5 28 1 80 | 5 30 1 81 | 5 32 1 82 | 5 34 1 83 | 5 35 1 84 | 5 39 1 85 | 5 40 1 86 | 5 45 1 87 | 5 48 1 88 | 5 49 1 89 | 5 50 1 90 | 5 53 1 91 | 5 54 1 92 | 5 55 1 93 | 5 57 1 94 | 5 70 1 95 | 5 74 1 96 | 5 79 1 97 | 5 84 1 98 | 5 85 1 99 | 5 87 1 100 | 5 91 1 101 | 5 94 1 102 | 6 8 1 103 | 6 9 1 104 | 6 10 1 105 | 6 12 1 106 | 6 14 1 107 | 6 16 1 108 | 6 22 1 109 | 6 23 1 110 | 6 26 1 111 | 6 27 1 112 | 6 32 1 113 | 6 37 1 114 | 6 43 1 115 | 6 48 1 116 | 6 54 1 117 | 6 55 1 118 | 6 56 1 119 | 6 78 1 120 | 6 79 1 121 | 6 81 1 122 | 6 86 1 123 | 6 87 1 124 | 6 96 1 125 | 6 97 1 126 | 7 9 1 127 | 7 10 1 128 | 7 14 1 129 | 7 15 1 130 | 7 19 1 131 | 7 22 1 132 | 7 27 1 133 | 7 38 1 134 | 7 44 1 135 | 7 48 1 136 | 7 57 1 137 | 7 96 1 138 | 8 9 1 139 | 8 14 1 140 | 8 15 1 141 | 8 16 1 142 | 8 17 1 143 | 8 21 1 144 | 8 24 1 145 | 8 33 1 146 | 8 40 1 147 | 8 46 1 148 | 8 59 1 149 | 8 60 1 150 | 8 89 1 151 | 9 11 1 152 | 9 17 1 153 | 9 18 1 154 | 9 25 1 155 | 9 33 1 156 | 9 42 1 157 | 9 59 1 158 | 9 95 1 159 | 10 11 1 160 | 10 13 1 161 | 10 18 1 162 | 10 20 1 163 | 10 26 1 164 | 10 41 1 165 | 10 62 1 166 | 10 69 1 167 | 10 72 1 168 | 10 84 1 169 | 10 86 1 170 | 10 93 1 171 | 11 13 1 172 | 11 14 1 173 | 11 36 1 174 | 11 45 1 175 | 11 46 1 176 | 11 57 1 177 | 11 95 1 178 | 12 19 1 179 | 12 24 1 180 | 12 26 1 181 | 12 29 1 182 | 12 39 1 183 | 12 42 1 184 | 12 63 1 185 | 12 83 1 186 | 13 18 1 187 | 13 19 1 188 | 13 21 1 189 | 13 28 1 190 | 13 31 1 191 | 13 49 1 192 | 13 69 1 193 | 13 78 1 194 | 13 96 1 195 | 14 16 1 196 | 14 17 1 197 | 14 20 1 198 | 14 21 1 199 | 14 35 1 200 | 14 39 1 201 | 14 41 1 202 | 14 51 1 203 | 14 54 1 204 | 14 89 1 205 | 14 95 1 206 | 15 23 1 207 | 15 38 1 208 | 15 53 1 209 | 15 74 1 210 | 15 85 1 211 | 16 18 1 212 | 16 25 1 213 | 16 33 1 214 | 16 43 1 215 | 16 47 1 216 | 16 56 1 217 | 16 57 1 218 | 16 59 1 219 | 16 81 1 220 | 16 86 1 221 | 16 87 1 222 | 16 92 1 223 | 16 100 1 224 | 17 20 1 225 | 17 44 1 226 | 17 51 1 227 | 17 52 1 228 | 17 76 1 229 | 17 88 1 230 | 18 19 1 231 | 18 34 1 232 | 18 75 1 233 | 19 36 1 234 | 19 51 1 235 | 19 89 1 236 | 20 23 1 237 | 20 26 1 238 | 20 31 1 239 | 20 34 1 240 | 20 37 1 241 | 20 46 1 242 | 20 52 1 243 | 20 53 1 244 | 20 62 1 245 | 20 72 1 246 | 21 22 1 247 | 21 30 1 248 | 21 50 1 249 | 21 53 1 250 | 21 63 1 251 | 21 67 1 252 | 21 76 1 253 | 21 92 1 254 | 21 93 1 255 | 22 38 1 256 | 22 61 1 257 | 22 63 1 258 | 22 77 1 259 | 22 85 1 260 | 22 88 1 261 | 23 25 1 262 | 23 34 1 263 | 23 40 1 264 | 23 41 1 265 | 23 44 1 266 | 23 96 1 267 | 23 99 1 268 | 24 29 1 269 | 24 56 1 270 | 24 65 1 271 | 24 72 1 272 | 24 73 1 273 | 24 87 1 274 | 25 31 1 275 | 25 40 1 276 | 25 52 1 277 | 25 58 1 278 | 25 61 1 279 | 25 64 1 280 | 25 66 1 281 | 25 70 1 282 | 25 82 1 283 | 25 91 1 284 | 26 67 1 285 | 26 93 1 286 | 27 75 1 287 | 28 31 1 288 | 28 32 1 289 | 28 36 1 290 | 28 60 1 291 | 28 63 1 292 | 28 64 1 293 | 28 76 1 294 | 28 100 1 295 | 29 38 1 296 | 29 98 1 297 | 30 65 1 298 | 31 36 1 299 | 31 42 1 300 | 31 45 1 301 | 31 65 1 302 | 31 71 1 303 | 31 73 1 304 | 31 79 1 305 | 31 86 1 306 | 32 35 1 307 | 32 73 1 308 | 32 90 1 309 | 33 47 1 310 | 33 69 1 311 | 34 47 1 312 | 35 37 1 313 | 35 41 1 314 | 35 61 1 315 | 36 45 1 316 | 36 46 1 317 | 36 59 1 318 | 36 71 1 319 | 36 77 1 320 | 36 80 1 321 | 36 85 1 322 | 36 90 1 323 | 36 91 1 324 | 36 93 1 325 | 37 79 1 326 | 38 97 1 327 | 39 74 1 328 | 39 81 1 329 | 40 80 1 330 | 41 84 1 331 | 41 92 1 332 | 42 44 1 333 | 42 60 1 334 | 42 73 1 335 | 44 50 1 336 | 44 58 1 337 | 44 68 1 338 | 46 80 1 339 | 47 55 1 340 | 47 56 1 341 | 47 61 1 342 | 47 69 1 343 | 47 94 1 344 | 48 55 1 345 | 48 65 1 346 | 48 80 1 347 | 49 51 1 348 | 49 52 1 349 | 49 68 1 350 | 50 98 1 351 | 52 70 1 352 | 52 75 1 353 | 53 64 1 354 | 53 99 1 355 | 54 62 1 356 | 54 67 1 357 | 57 60 1 358 | 57 71 1 359 | 58 68 1 360 | 58 83 1 361 | 62 67 1 362 | 62 77 1 363 | 62 95 1 364 | 63 97 1 365 | 65 66 1 366 | 66 68 1 367 | 66 71 1 368 | 66 72 1 369 | 66 82 1 370 | 66 89 1 371 | 66 99 1 372 | 71 82 1 373 | 72 90 1 374 | 73 94 1 375 | 74 77 1 376 | 74 91 1 377 | 75 100 1 378 | 76 78 1 379 | 77 78 1 380 | 80 88 1 381 | 80 98 1 382 | 81 88 1 383 | 81 97 1 384 | 83 94 1 385 | 94 100 1 386 | -------------------------------------------------------------------------------- /rlsolver/data/tsplib/a5.tsp: -------------------------------------------------------------------------------- 1 | NAME : a5 2 | COMMENT : 5 capitals of the US (Padberg/Rinaldi) 3 | TYPE : TSP 4 | DIMENSION : 5 5 | EDGE_WEIGHT_TYPE : A5 6 | NODE_COORD_SECTION 7 | 1 6734 1453 8 | 2 2233 10 9 | 3 5530 1424 10 | 4 401 841 11 | 5 3082 1644 12 | EOF 13 | -------------------------------------------------------------------------------- /rlsolver/data/tsplib/berlin52.tsp: -------------------------------------------------------------------------------- 1 | NAME: berlin52 2 | TYPE: TSP 3 | COMMENT: 52 locations in Berlin (Groetschel) 4 | DIMENSION: 52 5 | EDGE_WEIGHT_TYPE: EUC_2D 6 | NODE_COORD_SECTION 7 | 1 565.0 575.0 8 | 2 25.0 185.0 9 | 3 345.0 750.0 10 | 4 945.0 685.0 11 | 5 845.0 655.0 12 | 6 880.0 660.0 13 | 7 25.0 230.0 14 | 8 525.0 1000.0 15 | 9 580.0 1175.0 16 | 10 650.0 1130.0 17 | 11 1605.0 620.0 18 | 12 1220.0 580.0 19 | 13 1465.0 200.0 20 | 14 1530.0 5.0 21 | 15 845.0 680.0 22 | 16 725.0 370.0 23 | 17 145.0 665.0 24 | 18 415.0 635.0 25 | 19 510.0 875.0 26 | 20 560.0 365.0 27 | 21 300.0 465.0 28 | 22 520.0 585.0 29 | 23 480.0 415.0 30 | 24 835.0 625.0 31 | 25 975.0 580.0 32 | 26 1215.0 245.0 33 | 27 1320.0 315.0 34 | 28 1250.0 400.0 35 | 29 660.0 180.0 36 | 30 410.0 250.0 37 | 31 420.0 555.0 38 | 32 575.0 665.0 39 | 33 1150.0 1160.0 40 | 34 700.0 580.0 41 | 35 685.0 595.0 42 | 36 685.0 610.0 43 | 37 770.0 610.0 44 | 38 795.0 645.0 45 | 39 720.0 635.0 46 | 40 760.0 650.0 47 | 41 475.0 960.0 48 | 42 95.0 260.0 49 | 43 875.0 920.0 50 | 44 700.0 500.0 51 | 45 555.0 815.0 52 | 46 830.0 485.0 53 | 47 1170.0 65.0 54 | 48 830.0 610.0 55 | 49 605.0 625.0 56 | 50 595.0 360.0 57 | 51 1340.0 725.0 58 | 52 1740.0 245.0 59 | EOF 60 | 61 | -------------------------------------------------------------------------------- /rlsolver/docs/.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | python: 2 | setup_py_install: true 3 | -------------------------------------------------------------------------------- /rlsolver/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /rlsolver/docs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/docs/__init__.py -------------------------------------------------------------------------------- /rlsolver/docs/build/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/docs/build/__init__.py -------------------------------------------------------------------------------- /rlsolver/docs/build/init.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /rlsolver/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /rlsolver/docs/source/RLSolver/baseline: -------------------------------------------------------------------------------- 1 | gurobi 2 | random_walk 3 | greedy 4 | simulated_annealing 5 | -------------------------------------------------------------------------------- /rlsolver/docs/source/RLSolver/build_models: -------------------------------------------------------------------------------- 1 | build_models 2 | 3 | 4 | **Graph max-cut** 5 | 6 | 7 | 8 | 9 | 10 | ## MIMO 11 | 12 | 13 | 14 | ## Compressive sensing 15 | 16 | -------------------------------------------------------------------------------- /rlsolver/docs/source/RLSolver/helloworld: -------------------------------------------------------------------------------- 1 | helloworld 2 | -------------------------------------------------------------------------------- /rlsolver/docs/source/RLSolver/overview.rst: -------------------------------------------------------------------------------- 1 | Overview 2 | ============= 3 | 4 | One sentence summary: RLSolver is a high-performance RL Solver. 5 | 6 | We aim to find high-quality optimum, or even (nearly) global optimum, for nonconvex/nonlinear optimizations (continuous variables) and combinatorial optimizations (discrete variables). We provide pretrained neural networks to perform real-time inference for nonconvex optimization problems, including combinatorial optimization problems. 7 | 8 | 9 | The following two key technologies are under active development: 10 | - Massively parallel simuations of gym-environments on GPU, using thousands of CUDA cores and tensor cores. 11 | - Podracer scheduling on a GPU cloud, e.g., DGX-2 SuperPod. 12 | 13 | Key references: 14 | - Mazyavkina, Nina, et al. "Reinforcement learning for combinatorial optimization: A survey." Computers & Operations Research 134 (2021): 105400. 15 | 16 | - Bengio, Yoshua, Andrea Lodi, and Antoine Prouvost. "Machine learning for combinatorial optimization: a methodological tour d’horizon." European Journal of Operational Research 290.2 (2021): 405-421. 17 | 18 | - Makoviychuk, Viktor, et al. "Isaac Gym: High performance GPU based physics simulation for robot learning." Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track (Round 2). 2021. 19 | 20 | - Nair, Vinod, et al. "Solving mixed integer programs using neural networks." arXiv preprint arXiv:2012.13349 (2020). 21 | 22 | MCMC: 23 | - Maxcut 24 | - MIMO Beamforming in 5G/6G. 25 | - Classical NP-Hard problems. 26 | - Classical Simulation of Quantum Circuits. 27 | - Compressive Sensing. 28 | - Portfolio Management. 29 | - OR-Gym. 30 | 31 | File Structure: 32 | ``` 33 | -RLSolver 34 | -├── opt_methods 35 | -| ├──branch-and-bound.py 36 | -| └──cutting_plane.py 37 | -├── helloworld 38 | -| ├──maxcut 39 | -| ├──data 40 | -| ├──result 41 | -| ├──mcmc.py 42 | -| ├──l2a.py 43 | -└── rlsolver (main folder) 44 | - ├── mcmc 45 | - | ├── _base 46 | - | └── maxcut 47 | - | └── tsp 48 | - | ├── portfolio_management 49 | - |── rlsolver_learn2opt 50 | - | ├── mimo 51 | - | ├── tensor_train 52 | - └── utils 53 | - └── maxcut.py 54 | - └── maxcut_gurobi.py 55 | - └── tsp.py 56 | - └── tsp_gurobi.py 57 | ``` 58 | 59 | 60 | **RLSolver features high-performance and stability:** 61 | 62 | **High-performance**: it can find high-quality optimum, or even (nearly) global optimum. 63 | 64 | **Stable**: it leverages computing resource to implement the Hamiltonian-term as an add-on regularization to DRL algorithms. Such an add-on H-term utilizes computing power (can be computed in parallel on GPU) to search for the "minimum-energy state", corresponding to the stable state of a system. 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /rlsolver/docs/source/RLSolver/problems: -------------------------------------------------------------------------------- 1 | Classical nonconvex/nonlinear optimizations (continuous variables) and combinatorial optimizations (discrete variables) are listed here. 2 | - Maxcut 3 | - TSP 4 | - MILP 5 | - MIMO 6 | - Compressive sensing 7 | - TNCO 8 | 9 | -------------------------------------------------------------------------------- /rlsolver/docs/source/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/docs/source/__init__.py -------------------------------------------------------------------------------- /rlsolver/docs/source/about/overview.rst: -------------------------------------------------------------------------------- 1 | Key Concepts and Features 2 | ============= 3 | -------------------------------------------------------------------------------- /rlsolver/docs/source/algorithms/REINFORCE.rst: -------------------------------------------------------------------------------- 1 | .. _REINFORCE: 2 | 3 | 4 | REINFORCE 5 | ========== 6 | 7 | -------------------------------------------------------------------------------- /rlsolver/docs/source/api/config.rst: -------------------------------------------------------------------------------- 1 | Configuration: *config.py* 2 | ========================== 3 | 4 | 5 | ``Arguments`` 6 | --------------------- 7 | 8 | The ``Arguments`` class contains all parameters of the training process, including environment setup, model training, model evaluation, and resource allocation. It provides users an unified interface to customize the training process. 9 | 10 | The class should be initialized at the start of the training process. For example, 11 | -------------------------------------------------------------------------------- /rlsolver/docs/source/api/utils.rst: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /rlsolver/docs/source/helloworld/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ============= 3 | 4 | As a quickstart, we select the maxcut task from the demo.py to show how to train a DRL agent. 5 | 6 | Step 1: Import packages 7 | ------------------------------- 8 | 9 | .. code-block:: python 10 | 11 | 12 | 13 | Step 2: Specify Agent and Environment 14 | -------------------------------------- 15 | 16 | .. code-block:: python 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /rlsolver/docs/source/images/bellman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/docs/source/images/bellman.png -------------------------------------------------------------------------------- /rlsolver/docs/source/images/envs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/docs/source/images/envs.png -------------------------------------------------------------------------------- /rlsolver/docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. RLSolver documentation master file, created by 2 | 3 | Welcome to RLSolver! 4 | ===================================== 5 | 6 | 7 | :width: 50% 8 | :align: center 9 | :target: https://github.com/zhumingpassional/RLSolver 10 | 11 | 12 | 13 | `RLSolver `_ : GPU-based Massively Parallel Environments for Combinatorial Optimization (CO) Problems Using Reinforcement Learning 14 | 15 | We aim to showcase the effectiveness of massively parallel environments for combinatorial optimization (CO) problems using reinforcement learning (RL). RL with the help of GPU based parallel environments can significantly improve the sampling speed and can obtain high-quality solutions within short time. 16 | 17 | Overview 18 | 19 | RLSolver has three layers: 20 | 21 | -Environments: providing massively parallel environments using GPUs. 22 | -RL agents: providing RL algorithms, e.g., REINFORCE. 23 | -Problems: typical CO problems, e.g., graph maxcut and TNCO. 24 | 25 | Key Technologies 26 | -**GPU-based Massively parallel environments** of Markov chain Monte Carlo (MCMC) simulations on GPU using thousands of CUDA cores and tensor cores. 27 | -**Distribution-wise** is **much much faster** than the instance-wise methods, such as MCPG and iSCO, since we can obtain the results directly by inference. 28 | 29 | Why Use GPU-based Massively Parallel Environments? 30 | 31 | The bottleneck of using RL for solving CO problems is the sampling speed since existing solver engines (a.k.a, gym-style environments) are implemented on CPUs. Training the policy network is essentially estimating the gradients via a Markov chain Monte Carlo (MCMC) simulation, which requires a large number of samples from environments. 32 | 33 | Existing CPU-based environments have two significant disadvantages: 1) The number of CPU cores is typically small, generally ranging from 16 to 256, resulting in a small number of parallel environments. 2) The communication link between CPUs and GPUs has limited bandwidth. The massively parallel environments can overcome these disadvantages, since we can build thounsands of environments and the communication bottleneck between CPUs and GPUs is bypassed, therefore the sampling speed is significantly improved. 34 | 35 | Installation 36 | --------------------------------------- 37 | 38 | RLSolver generally requires: 39 | 40 | - Python>=3.6 41 | 42 | - PyTorch>=1.0.2 43 | 44 | - gym, matplotlib, numpy, torch 45 | 46 | You can simply install ElegantRL from PyPI with the following command: 47 | 48 | .. code-block:: bash 49 | :linenos: 50 | 51 | pip3 install rlsolver --upgrade 52 | 53 | Or install with the newest version through GitHub: 54 | 55 | .. code-block:: bash 56 | :linenos: 57 | 58 | git clone https://github.com/zhumingpassional/RLSolver 59 | cd RLSolver 60 | pip3 install . 61 | 62 | 63 | .. toctree:: 64 | :maxdepth: 1 65 | :hidden: 66 | 67 | Home 68 | 69 | .. toctree:: 70 | :maxdepth: 1 71 | :caption: HelloWorld 72 | 73 | helloworld/intro 74 | 75 | .. toctree:: 76 | :maxdepth: 1 77 | :caption: Overview 78 | 79 | about/overview 80 | about/cloud 81 | about/parallel 82 | 83 | 84 | .. toctree:: 85 | :maxdepth: 1 86 | :caption: Tutorials 87 | 88 | tutorial/maxcut 89 | 90 | 91 | .. toctree:: 92 | :maxdepth: 1 93 | :caption: Algorithms 94 | 95 | algorithms/REINFORCE 96 | 97 | 98 | .. toctree:: 99 | :maxdepth: 1 100 | :caption: RLSolver 101 | 102 | RLSolver/overview 103 | RLSolver/helloworld 104 | RLSolver/datasets 105 | RLSolver/environments 106 | RLSolver/benchmarks 107 | 108 | 109 | .. toctree:: 110 | :maxdepth: 1 111 | :caption: API Reference 112 | 113 | api/config 114 | api/util 115 | 116 | 117 | .. toctree:: 118 | :maxdepth: 1 119 | :caption: Other 120 | 121 | other/faq 122 | 123 | 124 | Indices and tables 125 | ================== 126 | 127 | * :ref:`genindex` 128 | * :ref:`modindex` 129 | * :ref:`search` 130 | -------------------------------------------------------------------------------- /rlsolver/docs/source/other/faq.rst: -------------------------------------------------------------------------------- 1 | FAQ 2 | ============================= 3 | 4 | :Version: 0.0.1 5 | :Date: 7-11-2024 6 | :Contributors: Ming Zhu 7 | 8 | 9 | 10 | Description 11 | ---------------- 12 | 13 | This document contains the most frequently asked questions related to the RLSolver Library, based on questions posted on the slack channels and Github_ issues. 14 | 15 | .. _Github: https://github.com/zhumingpassional/RLSolver 16 | 17 | 18 | Outline 19 | ---------------- 20 | 21 | - :ref:`Section-1` 22 | 23 | -------------------------------------------------------------------------------- /rlsolver/docs/source/tutorial/maxcut.rst: -------------------------------------------------------------------------------- 1 | 2 | 3 | What is maxcut? 4 | ----------------------------------------------- 5 | 6 | -------------------------------------------------------------------------------- /rlsolver/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/envs/__init__.py -------------------------------------------------------------------------------- /rlsolver/fig/RLSolver_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/fig/RLSolver_framework.png -------------------------------------------------------------------------------- /rlsolver/fig/RLSolver_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/fig/RLSolver_structure.png -------------------------------------------------------------------------------- /rlsolver/fig/objectives_epochs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/fig/objectives_epochs.png -------------------------------------------------------------------------------- /rlsolver/fig/parallel_sims_maxcut.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/fig/parallel_sims_maxcut.png -------------------------------------------------------------------------------- /rlsolver/fig/parallel_sims_pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/fig/parallel_sims_pattern.png -------------------------------------------------------------------------------- /rlsolver/fig/sampling_efficiency_maxcut.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/fig/sampling_efficiency_maxcut.png -------------------------------------------------------------------------------- /rlsolver/fig/speed_up_maxcut1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/fig/speed_up_maxcut1.png -------------------------------------------------------------------------------- /rlsolver/fig/speed_up_maxcut2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/fig/speed_up_maxcut2.png -------------------------------------------------------------------------------- /rlsolver/fig/work_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/fig/work_flow.png -------------------------------------------------------------------------------- /rlsolver/main.py: -------------------------------------------------------------------------------- 1 | 2 | def main(): 3 | print() 4 | 5 | if __name__ == '__main__': 6 | main() -------------------------------------------------------------------------------- /rlsolver/methods/VRPTW_algs/ESPPRC_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | from time import * 4 | import copy 5 | from typing import List, Tuple, Set, Dict 6 | 7 | class Label: 8 | path = [] 9 | travel_time = 0 10 | dist = 0 11 | 12 | # dominance rule 13 | def dominate(labels: List[Label], path_dict: Dict[int, Label]): 14 | labels_copy = copy.deepcopy(labels) 15 | path_dict_copy = copy.deepcopy(path_dict) 16 | 17 | # dominate Q 18 | for label in labels_copy: 19 | for another_label in labels: 20 | if (label.path[-1] == another_label.path[ 21 | -1] and label.time < another_label.time and label.dis < another_label.dis): 22 | labels.remove(another_label) 23 | print("dominated path (Q) : ", another_label.path) 24 | 25 | # dominate Paths 26 | for key_1 in path_dict_copy.keys(): 27 | for key_2 in path_dict_copy.keys(): 28 | if (path_dict_copy[key_1].path[-1] == path_dict_copy[key_2].path[-1] 29 | and path_dict_copy[key_1].travel_time < path_dict_copy[key_2].travel_time 30 | and path_dict_copy[key_1].dist < path_dict_copy[key_2].dist 31 | and (key_2 in path_dict.keys())): 32 | path_dict.pop(key_2) 33 | print("dominated path (P) : ", path_dict_copy[key_1].path) 34 | 35 | return labels, path_dict 36 | 37 | 38 | 39 | 40 | # labeling algorithm 41 | def labeling_SPPRC(graph, orig, dest): 42 | # initial Q 43 | labels: List[Label] = [] 44 | path_dict: Dict = {} 45 | 46 | # creat initial label 47 | label = Label() 48 | label.path = [orig] 49 | label.travel_time = 0 50 | label.dist = 0 51 | labels.append(label) 52 | 53 | count = 0 54 | 55 | while (len(labels) > 0): 56 | count += 1 57 | cur_label = labels.pop() 58 | 59 | # extend the current label 60 | last_node = cur_label.path[-1] 61 | for child in graph.successors(last_node): 62 | extended_label = copy.deepcopy(cur_label) 63 | arc = (last_node, child) 64 | 65 | # check the feasibility 66 | arrive_time = cur_label.travel_time + graph.edges[arc]["travel_time"] 67 | time_window = graph.nodes[child]["time_window"] 68 | if (arrive_time >= time_window[0] and arrive_time <= time_window[1] and last_node != dest): 69 | extended_label.path.append(child) 70 | extended_label.travel_time += graph.edges[arc]["travel_time"] 71 | extended_label.dist += graph.edges[arc]["cost"] 72 | labels.append(extended_label) 73 | 74 | path_dict[count] = cur_label 75 | # 调用dominance rule 76 | labels, path_dict = dominate(labels, path_dict) 77 | 78 | # filtering Paths, only keep feasible solutions 79 | path_dict_copy = copy.deepcopy(path_dict) 80 | for key in path_dict_copy.keys(): 81 | if (path_dict[key].path[-1] != dest): 82 | path_dict.pop(key) 83 | 84 | # choose optimal solution 85 | opt_path = {} 86 | min_dist = 1e6 87 | for key in path_dict.keys(): 88 | if (path_dict[key].dist < min_dist): 89 | min_dist = path_dict[key].dist 90 | opt_path[1] = path_dict[key] 91 | 92 | return graph, labels, path_dict, opt_path 93 | 94 | 95 | def main(): 96 | # 点中包含时间窗属性 97 | Nodes = {'s': (0, 0) 98 | , '1': (6, 14) 99 | , '2': (9, 12) 100 | , '3': (8, 12) 101 | , 't': (9, 15) 102 | } 103 | # 弧的属性包括travel_time与dist 104 | Arcs = {('s', '1'): (8, 3) 105 | , ('s', '2'): (5, 5) 106 | , ('s', '3'): (12, 2) 107 | , ('1', 't'): (4, 7) 108 | , ('2', 't'): (2, 6) 109 | , ('3', 't'): (4, 3) 110 | } 111 | 112 | # create the directed Graph 113 | graph = nx.DiGraph() 114 | cnt = 0 115 | # add nodes into the graph 116 | for name in Nodes.keys(): 117 | cnt += 1 118 | graph.add_node(name 119 | , time_window=(Nodes[name][0], Nodes[name][1]) 120 | , min_dist=0 121 | ) 122 | # add edges into the graph 123 | for key in Arcs.keys(): 124 | graph.add_edge(key[0], key[1] 125 | , travel_time=Arcs[key][0] 126 | , cost=Arcs[key][1] 127 | ) 128 | 129 | org = 's' 130 | des = 't' 131 | begin_time = time() 132 | graph, labels, path_dict, opt_path = labeling_SPPRC(graph, org, des) 133 | end_time = time() 134 | print("计算时间: ", end_time - begin_time) 135 | print('optimal path : ', opt_path[1].path) 136 | print('optimal path (dist): ', opt_path[1].dist) 137 | 138 | if __name__ == '__main__': 139 | main() 140 | 141 | -------------------------------------------------------------------------------- /rlsolver/methods/VRPTW_algs/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | cur_path = os.path.dirname(os.path.abspath(__file__)) 4 | rlsolver_path = os.path.join(cur_path, '../../../rlsolver') 5 | sys.path.append(os.path.dirname(rlsolver_path)) 6 | 7 | from rlsolver.methods.VRPTW_algs.config import (Config, Alg) 8 | from rlsolver.methods.VRPTW_algs.impact_heuristic import run_impact_heuristic 9 | from rlsolver.methods.VRPTW_algs.column_generation import run_column_generation 10 | def main(): 11 | if Config.ALG == Alg.impact_heuristic: 12 | run_impact_heuristic() 13 | elif Config.ALG == Alg.column_generation: 14 | run_column_generation() 15 | 16 | pass 17 | 18 | 19 | if __name__ == '__main__': 20 | main() -------------------------------------------------------------------------------- /rlsolver/methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/methods/__init__.py -------------------------------------------------------------------------------- /rlsolver/methods/config.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from typing import List 3 | from enum import Enum, unique 4 | # from L2A.graph_utils import GraphList, obtain_num_nodes 5 | import os 6 | 7 | @unique 8 | class Problem(Enum): 9 | maxcut = "maxcut" 10 | graph_partitioning = "graph_partitioning" 11 | minimum_vertex_cover = "minimum_vertex_cover" 12 | number_partitioning = "number_partitioning" 13 | bilp = "bilp" 14 | maximum_independent_set = "maximum_independent_set" 15 | knapsack = "knapsack" 16 | set_cover = "set_cover" 17 | graph_coloring = "graph_coloring" 18 | tsp = "tsp" 19 | PROBLEM = Problem.maxcut 20 | 21 | @unique 22 | class GraphType(Enum): 23 | BA: str = "BA" # "barabasi_albert" 24 | ER: str = "ER" # "erdos_renyi" 25 | PL: str = "PL" # "powerlaw" 26 | 27 | def calc_device(gpu_id: int): 28 | return th.device(f"cuda:{gpu_id}" if th.cuda.is_available() and gpu_id >= 0 else "cpu") 29 | 30 | GPU_ID: int = 0 # -1: cpu, >=0: gpu 31 | 32 | DATA_FILENAME = "../data/syn_BA/BA_100_ID0.txt" # one instance 33 | DIRECTORY_DATA = "../data/syn_BA" # used in multi instances 34 | PREFIXES = ["BA_100_ID0"] # used in multi instances 35 | 36 | DEVICE: th.device = calc_device(GPU_ID) 37 | 38 | GRAPH_TYPE = GraphType.PL 39 | GRAPH_TYPES: List[GraphType] = [GraphType.ER, GraphType.PL, GraphType.BA] 40 | # graph_types = ["erdos_renyi", "powerlaw", "barabasi_albert"] 41 | NUM_IDS = 30 # ID0, ..., ID29 42 | 43 | 44 | 45 | INF = 1e6 46 | 47 | # RUNNING_DURATIONS = [600, 1200, 1800, 2400, 3000, 3600] # store results 48 | RUNNING_DURATIONS = [300, 600, 900, 1200, 1500, 1800, 2100, 2400, 2700, 3000, 3300, 3600] # store results 49 | 50 | # None: write results when finished. 51 | # others: write results in mycallback. seconds, the interval of writing results to txt files 52 | GUROBI_INTERVAL = None # None: not using interval, i.e., not using mycallback function, write results when finished. If not None such as 100, write result every 100s 53 | GUROBI_TIME_LIMITS = [1 * 3600] # seconds 54 | # GUROBI_TIME_LIMITS = [600, 1200, 1800, 2400, 3000, 3600] # seconds 55 | # GUROBI_TIME_LIMITS2 = list(range(10 * 60, 1 * 3600 + 1, 10 * 60)) # seconds 56 | GUROBI_VAR_CONTINUOUS = False # True: relax it to LP, and return x values. False: sovle the primal MILP problem 57 | GUROBI_MILP_QUBO = 1 # 0: MILP, 1: QUBO. default: QUBO, since using QUBO is generally better than MILP. 58 | assert GUROBI_MILP_QUBO in [0, 1] 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /rlsolver/methods/quantum.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | cur_path = os.path.dirname(os.path.abspath(__file__)) 4 | rlsolver_path = os.path.join(cur_path, '../../rlsolver') 5 | sys.path.append(os.path.dirname(rlsolver_path)) 6 | 7 | import networkx as nx 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | from qiskit.circuit.library import TwoLocal 11 | from qiskit_optimization.applications import Maxcut, Tsp 12 | from qiskit.algorithms.minimum_eigensolvers import SamplingVQE, NumPyMinimumEigensolver 13 | from qiskit.algorithms.optimizers import SPSA 14 | from qiskit.utils import algorithm_globals 15 | from qiskit.primitives import Sampler 16 | from qiskit_optimization.algorithms import MinimumEigenOptimizer 17 | 18 | def draw_graph(G, colors, pos): 19 | default_axes = plt.axes(frameon=True) 20 | nx.draw_networkx(G, node_color=colors, node_size=400, alpha=0.8, ax=default_axes, pos=pos) 21 | edge_labels = nx.get_edge_attributes(G, "weight") 22 | nx.draw_networkx_edge_labels(G, pos=pos, edge_labels=edge_labels) 23 | plt.show() 24 | 25 | #CLASSICAL 26 | n = 5 27 | G = nx.Graph() 28 | G.add_nodes_from(np.arange(0,4,1)) 29 | 30 | 31 | edges = [(1,2),(1,3),(2,4),(3,4),(3,0),(4,0)] 32 | #edges = [(0,1),(1,2),(2,3),(3,4)]#[(1,2),(2,3),(3,4),(4,5)] 33 | G.add_edges_from(edges) 34 | 35 | #colors = ["g" for node in G.nodes()] 36 | pos = nx.spring_layout(G) 37 | #draw_graph(G, colors, pos) 38 | w = np.zeros([n, n]) 39 | for i in range(n): 40 | for j in range(n): 41 | temp = G.get_edge_data(i, j, default=0) 42 | if temp != 0: 43 | w[i, j] = 1 44 | 45 | 46 | #QUANTUM 47 | 48 | max_cut = Maxcut(w) 49 | qp = max_cut.to_quadratic_program() 50 | print(qp.prettyprint()) 51 | 52 | qubitOp, offset = qp.to_ising() 53 | print("Offset:", offset) 54 | print("Ising Hamiltonian:") 55 | print(str(qubitOp)) 56 | 57 | # solving Quadratic Program using exact classical eigensolver 58 | exact = MinimumEigenOptimizer(NumPyMinimumEigensolver()) 59 | result = exact.solve(qp) 60 | print("eigensolver:", result.prettyprint()) 61 | 62 | # Making the Hamiltonian in its full form and getting the lowest eigenvalue and eigenvector 63 | ee = NumPyMinimumEigensolver() 64 | result = ee.compute_minimum_eigenvalue(qubitOp) 65 | 66 | x = max_cut.sample_most_likely(result.eigenstate) 67 | print("energy:", result.eigenvalue.real) 68 | print("max-cut objective:", result.eigenvalue.real + offset) 69 | print("solution:", x) 70 | print("solution objective:", qp.objective.evaluate(x)) 71 | 72 | colors = ["r" if x[i] == 0 else "c" for i in range(n)] 73 | draw_graph(G, colors, pos) 74 | 75 | algorithm_globals.random_seed = 123 76 | seed = 10598 77 | 78 | # construct SamplingVQE 79 | optimizer = SPSA(maxiter=300) 80 | ry = TwoLocal(qubitOp.num_qubits, "ry", "cz", reps=5, entanglement="linear") 81 | vqe = SamplingVQE(sampler=Sampler(), ansatz=ry, optimizer=optimizer) 82 | 83 | # run SamplingVQE 84 | result = vqe.compute_minimum_eigenvalue(qubitOp) 85 | 86 | # print results 87 | x = max_cut.sample_most_likely(result.eigenstate) 88 | print("energy:", result.eigenvalue.real) 89 | print("time:", result.optimizer_time) 90 | print("max-cut objective:", result.eigenvalue.real + offset) 91 | print("solution:", x) 92 | print("solution objective:", qp.objective.evaluate(x)) 93 | 94 | # plot results 95 | colors = ["r" if x[i] == 0 else "c" for i in range(n)] 96 | draw_graph(G, colors, pos) 97 | 98 | # create minimum eigen optimizer based on SamplingVQE 99 | vqe_optimizer = MinimumEigenOptimizer(vqe) 100 | 101 | # solve quadratic program 102 | result = vqe_optimizer.solve(qp) 103 | print(result.prettyprint()) 104 | 105 | colors = ["r" if result.solution[i] == 0 else "c" for i in range(n)] 106 | draw_graph(G, colors, pos) 107 | -------------------------------------------------------------------------------- /rlsolver/methods/random_walk.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | cur_path = os.path.dirname(os.path.abspath(__file__)) 4 | rlsolver_path = os.path.join(cur_path, '../../rlsolver') 5 | sys.path.append(os.path.dirname(rlsolver_path)) 6 | 7 | import copy 8 | import time 9 | import networkx as nx 10 | import numpy as np 11 | from typing import List, Union 12 | import random 13 | from rlsolver.methods.util_read_data import read_nxgraph 14 | from rlsolver.methods.util_obj import obj_maxcut 15 | from rlsolver.methods.util_result import write_graph_result 16 | from rlsolver.methods.util import plot_fig 17 | 18 | import sys 19 | sys.path.append('../') 20 | 21 | def random_walk_maxcut(init_solution: Union[List[int], np.array], num_steps: int, graph: nx.Graph) -> (int, Union[List[int], np.array], List[int]): 22 | print('random_walk') 23 | start_time = time.time() 24 | curr_solution = copy.deepcopy(init_solution) 25 | init_score = obj_maxcut(init_solution, graph) 26 | num_nodes = len(curr_solution) 27 | scores = [] 28 | for i in range(num_steps): 29 | # select a node randomly 30 | node = random.randint(0, num_nodes - 1) 31 | curr_solution[node] = (curr_solution[node] + 1) % 2 32 | # calc the obj 33 | score = obj_maxcut(curr_solution, graph) 34 | scores.append(score) 35 | print("score, init_score of random_walk", score, init_score) 36 | print("scores: ", scores) 37 | print("solution: ", curr_solution) 38 | running_duration = time.time() - start_time 39 | print('running_duration: ', running_duration) 40 | return score, curr_solution, scores 41 | 42 | 43 | if __name__ == '__main__': 44 | # read data 45 | # graph1 = read_as_networkx_graph('data/gset_14.txt') 46 | start_time = time.time() 47 | filename = '../data/syn_BA/BA_100_ID0.txt' 48 | graph = read_nxgraph(filename) 49 | 50 | # run alg 51 | # init_solution = [1, 0, 1, 0, 1] 52 | init_solution = list(np.random.randint(0, 2, graph.number_of_nodes())) 53 | rw_score, rw_solution, rw_scores = random_walk_maxcut(init_solution=init_solution, num_steps=1000, graph=graph) 54 | running_duration = time.time() - start_time 55 | num_nodes = graph.number_of_nodes 56 | alg_name = "random_walk" 57 | # write result 58 | write_graph_result(rw_score, running_duration, num_nodes, alg_name, rw_solution, filename) 59 | # write_result(rw_solution, '../result/result.txt') 60 | obj = obj_maxcut(rw_solution, graph) 61 | print('obj: ', obj) 62 | alg_name = 'RW' 63 | 64 | # plot fig 65 | if_plot = False 66 | if if_plot: 67 | plot_fig(rw_scores, alg_name) 68 | 69 | 70 | -------------------------------------------------------------------------------- /rlsolver/methods/sdp.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | cur_path = os.path.dirname(os.path.abspath(__file__)) 5 | rlsolver_path = os.path.join(cur_path, '../../rlsolver') 6 | sys.path.append(os.path.dirname(rlsolver_path)) 7 | 8 | import cvxpy as cp 9 | import networkx as nx 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import scipy.linalg 13 | import os 14 | import time 15 | from typing import List 16 | 17 | from rlsolver.methods.util_obj import obj_maxcut 18 | from rlsolver.methods.util_read_data import read_nxgraph 19 | from rlsolver.methods.util import (calc_txt_files_with_prefixes, 20 | ) 21 | from rlsolver.methods.util_result import (write_graph_result, 22 | ) 23 | 24 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 25 | 26 | 27 | # approx ratio 1/0.87 28 | # goemans_williamson alg 29 | def sdp_maxcut(filename: str): 30 | graph = read_nxgraph(filename) 31 | n = graph.number_of_nodes() # num of nodes 32 | edges = graph.edges 33 | 34 | x = cp.Variable((n, n), symmetric=True) # construct n x n matrix 35 | 36 | # diagonals must be 1 (unit) and eigenvalues must be postivie 37 | # semidefinite 38 | constraints = [x >> 0] + [x[i, i] == 1 for i in range(n)] 39 | 40 | # this is function defing the cost of the cut. You want to maximize this function 41 | # to get heaviest cut 42 | objective = sum((0.5) * (1 - x[i, j]) for (i, j) in edges) 43 | 44 | # solves semidefinite program, optimizes linear cost function 45 | prob = cp.Problem(cp.Maximize(objective), constraints) 46 | prob.solve() 47 | 48 | # normalizes matrix, makes it applicable in unit sphere 49 | sqrtProb = scipy.linalg.sqrtm(x.value) 50 | 51 | # generates random hyperplane used to split set of points into two disjoint sets of nodes 52 | hyperplane = np.random.randn(n) 53 | 54 | # gives value -1 if on one side of plane and 1 if on other 55 | # returned as a array 56 | sqrtProb = np.sign(sqrtProb @ hyperplane) 57 | # print(sqrtProb) 58 | 59 | colors = ["r" if sqrtProb[i] == -1 else "c" for i in range(n)] 60 | solution = [0 if sqrtProb[i] == -1 else 1 for i in range(n)] 61 | 62 | pos = nx.spring_layout(graph) 63 | # draw_graph(graph, colors, pos) 64 | score = obj_maxcut(solution, graph) 65 | print("obj: ", score, ",solution = " + str(solution)) 66 | return score, solution 67 | 68 | 69 | def run_sdp_over_multiple_files(alg, alg_name, directory_data: str, prefixes: List[str]) -> List[List[float]]: 70 | scores = [] 71 | files = calc_txt_files_with_prefixes(directory_data, prefixes) 72 | for i in range(len(files)): 73 | start_time = time.time() 74 | filename = files[i] 75 | print(f'Start the {i}-th file: {filename}') 76 | score, solution = alg(filename) 77 | scores.append(score) 78 | print(f"score: {score}") 79 | running_duration = time.time() - start_time 80 | graph = read_nxgraph(filename) 81 | num_nodes = int(graph.number_of_nodes()) 82 | write_graph_result(score, running_duration, num_nodes, alg_name, solution, filename) 83 | return scores 84 | 85 | 86 | if __name__ == '__main__': 87 | # n = 5 88 | # graph = nx.Graph() 89 | # graph.add_nodes_from(np.arange(0, 4, 1)) 90 | # 91 | # edges = [(1, 2), (1, 3), (2, 4), (3, 4), (3, 0), (4, 0)] 92 | # # edges = [(0,1),(1,2),(2,3),(3,4)]#[(1,2),(2,3),(3,4),(4,5)] 93 | # graph.add_edges_from(edges) 94 | 95 | # filename = '../data/gset/gset_14.txt' 96 | run_single_file = False 97 | if run_single_file: 98 | filename = '../data/syn_BA/BA_100_ID0.txt' 99 | sdp_maxcut(filename) 100 | 101 | run_multi_files = True 102 | if run_multi_files: 103 | alg = sdp_maxcut 104 | alg_name = 'sdp' 105 | directory_data = '../data/syn_BA' 106 | prefixes = ['barabasi_albert_100'] 107 | scores = run_sdp_over_multiple_files(alg, alg_name, directory_data, prefixes) 108 | print(f"scores: {scores}") 109 | -------------------------------------------------------------------------------- /rlsolver/methods/tsp_algs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/methods/tsp_algs/__init__.py -------------------------------------------------------------------------------- /rlsolver/methods/tsp_algs/config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Alg(Enum): 5 | local_search_2_opt = 'local_search_2_opt' 6 | local_search_3_opt = 'local_search_3_opt' 7 | cheapest_insertion = 'cheapest_insertion' 8 | christofides_algorithm = 'christofides_algorithm' 9 | farthest_insertion = 'farthest_insertion' 10 | genetic_algorithm = 'genetic_algorithm' 11 | nearest_insertion = 'nearest_insertion' 12 | nearest_neighbour = 'nearest_neighbour' 13 | simulated_annealing = 'simulated_annealing' 14 | tabu_search = 'tabu_search' 15 | greedy_karp_steele_patching = 'greedy_karp_steele_patching' 16 | lkh = 'lkh' 17 | 18 | 19 | ALG = Alg.local_search_2_opt 20 | 21 | # Parameters 22 | if ALG == Alg.local_search_2_opt: 23 | PARAMETERS = { 24 | 'recursive_seeding': -1, # Total Number of Iterations. If This Value is Negative Then the Algorithm Only Stops When Convergence is Reached 25 | 'verbose': True 26 | } 27 | elif ALG == Alg.local_search_3_opt: 28 | PARAMETERS = { 29 | 'recursive_seeding': -1, # Total Number of Iterations. If This Value is Negative Then the Algorithm Only Stops When Convergence is Reached 30 | 'verbose': True 31 | } 32 | elif ALG == Alg.cheapest_insertion: 33 | PARAMETERS = { 34 | 'verbose': True 35 | } 36 | elif ALG == Alg.farthest_insertion: 37 | PARAMETERS = { 38 | 'initial_location': -1, # -1 = Try All Locations. 39 | 'verbose': True 40 | } 41 | elif ALG == Alg.genetic_algorithm: 42 | PARAMETERS = { 43 | 'population_size': 15, 44 | 'elite': 1, 45 | 'mutation_rate': 0.1, 46 | 'mutation_search': 8, 47 | 'generations': 1000, 48 | 'verbose': True 49 | } 50 | elif ALG == Alg.greedy_karp_steele_patching: 51 | PARAMETERS = { 52 | 'verbose': True 53 | } 54 | elif ALG == Alg.lkh: 55 | PARAMETERS = { 56 | 'max_trials': 10000, 57 | 'runs': 10 58 | } 59 | elif ALG == Alg.nearest_insertion: 60 | PARAMETERS = { 61 | 'initial_location': -1, # -1 = Try All Locations. 62 | 'verbose': True 63 | } 64 | elif ALG == Alg.nearest_neighbour: 65 | PARAMETERS = { 66 | 'initial_location': -1, # -1 = Try All Locations. 67 | 'local_search': True, 68 | 'verbose': True 69 | } 70 | elif ALG == Alg.simulated_annealing: 71 | PARAMETERS = { 72 | 'initial_temperature': 1.0, 73 | 'temperature_iterations': 10, 74 | 'final_temperature': 0.0001, 75 | 'alpha': 0.9, 76 | 'verbose': True 77 | } 78 | elif ALG == Alg.tabu_search: 79 | PARAMETERS = { 80 | 'iterations': 500, 81 | 'tabu_tenure': 75, 82 | 'verbose': True 83 | } 84 | 85 | 86 | -------------------------------------------------------------------------------- /rlsolver/methods/tsp_algs/ins_c.py: -------------------------------------------------------------------------------- 1 | ############################################################################ 2 | 3 | 4 | 5 | # Lesson: Cheapest Insertion 6 | 7 | 8 | 9 | ############################################################################ 10 | 11 | # Required Libraries 12 | import copy 13 | import numpy as np 14 | 15 | ############################################################################ 16 | 17 | # Function: Tour Distance 18 | def distance_calc(distance_matrix, city_tour): 19 | distance = 0 20 | for k in range(0, len(city_tour[0])-1): 21 | m = k + 1 22 | distance = distance + distance_matrix[city_tour[0][k]-1, city_tour[0][m]-1] 23 | return distance 24 | 25 | # Function: 2_opt 26 | def local_search_2_opt(distance_matrix, city_tour, recursive_seeding = -1, verbose = True): 27 | if (recursive_seeding < 0): 28 | count = -2 29 | else: 30 | count = 0 31 | city_list = copy.deepcopy(city_tour) 32 | distance = city_list[1]*2 33 | iteration = 0 34 | while (count < recursive_seeding): 35 | if (verbose == True): 36 | print('Iteration = ', iteration, 'Distance = ', round(city_list[1], 2)) 37 | best_route = copy.deepcopy(city_list) 38 | seed = copy.deepcopy(city_list) 39 | for i in range(0, len(city_list[0]) - 2): 40 | for j in range(i+1, len(city_list[0]) - 1): 41 | best_route[0][i:j+1] = list(reversed(best_route[0][i:j+1])) 42 | best_route[0][-1] = best_route[0][0] 43 | best_route[1] = distance_calc(distance_matrix, best_route) 44 | if (city_list[1] > best_route[1]): 45 | city_list = copy.deepcopy(best_route) 46 | best_route = copy.deepcopy(seed) 47 | count = count + 1 48 | iteration = iteration + 1 49 | if (distance > city_list[1] and recursive_seeding < 0): 50 | distance = city_list[1] 51 | count = -2 52 | recursive_seeding = -1 53 | elif(city_list[1] >= distance and recursive_seeding < 0): 54 | count = -1 55 | recursive_seeding = -2 56 | return city_list[0], city_list[1] 57 | 58 | ############################################################################ 59 | 60 | # Function: Cheapest Insertion 61 | def cheapest_insertion(distance_matrix, verbose = True): 62 | route = [] 63 | temp = [] 64 | i, idx = np.unravel_index(np.argmax(distance_matrix, axis = None), distance_matrix.shape) 65 | temp.append(i) 66 | temp.append(idx) 67 | count = 0 68 | while (len(temp) < distance_matrix.shape[0]): 69 | temp_ = [item+1 for item in temp] 70 | temp_ = temp_ + [temp_[0]] 71 | d = distance_calc(distance_matrix, [temp_, 1]) 72 | seed = [temp_, d] 73 | temp_, _ = local_search_2_opt(distance_matrix, seed, recursive_seeding = -1, verbose = False) 74 | temp = [item-1 for item in temp_[:-1]] 75 | idx = [i for i in range(0, distance_matrix.shape[0]) if i not in temp] 76 | best_d = [] 77 | best_r = [] 78 | for i in idx: 79 | temp_ = [item for item in temp] 80 | temp_.append(i) 81 | temp_ = [item+1 for item in temp_] 82 | temp_ = temp_ + [temp_[0]] 83 | d = distance_calc(distance_matrix, [temp_, 1]) 84 | seed = [temp_, d] 85 | temp_, d = local_search_2_opt(distance_matrix, seed, recursive_seeding = -1, verbose = False) 86 | temp_ = [item-1 for item in temp_[:-1]] 87 | best_d.append(d) 88 | best_r.append(temp_) 89 | temp = [item for item in best_r[best_d.index(min(best_d))]] 90 | if (verbose == True): 91 | print('Iteration = ', count) 92 | count = count + 1 93 | route = temp + [temp[0]] 94 | route = [item + 1 for item in route] 95 | distance = distance_calc(distance_matrix, [route, 1]) 96 | print("distance: ", distance) 97 | return route, distance 98 | 99 | ############################################################################ -------------------------------------------------------------------------------- /rlsolver/methods/tsp_algs/ins_f.py: -------------------------------------------------------------------------------- 1 | ############################################################################ 2 | 3 | 4 | 5 | # Lesson: Farthest Insertion 6 | 7 | 8 | 9 | ############################################################################ 10 | 11 | # Required Libraries 12 | import copy 13 | import numpy as np 14 | 15 | ############################################################################ 16 | 17 | # Function: Tour Distance 18 | def distance_calc(distance_matrix, city_tour): 19 | distance = 0 20 | for k in range(0, len(city_tour[0])-1): 21 | m = k + 1 22 | distance = distance + distance_matrix[city_tour[0][k]-1, city_tour[0][m]-1] 23 | return distance 24 | 25 | # Function: 2_opt 26 | def local_search_2_opt(distance_matrix, city_tour, recursive_seeding = -1, verbose = True): 27 | if (recursive_seeding < 0): 28 | count = -2 29 | else: 30 | count = 0 31 | city_list = copy.deepcopy(city_tour) 32 | distance = city_list[1]*2 33 | iteration = 0 34 | while (count < recursive_seeding): 35 | if (verbose == True): 36 | print('Iteration = ', iteration, 'Distance = ', round(city_list[1], 2)) 37 | best_route = copy.deepcopy(city_list) 38 | seed = copy.deepcopy(city_list) 39 | for i in range(0, len(city_list[0]) - 2): 40 | for j in range(i+1, len(city_list[0]) - 1): 41 | best_route[0][i:j+1] = list(reversed(best_route[0][i:j+1])) 42 | best_route[0][-1] = best_route[0][0] 43 | best_route[1] = distance_calc(distance_matrix, best_route) 44 | if (city_list[1] > best_route[1]): 45 | city_list = copy.deepcopy(best_route) 46 | best_route = copy.deepcopy(seed) 47 | count = count + 1 48 | iteration = iteration + 1 49 | if (distance > city_list[1] and recursive_seeding < 0): 50 | distance = city_list[1] 51 | count = -2 52 | recursive_seeding = -1 53 | elif(city_list[1] >= distance and recursive_seeding < 0): 54 | count = -1 55 | recursive_seeding = -2 56 | return city_list[0], city_list[1] 57 | 58 | ############################################################################ 59 | 60 | # Function: Best Insertion 61 | def best_insertion(distance_matrix, temp): 62 | temp_ = [item+1 for item in temp] 63 | temp_ = temp_ + [temp_[0]] 64 | d = distance_calc(distance_matrix, [temp_, 1]) 65 | seed = [temp_, d] 66 | temp_, _ = local_search_2_opt(distance_matrix, seed, recursive_seeding = -1, verbose = False) 67 | temp = [item-1 for item in temp_[:-1]] 68 | return temp 69 | 70 | ############################################################################ 71 | 72 | # Function: Farthest Insertion 73 | def farthest_insertion(distance_matrix, initial_location = -1, verbose = True): 74 | maximum = float('+inf') 75 | distance = float('+inf') 76 | route = [] 77 | for i1 in range(0, distance_matrix.shape[0]): 78 | if (initial_location != -1): 79 | i1 = initial_location-1 80 | temp = [] 81 | dist = np.copy(distance_matrix) 82 | dist = dist.astype(float) 83 | np.fill_diagonal(dist, float('-inf')) 84 | idx = dist[i1,:].argmax() 85 | dist[i1,:] = float('-inf') 86 | dist[:,i1] = float('-inf') 87 | temp.append(i1) 88 | temp.append(idx) 89 | for j in range(0, distance_matrix.shape[0]-2): 90 | i2 = idx 91 | idx = dist[i2,:].argmax() 92 | dist[i2,:] = float('-inf') 93 | dist[:,i2] = float('-inf') 94 | temp.append(idx) 95 | temp = best_insertion(distance_matrix, temp) 96 | temp = temp + [temp[0]] 97 | temp = [item + 1 for item in temp] 98 | val = distance_calc(distance_matrix, [temp, 1]) 99 | if (val < maximum): 100 | maximum = val 101 | distance = val 102 | route = [item for item in temp] 103 | if (verbose == True): 104 | print('Iteration = ', i1, 'Distance = ', round(distance, 2)) 105 | if (initial_location == -1): 106 | continue 107 | else: 108 | break 109 | return route, distance 110 | 111 | ############################################################################ 112 | -------------------------------------------------------------------------------- /rlsolver/methods/tsp_algs/ins_n.py: -------------------------------------------------------------------------------- 1 | ############################################################################ 2 | 3 | 4 | 5 | # Lesson: Nearest Insertion 6 | 7 | 8 | 9 | ############################################################################ 10 | 11 | # Required Libraries 12 | import copy 13 | import numpy as np 14 | 15 | ############################################################################ 16 | 17 | # Function: Tour Distance 18 | def distance_calc(distance_matrix, city_tour): 19 | distance = 0 20 | for k in range(0, len(city_tour[0])-1): 21 | m = k + 1 22 | distance = distance + distance_matrix[city_tour[0][k]-1, city_tour[0][m]-1] 23 | return distance 24 | 25 | # Function: 2_opt 26 | def local_search_2_opt(distance_matrix, city_tour, recursive_seeding = -1, verbose = True): 27 | if (recursive_seeding < 0): 28 | count = -2 29 | else: 30 | count = 0 31 | city_list = copy.deepcopy(city_tour) 32 | distance = city_list[1]*2 33 | iteration = 0 34 | while (count < recursive_seeding): 35 | if (verbose == True): 36 | print('Iteration = ', iteration, 'Distance = ', round(city_list[1], 2)) 37 | best_route = copy.deepcopy(city_list) 38 | seed = copy.deepcopy(city_list) 39 | for i in range(0, len(city_list[0]) - 2): 40 | for j in range(i+1, len(city_list[0]) - 1): 41 | best_route[0][i:j+1] = list(reversed(best_route[0][i:j+1])) 42 | best_route[0][-1] = best_route[0][0] 43 | best_route[1] = distance_calc(distance_matrix, best_route) 44 | if (city_list[1] > best_route[1]): 45 | city_list = copy.deepcopy(best_route) 46 | best_route = copy.deepcopy(seed) 47 | count = count + 1 48 | iteration = iteration + 1 49 | if (distance > city_list[1] and recursive_seeding < 0): 50 | distance = city_list[1] 51 | count = -2 52 | recursive_seeding = -1 53 | elif(city_list[1] >= distance and recursive_seeding < 0): 54 | count = -1 55 | recursive_seeding = -2 56 | return city_list[0], city_list[1] 57 | 58 | ############################################################################ 59 | 60 | # Function: Best Insertion 61 | def best_insertion(distance_matrix, temp): 62 | temp_ = [item+1 for item in temp] 63 | temp_ = temp_ + [temp_[0]] 64 | d = distance_calc(distance_matrix, [temp_, 1]) 65 | seed = [temp_, d] 66 | temp_, _ = local_search_2_opt(distance_matrix, seed, recursive_seeding = -1, verbose = False) 67 | temp = [item-1 for item in temp_[:-1]] 68 | return temp 69 | 70 | ############################################################################ 71 | 72 | # Function: Nearest Insertion 73 | def nearest_insertion(distance_matrix, initial_location = -1, verbose = True): 74 | minimum = float('+inf') 75 | distance = float('+inf') 76 | route = [] 77 | for i1 in range(0, distance_matrix.shape[0]): 78 | if (initial_location != -1): 79 | i1 = initial_location-1 80 | temp = [] 81 | dist = np.copy(distance_matrix) 82 | dist = dist.astype(float) 83 | np.fill_diagonal(dist, float('+inf')) 84 | idx = dist[i1,:].argmin() 85 | dist[i1,:] = float('+inf') 86 | dist[:,i1] = float('+inf') 87 | temp.append(i1) 88 | temp.append(idx) 89 | for j in range(0, distance_matrix.shape[0]-2): 90 | i2 = idx 91 | idx = dist[i2,:].argmin() 92 | dist[i2,:] = float('+inf') 93 | dist[:,i2] = float('+inf') 94 | temp.append(idx) 95 | temp = best_insertion(distance_matrix, temp) 96 | temp = temp + [temp[0]] 97 | temp = [item + 1 for item in temp] 98 | val = distance_calc(distance_matrix, [temp, 1]) 99 | if (val < minimum): 100 | minimum = val 101 | distance = val 102 | route = [item for item in temp] 103 | if (verbose == True): 104 | print('Iteration = ', i1, 'Distance = ', round(distance, 2)) 105 | if (initial_location == -1): 106 | continue 107 | else: 108 | break 109 | return route, distance 110 | 111 | ############################################################################ 112 | -------------------------------------------------------------------------------- /rlsolver/methods/tsp_algs/nn.py: -------------------------------------------------------------------------------- 1 | ############################################################################ 2 | 3 | 4 | 5 | # Lesson: Nearest Neighbour 6 | 7 | 8 | 9 | ############################################################################ 10 | 11 | # Required Libraries 12 | import copy 13 | import numpy as np 14 | 15 | ############################################################################ 16 | 17 | # Function: Tour Distance 18 | def distance_calc(distance_matrix, city_tour): 19 | distance = 0 20 | for k in range(0, len(city_tour[0])-1): 21 | m = k + 1 22 | distance = distance + distance_matrix[city_tour[0][k]-1, city_tour[0][m]-1] 23 | return distance 24 | 25 | # Function: 2_opt 26 | def local_search_2_opt(distance_matrix, city_tour, recursive_seeding = -1, verbose = True): 27 | if (recursive_seeding < 0): 28 | count = -2 29 | else: 30 | count = 0 31 | city_list = copy.deepcopy(city_tour) 32 | distance = city_list[1]*2 33 | iteration = 0 34 | while (count < recursive_seeding): 35 | if (verbose == True): 36 | print('Iteration = ', iteration, 'Distance = ', round(city_list[1], 2)) 37 | best_route = copy.deepcopy(city_list) 38 | seed = copy.deepcopy(city_list) 39 | for i in range(0, len(city_list[0]) - 2): 40 | for j in range(i+1, len(city_list[0]) - 1): 41 | best_route[0][i:j+1] = list(reversed(best_route[0][i:j+1])) 42 | best_route[0][-1] = best_route[0][0] 43 | best_route[1] = distance_calc(distance_matrix, best_route) 44 | if (city_list[1] > best_route[1]): 45 | city_list = copy.deepcopy(best_route) 46 | best_route = copy.deepcopy(seed) 47 | count = count + 1 48 | iteration = iteration + 1 49 | if (distance > city_list[1] and recursive_seeding < 0): 50 | distance = city_list[1] 51 | count = -2 52 | recursive_seeding = -1 53 | elif(city_list[1] >= distance and recursive_seeding < 0): 54 | count = -1 55 | recursive_seeding = -2 56 | return city_list[0], city_list[1] 57 | 58 | ############################################################################ 59 | 60 | # Function: Nearest Neighbour 61 | def nearest_neighbour(distance_matrix, initial_location = -1, local_search = True, verbose = True): 62 | minimum = float('+inf') 63 | distance = float('+inf') 64 | route = [] 65 | for i1 in range(0, distance_matrix.shape[0]): 66 | if (initial_location != -1): 67 | i1 = initial_location-1 68 | temp = [] 69 | dist = np.copy(distance_matrix) 70 | dist = dist.astype(float) 71 | np.fill_diagonal(dist, float('+inf')) 72 | idx = dist[i1,:].argmin() 73 | dist[i1,:] = float('+inf') 74 | dist[:,i1] = float('+inf') 75 | temp.append(i1) 76 | temp.append(idx) 77 | for _ in range(0, distance_matrix.shape[0]-2): 78 | i2 = idx 79 | idx = dist[i2,:].argmin() 80 | dist[i2,:] = float('+inf') 81 | dist[:,i2] = float('+inf') 82 | temp.append(idx) 83 | temp = temp + [temp[0]] 84 | temp = [item + 1 for item in temp] 85 | val = distance_calc(distance_matrix, [temp, 1]) 86 | if (local_search == True): 87 | temp, val = local_search_2_opt(distance_matrix, [temp, val], recursive_seeding = -1, verbose = False) 88 | if (val < minimum): 89 | minimum = val 90 | distance = val 91 | route = [item for item in temp] 92 | if (verbose == True): 93 | print('Iteration = ', i1, 'Distance = ', round(distance, 2)) 94 | if (initial_location == -1): 95 | continue 96 | else: 97 | break 98 | return route, distance 99 | 100 | ############################################################################ 101 | -------------------------------------------------------------------------------- /rlsolver/methods/tsp_algs/opt_2.py: -------------------------------------------------------------------------------- 1 | ############################################################################ 2 | 3 | 4 | 5 | # Lesson: Local Search-2-opt 6 | 7 | 8 | 9 | ############################################################################ 10 | 11 | # Required Libraries 12 | import copy 13 | 14 | ############################################################################ 15 | 16 | # Function: Tour Distance 17 | def distance_calc(distance_matrix, city_tour): 18 | distance = 0 19 | for k in range(0, len(city_tour[0])-1): 20 | m = k + 1 21 | distance = distance + distance_matrix[city_tour[0][k]-1, city_tour[0][m]-1] 22 | return distance 23 | 24 | ############################################################################ 25 | 26 | # Function: 2_opt 27 | def local_search_2_opt(distance_matrix, city_tour, recursive_seeding = -1, verbose = True): 28 | if (recursive_seeding < 0): 29 | count = -2 30 | else: 31 | count = 0 32 | city_list = copy.deepcopy(city_tour) 33 | distance = city_list[1]*2 34 | iteration = 0 35 | while (count < recursive_seeding): 36 | if (verbose == True): 37 | print('Iteration = ', iteration, 'Distance = ', round(city_list[1], 2)) 38 | best_route = copy.deepcopy(city_list) 39 | seed = copy.deepcopy(city_list) 40 | for i in range(0, len(city_list[0]) - 2): 41 | for j in range(i+1, len(city_list[0]) - 1): 42 | best_route[0][i:j+1] = list(reversed(best_route[0][i:j+1])) 43 | best_route[0][-1] = best_route[0][0] 44 | best_route[1] = distance_calc(distance_matrix, best_route) 45 | if (city_list[1] > best_route[1]): 46 | city_list = copy.deepcopy(best_route) 47 | best_route = copy.deepcopy(seed) 48 | count = count + 1 49 | iteration = iteration + 1 50 | if (distance > city_list[1] and recursive_seeding < 0): 51 | distance = city_list[1] 52 | count = -2 53 | recursive_seeding = -1 54 | elif(city_list[1] >= distance and recursive_seeding < 0): 55 | count = -1 56 | recursive_seeding = -2 57 | return city_list[0], city_list[1] 58 | 59 | ############################################################################ 60 | -------------------------------------------------------------------------------- /rlsolver/methods/tsp_algs/opt_3.py: -------------------------------------------------------------------------------- 1 | ############################################################################ 2 | 3 | 4 | 5 | # Lesson: Local Search-3-opt 6 | 7 | 8 | 9 | ############################################################################ 10 | 11 | # Required Libraries 12 | import copy 13 | 14 | ############################################################################ 15 | 16 | # Function: Tour Distance 17 | def distance_calc(distance_matrix, city_tour): 18 | distance = 0 19 | for k in range(0, len(city_tour[0])-1): 20 | m = k + 1 21 | distance = distance + distance_matrix[city_tour[0][k]-1, city_tour[0][m]-1] 22 | return distance 23 | 24 | ############################################################################ 25 | 26 | # Function: Possible Segments 27 | def segments_3_opt(n): 28 | x = [] 29 | a, b, c = 0, 0, 0 30 | for i in range(0, n): 31 | a = i 32 | for j in range(i + 1, n): 33 | b = j 34 | for k in range(j + 1, n + (i > 0)): 35 | c = k 36 | x.append((a, b, c)) 37 | return x 38 | 39 | ############################################################################ 40 | 41 | # Function: 3_opt 42 | def local_search_3_opt(distance_matrix, city_tour, recursive_seeding = -1, verbose = True): 43 | if (recursive_seeding < 0): 44 | count = recursive_seeding - 1 45 | else: 46 | count = 0 47 | city_list = [city_tour[0][:-1], city_tour[1]] 48 | city_list_old = city_list[1]*2 49 | iteration = 0 50 | while (count < recursive_seeding): 51 | if (verbose == True): 52 | print('Iteration = ', iteration, 'Distance = ', round(city_list[1], 2)) 53 | best_route = copy.deepcopy(city_list) 54 | best_route_1 = [[], 1] 55 | seed = copy.deepcopy(city_list) 56 | x = segments_3_opt(len(city_list[0])) 57 | for item in x: 58 | i, j, k = item 59 | A = best_route[0][:i+1] + best_route[0][i+1:j+1] 60 | a = best_route[0][:i+1] + list(reversed(best_route[0][i+1:j+1])) 61 | B = best_route[0][j+1:k+1] 62 | b = list(reversed(B)) 63 | C = best_route[0][k+1:] 64 | c = list(reversed(C)) 65 | trial = [ 66 | # Original Tour 67 | #[A + B + C], 68 | 69 | # 1 70 | [a + B + C], 71 | [A + b + C], 72 | [A + B + c], 73 | 74 | 75 | # 2 76 | [A + b + c], 77 | [a + b + C], 78 | [a + B + c], 79 | 80 | 81 | # 3 82 | [a + b + c] 83 | 84 | ] 85 | # Possibly, there is a sequence of 2-opt moves that decreases the total distance but it begins 86 | # with a move that first increases it 87 | for item in trial: 88 | best_route_1[0] = item[0] 89 | best_route_1[1] = distance_calc(distance_matrix, [best_route_1[0] + [best_route_1[0][0]], 1]) 90 | if (best_route_1[1] < best_route[1]): 91 | best_route = [best_route_1[0], best_route_1[1]] 92 | if (best_route[1] < city_list[1]): 93 | city_list = [best_route[0], best_route[1]] 94 | best_route = copy.deepcopy(seed) 95 | count = count + 1 96 | iteration = iteration + 1 97 | if (city_list_old > city_list[1] and recursive_seeding < 0): 98 | city_list_old = city_list[1] 99 | count = -2 100 | recursive_seeding = -1 101 | elif(city_list[1] >= city_list_old and recursive_seeding < 0): 102 | count = -1 103 | recursive_seeding = -2 104 | city_list = [city_list[0] + [city_list[0][0]], city_list[1]] 105 | return city_list[0], city_list[1] 106 | 107 | ############################################################################ 108 | -------------------------------------------------------------------------------- /rlsolver/methods/util_generate_tsp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | class ConfigTsp: 6 | batch_size = 1 7 | num_nodes = 10 8 | low = 0 9 | high = 100 10 | random_mode = 'uniform' # 'uniform','gaussian' 11 | assert random_mode in ['uniform', 'gaussian'] 12 | filename = f"tsp{num_nodes}_batch{batch_size}_{random_mode}" 13 | data_path = "../data/" + filename + '.tsp' 14 | 15 | 16 | def generate_tsp_data(batch=10, nodes_num=10, low=0, high=1, random_mode="uniform"): 17 | if random_mode == "uniform": 18 | node_coords = np.random.uniform(low, high, size=(batch, nodes_num, 2)) 19 | elif random_mode == "gaussian": 20 | node_coords = np.random.normal(loc=0, scale=1, size=(batch, nodes_num, 2)) 21 | max_value = np.max(node_coords) 22 | min_value = np.min(node_coords) 23 | node_coords = np.interp(node_coords, (min_value, max_value), (low, high)) 24 | else: 25 | raise ValueError(f"Unknown random_mode: {random_mode}") 26 | return node_coords 27 | 28 | 29 | def generate_tsp_file(node_coords: np.ndarray, filename): 30 | if node_coords.ndim == 3: 31 | shape = node_coords.shape 32 | if shape[0] == 1: 33 | node_coords = node_coords.squeeze(axis=0) 34 | _generate_tsp_file(node_coords, filename) 35 | else: 36 | for i in range(shape[0]): 37 | _filename = filename.replace('.tsp', '') + '_ID' + str(i) + '.tsp' 38 | _generate_tsp_file(node_coords[i], _filename) 39 | else: 40 | assert node_coords.ndim == 2 41 | _generate_tsp_file(node_coords, filename) 42 | 43 | 44 | def _generate_tsp_file(node_coords: np.ndarray, filename): 45 | num_points = node_coords.shape[0] 46 | file_basename = os.path.basename(filename) 47 | with open(filename, 'w') as f: 48 | f.write(f"NAME: {file_basename}\n") 49 | f.write("TYPE: TSP\n") 50 | f.write(f"DIMENSION: {num_points}\n") 51 | f.write("EDGE_WEIGHT_TYPE: EUC_2D\n") 52 | f.write("NODE_COORD_SECTION\n") 53 | for i in range(num_points): 54 | x, y = node_coords[i] 55 | f.write(f"{i + 1} {x} {y}\n") 56 | f.write("EOF\n") 57 | 58 | 59 | if __name__ == "__main__": 60 | # tab_printer(args) 61 | node_coords = generate_tsp_data(ConfigTsp.batch_size, ConfigTsp.num_nodes, ConfigTsp.low, ConfigTsp.high, ConfigTsp.random_mode) 62 | generate_tsp_file(node_coords, ConfigTsp.data_path) 63 | -------------------------------------------------------------------------------- /rlsolver/requirements.txt: -------------------------------------------------------------------------------- 1 | # ML framework 2 | torch 3 | 4 | # data handling 5 | numpy 6 | 7 | # plot/simulation 8 | matplotlib 9 | -------------------------------------------------------------------------------- /rlsolver/result/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/rlsolver/result/__init__.py -------------------------------------------------------------------------------- /rlsolver/result/c101-10-customers.txt: -------------------------------------------------------------------------------- 1 | alg_name: column_generation 2 | running_duration: 3.254345417022705 3 | dist: 59 4 | dists: [59] 5 | demands: [150] 6 | durations: [1026] 7 | paths: 8 | ('0-orig', '5', '3', '7', '8', '10', '9', '6', '4', '2', '1', '11-dest') 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="ElegantRL", 5 | version="0.3.10", 6 | author="AI4Finance Foundation", 7 | author_email="contact@ai4finance.org", 8 | url="https://github.com/AI4Finance-Foundation/ElegantRL", 9 | license="Apache 2.0", 10 | packages=find_packages(), 11 | install_requires=[ 12 | "th", 13 | "numpy", 14 | "gymnasium", 15 | "matplotlib", 16 | ], 17 | description="Lightweight, Efficient and Stable DRL Implementation Using PyTorch", 18 | classifiers=[ 19 | # Trove classifiers 20 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 21 | "License :: OSI Approved :: Apache Software License", 22 | "Programming Language :: Python", 23 | "Programming Language :: Python :: 3", 24 | "Programming Language :: Python :: 3.6", 25 | "Programming Language :: Python :: 3.7", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Programming Language :: Python :: Implementation :: CPython", 31 | "Programming Language :: Python :: Implementation :: PyPy", 32 | ], 33 | keywords="Deep Reinforcement Learning", 34 | python_requires=">=3.6", 35 | ) 36 | -------------------------------------------------------------------------------- /unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Finance-Foundation/ElegantRL/8ea76afc3e7f1564ae9f0e69e70254116d575fe9/unit_tests/__init__.py -------------------------------------------------------------------------------- /unit_tests/envs/test_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from elegantrl.envs.CustomGymEnv import PendulumEnv 3 | 4 | 5 | def test_pendulum_env(): 6 | print("\n| test_pendulum_env()") 7 | env = PendulumEnv() 8 | assert isinstance(env.env_name, str) 9 | assert isinstance(env.state_dim, int) 10 | assert isinstance(env.action_dim, int) 11 | assert isinstance(env.if_discrete, bool) 12 | 13 | state = env.reset() 14 | assert state.shape == (env.state_dim,) 15 | 16 | action = np.random.uniform(-1, +1, size=env.action_dim) 17 | state, reward, done, info_dict = env.step(action) 18 | assert isinstance(state, np.ndarray) 19 | assert state.shape == (env.state_dim,) 20 | assert isinstance(state, np.ndarray) 21 | assert isinstance(reward, float) 22 | assert isinstance(done, bool) 23 | assert isinstance(info_dict, dict) or (info_dict is None) 24 | 25 | 26 | if __name__ == '__main__': 27 | print('\n| test_env.py') 28 | test_pendulum_env() 29 | -------------------------------------------------------------------------------- /unit_tests/envs/test_isaac_env.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from elegantrl.envs.IsaacGym import * 4 | 5 | 6 | def create_isaac_vec_environment(env_name: str): 7 | isaac_env = IsaacVecEnv(env_name) 8 | del isaac_env 9 | 10 | 11 | if __name__ == "__main__": 12 | env_name = sys.argv[1] 13 | create_isaac_vec_environment(env_name) 14 | -------------------------------------------------------------------------------- /unit_tests/envs/test_isaac_environments.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script tests whether or not each Isaac Gym environment can be effectively 3 | instantiated. 4 | """ 5 | 6 | import isaacgym 7 | import unittest 8 | from elegantrl.envs.IsaacGym import * 9 | from elegantrl.envs.isaac_tasks import isaacgym_task_map 10 | from subprocess import call 11 | 12 | 13 | class TestIsaacEnvironments(unittest.TestCase): 14 | def setUp(self): 15 | self.task_map = isaacgym_task_map 16 | 17 | def test_should_instantiate_all_Isaac_vector_environments(self): 18 | for env_name in self.task_map: 19 | return_code = call( 20 | ["python3", "unit_tests/isaac_env_test_helper.py", env_name] 21 | ) 22 | if return_code != 0: 23 | raise Exception( 24 | f"Instantiating {env_name} resulted in error code {return_code}" 25 | ) 26 | 27 | 28 | if __name__ == "__main__": 29 | unittest.main() 30 | -------------------------------------------------------------------------------- /unit_tests/train/test_evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from elegantrl.train.evaluator import Evaluator 3 | from elegantrl.envs.CustomGymEnv import PendulumEnv 4 | 5 | EnvArgsPendulum = {'env_name': 'Pendulum-v1', 'state_dim': 3, 'action_dim': 1, 'if_discrete': False} 6 | 7 | 8 | def test_get_rewards_and_steps(): 9 | print("\n| test_get_rewards_and_steps()") 10 | from elegantrl.train.evaluator import get_rewards_and_steps 11 | from elegantrl.agents.net import Actor 12 | 13 | env = PendulumEnv() 14 | 15 | state_dim = env.state_dim 16 | action_dim = env.action_dim 17 | if_discrete = env.if_discrete 18 | 19 | actor = Actor(dims=[8, 8], state_dim=state_dim, action_dim=action_dim) 20 | 21 | if_render = False 22 | rewards, steps = get_rewards_and_steps(env=env, actor=actor, if_render=if_render) 23 | assert isinstance(rewards, float) 24 | assert isinstance(steps, int) 25 | 26 | if os.name == 'nt': # if the operating system is Windows NT 27 | if_render = True 28 | print("\"libpng warning: iCCP: cHRM chunk does not match sRGB\" → It doesn't matter to see this warning.") 29 | rewards, steps = get_rewards_and_steps(env=env, actor=actor, if_render=if_render) 30 | assert isinstance(rewards, float) 31 | assert isinstance(steps, int) 32 | 33 | 34 | if __name__ == '__main__': 35 | print("\n| test_evaluator.py") 36 | test_get_rewards_and_steps() 37 | --------------------------------------------------------------------------------