├── minppo ├── py.typed ├── __init__.py ├── utils.py ├── configs │ └── stompy_pro.yaml ├── infer.py ├── cli.py ├── config.py ├── env.py └── train.py ├── MANIFEST.in ├── .vscode └── settings.json ├── requirements.txt ├── Makefile ├── .gitignore ├── LICENSE ├── .github └── workflows │ ├── publish.yml │ └── test.yml ├── README.md ├── setup.py ├── pyproject.toml └── CONTRIBUTING.md /minppo/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /minppo/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include minppo/configs *.yaml 2 | include minppo/py.typed 3 | -------------------------------------------------------------------------------- /minppo/utils.py: -------------------------------------------------------------------------------- 1 | """Defines some shared utility functions for the package.""" 2 | -------------------------------------------------------------------------------- /minppo/configs/stompy_pro.yaml: -------------------------------------------------------------------------------- 1 | kscale_id: 5eb3cb7f23232298 2 | visualization: 3 | camera_name: track 4 | training: 5 | num_minibatches: 32 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "[python]": { 3 | "editor.formatOnSave": true, 4 | "editor.defaultFormatter": "ms-python.black-formatter" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # requirements.txt 2 | 3 | # Training 4 | tqdm 5 | colorlogging 6 | 7 | # Environment 8 | brax 9 | distrax 10 | equinox 11 | jax 12 | optax 13 | kscale 14 | 15 | # Configuration management. 16 | omegaconf 17 | 18 | # Types 19 | types-tqdm 20 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile 2 | 3 | py-files := $(shell find . -name '*.py') 4 | 5 | format: 6 | @black $(py-files) 7 | @ruff format $(py-files) 8 | .PHONY: format 9 | 10 | static-checks: 11 | @black --diff --check $(py-files) 12 | @ruff check $(py-files) 13 | @mypy --install-types --non-interactive $(py-files) 14 | .PHONY: lint 15 | 16 | test: 17 | python -m pytest 18 | .PHONY: test 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # .gitignore 2 | 3 | # Python 4 | *.py[oc] 5 | __pycache__/ 6 | *.egg-info 7 | .eggs/ 8 | .mypy_cache/* 9 | .pyre/ 10 | .pytest_cache/ 11 | .ruff_cache/ 12 | .dmypy.json 13 | 14 | # Build artifacts 15 | build/ 16 | dist/ 17 | *.so 18 | out*/ 19 | *.stl 20 | *.mp4 21 | *.notes 22 | 23 | # Directories 24 | environments/ 25 | screenshots/ 26 | scratch/ 27 | assets/ 28 | wandb/ 29 | models/ 30 | logs/ 31 | 32 | # Other artifacts 33 | MUJOCO_LOG.TXT 34 | -------------------------------------------------------------------------------- /minppo/infer.py: -------------------------------------------------------------------------------- 1 | """Runs inference for the trained model.""" 2 | 3 | import logging 4 | import os 5 | import pickle 6 | import sys 7 | from typing import Sequence 8 | 9 | os.environ["MUJOCO_GL"] = "egl" 10 | os.environ["DISPLAY"] = ":0" 11 | 12 | # Add logger configuration 13 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def load_model(filename: str) -> dict: 18 | with open(filename, "rb") as f: 19 | return pickle.load(f) 20 | 21 | 22 | def main(args: Sequence[str] | None = None) -> None: 23 | """Runs inference with pretrained models.""" 24 | if args is None: 25 | args = sys.argv[1:] 26 | 27 | raise NotImplementedError("Not implemented yet") 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /minppo/cli.py: -------------------------------------------------------------------------------- 1 | """Defines a command line interface for the package.""" 2 | 3 | import argparse 4 | 5 | import colorlogging 6 | 7 | from minppo.env import main as env_main 8 | from minppo.infer import main as infer_main 9 | from minppo.train import main as train_main 10 | 11 | 12 | def main() -> None: 13 | colorlogging.configure() 14 | 15 | parser = argparse.ArgumentParser(description="MinPPO CLI") 16 | parser.add_argument("command", choices=["train", "env", "infer"], help="Command to run") 17 | args, other_args = parser.parse_known_args() 18 | 19 | if args.command == "train": 20 | train_main(other_args) 21 | elif args.command == "env": 22 | env_main(other_args) 23 | elif args.command == "infer": 24 | infer_main(other_args) 25 | else: 26 | raise ValueError(f"Invalid command: {args.command}") 27 | 28 | 29 | if __name__ == "__main__": 30 | # python -m minppo.cli 31 | main() 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Nathan Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: read 10 | id-token: write 11 | 12 | concurrency: 13 | group: "publish" 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | publish: 18 | timeout-minutes: 10 19 | name: Build and publish 20 | 21 | # We don't need to run on all platforms since this package is 22 | # platform-agnostic. The output wheel is something like 23 | # "minppo--py3-none-any.whl". 24 | runs-on: ubuntu-latest 25 | 26 | steps: 27 | - name: Checkout code 28 | uses: actions/checkout@v3 29 | 30 | - name: Set up Python 31 | uses: actions/setup-python@v4 32 | with: 33 | python-version: "3.10" 34 | 35 | - name: Install dependencies 36 | run: | 37 | python -m pip install --upgrade pip 38 | pip install build wheel 39 | 40 | - name: Build package 41 | run: python -m build --sdist --wheel --outdir dist/ . 42 | 43 | - name: Publish package 44 | uses: pypa/gh-action-pypi-publish@release/v1 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | K-Scale Open Source Robotics 4 | 5 |

6 | 7 |
8 | 9 | [![License](https://img.shields.io/badge/license-MIT-green)](https://github.com/kscalelabs/ksim/blob/main/LICENSE) 10 | [![Discord](https://img.shields.io/discord/1224056091017478166)](https://discord.gg/k5mSvCkYQh) 11 | [![Wiki](https://img.shields.io/badge/wiki-humanoids-black)](https://humanoids.wiki) 12 |
13 | [![python](https://img.shields.io/badge/-Python_3.11-blue?logo=python&logoColor=white)](https://github.com/pre-commit/pre-commit) 14 | [![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/) 15 | [![ruff](https://img.shields.io/badge/Linter-Ruff-red.svg?labelColor=gray)](https://github.com/charliermarsh/ruff) 16 |
17 | [![Python Checks](https://github.com/kscalelabs/humanoid-standup/actions/workflows/test.yml/badge.svg)](https://github.com/kscalelabs/humanoid-standup/actions/workflows/test.yml) 18 | 19 |
20 | 21 | # MinPPO 22 | 23 | This repository implements a minimal version of PPO using Jax. For more information, see the [documentation](https://docs.kscale.dev/software/simulation/minppo). 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # mypy: disable-error-code="import-untyped" 2 | #!/usr/bin/env python 3 | """Setup script for the project.""" 4 | 5 | import re 6 | 7 | from setuptools import setup 8 | 9 | with open("README.md", "r", encoding="utf-8") as f: 10 | long_description: str = f.read() 11 | 12 | 13 | with open("requirements.txt", "r", encoding="utf-8") as f: 14 | requirements: list[str] = f.read().splitlines() 15 | 16 | 17 | requirements_dev = [ 18 | "black", 19 | "darglint", 20 | "mypy", 21 | "pytest", 22 | "ruff", 23 | ] 24 | 25 | 26 | with open("minppo/__init__.py", "r", encoding="utf-8") as fh: 27 | version_re = re.search(r"^__version__ = \"([^\"]*)\"", fh.read(), re.MULTILINE) 28 | assert version_re is not None, "Could not find version in minppo/__init__.py" 29 | version: str = version_re.group(1) 30 | 31 | 32 | setup( 33 | name="minppo", 34 | version=version, 35 | description="The minppo project", 36 | author="Benjamin Bolte", 37 | url="https://github.com/kscalelabs/minppo", 38 | long_description=long_description, 39 | long_description_content_type="text/markdown", 40 | python_requires=">=3.11", 41 | install_requires=requirements, 42 | tests_require=requirements_dev, 43 | extras_require={"dev": requirements_dev}, 44 | entry_points={ 45 | "console_scripts": [ 46 | "minppo=minppo.cli:main", 47 | ], 48 | }, 49 | ) 50 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | 3 | line-length = 120 4 | target-version = ["py310"] 5 | include = '\.pyi?$' 6 | 7 | [tool.mypy] 8 | 9 | pretty = true 10 | show_column_numbers = true 11 | show_error_context = true 12 | show_error_codes = true 13 | show_traceback = true 14 | disallow_untyped_defs = true 15 | strict_equality = true 16 | allow_redefinition = true 17 | 18 | warn_unused_ignores = true 19 | warn_redundant_casts = true 20 | 21 | incremental = true 22 | explicit_package_bases = true 23 | 24 | disable_error_code = ["attr-defined"] 25 | 26 | plugins = ["numpy.typing.mypy_plugin"] 27 | 28 | [[tool.mypy.overrides]] 29 | 30 | module = [ 31 | "brax.*", 32 | "optax.*", 33 | "equinox.*", 34 | "tensorflow.*", 35 | "mujoco.*", 36 | "torchaudio.*", 37 | "torchvision.*", 38 | "mediapy.*", 39 | "distrax.*", 40 | ] 41 | 42 | ignore_missing_imports = true 43 | 44 | [tool.isort] 45 | 46 | profile = "black" 47 | 48 | [tool.ruff] 49 | 50 | line-length = 120 51 | target-version = "py310" 52 | 53 | [tool.ruff.lint] 54 | 55 | select = ["ANN", "D", "E", "F", "I", "N", "PGH", "PLC", "PLE", "PLR", "PLW", "W"] 56 | 57 | ignore = [ 58 | "ANN101", "ANN102", 59 | "D101", "D102", "D103", "D104", "D105", "D106", "D107", 60 | "N812", "N817", 61 | "PLR0911", "PLR0912", "PLR0913", "PLR0915", "PLR2004", 62 | "PLW0603", "PLW2901", 63 | ] 64 | 65 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 66 | 67 | [tool.ruff.lint.per-file-ignores] 68 | 69 | "__init__.py" = ["E402", "F401", "F403", "F811"] 70 | 71 | [tool.ruff.lint.pydocstyle] 72 | 73 | convention = "google" 74 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Python Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | types: 11 | - opened 12 | - reopened 13 | - synchronize 14 | - ready_for_review 15 | 16 | concurrency: 17 | group: tests-${{ github.head_ref || github.run_id }} 18 | cancel-in-progress: true 19 | 20 | jobs: 21 | run-base-tests: 22 | timeout-minutes: 10 23 | runs-on: ubuntu-latest 24 | steps: 25 | - name: Check out repository 26 | uses: actions/checkout@v3 27 | 28 | - name: Set up Python 29 | uses: actions/setup-python@v4 30 | with: 31 | python-version: "3.11" 32 | 33 | - name: Restore cache 34 | id: restore-cache 35 | uses: actions/cache/restore@v3 36 | with: 37 | path: | 38 | ${{ env.pythonLocation }} 39 | .mypy_cache/ 40 | key: python-requirements-${{ env.pythonLocation }}-${{ github.event.pull_request.base.sha || github.sha }} 41 | restore-keys: | 42 | python-requirements-${{ env.pythonLocation }} 43 | python-requirements- 44 | 45 | - name: Install package 46 | run: | 47 | pip install --upgrade --upgrade-strategy eager --extra-index-url https://download.pytorch.org/whl/cpu -e '.[dev]' 48 | 49 | - name: Run static checks 50 | run: | 51 | mkdir -p .mypy_cache 52 | make static-checks 53 | 54 | - name: Run unit tests 55 | run: | 56 | make test 57 | 58 | - name: Save cache 59 | uses: actions/cache/save@v3 60 | if: github.ref == 'refs/heads/master' 61 | with: 62 | path: | 63 | ${{ env.pythonLocation }} 64 | .mypy_cache/ 65 | key: ${{ steps.restore-cache.outputs.cache-primary-key }} 66 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | We love your input! We want to make contributing to this project as easy and transparent as possible, whether it's: 4 | 5 | - Reporting a bug 6 | - Discussing the current state of the code 7 | - Submitting a fix 8 | - Proposing new features 9 | - Becoming a maintainer 10 | 11 | ## Getting Started 12 | 13 | - Fork the repo and clone it on your machine. 14 | - Create a branch (`git checkout -b feature/myNewFeature`). 15 | 16 | ## Your First Contribution 17 | 18 | Unsure where to begin contributing to our code? You can start by looking through the "TODO" list in the README.md or looking at the issues 19 | 20 | ## Pull Request Process 21 | 22 | 1. Ensure any install or build dependencies are removed before the end of the layer when doing a build through `pip install -r requirements.txt`. 23 | 2. Update the README.md with details of changes, this includes new features, further documentation, or possible TODOs. 24 | 3. You may merge the Pull Request in once you have the sign-off of two other developers, or if you do not have permission to do that, you may request the second reviewer to merge it for you. 25 | 4. Follow the linting process shown below. 26 | 27 | To run linting, use: 28 | 29 | ```bash 30 | black *.py 31 | ruff check --fix *.py 32 | ``` 33 | 34 | ## Setting Up Development Environment 35 | 36 | The repository is currently (and purposefully) very lightweight. `requirements.txt` is all you need! If you come acros any errors regarding MuJoCo's rendering, you may have to `export DISPLAY=:0` on a headless environment utilizing `xvfb`. If you are still getting errors, try `export MUJOCO_GL=egl` 37 | 38 | 39 | ## Community and Communication 40 | 41 | For major changes, please open an issue first to discuss what you would like to change. Please make sure to update tests as appropriate, and try to reduce verbosity and annotate your code as much as possible 42 | 43 | Thank you for contributing! 44 | -------------------------------------------------------------------------------- /minppo/config.py: -------------------------------------------------------------------------------- 1 | """Defines the model configuration options.""" 2 | 3 | import logging 4 | import sys 5 | from dataclasses import dataclass, field 6 | from pathlib import Path 7 | from typing import Sequence, cast 8 | 9 | from omegaconf import MISSING, OmegaConf 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | CONFIG_ROOT_DIR = Path(__file__).parent / "configs" 14 | 15 | 16 | @dataclass 17 | class EnvironmentConfig: 18 | n_frames: int = field(default=1) 19 | backend: str = field(default="mjx") 20 | include_c_vals: bool = field(default=True) 21 | 22 | 23 | @dataclass 24 | class VisualizationConfig: 25 | camera_name: str = field(default=MISSING) 26 | width: int = field(default=640) 27 | height: int = field(default=480) 28 | render_every: int = field(default=1) 29 | max_steps: int = field(default=1000) 30 | video_length: float = field(default=5.0) 31 | num_episodes: int = field(default=20) 32 | video_save_path: str = field(default="episode.mp4") 33 | 34 | 35 | @dataclass 36 | class RewardConfig: 37 | termination_height: float = field(default=-0.2) 38 | height_min_z: float = field(default=-0.2) 39 | height_max_z: float = field(default=2.0) 40 | is_healthy_reward: float = field(default=5) 41 | original_pos_reward_exp_coefficient: float = field(default=2) 42 | original_pos_reward_subtraction_factor: float = field(default=0.2) 43 | original_pos_reward_max_diff_norm: float = field(default=0.5) 44 | ctrl_cost_coefficient: float = field(default=0.1) 45 | weights_ctrl_cost: float = field(default=0.1) 46 | weights_original_pos_reward: float = field(default=4) 47 | weights_is_healthy: float = field(default=1) 48 | weights_velocity: float = field(default=1.25) 49 | 50 | 51 | @dataclass 52 | class ModelConfig: 53 | hidden_size: int = field(default=256) 54 | num_layers: int = field(default=2) 55 | use_tanh: bool = field(default=True) 56 | 57 | 58 | @dataclass 59 | class OptimizerConfig: 60 | lr: float = field(default=3e-4) 61 | max_grad_norm: float = field(default=0.5) 62 | 63 | 64 | @dataclass 65 | class ReinforcementLearningConfig: 66 | num_env_steps: int = field(default=10) 67 | gamma: float = field(default=0.99) 68 | gae_lambda: float = field(default=0.95) 69 | clip_eps: float = field(default=0.2) 70 | ent_coef: float = field(default=0.0) 71 | vf_coef: float = field(default=0.5) 72 | 73 | 74 | @dataclass 75 | class TrainingConfig: 76 | lr: float = field(default=3e-4) 77 | seed: int = field(default=1337) 78 | num_envs: int = field(default=2048) 79 | total_timesteps: int = field(default=1_000_000_000) 80 | num_minibatches: int = field(default=32) 81 | num_steps: int = field(default=10) 82 | update_epochs: int = field(default=4) 83 | anneal_lr: bool = field(default=True) 84 | model_save_path: str = field(default="trained_model.pkl") 85 | 86 | 87 | @dataclass 88 | class InferenceConfig: 89 | model_path: str = field(default=MISSING) 90 | 91 | 92 | @dataclass 93 | class Config: 94 | kscale_id: str = field(default=MISSING) 95 | environment: EnvironmentConfig = field(default_factory=EnvironmentConfig) 96 | visualization: VisualizationConfig = field(default_factory=VisualizationConfig) 97 | reward: RewardConfig = field(default_factory=RewardConfig) 98 | model: ModelConfig = field(default_factory=ModelConfig) 99 | opt: OptimizerConfig = field(default_factory=OptimizerConfig) 100 | rl: ReinforcementLearningConfig = field(default_factory=ReinforcementLearningConfig) 101 | training: TrainingConfig = field(default_factory=TrainingConfig) 102 | inference: InferenceConfig = field(default_factory=InferenceConfig) 103 | debug: bool = field(default=True) 104 | 105 | 106 | def load_config_from_cli(args: Sequence[str] | None = None) -> Config: 107 | if args is None: 108 | args = sys.argv[1:] 109 | if len(args) < 1: 110 | raise ValueError("Usage: ( ...)") 111 | path, *other_args = args 112 | 113 | if Path(path).exists(): 114 | raw_config = OmegaConf.load(path) 115 | elif (config_path := CONFIG_ROOT_DIR / f"{path}.yaml").exists(): 116 | raw_config = OmegaConf.load(config_path) 117 | else: 118 | raise ValueError(f"Config file not found: {path}") 119 | 120 | config = OmegaConf.structured(Config) 121 | config = OmegaConf.merge(config, raw_config) 122 | if other_args: 123 | config = OmegaConf.merge(config, OmegaConf.from_dotlist(other_args)) 124 | 125 | logger.info("Loaded config: %s", OmegaConf.to_yaml(config)) 126 | 127 | return cast(Config, config) 128 | -------------------------------------------------------------------------------- /minppo/env.py: -------------------------------------------------------------------------------- 1 | """Definition of base humanoids environment with reward system and termination conditions.""" 2 | 3 | import asyncio 4 | import logging 5 | import shutil 6 | import sys 7 | import tempfile 8 | import xml.etree.ElementTree as ET 9 | from functools import partial 10 | from pathlib import Path 11 | from typing import Any, NamedTuple, Sequence 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | import mujoco 16 | from brax import base 17 | from brax.envs.base import PipelineEnv 18 | from brax.io import mjcf 19 | from brax.mjx.base import State as MjxState 20 | from kscale import KScale 21 | 22 | from minppo.config import Config, load_config_from_cli 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def load_mjcf_model(kscale_id: str) -> mujoco.MjModel: 28 | api = KScale() 29 | mjcf_path = asyncio.run(api.mjcf_path(kscale_id)) 30 | 31 | # We need to fix up the MJCF model to allow it to work with Brax. 32 | # Specifically, we need to remove the frictionloss attribute from the 33 | # joints element, as Brax does not support it. 34 | with tempfile.TemporaryDirectory() as temp_dir: 35 | temp_mjcf_dir = Path(temp_dir) / mjcf_path.parent.name 36 | shutil.copytree(mjcf_path.parent, temp_mjcf_dir) 37 | temp_mjcf_path = temp_mjcf_dir / mjcf_path.name 38 | tree = ET.parse(temp_mjcf_path) 39 | root = tree.getroot() 40 | 41 | # Updates the element to remove frictionloss attrib. 42 | for joint in root.findall(".//default/joint"): 43 | if "frictionloss" in joint.attrib: 44 | del joint.attrib["frictionloss"] 45 | 46 | # Write the modified XML back to the file 47 | tree.write(temp_mjcf_path) 48 | 49 | model: mujoco.MjModel = mujoco.MjModel.from_xml_path(str(temp_mjcf_path)) 50 | return model 51 | 52 | 53 | class EnvMetrics(NamedTuple): 54 | episode_returns: jnp.ndarray 55 | episode_lengths: jnp.ndarray 56 | returned_episode_returns: jnp.ndarray 57 | returned_episode_lengths: jnp.ndarray 58 | timestep: jnp.ndarray 59 | returned_episode: jnp.ndarray 60 | 61 | 62 | class EnvState(NamedTuple): 63 | pipeline_state: Any # Use Any for MjxState as it's not a standard JAX type 64 | obs: jnp.ndarray 65 | reward: jnp.ndarray 66 | done: jnp.ndarray 67 | metrics: EnvMetrics 68 | 69 | 70 | class HumanoidEnv(PipelineEnv): 71 | """Defines the environment for controlling a humanoid robot. 72 | 73 | This environment uses Brax's `mjcf` module to load a MuJoCo model of a 74 | humanoid robot, which can then be controlled using the `PipelineEnv` API. 75 | 76 | Parameters: 77 | n_frames: The number of times to step the physics pipeline for each 78 | environment step. Setting this value to be greater than 1 means 79 | that the policy will run at a lower frequency than the physics 80 | simulation. 81 | backend: The backend to use for the physics simulation. 82 | kscale_id: The ID of the robot to load from K-Scale. 83 | """ 84 | 85 | initial_qpos: jnp.ndarray 86 | _action_size: int 87 | reset_noise_scale: float = 0.0 88 | 89 | def __init__(self, config: Config) -> None: 90 | self._include_c_vals = config.environment.include_c_vals 91 | self._kscale_id = config.kscale_id 92 | 93 | # Loads the MJCF model using the K-Scale API. 94 | mj_model: mujoco.MjModel = load_mjcf_model(self._kscale_id) 95 | mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG 96 | mj_model.opt.iterations = 6 97 | mj_model.opt.ls_iterations = 6 98 | 99 | self._action_size = mj_model.nu 100 | sys: base.System = mjcf.load_model(mj_model) 101 | 102 | super().__init__(sys, n_frames=config.environment.n_frames, backend=config.environment.backend) 103 | 104 | self.initial_qpos = jnp.array(sys.qpos0) 105 | self.reward_config = config.reward 106 | 107 | # Currently unused 108 | actuator_ctrlrange = [] 109 | for i in range(mj_model.nu): 110 | ctrlrange = mj_model.actuator_ctrlrange[i] 111 | actuator_ctrlrange.append(ctrlrange) 112 | 113 | self.actuator_ctrlrange = jnp.array(actuator_ctrlrange) 114 | 115 | def _get_reset_state(self, rng: jnp.ndarray) -> MjxState: 116 | rng1, rng2 = jax.random.split(rng, 2) 117 | low, hi = -self.reset_noise_scale, self.reset_noise_scale 118 | qpos = self.initial_qpos + jax.random.uniform(rng1, (self.sys.nq,), minval=low, maxval=hi) 119 | qvel = jax.random.uniform(rng2, (self.sys.nv,), minval=low, maxval=hi) 120 | state = self.pipeline_init(qpos, qvel) 121 | return state 122 | 123 | @partial(jax.jit, static_argnums=(0,)) 124 | def reset(self, rng: jnp.ndarray) -> EnvState: 125 | """Gets the initial state of the environment. 126 | 127 | Args: 128 | rng: A JAX random number generator key. 129 | 130 | Returns: 131 | The initial state of the environment. 132 | """ 133 | state = self._get_reset_state(rng) 134 | obs = self.get_obs(state, jnp.zeros(self._action_size)) 135 | 136 | metrics = EnvMetrics( 137 | episode_returns=jnp.array(0.0), 138 | episode_lengths=jnp.array(0), 139 | returned_episode_returns=jnp.array(0.0), 140 | returned_episode_lengths=jnp.array(0), 141 | timestep=jnp.array(0), 142 | returned_episode=jnp.array(False), 143 | ) 144 | 145 | return EnvState(state, obs, jnp.array(0.0), jnp.array(False), metrics) 146 | 147 | @partial(jax.jit, static_argnums=(0,)) 148 | def step(self, env_state: EnvState, action: jnp.ndarray, rng: jnp.ndarray) -> EnvState: 149 | """Runs one timestep of the environment's dynamics. 150 | 151 | Args: 152 | env_state: The current state of the environment. 153 | action: The action to take. 154 | rng: A JAX random number generator key. 155 | 156 | Returns: 157 | The next state of the environment. 158 | """ 159 | state = env_state.pipeline_state 160 | metrics = env_state.metrics 161 | 162 | state_step = self.pipeline_step(state, action) 163 | obs_state = self.get_obs(state, action) 164 | 165 | # Resets finished environments to semi-random states. 166 | state_reset = self._get_reset_state(rng) 167 | obs_reset = self.get_obs(state_reset, jnp.zeros(self._action_size)) 168 | 169 | # Gets the rewards and "done" status flags for the current state. 170 | reward = self.compute_reward(state, state_step, action) 171 | done = self.is_done(state_step) 172 | 173 | # Checks if the state is NaN. 174 | is_nan = jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)), state_step) 175 | any_nan = jax.tree_util.tree_reduce(jnp.logical_or, is_nan) 176 | done = jnp.logical_or(done, any_nan) 177 | 178 | # Reset to the start state if done. 179 | new_state = jax.tree.map(lambda x, y: jax.lax.select(done, x, y), state_reset, state_step) 180 | obs = jax.lax.select(done, obs_reset, obs_state) 181 | 182 | # Calculate new episode return and length. 183 | new_episode_return = metrics.episode_returns + reward 184 | new_episode_length = metrics.episode_lengths + 1 185 | 186 | # Update tracking metrics. 187 | new_metrics = EnvMetrics( 188 | episode_returns=new_episode_return * (1 - done), 189 | episode_lengths=new_episode_length * (1 - done), 190 | returned_episode_returns=metrics.returned_episode_returns * (1 - done) + new_episode_return * done, 191 | returned_episode_lengths=metrics.returned_episode_lengths * (1 - done) + new_episode_length * done, 192 | timestep=metrics.timestep + 1, 193 | returned_episode=done, 194 | ) 195 | 196 | return EnvState(new_state, obs, reward, done, new_metrics) 197 | 198 | @partial(jax.jit, static_argnums=(0,)) 199 | def compute_reward( 200 | self, 201 | state: MjxState, 202 | next_state: MjxState, 203 | action: jnp.ndarray, 204 | ) -> jnp.ndarray: 205 | min_z = self.reward_config.height_min_z 206 | max_z = self.reward_config.height_max_z 207 | 208 | # Reward for maintaining the original position. 209 | exp_coef = self.reward_config.original_pos_reward_exp_coefficient 210 | subtraction_factor = self.reward_config.original_pos_reward_subtraction_factor 211 | max_diff_norm = self.reward_config.original_pos_reward_max_diff_norm 212 | p0_pen = jnp.linalg.norm(self.initial_qpos - state.qpos) 213 | original_pos_reward = jnp.exp(-exp_coef * p0_pen) - subtraction_factor * jnp.clip(p0_pen, 0, max_diff_norm) 214 | 215 | # Reward for maintaining a "healthy" height. 216 | is_healthy = jnp.where(state.q[2] < min_z, 0.0, 1.0) 217 | is_healthy = jnp.where(state.q[2] > max_z, 0.0, is_healthy) 218 | 219 | # Penalizes the total squared torque. 220 | ctrl_cost = -jnp.sum(jnp.square(action)) 221 | 222 | xpos = state.subtree_com[1][0] 223 | next_xpos = next_state.subtree_com[1][0] 224 | velocity = (next_xpos - xpos) / self.dt 225 | 226 | # Weights the rewards by the weighting terms. 227 | ctrl_cost_weighted = self.reward_config.weights_ctrl_cost * ctrl_cost 228 | original_pos_reward_weighted = self.reward_config.weights_original_pos_reward * original_pos_reward 229 | velocity_weighted = self.reward_config.weights_velocity * velocity 230 | is_healthy_weighted = self.reward_config.weights_is_healthy * is_healthy 231 | 232 | # Get a single reward. 233 | total_reward = ctrl_cost_weighted + original_pos_reward_weighted + velocity_weighted + is_healthy_weighted 234 | 235 | return total_reward 236 | 237 | @partial(jax.jit, static_argnums=(0,)) 238 | def is_done(self, state: MjxState) -> jnp.ndarray: 239 | com_height = state.q[2] 240 | min_z, max_z = self.reward_config.height_min_z, self.reward_config.height_max_z 241 | height_condition = jnp.logical_not(jnp.logical_and(min_z < com_height, com_height < max_z)) 242 | return height_condition 243 | 244 | @partial(jax.jit, static_argnums=(0,)) 245 | def get_obs(self, data: MjxState, action: jnp.ndarray) -> jnp.ndarray: 246 | if self._include_c_vals: 247 | obs_components = [ 248 | data.qpos, 249 | data.qvel, 250 | data.cinert[1:].ravel(), 251 | data.cvel[1:].ravel(), 252 | data.qfrc_actuator, 253 | ] 254 | else: 255 | obs_components = [ 256 | data.qpos, 257 | data.qvel, 258 | data.qfrc_actuator, 259 | ] 260 | 261 | return jnp.concatenate(obs_components) 262 | 263 | 264 | def main(args: Sequence[str] | None = None) -> None: 265 | """Runs the environment for a few steps with random actions, for debugging.""" 266 | if args is None: 267 | args = sys.argv[1:] 268 | 269 | try: 270 | import mediapy as media 271 | from tqdm import tqdm 272 | except ImportError: 273 | raise ImportError("Please install `mediapy` and `tqdm` to run this script") 274 | 275 | if not shutil.which("ffmpeg"): 276 | raise ImportError("`ffmpeg` command not found; please make sure the `ffmpeg` command is available") 277 | 278 | config = load_config_from_cli(args) 279 | 280 | env = HumanoidEnv(config) 281 | action_size = env.action_size 282 | logger.info("Initialized environment with action size %d", action_size) 283 | 284 | rng = jax.random.PRNGKey(config.training.seed) 285 | logger.info("Initialized random number generator with seed %d", config.training.seed) 286 | 287 | reset_fn = jax.jit(env.reset) 288 | step_fn = jax.jit(env.step) 289 | 290 | fps = int(1 / env.dt) 291 | max_frames = int(config.visualization.video_length * fps) 292 | rollout: list[MjxState] = [] 293 | logger.info("Starting episode loop") 294 | 295 | for episode in range(config.visualization.num_episodes): 296 | rng, _ = jax.random.split(rng) 297 | env_state: EnvState = reset_fn(rng) 298 | 299 | total_reward = 0 300 | 301 | for _ in tqdm(range(config.visualization.max_steps), desc=f"Episode {episode + 1} Steps", leave=False): 302 | if len(rollout) < config.visualization.video_length * fps: 303 | rollout.append(env_state.pipeline_state) 304 | 305 | rng, action_rng = jax.random.split(rng) 306 | action = jax.random.uniform(action_rng, (action_size,), minval=0, maxval=1.0) 307 | 308 | rng, step_rng = jax.random.split(rng) 309 | env_state: EnvState = step_fn(env_state, action, step_rng) 310 | total_reward += env_state.reward 311 | 312 | if env_state.done: 313 | break 314 | 315 | logger.info("Episode %d total reward: %f", episode + 1, total_reward) 316 | 317 | if len(rollout) >= max_frames: 318 | break 319 | 320 | logger.info("Rendering video with %d frames at %d fps", len(rollout), fps) 321 | 322 | images = jnp.array( 323 | env.render( 324 | rollout[:: config.visualization.render_every], 325 | camera=config.visualization.camera_name, 326 | width=config.visualization.width, 327 | height=config.visualization.height, 328 | ) 329 | ) 330 | 331 | logger.info("Video rendered") 332 | media.write_video(config.visualization.video_save_path, images, fps=fps) 333 | logger.info("Video saved to %s", config.visualization.video_save_path) 334 | 335 | 336 | if __name__ == "__main__": 337 | # python environment.py 338 | main() 339 | -------------------------------------------------------------------------------- /minppo/train.py: -------------------------------------------------------------------------------- 1 | """Train a model with a specified environment module.""" 2 | 3 | import logging 4 | import os 5 | import pickle 6 | import sys 7 | from typing import Any, Callable, NamedTuple, Sequence 8 | 9 | import distrax 10 | import flax.linen as nn 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | import optax 15 | from brax.envs import State 16 | from flax.core import FrozenDict 17 | from flax.linen.initializers import constant, orthogonal 18 | from flax.training.train_state import TrainState 19 | 20 | from minppo.config import Config, load_config_from_cli 21 | from minppo.env import HumanoidEnv 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class Memory(NamedTuple): 27 | done: jnp.ndarray 28 | action: jnp.ndarray 29 | value: jnp.ndarray 30 | reward: jnp.ndarray 31 | log_prob: jnp.ndarray 32 | obs: jnp.ndarray 33 | info: Any 34 | 35 | 36 | class RunnerState(NamedTuple): 37 | train_state: TrainState 38 | env_state: State 39 | last_obs: jnp.ndarray 40 | rng: jnp.ndarray 41 | 42 | 43 | class UpdateState(NamedTuple): 44 | train_state: TrainState 45 | mem_batch: "Memory" 46 | advantages: jnp.ndarray 47 | targets: jnp.ndarray 48 | rng: jnp.ndarray 49 | 50 | 51 | class TrainOutput(NamedTuple): 52 | runner_state: RunnerState 53 | metrics: Any 54 | 55 | 56 | class MLP(nn.Module): 57 | features: Sequence[int] 58 | use_tanh: bool = True 59 | 60 | @nn.compact 61 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 62 | for feat in self.features[:-1]: 63 | x = nn.Dense(feat, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) 64 | if self.use_tanh: 65 | x = nn.tanh(x) 66 | else: 67 | x = nn.relu(x) 68 | return nn.Dense(self.features[-1], kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x) 69 | 70 | 71 | class ActorCritic(nn.Module): 72 | num_layers: int 73 | hidden_size: int 74 | action_dim: int 75 | use_tanh: bool = True 76 | 77 | @nn.compact 78 | def __call__(self, x: jnp.ndarray) -> tuple[distrax.Distribution, jnp.ndarray]: 79 | actor_mean = MLP([self.hidden_size] * self.num_layers + [self.action_dim], use_tanh=self.use_tanh)(x) 80 | actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,)) 81 | pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd)) 82 | critic = MLP([self.hidden_size] * self.num_layers + [1], use_tanh=False)(x) 83 | return pi, jnp.squeeze(critic, axis=-1) 84 | 85 | 86 | def save_model(params: FrozenDict, filename: str) -> None: 87 | os.makedirs(os.path.dirname(filename), exist_ok=True) 88 | with open(filename, "wb") as f: 89 | pickle.dump(params, f) 90 | 91 | 92 | def make_train(config: Config) -> Callable[[jnp.ndarray], TrainOutput]: 93 | num_updates = config.training.total_timesteps // config.training.num_steps // config.training.num_envs 94 | minibatch_size = config.training.num_envs * config.training.num_steps // config.training.num_minibatches 95 | 96 | env = HumanoidEnv(config) 97 | 98 | def linear_schedule(count: int) -> float: 99 | # Linear learning rate annealing 100 | frac = 1.0 - (count // (minibatch_size * config.training.update_epochs)) / num_updates 101 | return config.training.lr * frac 102 | 103 | def train(rng: jnp.ndarray) -> TrainOutput: 104 | network = ActorCritic( 105 | num_layers=config.model.num_layers, 106 | hidden_size=config.model.hidden_size, 107 | action_dim=env.action_size, 108 | use_tanh=config.model.use_tanh, 109 | ) 110 | rng, _rng = jax.random.split(rng) 111 | init_x = jnp.zeros(env.observation_size) 112 | network_params = network.init(_rng, init_x) 113 | 114 | # Set up optimizer with gradient clipping and optional learning rate annealing 115 | if config.training.anneal_lr: 116 | tx = optax.chain( 117 | optax.clip_by_global_norm(config.opt.max_grad_norm), 118 | optax.adam(learning_rate=linear_schedule, eps=1e-5), 119 | ) 120 | else: 121 | tx = optax.chain( 122 | optax.clip_by_global_norm(config.opt.max_grad_norm), 123 | optax.adam(config.opt.lr, eps=1e-5), 124 | ) 125 | 126 | train_state = TrainState.create( 127 | apply_fn=network.apply, 128 | params=network_params, 129 | tx=tx, 130 | ) 131 | 132 | # JIT-compile environment functions for performance 133 | @jax.jit 134 | def reset_fn(rng: jnp.ndarray) -> State: 135 | rngs = jax.random.split(rng, config.training.num_envs) 136 | return jax.vmap(env.reset)(rngs) 137 | 138 | @jax.jit 139 | def step_fn(states: State, actions: jnp.ndarray, rng: jnp.ndarray) -> State: 140 | return jax.vmap(env.step)(states, actions, rng) 141 | 142 | rng, reset_rng = jax.random.split(rng) 143 | env_state = reset_fn(jnp.array(reset_rng)) 144 | obs = env_state.obs 145 | 146 | def _update_step( 147 | runner_state: RunnerState, 148 | unused: Memory, 149 | ) -> tuple[RunnerState, Any]: 150 | def _env_step( 151 | runner_state: RunnerState, 152 | unused: Memory, 153 | ) -> tuple[RunnerState, Memory]: 154 | train_state, env_state, last_obs, rng = runner_state 155 | 156 | # Sample actions from the policy and evaluate the value function 157 | pi, value = network.apply(train_state.params, last_obs) 158 | rng, action_rng = jax.random.split(rng) 159 | action = pi.sample(seed=action_rng) 160 | log_prob = pi.log_prob(action) 161 | 162 | # Step the environment 163 | rng, step_rng = jax.random.split(rng) 164 | step_rngs = jax.random.split(step_rng, config.training.num_envs) 165 | env_state: State = step_fn(env_state, action, step_rngs) 166 | 167 | obs = env_state.obs 168 | reward = env_state.reward 169 | done = env_state.done 170 | info = env_state.metrics 171 | 172 | # Store experience for later use in PPO updates 173 | memory = Memory(done, action, value, reward, log_prob, last_obs, info) 174 | runner_state = RunnerState(train_state, env_state, obs, rng) 175 | 176 | return runner_state, memory 177 | 178 | # Collect experience for multiple steps 179 | runner_state, mem_batch = jax.lax.scan(_env_step, runner_state, None, config.rl.num_env_steps) 180 | 181 | # Calculate advantages using Generalized Advantage Estimation (GAE) 182 | _, last_val = network.apply(runner_state.train_state.params, runner_state.last_obs) 183 | last_val = jnp.array(last_val) 184 | 185 | def _calculate_gae(mem_batch: Memory, last_val: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: 186 | def _get_advantages( 187 | gae_and_next_value: tuple[jnp.ndarray, jnp.ndarray], memory: Memory 188 | ) -> tuple[tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: 189 | gae, next_value = gae_and_next_value 190 | done, value, reward = memory.done, memory.value, memory.reward 191 | 192 | # Calculate TD error and GAE 193 | delta = reward + config.rl.gamma * next_value * (1 - done) - value 194 | gae = delta + config.rl.gamma * config.rl.gae_lambda * (1 - done) * gae 195 | return (gae, value), gae 196 | 197 | # Reverse-order scan to efficiently compute GAE 198 | _, advantages = jax.lax.scan( 199 | _get_advantages, 200 | (jnp.zeros_like(last_val), last_val), 201 | mem_batch, 202 | reverse=True, 203 | unroll=16, 204 | ) 205 | return advantages, advantages + mem_batch.value 206 | 207 | advantages, targets = _calculate_gae(mem_batch, last_val) 208 | 209 | def _update_epoch( 210 | update_state: UpdateState, 211 | unused: tuple[jnp.ndarray, jnp.ndarray], 212 | ) -> tuple[UpdateState, Any]: 213 | def _update_minibatch( 214 | train_state: TrainState, batch_info: tuple[Memory, jnp.ndarray, jnp.ndarray] 215 | ) -> tuple[TrainState, Any]: 216 | mem_batch, advantages, targets = batch_info 217 | 218 | def _loss_fn( 219 | params: FrozenDict, mem_batch: Memory, gae: jnp.ndarray, targets: jnp.ndarray 220 | ) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]: 221 | # Recompute values to calculate losses 222 | pi, value = network.apply(params, mem_batch.obs) 223 | log_prob = pi.log_prob(mem_batch.action) 224 | 225 | # Compute value function loss 226 | value_pred_clipped = mem_batch.value + (value - mem_batch.value).clip( 227 | -config.rl.clip_eps, config.rl.clip_eps 228 | ) 229 | value_losses = jnp.square(value - targets) 230 | value_losses_clipped = jnp.square(value_pred_clipped - targets) 231 | value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() 232 | 233 | # Compute policy loss using PPO clipped objective 234 | ratio = jnp.exp(log_prob - mem_batch.log_prob) 235 | gae = (gae - gae.mean()) / (gae.std() + 1e-8) 236 | loss_actor1 = ratio * gae 237 | loss_actor2 = jnp.clip(ratio, 1.0 - config.rl.clip_eps, 1.0 + config.rl.clip_eps) * gae 238 | loss_actor = -jnp.minimum(loss_actor1, loss_actor2) 239 | loss_actor = loss_actor.mean() 240 | entropy = pi.entropy().mean() 241 | 242 | total_loss = loss_actor + config.rl.vf_coef * value_loss - config.rl.ent_coef * entropy 243 | return total_loss, (value_loss, loss_actor, entropy) 244 | 245 | # Compute gradients and update model 246 | grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) 247 | total_loss, grads = grad_fn(train_state.params, mem_batch, advantages, targets) 248 | train_state = train_state.apply_gradients(grads=grads) 249 | return train_state, total_loss 250 | 251 | train_state, mem_batch, advantages, targets, rng = update_state 252 | rng, _rng = jax.random.split(rng) 253 | batch_size = minibatch_size * config.training.num_minibatches 254 | if batch_size != config.training.num_steps * config.training.num_envs: 255 | raise ValueError("`batch_size` must be equal to `num_steps * num_envs`") 256 | 257 | # Shuffle and organize data into minibatches 258 | permutation = jax.random.permutation(_rng, batch_size) 259 | batch = (mem_batch, advantages, targets) 260 | batch = jax.tree_util.tree_map(lambda x: x.reshape((batch_size,) + x.shape[2:]), batch) 261 | shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=0), batch) 262 | minibatches = jax.tree_util.tree_map( 263 | lambda x: jnp.reshape(x, [config.training.num_minibatches, -1] + list(x.shape[1:])), 264 | shuffled_batch, 265 | ) 266 | 267 | # Update model for each minibatch 268 | train_state, total_loss = jax.lax.scan(_update_minibatch, train_state, minibatches) 269 | update_state = UpdateState(train_state, mem_batch, advantages, targets, rng) 270 | return update_state, total_loss 271 | 272 | # Perform multiple epochs of updates on collected data 273 | update_state = UpdateState(runner_state.train_state, mem_batch, advantages, targets, runner_state.rng) 274 | update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config.training.update_epochs) 275 | 276 | runner_state = RunnerState( 277 | train_state=update_state.train_state, 278 | env_state=runner_state.env_state, 279 | last_obs=runner_state.last_obs, 280 | rng=update_state.rng, 281 | ) 282 | 283 | return runner_state, mem_batch.info 284 | 285 | rng, _rng = jax.random.split(rng) 286 | runner_state = RunnerState(train_state, env_state, obs, _rng) 287 | runner_state, metric = jax.lax.scan(_update_step, runner_state, None, num_updates) 288 | 289 | return TrainOutput(runner_state=runner_state, metrics=metric) 290 | 291 | return train 292 | 293 | 294 | def main(args: Sequence[str] | None = None) -> None: 295 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") 296 | 297 | if args is None: 298 | args = sys.argv[1:] 299 | 300 | config = load_config_from_cli(args) 301 | logger.info("Configuration loaded") 302 | 303 | rng = jax.random.PRNGKey(config.training.seed) 304 | logger.info(f"Random seed set to {config.training.seed}") 305 | 306 | train_jit = jax.jit(make_train(config)) 307 | logger.info("Training function compiled with JAX") 308 | 309 | logger.info("Starting training...") 310 | out = train_jit(rng) 311 | logger.info("Training completed") 312 | 313 | logger.info(f"Saving model to {config.training.model_save_path}") 314 | save_model(out.runner_state.train_state.params, config.training.model_save_path) 315 | logger.info("Model saved successfully") 316 | 317 | 318 | if __name__ == "__main__": 319 | main() 320 | --------------------------------------------------------------------------------