├── .gitignore ├── rocket_learn ├── __init__.py ├── agent │ ├── __init__.py │ ├── pretrained_agents │ │ ├── nexto │ │ │ ├── nexto-model.pt │ │ │ ├── nexto_v2.py │ │ │ └── nexto_v2_obs.py │ │ ├── necto │ │ │ ├── necto-model-10Y.pt │ │ │ ├── necto-model-20Y.pt │ │ │ ├── necto-model-30Y.pt │ │ │ ├── necto_v1.py │ │ │ └── necto_v1_obs.py │ │ └── human_agent.py │ ├── actor_critic_agent.py │ ├── policy.py │ ├── pretrained_policy.py │ └── discrete_policy.py ├── utils │ ├── __init__.py │ ├── stat_trackers │ │ ├── __init__.py │ │ ├── stat_tracker.py │ │ └── common_trackers.py │ ├── truncated_condition.py │ ├── dynamic_gamemode_setter.py │ ├── batched_obs_builder.py │ ├── util.py │ ├── gamestate_encoding.py │ ├── scoreboard.py │ └── generate_episode.py ├── rollout_generator │ ├── __init__.py │ ├── redis │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── redis_rollout_generator.py │ │ └── redis_rollout_worker.py │ ├── base_rollout_generator.py │ ├── simple_rollout_generator.py │ └── api_rollout_generator.py ├── learner.py ├── simple_agents.py ├── experience_buffer.py ├── agent.py └── ppo.py ├── setup.py ├── docs ├── sb3_to_rocketlearn_transition.txt ├── troubleshooting.txt └── network_setup_readme.txt ├── examples ├── default │ ├── sb3_to_rocketlearn_transition.txt │ ├── worker.py │ └── learner.py ├── human_trainer │ └── worker_with_human_trainer.py ├── pretrained_agent │ └── worker_with_pretrained_agent.py └── loading │ └── learner.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* -------------------------------------------------------------------------------- /rocket_learn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rocket_learn/agent/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rocket_learn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rocket_learn/rollout_generator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rocket_learn/rollout_generator/redis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rocket_learn/utils/stat_trackers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /rocket_learn/agent/pretrained_agents/nexto/nexto-model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rolv-Arild/rocket-learn/HEAD/rocket_learn/agent/pretrained_agents/nexto/nexto-model.pt -------------------------------------------------------------------------------- /rocket_learn/agent/pretrained_agents/necto/necto-model-10Y.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rolv-Arild/rocket-learn/HEAD/rocket_learn/agent/pretrained_agents/necto/necto-model-10Y.pt -------------------------------------------------------------------------------- /rocket_learn/agent/pretrained_agents/necto/necto-model-20Y.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rolv-Arild/rocket-learn/HEAD/rocket_learn/agent/pretrained_agents/necto/necto-model-20Y.pt -------------------------------------------------------------------------------- /rocket_learn/agent/pretrained_agents/necto/necto-model-30Y.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rolv-Arild/rocket-learn/HEAD/rocket_learn/agent/pretrained_agents/necto/necto-model-30Y.pt -------------------------------------------------------------------------------- /rocket_learn/utils/stat_trackers/stat_tracker.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import numpy as np 4 | 5 | 6 | class StatTracker(abc.ABC): 7 | def __init__(self, name): 8 | self.name = name 9 | 10 | def reset(self): # Called whenever 11 | raise NotImplementedError 12 | 13 | def update(self, gamestates: np.ndarray, masks: np.ndarray): 14 | raise NotImplementedError 15 | 16 | def get_stat(self): 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /rocket_learn/rollout_generator/base_rollout_generator.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Iterator 3 | 4 | from rocket_learn.experience_buffer import ExperienceBuffer 5 | 6 | 7 | class BaseRolloutGenerator(ABC): 8 | @abstractmethod 9 | def generate_rollouts(self) -> Iterator[ExperienceBuffer]: 10 | raise NotImplementedError 11 | 12 | @abstractmethod 13 | def update_parameters(self, new_params): 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /rocket_learn/agent/actor_critic_agent.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | from rocket_learn.agent.policy import Policy 5 | 6 | 7 | class ActorCriticAgent(nn.Module): 8 | def __init__(self, actor: Policy, critic: nn.Module, optimizer: th.optim.Optimizer): 9 | super().__init__() 10 | self.actor = actor 11 | self.critic = critic 12 | self.optimizer = optimizer 13 | # self.algo = ? 14 | # TODO self.shared = 15 | 16 | def forward(self, *args, **kwargs): 17 | return self.actor(*args, **kwargs), self.critic(*args, **kwargs) 18 | -------------------------------------------------------------------------------- /rocket_learn/learner.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import cloudpickle 4 | 5 | 6 | class CloudpickleWrapper: 7 | """ 8 | ** Copied from SB3 ** 9 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 10 | :param var: the variable you wish to wrap for pickling with cloudpickle 11 | """ 12 | 13 | def __init__(self, var: Any): 14 | self.var = var 15 | 16 | def __getstate__(self) -> Any: 17 | return cloudpickle.dumps(self.var) 18 | 19 | def __setstate__(self, var: Any) -> None: 20 | self.var = cloudpickle.loads(var) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | long_description = """ 4 | 5 | # Rocket Learn 6 | 7 | Learning! 8 | 9 | """ 10 | 11 | setup( 12 | name='rocket_learn', 13 | version='0.2.8', 14 | description='Rocket Learn', 15 | author='Rolv-Arild Braaten, Daniel Downs', 16 | url='https://github.com/Rolv-Arild/rocket-learn', 17 | packages=[package for package in find_packages() if package.startswith("rocket_learn")], 18 | long_description=long_description, 19 | install_requires=['cloudpickle==1.6.0', 'gym', 'torch', 'tqdm', 'trueskill', 20 | 'msgpack_numpy', 'wandb', 'pygame', 'keyboard', 'tabulate'], 21 | ) 22 | -------------------------------------------------------------------------------- /rocket_learn/utils/truncated_condition.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from rlgym.utils import TerminalCondition 4 | from rlgym.utils.gamestates import GameState 5 | 6 | 7 | class TruncatedCondition(TerminalCondition, ABC): 8 | def is_truncated(self, current_state: GameState): 9 | raise NotImplementedError 10 | 11 | 12 | class TerminalToTruncatedWrapper(TruncatedCondition): 13 | def __init__(self, condition: TerminalCondition): 14 | super().__init__() 15 | self.condition = condition 16 | 17 | def is_truncated(self, current_state: GameState): 18 | return self.condition.is_terminal(current_state) 19 | 20 | def reset(self, initial_state: GameState): 21 | self.condition.reset(initial_state) 22 | 23 | def is_terminal(self, current_state: GameState) -> bool: 24 | return False 25 | -------------------------------------------------------------------------------- /rocket_learn/utils/dynamic_gamemode_setter.py: -------------------------------------------------------------------------------- 1 | from rlgym.utils import StateSetter 2 | from rlgym.utils.state_setters import StateWrapper 3 | 4 | 5 | class DynamicGMSetter(StateSetter): 6 | def __init__(self, setter: StateSetter): 7 | self.setter = setter 8 | self.blue = 0 9 | self.orange = 0 10 | 11 | def set_team_size(self, blue=None, orange=None): 12 | if blue is not None: 13 | self.blue = blue 14 | if orange is not None: 15 | self.orange = orange 16 | 17 | def build_wrapper(self, max_team_size: int, spawn_opponents: bool) -> StateWrapper: 18 | assert self.blue <= max_team_size and self.orange <= max_team_size 19 | return StateWrapper(self.blue, self.orange) 20 | 21 | def reset(self, state_wrapper: StateWrapper): 22 | self.setter.reset(state_wrapper) 23 | -------------------------------------------------------------------------------- /rocket_learn/agent/policy.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from torch import nn 4 | 5 | class Policy(nn.Module, ABC): 6 | def __init__(self, deterministic=False): 7 | super().__init__() 8 | self.deterministic = deterministic 9 | 10 | @abstractmethod 11 | def forward(self, *args, **kwargs): raise NotImplementedError 12 | 13 | @abstractmethod 14 | def get_action_distribution(self, obs): raise NotImplementedError 15 | 16 | @staticmethod 17 | @abstractmethod 18 | def sample_action(distribution, deterministic=None): raise NotImplementedError 19 | 20 | @staticmethod 21 | @abstractmethod 22 | def log_prob(distribution, selected_action): raise NotImplementedError 23 | 24 | @staticmethod 25 | @abstractmethod 26 | def entropy(distribution, selected_action): raise NotImplementedError 27 | 28 | @abstractmethod 29 | def env_compatible(self, action): raise NotImplementedError 30 | -------------------------------------------------------------------------------- /rocket_learn/rollout_generator/simple_rollout_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | 3 | import rlgym 4 | from rocket_learn.agent.policy import Policy 5 | from rocket_learn.experience_buffer import ExperienceBuffer 6 | from rocket_learn.rollout_generator.base_rollout_generator import BaseRolloutGenerator 7 | from rocket_learn.utils.generate_episode import generate_episode 8 | 9 | 10 | class SimpleRolloutGenerator(BaseRolloutGenerator): 11 | def __init__(self, policy: Policy, **make_args): 12 | self.env = rlgym.make(**make_args) 13 | self.policy = policy 14 | self.n_agents = self.env._match.agents 15 | 16 | def generate_rollouts(self) -> Iterator[ExperienceBuffer]: 17 | while True: 18 | # TODO: need to add selfplay agent here? 19 | rollouts, result = generate_episode(self.env, [self.policy] * self.n_agents) 20 | 21 | yield from rollouts 22 | 23 | def update_parameters(self, new_params): 24 | self.policy = new_params 25 | -------------------------------------------------------------------------------- /docs/sb3_to_rocketlearn_transition.txt: -------------------------------------------------------------------------------- 1 | SWITCHING FROM STABLE-BASELINES3 TO ROCKET-LEARN 2 | 3 | When you install rocket-learn, verify that the version of rlgym is compatible with rocket-learn. 4 | Use only the release version of rlgym. Do not use beta of rlgym unless you are attempting to beta test. 5 | When setting up your environment, make sure you do not install rlgym from a cached version on your 6 | machine, and verify that the dll in the bakkesmod plugin folder is accurate. 7 | 8 | SB3 abstracts away several important parts of ML training that rocket-learn does not. 9 | 10 | -your rewards will not be normalized 11 | -your networks will not have orthogonal initialization by default (assuming you use PPO) 12 | 13 | This can drastically affect the results you get and it is not uncommon to not see the same results 14 | in rocket-learn as you did in SB3, at least until you make tweaks. In addition to the major 15 | differences listed above, differences in implementation in learning algorithms can cause large 16 | changes in results. Be prepared to do some extra tweaking as a part of the switch. -------------------------------------------------------------------------------- /rocket_learn/rollout_generator/api_rollout_generator.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator 2 | from uuid import uuid4 3 | 4 | from fastapi import FastAPI # noqa 5 | from pydantic import Field 6 | from pydantic.main import Model # noqa 7 | from starlette.middleware.gzip import GZipMiddleware 8 | 9 | from rocket_learn.experience_buffer import ExperienceBuffer 10 | from rocket_learn.rollout_generator.base_rollout_generator import BaseRolloutGenerator 11 | 12 | 13 | app = FastAPI( 14 | title="rocket-learn-api", 15 | version="0.1.0" 16 | ) 17 | 18 | workers = {} 19 | 20 | 21 | @app.post("/worker") 22 | async def create_worker(name: str = "contributor"): 23 | uuid = uuid4() 24 | workers[uuid] = name 25 | return uuid 26 | 27 | @app.get("/matchup") 28 | async def get_matchup(mode: int): 29 | qualities = [...] 30 | 31 | @app.post("/rollout") 32 | async def rollout(obs_rew_probs: bytes, ): 33 | 34 | 35 | 36 | class ApiRolloutGenerator(BaseRolloutGenerator): 37 | def generate_rollouts(self) -> Iterator[ExperienceBuffer]: 38 | pass 39 | 40 | def update_parameters(self, new_params): 41 | pass -------------------------------------------------------------------------------- /examples/default/sb3_to_rocketlearn_transition.txt: -------------------------------------------------------------------------------- 1 | SWITCHING FROM STABLE-BASELINES3 TO ROCKET-LEARN 2 | 3 | When you install rocket-learn, verify that the version of rlgym is compatible with rocket-learn. 4 | Use only the release version of rlgym. Do not use beta of rlgym unless you are attempting to beta test. 5 | When setting up your environment, make sure you do not install rlgym from a cached version on your 6 | machine, and verify that the dll in the bakkesmod plugin folder is accurate. 7 | 8 | SB3 abstracts away several important parts of ML training that rocket-learn does not. 9 | 10 | -your rewards will not be normalized 11 | -your networks will not have orthogonal initialization by default (assuming you use PPO) 12 | 13 | This can drastically affect the results you get and it is not uncommon to not see the same results 14 | in rocket-learn as you did in SB3, at least until you make tweaks. In addition to the major 15 | differences listed above, differences in implementation in learning algorithms can cause large 16 | changes in results. Be prepared to do some extra tweaking as a part of the switch. 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /rocket_learn/agent/pretrained_policy.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple 3 | from torch import nn 4 | 5 | from rocket_learn.agent.discrete_policy import DiscretePolicy 6 | 7 | from rlgym.utils.gamestates import GameState 8 | 9 | 10 | class HardcodedAgent(ABC): 11 | """ 12 | An external bot prebuilt and imported to be trained against 13 | """ 14 | 15 | @abstractmethod 16 | def act(self, state: GameState, player_index: int): raise NotImplementedError 17 | 18 | 19 | class PretrainedDiscretePolicy(DiscretePolicy, HardcodedAgent): 20 | """ 21 | A rocket-learn discrete policy pretrained and imported to be trained against 22 | 23 | :param obs_builder_func: Function that will generate the correct observation from the gamestate 24 | :param net: policy net 25 | :param shape: action distribution shape 26 | """ 27 | 28 | def __init__(self, obs_builder_func, net: nn.Module, shape: Tuple[int, ...] = (3,) * 5 + (2,) * 3): 29 | super().__init__(net, shape) 30 | self.obs_builder_func = obs_builder_func 31 | 32 | def act(self, state: GameState, player_index): 33 | obs = self.obs_builder_func(state) 34 | dist = self.get_action_distribution(obs) 35 | action_indices = self.sample_action(dist, deterministic=False) 36 | actions = self.env_compatible(action_indices) 37 | 38 | return actions 39 | 40 | 41 | class DemoDriveAgent(HardcodedAgent): 42 | def act(self, state: GameState, player_index: int): 43 | return [2, 1, 1, 0, 0, 0, 0, 0] 44 | 45 | 46 | class DemoKBMDriveAgent(HardcodedAgent): 47 | def act(self, state: GameState, player_index: int): 48 | return [2, 1, 0, 0, 0] 49 | -------------------------------------------------------------------------------- /docs/troubleshooting.txt: -------------------------------------------------------------------------------- 1 | TROUBLESHOOTING COMMON PROBLEMS 2 | 3 | 4 | RuntimeError: mat1 and mat2 shapes cannot be multiplied (AxB and CxD) 5 | 6 | Compare the actor and critic input size to the observation size. They need to be identical. 7 | Remember that changing team size and selfplay can change the observation size. 8 | 9 | 10 | _______________________________________________________________________________________________________ 11 | Blue Screen of Death error (related to wandb) 12 | 13 | 1) Rollback your Nvidia driver to WHQL certified September 2021 14 | 15 | OR 2) Comment out the following lines of code in the wandb repo: 16 | 17 | try: 18 | pynvml.nvmlInit() 19 | self.gpu_count = pynvml.nvmlDeviceGetCount() 20 | except pynvml.NVMLError: 21 | 22 | 23 | _______________________________________________________________________________________________________ 24 | Can't get Redis to work 25 | 26 | -WSL2 is probably the easiest way to get things working. 27 | -Double check that you can ping the redis server locally 28 | 29 | 30 | _______________________________________________________________________________________________________ 31 | There are no errors but changes I'm making don't seem to be affecting anything 32 | 33 | -Double check that observations, rewards, and action parsers are the same on both the learner 34 | and workers. 35 | 36 | 37 | 38 | _______________________________________________________________________________________________________ 39 | wandb is not working properly or giving you hint errors in your IDE 40 | 41 | -check that there isn’t a folder created called wandb in your project 42 | -this is created automatically by wandb, you can change the name in the init call 43 | so it doesn’t interfere. 44 | -------------------------------------------------------------------------------- /rocket_learn/utils/batched_obs_builder.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union, Optional 2 | 3 | import numpy as np 4 | from rlgym.utils import ObsBuilder 5 | from rlgym.utils.gamestates import PlayerData, GameState 6 | 7 | from rocket_learn.utils.gamestate_encoding import encode_gamestate 8 | from rocket_learn.utils.scoreboard import Scoreboard 9 | 10 | 11 | class BatchedObsBuilder(ObsBuilder): 12 | def __init__(self, scoreboard: Optional[Scoreboard] = None): 13 | super().__init__() 14 | self.current_state = None 15 | self.current_obs = None 16 | self.scoreboard = scoreboard 17 | 18 | def batched_build_obs(self, encoded_states: np.ndarray) -> Any: 19 | raise NotImplementedError 20 | 21 | def add_actions(self, obs: Any, previous_actions: np.ndarray, player_index=None): 22 | # Modify current obs to include action 23 | # player_index=None means actions for all players should be provided 24 | raise NotImplementedError 25 | 26 | def _reset(self, initial_state: GameState): 27 | raise NotImplementedError 28 | 29 | def reset(self, initial_state: GameState): 30 | self.current_state = False 31 | self.current_obs = None 32 | if self.scoreboard is not None: 33 | self.scoreboard.reset(initial_state) 34 | self._reset(initial_state) 35 | 36 | def pre_step(self, state: GameState): 37 | if state != self.current_state: 38 | if self.scoreboard is not None: 39 | self.scoreboard.step(state) 40 | self.current_obs = self.batched_build_obs( 41 | np.expand_dims(encode_gamestate(state), axis=0) 42 | ) 43 | self.current_state = state 44 | 45 | def build_obs(self, player: PlayerData, state: GameState, previous_action: np.ndarray) -> Any: 46 | for i, p in enumerate(state.players): 47 | if p == player: 48 | self.add_actions(self.current_obs, previous_action, i) 49 | return self.current_obs[i] 50 | -------------------------------------------------------------------------------- /rocket_learn/simple_agents.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch as th 5 | from torch.distributions import Categorical 6 | 7 | from rocket_learn.agent import BaseAgent 8 | 9 | 10 | class RandomAgent(BaseAgent): 11 | """Does softmax using biases alone""" 12 | 13 | def __init__(self, throttle=(0, 1, 2), steer=(0, 1, 0), pitch=(0, 0, 0), yaw=(0, 0, 0), 14 | roll=(0, 0, 0), jump=(3, 0), boost=(0, 0), handbrake=(3, 0)): 15 | super().__init__() 16 | self.distributions = [ 17 | Categorical(logits=th.as_tensor(logits).float()) 18 | for logits in (throttle, steer, pitch, yaw, roll, jump, boost, handbrake) 19 | ] 20 | 21 | def get_actions(self, observation, deterministic=False) -> np.ndarray: 22 | actions = np.stack([dist.sample() for dist in self.distributions]) 23 | return actions 24 | 25 | def get_action_with_log_prob(self, observation) -> Tuple[np.ndarray, float]: 26 | pass 27 | 28 | def set_model_params(self, params) -> None: 29 | pass 30 | 31 | # def set_model_params(self, params): 32 | # self.distributions = [ 33 | # Categorical(logits=th.as_tensor(logits).float()) 34 | # for logits in params 35 | # ] 36 | # 37 | # def get_actions(self, observation, deterministic=False): 38 | # actions = np.stack([dist.sample() for dist in self.distributions]) 39 | # return actions 40 | # 41 | # def get_log_prob(self, actions): 42 | # return th.stack( 43 | # [dist.log_prob(action) for dist, action in zip(self.distributions, th.unbind(actions, dim=1))], dim=1 44 | # ).sum(dim=1) 45 | 46 | 47 | # ** This should be in its own file or packaged with PPO ** 48 | 49 | 50 | class NoOpAgent(BaseAgent): 51 | def get_actions(self, observation, deterministic=False): 52 | return th.zeros((8,)) 53 | 54 | def get_log_prob(self, actions): 55 | return 0 56 | 57 | def set_model_params(self, params): 58 | pass 59 | -------------------------------------------------------------------------------- /rocket_learn/experience_buffer.py: -------------------------------------------------------------------------------- 1 | class ExperienceBuffer: 2 | def __init__(self, observations=None, actions=None, rewards=None, dones=None, log_probs=None, infos=None): 3 | self.result = 0 4 | self.observations = [] 5 | self.actions = [] 6 | self.rewards = [] 7 | self.dones = [] 8 | self.log_probs = [] 9 | self.infos = [] 10 | 11 | if observations is not None: 12 | self.observations = observations 13 | 14 | if actions is not None: 15 | self.actions = actions 16 | 17 | if rewards is not None: 18 | self.rewards = rewards 19 | 20 | if dones is not None: 21 | self.dones = dones # TODO Done probably doesn't need to be a list, will always just be false until last? 22 | 23 | if log_probs is not None: 24 | self.log_probs = log_probs 25 | 26 | if infos is not None: 27 | self.infos = infos 28 | 29 | def size(self): 30 | return len(self.rewards) 31 | 32 | def add_step(self, observation, action, reward, done, log_prob, info): 33 | self.observations.append(observation) 34 | self.actions.append(action) 35 | self.rewards.append(reward) 36 | self.dones.append(done) 37 | self.log_probs.append(log_prob) 38 | self.infos.append(info) 39 | 40 | def clear(self): 41 | self.observations = [] 42 | self.actions = [] 43 | self.rewards = [] 44 | self.dones = [] 45 | self.log_probs = [] 46 | self.infos = [] 47 | 48 | def generate_slices(self, batch_size): 49 | for i in range(0, len(self.observations), batch_size): 50 | yield ExperienceBuffer(self.observations[i:i + batch_size], 51 | self.actions[i:i + batch_size], 52 | self.rewards[i:i + batch_size], 53 | self.dones[i:i + batch_size], 54 | self.log_probs[i:i + batch_size], 55 | self.infos[i:i + batch_size]) 56 | -------------------------------------------------------------------------------- /rocket_learn/utils/util.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | import torch 5 | import torch.distributions 6 | from rlgym.utils.gamestates import GameState, PlayerData 7 | from rlgym.utils.obs_builders import AdvancedObs 8 | from torch import nn 9 | 10 | 11 | def softmax(x): 12 | """Compute softmax values for each sets of scores in x.""" 13 | e_x = np.exp(x - np.max(x)) 14 | return e_x / e_x.sum() 15 | 16 | 17 | class SplitLayer(nn.Module): 18 | def __init__(self, splits=None): 19 | super().__init__() 20 | if splits is not None: 21 | self.splits = splits 22 | else: 23 | self.splits = (3,) * 5 + (2,) * 3 24 | 25 | def forward(self, x): 26 | return torch.split(x, self.splits, dim=-1) 27 | 28 | 29 | # TODO AdvancedObs should be supported by default, use stack instead of cat 30 | class ExpandAdvancedObs(AdvancedObs): 31 | def build_obs(self, player: PlayerData, state: GameState, previous_action: np.ndarray) -> Any: 32 | return np.reshape( 33 | super(ExpandAdvancedObs, self).build_obs(player, state, previous_action), 34 | (1, -1) 35 | ) 36 | 37 | 38 | def probability_NvsM(team1_ratings, team2_ratings, env=None): 39 | from trueskill import global_env 40 | # Trueskill extension, source: https://github.com/sublee/trueskill/pull/17 41 | """Calculates the win probability of the first team over the second team. 42 | :param team1_ratings: ratings of the first team participants. 43 | :param team2_ratings: ratings of another team participants. 44 | :param env: the :class:`TrueSkill` object. Defaults to the global 45 | environment. 46 | """ 47 | if env is None: 48 | env = global_env() 49 | 50 | team1_mu = sum(r.mu for r in team1_ratings) 51 | team1_sigma = sum((env.beta ** 2 + r.sigma ** 2) for r in team1_ratings) 52 | team2_mu = sum(r.mu for r in team2_ratings) 53 | team2_sigma = sum((env.beta ** 2 + r.sigma ** 2) for r in team2_ratings) 54 | 55 | x = (team1_mu - team2_mu) / np.sqrt(team1_sigma + team2_sigma) 56 | probability_win_team1 = env.cdf(x) 57 | return probability_win_team1 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rocket-learn 2 | 3 | ## What is rocket-learn? 4 | 5 | rocket-learn is a machine learning framework specifically designed for Rocket League Reinforcement Learning. 6 | It works in conjunction with Rocket League, RLGym, and Bakkesmod. 7 | 8 | ## What features does rocket-learn have? 9 | 10 | 24 | 25 | 26 | ## Should I use rocket-learn? 27 | 28 | You should use Stable Baselines3 (SB3) to make your bot at first. The hardest parts of building a 29 | machine learning bot are 30 | 31 | - understanding how to program 32 | - understanding how machine learning works 33 | - choosing good hyperparameters 34 | - choosing good reward functions 35 | - choosing an action parser 36 | - making a statesetter that puts the bot in the best situations 37 | 38 | SB3 is a great way to figure out those essential parts. Once you have all of those aspects down, rocket-learn 39 | may be a good next step to a better machine learning bot. 40 | 41 | If you *don't* yet have these, rocket-learn will add a large amount of complexity for no added benefit. It's 42 | important to remember that high compute and a tough opponent are less important than good fundamentals of ML. 43 | 44 | ## How do I setup rocket-learn? 45 | 46 | 1) Get [Redis](https://docs.servicestack.net/install-redis-windows) running 47 | 48 | *__Improper Redis setup can leave your computer extremely vulnerable to Bad Guys. 49 | We are not responsible for your computer's safety. We assume you know what you are doing.__* 50 | 51 | 2) Clone the repo 52 | 53 | ``` 54 | git clone https://github.com/Rolv-Arild/rocket-learn.git 55 | ``` 56 | 57 | 3) Start up, in order: 58 | 59 | - the Redis server 60 | - the Learner 61 | - the Workers 62 | 63 | Look at the examples to get up and running 64 | -------------------------------------------------------------------------------- /docs/network_setup_readme.txt: -------------------------------------------------------------------------------- 1 | NETWORK INPUT 2 | 3 | rocket-learn expects both actor and critic networks to have an input dimension equal to observation length. 4 | If the observation outputs an array of size (1, 150) (note the batch dimension of the observation output), 5 | then the network input should be 150. As an example: 6 | 7 | actor = DiscretePolicy(Sequential( 8 | Linear(150, 256), 9 | ReLU(), 10 | Linear(256, 256), 11 | ReLU(), 12 | Linear(256, total_output), 13 | SplitLayer(splits=split) 14 | ), split) 15 | 16 | 17 | 18 | __________________________________________________________________________________________________ 19 | NETWORK OUTPUT 20 | 21 | rocket-learn expects actor networks to output a set of probablities for each possible action. For example, 22 | the default Discrete Action allows 8 actions, 5 of which are discrete control choices and 3 23 | of which are boolean choices. Because the Discrete control choices can each be -1, 0, or 1 and each 24 | boolean can be True or False, the network must output ((5 * 3) + (3 * 2)) aka 21 total actions. The actions 25 | must then be split into properly sized groups for each actions. 26 | 27 | split = (3, 3, 3, 3, 3, 2, 2, 2) 28 | total_output = sum(split) 29 | 30 | class SplitLayer(nn.Module): 31 | def __init__(self, splits=(3, 3, 3, 3, 3, 2, 2, 2)): 32 | super().__init__() 33 | self.splits = splits 34 | 35 | def forward(self, x): 36 | return torch.split(x, self.splits, dim=-1) 37 | 38 | actor = DiscretePolicy(nn.Sequential( 39 | nn.Linear(INPUT_SIZE, 256), 40 | nn.ReLU(), 41 | nn.Linear(256, total_output), 42 | SplitLayer(split) 43 | ), split) 44 | 45 | As another example, KBM actions allow 2 Discrete controls and 3 boolean controls so the network must 46 | output ((2 * 3) + (3 * 2)) aka 12 total actions 47 | 48 | split = (3, 3, 2, 2, 2) 49 | total_output = sum(split) 50 | 51 | class SplitLayer(nn.Module): 52 | def __init__(self, splits=(3, 3, 3, 3, 3, 2, 2, 2)): 53 | super().__init__() 54 | self.splits = splits 55 | 56 | def forward(self, x): 57 | return torch.split(x, self.splits, dim=-1) 58 | 59 | actor = DiscretePolicy(nn.Sequential( 60 | nn.Linear(INPUT_SIZE, 256), 61 | nn.ReLU(), 62 | nn.Linear(256, total_output), 63 | SplitLayer(split) 64 | ), split) -------------------------------------------------------------------------------- /rocket_learn/agent/discrete_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | 3 | import numpy as np 4 | import torch as th 5 | from torch import nn 6 | from torch.distributions import Categorical 7 | import torch.nn.functional as F 8 | 9 | from rocket_learn.agent.policy import Policy 10 | 11 | 12 | class DiscretePolicy(Policy): 13 | def __init__(self, net: nn.Module, shape: Tuple[int, ...] = (3,) * 5 + (2,) * 3, deterministic=False): 14 | super().__init__(deterministic) 15 | self.net = net 16 | self.shape = shape 17 | 18 | def forward(self, obs): 19 | logits = self.net(obs) 20 | return logits 21 | 22 | def get_action_distribution(self, obs): 23 | if isinstance(obs, np.ndarray): 24 | obs = th.from_numpy(obs).float() 25 | elif isinstance(obs, tuple): 26 | obs = tuple(o if isinstance(o, th.Tensor) else th.from_numpy(o).float() for o in obs) 27 | 28 | logits = self(obs) 29 | 30 | if isinstance(logits, th.Tensor): 31 | logits = (logits,) 32 | 33 | max_shape = max(self.shape) 34 | logits = th.stack( 35 | [ 36 | l 37 | if l.shape[-1] == max_shape 38 | else F.pad(l, pad=(0, max_shape - l.shape[-1]), value=float("-inf")) 39 | for l in logits 40 | ], 41 | dim=1 42 | ) 43 | 44 | return Categorical(logits=logits) 45 | 46 | def sample_action( 47 | self, 48 | distribution: Categorical, 49 | deterministic=None 50 | ): 51 | if deterministic is None: 52 | deterministic = self.deterministic 53 | if deterministic: 54 | action_indices = th.argmax(distribution.logits, dim=-1) 55 | else: 56 | action_indices = distribution.sample() 57 | 58 | return action_indices 59 | 60 | def log_prob(self, distribution: Categorical, selected_action): 61 | log_prob = distribution.log_prob(selected_action).sum(dim=-1) 62 | return log_prob 63 | 64 | def entropy(self, distribution: Categorical, selected_action): 65 | entropy = distribution.entropy().sum(dim=-1) 66 | return entropy 67 | 68 | def env_compatible(self, action): 69 | if isinstance(action, th.Tensor): 70 | action = action.numpy() 71 | return action 72 | -------------------------------------------------------------------------------- /rocket_learn/agent/pretrained_agents/necto/necto_v1.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from rocket_learn.agent.pretrained_policy import HardcodedAgent 9 | from pretrained_agents.necto.necto_v1_obs import NectoV1Obs 10 | 11 | from rlgym.utils.gamestates import GameState 12 | 13 | import copy 14 | 15 | 16 | class NectoV1(HardcodedAgent): 17 | def __init__(self, model_string, n_players): 18 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 19 | self.actor = torch.jit.load(os.path.join(cur_dir, model_string)) 20 | self.obs_builder = NectoV1Obs(n_players=n_players) 21 | self.previous_action = np.array([0, 0, 0, 0, 0, 0, 0, 0]) 22 | 23 | def act(self, state: GameState, player_index: int): 24 | player = state.players[player_index] 25 | teammates = [p for p in state.players if p.team_num == player.team_num and p != player] 26 | opponents = [p for p in state.players if p.team_num != player.team_num] 27 | 28 | necto_gamestate: GameState = copy.deepcopy(state) 29 | necto_gamestate.players = [player] + teammates + opponents 30 | 31 | self.obs_builder.reset(necto_gamestate) 32 | obs = self.obs_builder.build_obs(player, necto_gamestate, self.previous_action) 33 | 34 | obs = tuple(torch.from_numpy(s).float() for s in obs) 35 | with torch.no_grad(): 36 | out, _ = self.actor(obs) 37 | 38 | max_shape = max(o.shape[-1] for o in out) 39 | logits = torch.stack( 40 | [ 41 | l 42 | if l.shape[-1] == max_shape 43 | else F.pad(l, pad=(0, max_shape - l.shape[-1]), value=float("-inf")) 44 | for l in out 45 | ] 46 | ).swapdims(0, 1).squeeze() 47 | 48 | actions = np.argmax(logits, axis=-1) 49 | 50 | actions = actions.reshape((-1, 5)) 51 | actions[:, 0] = actions[:, 0] - 1 52 | actions[:, 1] = actions[:, 1] - 1 53 | 54 | parsed = np.zeros((actions.shape[0], 8)) 55 | parsed[:, 0] = actions[:, 0] # throttle 56 | parsed[:, 1] = actions[:, 1] # steer 57 | parsed[:, 2] = actions[:, 0] # pitch 58 | parsed[:, 3] = actions[:, 1] * (1 - actions[:, 4]) # yaw 59 | parsed[:, 4] = actions[:, 1] * actions[:, 4] # roll 60 | parsed[:, 5] = actions[:, 2] # jump 61 | parsed[:, 6] = actions[:, 3] # boost 62 | parsed[:, 7] = actions[:, 4] # handbrake 63 | 64 | self.previous_action = parsed[0] 65 | return parsed[0] 66 | -------------------------------------------------------------------------------- /examples/human_trainer/worker_with_human_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import numpy 3 | 4 | from redis import Redis 5 | 6 | from rlgym.envs import Match 7 | from rlgym.utils.gamestates import PlayerData, GameState 8 | from rlgym.utils.terminal_conditions.common_conditions import GoalScoredCondition, TimeoutCondition 9 | from rlgym.utils.reward_functions.default_reward import DefaultReward 10 | from rlgym.utils.state_setters.default_state import DefaultState 11 | from rlgym.utils.obs_builders.advanced_obs import AdvancedObs 12 | from rlgym.utils.action_parsers.discrete_act import DiscreteAction 13 | 14 | from rocket_learn.rollout_generator.redis.redis_rollout_worker import RedisRolloutWorker 15 | from rocket_learn.agent.pretrained_agents.human_agent import HumanAgent 16 | 17 | 18 | # ROCKET-LEARN ALWAYS EXPECTS A BATCH DIMENSION IN THE BUILT OBSERVATION 19 | class ExpandAdvancedObs(AdvancedObs): 20 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any: 21 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action) 22 | return numpy.expand_dims(obs, 0) 23 | 24 | """ 25 | 26 | Allows the worker to run a human player, letting the AI play against and learn from human interation. 27 | 28 | Important things to note: 29 | 30 | -The human will always be blue due to RLGym camera constraints 31 | -Attempting to run a human trainer and pretrained agents will cause the pretrained agents to be ignored. 32 | They will never show up. 33 | 34 | """ 35 | 36 | if __name__ == "__main__": 37 | match = Match( 38 | game_speed=1, 39 | self_play=True, 40 | team_size=1, 41 | state_setter=DefaultState(), 42 | obs_builder=ExpandAdvancedObs(), 43 | action_parser=DiscreteAction(), 44 | terminal_conditions=[TimeoutCondition(round(2000)), 45 | GoalScoredCondition()], 46 | reward_function=DefaultReward() 47 | ) 48 | 49 | # ALLOW HUMAN CONTROL THROUGH MOUSE AND KEYBOARD OR A CONTROLLER IF ONE IS PLUGGED IN 50 | # -CONTROL BINDINGS ARE CURRENTLY NOT CHANGEABLE 51 | # -CONTROLLER SETUP CURRENTLY EXPECTS AN XBOX 360 CONTROLLER. OTHERS WILL WORK BUT PROBABLY NOT WELL 52 | human = HumanAgent() 53 | 54 | r = Redis(host="127.0.0.1", password="you_better_use_a_password") 55 | 56 | # LAUNCH ROCKET LEAGUE AND BEGIN TRAINING 57 | # -human_agent TELLS RLGYM THAT THE FIRST AGENT IS ALWAYS TO BE HUMAN CONTROLLED 58 | # -past_version_prob SPECIFIES HOW OFTEN OLD VERSIONS WILL BE RANDOMLY SELECTED AND TRAINED AGAINST 59 | RedisRolloutWorker(r, "exampleHumanWorker", match, human_agent=human, past_version_prob=.05).run() 60 | -------------------------------------------------------------------------------- /rocket_learn/agent/pretrained_agents/human_agent.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import keyboard 3 | 4 | from rlgym.utils.gamestates import GameState 5 | 6 | from rocket_learn.agent.pretrained_policy import HardcodedAgent 7 | 8 | class HumanAgent(HardcodedAgent): 9 | def __init__(self): 10 | pygame.init() 11 | self.controller_map = {} 12 | 13 | self.joystick = None 14 | if pygame.joystick.get_count() > 0: 15 | self.joystick = pygame.joystick.Joystick(0) 16 | self.joystick.init() 17 | print("Controller found") 18 | 19 | def controller_actions(self, state): 20 | player = [p for p in state.players if p.team_num == 0][0] 21 | # allow controller to activate 22 | pygame.event.pump() 23 | 24 | jump = self.joystick.get_button(0) 25 | boost = self.joystick.get_button(1) 26 | handbrake = self.joystick.get_button(2) 27 | 28 | throttle = self.joystick.get_axis(5) 29 | throttle = max(0, throttle) 30 | 31 | reverse_throttle = self.joystick.get_axis(4) 32 | reverse_throttle = max(0, reverse_throttle) 33 | 34 | throttle = throttle - reverse_throttle 35 | 36 | steer = self.joystick.get_axis(0) 37 | if abs(steer) < .2: 38 | steer = 0 39 | 40 | pitch = self.joystick.get_axis(1) 41 | if abs(pitch) < .2: 42 | pitch = 0 43 | 44 | yaw = steer 45 | 46 | roll = 0 47 | roll_button = self.joystick.get_button(4) 48 | if roll_button or jump: 49 | roll = steer 50 | 51 | return [throttle, steer, pitch, yaw, roll, jump, boost, handbrake] 52 | 53 | 54 | def kbm_actions(self, state): 55 | player = [p for p in state.players if p.team_num == 0][0] 56 | 57 | throttle = 0 58 | if keyboard.is_pressed('w'): 59 | throttle = 1 60 | if keyboard.is_pressed('s'): 61 | throttle = -1 62 | 63 | steer = 0 64 | if keyboard.is_pressed('d'): 65 | steer = 1 66 | if keyboard.is_pressed('a'): 67 | steer = -1 68 | 69 | pitch = -throttle 70 | 71 | yaw = steer 72 | 73 | roll = 0 74 | if keyboard.is_pressed('e'): 75 | roll = 1 76 | if keyboard.is_pressed('q'): 77 | roll = -1 78 | 79 | jump = 0 80 | if keyboard.is_pressed('f'): 81 | jump = 1 82 | 83 | boost = 0 84 | handbrake = 0 85 | 86 | return [throttle, steer, pitch, yaw, roll, jump, boost, handbrake] 87 | 88 | def act(self, state: GameState, player_index: int): 89 | if self.joystick: 90 | actions = self.controller_actions(state) 91 | else: 92 | actions = self.kbm_actions(state) 93 | 94 | return actions 95 | -------------------------------------------------------------------------------- /examples/default/worker.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import numpy 3 | 4 | from redis import Redis 5 | 6 | from rlgym.envs import Match 7 | from rlgym.utils.gamestates import PlayerData, GameState 8 | from rlgym.utils.terminal_conditions.common_conditions import GoalScoredCondition, TimeoutCondition 9 | from rlgym.utils.reward_functions.default_reward import DefaultReward 10 | from rlgym.utils.state_setters.default_state import DefaultState 11 | from rlgym.utils.obs_builders.advanced_obs import AdvancedObs 12 | from rlgym.utils.action_parsers.discrete_act import DiscreteAction 13 | 14 | from rocket_learn.rollout_generator.redis.redis_rollout_worker import RedisRolloutWorker 15 | 16 | 17 | # ROCKET-LEARN ALWAYS EXPECTS A BATCH DIMENSION IN THE BUILT OBSERVATION 18 | class ExpandAdvancedObs(AdvancedObs): 19 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any: 20 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action) 21 | return numpy.expand_dims(obs, 0) 22 | 23 | 24 | if __name__ == "__main__": 25 | """ 26 | 27 | Starts up a rocket-learn worker process, which plays out a game, sends back game data to the 28 | learner, and receives updated model parameters when available 29 | 30 | """ 31 | 32 | # OPTIONAL ADDITION: 33 | # LIMIT TORCH THREADS TO 1 ON THE WORKERS TO LIMIT TOTAL RESOURCE USAGE 34 | # TRY WITH AND WITHOUT FOR YOUR SPECIFIC HARDWARE 35 | import torch 36 | 37 | torch.set_num_threads(1) 38 | 39 | # BUILD THE ROCKET LEAGUE MATCH THAT WILL USED FOR TRAINING 40 | # -ENSURE OBSERVATION, REWARD, AND ACTION CHOICES ARE THE SAME IN THE WORKER 41 | match = Match( 42 | game_speed=100, 43 | spawn_opponents=True, 44 | team_size=1, 45 | state_setter=DefaultState(), 46 | obs_builder=ExpandAdvancedObs(), 47 | action_parser=DiscreteAction(), 48 | terminal_conditions=[TimeoutCondition(round(2000)), 49 | GoalScoredCondition()], 50 | reward_function=DefaultReward() 51 | ) 52 | 53 | # LINK TO THE REDIS SERVER YOU SHOULD HAVE RUNNING (USE THE SAME PASSWORD YOU SET IN THE REDIS 54 | # CONFIG) 55 | r = Redis(host="127.0.0.1", password="you_better_use_a_password") 56 | 57 | # LAUNCH ROCKET LEAGUE AND BEGIN TRAINING 58 | # -past_version_prob SPECIFIES HOW OFTEN OLD VERSIONS WILL BE RANDOMLY SELECTED AND TRAINED AGAINST 59 | RedisRolloutWorker(r, "example", match, 60 | past_version_prob=.2, 61 | evaluation_prob=0.01, 62 | sigma_target=2, 63 | dynamic_gm=False, 64 | send_obs=True, 65 | streamer_mode=False, 66 | send_gamestates=False, 67 | force_paging=False, 68 | auto_minimize=True, 69 | local_cache_name="example_model_database").run() 70 | -------------------------------------------------------------------------------- /rocket_learn/agent.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | import torch as th 6 | from torch.distributions import Categorical 7 | 8 | 9 | # class BaseAgent(ABC): 10 | # def __init__(self, index_action_map: Optional[np.ndarray] = None): 11 | # if index_action_map is None: 12 | # self.index_action_map = np.array([ 13 | # [-1., 0., 1.], 14 | # [-1., 0., 1.], 15 | # [-1., 0., 1.], 16 | # [-1., 0., 1.], 17 | # [-1., 0., 1.], 18 | # [0., 1., np.nan], 19 | # [0., 1., np.nan], 20 | # [0., 1., np.nan] 21 | # ]) 22 | # else: 23 | # self.index_action_map = index_action_map 24 | # 25 | # def forward_actor_critic(self, obs): 26 | # raise NotImplementedError 27 | # 28 | # def forward_actor(self, obs): 29 | # return self.forward_actor_critic(obs)[0] 30 | # 31 | # def forward_critic(self, obs): 32 | # return self.forward_actor_critic(obs)[1] 33 | # 34 | # def get_action_distribution(self, obs) -> List[Categorical]: 35 | # if isinstance(obs, np.ndarray): 36 | # obs = th.from_numpy(obs).float() 37 | # elif isinstance(obs, tuple): 38 | # obs = tuple(o if isinstance(o, th.Tensor) else th.from_numpy(o).float() for o in obs) 39 | # logits = self.forward_actor(obs) 40 | # 41 | # return [Categorical(logits=logit) for logit in logits] 42 | # 43 | # def get_action_indices(self, distribution: List[Categorical], deterministic=False, include_log_prob=False, 44 | # include_entropy=False): 45 | # if deterministic: 46 | # action_indices = th.stack([th.argmax(dist.logits) for dist in distribution]) 47 | # else: 48 | # action_indices = th.stack([dist.sample() for dist in distribution]) 49 | # 50 | # returns = [action_indices.numpy()] 51 | # if include_log_prob: 52 | # # SOREN NOTE: 53 | # # adding dim=1 is causing it to crash 54 | # 55 | # log_prob = th.stack( 56 | # [dist.log_prob(action) for dist, action in zip(distribution, th.unbind(action_indices, dim=-1))], dim=-1 57 | # ).sum(dim=-1) 58 | # returns.append(log_prob.numpy()) 59 | # if include_entropy: 60 | # entropy = th.stack([dist.entropy() for dist in distribution], dim=1).sum(dim=1) 61 | # returns.append(entropy.numpy()) 62 | # return tuple(returns) 63 | # 64 | # def get_action(self, action_indices) -> np.ndarray: 65 | # return self.index_action_map[np.arange(len(self.index_action_map)), action_indices] 66 | # 67 | # @abstractmethod 68 | # def get_model_params(self): 69 | # raise NotImplementedError 70 | # 71 | # @abstractmethod 72 | # def set_model_params(self, params) -> None: 73 | # raise NotImplementedError 74 | -------------------------------------------------------------------------------- /examples/pretrained_agent/worker_with_pretrained_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import numpy 3 | 4 | from redis import Redis 5 | 6 | from rlgym.envs import Match 7 | from rlgym.utils.gamestates import PlayerData, GameState 8 | from rlgym.utils.terminal_conditions.common_conditions import GoalScoredCondition, TimeoutCondition 9 | from rlgym.utils.reward_functions.default_reward import DefaultReward 10 | from rlgym.utils.state_setters.default_state import DefaultState 11 | from rlgym.utils.obs_builders.advanced_obs import AdvancedObs 12 | from rlgym.utils.action_parsers.discrete_act import DiscreteAction 13 | 14 | from rocket_learn.rollout_generator.redis.redis_rollout_worker import RedisRolloutWorker 15 | from rocket_learn.agent.pretrained_policy import HardcodedAgent 16 | 17 | from rocket_learn.agent.pretrained_agents.necto.necto_v1 import NectoV1 18 | 19 | # ROCKET-LEARN ALWAYS EXPECTS A BATCH DIMENSION IN THE BUILT OBSERVATION 20 | class ExpandAdvancedObs(AdvancedObs): 21 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any: 22 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action) 23 | return numpy.expand_dims(obs, 0) 24 | 25 | class TotallyDifferentObs(AdvancedObs): 26 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any: 27 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action) 28 | return numpy.expand_dims(obs, 0) 29 | 30 | 31 | # NEW HARDCODED AGENTS MUST IMPLEMENT THE act() METHOD, WHICH RETURNS AN ARRAY OF ACTIONS USED TO 32 | # CONTROL THE AGENT 33 | class DemoCustomAgent(HardcodedAgent): 34 | def act(self, state: GameState, player_index: int): 35 | player_data = state.players[player_index] 36 | 37 | return [2, 1, 1, 0, 0, 0, 0, 0] 38 | 39 | 40 | 41 | 42 | if __name__ == "__main__": 43 | """ 44 | 45 | Allows the worker to add in already built agents into training 46 | 47 | Important things to note: 48 | 49 | -RLGym only accepts 1 action parser. All agents will need to have the same parser or use a combination parser 50 | 51 | """ 52 | team_size = 1 53 | 54 | # ENSURE OBSERVATION, REWARD, AND ACTION CHOICES ARE THE SAME IN THE LEARNER 55 | match = Match( 56 | game_speed=100, 57 | self_play=True, 58 | team_size=team_size, 59 | state_setter=DefaultState(), 60 | obs_builder=ExpandAdvancedObs(), 61 | action_parser=DiscreteAction(), 62 | terminal_conditions=[TimeoutCondition(round(2000)), 63 | GoalScoredCondition()], 64 | reward_function=DefaultReward() 65 | ) 66 | 67 | 68 | # AT THE MOMENT, THE AVAILABLE NECTO MODELS ARE: 69 | # necto-model-10Y.pt 70 | # necto-model-20Y.pt 71 | # necto-model-30Y.pt 72 | model_name = "necto-model-30Y.pt" 73 | nectov1 = NectoV1(model_string=model_name, n_players=team_size*2) 74 | 75 | demo_hardcoded_agent = DemoCustomAgent() 76 | 77 | #EACH AGENT AND THEIR PROBABILITY OF OCCURRENCE 78 | pretrained_agents = {demo_hardcoded_agent: .05, nectov1: .05} 79 | 80 | r = Redis(host="127.0.0.1", password="you_better_use_a_password") 81 | 82 | 83 | RedisRolloutWorker(r, "examplePretrainedWorker", match, pretrained_agents=pretrained_agents, past_version_prob=.05).run() 84 | -------------------------------------------------------------------------------- /rocket_learn/agent/pretrained_agents/nexto/nexto_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from rocket_learn.agent.pretrained_policy import HardcodedAgent 9 | from pretrained_agents.nexto.nexto_v2_obs import Nexto_V2_ObsBuilder 10 | 11 | from rlgym.utils.gamestates import GameState 12 | 13 | import copy 14 | 15 | class NextoV2(HardcodedAgent): 16 | def __init__(self, model_string, n_players): 17 | cur_dir = os.path.dirname(os.path.realpath(__file__)) 18 | self.actor = torch.jit.load(os.path.join(cur_dir, model_string)) 19 | self.obs_builder = Nexto_V2_ObsBuilder(n_players=n_players) 20 | self.previous_action = np.array([0, 0, 0, 0, 0, 0, 0, 0]) 21 | 22 | self._lookup_table = self.make_lookup_table() 23 | 24 | @staticmethod 25 | def make_lookup_table(): 26 | actions = [] 27 | # Ground 28 | for throttle in (-1, 0, 1): 29 | for steer in (-1, 0, 1): 30 | for boost in (0, 1): 31 | for handbrake in (0, 1): 32 | if boost == 1 and throttle != 1: 33 | continue 34 | actions.append([throttle or boost, steer, 0, steer, 0, 0, boost, handbrake]) 35 | # Aerial 36 | for pitch in (-1, 0, 1): 37 | for yaw in (-1, 0, 1): 38 | for roll in (-1, 0, 1): 39 | for jump in (0, 1): 40 | for boost in (0, 1): 41 | if jump == 1 and yaw != 0: # Only need roll for sideflip 42 | continue 43 | if pitch == roll == jump == 0: # Duplicate with ground 44 | continue 45 | # Enable handbrake for potential wavedashes 46 | handbrake = jump == 1 and (pitch != 0 or yaw != 0 or roll != 0) 47 | actions.append([boost, yaw, pitch, yaw, roll, jump, boost, handbrake]) 48 | actions = np.array(actions) 49 | return actions 50 | 51 | def act(self, state: GameState, player_index: int): 52 | player = state.players[player_index] 53 | teammates = [p for p in state.players if p.team_num == player.team_num and p != player] 54 | opponents = [p for p in state.players if p.team_num != player.team_num] 55 | necto_gamestate: GameState = copy.deepcopy(state) 56 | necto_gamestate.players = [player] + teammates + opponents 57 | 58 | self.obs_builder.reset(necto_gamestate) 59 | obs = self.obs_builder.build_obs(player, necto_gamestate, self.previous_action) 60 | 61 | obs = tuple(torch.from_numpy(s).float() for s in obs) 62 | with torch.no_grad(): 63 | out, _ = self.actor(obs) 64 | 65 | out = (out,) 66 | max_shape = max(o.shape[-1] for o in out) 67 | logits = torch.stack( 68 | [ 69 | l 70 | if l.shape[-1] == max_shape 71 | else F.pad(l, pad=(0, max_shape - l.shape[-1]), value=float("-inf")) 72 | for l in out 73 | ], 74 | dim=1 75 | ) 76 | 77 | actions = np.argmax(logits, axis=-1) 78 | parsed = self._lookup_table[actions.numpy().item()] 79 | 80 | self.previous_action = parsed 81 | return parsed 82 | -------------------------------------------------------------------------------- /examples/loading/learner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | import numpy 4 | from typing import Any 5 | 6 | import torch.jit 7 | from torch.nn import Linear, Sequential 8 | 9 | from redis import Redis 10 | 11 | from rlgym.utils.obs_builders.advanced_obs import AdvancedObs 12 | from rlgym.utils.gamestates import PlayerData, GameState 13 | from rlgym.utils.reward_functions.default_reward import DefaultReward 14 | from rlgym.utils.action_parsers.discrete_act import DiscreteAction 15 | 16 | from rocket_learn.agent.actor_critic_agent import ActorCriticAgent 17 | from rocket_learn.agent.discrete_policy import DiscretePolicy 18 | from rocket_learn.ppo import PPO 19 | from rocket_learn.rollout_generator.redis.redis_rollout_generator import RedisRolloutGenerator 20 | from rocket_learn.utils.util import SplitLayer 21 | 22 | 23 | # rocket-learn always expects a batch dimension in the built observation 24 | class ExpandAdvancedObs(AdvancedObs): 25 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any: 26 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action) 27 | return numpy.expand_dims(obs, 0) 28 | 29 | 30 | if __name__ == "__main__": 31 | wandb.login(key=os.environ["WANDB_KEY"]) 32 | logger = wandb.init(project="demo", entity="wandb_username") 33 | logger.name = "LOADING_RUN_EXAMPLE" 34 | 35 | redis = Redis(password="you_better_use_a_password") 36 | 37 | 38 | def obs(): 39 | return ExpandAdvancedObs() 40 | 41 | def rew(): 42 | return DefaultReward() 43 | 44 | def act(): 45 | return DiscreteAction() 46 | 47 | 48 | # -clear DELETE REDIS ENTRIES WHEN STARTING UP (SET TO FALSE TO CONTINUE WITH OLD AGENTS) 49 | rollout_gen = RedisRolloutGenerator(redis, obs, rew, act, 50 | logger=logger, 51 | save_every=100, 52 | clear=False) 53 | 54 | critic = Sequential(Linear(107, 128), Linear(128, 64), Linear(64, 32), Linear(32, 1)) 55 | actor = DiscretePolicy( 56 | Sequential(Linear(107, 128), Linear(128, 64), Linear(64, 32), Linear(32, 21), SplitLayer())) 57 | 58 | optim = torch.optim.Adam([ 59 | {"params": actor.parameters(), "lr": 5e-5}, 60 | {"params": critic.parameters(), "lr": 5e-5} 61 | ]) 62 | 63 | agent = ActorCriticAgent(actor=actor, critic=critic, optimizer=optim) 64 | 65 | alg = PPO( 66 | rollout_gen, 67 | agent, 68 | ent_coef=0.01, 69 | n_steps=1_000_000, 70 | batch_size=20_000, 71 | minibatch_size=10_000, 72 | epochs=10, 73 | gamma=599 / 600, 74 | logger=logger, 75 | ) 76 | 77 | 78 | # LOAD A CHECKPOINT THAT WAS PREVIOUSLY SAVED AND CONTINUE TRAINING. OPTIONAL PARAMETER ALLOWS YOU 79 | # TO RESTART THE STEP COUNT INSTEAD OF CONTINUING 80 | alg.load("path\\from\\below\\checkpoint.pt") 81 | 82 | 83 | # OPTIONAL: FOR A PRETRAINED NETWORK, FREEZE THE POLICY NETWORK TO ALLOW THE CRITIC TO SETTLE 84 | # commented out here to keep you from accidentally adding it via copy/paste 85 | # alg.freeze_policy(frozen_iterations=200) 86 | 87 | 88 | # BEGIN TRAINING. IT WILL CONTINUE UNTIL MANUALLY STOPPED 89 | # -iterations_per_save SPECIFIES HOW OFTEN CHECKPOINTS ARE SAVED 90 | # -save_dir SPECIFIES WHERE 91 | alg.run(iterations_per_save=100, save_dir="checkpoint_save_directory") 92 | -------------------------------------------------------------------------------- /rocket_learn/utils/gamestate_encoding.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from rlgym.utils.gamestates import GameState 4 | 5 | 6 | def encode_gamestate(state: GameState): 7 | state_vals = [0, state.blue_score, state.orange_score] 8 | state_vals += state.boost_pads.tolist() 9 | 10 | for bd in (state.ball, state.inverted_ball): 11 | state_vals += bd.position.tolist() 12 | state_vals += bd.linear_velocity.tolist() 13 | state_vals += bd.angular_velocity.tolist() 14 | 15 | for p in state.players: 16 | state_vals += [p.car_id, p.team_num] 17 | for cd in (p.car_data, p.inverted_car_data): 18 | state_vals += cd.position.tolist() 19 | state_vals += cd.quaternion.tolist() 20 | state_vals += cd.linear_velocity.tolist() 21 | state_vals += cd.angular_velocity.tolist() 22 | state_vals += [ 23 | p.match_goals, 24 | p.match_saves, 25 | p.match_shots, 26 | p.match_demolishes, 27 | p.boost_pickups, 28 | p.is_demoed, 29 | p.on_ground, 30 | p.ball_touched, 31 | p.has_jump, 32 | p.has_flip, 33 | p.boost_amount 34 | ] 35 | return state_vals 36 | 37 | 38 | # Now some constants for easy and consistent querying of gamestate values 39 | class StateConstants: 40 | DUMMY = 0 41 | BLUE_SCORE = 1 42 | ORANGE_SCORE = 2 43 | BOOST_PADS = slice(3, 3 + GameState.BOOST_PADS_LENGTH) 44 | BALL_POSITION = slice(BOOST_PADS.stop, BOOST_PADS.stop + 3) 45 | BALL_LINEAR_VELOCITY = slice(BALL_POSITION.stop, BALL_POSITION.stop + 3) 46 | BALL_ANGULAR_VELOCITY = slice(BALL_LINEAR_VELOCITY.stop, BALL_LINEAR_VELOCITY.stop + 3) 47 | 48 | PLAYERS = slice(BALL_ANGULAR_VELOCITY.stop + 9, None) # Skip inverted data 49 | CAR_IDS = slice(0, None, GameState.PLAYER_INFO_LENGTH) 50 | TEAM_NUMS = slice(CAR_IDS.start + 1, None, GameState.PLAYER_INFO_LENGTH) 51 | CAR_POS_X = slice(TEAM_NUMS.start + 1, None, GameState.PLAYER_INFO_LENGTH) 52 | CAR_POS_Y = slice(CAR_POS_X.start + 1, None, GameState.PLAYER_INFO_LENGTH) 53 | CAR_POS_Z = slice(CAR_POS_Y.start + 1, None, GameState.PLAYER_INFO_LENGTH) 54 | CAR_QUAT_W = slice(CAR_POS_Z.start + 1, None, GameState.PLAYER_INFO_LENGTH) 55 | CAR_QUAT_X = slice(CAR_QUAT_W.start + 1, None, GameState.PLAYER_INFO_LENGTH) 56 | CAR_QUAT_Y = slice(CAR_QUAT_X.start + 1, None, GameState.PLAYER_INFO_LENGTH) 57 | CAR_QUAT_Z = slice(CAR_QUAT_Y.start + 1, None, GameState.PLAYER_INFO_LENGTH) 58 | CAR_LINEAR_VEL_X = slice(CAR_QUAT_Z.start + 1, None, GameState.PLAYER_INFO_LENGTH) 59 | CAR_LINEAR_VEL_Y = slice(CAR_LINEAR_VEL_X.start + 1, None, GameState.PLAYER_INFO_LENGTH) 60 | CAR_LINEAR_VEL_Z = slice(CAR_LINEAR_VEL_Y.start + 1, None, GameState.PLAYER_INFO_LENGTH) 61 | CAR_ANGULAR_VEL_X = slice(CAR_LINEAR_VEL_Z.start + 1, None, GameState.PLAYER_INFO_LENGTH) 62 | CAR_ANGULAR_VEL_Y = slice(CAR_ANGULAR_VEL_X.start + 1, None, GameState.PLAYER_INFO_LENGTH) 63 | CAR_ANGULAR_VEL_Z = slice(CAR_ANGULAR_VEL_Y.start + 1, None, GameState.PLAYER_INFO_LENGTH) 64 | MATCH_GOALS = slice(CAR_ANGULAR_VEL_Z.start + 1 + 13, None, GameState.PLAYER_INFO_LENGTH) # Skip inverted data 65 | MATCH_SAVES = slice(MATCH_GOALS.start + 1, None, GameState.PLAYER_INFO_LENGTH) 66 | MATCH_SHOTS = slice(MATCH_SAVES.start + 1, None, GameState.PLAYER_INFO_LENGTH) 67 | MATCH_DEMOLISHES = slice(MATCH_SHOTS.start + 1, None, GameState.PLAYER_INFO_LENGTH) 68 | BOOST_PICKUPS = slice(MATCH_DEMOLISHES.start + 1, None, GameState.PLAYER_INFO_LENGTH) 69 | IS_DEMOED = slice(BOOST_PICKUPS.start + 1, None, GameState.PLAYER_INFO_LENGTH) 70 | ON_GROUND = slice(IS_DEMOED.start + 1, None, GameState.PLAYER_INFO_LENGTH) 71 | BALL_TOUCHED = slice(ON_GROUND.start + 1, None, GameState.PLAYER_INFO_LENGTH) 72 | HAS_JUMP = slice(BALL_TOUCHED.start + 1, None, GameState.PLAYER_INFO_LENGTH) 73 | HAS_FLIP = slice(HAS_JUMP.start + 1, None, GameState.PLAYER_INFO_LENGTH) 74 | BOOST_AMOUNT = slice(HAS_FLIP.start + 1, None, GameState.PLAYER_INFO_LENGTH) 75 | -------------------------------------------------------------------------------- /examples/default/learner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | import numpy 4 | from typing import Any 5 | 6 | import torch.jit 7 | from torch.nn import Linear, Sequential, ReLU 8 | 9 | from redis import Redis 10 | 11 | from rlgym.utils.obs_builders.advanced_obs import AdvancedObs 12 | from rlgym.utils.gamestates import PlayerData, GameState 13 | from rlgym.utils.reward_functions.default_reward import DefaultReward 14 | from rlgym.utils.action_parsers.discrete_act import DiscreteAction 15 | 16 | from rocket_learn.agent.actor_critic_agent import ActorCriticAgent 17 | from rocket_learn.agent.discrete_policy import DiscretePolicy 18 | from rocket_learn.ppo import PPO 19 | from rocket_learn.rollout_generator.redis.redis_rollout_generator import RedisRolloutGenerator 20 | from rocket_learn.utils.util import SplitLayer 21 | 22 | 23 | # ROCKET-LEARN ALWAYS EXPECTS A BATCH DIMENSION IN THE BUILT OBSERVATION 24 | class ExpandAdvancedObs(AdvancedObs): 25 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any: 26 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action) 27 | return numpy.expand_dims(obs, 0) 28 | 29 | 30 | if __name__ == "__main__": 31 | """ 32 | 33 | Starts up a rocket-learn learner process, which ingests incoming data, updates parameters 34 | based on results, and sends updated model parameters out to the workers 35 | 36 | """ 37 | 38 | # ROCKET-LEARN USES WANDB WHICH REQUIRES A LOGIN TO USE. YOU CAN SET AN ENVIRONMENTAL VARIABLE 39 | # OR HARDCODE IT IF YOU ARE NOT SHARING YOUR SOURCE FILES 40 | wandb.login(key=os.environ["WANDB_KEY"]) 41 | logger = wandb.init(project="demo", entity="wandb_username") 42 | logger.name = "DEFAULT_LEARNER_EXAMPLE" 43 | 44 | # LINK TO THE REDIS SERVER YOU SHOULD HAVE RUNNING (USE THE SAME PASSWORD YOU SET IN THE REDIS 45 | # CONFIG) 46 | redis = Redis(password="you_better_use_a_password") 47 | 48 | # ** ENSURE OBSERVATION, REWARD, AND ACTION CHOICES ARE THE SAME IN THE WORKER ** 49 | def obs(): 50 | return ExpandAdvancedObs() 51 | 52 | def rew(): 53 | return DefaultReward() 54 | 55 | def act(): 56 | return DiscreteAction() 57 | 58 | 59 | # THE ROLLOUT GENERATOR CAPTURES INCOMING DATA THROUGH REDIS AND PASSES IT TO THE LEARNER. 60 | # -save_every SPECIFIES HOW OFTEN REDIS DATABASE IS BACKED UP TO DISK 61 | # -model_every SPECIFIES HOW OFTEN OLD VERSIONS ARE SAVED TO REDIS. THESE ARE USED FOR TRUESKILL 62 | # COMPARISON AND TRAINING AGAINST PREVIOUS VERSIONS 63 | # -clear DELETE REDIS ENTRIES WHEN STARTING UP (SET TO FALSE TO CONTINUE WITH OLD AGENTS) 64 | rollout_gen = RedisRolloutGenerator("demo-bot", redis, obs, rew, act, 65 | logger=logger, 66 | save_every=100, 67 | model_every=100, 68 | clear=False) 69 | 70 | # ROCKET-LEARN EXPECTS A SET OF DISTRIBUTIONS FOR EACH ACTION FROM THE NETWORK, NOT 71 | # THE ACTIONS THEMSELVES. SEE network_setup.readme.txt FOR MORE INFORMATION 72 | split = (3, 3, 3, 3, 3, 2, 2, 2) 73 | total_output = sum(split) 74 | 75 | # TOTAL SIZE OF THE INPUT DATA 76 | state_dim = 107 77 | 78 | critic = Sequential( 79 | Linear(state_dim, 256), 80 | ReLU(), 81 | Linear(256, 256), 82 | ReLU(), 83 | Linear(256, 1) 84 | ) 85 | 86 | actor = DiscretePolicy(Sequential( 87 | Linear(state_dim, 256), 88 | ReLU(), 89 | Linear(256, 256), 90 | ReLU(), 91 | Linear(256, total_output), 92 | SplitLayer(splits=split) 93 | ), split) 94 | 95 | # CREATE THE OPTIMIZER 96 | optim = torch.optim.Adam([ 97 | {"params": actor.parameters(), "lr": 5e-5}, 98 | {"params": critic.parameters(), "lr": 5e-5} 99 | ]) 100 | 101 | # PPO REQUIRES AN ACTOR/CRITIC AGENT 102 | agent = ActorCriticAgent(actor=actor, critic=critic, optimizer=optim) 103 | 104 | # INSTANTIATE THE PPO TRAINING ALGORITHM 105 | alg = PPO( 106 | rollout_gen, 107 | agent, 108 | ent_coef=0.01, 109 | n_steps=1_000_000, 110 | batch_size=20_000, 111 | minibatch_size=10_000, 112 | epochs=10, 113 | gamma=599 / 600, 114 | clip_range=0.2, 115 | gae_lambda=0.95, 116 | vf_coef=1, 117 | max_grad_norm=0.5, 118 | logger=logger, 119 | device="cuda", 120 | ) 121 | 122 | # BEGIN TRAINING. IT WILL CONTINUE UNTIL MANUALLY STOPPED 123 | # -iterations_per_save SPECIFIES HOW OFTEN CHECKPOINTS ARE SAVED 124 | # -save_dir SPECIFIES WHERE 125 | alg.run(iterations_per_save=100, save_dir="checkpoint_save_directory") 126 | -------------------------------------------------------------------------------- /rocket_learn/agent/pretrained_agents/necto/necto_v1_obs.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | 5 | from rlgym.utils import ObsBuilder 6 | from rlgym.utils.common_values import BOOST_LOCATIONS, BLUE_TEAM, ORANGE_TEAM 7 | from rlgym.utils.gamestates import GameState, PlayerData 8 | 9 | 10 | class NectoV1Obs(ObsBuilder): 11 | _boost_locations = np.array(BOOST_LOCATIONS) 12 | _invert = np.array([1] * 5 + [-1, -1, 1] * 5 + [1] * 4) 13 | _norm = np.array([1.] * 5 + [2300] * 6 + [1] * 6 + [5.5] * 3 + [1] * 4) 14 | 15 | def __init__(self, n_players=6, tick_skip=8): 16 | super().__init__() 17 | self.n_players = n_players 18 | self.demo_timers = None 19 | self.boost_timers = None 20 | self.current_state = None 21 | self.current_qkv = None 22 | self.current_mask = None 23 | self.tick_skip = tick_skip 24 | 25 | def reset(self, initial_state: GameState): 26 | self.demo_timers = np.zeros(self.n_players) 27 | self.boost_timers = np.zeros(len(initial_state.boost_pads)) 28 | # self.current_state = initial_state 29 | 30 | def _maybe_update_obs(self, state: GameState): 31 | if state == self.current_state: # No need to update 32 | return 33 | 34 | if self.boost_timers is None: 35 | self.reset(state) 36 | else: 37 | self.current_state = state 38 | 39 | qkv = np.zeros((1, 1 + self.n_players + len(state.boost_pads), 24)) # Ball, players, boosts 40 | 41 | # Add ball 42 | n = 0 43 | ball = state.ball 44 | qkv[0, 0, 3] = 1 # is_ball 45 | qkv[0, 0, 5:8] = ball.position 46 | qkv[0, 0, 8:11] = ball.linear_velocity 47 | qkv[0, 0, 17:20] = ball.angular_velocity 48 | 49 | # Add players 50 | n += 1 51 | demos = np.zeros(self.n_players) # Which players are currently demoed 52 | for player in state.players: 53 | if player.team_num == BLUE_TEAM: 54 | qkv[0, n, 1] = 1 # is_teammate 55 | else: 56 | qkv[0, n, 2] = 1 # is_opponent 57 | car_data = player.car_data 58 | qkv[0, n, 5:8] = car_data.position 59 | qkv[0, n, 8:11] = car_data.linear_velocity 60 | qkv[0, n, 11:14] = car_data.forward() 61 | qkv[0, n, 14:17] = car_data.up() 62 | qkv[0, n, 17:20] = car_data.angular_velocity 63 | qkv[0, n, 20] = player.boost_amount 64 | # qkv[0, n, 21] = player.is_demoed 65 | demos[n - 1] = player.is_demoed # Keep track for demo timer 66 | qkv[0, n, 22] = player.on_ground 67 | qkv[0, n, 23] = player.has_flip 68 | n += 1 69 | 70 | # Add boost pads 71 | n = 1 + self.n_players 72 | boost_pads = state.boost_pads 73 | qkv[0, n:, 4] = 1 # is_boost 74 | qkv[0, n:, 5:8] = self._boost_locations 75 | qkv[0, n:, 20] = 0.12 + 0.88 * (self._boost_locations[:, 2] > 72) # Boost amount 76 | # qkv[0, n:, 21] = boost_pads 77 | 78 | # Boost and demo timers 79 | new_boost_grabs = (boost_pads == 1) & (self.boost_timers == 0) # New boost grabs since last frame 80 | self.boost_timers[new_boost_grabs] = 0.4 + 0.6 * (self._boost_locations[new_boost_grabs, 2] > 72) 81 | self.boost_timers *= boost_pads # Make sure we have zeros right 82 | qkv[0, 1 + self.n_players:, 21] = self.boost_timers 83 | self.boost_timers -= self.tick_skip / 1200 # Pre-normalized, 120 fps for 10 seconds 84 | self.boost_timers[self.boost_timers < 0] = 0 85 | 86 | new_demos = (demos == 1) & (self.demo_timers == 0) 87 | self.demo_timers[new_demos] = 0.3 88 | self.demo_timers *= demos 89 | qkv[0, 1: 1 + self.n_players, 21] = self.demo_timers 90 | self.demo_timers -= self.tick_skip / 1200 91 | self.demo_timers[self.demo_timers < 0] = 0 92 | 93 | # Store results 94 | self.current_qkv = qkv / self._norm 95 | mask = np.zeros((1, qkv.shape[1])) 96 | mask[0, 1 + len(state.players):1 + self.n_players] = 1 97 | self.current_mask = mask 98 | 99 | def build_obs(self, player: PlayerData, state: GameState, previous_action: np.ndarray) -> Any: 100 | if self.boost_timers is None: 101 | return np.zeros(0) # Obs space autodetect, make Aech happy 102 | self._maybe_update_obs(state) 103 | invert = player.team_num == ORANGE_TEAM 104 | 105 | qkv = self.current_qkv.copy() 106 | mask = self.current_mask.copy() 107 | 108 | main_n = state.players.index(player) + 1 109 | qkv[0, main_n, 0] = 1 # is_main 110 | if invert: 111 | qkv[0, :, (1, 2)] = qkv[0, :, (2, 1)] # Swap blue/orange 112 | qkv *= self._invert # Negate x and y values 113 | 114 | # TODO left-right normalization (always pick one side) 115 | 116 | q = qkv[0, main_n, :] 117 | 118 | #print(q) 119 | #print("vs") 120 | #print(previous_action) 121 | q = np.expand_dims(np.concatenate((q, previous_action), axis=0), axis=(0, 1)) 122 | # kv = np.delete(qkv, main_n, axis=0) # Delete main? Watch masking 123 | kv = qkv 124 | 125 | # With EARLPerceiver we can use relative coords+vel(+more?) for key/value tensor, might be smart 126 | kv[0, :, 5:11] -= q[0, 0, 5:11] 127 | return q, kv, mask -------------------------------------------------------------------------------- /rocket_learn/utils/scoreboard.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy as np 5 | from numpy.random import poisson 6 | from rlgym.utils.common_values import BACK_WALL_Y, SIDE_WALL_X, GOAL_HEIGHT 7 | 8 | from rlgym.utils.gamestates import GameState 9 | 10 | TICKS_PER_SECOND = 120 11 | SECONDS_PER_MINUTE = 60 12 | GOALS_PER_MIN = (1, 0.6, 0.45) # Stats from ballchasing, S14 GC (before SSL) 13 | 14 | 15 | class Scoreboard: 16 | def __init__(self, random_resets=True, tick_skip=8, max_time_seconds=300, skip_warning=False): 17 | super().__init__() 18 | self.random_resets = random_resets 19 | self.tick_skip = tick_skip 20 | self.max_time_seconds = max_time_seconds 21 | self.ticks_left = None 22 | self.scoreline = None 23 | self.state = None 24 | if not skip_warning: 25 | print("WARNING: The Scoreboard object overwrites the inverted ball ang.vel. to include scoreboard, " 26 | "make sure you're not using that and instead inverting on your own. " 27 | "Call it in your obs builder's pre-step method to use.") 28 | 29 | def reset(self, initial_state: GameState): 30 | self.state = initial_state 31 | players_per_team = len(initial_state.players) // 2 32 | if self.random_resets: 33 | gpm = GOALS_PER_MIN[players_per_team - 1] 34 | mu_full = gpm * self.max_time_seconds / SECONDS_PER_MINUTE 35 | full_game = poisson(mu_full, size=2) 36 | if full_game[1] == full_game[0]: 37 | self.ticks_left = float("inf") # Overtime 38 | b = o = full_game[0].item() 39 | else: 40 | max_ticks = self.max_time_seconds * TICKS_PER_SECOND 41 | self.ticks_left = random.randrange(0, max_ticks) 42 | seconds_spent = self.max_time_seconds - self.ticks_left / TICKS_PER_SECOND 43 | mu_spent = gpm * seconds_spent / SECONDS_PER_MINUTE 44 | b, o = poisson(mu_spent, size=2).tolist() 45 | self.scoreline = b, o 46 | else: 47 | self.scoreline = 0, 0 48 | self.ticks_left = self.max_time_seconds * TICKS_PER_SECOND 49 | self.modify_gamestate(initial_state) 50 | 51 | def step(self, state: GameState, update_scores=True): 52 | if state != self.state: 53 | if state.ball.position[1] != 0: # Don't count during kickoffs 54 | self.ticks_left = max(0, self.ticks_left - self.tick_skip) 55 | 56 | if update_scores: 57 | b, o = self.scoreline 58 | changed = False 59 | if state.blue_score > self.state.blue_score: # Check in case of crash 60 | b += state.blue_score - self.state.blue_score 61 | changed = True 62 | if state.orange_score > self.state.orange_score: 63 | o += state.orange_score - self.state.orange_score 64 | changed = True 65 | tied = b == o 66 | if self.is_overtime(): 67 | if not tied: 68 | self.ticks_left = float("-inf") # Finished 69 | if self.ticks_left <= 0 and (state.ball.position[2] <= 110 or changed): 70 | if tied: 71 | self.ticks_left = float("inf") # Overtime 72 | else: 73 | self.ticks_left = float("-inf") # Finished 74 | self.scoreline = b, o 75 | 76 | self.state = state 77 | self.modify_gamestate(state) 78 | 79 | def modify_gamestate(self, state: GameState): 80 | state.inverted_ball.angular_velocity[:] = *self.scoreline, self.ticks_left 81 | 82 | def is_overtime(self): 83 | return self.ticks_left > 0 and math.isinf(self.ticks_left) 84 | 85 | def is_finished(self): 86 | return self.ticks_left < 0 and math.isinf(self.ticks_left) 87 | 88 | def win_prob(self): 89 | return win_prob(self.state.players // 2, 90 | self.ticks_left * TICKS_PER_SECOND, 91 | self.scoreline[0] - self.scoreline[1]).item() 92 | 93 | 94 | FLOOR_AREA = 4 * BACK_WALL_Y * SIDE_WALL_X - 1152 * 1152 # Subtract corners 95 | GOAL_AREA = GOAL_HEIGHT * 880 96 | 97 | 98 | def win_prob(players_per_team, time_left_seconds, differential): 99 | # Utility function, calculates probability of blue team winning the full game 100 | from scipy.stats import skellam 101 | 102 | players_per_team = np.asarray(players_per_team) 103 | time_left_seconds = np.asarray(time_left_seconds) 104 | differential = np.asarray(differential) 105 | 106 | goal_floor_ratio = GOAL_AREA / (2 * GOAL_AREA + FLOOR_AREA) 107 | 108 | # inverted = np.random.random(differential.shape) > 0.5 109 | # differential[inverted] *= -1 110 | p = np.zeros(differential.shape) 111 | 112 | gpm = np.array(GOALS_PER_MIN)[players_per_team - 1] 113 | zero_seconds = (time_left_seconds <= 0) | np.isinf(time_left_seconds) 114 | 115 | mu_left = gpm * time_left_seconds / SECONDS_PER_MINUTE 116 | mu_left = mu_left[~zero_seconds] 117 | dist_left = skellam(mu_left, mu_left) 118 | 119 | diff_regulation = differential[~zero_seconds] 120 | 121 | # Probability of leading by two or more goals at end of regulation 122 | p[zero_seconds & (differential >= 2)] += 1 123 | p[~zero_seconds] += dist_left.cdf(diff_regulation - 2) 124 | 125 | # Probability of being tied at zero seconds and winning in overtime 126 | w = 0.5 127 | p[zero_seconds & (differential == 0)] += w 128 | p[~zero_seconds] += dist_left.pmf(diff_regulation) * w 129 | 130 | # Probability of leading by one at zero seconds, and surviving or getting scored on and winning OT 131 | w = (1 - goal_floor_ratio * 0.5) 132 | p[zero_seconds & (differential == 1)] += w 133 | p[~zero_seconds] += dist_left.pmf(diff_regulation - 1) * w 134 | 135 | # Probability of losing by one at zero seconds, and scoring and winning overtime 136 | w = goal_floor_ratio * 0.5 137 | p[zero_seconds & (differential == -1)] += w 138 | p[~zero_seconds] += dist_left.pmf(diff_regulation + 1) * w 139 | 140 | # p[inverted] = 1 - p[inverted] 141 | 142 | return p 143 | -------------------------------------------------------------------------------- /rocket_learn/rollout_generator/redis/utils.py: -------------------------------------------------------------------------------- 1 | # Constants for consistent key lookup 2 | import pickle 3 | import zlib 4 | from typing import List, Optional, Union, Dict 5 | 6 | import numpy as np 7 | from redis import Redis 8 | from rlgym.utils.gamestates import GameState 9 | from trueskill import Rating 10 | 11 | from rocket_learn.experience_buffer import ExperienceBuffer 12 | from rocket_learn.utils.batched_obs_builder import BatchedObsBuilder 13 | from rocket_learn.utils.gamestate_encoding import encode_gamestate 14 | import msgpack 15 | import msgpack_numpy as m 16 | 17 | QUALITIES = "qualities-{}" 18 | N_UPDATES = "num-updates" 19 | # SAVE_FREQ = "save-freq" 20 | # MODEL_FREQ = "model-freq" 21 | 22 | MODEL_LATEST = "model-latest" 23 | VERSION_LATEST = "model-version" 24 | 25 | ROLLOUTS = "rollout" 26 | OPPONENT_MODELS = "opponent-models" 27 | WORKER_IDS = "worker-ids" 28 | CONTRIBUTORS = "contributors" 29 | LATEST_RATING_ID = "latest-rating-id" 30 | EXPERIENCE_PER_MODE = "experience-per-mode" 31 | _ALL = ( 32 | N_UPDATES, MODEL_LATEST, VERSION_LATEST, ROLLOUTS, OPPONENT_MODELS, 33 | WORKER_IDS, CONTRIBUTORS, LATEST_RATING_ID, EXPERIENCE_PER_MODE) 34 | 35 | m.patch() 36 | 37 | 38 | # Helper methods for easier changing of byte conversion 39 | def _serialize(obj): 40 | return zlib.compress(msgpack.packb(obj)) 41 | 42 | 43 | def _unserialize(obj): 44 | return msgpack.unpackb(zlib.decompress(obj)) 45 | 46 | 47 | def _serialize_model(mdl): 48 | device = next(mdl.parameters()).device # Must be a better way right? 49 | mdl_bytes = pickle.dumps(mdl.cpu()) 50 | mdl.to(device) 51 | return mdl_bytes 52 | 53 | 54 | def _unserialize_model(buf): 55 | agent = pickle.loads(buf) 56 | return agent 57 | 58 | 59 | def get_rating(gamemode: str, model_id: Optional[str], redis: Redis) -> Union[Rating, Dict[str, Rating]]: 60 | """ 61 | Get the rating of a player. 62 | :param gamemode: The game mode to get the rating for. 63 | :param model_id: The id of the model. 64 | :param redis: The redis client. 65 | :return: The rating of the player. 66 | """ 67 | quality_key = QUALITIES.format(gamemode) 68 | if model_id is None: # Return all ratings 69 | return { 70 | k.decode("utf-8"): Rating(*_unserialize(v)) 71 | for k, v in redis.hgetall(quality_key).items() 72 | } 73 | return Rating(*_unserialize(redis.hget(quality_key, model_id))) 74 | 75 | 76 | def encode_buffers(buffers: List[ExperienceBuffer], return_obs=True, return_states=True, return_rewards=True): 77 | res = [] 78 | 79 | if return_states: 80 | states = np.asarray([encode_gamestate(info["state"]) for info in buffers[0].infos] if len(buffers) > 0 else []) 81 | res.append(states) 82 | 83 | if return_obs: 84 | observations = [buffer.observations for buffer in buffers] 85 | res.append(observations) 86 | 87 | if return_rewards: 88 | rewards = np.asarray([buffer.rewards for buffer in buffers]) 89 | res.append(rewards) 90 | 91 | actions = np.asarray([buffer.actions for buffer in buffers]) 92 | log_probs = np.asarray([buffer.log_probs for buffer in buffers]) 93 | dones = np.asarray([buffer.dones for buffer in buffers]) 94 | res.append(actions) 95 | res.append(log_probs) 96 | res.append(dones) 97 | 98 | return res 99 | 100 | 101 | def decode_buffers(enc_buffers, versions, has_obs, has_states, has_rewards, 102 | obs_build_factory=None, rew_func_factory=None, act_parse_factory=None): 103 | assert has_states or has_obs, "Must have at least one of obs or states" 104 | assert has_states or has_rewards, "Must have at least one of rewards or states" 105 | assert not has_obs or has_obs and has_rewards, "Must have both obs and rewards" # TODO obs+no reward? 106 | 107 | i = 0 108 | if has_states: 109 | game_states = enc_buffers[i] 110 | if len(game_states) == 0: 111 | raise RuntimeError 112 | i += 1 113 | else: 114 | game_states = None 115 | if has_obs: 116 | obs = enc_buffers[i] 117 | i += 1 118 | else: 119 | obs = None 120 | if has_rewards: 121 | rewards = enc_buffers[i] 122 | i += 1 123 | # dones = np.zeros_like(rewards, dtype=bool) # TODO: Support for dones? 124 | # if len(dones) > 0: 125 | # dones[:, -1] = True 126 | else: 127 | rewards = None 128 | # dones = None 129 | actions = enc_buffers[i] 130 | i += 1 131 | log_probs = enc_buffers[i] 132 | i += 1 133 | dones = enc_buffers[i] 134 | i += 1 135 | 136 | if obs is None: 137 | # Reconstruct observations 138 | obs_builder = obs_build_factory() 139 | act_parser = act_parse_factory() 140 | if isinstance(obs_builder, BatchedObsBuilder): 141 | # TODO support states+no rewards 142 | assert game_states is not None and rewards is not None, "Must have both game states and rewards" 143 | obs = obs_builder.batched_build_obs(game_states[:-1]) 144 | prev_actions = act_parser.parse_actions(actions.reshape((-1,) + actions.shape[2:]).copy(), None).reshape( 145 | actions.shape[:2] + (8,)) 146 | prev_actions = np.concatenate((np.zeros((actions.shape[0], 1, 8)), prev_actions[:, :-1]), axis=1) 147 | obs_builder.add_actions(obs, prev_actions) 148 | buffers = [ 149 | ExperienceBuffer(observations=[obs[i]], actions=actions[i], rewards=rewards[i], dones=dones[i], 150 | log_probs=log_probs[i]) 151 | for i in range(len(obs)) 152 | ] 153 | return buffers, game_states 154 | else: # Slow reconstruction, but works for any ObsBuilder 155 | gs_arrays = game_states 156 | game_states = [GameState(gs.tolist()) for gs in game_states] 157 | rew_func = rew_func_factory() 158 | obs_builder.reset(game_states[0]) 159 | rew_func.reset(game_states[0]) 160 | buffers = [ 161 | ExperienceBuffer(infos=[{"state": game_states[0]}]) 162 | for _ in range(len(game_states[0].players)) 163 | ] 164 | 165 | env_actions = [ 166 | act_parser.parse_actions(actions[:, s, :].copy(), game_states[s]) 167 | for s in range(actions.shape[1]) 168 | ] 169 | 170 | obss = [obs_builder.build_obs(p, game_states[0], np.zeros(8)) 171 | for i, p in enumerate(game_states[0].players)] 172 | for s, gs in enumerate(game_states[1:]): 173 | assert len(gs.players) == len(versions) 174 | final = s == len(game_states) - 2 175 | old_obs = obss 176 | obss = [] 177 | i = 0 178 | for version in versions: 179 | if version == 'na': 180 | continue # don't want to rebuild or use prebuilt agents 181 | player = gs.players[i] 182 | 183 | # IF ONLY 1 buffer is returned, need a way to say to discard bad version 184 | 185 | obs = obs_builder.build_obs(player, gs, env_actions[s][i]) 186 | if rewards is None: 187 | if final: 188 | rew = rew_func.get_final_reward(player, gs, env_actions[s][i]) 189 | else: 190 | rew = rew_func.get_reward(player, gs, env_actions[s][i]) 191 | else: 192 | rew = rewards[i][s] 193 | buffers[i].add_step(old_obs[i], actions[i][s], rew, final, log_probs[i][s], {"state": gs}) 194 | obss.append(obs) 195 | i += 1 196 | 197 | return buffers, gs_arrays 198 | else: # We have everything we need 199 | buffers = [] 200 | for i in range(len(obs)): 201 | buffers.append( 202 | ExperienceBuffer(observations=obs[i], 203 | actions=actions[i], 204 | rewards=rewards[i], 205 | dones=dones[i], 206 | log_probs=log_probs[i]) 207 | ) 208 | return buffers, game_states 209 | -------------------------------------------------------------------------------- /rocket_learn/utils/generate_episode.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | from rlgym.gym import Gym 6 | from rlgym.utils.reward_functions.common_rewards import ConstantReward 7 | from rlgym.utils.state_setters import DefaultState 8 | from tqdm import tqdm 9 | 10 | from rocket_learn.agent.policy import Policy 11 | from rocket_learn.agent.pretrained_policy import HardcodedAgent 12 | from rocket_learn.experience_buffer import ExperienceBuffer 13 | from rocket_learn.utils.dynamic_gamemode_setter import DynamicGMSetter 14 | from rocket_learn.utils.truncated_condition import TruncatedCondition 15 | 16 | 17 | def generate_episode(env: Gym, policies, evaluate=False, scoreboard=None, progress=False) -> ( 18 | List[ExperienceBuffer], int): 19 | """ 20 | create experience buffer data by interacting with the environment(s) 21 | """ 22 | if progress: 23 | progress = tqdm(unit=" steps") 24 | else: 25 | progress = None 26 | 27 | if evaluate: # Change setup temporarily to play a normal game (approximately) 28 | from rlgym_tools.extra_terminals.game_condition import GameCondition # tools is an optional dependency 29 | terminals = env._match._terminal_conditions # noqa 30 | reward = env._match._reward_fn # noqa 31 | game_condition = GameCondition(tick_skip=env._match._tick_skip, # noqa 32 | seconds_per_goal_forfeit=10 * env._match._team_size, 33 | max_overtime_seconds=300, 34 | max_no_touch_seconds=60) 35 | env._match._terminal_conditions = [game_condition] # noqa 36 | if isinstance(env._match._state_setter, DynamicGMSetter): # noqa 37 | state_setter = env._match._state_setter.setter # noqa 38 | env._match._state_setter.setter = DefaultState() # noqa 39 | else: 40 | state_setter = env._match._state_setter # noqa 41 | env._match._state_setter = DefaultState() # noqa 42 | 43 | env._match._reward_fn = ConstantReward() # noqa Save some cpu cycles 44 | 45 | if scoreboard is not None: 46 | random_resets = scoreboard.random_resets 47 | scoreboard.random_resets = not evaluate 48 | observations, info = env.reset(return_info=True) 49 | result = 0 50 | 51 | last_state = info['state'] # game_state for obs_building of other agents 52 | 53 | latest_policy_indices = [0 if isinstance(p, HardcodedAgent) else 1 for p in policies] 54 | # rollouts for all latest_policies 55 | rollouts = [ 56 | ExperienceBuffer(infos=[info]) 57 | for _ in range(sum(latest_policy_indices)) 58 | ] 59 | 60 | b = o = 0 61 | with torch.no_grad(): 62 | while True: 63 | all_indices = [] 64 | all_actions = [] 65 | all_log_probs = [] 66 | 67 | # if observation isn't a list, make it one so we don't iterate over the observation directly 68 | if not isinstance(observations, list): 69 | observations = [observations] 70 | 71 | if not isinstance(policies[0], HardcodedAgent) and all(policy == policies[0] for policy in policies): 72 | policy = policies[0] 73 | if isinstance(observations[0], tuple): 74 | obs = tuple(np.concatenate([obs[i] for obs in observations], axis=0) 75 | for i in range(len(observations[0]))) 76 | else: 77 | obs = np.concatenate(observations, axis=0) 78 | dist = policy.get_action_distribution(obs) 79 | action_indices = policy.sample_action(dist) 80 | log_probs = policy.log_prob(dist, action_indices) 81 | actions = policy.env_compatible(action_indices) 82 | 83 | all_indices.extend(list(action_indices.numpy())) 84 | all_actions.extend(list(actions)) 85 | all_log_probs.extend(list(log_probs.numpy())) 86 | else: 87 | index = 0 88 | for policy, obs in zip(policies, observations): 89 | if isinstance(policy, HardcodedAgent): 90 | actions = policy.act(last_state, index) 91 | 92 | # make sure output is in correct format 93 | if not isinstance(observations, np.ndarray): 94 | actions = np.array(actions) 95 | 96 | # TODO: add converter that takes normal 8 actions into action space 97 | # actions = env._match._action_parser.convert_to_action_space(actions) 98 | 99 | all_indices.append(None) 100 | all_actions.append(actions) 101 | all_log_probs.append(None) 102 | 103 | elif isinstance(policy, Policy): 104 | dist = policy.get_action_distribution(obs) 105 | action_indices = policy.sample_action(dist)[0] 106 | log_probs = policy.log_prob(dist, action_indices).item() 107 | actions = policy.env_compatible(action_indices) 108 | 109 | all_indices.append(action_indices.numpy()) 110 | all_actions.append(actions) 111 | all_log_probs.append(log_probs) 112 | 113 | else: 114 | print(str(type(policy)) + " type use not defined") 115 | assert False 116 | 117 | index += 1 118 | 119 | # to allow different action spaces, pad out short ones to longest length (assume later unpadding in parser) 120 | # length = max([a.shape[0] for a in all_actions]) 121 | # padded_actions = [] 122 | # for a in all_actions: 123 | # action = np.pad(a.astype('float64'), (0, length - a.size), 'constant', constant_values=np.NAN) 124 | # padded_actions.append(action) 125 | # 126 | # all_actions = padded_actions 127 | # TEST OUT ABOVE TO DEAL WITH VARIABLE LENGTH 128 | 129 | all_actions = np.vstack(all_actions) 130 | old_obs = observations 131 | observations, rewards, done, info = env.step(all_actions) 132 | 133 | truncated = False 134 | for terminal in env._match._terminal_conditions: # noqa 135 | if isinstance(terminal, TruncatedCondition): 136 | truncated |= terminal.is_truncated(info["state"]) 137 | 138 | if len(policies) <= 1: 139 | observations, rewards = [observations], [rewards] 140 | 141 | # prune data that belongs to old agents 142 | old_obs = [a for i, a in enumerate(old_obs) if latest_policy_indices[i] == 1] 143 | all_indices = [d for i, d in enumerate(all_indices) if latest_policy_indices[i] == 1] 144 | rewards = [r for i, r in enumerate(rewards) if latest_policy_indices[i] == 1] 145 | all_log_probs = [r for i, r in enumerate(all_log_probs) if latest_policy_indices[i] == 1] 146 | 147 | assert len(old_obs) == len(all_indices), str(len(old_obs)) + " obs, " + str(len(all_indices)) + " ind" 148 | assert len(old_obs) == len(rewards), str(len(old_obs)) + " obs, " + str(len(rewards)) + " ind" 149 | assert len(old_obs) == len(all_log_probs), str(len(old_obs)) + " obs, " + str(len(all_log_probs)) + " ind" 150 | assert len(old_obs) == len(rollouts), str(len(old_obs)) + " obs, " + str(len(rollouts)) + " ind" 151 | 152 | # Might be different if only one agent? 153 | if not evaluate: # Evaluation matches can be long, no reason to keep them in memory 154 | for exp_buf, obs, act, rew, log_prob in zip(rollouts, old_obs, all_indices, rewards, all_log_probs): 155 | exp_buf.add_step(obs, act, rew, done + 2 * truncated, log_prob, info) 156 | 157 | if progress is not None: 158 | progress.update() 159 | igt = progress.n * env._match._tick_skip / 120 # noqa 160 | prog_str = f"{igt // 60:02.0f}:{igt % 60:02.0f} IGT" 161 | if evaluate: 162 | prog_str += f", BLUE {b} - {o} ORANGE" 163 | progress.set_postfix_str(prog_str) 164 | 165 | if done or truncated: 166 | result += info["result"] 167 | if info["result"] > 0: 168 | b += 1 169 | elif info["result"] < 0: 170 | o += 1 171 | 172 | if not evaluate: 173 | break 174 | elif game_condition.done: # noqa 175 | break 176 | else: 177 | observations, info = env.reset(return_info=True) 178 | 179 | last_state = info['state'] 180 | 181 | if scoreboard is not None: 182 | scoreboard.random_resets = random_resets # noqa Checked above 183 | 184 | if evaluate: 185 | if isinstance(env._match._state_setter, DynamicGMSetter): # noqa 186 | env._match._state_setter.setter = state_setter # noqa 187 | else: 188 | env._match._state_setter = state_setter # noqa 189 | env._match._terminal_conditions = terminals # noqa 190 | env._match._reward_fn = reward # noqa 191 | return result 192 | 193 | if progress is not None: 194 | progress.close() 195 | 196 | return rollouts, result 197 | -------------------------------------------------------------------------------- /rocket_learn/agent/pretrained_agents/nexto/nexto_v2_obs.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | from rlgym.utils.gamestates import GameState, PlayerData 5 | 6 | BOOST_LOCATIONS = ( 7 | (0.0, -4240.0, 70.0), 8 | (-1792.0, -4184.0, 70.0), 9 | (1792.0, -4184.0, 70.0), 10 | (-3072.0, -4096.0, 73.0), 11 | (3072.0, -4096.0, 73.0), 12 | (- 940.0, -3308.0, 70.0), 13 | (940.0, -3308.0, 70.0), 14 | (0.0, -2816.0, 70.0), 15 | (-3584.0, -2484.0, 70.0), 16 | (3584.0, -2484.0, 70.0), 17 | (-1788.0, -2300.0, 70.0), 18 | (1788.0, -2300.0, 70.0), 19 | (-2048.0, -1036.0, 70.0), 20 | (0.0, -1024.0, 70.0), 21 | (2048.0, -1036.0, 70.0), 22 | (-3584.0, 0.0, 73.0), 23 | (-1024.0, 0.0, 70.0), 24 | (1024.0, 0.0, 70.0), 25 | (3584.0, 0.0, 73.0), 26 | (-2048.0, 1036.0, 70.0), 27 | (0.0, 1024.0, 70.0), 28 | (2048.0, 1036.0, 70.0), 29 | (-1788.0, 2300.0, 70.0), 30 | (1788.0, 2300.0, 70.0), 31 | (-3584.0, 2484.0, 70.0), 32 | (3584.0, 2484.0, 70.0), 33 | (0.0, 2816.0, 70.0), 34 | (- 940.0, 3310.0, 70.0), 35 | (940.0, 3308.0, 70.0), 36 | (-3072.0, 4096.0, 73.0), 37 | (3072.0, 4096.0, 73.0), 38 | (-1792.0, 4184.0, 70.0), 39 | (1792.0, 4184.0, 70.0), 40 | (0.0, 4240.0, 70.0), 41 | ) 42 | 43 | 44 | def rotation_to_quaternion(m: np.ndarray) -> np.ndarray: 45 | trace = np.trace(m) 46 | q = np.zeros(4) 47 | 48 | if trace > 0: 49 | s = (trace + 1) ** 0.5 50 | q[0] = s * 0.5 51 | s = 0.5 / s 52 | q[1] = (m[2, 1] - m[1, 2]) * s 53 | q[2] = (m[0, 2] - m[2, 0]) * s 54 | q[3] = (m[1, 0] - m[0, 1]) * s 55 | else: 56 | if m[0, 0] >= m[1, 1] and m[0, 0] >= m[2, 2]: 57 | s = (1 + m[0, 0] - m[1, 1] - m[2, 2]) ** 0.5 58 | inv_s = 0.5 / s 59 | q[1] = 0.5 * s 60 | q[2] = (m[1, 0] + m[0, 1]) * inv_s 61 | q[3] = (m[2, 0] + m[0, 2]) * inv_s 62 | q[0] = (m[2, 1] - m[1, 2]) * inv_s 63 | elif m[1, 1] > m[2, 2]: 64 | s = (1 + m[1, 1] - m[0, 0] - m[2, 2]) ** 0.5 65 | inv_s = 0.5 / s 66 | q[1] = (m[0, 1] + m[1, 0]) * inv_s 67 | q[2] = 0.5 * s 68 | q[3] = (m[1, 2] + m[2, 1]) * inv_s 69 | q[0] = (m[0, 2] - m[2, 0]) * inv_s 70 | else: 71 | s = (1 + m[2, 2] - m[0, 0] - m[1, 1]) ** 0.5 72 | inv_s = 0.5 / s 73 | q[1] = (m[0, 2] + m[2, 0]) * inv_s 74 | q[2] = (m[1, 2] + m[2, 1]) * inv_s 75 | q[3] = 0.5 * s 76 | q[0] = (m[1, 0] - m[0, 1]) * inv_s 77 | 78 | # q[[0, 1, 2, 3]] = q[[3, 0, 1, 2]] 79 | 80 | return -q 81 | 82 | 83 | def encode_gamestate(state: GameState): 84 | state_vals = [0, state.blue_score, state.orange_score] 85 | state_vals += state.boost_pads.tolist() 86 | 87 | for bd in (state.ball, state.inverted_ball): 88 | state_vals += bd.position.tolist() 89 | state_vals += bd.linear_velocity.tolist() 90 | state_vals += bd.angular_velocity.tolist() 91 | 92 | for p in state.players: 93 | state_vals += [p.car_id, p.team_num] 94 | for cd in (p.car_data, p.inverted_car_data): 95 | state_vals += cd.position.tolist() 96 | state_vals += rotation_to_quaternion(cd.rotation_mtx()).tolist() 97 | state_vals += cd.linear_velocity.tolist() 98 | state_vals += cd.angular_velocity.tolist() 99 | state_vals += [ 100 | 0, 101 | 0, 102 | 0, 103 | 0, 104 | 0, 105 | p.is_demoed, 106 | p.on_ground, 107 | p.ball_touched, 108 | p.has_flip, 109 | p.boost_amount 110 | ] 111 | return state_vals 112 | 113 | 114 | class BatchedObsBuilder: 115 | def __init__(self): 116 | super().__init__() 117 | self.current_state = None 118 | self.current_obs = None 119 | 120 | def batched_build_obs(self, encoded_states: np.ndarray) -> Any: 121 | raise NotImplementedError 122 | 123 | def add_actions(self, obs: Any, previous_actions: np.ndarray, player_index=None): 124 | # Modify current obs to include action 125 | # player_index=None means actions for all players should be provided 126 | raise NotImplementedError 127 | 128 | def _reset(self, initial_state: GameState): 129 | raise NotImplementedError 130 | 131 | def reset(self, initial_state: GameState): 132 | self.current_state = False 133 | self.current_obs = None 134 | self._reset(initial_state) 135 | 136 | def build_obs(self, player: PlayerData, state: GameState, previous_action: np.ndarray) -> Any: 137 | # if state != self.current_state: 138 | self.current_obs = self.batched_build_obs( 139 | np.expand_dims(encode_gamestate(state), axis=0) 140 | ) 141 | self.current_state = state 142 | 143 | for i, p in enumerate(state.players): 144 | if p == player: 145 | self.add_actions(self.current_obs, previous_action, i) 146 | return self.current_obs[i] 147 | 148 | 149 | IS_SELF, IS_MATE, IS_OPP, IS_BALL, IS_BOOST = range(5) 150 | POS = slice(5, 8) 151 | LIN_VEL = slice(8, 11) 152 | FW = slice(11, 14) 153 | UP = slice(14, 17) 154 | ANG_VEL = slice(17, 20) 155 | BOOST, DEMO, ON_GROUND, HAS_FLIP = range(20, 24) 156 | ACTIONS = range(24, 32) 157 | 158 | BALL_STATE_LENGTH = 18 159 | PLAYER_CAR_STATE_LENGTH = 13 160 | PLAYER_TERTIARY_INFO_LENGTH = 10 161 | PLAYER_INFO_LENGTH = 2 + 2 * PLAYER_CAR_STATE_LENGTH + PLAYER_TERTIARY_INFO_LENGTH 162 | 163 | 164 | class Nexto_V2_ObsBuilder(BatchedObsBuilder): 165 | _invert = np.array([1] * 5 + [-1, -1, 1] * 5 + [1] * 4) 166 | _norm = np.array([1.] * 5 + [2300] * 6 + [1] * 6 + [5.5] * 3 + [1] * 4) 167 | 168 | def __init__(self, field_info=None, n_players=None, tick_skip=8): 169 | super().__init__() 170 | self.n_players = n_players 171 | self.demo_timers = None 172 | self.boost_timers = None 173 | self.tick_skip = tick_skip 174 | if field_info is None: 175 | self._boost_locations = np.array(BOOST_LOCATIONS) 176 | self._boost_types = self._boost_locations[:, 2] > 72 177 | else: 178 | self._boost_locations = np.array([[bp.location.x, bp.location.y, bp.location.z] 179 | for bp in field_info.boost_pads[:field_info.num_boosts]]) 180 | self._boost_types = np.array([bp.is_full_boost for bp in field_info.boost_pads[:field_info.num_boosts]]) 181 | 182 | def _reset(self, initial_state: GameState): 183 | self.demo_timers = np.zeros(len(initial_state.players)) 184 | self.boost_timers = np.zeros(len(initial_state.boost_pads)) 185 | 186 | @staticmethod 187 | def _quats_to_rot_mtx(quats: np.ndarray) -> np.ndarray: 188 | # From rlgym.utils.math.quat_to_rot_mtx 189 | w = -quats[:, 0] 190 | x = -quats[:, 1] 191 | y = -quats[:, 2] 192 | z = -quats[:, 3] 193 | 194 | theta = np.zeros((quats.shape[0], 3, 3)) 195 | 196 | norm = np.einsum("fq,fq->f", quats, quats) 197 | 198 | sel = norm != 0 199 | 200 | w = w[sel] 201 | x = x[sel] 202 | y = y[sel] 203 | z = z[sel] 204 | 205 | s = 1.0 / norm[sel] 206 | 207 | # front direction 208 | theta[sel, 0, 0] = 1.0 - 2.0 * s * (y * y + z * z) 209 | theta[sel, 1, 0] = 2.0 * s * (x * y + z * w) 210 | theta[sel, 2, 0] = 2.0 * s * (x * z - y * w) 211 | 212 | # left direction 213 | theta[sel, 0, 1] = 2.0 * s * (x * y - z * w) 214 | theta[sel, 1, 1] = 1.0 - 2.0 * s * (x * x + z * z) 215 | theta[sel, 2, 1] = 2.0 * s * (y * z + x * w) 216 | 217 | # up direction 218 | theta[sel, 0, 2] = 2.0 * s * (x * z + y * w) 219 | theta[sel, 1, 2] = 2.0 * s * (y * z - x * w) 220 | theta[sel, 2, 2] = 1.0 - 2.0 * s * (x * x + y * y) 221 | 222 | return theta 223 | 224 | @staticmethod 225 | def convert_to_relative(q, kv): 226 | # kv[..., POS.start:LIN_VEL.stop] -= q[..., POS.start:LIN_VEL.stop] 227 | kv[..., POS] -= q[..., POS] 228 | forward = q[..., FW] 229 | theta = np.arctan2(forward[..., 0], forward[..., 1]) 230 | theta = np.expand_dims(theta, axis=-1) 231 | ct = np.cos(theta) 232 | st = np.sin(theta) 233 | xs = kv[..., POS.start:ANG_VEL.stop:3] 234 | ys = kv[..., POS.start + 1:ANG_VEL.stop:3] 235 | # Use temp variables to prevent modifying original array 236 | nx = ct * xs - st * ys 237 | ny = st * xs + ct * ys 238 | kv[..., POS.start:ANG_VEL.stop:3] = nx # x-components 239 | kv[..., POS.start + 1:ANG_VEL.stop:3] = ny # y-components 240 | 241 | def batched_build_obs(self, encoded_states: np.ndarray): 242 | ball_start_index = 3 + len(self._boost_locations) 243 | players_start_index = ball_start_index + BALL_STATE_LENGTH 244 | player_length = PLAYER_INFO_LENGTH 245 | 246 | n_players = (encoded_states.shape[1] - players_start_index) // player_length 247 | lim_players = n_players if self.n_players is None else self.n_players 248 | n_entities = lim_players + 1 + 34 249 | 250 | # SELECTORS 251 | sel_players = slice(0, lim_players) 252 | sel_ball = sel_players.stop 253 | sel_boosts = slice(sel_ball + 1, None) 254 | 255 | # MAIN ARRAYS 256 | q = np.zeros((n_players, encoded_states.shape[0], 1, 32)) 257 | kv = np.zeros((n_players, encoded_states.shape[0], n_entities, 24)) # Keys and values are (mostly) shared 258 | m = np.zeros((n_players, encoded_states.shape[0], n_entities)) # Mask is shared 259 | 260 | # BALL 261 | kv[:, :, sel_ball, 3] = 1 262 | kv[:, :, sel_ball, np.r_[POS, LIN_VEL, ANG_VEL]] = encoded_states[:, ball_start_index: ball_start_index + 9] 263 | 264 | # BOOSTS 265 | kv[:, :, sel_boosts, IS_BOOST] = 1 266 | kv[:, :, sel_boosts, POS] = self._boost_locations 267 | kv[:, :, sel_boosts, BOOST] = 0.12 + 0.88 * (self._boost_locations[:, 2] > 72) 268 | kv[:, :, sel_boosts, DEMO] = encoded_states[:, 3:3 + 34] # FIXME boost timer 269 | 270 | # PLAYERS 271 | teams = encoded_states[0, players_start_index + 1::player_length] 272 | kv[:, :, :n_players, IS_MATE] = 1 - teams # Default team is blue 273 | kv[:, :, :n_players, IS_OPP] = teams 274 | for i in range(n_players): 275 | encoded_player = encoded_states[:, 276 | players_start_index + i * player_length: players_start_index + (i + 1) * player_length] 277 | 278 | kv[i, :, i, IS_SELF] = 1 279 | kv[:, :, i, POS] = encoded_player[:, 2: 5] # TODO constants for these indices 280 | kv[:, :, i, LIN_VEL] = encoded_player[:, 9: 12] 281 | quats = encoded_player[:, 5: 9] 282 | rot_mtx = self._quats_to_rot_mtx(quats) 283 | kv[:, :, i, FW] = rot_mtx[:, :, 0] 284 | kv[:, :, i, UP] = rot_mtx[:, :, 2] 285 | kv[:, :, i, ANG_VEL] = encoded_player[:, 12: 15] 286 | kv[:, :, i, BOOST] = encoded_player[:, 37] 287 | kv[:, :, i, DEMO] = encoded_player[:, 33] # FIXME demo timer 288 | kv[:, :, i, ON_GROUND] = encoded_player[:, 34] 289 | kv[:, :, i, HAS_FLIP] = encoded_player[:, 36] 290 | 291 | kv[teams == 1] *= self._invert 292 | kv[np.argwhere(teams == 1), ..., (IS_MATE, IS_OPP)] = kv[ 293 | np.argwhere(teams == 1), ..., (IS_OPP, IS_MATE)] # Swap teams 294 | 295 | kv /= self._norm 296 | 297 | for i in range(n_players): 298 | q[i, :, 0, :kv.shape[-1]] = kv[i, :, i, :] 299 | 300 | self.convert_to_relative(q, kv) 301 | # kv[:, :, :, 5:11] -= q[:, :, :, 5:11] 302 | 303 | # MASK 304 | m[:, :, n_players: lim_players] = 1 305 | 306 | return [(q[i], kv[i], m[i]) for i in range(n_players)] 307 | 308 | def add_actions(self, obs: Any, previous_actions: np.ndarray, player_index=None): 309 | if player_index is None: 310 | for (q, kv, m), act in zip(obs, previous_actions): 311 | q[:, 0, ACTIONS] = act 312 | else: 313 | q, kv, m = obs[player_index] 314 | q[:, 0, ACTIONS] = previous_actions 315 | -------------------------------------------------------------------------------- /rocket_learn/utils/stat_trackers/common_trackers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from rocket_learn.utils.gamestate_encoding import StateConstants 4 | from rocket_learn.utils.stat_trackers.stat_tracker import StatTracker 5 | 6 | 7 | class Speed(StatTracker): 8 | def __init__(self): 9 | super().__init__("average_speed") 10 | self.count = 0 11 | self.total_speed = 0.0 12 | 13 | def reset(self): 14 | self.count = 0 15 | self.total_speed = 0.0 16 | 17 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 18 | players = gamestates[:, StateConstants.PLAYERS] 19 | xs = players[:, StateConstants.CAR_LINEAR_VEL_X] 20 | ys = players[:, StateConstants.CAR_LINEAR_VEL_Y] 21 | zs = players[:, StateConstants.CAR_LINEAR_VEL_Z] 22 | 23 | speeds = np.sqrt(xs ** 2 + ys ** 2 + zs ** 2) 24 | self.total_speed += np.sum(speeds) 25 | self.count += speeds.size 26 | 27 | def get_stat(self): 28 | return self.total_speed / (self.count or 1) 29 | 30 | 31 | class Demos(StatTracker): 32 | def __init__(self): 33 | super().__init__("average_demos") 34 | self.count = 0 35 | self.total_demos = 0 36 | 37 | def reset(self): 38 | self.count = 0 39 | self.total_demos = 0 40 | 41 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 42 | players = gamestates[:, StateConstants.PLAYERS] 43 | 44 | demos = np.clip(players[-1, StateConstants.MATCH_DEMOLISHES] - players[0, StateConstants.MATCH_DEMOLISHES], 45 | 0, None) 46 | self.total_demos += np.sum(demos) 47 | self.count += demos.size 48 | 49 | def get_stat(self): 50 | return self.total_demos / (self.count or 1) 51 | 52 | 53 | class TimeoutRate(StatTracker): 54 | def __init__(self): 55 | super().__init__("timeout_rate") 56 | self.count = 0 57 | self.total_timeouts = 0 58 | 59 | def reset(self): 60 | self.count = 0 61 | self.total_timeouts = 0 62 | 63 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 64 | orange_diff = gamestates[-1, StateConstants.ORANGE_SCORE] - gamestates[0, StateConstants.ORANGE_SCORE] 65 | blue_diff = gamestates[-1, StateConstants.BLUE_SCORE] - gamestates[0, StateConstants.BLUE_SCORE] 66 | 67 | self.total_timeouts += ((orange_diff == 0) & (blue_diff == 0)).item() 68 | self.count += 1 69 | 70 | def get_stat(self): 71 | return self.total_timeouts / (self.count or 1) 72 | 73 | 74 | class Touch(StatTracker): 75 | def __init__(self): 76 | super().__init__("touch_rate") 77 | self.count = 0 78 | self.total_touches = 0 79 | 80 | def reset(self): 81 | self.count = 0 82 | self.total_touches = 0 83 | 84 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 85 | players = gamestates[:, StateConstants.PLAYERS] 86 | is_touch = players[:, StateConstants.BALL_TOUCHED] 87 | 88 | self.total_touches += np.sum(is_touch) 89 | self.count += is_touch.size 90 | 91 | def get_stat(self): 92 | return self.total_touches / (self.count or 1) 93 | 94 | 95 | class EpisodeLength(StatTracker): 96 | def __init__(self): 97 | super().__init__("episode_length") 98 | self.count = 0 99 | self.total_length = 0 100 | 101 | def reset(self): 102 | self.count = 0 103 | self.total_length = 0 104 | 105 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 106 | self.total_length += gamestates.shape[0] 107 | self.count += 1 108 | 109 | def get_stat(self): 110 | return self.total_length / (self.count or 1) 111 | 112 | 113 | class Boost(StatTracker): 114 | def __init__(self): 115 | super().__init__("average_boost") 116 | self.count = 0 117 | self.total_boost = 0 118 | 119 | def reset(self): 120 | self.count = 0 121 | self.total_boost = 0 122 | 123 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 124 | players = gamestates[:, StateConstants.PLAYERS] 125 | boost = players[:, StateConstants.BOOST_AMOUNT] 126 | is_limited = (0 <= boost) & (boost <= 1) 127 | boost = boost[is_limited] 128 | self.total_boost += np.sum(boost) 129 | self.count += boost.size 130 | 131 | def get_stat(self): 132 | return self.total_boost / (self.count or 1) 133 | 134 | 135 | class BehindBall(StatTracker): 136 | def __init__(self): 137 | super().__init__("behind_ball_rate") 138 | self.count = 0 139 | self.total_behind = 0 140 | 141 | def reset(self): 142 | self.count = 0 143 | self.total_behind = 0 144 | 145 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 146 | players = gamestates[:, StateConstants.PLAYERS] 147 | ball_y = gamestates[:, StateConstants.BALL_POSITION.start + 1] 148 | player_y = players[:, StateConstants.CAR_POS_Y] 149 | is_orange = players[:, StateConstants.TEAM_NUMS] 150 | behind = (2 * is_orange - 1) * (ball_y.reshape(-1, 1) - player_y) < 0 151 | 152 | self.total_behind += np.sum(behind) 153 | self.count += behind.size 154 | 155 | def get_stat(self): 156 | return self.total_behind / (self.count or 1) 157 | 158 | 159 | class TouchHeight(StatTracker): 160 | def __init__(self): 161 | super().__init__("touch_height") 162 | self.count = 0 163 | self.total_height = 0 164 | 165 | def reset(self): 166 | self.count = 0 167 | self.total_height = 0 168 | 169 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 170 | players = gamestates[:, StateConstants.PLAYERS] 171 | ball_z = gamestates[:, StateConstants.BALL_POSITION.start + 2] 172 | touch_heights = ball_z[players[:, StateConstants.BALL_TOUCHED].any(axis=1)] 173 | 174 | self.total_height += np.sum(touch_heights) 175 | self.count += touch_heights.size 176 | 177 | def get_stat(self): 178 | return self.total_height / (self.count or 1) 179 | 180 | 181 | class DistToBall(StatTracker): 182 | def __init__(self): 183 | super().__init__("distance_to_ball") 184 | self.count = 0 185 | self.total_dist = 0.0 186 | 187 | def reset(self): 188 | self.count = 0 189 | self.total_dist = 0.0 190 | 191 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 192 | players = gamestates[:, StateConstants.PLAYERS] 193 | ball = gamestates[:, StateConstants.BALL_POSITION] 194 | ball_x = ball[:, 0].reshape((-1, 1)) 195 | ball_y = ball[:, 1].reshape((-1, 1)) 196 | ball_z = ball[:, 2].reshape((-1, 1)) 197 | xs = players[:, StateConstants.CAR_POS_X] 198 | ys = players[:, StateConstants.CAR_POS_Y] 199 | zs = players[:, StateConstants.CAR_POS_Z] 200 | 201 | dists = np.sqrt((ball_x - xs) ** 2 + (ball_y - ys) ** 2 + (ball_z - zs) ** 2) 202 | self.total_dist += np.sum(dists) 203 | self.count += dists.size 204 | 205 | def get_stat(self): 206 | return self.total_dist / (self.count or 1) 207 | 208 | 209 | class AirTouch(StatTracker): 210 | def __init__(self): 211 | super().__init__("air_touch_rate") 212 | self.count = 0 213 | self.total_touches = 0 214 | 215 | def reset(self): 216 | self.count = 0 217 | self.total_touches = 0 218 | 219 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 220 | players = gamestates[:, StateConstants.PLAYERS] 221 | is_touch = np.asarray([a * b for a, b in 222 | zip(players[:, StateConstants.BALL_TOUCHED], 223 | np.invert(players[:, StateConstants.ON_GROUND].astype(bool)))]) 224 | 225 | self.total_touches += np.sum(is_touch) 226 | self.count += is_touch.size 227 | 228 | def get_stat(self): 229 | return self.total_touches / (self.count or 1) 230 | 231 | 232 | class AirTouchHeight(StatTracker): 233 | def __init__(self): 234 | super().__init__("air_touch_height") 235 | self.count = 0 236 | self.total_height = 0 237 | 238 | def reset(self): 239 | self.count = 0 240 | self.total_height = 0 241 | 242 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 243 | players = gamestates[:, StateConstants.PLAYERS] 244 | ball_z = gamestates[:, StateConstants.BALL_POSITION.start + 2] 245 | touch_heights = ball_z[players[:, StateConstants.BALL_TOUCHED].any(axis=1)] 246 | touch_heights = touch_heights[touch_heights >= 175] # remove dribble touches and below 247 | 248 | self.total_height += np.sum(touch_heights) 249 | self.count += touch_heights.size 250 | 251 | def get_stat(self): 252 | return self.total_height / (self.count or 1) 253 | 254 | 255 | class BallSpeed(StatTracker): 256 | def __init__(self): 257 | super().__init__("average_ball_speed") 258 | self.count = 0 259 | self.total_speed = 0.0 260 | 261 | def reset(self): 262 | self.count = 0 263 | self.total_speed = 0.0 264 | 265 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 266 | ball_speeds = gamestates[:, StateConstants.BALL_LINEAR_VELOCITY] 267 | xs = ball_speeds[:, 0] 268 | ys = ball_speeds[:, 1] 269 | zs = ball_speeds[:, 2] 270 | speeds = np.sqrt(xs ** 2 + ys ** 2 + zs ** 2) 271 | self.total_speed += np.sum(speeds) 272 | self.count += speeds.size 273 | 274 | def get_stat(self): 275 | return self.total_speed / (self.count or 1) 276 | 277 | 278 | class BallHeight(StatTracker): 279 | def __init__(self): 280 | super().__init__("ball_height") 281 | self.count = 0 282 | self.total_height = 0 283 | 284 | def reset(self): 285 | self.count = 0 286 | self.total_height = 0 287 | 288 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 289 | ball_z = gamestates[:, StateConstants.BALL_POSITION.start + 2] 290 | 291 | self.total_height += np.sum(ball_z) 292 | self.count += ball_z.size 293 | 294 | def get_stat(self): 295 | return self.total_height / (self.count or 1) 296 | 297 | 298 | class GoalSpeed(StatTracker): 299 | def __init__(self): 300 | super().__init__("avg_goal_speed") 301 | self.count = 0 302 | self.total_speed = 0 303 | 304 | def reset(self): 305 | self.count = 0 306 | self.total_speed = 0 307 | 308 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 309 | orange_diff = np.diff(gamestates[:, StateConstants.ORANGE_SCORE], append=np.nan) 310 | blue_diff = np.diff(gamestates[:, StateConstants.BLUE_SCORE], append=np.nan) 311 | 312 | goal_frames = (orange_diff > 0) | (blue_diff > 0) 313 | 314 | goal_speed = gamestates[goal_frames, StateConstants.BALL_LINEAR_VELOCITY] 315 | goal_speed = np.linalg.norm(goal_speed, axis=-1).sum() 316 | 317 | self.total_speed += goal_speed / 27.78 # convert to km/h 318 | self.count += goal_speed.size 319 | 320 | def get_stat(self): 321 | return self.total_speed / (self.count or 1) 322 | 323 | 324 | class MaxGoalSpeed(StatTracker): 325 | def __init__(self): 326 | super().__init__("max_goal_speed") 327 | self.max_speed = 0 328 | 329 | def reset(self): 330 | self.max_speed = 0 331 | 332 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 333 | if gamestates.ndim > 1 and len(gamestates) > 1: 334 | end = gamestates[-2] 335 | ball_speeds = end[StateConstants.BALL_LINEAR_VELOCITY] 336 | goal_speed = float(np.linalg.norm(ball_speeds)) / 27.78 # convert to km/h 337 | 338 | self.max_speed = max(float(self.max_speed), goal_speed) 339 | 340 | def get_stat(self): 341 | return self.max_speed 342 | 343 | 344 | class CarOnGround(StatTracker): 345 | def __init__(self): 346 | super().__init__("pct_car_on_ground") 347 | self.count = 0 348 | self.total_ground = 0.0 349 | 350 | def reset(self): 351 | self.count = 0 352 | self.total_ground = 0.0 353 | 354 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 355 | on_ground = gamestates[:, StateConstants.ON_GROUND] 356 | 357 | self.total_ground += np.sum(on_ground) 358 | self.count += on_ground.size 359 | 360 | def get_stat(self): 361 | return 100 * self.total_ground / (self.count or 1) 362 | 363 | 364 | class Saves(StatTracker): 365 | def __init__(self): 366 | super().__init__("average_saves") 367 | self.count = 0 368 | self.total_saves = 0 369 | 370 | def reset(self): 371 | self.count = 0 372 | self.total_saves = 0 373 | 374 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 375 | players = gamestates[:, StateConstants.PLAYERS] 376 | 377 | saves = np.clip(players[-1, StateConstants.MATCH_SAVES] - players[0, StateConstants.MATCH_SAVES], 378 | 0, None) 379 | self.total_saves += np.sum(saves) 380 | self.count += saves.size 381 | 382 | def get_stat(self): 383 | return self.total_saves / (self.count or 1) 384 | 385 | 386 | class Shots(StatTracker): 387 | def __init__(self): 388 | super().__init__("average_shots") 389 | self.count = 0 390 | self.total_shots = 0 391 | 392 | def reset(self): 393 | self.count = 0 394 | self.total_shots = 0 395 | 396 | def update(self, gamestates: np.ndarray, mask: np.ndarray): 397 | players = gamestates[:, StateConstants.PLAYERS] 398 | 399 | shots = np.clip(players[-1, StateConstants.MATCH_SHOTS] - players[0, StateConstants.MATCH_SHOTS], 400 | 0, None) 401 | self.total_shots += np.sum(shots) 402 | self.count += shots.size 403 | 404 | def get_stat(self): 405 | return self.total_shots / (self.count or 1) 406 | -------------------------------------------------------------------------------- /rocket_learn/rollout_generator/redis/redis_rollout_generator.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections import Counter 3 | from typing import Iterator, Callable, Optional, List 4 | 5 | import numpy as np 6 | import plotly.graph_objs as go 7 | # import matplotlib.pyplot # noqa 8 | import wandb 9 | # from matplotlib.axes import Axes 10 | # from matplotlib.figure import Figure 11 | from redis import Redis 12 | from redis.exceptions import ResponseError 13 | from rlgym.utils import ObsBuilder, RewardFunction 14 | from rlgym.utils.action_parsers import ActionParser 15 | from trueskill import Rating, rate, SIGMA 16 | 17 | from rocket_learn.experience_buffer import ExperienceBuffer 18 | from rocket_learn.rollout_generator.base_rollout_generator import BaseRolloutGenerator 19 | from rocket_learn.rollout_generator.redis.utils import decode_buffers, _unserialize, QUALITIES, _serialize, ROLLOUTS, \ 20 | VERSION_LATEST, OPPONENT_MODELS, CONTRIBUTORS, N_UPDATES, MODEL_LATEST, _serialize_model, get_rating, \ 21 | _ALL, LATEST_RATING_ID, EXPERIENCE_PER_MODE 22 | from rocket_learn.utils.stat_trackers.stat_tracker import StatTracker 23 | 24 | 25 | class RedisRolloutGenerator(BaseRolloutGenerator): 26 | """ 27 | Rollout generator in charge of sending commands to workers via redis 28 | """ 29 | 30 | def __init__( 31 | self, 32 | name: str, 33 | redis: Redis, 34 | obs_build_factory: Callable[[], ObsBuilder], 35 | rew_func_factory: Callable[[], RewardFunction], 36 | act_parse_factory: Callable[[], ActionParser], 37 | save_every=10, 38 | model_every=100, 39 | logger=None, 40 | clear=True, 41 | max_age=0, 42 | default_sigma=SIGMA, 43 | min_sigma=1, 44 | gamemodes=("1v1", "2v2", "3v3"), 45 | stat_trackers: Optional[List[StatTracker]] = None, 46 | ): 47 | self.lastsave_ts = None 48 | self.name = name 49 | self.tot_bytes = 0 50 | self.redis = redis 51 | self.logger = logger 52 | 53 | # TODO saving/loading 54 | if clear: 55 | self.redis.delete(*(_ALL + tuple(QUALITIES.format(gm) for gm in gamemodes))) 56 | self.redis.set(N_UPDATES, 0) 57 | self.redis.hset(EXPERIENCE_PER_MODE, mapping={m: 0 for m in gamemodes}) 58 | else: 59 | if self.redis.exists(ROLLOUTS) > 0: 60 | self.redis.delete(ROLLOUTS) 61 | self.redis.decr(VERSION_LATEST, 62 | max_age + 1) # In case of reload from old version, don't let current seep in 63 | 64 | # self.redis.set(SAVE_FREQ, save_every) 65 | # self.redis.set(MODEL_FREQ, model_every) 66 | self.save_freq = save_every 67 | self.model_freq = model_every 68 | self.contributors = Counter() # No need to save, clears every iteration 69 | self.obs_build_func = obs_build_factory 70 | self.rew_func_factory = rew_func_factory 71 | self.act_parse_factory = act_parse_factory 72 | self.max_age = max_age 73 | self.default_sigma = default_sigma 74 | self.min_sigma = min_sigma 75 | self.gamemodes = gamemodes 76 | self.stat_trackers = stat_trackers or [] 77 | self._reset_stats() 78 | 79 | @staticmethod 80 | def _process_rollout(rollout_bytes, latest_version, obs_build_func, rew_build_func, act_build_func, max_age): 81 | rollout_data, versions, uuid, name, result, has_obs, has_states, has_rewards = _unserialize(rollout_bytes) 82 | 83 | v_check = [v for v in versions if isinstance(v, int) or v.startswith("-")] 84 | 85 | if any(version < 0 and abs(version - latest_version) > max_age for version in v_check): 86 | return 87 | 88 | if any(version < 0 for version in v_check): 89 | buffers, states = decode_buffers(rollout_data, versions, has_obs, has_states, has_rewards, 90 | obs_build_func, rew_build_func, act_build_func) 91 | else: 92 | buffers = states = [None] * len(v_check) 93 | 94 | return buffers, states, versions, uuid, name, result 95 | 96 | def _update_ratings(self, name, versions, buffers, latest_version, result): 97 | ratings = [] 98 | relevant_buffers = [] 99 | gamemode = f"{len(versions) // 2}v{len(versions) // 2}" # TODO: support unfair games 100 | 101 | versions = [v for v in versions if v != "na"] 102 | for version, buffer in itertools.zip_longest(versions, buffers): 103 | if isinstance(version, int) and version < 0: 104 | if abs(version - latest_version) <= self.max_age: 105 | relevant_buffers.append(buffer) 106 | self.contributors[name] += buffer.size() 107 | else: 108 | return [] 109 | else: 110 | rating = get_rating(gamemode, version, self.redis) 111 | ratings.append(rating) 112 | 113 | # Only old versions, calculate MMR 114 | if len(ratings) == len(versions) and len(buffers) == 0: 115 | blue_players = sum(divmod(len(ratings), 2)) 116 | blue = tuple(ratings[:blue_players]) # Tuple is important 117 | orange = tuple(ratings[blue_players:]) 118 | 119 | # In ranks lowest number is best, result=-1 is orange win, 0 tie, 1 blue 120 | r1, r2 = rate((blue, orange), ranks=(0, result)) 121 | 122 | # Some trickery to handle same rating appearing multiple times, we just average their new mus and sigmas 123 | ratings_versions = {} 124 | for rating, version in zip(r1 + r2, versions): 125 | ratings_versions.setdefault(version, []).append(rating) 126 | 127 | mapping = {} 128 | for version, ratings in ratings_versions.items(): 129 | # In case of duplicates, average ratings together (not strictly necessary with default setup) 130 | # Also limit sigma to its lower bound 131 | avg_rating = Rating(sum(r.mu for r in ratings) / len(ratings), 132 | max(sum(r.sigma ** 2 for r in ratings) ** 0.5 / len(ratings), self.min_sigma)) 133 | mapping[version] = _serialize(tuple(avg_rating)) 134 | gamemode = f"{len(versions) // 2}v{len(versions) // 2}" # TODO: support unfair games 135 | self.redis.hset(QUALITIES.format(gamemode), mapping=mapping) 136 | 137 | if len(relevant_buffers) > 0: 138 | self.redis.hincrby(EXPERIENCE_PER_MODE, gamemode, len(relevant_buffers) * relevant_buffers[0].size()) 139 | 140 | return relevant_buffers 141 | 142 | def _reset_stats(self): 143 | for stat_tracker in self.stat_trackers: 144 | stat_tracker.reset() 145 | 146 | def _update_stats(self, states, mask): 147 | if states is None: 148 | return 149 | for stat_tracker in self.stat_trackers: 150 | stat_tracker.update(states, mask) 151 | 152 | def _get_stats(self): 153 | stats = {} 154 | for stat_tracker in self.stat_trackers: 155 | stats[stat_tracker.name] = stat_tracker.get_stat() 156 | return stats 157 | 158 | def generate_rollouts(self) -> Iterator[ExperienceBuffer]: 159 | while True: 160 | latest_version = int(self.redis.get(VERSION_LATEST)) 161 | data = self.redis.blpop(ROLLOUTS)[1] 162 | self.tot_bytes += len(data) 163 | res = self._process_rollout( 164 | data, latest_version, 165 | self.obs_build_func, self.rew_func_factory, self.act_parse_factory, 166 | self.max_age 167 | ) 168 | if res is not None: 169 | buffers, states, versions, uuid, name, result = res 170 | # versions = [version for version in versions if version != 'na'] # don't track humans or hardcoded 171 | 172 | relevant_buffers = self._update_ratings(name, versions, buffers, latest_version, result) 173 | if len(relevant_buffers) > 0: 174 | self._update_stats(states, [b in relevant_buffers for b in buffers]) 175 | yield from relevant_buffers 176 | 177 | def _plot_ratings(self): 178 | fig_data = [] 179 | i = 0 180 | means = {} 181 | mean_key = "mean" 182 | gamemodes = list(self.gamemodes) 183 | if len(gamemodes) > 1: 184 | gamemodes.append(mean_key) 185 | for gamemode in gamemodes: 186 | if gamemode != mean_key: 187 | ratings = get_rating(gamemode, None, self.redis) 188 | if len(ratings) <= 0: 189 | return 190 | baseline = None 191 | for mode in "stochastic", "deterministic": 192 | if gamemode != "mean": 193 | x = [] 194 | mus = [] 195 | sigmas = [] 196 | for k, r in ratings.items(): # noqa 197 | if k.endswith(mode): 198 | v = int(k.rsplit("-", 2)[1][1:]) 199 | # v = int(k.split("-")[1][1:]) 200 | x.append(v) 201 | mus.append(r.mu) 202 | sigmas.append(r.sigma) 203 | mean = means.setdefault(mode, {}).get(v, (0, 0)) 204 | means[mode][v] = (mean[0] + r.mu, mean[1] + r.sigma ** 2) 205 | # *Smoothly* transition from red, to green, to blue depending on gamemode 206 | mid = (len(self.gamemodes) - 1) / 2 207 | # avoid divide by 0 issues if there's only one gamemode, this moves it halfway into the colors 208 | if mid == 0: 209 | mid = 0.5 210 | if i < mid: 211 | r = 1 - i / mid 212 | g = i / mid 213 | b = 0 214 | else: 215 | r = 0 216 | g = 1 - (i - mid) / mid 217 | b = (i - mid) / mid 218 | else: 219 | means_mode = means.get(mode, {}) 220 | x = list(means_mode.keys()) 221 | mus = [mean[0] / len(self.gamemodes) for mean in means_mode.values()] 222 | sigmas = [(mean[1] / len(self.gamemodes)) ** 0.5 for mean in means_mode.values()] 223 | r = g = b = 1 / 3 224 | 225 | indices = np.argsort(x) 226 | x = np.array(x)[indices] 227 | mus = np.array(mus)[indices] 228 | sigmas = np.array(sigmas)[indices] 229 | 230 | if baseline is None: 231 | baseline = mus[0] # Stochastic initialization is defined as the baseline (0 mu) 232 | mus = mus - baseline 233 | y = mus 234 | y_upper = mus + 2 * sigmas # 95% confidence 235 | y_lower = mus - 2 * sigmas 236 | 237 | scale = 255 if mode == "stochastic" else 128 238 | color = f"{int(r * scale)},{int(g * scale)},{int(b * scale)}" 239 | 240 | fig_data += [ 241 | go.Scatter( 242 | x=x, 243 | y=y, 244 | line=dict(color=f'rgb({color})'), 245 | mode='lines', 246 | name=f"{gamemode}-{mode}", 247 | legendgroup=f"{gamemode}-{mode}", 248 | showlegend=True, 249 | visible=None if gamemode == gamemodes[-1] else "legendonly", 250 | ), 251 | go.Scatter( 252 | x=np.concatenate((x, x[::-1])), # x, then x reversed 253 | y=np.concatenate((y_upper, y_lower[::-1])), # upper, then lower reversed 254 | fill='toself', 255 | fillcolor=f'rgba({color},0.2)', 256 | line=dict(color='rgba(255,255,255,0)'), 257 | hoverinfo="skip", 258 | name="sigma", 259 | legendgroup=f"{gamemode}-{mode}", 260 | showlegend=False, 261 | visible=None if gamemode == gamemodes[-1] else "legendonly", 262 | ), 263 | ] 264 | i += 1 265 | 266 | if len(fig_data) <= 0: 267 | return 268 | 269 | fig = go.Figure(fig_data) 270 | fig.update_layout(title="Rating", xaxis_title="Iteration", yaxis_title="TrueSkill") 271 | 272 | self.logger.log({ 273 | "qualities": fig, 274 | }, commit=False) 275 | 276 | def _add_opponent(self, agent): 277 | latest_id = self.redis.get(LATEST_RATING_ID) 278 | prefix = f"{self.name}-v" 279 | if latest_id is None: 280 | version = 0 281 | else: 282 | latest_id = latest_id.decode("utf-8") 283 | version = int(latest_id.replace(prefix, "")) + 1 284 | key = f"{prefix}{version}" 285 | 286 | # Add to list 287 | self.redis.hset(OPPONENT_MODELS, key, agent) 288 | 289 | # Set quality 290 | for gamemode in self.gamemodes: 291 | ratings = get_rating(gamemode, None, self.redis) 292 | 293 | for mode in "stochastic", "deterministic": 294 | if latest_id is not None: 295 | latest_key = f"{latest_id}-{mode}" 296 | quality = Rating(ratings[latest_key].mu, self.default_sigma) 297 | else: 298 | quality = Rating(0, self.min_sigma) # First (basically random) agent is initialized at 0 299 | 300 | self.redis.hset(QUALITIES.format(gamemode), f"{key}-{mode}", _serialize(tuple(quality))) 301 | 302 | # Inform that new opponent is ready 303 | self.redis.set(LATEST_RATING_ID, key) 304 | 305 | def update_parameters(self, new_params): 306 | """ 307 | update redis (and thus workers) with new model data and save data as future opponent 308 | :param new_params: new model parameters 309 | """ 310 | model_bytes = _serialize_model(new_params) 311 | self.redis.set(MODEL_LATEST, model_bytes) 312 | self.redis.decr(VERSION_LATEST) 313 | 314 | print("Top contributors:\n" + "\n".join(f"\t{c}: {n}" for c, n in self.contributors.most_common(5))) 315 | self.logger.log({ 316 | "redis/contributors": wandb.Table(columns=["name", "steps"], data=self.contributors.most_common())}, 317 | commit=False 318 | ) 319 | self._plot_ratings() 320 | tot_contributors = self.redis.hgetall(CONTRIBUTORS) 321 | tot_contributors = Counter({name: int(count) for name, count in tot_contributors.items()}) 322 | tot_contributors += self.contributors 323 | if tot_contributors: 324 | self.redis.hset(CONTRIBUTORS, mapping=tot_contributors) 325 | self.contributors.clear() 326 | 327 | self.logger.log({"redis/rollout_bytes": self.tot_bytes}, commit=False) 328 | self.tot_bytes = 0 329 | 330 | n_updates = self.redis.incr(N_UPDATES) - 1 331 | # save_freq = int(self.redis.get(SAVE_FREQ)) 332 | 333 | if n_updates > 0: 334 | self.logger.log({f"stat/{name}": value for name, value in self._get_stats().items()}, commit=False) 335 | self._reset_stats() 336 | 337 | if n_updates % self.model_freq == 0: 338 | print("Adding model to pool...") 339 | self._add_opponent(model_bytes) 340 | 341 | if n_updates % self.save_freq == 0: 342 | # self.redis.set(MODEL_N.format(self.n_updates // self.save_every), model_bytes) 343 | print("Saving model...") 344 | if self.lastsave_ts == self.redis.lastsave(): 345 | print("redis save error, previous bgsave failed") 346 | self.lastsave_ts = self.redis.lastsave() 347 | try: 348 | self.redis.bgsave() 349 | except ResponseError: 350 | print("redis bgsave failed, auto save already in progress") 351 | -------------------------------------------------------------------------------- /rocket_learn/ppo.py: -------------------------------------------------------------------------------- 1 | import cProfile 2 | import io 3 | import os 4 | import pstats 5 | import time 6 | import sys 7 | from typing import Iterator, List, Tuple, Union 8 | 9 | import numba 10 | import numpy as np 11 | import torch 12 | import torch as th 13 | from torch.distributions import kl_divergence 14 | from torch.nn import functional as F 15 | from torch.nn.utils import clip_grad_norm_ 16 | 17 | from rocket_learn.agent.actor_critic_agent import ActorCriticAgent 18 | from rocket_learn.agent.policy import Policy 19 | from rocket_learn.experience_buffer import ExperienceBuffer 20 | from rocket_learn.rollout_generator.base_rollout_generator import BaseRolloutGenerator 21 | 22 | 23 | class PPO: 24 | """ 25 | Proximal Policy Optimization algorithm (PPO) 26 | 27 | :param rollout_generator: Function that will generate the rollouts 28 | :param agent: An ActorCriticAgent 29 | :param n_steps: The number of steps to run per update 30 | :param gamma: Discount factor 31 | :param batch_size: batch size to break experience data into for training 32 | :param epochs: Number of epoch when optimizing the loss 33 | :param minibatch_size: size to break batch sets into (helps combat VRAM issues) 34 | :param clip_range: PPO Clipping parameter for the value function 35 | :param ent_coef: Entropy coefficient for the loss calculation 36 | :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator 37 | :param vf_coef: Value function coefficient for the loss calculation 38 | :param max_grad_norm: optional clip_grad_norm value 39 | :param logger: wandb logger to store run results 40 | :param device: torch device 41 | :param zero_grads_with_none: 0 gradient with None instead of 0 42 | 43 | Look here for info on zero_grads_with_none 44 | https://pytorch.org/docs/master/generated/torch.optim.Optimizer.zero_grad.html#torch.optim.Optimizer.zero_grad 45 | """ 46 | 47 | def __init__( 48 | self, 49 | rollout_generator: BaseRolloutGenerator, 50 | agent: ActorCriticAgent, 51 | n_steps=4096, 52 | gamma=0.99, 53 | batch_size=512, 54 | epochs=10, 55 | # reuse=2, 56 | minibatch_size=None, 57 | clip_range=0.2, 58 | ent_coef=0.01, 59 | gae_lambda=0.95, 60 | vf_coef=1, 61 | max_grad_norm=0.5, 62 | logger=None, 63 | device="cuda", 64 | zero_grads_with_none=False, 65 | kl_models_weights: List[Union[Tuple[Policy, float], Tuple[Policy, float, float]]] = None 66 | ): 67 | self.rollout_generator = rollout_generator 68 | 69 | # TODO let users choose their own agent 70 | # TODO move agent to rollout generator 71 | self.agent = agent.to(device) 72 | self.device = device 73 | self.zero_grads_with_none = zero_grads_with_none 74 | self.frozen_iterations = 0 75 | self._saved_lr = None 76 | 77 | self.starting_iteration = 0 78 | 79 | # hyperparameters 80 | self.epochs = epochs 81 | self.gamma = gamma 82 | # assert n_steps % batch_size == 0 83 | # self.reuse = reuse 84 | self.n_steps = n_steps 85 | self.gae_lambda = gae_lambda 86 | self.batch_size = batch_size 87 | self.minibatch_size = minibatch_size or batch_size 88 | assert self.batch_size % self.minibatch_size == 0 89 | self.clip_range = clip_range 90 | self.ent_coef = ent_coef 91 | self.vf_coef = vf_coef 92 | self.max_grad_norm = max_grad_norm 93 | 94 | self.running_rew_mean = 0 95 | self.running_rew_var = 1 96 | self.running_rew_count = 1e-4 97 | 98 | self.total_steps = 0 99 | self.logger = logger 100 | self.logger.watch((self.agent.actor, self.agent.critic)) 101 | self.timer = time.time_ns() // 1_000_000 102 | self.jit_tracer = None 103 | 104 | if kl_models_weights is not None: 105 | for i in range(len(kl_models_weights)): 106 | assert len(kl_models_weights[i]) in (2, 3) 107 | if len(kl_models_weights[i]) == 2: 108 | kl_models_weights[i] = kl_models_weights[i] + (None,) 109 | self.kl_models_weights = kl_models_weights 110 | 111 | def update_reward_norm(self, rewards: np.ndarray) -> np.ndarray: 112 | batch_mean = np.mean(rewards) 113 | batch_var = np.var(rewards) 114 | batch_count = rewards.shape[0] 115 | 116 | delta = batch_mean - self.running_rew_mean 117 | tot_count = self.running_rew_count + batch_count 118 | 119 | new_mean = self.running_rew_mean + delta * batch_count / tot_count 120 | m_a = self.running_rew_var * self.running_rew_count 121 | m_b = batch_var * batch_count 122 | m_2 = m_a + m_b + np.square(delta) * self.running_rew_count * batch_count / ( 123 | self.running_rew_count + batch_count) 124 | new_var = m_2 / (self.running_rew_count + batch_count) 125 | 126 | new_count = batch_count + self.running_rew_count 127 | 128 | self.running_rew_mean = new_mean 129 | self.running_rew_var = new_var 130 | self.running_rew_count = new_count 131 | 132 | return (rewards - self.running_rew_mean) / np.sqrt(self.running_rew_var + 1e-8) # TODO normalize before update? 133 | 134 | def run(self, iterations_per_save=10, save_dir=None, save_jit=False): 135 | """ 136 | Generate rollout data and train 137 | :param iterations_per_save: number of iterations between checkpoint saves 138 | :param save_dir: where to save 139 | """ 140 | if save_dir: 141 | current_run_dir = os.path.join(save_dir, self.logger.project + "_" + str(time.time())) 142 | os.makedirs(current_run_dir) 143 | elif iterations_per_save and not save_dir: 144 | print("Warning: no save directory specified.") 145 | print("Checkpoints will not be saved.") 146 | 147 | iteration = self.starting_iteration 148 | rollout_gen = self.rollout_generator.generate_rollouts() 149 | 150 | self.rollout_generator.update_parameters(self.agent.actor) 151 | 152 | while True: 153 | # pr = cProfile.Profile() 154 | # pr.enable() 155 | t0 = time.time() 156 | 157 | def _iter(): 158 | size = 0 159 | print(f"Collecting rollouts ({iteration})...") 160 | while size < self.n_steps: 161 | try: 162 | rollout = next(rollout_gen) 163 | if rollout.size() > 0: 164 | size += rollout.size() 165 | # progress.update(rollout.size()) 166 | yield rollout 167 | except StopIteration: 168 | return 169 | 170 | self.calculate(_iter(), iteration) 171 | iteration += 1 172 | 173 | if save_dir: 174 | self.save(os.path.join(save_dir, self.logger.project + "_" + "latest"), -1, save_jit) 175 | if iteration % iterations_per_save == 0: 176 | self.save(current_run_dir, iteration, save_jit) # noqa 177 | 178 | if self.frozen_iterations > 0: 179 | if self.frozen_iterations == 1: 180 | print(" ** Unfreezing policy network **") 181 | 182 | assert self._saved_lr is not None 183 | self.agent.optimizer.param_groups[0]["lr"] = self._saved_lr 184 | self._saved_lr = None 185 | 186 | self.frozen_iterations -= 1 187 | 188 | self.rollout_generator.update_parameters(self.agent.actor) 189 | 190 | self.total_steps += self.n_steps # size 191 | t1 = time.time() 192 | self.logger.log({"ppo/steps_per_second": self.n_steps / (t1 - t0), "ppo/total_timesteps": self.total_steps}) 193 | 194 | # pr.disable() 195 | # s = io.StringIO() 196 | # sortby = pstats.SortKey.CUMULATIVE 197 | # ps = pstats.Stats(pr, stream=s).sort_stats(sortby) 198 | # ps.dump_stats(f"profile_{self.total_steps}") 199 | 200 | def set_logger(self, logger): 201 | self.logger = logger 202 | 203 | def evaluate_actions(self, observations, actions): 204 | """ 205 | Calculate Log Probability and Entropy of actions 206 | """ 207 | dist = self.agent.actor.get_action_distribution(observations) 208 | # indices = self.agent.get_action_indices(dists) 209 | 210 | log_prob = self.agent.actor.log_prob(dist, actions) 211 | entropy = self.agent.actor.entropy(dist, actions) 212 | 213 | entropy = -torch.mean(entropy) 214 | return log_prob, entropy, dist 215 | 216 | @staticmethod 217 | @numba.njit 218 | def _calculate_advantages_numba(rewards, values, gamma, gae_lambda, truncated): 219 | advantages = np.zeros_like(rewards) 220 | # v_targets = np.zeros_like(rewards) 221 | dones = np.zeros_like(rewards) 222 | dones[-1] = 1. if not truncated else 0. 223 | episode_starts = np.zeros_like(rewards) 224 | episode_starts[0] = 1. 225 | last_values = values[-1] 226 | last_gae_lam = 0 227 | size = len(advantages) 228 | for step in range(size - 1, -1, -1): 229 | if step == size - 1: 230 | next_non_terminal = 1.0 - dones[-1].item() 231 | next_values = last_values 232 | else: 233 | next_non_terminal = 1.0 - episode_starts[step + 1].item() 234 | next_values = values[step + 1] 235 | v_target = rewards[step] + gamma * next_values * next_non_terminal 236 | delta = v_target - values[step] 237 | last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam 238 | advantages[step] = last_gae_lam 239 | # v_targets[step] = v_target 240 | return advantages # , v_targets 241 | 242 | def calculate(self, buffers: Iterator[ExperienceBuffer], iteration): 243 | """ 244 | Calculate loss and update network 245 | """ 246 | obs_tensors = [] 247 | act_tensors = [] 248 | # value_tensors = [] 249 | log_prob_tensors = [] 250 | # advantage_tensors = [] 251 | returns_tensors = [] 252 | 253 | rewards_tensors = [] 254 | 255 | ep_rewards = [] 256 | ep_steps = [] 257 | n = 0 258 | 259 | for buffer in buffers: # Do discounts for each ExperienceBuffer individually 260 | if isinstance(buffer.observations[0], (tuple, list)): 261 | transposed = tuple(zip(*buffer.observations)) 262 | obs_tensor = tuple(torch.from_numpy(np.vstack(t)).float() for t in transposed) 263 | else: 264 | obs_tensor = th.from_numpy(np.vstack(buffer.observations)).float() 265 | 266 | with th.no_grad(): 267 | if isinstance(obs_tensor, tuple): 268 | x = tuple(o.to(self.device) for o in obs_tensor) 269 | else: 270 | x = obs_tensor.to(self.device) 271 | values = self.agent.critic(x).detach().cpu().numpy().flatten() # No batching? 272 | 273 | actions = np.stack(buffer.actions) 274 | log_probs = np.stack(buffer.log_probs) 275 | rewards = np.stack(buffer.rewards) 276 | dones = np.stack(buffer.dones) 277 | 278 | size = rewards.shape[0] 279 | 280 | advantages = self._calculate_advantages_numba(rewards, values, self.gamma, self.gae_lambda, dones[-1] == 2) 281 | 282 | returns = advantages + values 283 | 284 | obs_tensors.append(obs_tensor) 285 | act_tensors.append(th.from_numpy(actions)) 286 | log_prob_tensors.append(th.from_numpy(log_probs)) 287 | returns_tensors.append(th.from_numpy(returns)) 288 | rewards_tensors.append(th.from_numpy(rewards)) 289 | 290 | ep_rewards.append(rewards.sum()) 291 | ep_steps.append(size) 292 | n += 1 293 | ep_rewards = np.array(ep_rewards) 294 | ep_steps = np.array(ep_steps) 295 | 296 | self.logger.log({ 297 | "ppo/ep_reward_mean": ep_rewards.mean(), 298 | "ppo/ep_reward_std": ep_rewards.std(), 299 | "ppo/ep_len_mean": ep_steps.mean(), 300 | }, step=iteration, commit=False) 301 | 302 | if isinstance(obs_tensors[0], tuple): 303 | transposed = zip(*obs_tensors) 304 | obs_tensor = tuple(th.cat(t).float() for t in transposed) 305 | else: 306 | obs_tensor = th.cat(obs_tensors).float() 307 | act_tensor = th.cat(act_tensors) 308 | log_prob_tensor = th.cat(log_prob_tensors).float() 309 | # advantages_tensor = th.cat(advantage_tensors) 310 | returns_tensor = th.cat(returns_tensors).float() 311 | 312 | tot_loss = 0 313 | tot_policy_loss = 0 314 | tot_entropy_loss = 0 315 | tot_value_loss = 0 316 | total_kl_div = 0 317 | tot_clipped = 0 318 | 319 | if self.kl_models_weights is not None: 320 | tot_kl_other_models = np.zeros(len(self.kl_models_weights)) 321 | tot_kl_coeffs = np.zeros(len(self.kl_models_weights)) 322 | 323 | n = 0 324 | 325 | if self.jit_tracer is None: 326 | self.jit_tracer = obs_tensor[0].to(self.device) 327 | 328 | print("Training network...") 329 | 330 | if self.frozen_iterations > 0: 331 | print("Policy network frozen, only updating value network...") 332 | 333 | precompute = torch.cat([param.view(-1) for param in self.agent.actor.parameters()]) 334 | t0 = time.perf_counter_ns() 335 | self.agent.optimizer.zero_grad(set_to_none=self.zero_grads_with_none) 336 | for e in range(self.epochs): 337 | # this is mostly pulled from sb3 338 | indices = torch.randperm(returns_tensor.shape[0])[:self.batch_size] 339 | if isinstance(obs_tensor, tuple): 340 | obs_batch = tuple(o[indices] for o in obs_tensor) 341 | else: 342 | obs_batch = obs_tensor[indices] 343 | act_batch = act_tensor[indices] 344 | log_prob_batch = log_prob_tensor[indices] 345 | # advantages_batch = advantages_tensor[indices] 346 | returns_batch = returns_tensor[indices] 347 | 348 | for i in range(0, self.batch_size, self.minibatch_size): 349 | # Note: Will cut off final few samples 350 | 351 | if isinstance(obs_tensor, tuple): 352 | obs = tuple(o[i: i + self.minibatch_size].to(self.device) for o in obs_batch) 353 | else: 354 | obs = obs_batch[i: i + self.minibatch_size].to(self.device) 355 | 356 | act = act_batch[i: i + self.minibatch_size].to(self.device) 357 | # adv = advantages_batch[i:i + self.minibatch_size].to(self.device) 358 | ret = returns_batch[i: i + self.minibatch_size].to(self.device) 359 | 360 | old_log_prob = log_prob_batch[i: i + self.minibatch_size].to(self.device) 361 | 362 | # TODO optimization: use forward_actor_critic instead of separate in case shared, also use GPU 363 | try: 364 | log_prob, entropy, dist = self.evaluate_actions(obs, act) # Assuming obs and actions as input 365 | except ValueError as e: 366 | print("ValueError in evaluate_actions", e) 367 | continue 368 | 369 | ratio = torch.exp(log_prob - old_log_prob) 370 | 371 | values_pred = self.agent.critic(obs) 372 | 373 | values_pred = th.squeeze(values_pred) 374 | adv = ret - values_pred.detach() 375 | adv = (adv - th.mean(adv)) / (th.std(adv) + 1e-8) 376 | 377 | # clipped surrogate loss 378 | policy_loss_1 = adv * ratio 379 | policy_loss_2 = adv * th.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range) 380 | policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean() 381 | 382 | # **If we want value clipping, add it here** 383 | value_loss = F.mse_loss(ret, values_pred) 384 | 385 | if entropy is None: 386 | # Approximate entropy when no analytical form 387 | entropy_loss = -th.mean(-log_prob) 388 | else: 389 | entropy_loss = entropy 390 | 391 | kl_loss = 0 392 | if self.kl_models_weights is not None: 393 | for k, (model, kl_coef, half_life) in enumerate(self.kl_models_weights): 394 | if half_life is not None: 395 | kl_coef *= 0.5 ** (self.total_steps / half_life) 396 | with torch.no_grad(): 397 | dist_other = model.get_action_distribution(obs) 398 | div = kl_divergence(dist_other, dist).mean() 399 | tot_kl_other_models[k] += div 400 | tot_kl_coeffs[k] = kl_coef 401 | kl_loss += kl_coef * div 402 | 403 | loss = ((policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + kl_loss) 404 | / (self.batch_size / self.minibatch_size)) 405 | 406 | if not torch.isfinite(loss).all(): 407 | print("Non-finite loss, skipping", n) 408 | print("\tPolicy loss:", policy_loss) 409 | print("\tEntropy loss:", entropy_loss) 410 | print("\tValue loss:", value_loss) 411 | print("\tTotal loss:", loss) 412 | print("\tRatio:", ratio) 413 | print("\tAdv:", adv) 414 | print("\tLog prob:", log_prob) 415 | print("\tOld log prob:", old_log_prob) 416 | print("\tEntropy:", entropy) 417 | print("\tActor has inf:", any(not p.isfinite().all() for p in self.agent.actor.parameters())) 418 | print("\tCritic has inf:", any(not p.isfinite().all() for p in self.agent.critic.parameters())) 419 | print("\tReward as inf:", not np.isfinite(ep_rewards).all()) 420 | if isinstance(obs, tuple): 421 | for j in range(len(obs)): 422 | print(f"\tObs[{j}] has inf:", not obs[j].isfinite().all()) 423 | else: 424 | print("\tObs has inf:", not obs.isfinite().all()) 425 | continue 426 | 427 | loss.backward() 428 | 429 | # Unbiased low variance KL div estimator from http://joschu.net/blog/kl-approx.html 430 | total_kl_div += th.mean((ratio - 1) - (log_prob - old_log_prob)).item() 431 | tot_loss += loss.item() 432 | tot_policy_loss += policy_loss.item() 433 | tot_entropy_loss += entropy_loss.item() 434 | tot_value_loss += value_loss.item() 435 | tot_clipped += th.mean((th.abs(ratio - 1) > self.clip_range).float()).item() 436 | n += 1 437 | # pb.update(self.minibatch_size) 438 | 439 | # Clip grad norm 440 | if self.max_grad_norm is not None: 441 | clip_grad_norm_(self.agent.actor.parameters(), self.max_grad_norm) 442 | 443 | self.agent.optimizer.step() 444 | self.agent.optimizer.zero_grad(set_to_none=self.zero_grads_with_none) 445 | 446 | t1 = time.perf_counter_ns() 447 | 448 | assert n > 0 449 | 450 | postcompute = torch.cat([param.view(-1) for param in self.agent.actor.parameters()]) 451 | 452 | log_dict = { 453 | "ppo/loss": tot_loss / n, 454 | "ppo/policy_loss": tot_policy_loss / n, 455 | "ppo/entropy_loss": tot_entropy_loss / n, 456 | "ppo/value_loss": tot_value_loss / n, 457 | "ppo/mean_kl": total_kl_div / n, 458 | "ppo/clip_fraction": tot_clipped / n, 459 | "ppo/epoch_time": (t1 - t0) / (1e6 * self.epochs), 460 | "ppo/update_magnitude": th.dist(precompute, postcompute, p=2), 461 | } 462 | 463 | if self.kl_models_weights is not None and len(self.kl_models_weights) > 0: 464 | log_dict.update({f"ppo/kl_div_model_{i}": tot_kl_other_models[i] / n 465 | for i in range(len(self.kl_models_weights))}) 466 | log_dict.update({f"ppo/kl_coeff_model_{i}": tot_kl_coeffs[i] 467 | for i in range(len(self.kl_models_weights))}) 468 | 469 | self.logger.log(log_dict, step=iteration, commit=False) # Is committed after when calculating fps 470 | 471 | def load(self, load_location, continue_iterations=True): 472 | """ 473 | load the model weights, optimizer values, and metadata 474 | :param load_location: checkpoint folder to read 475 | :param continue_iterations: keep the same training steps 476 | """ 477 | 478 | checkpoint = torch.load(load_location) 479 | self.agent.actor.load_state_dict(checkpoint['actor_state_dict']) 480 | self.agent.critic.load_state_dict(checkpoint['critic_state_dict']) 481 | self.agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 482 | 483 | if continue_iterations: 484 | self.starting_iteration = checkpoint['epoch'] 485 | self.total_steps = checkpoint["total_steps"] 486 | print("Continuing training at iteration " + str(self.starting_iteration)) 487 | 488 | def save(self, save_location, current_step, save_actor_jit=False): 489 | """ 490 | Save the model weights, optimizer values, and metadata 491 | :param save_location: where to save 492 | :param current_step: the current iteration when saved. Use to later continue training 493 | :param save_actor_jit: save the policy network as a torch jit file for rlbot use 494 | """ 495 | 496 | version_str = str(self.logger.project) + "_" + str(current_step) 497 | version_dir = save_location + "\\" + version_str 498 | 499 | os.makedirs(version_dir, exist_ok=current_step == -1) 500 | 501 | torch.save({ 502 | 'epoch': current_step, 503 | "total_steps": self.total_steps, 504 | 'actor_state_dict': self.agent.actor.state_dict(), 505 | 'critic_state_dict': self.agent.critic.state_dict(), 506 | 'optimizer_state_dict': self.agent.optimizer.state_dict(), 507 | # TODO save/load reward normalization mean, std, count 508 | }, version_dir + "\\checkpoint.pt") 509 | 510 | if save_actor_jit: 511 | traced_actor = th.jit.trace(self.agent.actor, self.jit_tracer) 512 | torch.jit.save(traced_actor, version_dir + "\\jit_policy.jit") 513 | 514 | def freeze_policy(self, frozen_iterations=100): 515 | """ 516 | Freeze policy network to allow value network to settle. Useful with pretrained policy networks. 517 | 518 | Note that network weights will not be transmitted when frozen. 519 | 520 | :param frozen_iterations: how many iterations the policy update will remain unchanged 521 | """ 522 | 523 | print("-------------------------------------------------------------") 524 | print("Policy Weights frozen for " + str(frozen_iterations) + " iterations") 525 | print("-------------------------------------------------------------") 526 | 527 | self.frozen_iterations = frozen_iterations 528 | 529 | self._saved_lr = self.agent.optimizer.param_groups[0]["lr"] 530 | self.agent.optimizer.param_groups[0]["lr"] = 0 531 | -------------------------------------------------------------------------------- /rocket_learn/rollout_generator/redis/redis_rollout_worker.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import itertools 3 | import os 4 | import time 5 | from threading import Thread 6 | from uuid import uuid4 7 | 8 | import sqlite3 as sql 9 | 10 | import numpy as np 11 | from redis import Redis 12 | from rlgym.envs import Match 13 | from rlgym.gamelaunch import LaunchPreference 14 | from rlgym.gym import Gym 15 | from tabulate import tabulate 16 | 17 | import rocket_learn.agent.policy 18 | import rocket_learn.utils.generate_episode 19 | from rocket_learn.rollout_generator.redis.utils import _unserialize_model, MODEL_LATEST, WORKER_IDS, OPPONENT_MODELS, \ 20 | VERSION_LATEST, _serialize, ROLLOUTS, encode_buffers, decode_buffers, get_rating, LATEST_RATING_ID, \ 21 | EXPERIENCE_PER_MODE 22 | from rocket_learn.utils.util import probability_NvsM 23 | from rocket_learn.utils.dynamic_gamemode_setter import DynamicGMSetter 24 | 25 | 26 | class RedisRolloutWorker: 27 | """ 28 | Provides RedisRolloutGenerator with rollouts via a Redis server 29 | 30 | :param redis: redis object 31 | :param name: rollout worker name 32 | :param match: match object 33 | :param past_version_prob: Odds of playing against previous checkpoints 34 | :param evaluation_prob: Odds of running an evaluation match 35 | :param sigma_target: Trueskill sigma target 36 | :param dynamic_gm: Pick game mode dynamically. If True, Match.team_size should be 3 37 | :param streamer_mode: Should run in streamer mode (less data printed to screen) 38 | :param send_gamestates: Should gamestate data be sent back (increases data sent) - must send obs or gamestates 39 | :param send_obs: Should observations be send back (increases data sent) - must send obs or gamestates 40 | :param scoreboard: Scoreboard object 41 | :param pretrained_agents: Dict{} of pretrained agents and their appearance probability 42 | :param human_agent: human agent object. Sets a human match if not None 43 | :param force_paging: Should paging be forced 44 | :param auto_minimize: automatically minimize the launched rocket league instance 45 | :param local_cache_name: name of local database used for model caching. If None, caching is not used 46 | :param gamemode_weights: dict of dynamic gamemode choice weights. If None, default equal experience 47 | """ 48 | 49 | def __init__(self, redis: Redis, name: str, match: Match, 50 | past_version_prob=.2, evaluation_prob=0.01, sigma_target=1, 51 | dynamic_gm=True, streamer_mode=False, send_gamestates=True, 52 | send_obs=True, scoreboard=None, pretrained_agents=None, 53 | human_agent=None, force_paging=False, auto_minimize=True, 54 | local_cache_name=None, gamemode_weights=None, full_team_evaluations=False, 55 | live_progress=True): 56 | # TODO model or config+params so workers can recreate just from redis connection? 57 | self.redis = redis 58 | self.name = name 59 | 60 | assert send_gamestates or send_obs, "Must have at least one of obs or states" 61 | 62 | self.pretrained_agents = {} 63 | self.pretrained_total_prob = 0 64 | if pretrained_agents is not None: 65 | self.pretrained_agents = pretrained_agents 66 | self.pretrained_total_prob = sum([self.pretrained_agents[key] for key in self.pretrained_agents]) 67 | 68 | self.human_agent = human_agent 69 | 70 | if human_agent and pretrained_agents: 71 | print("** WARNING - Human Player and Pretrained Agents are in conflict. **") 72 | print("** Pretrained Agents will be ignored. **") 73 | 74 | self.streamer_mode = streamer_mode 75 | 76 | self.current_agent = _unserialize_model(self.redis.get(MODEL_LATEST)) 77 | 78 | self.past_version_prob = past_version_prob 79 | self.evaluation_prob = evaluation_prob 80 | self.sigma_target = sigma_target 81 | self.send_gamestates = send_gamestates 82 | self.send_obs = send_obs 83 | self.dynamic_gm = dynamic_gm 84 | self.gamemode_weights = gamemode_weights 85 | if self.gamemode_weights is not None: 86 | assert np.isclose(sum(self.gamemode_weights.values()), 1), "gamemode_weights must sum to 1" 87 | self.gamemode_exp_per_episode_ema = {} 88 | self.local_cache_name = local_cache_name 89 | 90 | self.full_team_evaluations = full_team_evaluations 91 | 92 | self.live_progress = live_progress 93 | 94 | self.uuid = str(uuid4()) 95 | self.redis.rpush(WORKER_IDS, self.uuid) 96 | 97 | # currently doesn't rebuild, if the old is there, reuse it. 98 | if self.local_cache_name: 99 | self.sql = sql.connect('redis-model-cache-' + local_cache_name + '.db') 100 | # if the table doesn't exist in the database, make it 101 | self.sql.execute(""" 102 | CREATE TABLE if not exists MODELS ( 103 | id TEXT PRIMARY KEY, 104 | parameters BLOB NOT NULL 105 | ); 106 | """) 107 | 108 | if not self.streamer_mode: 109 | print("Started worker", self.uuid, "on host", self.redis.connection_pool.connection_kwargs.get("host"), 110 | "under name", name) # TODO log instead 111 | else: 112 | print("Streaming mode set. Running silent.") 113 | 114 | self.scoreboard = scoreboard 115 | state_setter = DynamicGMSetter(match._state_setter) # noqa Rangler made me do it 116 | self.set_team_size = state_setter.set_team_size 117 | match._state_setter = state_setter 118 | self.match = match 119 | self.env = Gym(match=self.match, pipe_id=os.getpid(), launch_preference=LaunchPreference.EPIC, 120 | use_injector=True, force_paging=force_paging, raise_on_crash=True, auto_minimize=auto_minimize) 121 | self.total_steps_generated = 0 122 | 123 | def _get_opponent_ids(self, n_new, n_old, pretrained_choice): 124 | # Get qualities 125 | assert (n_new + n_old) % 2 == 0 126 | per_team = (n_new + n_old) // 2 127 | gamemode = f"{per_team}v{per_team}" 128 | latest_id = self.redis.get(LATEST_RATING_ID).decode("utf-8") 129 | latest_key = f"{latest_id}-stochastic" 130 | if n_old == 0: 131 | rating = get_rating(gamemode, latest_key, self.redis) 132 | return [-1] * n_new, [rating] * n_new 133 | 134 | ratings = get_rating(gamemode, None, self.redis) 135 | latest_rating = ratings[latest_key] 136 | keys, values = zip(*ratings.items()) 137 | 138 | is_eval = (n_new == 0 and len(values) >= n_old) 139 | if is_eval: # Evaluation game, try to find agents with high sigma 140 | sigmas = np.array([r.sigma for r in values]) 141 | probs = np.clip(sigmas - self.sigma_target, a_min=0, a_max=None) 142 | s = probs.sum() 143 | if s == 0: # No versions with high sigma available 144 | if np.random.normal(0, self.sigma_target) > 1: 145 | # Some chance of doing a match with random versions, so they might correct themselves 146 | probs = np.ones_like(probs) / len(probs) 147 | else: 148 | return [-1] * n_old, [latest_rating] * n_old 149 | else: 150 | probs /= s 151 | versions = [np.random.choice(len(keys), p=probs)] 152 | if self.full_team_evaluations: 153 | versions = versions * per_team 154 | target_rating = values[versions[0]] 155 | elif pretrained_choice is not None: # pretrained agent chosen, just need index generation 156 | matchups = np.full((n_new + n_old), -1).tolist() 157 | for i in range(n_old): 158 | index = np.random.randint(0, n_new + n_old) 159 | matchups[index] = 'na' 160 | return matchups, ratings.values() 161 | else: 162 | if n_new == 0: # Would-be evaluation game, but not enough agents 163 | n_new = n_old 164 | n_old = 0 165 | versions = [-1] * n_new 166 | target_rating = latest_rating 167 | 168 | # Calculate 1v1 win prob against target 169 | # All the agents included should hold their own (at least approximately) 170 | # This is to prevent unrealistic scenarios, 171 | # like for instance ratings of [100, 0] vs [100, 0], which is technically fair but not useful 172 | probs = np.zeros(len(keys)) 173 | if n_new == 0 and self.full_team_evaluations: 174 | for i, rating in enumerate(values): 175 | if i == versions[0]: 176 | p = 0 # Don't add more of the same agent in evaluation matches 177 | else: 178 | p = probability_NvsM([rating] * per_team, [target_rating] * per_team) 179 | probs[i] = p * (1 - p) 180 | probs /= probs.sum() 181 | opponent = np.random.choice(len(probs), p=probs) 182 | if np.random.random() < 0.5: # Randomly do blue/orange 183 | versions = versions + [opponent] * per_team 184 | else: 185 | versions = [opponent] * per_team + versions 186 | return [keys[i] for i in versions], [values[i] for i in versions] 187 | else: 188 | for i, rating in enumerate(values): 189 | if n_new == 0 and i == versions[0]: 190 | continue # Don't add more of the same agent in evaluation matches 191 | p = probability_NvsM([rating], [target_rating]) 192 | probs[i] = (p * (1 - p)) ** ((n_new + n_old) // 2) # Be less lenient the more players there are 193 | probs /= probs.sum() 194 | 195 | old_versions = np.random.choice(len(probs), size=n_old - is_eval, p=probs, replace=True).tolist() 196 | versions += old_versions 197 | 198 | # Then calculate the full matchup, with just permutations of the selected versions (weighted by fairness) 199 | matchups = [] 200 | qualities = [] 201 | for perm in itertools.permutations(versions): 202 | it_ratings = [latest_rating if v == -1 else values[v] for v in perm] 203 | mid = len(it_ratings) // 2 204 | p = probability_NvsM(it_ratings[:mid], it_ratings[mid:]) 205 | if n_new == 0 and set(perm[:mid]) == set(perm[mid:]): # Don't want team against team 206 | p = 0 207 | matchups.append(perm) 208 | qualities.append(p * (1 - p)) # From AlphaStar 209 | qualities = np.array(qualities) 210 | s = qualities.sum() 211 | if s == 0: 212 | return [-1] * (n_new + n_old), [latest_rating] * (n_new + n_old) 213 | k = np.random.choice(len(matchups), p=qualities / s) 214 | return [-1 if i == -1 else keys[i] for i in matchups[k]], \ 215 | [latest_rating if i == -1 else values[i] for i in matchups[k]] 216 | 217 | @functools.lru_cache(maxsize=8) 218 | def _get_past_model(self, version): 219 | # if version in local database, query from database 220 | # if not, pull from REDIS and store in disk cache 221 | 222 | if self.local_cache_name: 223 | models = self.sql.execute("SELECT parameters FROM MODELS WHERE id == ?", (version,)).fetchall() 224 | if len(models) == 0: 225 | bytestream = self.redis.hget(OPPONENT_MODELS, version) 226 | model = _unserialize_model(bytestream) 227 | 228 | self.sql.execute('INSERT INTO MODELS (id, parameters) VALUES (?, ?)', (version, bytestream)) 229 | self.sql.commit() 230 | else: 231 | # should only ever be 1 version of parameters 232 | assert len(models) <= 1 233 | # stored as tuple due to sqlite, 234 | assert len(models[0]) == 1 235 | 236 | bytestream = models[0][0] 237 | model = _unserialize_model(bytestream) 238 | else: 239 | model = _unserialize_model(self.redis.hget(OPPONENT_MODELS, version)) 240 | 241 | return model 242 | 243 | def select_gamemode(self, equal_likelihood): 244 | mode_exp = {m.decode("utf-8"): int(v) for m, v in self.redis.hgetall(EXPERIENCE_PER_MODE).items()} 245 | modes = list(mode_exp.keys()) 246 | if equal_likelihood: 247 | mode = np.random.choice(modes) 248 | else: 249 | dist = np.array(list(mode_exp.values())) + 1 250 | dist = dist / dist.sum() 251 | if self.gamemode_weights is None: 252 | target_dist = np.ones(len(modes)) 253 | else: 254 | target_dist = np.array([self.gamemode_weights[k] for k in modes]) 255 | mode_steps_per_episode = np.array(list(self.gamemode_exp_per_episode_ema.get(m, None) or 1 for m in modes)) 256 | 257 | target_dist = target_dist / mode_steps_per_episode 258 | target_dist = target_dist / target_dist.sum() 259 | inv_dist = 1 - dist 260 | inv_dist = inv_dist / inv_dist.sum() 261 | 262 | dist = target_dist * inv_dist 263 | dist = dist / dist.sum() 264 | 265 | mode = np.random.choice(modes, p=dist) 266 | 267 | b, o = mode.split("v") 268 | return int(b), int(o) 269 | 270 | @staticmethod 271 | def make_table(versions, ratings, blue, orange, pretrained_choice): 272 | version_info = [] 273 | for v, r in zip(versions, ratings): 274 | if pretrained_choice is not None and v == 'na': # print name but don't send it back 275 | version_info.append([str(type(pretrained_choice).__name__), "N/A"]) 276 | elif v == 'na': 277 | version_info.append(['Human', "N/A"]) 278 | else: 279 | if isinstance(v, int) and v < 0: 280 | v = f"Latest ({-v})" 281 | version_info.append([v, f"{r.mu:.2f}±{2 * r.sigma:.2f}"]) 282 | 283 | blue_versions, blue_ratings = list(zip(*version_info[:blue])) 284 | orange_versions, orange_ratings = list(zip(*version_info[blue:])) 285 | 286 | if blue < orange: 287 | blue_versions += [""] * (orange - blue) 288 | blue_ratings += [""] * (orange - blue) 289 | elif orange < blue: 290 | orange_versions += [""] * (blue - orange) 291 | orange_ratings += [""] * (blue - orange) 292 | 293 | table_str = tabulate(list(zip(blue_versions, blue_ratings, orange_versions, orange_ratings)), 294 | headers=["Blue", "rating", "Orange", "rating"], tablefmt="rounded_outline") 295 | 296 | return table_str 297 | 298 | def run(self): # Mimics Thread 299 | """ 300 | begin processing in already launched match and push to redis 301 | """ 302 | n = 0 303 | latest_version = None 304 | # t = Thread() 305 | # t.start() 306 | while True: 307 | # Get the most recent version available 308 | available_version = self.redis.get(VERSION_LATEST) 309 | if available_version is None: 310 | time.sleep(1) 311 | continue # Wait for version to be published (not sure if this is necessary?) 312 | available_version = int(available_version) 313 | 314 | # Only try to download latest version when new 315 | if latest_version != available_version: 316 | model_bytes = self.redis.get(MODEL_LATEST) 317 | if model_bytes is None: 318 | time.sleep(1) 319 | continue # This is maybe not necessary? Can't hurt to leave it in. 320 | latest_version = available_version 321 | updated_agent = _unserialize_model(model_bytes) 322 | self.current_agent = updated_agent 323 | 324 | n += 1 325 | pretrained_choice = None 326 | 327 | evaluate = np.random.random() < self.evaluation_prob 328 | 329 | if self.dynamic_gm: 330 | blue, orange = self.select_gamemode(equal_likelihood=evaluate or self.streamer_mode) 331 | elif self.match._spawn_opponents is False: 332 | blue = self.match.agents 333 | orange = 0 334 | else: 335 | blue = orange = self.match.agents // 2 336 | self.set_team_size(blue, orange) 337 | 338 | if self.human_agent: 339 | n_new = blue + orange - 1 340 | versions = ['na'] 341 | 342 | agents = [self.human_agent] 343 | for n in range(n_new): 344 | agents.append(self.current_agent) 345 | versions.append(-1) 346 | 347 | versions = [v if v != -1 else latest_version for v in versions] 348 | ratings = ["na"] * len(versions) 349 | else: 350 | # TODO customizable past agent selection, should team only be same agent? 351 | agents, pretrained_choice, versions, ratings = self._generate_matchup(blue + orange, 352 | latest_version, 353 | pretrained_choice, 354 | evaluate) 355 | 356 | evaluate = not any(isinstance(v, int) and v < 0 for v in versions) # Might be changed in matchup code 357 | 358 | table_str = self.make_table(versions, ratings, blue, orange, pretrained_choice) 359 | 360 | if evaluate and not self.streamer_mode and self.human_agent is None: 361 | print("EVALUATION GAME\n" + table_str) 362 | result = rocket_learn.utils.generate_episode.generate_episode(self.env, agents, evaluate=True, 363 | scoreboard=self.scoreboard, 364 | progress=self.live_progress) 365 | rollouts = [] 366 | print("Evaluation finished, goal differential:", result) 367 | print() 368 | else: 369 | if not self.streamer_mode: 370 | print("ROLLOUT\n" + table_str) 371 | 372 | try: 373 | rollouts, result = rocket_learn.utils.generate_episode.generate_episode(self.env, agents, 374 | evaluate=False, 375 | scoreboard=self.scoreboard, 376 | progress=self.live_progress) 377 | 378 | if len(rollouts[0].observations) <= 1: # Happens sometimes, unknown reason 379 | print(" ** Rollout Generation Error: Restarting Generation ** ") 380 | print() 381 | continue 382 | except EnvironmentError: 383 | self.env.attempt_recovery() 384 | continue 385 | 386 | state = rollouts[0].infos[-2]["state"] 387 | goal_speed = np.linalg.norm(state.ball.linear_velocity) * 0.036 # kph 388 | str_result = ('+' if result > 0 else "") + str(result) 389 | episode_exp = len(rollouts[0].observations) * len(rollouts) 390 | self.total_steps_generated += episode_exp 391 | 392 | if self.dynamic_gm and not evaluate: 393 | mode = f"{blue}v{orange}" 394 | if mode in self.gamemode_exp_per_episode_ema: 395 | current_mean = self.gamemode_exp_per_episode_ema[mode] 396 | self.gamemode_exp_per_episode_ema[mode] = 0.98 * current_mean + 0.02 * episode_exp 397 | else: 398 | self.gamemode_exp_per_episode_ema[mode] = episode_exp 399 | 400 | post_stats = f"Rollout finished after {len(rollouts[0].observations)} steps ({self.total_steps_generated} total steps), result was {str_result}" 401 | if result != 0: 402 | post_stats += f", goal speed: {goal_speed:.2f} kph" 403 | 404 | if not self.streamer_mode: 405 | print(post_stats) 406 | print() 407 | 408 | if not self.streamer_mode: 409 | rollout_data = encode_buffers(rollouts, 410 | return_obs=self.send_obs, 411 | return_states=self.send_gamestates, 412 | return_rewards=True) 413 | # sanity_check = decode_buffers(rollout_data, versions, 414 | # has_obs=False, has_states=True, has_rewards=True, 415 | # obs_build_factory=lambda: self.match._obs_builder, 416 | # rew_func_factory=lambda: self.match._reward_fn, 417 | # act_parse_factory=lambda: self.match._action_parser) 418 | rollout_bytes = _serialize((rollout_data, versions, self.uuid, self.name, result, 419 | self.send_obs, self.send_gamestates, True)) 420 | 421 | # while True: 422 | # t.join() 423 | 424 | def send(): 425 | n_items = self.redis.rpush(ROLLOUTS, rollout_bytes) 426 | if n_items >= 1000: 427 | print("Had to limit rollouts. Learner may have have crashed, or is overloaded") 428 | self.redis.ltrim(ROLLOUTS, -100, -1) 429 | 430 | send() 431 | # t = Thread(target=send) 432 | # t.start() 433 | # time.sleep(0.01) 434 | 435 | def _generate_matchup(self, n_agents, latest_version, pretrained_choice, evaluate): 436 | if evaluate: 437 | n_old = n_agents 438 | else: 439 | n_old = 0 440 | rand_choice = np.random.random() 441 | if rand_choice < self.past_version_prob: 442 | n_old = np.random.randint(low=1, high=n_agents) 443 | elif rand_choice < (self.past_version_prob + self.pretrained_total_prob): 444 | wheel_prob = self.past_version_prob 445 | for agent in self.pretrained_agents: 446 | wheel_prob += self.pretrained_agents[agent] 447 | if rand_choice < wheel_prob: 448 | pretrained_choice = agent 449 | n_old = np.random.randint(low=1, high=n_agents) 450 | break 451 | n_new = n_agents - n_old 452 | versions, ratings = self._get_opponent_ids(n_new, n_old, pretrained_choice) 453 | agents = [] 454 | for version in versions: 455 | if version == -1: 456 | agents.append(self.current_agent) 457 | elif pretrained_choice is not None and version == 'na': 458 | agents.append(pretrained_choice) 459 | else: 460 | selected_agent = self._get_past_model("-".join(version.split("-")[:-1])) 461 | if version.endswith("deterministic"): 462 | selected_agent.deterministic = True 463 | elif version.endswith("stochastic"): 464 | selected_agent.deterministic = False 465 | else: 466 | raise ValueError("Unknown version type") 467 | agents.append(selected_agent) 468 | if self.streamer_mode > 1: 469 | agents[-1].deterministic = True 470 | versions = [v if v != -1 else latest_version for v in versions] 471 | return agents, pretrained_choice, versions, ratings 472 | --------------------------------------------------------------------------------