├── .gitignore
├── rocket_learn
├── __init__.py
├── agent
│ ├── __init__.py
│ ├── pretrained_agents
│ │ ├── nexto
│ │ │ ├── nexto-model.pt
│ │ │ ├── nexto_v2.py
│ │ │ └── nexto_v2_obs.py
│ │ ├── necto
│ │ │ ├── necto-model-10Y.pt
│ │ │ ├── necto-model-20Y.pt
│ │ │ ├── necto-model-30Y.pt
│ │ │ ├── necto_v1.py
│ │ │ └── necto_v1_obs.py
│ │ └── human_agent.py
│ ├── actor_critic_agent.py
│ ├── policy.py
│ ├── pretrained_policy.py
│ └── discrete_policy.py
├── utils
│ ├── __init__.py
│ ├── stat_trackers
│ │ ├── __init__.py
│ │ ├── stat_tracker.py
│ │ └── common_trackers.py
│ ├── truncated_condition.py
│ ├── dynamic_gamemode_setter.py
│ ├── batched_obs_builder.py
│ ├── util.py
│ ├── gamestate_encoding.py
│ ├── scoreboard.py
│ └── generate_episode.py
├── rollout_generator
│ ├── __init__.py
│ ├── redis
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── redis_rollout_generator.py
│ │ └── redis_rollout_worker.py
│ ├── base_rollout_generator.py
│ ├── simple_rollout_generator.py
│ └── api_rollout_generator.py
├── learner.py
├── simple_agents.py
├── experience_buffer.py
├── agent.py
└── ppo.py
├── setup.py
├── docs
├── sb3_to_rocketlearn_transition.txt
├── troubleshooting.txt
└── network_setup_readme.txt
├── examples
├── default
│ ├── sb3_to_rocketlearn_transition.txt
│ ├── worker.py
│ └── learner.py
├── human_trainer
│ └── worker_with_human_trainer.py
├── pretrained_agent
│ └── worker_with_pretrained_agent.py
└── loading
│ └── learner.py
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__*
--------------------------------------------------------------------------------
/rocket_learn/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rocket_learn/agent/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rocket_learn/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rocket_learn/rollout_generator/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rocket_learn/rollout_generator/redis/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rocket_learn/utils/stat_trackers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/rocket_learn/agent/pretrained_agents/nexto/nexto-model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rolv-Arild/rocket-learn/HEAD/rocket_learn/agent/pretrained_agents/nexto/nexto-model.pt
--------------------------------------------------------------------------------
/rocket_learn/agent/pretrained_agents/necto/necto-model-10Y.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rolv-Arild/rocket-learn/HEAD/rocket_learn/agent/pretrained_agents/necto/necto-model-10Y.pt
--------------------------------------------------------------------------------
/rocket_learn/agent/pretrained_agents/necto/necto-model-20Y.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rolv-Arild/rocket-learn/HEAD/rocket_learn/agent/pretrained_agents/necto/necto-model-20Y.pt
--------------------------------------------------------------------------------
/rocket_learn/agent/pretrained_agents/necto/necto-model-30Y.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rolv-Arild/rocket-learn/HEAD/rocket_learn/agent/pretrained_agents/necto/necto-model-30Y.pt
--------------------------------------------------------------------------------
/rocket_learn/utils/stat_trackers/stat_tracker.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import numpy as np
4 |
5 |
6 | class StatTracker(abc.ABC):
7 | def __init__(self, name):
8 | self.name = name
9 |
10 | def reset(self): # Called whenever
11 | raise NotImplementedError
12 |
13 | def update(self, gamestates: np.ndarray, masks: np.ndarray):
14 | raise NotImplementedError
15 |
16 | def get_stat(self):
17 | raise NotImplementedError
18 |
--------------------------------------------------------------------------------
/rocket_learn/rollout_generator/base_rollout_generator.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Iterator
3 |
4 | from rocket_learn.experience_buffer import ExperienceBuffer
5 |
6 |
7 | class BaseRolloutGenerator(ABC):
8 | @abstractmethod
9 | def generate_rollouts(self) -> Iterator[ExperienceBuffer]:
10 | raise NotImplementedError
11 |
12 | @abstractmethod
13 | def update_parameters(self, new_params):
14 | raise NotImplementedError
15 |
--------------------------------------------------------------------------------
/rocket_learn/agent/actor_critic_agent.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | from torch import nn
3 |
4 | from rocket_learn.agent.policy import Policy
5 |
6 |
7 | class ActorCriticAgent(nn.Module):
8 | def __init__(self, actor: Policy, critic: nn.Module, optimizer: th.optim.Optimizer):
9 | super().__init__()
10 | self.actor = actor
11 | self.critic = critic
12 | self.optimizer = optimizer
13 | # self.algo = ?
14 | # TODO self.shared =
15 |
16 | def forward(self, *args, **kwargs):
17 | return self.actor(*args, **kwargs), self.critic(*args, **kwargs)
18 |
--------------------------------------------------------------------------------
/rocket_learn/learner.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import cloudpickle
4 |
5 |
6 | class CloudpickleWrapper:
7 | """
8 | ** Copied from SB3 **
9 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
10 | :param var: the variable you wish to wrap for pickling with cloudpickle
11 | """
12 |
13 | def __init__(self, var: Any):
14 | self.var = var
15 |
16 | def __getstate__(self) -> Any:
17 | return cloudpickle.dumps(self.var)
18 |
19 | def __setstate__(self, var: Any) -> None:
20 | self.var = cloudpickle.loads(var)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | long_description = """
4 |
5 | # Rocket Learn
6 |
7 | Learning!
8 |
9 | """
10 |
11 | setup(
12 | name='rocket_learn',
13 | version='0.2.8',
14 | description='Rocket Learn',
15 | author='Rolv-Arild Braaten, Daniel Downs',
16 | url='https://github.com/Rolv-Arild/rocket-learn',
17 | packages=[package for package in find_packages() if package.startswith("rocket_learn")],
18 | long_description=long_description,
19 | install_requires=['cloudpickle==1.6.0', 'gym', 'torch', 'tqdm', 'trueskill',
20 | 'msgpack_numpy', 'wandb', 'pygame', 'keyboard', 'tabulate'],
21 | )
22 |
--------------------------------------------------------------------------------
/rocket_learn/utils/truncated_condition.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | from rlgym.utils import TerminalCondition
4 | from rlgym.utils.gamestates import GameState
5 |
6 |
7 | class TruncatedCondition(TerminalCondition, ABC):
8 | def is_truncated(self, current_state: GameState):
9 | raise NotImplementedError
10 |
11 |
12 | class TerminalToTruncatedWrapper(TruncatedCondition):
13 | def __init__(self, condition: TerminalCondition):
14 | super().__init__()
15 | self.condition = condition
16 |
17 | def is_truncated(self, current_state: GameState):
18 | return self.condition.is_terminal(current_state)
19 |
20 | def reset(self, initial_state: GameState):
21 | self.condition.reset(initial_state)
22 |
23 | def is_terminal(self, current_state: GameState) -> bool:
24 | return False
25 |
--------------------------------------------------------------------------------
/rocket_learn/utils/dynamic_gamemode_setter.py:
--------------------------------------------------------------------------------
1 | from rlgym.utils import StateSetter
2 | from rlgym.utils.state_setters import StateWrapper
3 |
4 |
5 | class DynamicGMSetter(StateSetter):
6 | def __init__(self, setter: StateSetter):
7 | self.setter = setter
8 | self.blue = 0
9 | self.orange = 0
10 |
11 | def set_team_size(self, blue=None, orange=None):
12 | if blue is not None:
13 | self.blue = blue
14 | if orange is not None:
15 | self.orange = orange
16 |
17 | def build_wrapper(self, max_team_size: int, spawn_opponents: bool) -> StateWrapper:
18 | assert self.blue <= max_team_size and self.orange <= max_team_size
19 | return StateWrapper(self.blue, self.orange)
20 |
21 | def reset(self, state_wrapper: StateWrapper):
22 | self.setter.reset(state_wrapper)
23 |
--------------------------------------------------------------------------------
/rocket_learn/agent/policy.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from torch import nn
4 |
5 | class Policy(nn.Module, ABC):
6 | def __init__(self, deterministic=False):
7 | super().__init__()
8 | self.deterministic = deterministic
9 |
10 | @abstractmethod
11 | def forward(self, *args, **kwargs): raise NotImplementedError
12 |
13 | @abstractmethod
14 | def get_action_distribution(self, obs): raise NotImplementedError
15 |
16 | @staticmethod
17 | @abstractmethod
18 | def sample_action(distribution, deterministic=None): raise NotImplementedError
19 |
20 | @staticmethod
21 | @abstractmethod
22 | def log_prob(distribution, selected_action): raise NotImplementedError
23 |
24 | @staticmethod
25 | @abstractmethod
26 | def entropy(distribution, selected_action): raise NotImplementedError
27 |
28 | @abstractmethod
29 | def env_compatible(self, action): raise NotImplementedError
30 |
--------------------------------------------------------------------------------
/rocket_learn/rollout_generator/simple_rollout_generator.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator
2 |
3 | import rlgym
4 | from rocket_learn.agent.policy import Policy
5 | from rocket_learn.experience_buffer import ExperienceBuffer
6 | from rocket_learn.rollout_generator.base_rollout_generator import BaseRolloutGenerator
7 | from rocket_learn.utils.generate_episode import generate_episode
8 |
9 |
10 | class SimpleRolloutGenerator(BaseRolloutGenerator):
11 | def __init__(self, policy: Policy, **make_args):
12 | self.env = rlgym.make(**make_args)
13 | self.policy = policy
14 | self.n_agents = self.env._match.agents
15 |
16 | def generate_rollouts(self) -> Iterator[ExperienceBuffer]:
17 | while True:
18 | # TODO: need to add selfplay agent here?
19 | rollouts, result = generate_episode(self.env, [self.policy] * self.n_agents)
20 |
21 | yield from rollouts
22 |
23 | def update_parameters(self, new_params):
24 | self.policy = new_params
25 |
--------------------------------------------------------------------------------
/docs/sb3_to_rocketlearn_transition.txt:
--------------------------------------------------------------------------------
1 | SWITCHING FROM STABLE-BASELINES3 TO ROCKET-LEARN
2 |
3 | When you install rocket-learn, verify that the version of rlgym is compatible with rocket-learn.
4 | Use only the release version of rlgym. Do not use beta of rlgym unless you are attempting to beta test.
5 | When setting up your environment, make sure you do not install rlgym from a cached version on your
6 | machine, and verify that the dll in the bakkesmod plugin folder is accurate.
7 |
8 | SB3 abstracts away several important parts of ML training that rocket-learn does not.
9 |
10 | -your rewards will not be normalized
11 | -your networks will not have orthogonal initialization by default (assuming you use PPO)
12 |
13 | This can drastically affect the results you get and it is not uncommon to not see the same results
14 | in rocket-learn as you did in SB3, at least until you make tweaks. In addition to the major
15 | differences listed above, differences in implementation in learning algorithms can cause large
16 | changes in results. Be prepared to do some extra tweaking as a part of the switch.
--------------------------------------------------------------------------------
/rocket_learn/rollout_generator/api_rollout_generator.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator
2 | from uuid import uuid4
3 |
4 | from fastapi import FastAPI # noqa
5 | from pydantic import Field
6 | from pydantic.main import Model # noqa
7 | from starlette.middleware.gzip import GZipMiddleware
8 |
9 | from rocket_learn.experience_buffer import ExperienceBuffer
10 | from rocket_learn.rollout_generator.base_rollout_generator import BaseRolloutGenerator
11 |
12 |
13 | app = FastAPI(
14 | title="rocket-learn-api",
15 | version="0.1.0"
16 | )
17 |
18 | workers = {}
19 |
20 |
21 | @app.post("/worker")
22 | async def create_worker(name: str = "contributor"):
23 | uuid = uuid4()
24 | workers[uuid] = name
25 | return uuid
26 |
27 | @app.get("/matchup")
28 | async def get_matchup(mode: int):
29 | qualities = [...]
30 |
31 | @app.post("/rollout")
32 | async def rollout(obs_rew_probs: bytes, ):
33 |
34 |
35 |
36 | class ApiRolloutGenerator(BaseRolloutGenerator):
37 | def generate_rollouts(self) -> Iterator[ExperienceBuffer]:
38 | pass
39 |
40 | def update_parameters(self, new_params):
41 | pass
--------------------------------------------------------------------------------
/examples/default/sb3_to_rocketlearn_transition.txt:
--------------------------------------------------------------------------------
1 | SWITCHING FROM STABLE-BASELINES3 TO ROCKET-LEARN
2 |
3 | When you install rocket-learn, verify that the version of rlgym is compatible with rocket-learn.
4 | Use only the release version of rlgym. Do not use beta of rlgym unless you are attempting to beta test.
5 | When setting up your environment, make sure you do not install rlgym from a cached version on your
6 | machine, and verify that the dll in the bakkesmod plugin folder is accurate.
7 |
8 | SB3 abstracts away several important parts of ML training that rocket-learn does not.
9 |
10 | -your rewards will not be normalized
11 | -your networks will not have orthogonal initialization by default (assuming you use PPO)
12 |
13 | This can drastically affect the results you get and it is not uncommon to not see the same results
14 | in rocket-learn as you did in SB3, at least until you make tweaks. In addition to the major
15 | differences listed above, differences in implementation in learning algorithms can cause large
16 | changes in results. Be prepared to do some extra tweaking as a part of the switch.
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/rocket_learn/agent/pretrained_policy.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Tuple
3 | from torch import nn
4 |
5 | from rocket_learn.agent.discrete_policy import DiscretePolicy
6 |
7 | from rlgym.utils.gamestates import GameState
8 |
9 |
10 | class HardcodedAgent(ABC):
11 | """
12 | An external bot prebuilt and imported to be trained against
13 | """
14 |
15 | @abstractmethod
16 | def act(self, state: GameState, player_index: int): raise NotImplementedError
17 |
18 |
19 | class PretrainedDiscretePolicy(DiscretePolicy, HardcodedAgent):
20 | """
21 | A rocket-learn discrete policy pretrained and imported to be trained against
22 |
23 | :param obs_builder_func: Function that will generate the correct observation from the gamestate
24 | :param net: policy net
25 | :param shape: action distribution shape
26 | """
27 |
28 | def __init__(self, obs_builder_func, net: nn.Module, shape: Tuple[int, ...] = (3,) * 5 + (2,) * 3):
29 | super().__init__(net, shape)
30 | self.obs_builder_func = obs_builder_func
31 |
32 | def act(self, state: GameState, player_index):
33 | obs = self.obs_builder_func(state)
34 | dist = self.get_action_distribution(obs)
35 | action_indices = self.sample_action(dist, deterministic=False)
36 | actions = self.env_compatible(action_indices)
37 |
38 | return actions
39 |
40 |
41 | class DemoDriveAgent(HardcodedAgent):
42 | def act(self, state: GameState, player_index: int):
43 | return [2, 1, 1, 0, 0, 0, 0, 0]
44 |
45 |
46 | class DemoKBMDriveAgent(HardcodedAgent):
47 | def act(self, state: GameState, player_index: int):
48 | return [2, 1, 0, 0, 0]
49 |
--------------------------------------------------------------------------------
/docs/troubleshooting.txt:
--------------------------------------------------------------------------------
1 | TROUBLESHOOTING COMMON PROBLEMS
2 |
3 |
4 | RuntimeError: mat1 and mat2 shapes cannot be multiplied (AxB and CxD)
5 |
6 | Compare the actor and critic input size to the observation size. They need to be identical.
7 | Remember that changing team size and selfplay can change the observation size.
8 |
9 |
10 | _______________________________________________________________________________________________________
11 | Blue Screen of Death error (related to wandb)
12 |
13 | 1) Rollback your Nvidia driver to WHQL certified September 2021
14 |
15 | OR 2) Comment out the following lines of code in the wandb repo:
16 |
17 | try:
18 | pynvml.nvmlInit()
19 | self.gpu_count = pynvml.nvmlDeviceGetCount()
20 | except pynvml.NVMLError:
21 |
22 |
23 | _______________________________________________________________________________________________________
24 | Can't get Redis to work
25 |
26 | -WSL2 is probably the easiest way to get things working.
27 | -Double check that you can ping the redis server locally
28 |
29 |
30 | _______________________________________________________________________________________________________
31 | There are no errors but changes I'm making don't seem to be affecting anything
32 |
33 | -Double check that observations, rewards, and action parsers are the same on both the learner
34 | and workers.
35 |
36 |
37 |
38 | _______________________________________________________________________________________________________
39 | wandb is not working properly or giving you hint errors in your IDE
40 |
41 | -check that there isn’t a folder created called wandb in your project
42 | -this is created automatically by wandb, you can change the name in the init call
43 | so it doesn’t interfere.
44 |
--------------------------------------------------------------------------------
/rocket_learn/utils/batched_obs_builder.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Union, Optional
2 |
3 | import numpy as np
4 | from rlgym.utils import ObsBuilder
5 | from rlgym.utils.gamestates import PlayerData, GameState
6 |
7 | from rocket_learn.utils.gamestate_encoding import encode_gamestate
8 | from rocket_learn.utils.scoreboard import Scoreboard
9 |
10 |
11 | class BatchedObsBuilder(ObsBuilder):
12 | def __init__(self, scoreboard: Optional[Scoreboard] = None):
13 | super().__init__()
14 | self.current_state = None
15 | self.current_obs = None
16 | self.scoreboard = scoreboard
17 |
18 | def batched_build_obs(self, encoded_states: np.ndarray) -> Any:
19 | raise NotImplementedError
20 |
21 | def add_actions(self, obs: Any, previous_actions: np.ndarray, player_index=None):
22 | # Modify current obs to include action
23 | # player_index=None means actions for all players should be provided
24 | raise NotImplementedError
25 |
26 | def _reset(self, initial_state: GameState):
27 | raise NotImplementedError
28 |
29 | def reset(self, initial_state: GameState):
30 | self.current_state = False
31 | self.current_obs = None
32 | if self.scoreboard is not None:
33 | self.scoreboard.reset(initial_state)
34 | self._reset(initial_state)
35 |
36 | def pre_step(self, state: GameState):
37 | if state != self.current_state:
38 | if self.scoreboard is not None:
39 | self.scoreboard.step(state)
40 | self.current_obs = self.batched_build_obs(
41 | np.expand_dims(encode_gamestate(state), axis=0)
42 | )
43 | self.current_state = state
44 |
45 | def build_obs(self, player: PlayerData, state: GameState, previous_action: np.ndarray) -> Any:
46 | for i, p in enumerate(state.players):
47 | if p == player:
48 | self.add_actions(self.current_obs, previous_action, i)
49 | return self.current_obs[i]
50 |
--------------------------------------------------------------------------------
/rocket_learn/simple_agents.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 | import torch as th
5 | from torch.distributions import Categorical
6 |
7 | from rocket_learn.agent import BaseAgent
8 |
9 |
10 | class RandomAgent(BaseAgent):
11 | """Does softmax using biases alone"""
12 |
13 | def __init__(self, throttle=(0, 1, 2), steer=(0, 1, 0), pitch=(0, 0, 0), yaw=(0, 0, 0),
14 | roll=(0, 0, 0), jump=(3, 0), boost=(0, 0), handbrake=(3, 0)):
15 | super().__init__()
16 | self.distributions = [
17 | Categorical(logits=th.as_tensor(logits).float())
18 | for logits in (throttle, steer, pitch, yaw, roll, jump, boost, handbrake)
19 | ]
20 |
21 | def get_actions(self, observation, deterministic=False) -> np.ndarray:
22 | actions = np.stack([dist.sample() for dist in self.distributions])
23 | return actions
24 |
25 | def get_action_with_log_prob(self, observation) -> Tuple[np.ndarray, float]:
26 | pass
27 |
28 | def set_model_params(self, params) -> None:
29 | pass
30 |
31 | # def set_model_params(self, params):
32 | # self.distributions = [
33 | # Categorical(logits=th.as_tensor(logits).float())
34 | # for logits in params
35 | # ]
36 | #
37 | # def get_actions(self, observation, deterministic=False):
38 | # actions = np.stack([dist.sample() for dist in self.distributions])
39 | # return actions
40 | #
41 | # def get_log_prob(self, actions):
42 | # return th.stack(
43 | # [dist.log_prob(action) for dist, action in zip(self.distributions, th.unbind(actions, dim=1))], dim=1
44 | # ).sum(dim=1)
45 |
46 |
47 | # ** This should be in its own file or packaged with PPO **
48 |
49 |
50 | class NoOpAgent(BaseAgent):
51 | def get_actions(self, observation, deterministic=False):
52 | return th.zeros((8,))
53 |
54 | def get_log_prob(self, actions):
55 | return 0
56 |
57 | def set_model_params(self, params):
58 | pass
59 |
--------------------------------------------------------------------------------
/rocket_learn/experience_buffer.py:
--------------------------------------------------------------------------------
1 | class ExperienceBuffer:
2 | def __init__(self, observations=None, actions=None, rewards=None, dones=None, log_probs=None, infos=None):
3 | self.result = 0
4 | self.observations = []
5 | self.actions = []
6 | self.rewards = []
7 | self.dones = []
8 | self.log_probs = []
9 | self.infos = []
10 |
11 | if observations is not None:
12 | self.observations = observations
13 |
14 | if actions is not None:
15 | self.actions = actions
16 |
17 | if rewards is not None:
18 | self.rewards = rewards
19 |
20 | if dones is not None:
21 | self.dones = dones # TODO Done probably doesn't need to be a list, will always just be false until last?
22 |
23 | if log_probs is not None:
24 | self.log_probs = log_probs
25 |
26 | if infos is not None:
27 | self.infos = infos
28 |
29 | def size(self):
30 | return len(self.rewards)
31 |
32 | def add_step(self, observation, action, reward, done, log_prob, info):
33 | self.observations.append(observation)
34 | self.actions.append(action)
35 | self.rewards.append(reward)
36 | self.dones.append(done)
37 | self.log_probs.append(log_prob)
38 | self.infos.append(info)
39 |
40 | def clear(self):
41 | self.observations = []
42 | self.actions = []
43 | self.rewards = []
44 | self.dones = []
45 | self.log_probs = []
46 | self.infos = []
47 |
48 | def generate_slices(self, batch_size):
49 | for i in range(0, len(self.observations), batch_size):
50 | yield ExperienceBuffer(self.observations[i:i + batch_size],
51 | self.actions[i:i + batch_size],
52 | self.rewards[i:i + batch_size],
53 | self.dones[i:i + batch_size],
54 | self.log_probs[i:i + batch_size],
55 | self.infos[i:i + batch_size])
56 |
--------------------------------------------------------------------------------
/rocket_learn/utils/util.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import numpy as np
4 | import torch
5 | import torch.distributions
6 | from rlgym.utils.gamestates import GameState, PlayerData
7 | from rlgym.utils.obs_builders import AdvancedObs
8 | from torch import nn
9 |
10 |
11 | def softmax(x):
12 | """Compute softmax values for each sets of scores in x."""
13 | e_x = np.exp(x - np.max(x))
14 | return e_x / e_x.sum()
15 |
16 |
17 | class SplitLayer(nn.Module):
18 | def __init__(self, splits=None):
19 | super().__init__()
20 | if splits is not None:
21 | self.splits = splits
22 | else:
23 | self.splits = (3,) * 5 + (2,) * 3
24 |
25 | def forward(self, x):
26 | return torch.split(x, self.splits, dim=-1)
27 |
28 |
29 | # TODO AdvancedObs should be supported by default, use stack instead of cat
30 | class ExpandAdvancedObs(AdvancedObs):
31 | def build_obs(self, player: PlayerData, state: GameState, previous_action: np.ndarray) -> Any:
32 | return np.reshape(
33 | super(ExpandAdvancedObs, self).build_obs(player, state, previous_action),
34 | (1, -1)
35 | )
36 |
37 |
38 | def probability_NvsM(team1_ratings, team2_ratings, env=None):
39 | from trueskill import global_env
40 | # Trueskill extension, source: https://github.com/sublee/trueskill/pull/17
41 | """Calculates the win probability of the first team over the second team.
42 | :param team1_ratings: ratings of the first team participants.
43 | :param team2_ratings: ratings of another team participants.
44 | :param env: the :class:`TrueSkill` object. Defaults to the global
45 | environment.
46 | """
47 | if env is None:
48 | env = global_env()
49 |
50 | team1_mu = sum(r.mu for r in team1_ratings)
51 | team1_sigma = sum((env.beta ** 2 + r.sigma ** 2) for r in team1_ratings)
52 | team2_mu = sum(r.mu for r in team2_ratings)
53 | team2_sigma = sum((env.beta ** 2 + r.sigma ** 2) for r in team2_ratings)
54 |
55 | x = (team1_mu - team2_mu) / np.sqrt(team1_sigma + team2_sigma)
56 | probability_win_team1 = env.cdf(x)
57 | return probability_win_team1
58 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # rocket-learn
2 |
3 | ## What is rocket-learn?
4 |
5 | rocket-learn is a machine learning framework specifically designed for Rocket League Reinforcement Learning.
6 | It works in conjunction with Rocket League, RLGym, and Bakkesmod.
7 |
8 | ## What features does rocket-learn have?
9 |
10 |
11 | - Reinforcement learning algorithm available out of the box
12 |
13 | - Proximal Policy Optimization (PPO)
14 | - extensible format allows new algorithms to be added
15 |
16 | - Distributed compute from multiple computers
17 | - Automatic saving of and training against previous agent versions
18 | - Trueskill progress tracking
19 | - Training against Hardcoded/Pretrained Agents
20 | - Training against Humans
21 | - Saving and loading models
22 | - wandb logging
23 |
24 |
25 |
26 | ## Should I use rocket-learn?
27 |
28 | You should use Stable Baselines3 (SB3) to make your bot at first. The hardest parts of building a
29 | machine learning bot are
30 |
31 | - understanding how to program
32 | - understanding how machine learning works
33 | - choosing good hyperparameters
34 | - choosing good reward functions
35 | - choosing an action parser
36 | - making a statesetter that puts the bot in the best situations
37 |
38 | SB3 is a great way to figure out those essential parts. Once you have all of those aspects down, rocket-learn
39 | may be a good next step to a better machine learning bot.
40 |
41 | If you *don't* yet have these, rocket-learn will add a large amount of complexity for no added benefit. It's
42 | important to remember that high compute and a tough opponent are less important than good fundamentals of ML.
43 |
44 | ## How do I setup rocket-learn?
45 |
46 | 1) Get [Redis](https://docs.servicestack.net/install-redis-windows) running
47 |
48 | *__Improper Redis setup can leave your computer extremely vulnerable to Bad Guys.
49 | We are not responsible for your computer's safety. We assume you know what you are doing.__*
50 |
51 | 2) Clone the repo
52 |
53 | ```
54 | git clone https://github.com/Rolv-Arild/rocket-learn.git
55 | ```
56 |
57 | 3) Start up, in order:
58 |
59 | - the Redis server
60 | - the Learner
61 | - the Workers
62 |
63 | Look at the examples to get up and running
64 |
--------------------------------------------------------------------------------
/docs/network_setup_readme.txt:
--------------------------------------------------------------------------------
1 | NETWORK INPUT
2 |
3 | rocket-learn expects both actor and critic networks to have an input dimension equal to observation length.
4 | If the observation outputs an array of size (1, 150) (note the batch dimension of the observation output),
5 | then the network input should be 150. As an example:
6 |
7 | actor = DiscretePolicy(Sequential(
8 | Linear(150, 256),
9 | ReLU(),
10 | Linear(256, 256),
11 | ReLU(),
12 | Linear(256, total_output),
13 | SplitLayer(splits=split)
14 | ), split)
15 |
16 |
17 |
18 | __________________________________________________________________________________________________
19 | NETWORK OUTPUT
20 |
21 | rocket-learn expects actor networks to output a set of probablities for each possible action. For example,
22 | the default Discrete Action allows 8 actions, 5 of which are discrete control choices and 3
23 | of which are boolean choices. Because the Discrete control choices can each be -1, 0, or 1 and each
24 | boolean can be True or False, the network must output ((5 * 3) + (3 * 2)) aka 21 total actions. The actions
25 | must then be split into properly sized groups for each actions.
26 |
27 | split = (3, 3, 3, 3, 3, 2, 2, 2)
28 | total_output = sum(split)
29 |
30 | class SplitLayer(nn.Module):
31 | def __init__(self, splits=(3, 3, 3, 3, 3, 2, 2, 2)):
32 | super().__init__()
33 | self.splits = splits
34 |
35 | def forward(self, x):
36 | return torch.split(x, self.splits, dim=-1)
37 |
38 | actor = DiscretePolicy(nn.Sequential(
39 | nn.Linear(INPUT_SIZE, 256),
40 | nn.ReLU(),
41 | nn.Linear(256, total_output),
42 | SplitLayer(split)
43 | ), split)
44 |
45 | As another example, KBM actions allow 2 Discrete controls and 3 boolean controls so the network must
46 | output ((2 * 3) + (3 * 2)) aka 12 total actions
47 |
48 | split = (3, 3, 2, 2, 2)
49 | total_output = sum(split)
50 |
51 | class SplitLayer(nn.Module):
52 | def __init__(self, splits=(3, 3, 3, 3, 3, 2, 2, 2)):
53 | super().__init__()
54 | self.splits = splits
55 |
56 | def forward(self, x):
57 | return torch.split(x, self.splits, dim=-1)
58 |
59 | actor = DiscretePolicy(nn.Sequential(
60 | nn.Linear(INPUT_SIZE, 256),
61 | nn.ReLU(),
62 | nn.Linear(256, total_output),
63 | SplitLayer(split)
64 | ), split)
--------------------------------------------------------------------------------
/rocket_learn/agent/discrete_policy.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Tuple
2 |
3 | import numpy as np
4 | import torch as th
5 | from torch import nn
6 | from torch.distributions import Categorical
7 | import torch.nn.functional as F
8 |
9 | from rocket_learn.agent.policy import Policy
10 |
11 |
12 | class DiscretePolicy(Policy):
13 | def __init__(self, net: nn.Module, shape: Tuple[int, ...] = (3,) * 5 + (2,) * 3, deterministic=False):
14 | super().__init__(deterministic)
15 | self.net = net
16 | self.shape = shape
17 |
18 | def forward(self, obs):
19 | logits = self.net(obs)
20 | return logits
21 |
22 | def get_action_distribution(self, obs):
23 | if isinstance(obs, np.ndarray):
24 | obs = th.from_numpy(obs).float()
25 | elif isinstance(obs, tuple):
26 | obs = tuple(o if isinstance(o, th.Tensor) else th.from_numpy(o).float() for o in obs)
27 |
28 | logits = self(obs)
29 |
30 | if isinstance(logits, th.Tensor):
31 | logits = (logits,)
32 |
33 | max_shape = max(self.shape)
34 | logits = th.stack(
35 | [
36 | l
37 | if l.shape[-1] == max_shape
38 | else F.pad(l, pad=(0, max_shape - l.shape[-1]), value=float("-inf"))
39 | for l in logits
40 | ],
41 | dim=1
42 | )
43 |
44 | return Categorical(logits=logits)
45 |
46 | def sample_action(
47 | self,
48 | distribution: Categorical,
49 | deterministic=None
50 | ):
51 | if deterministic is None:
52 | deterministic = self.deterministic
53 | if deterministic:
54 | action_indices = th.argmax(distribution.logits, dim=-1)
55 | else:
56 | action_indices = distribution.sample()
57 |
58 | return action_indices
59 |
60 | def log_prob(self, distribution: Categorical, selected_action):
61 | log_prob = distribution.log_prob(selected_action).sum(dim=-1)
62 | return log_prob
63 |
64 | def entropy(self, distribution: Categorical, selected_action):
65 | entropy = distribution.entropy().sum(dim=-1)
66 | return entropy
67 |
68 | def env_compatible(self, action):
69 | if isinstance(action, th.Tensor):
70 | action = action.numpy()
71 | return action
72 |
--------------------------------------------------------------------------------
/rocket_learn/agent/pretrained_agents/necto/necto_v1.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | from rocket_learn.agent.pretrained_policy import HardcodedAgent
9 | from pretrained_agents.necto.necto_v1_obs import NectoV1Obs
10 |
11 | from rlgym.utils.gamestates import GameState
12 |
13 | import copy
14 |
15 |
16 | class NectoV1(HardcodedAgent):
17 | def __init__(self, model_string, n_players):
18 | cur_dir = os.path.dirname(os.path.realpath(__file__))
19 | self.actor = torch.jit.load(os.path.join(cur_dir, model_string))
20 | self.obs_builder = NectoV1Obs(n_players=n_players)
21 | self.previous_action = np.array([0, 0, 0, 0, 0, 0, 0, 0])
22 |
23 | def act(self, state: GameState, player_index: int):
24 | player = state.players[player_index]
25 | teammates = [p for p in state.players if p.team_num == player.team_num and p != player]
26 | opponents = [p for p in state.players if p.team_num != player.team_num]
27 |
28 | necto_gamestate: GameState = copy.deepcopy(state)
29 | necto_gamestate.players = [player] + teammates + opponents
30 |
31 | self.obs_builder.reset(necto_gamestate)
32 | obs = self.obs_builder.build_obs(player, necto_gamestate, self.previous_action)
33 |
34 | obs = tuple(torch.from_numpy(s).float() for s in obs)
35 | with torch.no_grad():
36 | out, _ = self.actor(obs)
37 |
38 | max_shape = max(o.shape[-1] for o in out)
39 | logits = torch.stack(
40 | [
41 | l
42 | if l.shape[-1] == max_shape
43 | else F.pad(l, pad=(0, max_shape - l.shape[-1]), value=float("-inf"))
44 | for l in out
45 | ]
46 | ).swapdims(0, 1).squeeze()
47 |
48 | actions = np.argmax(logits, axis=-1)
49 |
50 | actions = actions.reshape((-1, 5))
51 | actions[:, 0] = actions[:, 0] - 1
52 | actions[:, 1] = actions[:, 1] - 1
53 |
54 | parsed = np.zeros((actions.shape[0], 8))
55 | parsed[:, 0] = actions[:, 0] # throttle
56 | parsed[:, 1] = actions[:, 1] # steer
57 | parsed[:, 2] = actions[:, 0] # pitch
58 | parsed[:, 3] = actions[:, 1] * (1 - actions[:, 4]) # yaw
59 | parsed[:, 4] = actions[:, 1] * actions[:, 4] # roll
60 | parsed[:, 5] = actions[:, 2] # jump
61 | parsed[:, 6] = actions[:, 3] # boost
62 | parsed[:, 7] = actions[:, 4] # handbrake
63 |
64 | self.previous_action = parsed[0]
65 | return parsed[0]
66 |
--------------------------------------------------------------------------------
/examples/human_trainer/worker_with_human_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import numpy
3 |
4 | from redis import Redis
5 |
6 | from rlgym.envs import Match
7 | from rlgym.utils.gamestates import PlayerData, GameState
8 | from rlgym.utils.terminal_conditions.common_conditions import GoalScoredCondition, TimeoutCondition
9 | from rlgym.utils.reward_functions.default_reward import DefaultReward
10 | from rlgym.utils.state_setters.default_state import DefaultState
11 | from rlgym.utils.obs_builders.advanced_obs import AdvancedObs
12 | from rlgym.utils.action_parsers.discrete_act import DiscreteAction
13 |
14 | from rocket_learn.rollout_generator.redis.redis_rollout_worker import RedisRolloutWorker
15 | from rocket_learn.agent.pretrained_agents.human_agent import HumanAgent
16 |
17 |
18 | # ROCKET-LEARN ALWAYS EXPECTS A BATCH DIMENSION IN THE BUILT OBSERVATION
19 | class ExpandAdvancedObs(AdvancedObs):
20 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any:
21 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action)
22 | return numpy.expand_dims(obs, 0)
23 |
24 | """
25 |
26 | Allows the worker to run a human player, letting the AI play against and learn from human interation.
27 |
28 | Important things to note:
29 |
30 | -The human will always be blue due to RLGym camera constraints
31 | -Attempting to run a human trainer and pretrained agents will cause the pretrained agents to be ignored.
32 | They will never show up.
33 |
34 | """
35 |
36 | if __name__ == "__main__":
37 | match = Match(
38 | game_speed=1,
39 | self_play=True,
40 | team_size=1,
41 | state_setter=DefaultState(),
42 | obs_builder=ExpandAdvancedObs(),
43 | action_parser=DiscreteAction(),
44 | terminal_conditions=[TimeoutCondition(round(2000)),
45 | GoalScoredCondition()],
46 | reward_function=DefaultReward()
47 | )
48 |
49 | # ALLOW HUMAN CONTROL THROUGH MOUSE AND KEYBOARD OR A CONTROLLER IF ONE IS PLUGGED IN
50 | # -CONTROL BINDINGS ARE CURRENTLY NOT CHANGEABLE
51 | # -CONTROLLER SETUP CURRENTLY EXPECTS AN XBOX 360 CONTROLLER. OTHERS WILL WORK BUT PROBABLY NOT WELL
52 | human = HumanAgent()
53 |
54 | r = Redis(host="127.0.0.1", password="you_better_use_a_password")
55 |
56 | # LAUNCH ROCKET LEAGUE AND BEGIN TRAINING
57 | # -human_agent TELLS RLGYM THAT THE FIRST AGENT IS ALWAYS TO BE HUMAN CONTROLLED
58 | # -past_version_prob SPECIFIES HOW OFTEN OLD VERSIONS WILL BE RANDOMLY SELECTED AND TRAINED AGAINST
59 | RedisRolloutWorker(r, "exampleHumanWorker", match, human_agent=human, past_version_prob=.05).run()
60 |
--------------------------------------------------------------------------------
/rocket_learn/agent/pretrained_agents/human_agent.py:
--------------------------------------------------------------------------------
1 | import pygame
2 | import keyboard
3 |
4 | from rlgym.utils.gamestates import GameState
5 |
6 | from rocket_learn.agent.pretrained_policy import HardcodedAgent
7 |
8 | class HumanAgent(HardcodedAgent):
9 | def __init__(self):
10 | pygame.init()
11 | self.controller_map = {}
12 |
13 | self.joystick = None
14 | if pygame.joystick.get_count() > 0:
15 | self.joystick = pygame.joystick.Joystick(0)
16 | self.joystick.init()
17 | print("Controller found")
18 |
19 | def controller_actions(self, state):
20 | player = [p for p in state.players if p.team_num == 0][0]
21 | # allow controller to activate
22 | pygame.event.pump()
23 |
24 | jump = self.joystick.get_button(0)
25 | boost = self.joystick.get_button(1)
26 | handbrake = self.joystick.get_button(2)
27 |
28 | throttle = self.joystick.get_axis(5)
29 | throttle = max(0, throttle)
30 |
31 | reverse_throttle = self.joystick.get_axis(4)
32 | reverse_throttle = max(0, reverse_throttle)
33 |
34 | throttle = throttle - reverse_throttle
35 |
36 | steer = self.joystick.get_axis(0)
37 | if abs(steer) < .2:
38 | steer = 0
39 |
40 | pitch = self.joystick.get_axis(1)
41 | if abs(pitch) < .2:
42 | pitch = 0
43 |
44 | yaw = steer
45 |
46 | roll = 0
47 | roll_button = self.joystick.get_button(4)
48 | if roll_button or jump:
49 | roll = steer
50 |
51 | return [throttle, steer, pitch, yaw, roll, jump, boost, handbrake]
52 |
53 |
54 | def kbm_actions(self, state):
55 | player = [p for p in state.players if p.team_num == 0][0]
56 |
57 | throttle = 0
58 | if keyboard.is_pressed('w'):
59 | throttle = 1
60 | if keyboard.is_pressed('s'):
61 | throttle = -1
62 |
63 | steer = 0
64 | if keyboard.is_pressed('d'):
65 | steer = 1
66 | if keyboard.is_pressed('a'):
67 | steer = -1
68 |
69 | pitch = -throttle
70 |
71 | yaw = steer
72 |
73 | roll = 0
74 | if keyboard.is_pressed('e'):
75 | roll = 1
76 | if keyboard.is_pressed('q'):
77 | roll = -1
78 |
79 | jump = 0
80 | if keyboard.is_pressed('f'):
81 | jump = 1
82 |
83 | boost = 0
84 | handbrake = 0
85 |
86 | return [throttle, steer, pitch, yaw, roll, jump, boost, handbrake]
87 |
88 | def act(self, state: GameState, player_index: int):
89 | if self.joystick:
90 | actions = self.controller_actions(state)
91 | else:
92 | actions = self.kbm_actions(state)
93 |
94 | return actions
95 |
--------------------------------------------------------------------------------
/examples/default/worker.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import numpy
3 |
4 | from redis import Redis
5 |
6 | from rlgym.envs import Match
7 | from rlgym.utils.gamestates import PlayerData, GameState
8 | from rlgym.utils.terminal_conditions.common_conditions import GoalScoredCondition, TimeoutCondition
9 | from rlgym.utils.reward_functions.default_reward import DefaultReward
10 | from rlgym.utils.state_setters.default_state import DefaultState
11 | from rlgym.utils.obs_builders.advanced_obs import AdvancedObs
12 | from rlgym.utils.action_parsers.discrete_act import DiscreteAction
13 |
14 | from rocket_learn.rollout_generator.redis.redis_rollout_worker import RedisRolloutWorker
15 |
16 |
17 | # ROCKET-LEARN ALWAYS EXPECTS A BATCH DIMENSION IN THE BUILT OBSERVATION
18 | class ExpandAdvancedObs(AdvancedObs):
19 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any:
20 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action)
21 | return numpy.expand_dims(obs, 0)
22 |
23 |
24 | if __name__ == "__main__":
25 | """
26 |
27 | Starts up a rocket-learn worker process, which plays out a game, sends back game data to the
28 | learner, and receives updated model parameters when available
29 |
30 | """
31 |
32 | # OPTIONAL ADDITION:
33 | # LIMIT TORCH THREADS TO 1 ON THE WORKERS TO LIMIT TOTAL RESOURCE USAGE
34 | # TRY WITH AND WITHOUT FOR YOUR SPECIFIC HARDWARE
35 | import torch
36 |
37 | torch.set_num_threads(1)
38 |
39 | # BUILD THE ROCKET LEAGUE MATCH THAT WILL USED FOR TRAINING
40 | # -ENSURE OBSERVATION, REWARD, AND ACTION CHOICES ARE THE SAME IN THE WORKER
41 | match = Match(
42 | game_speed=100,
43 | spawn_opponents=True,
44 | team_size=1,
45 | state_setter=DefaultState(),
46 | obs_builder=ExpandAdvancedObs(),
47 | action_parser=DiscreteAction(),
48 | terminal_conditions=[TimeoutCondition(round(2000)),
49 | GoalScoredCondition()],
50 | reward_function=DefaultReward()
51 | )
52 |
53 | # LINK TO THE REDIS SERVER YOU SHOULD HAVE RUNNING (USE THE SAME PASSWORD YOU SET IN THE REDIS
54 | # CONFIG)
55 | r = Redis(host="127.0.0.1", password="you_better_use_a_password")
56 |
57 | # LAUNCH ROCKET LEAGUE AND BEGIN TRAINING
58 | # -past_version_prob SPECIFIES HOW OFTEN OLD VERSIONS WILL BE RANDOMLY SELECTED AND TRAINED AGAINST
59 | RedisRolloutWorker(r, "example", match,
60 | past_version_prob=.2,
61 | evaluation_prob=0.01,
62 | sigma_target=2,
63 | dynamic_gm=False,
64 | send_obs=True,
65 | streamer_mode=False,
66 | send_gamestates=False,
67 | force_paging=False,
68 | auto_minimize=True,
69 | local_cache_name="example_model_database").run()
70 |
--------------------------------------------------------------------------------
/rocket_learn/agent.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod, ABC
2 | from typing import List, Optional
3 |
4 | import numpy as np
5 | import torch as th
6 | from torch.distributions import Categorical
7 |
8 |
9 | # class BaseAgent(ABC):
10 | # def __init__(self, index_action_map: Optional[np.ndarray] = None):
11 | # if index_action_map is None:
12 | # self.index_action_map = np.array([
13 | # [-1., 0., 1.],
14 | # [-1., 0., 1.],
15 | # [-1., 0., 1.],
16 | # [-1., 0., 1.],
17 | # [-1., 0., 1.],
18 | # [0., 1., np.nan],
19 | # [0., 1., np.nan],
20 | # [0., 1., np.nan]
21 | # ])
22 | # else:
23 | # self.index_action_map = index_action_map
24 | #
25 | # def forward_actor_critic(self, obs):
26 | # raise NotImplementedError
27 | #
28 | # def forward_actor(self, obs):
29 | # return self.forward_actor_critic(obs)[0]
30 | #
31 | # def forward_critic(self, obs):
32 | # return self.forward_actor_critic(obs)[1]
33 | #
34 | # def get_action_distribution(self, obs) -> List[Categorical]:
35 | # if isinstance(obs, np.ndarray):
36 | # obs = th.from_numpy(obs).float()
37 | # elif isinstance(obs, tuple):
38 | # obs = tuple(o if isinstance(o, th.Tensor) else th.from_numpy(o).float() for o in obs)
39 | # logits = self.forward_actor(obs)
40 | #
41 | # return [Categorical(logits=logit) for logit in logits]
42 | #
43 | # def get_action_indices(self, distribution: List[Categorical], deterministic=False, include_log_prob=False,
44 | # include_entropy=False):
45 | # if deterministic:
46 | # action_indices = th.stack([th.argmax(dist.logits) for dist in distribution])
47 | # else:
48 | # action_indices = th.stack([dist.sample() for dist in distribution])
49 | #
50 | # returns = [action_indices.numpy()]
51 | # if include_log_prob:
52 | # # SOREN NOTE:
53 | # # adding dim=1 is causing it to crash
54 | #
55 | # log_prob = th.stack(
56 | # [dist.log_prob(action) for dist, action in zip(distribution, th.unbind(action_indices, dim=-1))], dim=-1
57 | # ).sum(dim=-1)
58 | # returns.append(log_prob.numpy())
59 | # if include_entropy:
60 | # entropy = th.stack([dist.entropy() for dist in distribution], dim=1).sum(dim=1)
61 | # returns.append(entropy.numpy())
62 | # return tuple(returns)
63 | #
64 | # def get_action(self, action_indices) -> np.ndarray:
65 | # return self.index_action_map[np.arange(len(self.index_action_map)), action_indices]
66 | #
67 | # @abstractmethod
68 | # def get_model_params(self):
69 | # raise NotImplementedError
70 | #
71 | # @abstractmethod
72 | # def set_model_params(self, params) -> None:
73 | # raise NotImplementedError
74 |
--------------------------------------------------------------------------------
/examples/pretrained_agent/worker_with_pretrained_agent.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import numpy
3 |
4 | from redis import Redis
5 |
6 | from rlgym.envs import Match
7 | from rlgym.utils.gamestates import PlayerData, GameState
8 | from rlgym.utils.terminal_conditions.common_conditions import GoalScoredCondition, TimeoutCondition
9 | from rlgym.utils.reward_functions.default_reward import DefaultReward
10 | from rlgym.utils.state_setters.default_state import DefaultState
11 | from rlgym.utils.obs_builders.advanced_obs import AdvancedObs
12 | from rlgym.utils.action_parsers.discrete_act import DiscreteAction
13 |
14 | from rocket_learn.rollout_generator.redis.redis_rollout_worker import RedisRolloutWorker
15 | from rocket_learn.agent.pretrained_policy import HardcodedAgent
16 |
17 | from rocket_learn.agent.pretrained_agents.necto.necto_v1 import NectoV1
18 |
19 | # ROCKET-LEARN ALWAYS EXPECTS A BATCH DIMENSION IN THE BUILT OBSERVATION
20 | class ExpandAdvancedObs(AdvancedObs):
21 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any:
22 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action)
23 | return numpy.expand_dims(obs, 0)
24 |
25 | class TotallyDifferentObs(AdvancedObs):
26 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any:
27 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action)
28 | return numpy.expand_dims(obs, 0)
29 |
30 |
31 | # NEW HARDCODED AGENTS MUST IMPLEMENT THE act() METHOD, WHICH RETURNS AN ARRAY OF ACTIONS USED TO
32 | # CONTROL THE AGENT
33 | class DemoCustomAgent(HardcodedAgent):
34 | def act(self, state: GameState, player_index: int):
35 | player_data = state.players[player_index]
36 |
37 | return [2, 1, 1, 0, 0, 0, 0, 0]
38 |
39 |
40 |
41 |
42 | if __name__ == "__main__":
43 | """
44 |
45 | Allows the worker to add in already built agents into training
46 |
47 | Important things to note:
48 |
49 | -RLGym only accepts 1 action parser. All agents will need to have the same parser or use a combination parser
50 |
51 | """
52 | team_size = 1
53 |
54 | # ENSURE OBSERVATION, REWARD, AND ACTION CHOICES ARE THE SAME IN THE LEARNER
55 | match = Match(
56 | game_speed=100,
57 | self_play=True,
58 | team_size=team_size,
59 | state_setter=DefaultState(),
60 | obs_builder=ExpandAdvancedObs(),
61 | action_parser=DiscreteAction(),
62 | terminal_conditions=[TimeoutCondition(round(2000)),
63 | GoalScoredCondition()],
64 | reward_function=DefaultReward()
65 | )
66 |
67 |
68 | # AT THE MOMENT, THE AVAILABLE NECTO MODELS ARE:
69 | # necto-model-10Y.pt
70 | # necto-model-20Y.pt
71 | # necto-model-30Y.pt
72 | model_name = "necto-model-30Y.pt"
73 | nectov1 = NectoV1(model_string=model_name, n_players=team_size*2)
74 |
75 | demo_hardcoded_agent = DemoCustomAgent()
76 |
77 | #EACH AGENT AND THEIR PROBABILITY OF OCCURRENCE
78 | pretrained_agents = {demo_hardcoded_agent: .05, nectov1: .05}
79 |
80 | r = Redis(host="127.0.0.1", password="you_better_use_a_password")
81 |
82 |
83 | RedisRolloutWorker(r, "examplePretrainedWorker", match, pretrained_agents=pretrained_agents, past_version_prob=.05).run()
84 |
--------------------------------------------------------------------------------
/rocket_learn/agent/pretrained_agents/nexto/nexto_v2.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | from rocket_learn.agent.pretrained_policy import HardcodedAgent
9 | from pretrained_agents.nexto.nexto_v2_obs import Nexto_V2_ObsBuilder
10 |
11 | from rlgym.utils.gamestates import GameState
12 |
13 | import copy
14 |
15 | class NextoV2(HardcodedAgent):
16 | def __init__(self, model_string, n_players):
17 | cur_dir = os.path.dirname(os.path.realpath(__file__))
18 | self.actor = torch.jit.load(os.path.join(cur_dir, model_string))
19 | self.obs_builder = Nexto_V2_ObsBuilder(n_players=n_players)
20 | self.previous_action = np.array([0, 0, 0, 0, 0, 0, 0, 0])
21 |
22 | self._lookup_table = self.make_lookup_table()
23 |
24 | @staticmethod
25 | def make_lookup_table():
26 | actions = []
27 | # Ground
28 | for throttle in (-1, 0, 1):
29 | for steer in (-1, 0, 1):
30 | for boost in (0, 1):
31 | for handbrake in (0, 1):
32 | if boost == 1 and throttle != 1:
33 | continue
34 | actions.append([throttle or boost, steer, 0, steer, 0, 0, boost, handbrake])
35 | # Aerial
36 | for pitch in (-1, 0, 1):
37 | for yaw in (-1, 0, 1):
38 | for roll in (-1, 0, 1):
39 | for jump in (0, 1):
40 | for boost in (0, 1):
41 | if jump == 1 and yaw != 0: # Only need roll for sideflip
42 | continue
43 | if pitch == roll == jump == 0: # Duplicate with ground
44 | continue
45 | # Enable handbrake for potential wavedashes
46 | handbrake = jump == 1 and (pitch != 0 or yaw != 0 or roll != 0)
47 | actions.append([boost, yaw, pitch, yaw, roll, jump, boost, handbrake])
48 | actions = np.array(actions)
49 | return actions
50 |
51 | def act(self, state: GameState, player_index: int):
52 | player = state.players[player_index]
53 | teammates = [p for p in state.players if p.team_num == player.team_num and p != player]
54 | opponents = [p for p in state.players if p.team_num != player.team_num]
55 | necto_gamestate: GameState = copy.deepcopy(state)
56 | necto_gamestate.players = [player] + teammates + opponents
57 |
58 | self.obs_builder.reset(necto_gamestate)
59 | obs = self.obs_builder.build_obs(player, necto_gamestate, self.previous_action)
60 |
61 | obs = tuple(torch.from_numpy(s).float() for s in obs)
62 | with torch.no_grad():
63 | out, _ = self.actor(obs)
64 |
65 | out = (out,)
66 | max_shape = max(o.shape[-1] for o in out)
67 | logits = torch.stack(
68 | [
69 | l
70 | if l.shape[-1] == max_shape
71 | else F.pad(l, pad=(0, max_shape - l.shape[-1]), value=float("-inf"))
72 | for l in out
73 | ],
74 | dim=1
75 | )
76 |
77 | actions = np.argmax(logits, axis=-1)
78 | parsed = self._lookup_table[actions.numpy().item()]
79 |
80 | self.previous_action = parsed
81 | return parsed
82 |
--------------------------------------------------------------------------------
/examples/loading/learner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wandb
3 | import numpy
4 | from typing import Any
5 |
6 | import torch.jit
7 | from torch.nn import Linear, Sequential
8 |
9 | from redis import Redis
10 |
11 | from rlgym.utils.obs_builders.advanced_obs import AdvancedObs
12 | from rlgym.utils.gamestates import PlayerData, GameState
13 | from rlgym.utils.reward_functions.default_reward import DefaultReward
14 | from rlgym.utils.action_parsers.discrete_act import DiscreteAction
15 |
16 | from rocket_learn.agent.actor_critic_agent import ActorCriticAgent
17 | from rocket_learn.agent.discrete_policy import DiscretePolicy
18 | from rocket_learn.ppo import PPO
19 | from rocket_learn.rollout_generator.redis.redis_rollout_generator import RedisRolloutGenerator
20 | from rocket_learn.utils.util import SplitLayer
21 |
22 |
23 | # rocket-learn always expects a batch dimension in the built observation
24 | class ExpandAdvancedObs(AdvancedObs):
25 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any:
26 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action)
27 | return numpy.expand_dims(obs, 0)
28 |
29 |
30 | if __name__ == "__main__":
31 | wandb.login(key=os.environ["WANDB_KEY"])
32 | logger = wandb.init(project="demo", entity="wandb_username")
33 | logger.name = "LOADING_RUN_EXAMPLE"
34 |
35 | redis = Redis(password="you_better_use_a_password")
36 |
37 |
38 | def obs():
39 | return ExpandAdvancedObs()
40 |
41 | def rew():
42 | return DefaultReward()
43 |
44 | def act():
45 | return DiscreteAction()
46 |
47 |
48 | # -clear DELETE REDIS ENTRIES WHEN STARTING UP (SET TO FALSE TO CONTINUE WITH OLD AGENTS)
49 | rollout_gen = RedisRolloutGenerator(redis, obs, rew, act,
50 | logger=logger,
51 | save_every=100,
52 | clear=False)
53 |
54 | critic = Sequential(Linear(107, 128), Linear(128, 64), Linear(64, 32), Linear(32, 1))
55 | actor = DiscretePolicy(
56 | Sequential(Linear(107, 128), Linear(128, 64), Linear(64, 32), Linear(32, 21), SplitLayer()))
57 |
58 | optim = torch.optim.Adam([
59 | {"params": actor.parameters(), "lr": 5e-5},
60 | {"params": critic.parameters(), "lr": 5e-5}
61 | ])
62 |
63 | agent = ActorCriticAgent(actor=actor, critic=critic, optimizer=optim)
64 |
65 | alg = PPO(
66 | rollout_gen,
67 | agent,
68 | ent_coef=0.01,
69 | n_steps=1_000_000,
70 | batch_size=20_000,
71 | minibatch_size=10_000,
72 | epochs=10,
73 | gamma=599 / 600,
74 | logger=logger,
75 | )
76 |
77 |
78 | # LOAD A CHECKPOINT THAT WAS PREVIOUSLY SAVED AND CONTINUE TRAINING. OPTIONAL PARAMETER ALLOWS YOU
79 | # TO RESTART THE STEP COUNT INSTEAD OF CONTINUING
80 | alg.load("path\\from\\below\\checkpoint.pt")
81 |
82 |
83 | # OPTIONAL: FOR A PRETRAINED NETWORK, FREEZE THE POLICY NETWORK TO ALLOW THE CRITIC TO SETTLE
84 | # commented out here to keep you from accidentally adding it via copy/paste
85 | # alg.freeze_policy(frozen_iterations=200)
86 |
87 |
88 | # BEGIN TRAINING. IT WILL CONTINUE UNTIL MANUALLY STOPPED
89 | # -iterations_per_save SPECIFIES HOW OFTEN CHECKPOINTS ARE SAVED
90 | # -save_dir SPECIFIES WHERE
91 | alg.run(iterations_per_save=100, save_dir="checkpoint_save_directory")
92 |
--------------------------------------------------------------------------------
/rocket_learn/utils/gamestate_encoding.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | from rlgym.utils.gamestates import GameState
4 |
5 |
6 | def encode_gamestate(state: GameState):
7 | state_vals = [0, state.blue_score, state.orange_score]
8 | state_vals += state.boost_pads.tolist()
9 |
10 | for bd in (state.ball, state.inverted_ball):
11 | state_vals += bd.position.tolist()
12 | state_vals += bd.linear_velocity.tolist()
13 | state_vals += bd.angular_velocity.tolist()
14 |
15 | for p in state.players:
16 | state_vals += [p.car_id, p.team_num]
17 | for cd in (p.car_data, p.inverted_car_data):
18 | state_vals += cd.position.tolist()
19 | state_vals += cd.quaternion.tolist()
20 | state_vals += cd.linear_velocity.tolist()
21 | state_vals += cd.angular_velocity.tolist()
22 | state_vals += [
23 | p.match_goals,
24 | p.match_saves,
25 | p.match_shots,
26 | p.match_demolishes,
27 | p.boost_pickups,
28 | p.is_demoed,
29 | p.on_ground,
30 | p.ball_touched,
31 | p.has_jump,
32 | p.has_flip,
33 | p.boost_amount
34 | ]
35 | return state_vals
36 |
37 |
38 | # Now some constants for easy and consistent querying of gamestate values
39 | class StateConstants:
40 | DUMMY = 0
41 | BLUE_SCORE = 1
42 | ORANGE_SCORE = 2
43 | BOOST_PADS = slice(3, 3 + GameState.BOOST_PADS_LENGTH)
44 | BALL_POSITION = slice(BOOST_PADS.stop, BOOST_PADS.stop + 3)
45 | BALL_LINEAR_VELOCITY = slice(BALL_POSITION.stop, BALL_POSITION.stop + 3)
46 | BALL_ANGULAR_VELOCITY = slice(BALL_LINEAR_VELOCITY.stop, BALL_LINEAR_VELOCITY.stop + 3)
47 |
48 | PLAYERS = slice(BALL_ANGULAR_VELOCITY.stop + 9, None) # Skip inverted data
49 | CAR_IDS = slice(0, None, GameState.PLAYER_INFO_LENGTH)
50 | TEAM_NUMS = slice(CAR_IDS.start + 1, None, GameState.PLAYER_INFO_LENGTH)
51 | CAR_POS_X = slice(TEAM_NUMS.start + 1, None, GameState.PLAYER_INFO_LENGTH)
52 | CAR_POS_Y = slice(CAR_POS_X.start + 1, None, GameState.PLAYER_INFO_LENGTH)
53 | CAR_POS_Z = slice(CAR_POS_Y.start + 1, None, GameState.PLAYER_INFO_LENGTH)
54 | CAR_QUAT_W = slice(CAR_POS_Z.start + 1, None, GameState.PLAYER_INFO_LENGTH)
55 | CAR_QUAT_X = slice(CAR_QUAT_W.start + 1, None, GameState.PLAYER_INFO_LENGTH)
56 | CAR_QUAT_Y = slice(CAR_QUAT_X.start + 1, None, GameState.PLAYER_INFO_LENGTH)
57 | CAR_QUAT_Z = slice(CAR_QUAT_Y.start + 1, None, GameState.PLAYER_INFO_LENGTH)
58 | CAR_LINEAR_VEL_X = slice(CAR_QUAT_Z.start + 1, None, GameState.PLAYER_INFO_LENGTH)
59 | CAR_LINEAR_VEL_Y = slice(CAR_LINEAR_VEL_X.start + 1, None, GameState.PLAYER_INFO_LENGTH)
60 | CAR_LINEAR_VEL_Z = slice(CAR_LINEAR_VEL_Y.start + 1, None, GameState.PLAYER_INFO_LENGTH)
61 | CAR_ANGULAR_VEL_X = slice(CAR_LINEAR_VEL_Z.start + 1, None, GameState.PLAYER_INFO_LENGTH)
62 | CAR_ANGULAR_VEL_Y = slice(CAR_ANGULAR_VEL_X.start + 1, None, GameState.PLAYER_INFO_LENGTH)
63 | CAR_ANGULAR_VEL_Z = slice(CAR_ANGULAR_VEL_Y.start + 1, None, GameState.PLAYER_INFO_LENGTH)
64 | MATCH_GOALS = slice(CAR_ANGULAR_VEL_Z.start + 1 + 13, None, GameState.PLAYER_INFO_LENGTH) # Skip inverted data
65 | MATCH_SAVES = slice(MATCH_GOALS.start + 1, None, GameState.PLAYER_INFO_LENGTH)
66 | MATCH_SHOTS = slice(MATCH_SAVES.start + 1, None, GameState.PLAYER_INFO_LENGTH)
67 | MATCH_DEMOLISHES = slice(MATCH_SHOTS.start + 1, None, GameState.PLAYER_INFO_LENGTH)
68 | BOOST_PICKUPS = slice(MATCH_DEMOLISHES.start + 1, None, GameState.PLAYER_INFO_LENGTH)
69 | IS_DEMOED = slice(BOOST_PICKUPS.start + 1, None, GameState.PLAYER_INFO_LENGTH)
70 | ON_GROUND = slice(IS_DEMOED.start + 1, None, GameState.PLAYER_INFO_LENGTH)
71 | BALL_TOUCHED = slice(ON_GROUND.start + 1, None, GameState.PLAYER_INFO_LENGTH)
72 | HAS_JUMP = slice(BALL_TOUCHED.start + 1, None, GameState.PLAYER_INFO_LENGTH)
73 | HAS_FLIP = slice(HAS_JUMP.start + 1, None, GameState.PLAYER_INFO_LENGTH)
74 | BOOST_AMOUNT = slice(HAS_FLIP.start + 1, None, GameState.PLAYER_INFO_LENGTH)
75 |
--------------------------------------------------------------------------------
/examples/default/learner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import wandb
3 | import numpy
4 | from typing import Any
5 |
6 | import torch.jit
7 | from torch.nn import Linear, Sequential, ReLU
8 |
9 | from redis import Redis
10 |
11 | from rlgym.utils.obs_builders.advanced_obs import AdvancedObs
12 | from rlgym.utils.gamestates import PlayerData, GameState
13 | from rlgym.utils.reward_functions.default_reward import DefaultReward
14 | from rlgym.utils.action_parsers.discrete_act import DiscreteAction
15 |
16 | from rocket_learn.agent.actor_critic_agent import ActorCriticAgent
17 | from rocket_learn.agent.discrete_policy import DiscretePolicy
18 | from rocket_learn.ppo import PPO
19 | from rocket_learn.rollout_generator.redis.redis_rollout_generator import RedisRolloutGenerator
20 | from rocket_learn.utils.util import SplitLayer
21 |
22 |
23 | # ROCKET-LEARN ALWAYS EXPECTS A BATCH DIMENSION IN THE BUILT OBSERVATION
24 | class ExpandAdvancedObs(AdvancedObs):
25 | def build_obs(self, player: PlayerData, state: GameState, previous_action: numpy.ndarray) -> Any:
26 | obs = super(ExpandAdvancedObs, self).build_obs(player, state, previous_action)
27 | return numpy.expand_dims(obs, 0)
28 |
29 |
30 | if __name__ == "__main__":
31 | """
32 |
33 | Starts up a rocket-learn learner process, which ingests incoming data, updates parameters
34 | based on results, and sends updated model parameters out to the workers
35 |
36 | """
37 |
38 | # ROCKET-LEARN USES WANDB WHICH REQUIRES A LOGIN TO USE. YOU CAN SET AN ENVIRONMENTAL VARIABLE
39 | # OR HARDCODE IT IF YOU ARE NOT SHARING YOUR SOURCE FILES
40 | wandb.login(key=os.environ["WANDB_KEY"])
41 | logger = wandb.init(project="demo", entity="wandb_username")
42 | logger.name = "DEFAULT_LEARNER_EXAMPLE"
43 |
44 | # LINK TO THE REDIS SERVER YOU SHOULD HAVE RUNNING (USE THE SAME PASSWORD YOU SET IN THE REDIS
45 | # CONFIG)
46 | redis = Redis(password="you_better_use_a_password")
47 |
48 | # ** ENSURE OBSERVATION, REWARD, AND ACTION CHOICES ARE THE SAME IN THE WORKER **
49 | def obs():
50 | return ExpandAdvancedObs()
51 |
52 | def rew():
53 | return DefaultReward()
54 |
55 | def act():
56 | return DiscreteAction()
57 |
58 |
59 | # THE ROLLOUT GENERATOR CAPTURES INCOMING DATA THROUGH REDIS AND PASSES IT TO THE LEARNER.
60 | # -save_every SPECIFIES HOW OFTEN REDIS DATABASE IS BACKED UP TO DISK
61 | # -model_every SPECIFIES HOW OFTEN OLD VERSIONS ARE SAVED TO REDIS. THESE ARE USED FOR TRUESKILL
62 | # COMPARISON AND TRAINING AGAINST PREVIOUS VERSIONS
63 | # -clear DELETE REDIS ENTRIES WHEN STARTING UP (SET TO FALSE TO CONTINUE WITH OLD AGENTS)
64 | rollout_gen = RedisRolloutGenerator("demo-bot", redis, obs, rew, act,
65 | logger=logger,
66 | save_every=100,
67 | model_every=100,
68 | clear=False)
69 |
70 | # ROCKET-LEARN EXPECTS A SET OF DISTRIBUTIONS FOR EACH ACTION FROM THE NETWORK, NOT
71 | # THE ACTIONS THEMSELVES. SEE network_setup.readme.txt FOR MORE INFORMATION
72 | split = (3, 3, 3, 3, 3, 2, 2, 2)
73 | total_output = sum(split)
74 |
75 | # TOTAL SIZE OF THE INPUT DATA
76 | state_dim = 107
77 |
78 | critic = Sequential(
79 | Linear(state_dim, 256),
80 | ReLU(),
81 | Linear(256, 256),
82 | ReLU(),
83 | Linear(256, 1)
84 | )
85 |
86 | actor = DiscretePolicy(Sequential(
87 | Linear(state_dim, 256),
88 | ReLU(),
89 | Linear(256, 256),
90 | ReLU(),
91 | Linear(256, total_output),
92 | SplitLayer(splits=split)
93 | ), split)
94 |
95 | # CREATE THE OPTIMIZER
96 | optim = torch.optim.Adam([
97 | {"params": actor.parameters(), "lr": 5e-5},
98 | {"params": critic.parameters(), "lr": 5e-5}
99 | ])
100 |
101 | # PPO REQUIRES AN ACTOR/CRITIC AGENT
102 | agent = ActorCriticAgent(actor=actor, critic=critic, optimizer=optim)
103 |
104 | # INSTANTIATE THE PPO TRAINING ALGORITHM
105 | alg = PPO(
106 | rollout_gen,
107 | agent,
108 | ent_coef=0.01,
109 | n_steps=1_000_000,
110 | batch_size=20_000,
111 | minibatch_size=10_000,
112 | epochs=10,
113 | gamma=599 / 600,
114 | clip_range=0.2,
115 | gae_lambda=0.95,
116 | vf_coef=1,
117 | max_grad_norm=0.5,
118 | logger=logger,
119 | device="cuda",
120 | )
121 |
122 | # BEGIN TRAINING. IT WILL CONTINUE UNTIL MANUALLY STOPPED
123 | # -iterations_per_save SPECIFIES HOW OFTEN CHECKPOINTS ARE SAVED
124 | # -save_dir SPECIFIES WHERE
125 | alg.run(iterations_per_save=100, save_dir="checkpoint_save_directory")
126 |
--------------------------------------------------------------------------------
/rocket_learn/agent/pretrained_agents/necto/necto_v1_obs.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import numpy as np
4 |
5 | from rlgym.utils import ObsBuilder
6 | from rlgym.utils.common_values import BOOST_LOCATIONS, BLUE_TEAM, ORANGE_TEAM
7 | from rlgym.utils.gamestates import GameState, PlayerData
8 |
9 |
10 | class NectoV1Obs(ObsBuilder):
11 | _boost_locations = np.array(BOOST_LOCATIONS)
12 | _invert = np.array([1] * 5 + [-1, -1, 1] * 5 + [1] * 4)
13 | _norm = np.array([1.] * 5 + [2300] * 6 + [1] * 6 + [5.5] * 3 + [1] * 4)
14 |
15 | def __init__(self, n_players=6, tick_skip=8):
16 | super().__init__()
17 | self.n_players = n_players
18 | self.demo_timers = None
19 | self.boost_timers = None
20 | self.current_state = None
21 | self.current_qkv = None
22 | self.current_mask = None
23 | self.tick_skip = tick_skip
24 |
25 | def reset(self, initial_state: GameState):
26 | self.demo_timers = np.zeros(self.n_players)
27 | self.boost_timers = np.zeros(len(initial_state.boost_pads))
28 | # self.current_state = initial_state
29 |
30 | def _maybe_update_obs(self, state: GameState):
31 | if state == self.current_state: # No need to update
32 | return
33 |
34 | if self.boost_timers is None:
35 | self.reset(state)
36 | else:
37 | self.current_state = state
38 |
39 | qkv = np.zeros((1, 1 + self.n_players + len(state.boost_pads), 24)) # Ball, players, boosts
40 |
41 | # Add ball
42 | n = 0
43 | ball = state.ball
44 | qkv[0, 0, 3] = 1 # is_ball
45 | qkv[0, 0, 5:8] = ball.position
46 | qkv[0, 0, 8:11] = ball.linear_velocity
47 | qkv[0, 0, 17:20] = ball.angular_velocity
48 |
49 | # Add players
50 | n += 1
51 | demos = np.zeros(self.n_players) # Which players are currently demoed
52 | for player in state.players:
53 | if player.team_num == BLUE_TEAM:
54 | qkv[0, n, 1] = 1 # is_teammate
55 | else:
56 | qkv[0, n, 2] = 1 # is_opponent
57 | car_data = player.car_data
58 | qkv[0, n, 5:8] = car_data.position
59 | qkv[0, n, 8:11] = car_data.linear_velocity
60 | qkv[0, n, 11:14] = car_data.forward()
61 | qkv[0, n, 14:17] = car_data.up()
62 | qkv[0, n, 17:20] = car_data.angular_velocity
63 | qkv[0, n, 20] = player.boost_amount
64 | # qkv[0, n, 21] = player.is_demoed
65 | demos[n - 1] = player.is_demoed # Keep track for demo timer
66 | qkv[0, n, 22] = player.on_ground
67 | qkv[0, n, 23] = player.has_flip
68 | n += 1
69 |
70 | # Add boost pads
71 | n = 1 + self.n_players
72 | boost_pads = state.boost_pads
73 | qkv[0, n:, 4] = 1 # is_boost
74 | qkv[0, n:, 5:8] = self._boost_locations
75 | qkv[0, n:, 20] = 0.12 + 0.88 * (self._boost_locations[:, 2] > 72) # Boost amount
76 | # qkv[0, n:, 21] = boost_pads
77 |
78 | # Boost and demo timers
79 | new_boost_grabs = (boost_pads == 1) & (self.boost_timers == 0) # New boost grabs since last frame
80 | self.boost_timers[new_boost_grabs] = 0.4 + 0.6 * (self._boost_locations[new_boost_grabs, 2] > 72)
81 | self.boost_timers *= boost_pads # Make sure we have zeros right
82 | qkv[0, 1 + self.n_players:, 21] = self.boost_timers
83 | self.boost_timers -= self.tick_skip / 1200 # Pre-normalized, 120 fps for 10 seconds
84 | self.boost_timers[self.boost_timers < 0] = 0
85 |
86 | new_demos = (demos == 1) & (self.demo_timers == 0)
87 | self.demo_timers[new_demos] = 0.3
88 | self.demo_timers *= demos
89 | qkv[0, 1: 1 + self.n_players, 21] = self.demo_timers
90 | self.demo_timers -= self.tick_skip / 1200
91 | self.demo_timers[self.demo_timers < 0] = 0
92 |
93 | # Store results
94 | self.current_qkv = qkv / self._norm
95 | mask = np.zeros((1, qkv.shape[1]))
96 | mask[0, 1 + len(state.players):1 + self.n_players] = 1
97 | self.current_mask = mask
98 |
99 | def build_obs(self, player: PlayerData, state: GameState, previous_action: np.ndarray) -> Any:
100 | if self.boost_timers is None:
101 | return np.zeros(0) # Obs space autodetect, make Aech happy
102 | self._maybe_update_obs(state)
103 | invert = player.team_num == ORANGE_TEAM
104 |
105 | qkv = self.current_qkv.copy()
106 | mask = self.current_mask.copy()
107 |
108 | main_n = state.players.index(player) + 1
109 | qkv[0, main_n, 0] = 1 # is_main
110 | if invert:
111 | qkv[0, :, (1, 2)] = qkv[0, :, (2, 1)] # Swap blue/orange
112 | qkv *= self._invert # Negate x and y values
113 |
114 | # TODO left-right normalization (always pick one side)
115 |
116 | q = qkv[0, main_n, :]
117 |
118 | #print(q)
119 | #print("vs")
120 | #print(previous_action)
121 | q = np.expand_dims(np.concatenate((q, previous_action), axis=0), axis=(0, 1))
122 | # kv = np.delete(qkv, main_n, axis=0) # Delete main? Watch masking
123 | kv = qkv
124 |
125 | # With EARLPerceiver we can use relative coords+vel(+more?) for key/value tensor, might be smart
126 | kv[0, :, 5:11] -= q[0, 0, 5:11]
127 | return q, kv, mask
--------------------------------------------------------------------------------
/rocket_learn/utils/scoreboard.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import numpy as np
5 | from numpy.random import poisson
6 | from rlgym.utils.common_values import BACK_WALL_Y, SIDE_WALL_X, GOAL_HEIGHT
7 |
8 | from rlgym.utils.gamestates import GameState
9 |
10 | TICKS_PER_SECOND = 120
11 | SECONDS_PER_MINUTE = 60
12 | GOALS_PER_MIN = (1, 0.6, 0.45) # Stats from ballchasing, S14 GC (before SSL)
13 |
14 |
15 | class Scoreboard:
16 | def __init__(self, random_resets=True, tick_skip=8, max_time_seconds=300, skip_warning=False):
17 | super().__init__()
18 | self.random_resets = random_resets
19 | self.tick_skip = tick_skip
20 | self.max_time_seconds = max_time_seconds
21 | self.ticks_left = None
22 | self.scoreline = None
23 | self.state = None
24 | if not skip_warning:
25 | print("WARNING: The Scoreboard object overwrites the inverted ball ang.vel. to include scoreboard, "
26 | "make sure you're not using that and instead inverting on your own. "
27 | "Call it in your obs builder's pre-step method to use.")
28 |
29 | def reset(self, initial_state: GameState):
30 | self.state = initial_state
31 | players_per_team = len(initial_state.players) // 2
32 | if self.random_resets:
33 | gpm = GOALS_PER_MIN[players_per_team - 1]
34 | mu_full = gpm * self.max_time_seconds / SECONDS_PER_MINUTE
35 | full_game = poisson(mu_full, size=2)
36 | if full_game[1] == full_game[0]:
37 | self.ticks_left = float("inf") # Overtime
38 | b = o = full_game[0].item()
39 | else:
40 | max_ticks = self.max_time_seconds * TICKS_PER_SECOND
41 | self.ticks_left = random.randrange(0, max_ticks)
42 | seconds_spent = self.max_time_seconds - self.ticks_left / TICKS_PER_SECOND
43 | mu_spent = gpm * seconds_spent / SECONDS_PER_MINUTE
44 | b, o = poisson(mu_spent, size=2).tolist()
45 | self.scoreline = b, o
46 | else:
47 | self.scoreline = 0, 0
48 | self.ticks_left = self.max_time_seconds * TICKS_PER_SECOND
49 | self.modify_gamestate(initial_state)
50 |
51 | def step(self, state: GameState, update_scores=True):
52 | if state != self.state:
53 | if state.ball.position[1] != 0: # Don't count during kickoffs
54 | self.ticks_left = max(0, self.ticks_left - self.tick_skip)
55 |
56 | if update_scores:
57 | b, o = self.scoreline
58 | changed = False
59 | if state.blue_score > self.state.blue_score: # Check in case of crash
60 | b += state.blue_score - self.state.blue_score
61 | changed = True
62 | if state.orange_score > self.state.orange_score:
63 | o += state.orange_score - self.state.orange_score
64 | changed = True
65 | tied = b == o
66 | if self.is_overtime():
67 | if not tied:
68 | self.ticks_left = float("-inf") # Finished
69 | if self.ticks_left <= 0 and (state.ball.position[2] <= 110 or changed):
70 | if tied:
71 | self.ticks_left = float("inf") # Overtime
72 | else:
73 | self.ticks_left = float("-inf") # Finished
74 | self.scoreline = b, o
75 |
76 | self.state = state
77 | self.modify_gamestate(state)
78 |
79 | def modify_gamestate(self, state: GameState):
80 | state.inverted_ball.angular_velocity[:] = *self.scoreline, self.ticks_left
81 |
82 | def is_overtime(self):
83 | return self.ticks_left > 0 and math.isinf(self.ticks_left)
84 |
85 | def is_finished(self):
86 | return self.ticks_left < 0 and math.isinf(self.ticks_left)
87 |
88 | def win_prob(self):
89 | return win_prob(self.state.players // 2,
90 | self.ticks_left * TICKS_PER_SECOND,
91 | self.scoreline[0] - self.scoreline[1]).item()
92 |
93 |
94 | FLOOR_AREA = 4 * BACK_WALL_Y * SIDE_WALL_X - 1152 * 1152 # Subtract corners
95 | GOAL_AREA = GOAL_HEIGHT * 880
96 |
97 |
98 | def win_prob(players_per_team, time_left_seconds, differential):
99 | # Utility function, calculates probability of blue team winning the full game
100 | from scipy.stats import skellam
101 |
102 | players_per_team = np.asarray(players_per_team)
103 | time_left_seconds = np.asarray(time_left_seconds)
104 | differential = np.asarray(differential)
105 |
106 | goal_floor_ratio = GOAL_AREA / (2 * GOAL_AREA + FLOOR_AREA)
107 |
108 | # inverted = np.random.random(differential.shape) > 0.5
109 | # differential[inverted] *= -1
110 | p = np.zeros(differential.shape)
111 |
112 | gpm = np.array(GOALS_PER_MIN)[players_per_team - 1]
113 | zero_seconds = (time_left_seconds <= 0) | np.isinf(time_left_seconds)
114 |
115 | mu_left = gpm * time_left_seconds / SECONDS_PER_MINUTE
116 | mu_left = mu_left[~zero_seconds]
117 | dist_left = skellam(mu_left, mu_left)
118 |
119 | diff_regulation = differential[~zero_seconds]
120 |
121 | # Probability of leading by two or more goals at end of regulation
122 | p[zero_seconds & (differential >= 2)] += 1
123 | p[~zero_seconds] += dist_left.cdf(diff_regulation - 2)
124 |
125 | # Probability of being tied at zero seconds and winning in overtime
126 | w = 0.5
127 | p[zero_seconds & (differential == 0)] += w
128 | p[~zero_seconds] += dist_left.pmf(diff_regulation) * w
129 |
130 | # Probability of leading by one at zero seconds, and surviving or getting scored on and winning OT
131 | w = (1 - goal_floor_ratio * 0.5)
132 | p[zero_seconds & (differential == 1)] += w
133 | p[~zero_seconds] += dist_left.pmf(diff_regulation - 1) * w
134 |
135 | # Probability of losing by one at zero seconds, and scoring and winning overtime
136 | w = goal_floor_ratio * 0.5
137 | p[zero_seconds & (differential == -1)] += w
138 | p[~zero_seconds] += dist_left.pmf(diff_regulation + 1) * w
139 |
140 | # p[inverted] = 1 - p[inverted]
141 |
142 | return p
143 |
--------------------------------------------------------------------------------
/rocket_learn/rollout_generator/redis/utils.py:
--------------------------------------------------------------------------------
1 | # Constants for consistent key lookup
2 | import pickle
3 | import zlib
4 | from typing import List, Optional, Union, Dict
5 |
6 | import numpy as np
7 | from redis import Redis
8 | from rlgym.utils.gamestates import GameState
9 | from trueskill import Rating
10 |
11 | from rocket_learn.experience_buffer import ExperienceBuffer
12 | from rocket_learn.utils.batched_obs_builder import BatchedObsBuilder
13 | from rocket_learn.utils.gamestate_encoding import encode_gamestate
14 | import msgpack
15 | import msgpack_numpy as m
16 |
17 | QUALITIES = "qualities-{}"
18 | N_UPDATES = "num-updates"
19 | # SAVE_FREQ = "save-freq"
20 | # MODEL_FREQ = "model-freq"
21 |
22 | MODEL_LATEST = "model-latest"
23 | VERSION_LATEST = "model-version"
24 |
25 | ROLLOUTS = "rollout"
26 | OPPONENT_MODELS = "opponent-models"
27 | WORKER_IDS = "worker-ids"
28 | CONTRIBUTORS = "contributors"
29 | LATEST_RATING_ID = "latest-rating-id"
30 | EXPERIENCE_PER_MODE = "experience-per-mode"
31 | _ALL = (
32 | N_UPDATES, MODEL_LATEST, VERSION_LATEST, ROLLOUTS, OPPONENT_MODELS,
33 | WORKER_IDS, CONTRIBUTORS, LATEST_RATING_ID, EXPERIENCE_PER_MODE)
34 |
35 | m.patch()
36 |
37 |
38 | # Helper methods for easier changing of byte conversion
39 | def _serialize(obj):
40 | return zlib.compress(msgpack.packb(obj))
41 |
42 |
43 | def _unserialize(obj):
44 | return msgpack.unpackb(zlib.decompress(obj))
45 |
46 |
47 | def _serialize_model(mdl):
48 | device = next(mdl.parameters()).device # Must be a better way right?
49 | mdl_bytes = pickle.dumps(mdl.cpu())
50 | mdl.to(device)
51 | return mdl_bytes
52 |
53 |
54 | def _unserialize_model(buf):
55 | agent = pickle.loads(buf)
56 | return agent
57 |
58 |
59 | def get_rating(gamemode: str, model_id: Optional[str], redis: Redis) -> Union[Rating, Dict[str, Rating]]:
60 | """
61 | Get the rating of a player.
62 | :param gamemode: The game mode to get the rating for.
63 | :param model_id: The id of the model.
64 | :param redis: The redis client.
65 | :return: The rating of the player.
66 | """
67 | quality_key = QUALITIES.format(gamemode)
68 | if model_id is None: # Return all ratings
69 | return {
70 | k.decode("utf-8"): Rating(*_unserialize(v))
71 | for k, v in redis.hgetall(quality_key).items()
72 | }
73 | return Rating(*_unserialize(redis.hget(quality_key, model_id)))
74 |
75 |
76 | def encode_buffers(buffers: List[ExperienceBuffer], return_obs=True, return_states=True, return_rewards=True):
77 | res = []
78 |
79 | if return_states:
80 | states = np.asarray([encode_gamestate(info["state"]) for info in buffers[0].infos] if len(buffers) > 0 else [])
81 | res.append(states)
82 |
83 | if return_obs:
84 | observations = [buffer.observations for buffer in buffers]
85 | res.append(observations)
86 |
87 | if return_rewards:
88 | rewards = np.asarray([buffer.rewards for buffer in buffers])
89 | res.append(rewards)
90 |
91 | actions = np.asarray([buffer.actions for buffer in buffers])
92 | log_probs = np.asarray([buffer.log_probs for buffer in buffers])
93 | dones = np.asarray([buffer.dones for buffer in buffers])
94 | res.append(actions)
95 | res.append(log_probs)
96 | res.append(dones)
97 |
98 | return res
99 |
100 |
101 | def decode_buffers(enc_buffers, versions, has_obs, has_states, has_rewards,
102 | obs_build_factory=None, rew_func_factory=None, act_parse_factory=None):
103 | assert has_states or has_obs, "Must have at least one of obs or states"
104 | assert has_states or has_rewards, "Must have at least one of rewards or states"
105 | assert not has_obs or has_obs and has_rewards, "Must have both obs and rewards" # TODO obs+no reward?
106 |
107 | i = 0
108 | if has_states:
109 | game_states = enc_buffers[i]
110 | if len(game_states) == 0:
111 | raise RuntimeError
112 | i += 1
113 | else:
114 | game_states = None
115 | if has_obs:
116 | obs = enc_buffers[i]
117 | i += 1
118 | else:
119 | obs = None
120 | if has_rewards:
121 | rewards = enc_buffers[i]
122 | i += 1
123 | # dones = np.zeros_like(rewards, dtype=bool) # TODO: Support for dones?
124 | # if len(dones) > 0:
125 | # dones[:, -1] = True
126 | else:
127 | rewards = None
128 | # dones = None
129 | actions = enc_buffers[i]
130 | i += 1
131 | log_probs = enc_buffers[i]
132 | i += 1
133 | dones = enc_buffers[i]
134 | i += 1
135 |
136 | if obs is None:
137 | # Reconstruct observations
138 | obs_builder = obs_build_factory()
139 | act_parser = act_parse_factory()
140 | if isinstance(obs_builder, BatchedObsBuilder):
141 | # TODO support states+no rewards
142 | assert game_states is not None and rewards is not None, "Must have both game states and rewards"
143 | obs = obs_builder.batched_build_obs(game_states[:-1])
144 | prev_actions = act_parser.parse_actions(actions.reshape((-1,) + actions.shape[2:]).copy(), None).reshape(
145 | actions.shape[:2] + (8,))
146 | prev_actions = np.concatenate((np.zeros((actions.shape[0], 1, 8)), prev_actions[:, :-1]), axis=1)
147 | obs_builder.add_actions(obs, prev_actions)
148 | buffers = [
149 | ExperienceBuffer(observations=[obs[i]], actions=actions[i], rewards=rewards[i], dones=dones[i],
150 | log_probs=log_probs[i])
151 | for i in range(len(obs))
152 | ]
153 | return buffers, game_states
154 | else: # Slow reconstruction, but works for any ObsBuilder
155 | gs_arrays = game_states
156 | game_states = [GameState(gs.tolist()) for gs in game_states]
157 | rew_func = rew_func_factory()
158 | obs_builder.reset(game_states[0])
159 | rew_func.reset(game_states[0])
160 | buffers = [
161 | ExperienceBuffer(infos=[{"state": game_states[0]}])
162 | for _ in range(len(game_states[0].players))
163 | ]
164 |
165 | env_actions = [
166 | act_parser.parse_actions(actions[:, s, :].copy(), game_states[s])
167 | for s in range(actions.shape[1])
168 | ]
169 |
170 | obss = [obs_builder.build_obs(p, game_states[0], np.zeros(8))
171 | for i, p in enumerate(game_states[0].players)]
172 | for s, gs in enumerate(game_states[1:]):
173 | assert len(gs.players) == len(versions)
174 | final = s == len(game_states) - 2
175 | old_obs = obss
176 | obss = []
177 | i = 0
178 | for version in versions:
179 | if version == 'na':
180 | continue # don't want to rebuild or use prebuilt agents
181 | player = gs.players[i]
182 |
183 | # IF ONLY 1 buffer is returned, need a way to say to discard bad version
184 |
185 | obs = obs_builder.build_obs(player, gs, env_actions[s][i])
186 | if rewards is None:
187 | if final:
188 | rew = rew_func.get_final_reward(player, gs, env_actions[s][i])
189 | else:
190 | rew = rew_func.get_reward(player, gs, env_actions[s][i])
191 | else:
192 | rew = rewards[i][s]
193 | buffers[i].add_step(old_obs[i], actions[i][s], rew, final, log_probs[i][s], {"state": gs})
194 | obss.append(obs)
195 | i += 1
196 |
197 | return buffers, gs_arrays
198 | else: # We have everything we need
199 | buffers = []
200 | for i in range(len(obs)):
201 | buffers.append(
202 | ExperienceBuffer(observations=obs[i],
203 | actions=actions[i],
204 | rewards=rewards[i],
205 | dones=dones[i],
206 | log_probs=log_probs[i])
207 | )
208 | return buffers, game_states
209 |
--------------------------------------------------------------------------------
/rocket_learn/utils/generate_episode.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import torch
5 | from rlgym.gym import Gym
6 | from rlgym.utils.reward_functions.common_rewards import ConstantReward
7 | from rlgym.utils.state_setters import DefaultState
8 | from tqdm import tqdm
9 |
10 | from rocket_learn.agent.policy import Policy
11 | from rocket_learn.agent.pretrained_policy import HardcodedAgent
12 | from rocket_learn.experience_buffer import ExperienceBuffer
13 | from rocket_learn.utils.dynamic_gamemode_setter import DynamicGMSetter
14 | from rocket_learn.utils.truncated_condition import TruncatedCondition
15 |
16 |
17 | def generate_episode(env: Gym, policies, evaluate=False, scoreboard=None, progress=False) -> (
18 | List[ExperienceBuffer], int):
19 | """
20 | create experience buffer data by interacting with the environment(s)
21 | """
22 | if progress:
23 | progress = tqdm(unit=" steps")
24 | else:
25 | progress = None
26 |
27 | if evaluate: # Change setup temporarily to play a normal game (approximately)
28 | from rlgym_tools.extra_terminals.game_condition import GameCondition # tools is an optional dependency
29 | terminals = env._match._terminal_conditions # noqa
30 | reward = env._match._reward_fn # noqa
31 | game_condition = GameCondition(tick_skip=env._match._tick_skip, # noqa
32 | seconds_per_goal_forfeit=10 * env._match._team_size,
33 | max_overtime_seconds=300,
34 | max_no_touch_seconds=60)
35 | env._match._terminal_conditions = [game_condition] # noqa
36 | if isinstance(env._match._state_setter, DynamicGMSetter): # noqa
37 | state_setter = env._match._state_setter.setter # noqa
38 | env._match._state_setter.setter = DefaultState() # noqa
39 | else:
40 | state_setter = env._match._state_setter # noqa
41 | env._match._state_setter = DefaultState() # noqa
42 |
43 | env._match._reward_fn = ConstantReward() # noqa Save some cpu cycles
44 |
45 | if scoreboard is not None:
46 | random_resets = scoreboard.random_resets
47 | scoreboard.random_resets = not evaluate
48 | observations, info = env.reset(return_info=True)
49 | result = 0
50 |
51 | last_state = info['state'] # game_state for obs_building of other agents
52 |
53 | latest_policy_indices = [0 if isinstance(p, HardcodedAgent) else 1 for p in policies]
54 | # rollouts for all latest_policies
55 | rollouts = [
56 | ExperienceBuffer(infos=[info])
57 | for _ in range(sum(latest_policy_indices))
58 | ]
59 |
60 | b = o = 0
61 | with torch.no_grad():
62 | while True:
63 | all_indices = []
64 | all_actions = []
65 | all_log_probs = []
66 |
67 | # if observation isn't a list, make it one so we don't iterate over the observation directly
68 | if not isinstance(observations, list):
69 | observations = [observations]
70 |
71 | if not isinstance(policies[0], HardcodedAgent) and all(policy == policies[0] for policy in policies):
72 | policy = policies[0]
73 | if isinstance(observations[0], tuple):
74 | obs = tuple(np.concatenate([obs[i] for obs in observations], axis=0)
75 | for i in range(len(observations[0])))
76 | else:
77 | obs = np.concatenate(observations, axis=0)
78 | dist = policy.get_action_distribution(obs)
79 | action_indices = policy.sample_action(dist)
80 | log_probs = policy.log_prob(dist, action_indices)
81 | actions = policy.env_compatible(action_indices)
82 |
83 | all_indices.extend(list(action_indices.numpy()))
84 | all_actions.extend(list(actions))
85 | all_log_probs.extend(list(log_probs.numpy()))
86 | else:
87 | index = 0
88 | for policy, obs in zip(policies, observations):
89 | if isinstance(policy, HardcodedAgent):
90 | actions = policy.act(last_state, index)
91 |
92 | # make sure output is in correct format
93 | if not isinstance(observations, np.ndarray):
94 | actions = np.array(actions)
95 |
96 | # TODO: add converter that takes normal 8 actions into action space
97 | # actions = env._match._action_parser.convert_to_action_space(actions)
98 |
99 | all_indices.append(None)
100 | all_actions.append(actions)
101 | all_log_probs.append(None)
102 |
103 | elif isinstance(policy, Policy):
104 | dist = policy.get_action_distribution(obs)
105 | action_indices = policy.sample_action(dist)[0]
106 | log_probs = policy.log_prob(dist, action_indices).item()
107 | actions = policy.env_compatible(action_indices)
108 |
109 | all_indices.append(action_indices.numpy())
110 | all_actions.append(actions)
111 | all_log_probs.append(log_probs)
112 |
113 | else:
114 | print(str(type(policy)) + " type use not defined")
115 | assert False
116 |
117 | index += 1
118 |
119 | # to allow different action spaces, pad out short ones to longest length (assume later unpadding in parser)
120 | # length = max([a.shape[0] for a in all_actions])
121 | # padded_actions = []
122 | # for a in all_actions:
123 | # action = np.pad(a.astype('float64'), (0, length - a.size), 'constant', constant_values=np.NAN)
124 | # padded_actions.append(action)
125 | #
126 | # all_actions = padded_actions
127 | # TEST OUT ABOVE TO DEAL WITH VARIABLE LENGTH
128 |
129 | all_actions = np.vstack(all_actions)
130 | old_obs = observations
131 | observations, rewards, done, info = env.step(all_actions)
132 |
133 | truncated = False
134 | for terminal in env._match._terminal_conditions: # noqa
135 | if isinstance(terminal, TruncatedCondition):
136 | truncated |= terminal.is_truncated(info["state"])
137 |
138 | if len(policies) <= 1:
139 | observations, rewards = [observations], [rewards]
140 |
141 | # prune data that belongs to old agents
142 | old_obs = [a for i, a in enumerate(old_obs) if latest_policy_indices[i] == 1]
143 | all_indices = [d for i, d in enumerate(all_indices) if latest_policy_indices[i] == 1]
144 | rewards = [r for i, r in enumerate(rewards) if latest_policy_indices[i] == 1]
145 | all_log_probs = [r for i, r in enumerate(all_log_probs) if latest_policy_indices[i] == 1]
146 |
147 | assert len(old_obs) == len(all_indices), str(len(old_obs)) + " obs, " + str(len(all_indices)) + " ind"
148 | assert len(old_obs) == len(rewards), str(len(old_obs)) + " obs, " + str(len(rewards)) + " ind"
149 | assert len(old_obs) == len(all_log_probs), str(len(old_obs)) + " obs, " + str(len(all_log_probs)) + " ind"
150 | assert len(old_obs) == len(rollouts), str(len(old_obs)) + " obs, " + str(len(rollouts)) + " ind"
151 |
152 | # Might be different if only one agent?
153 | if not evaluate: # Evaluation matches can be long, no reason to keep them in memory
154 | for exp_buf, obs, act, rew, log_prob in zip(rollouts, old_obs, all_indices, rewards, all_log_probs):
155 | exp_buf.add_step(obs, act, rew, done + 2 * truncated, log_prob, info)
156 |
157 | if progress is not None:
158 | progress.update()
159 | igt = progress.n * env._match._tick_skip / 120 # noqa
160 | prog_str = f"{igt // 60:02.0f}:{igt % 60:02.0f} IGT"
161 | if evaluate:
162 | prog_str += f", BLUE {b} - {o} ORANGE"
163 | progress.set_postfix_str(prog_str)
164 |
165 | if done or truncated:
166 | result += info["result"]
167 | if info["result"] > 0:
168 | b += 1
169 | elif info["result"] < 0:
170 | o += 1
171 |
172 | if not evaluate:
173 | break
174 | elif game_condition.done: # noqa
175 | break
176 | else:
177 | observations, info = env.reset(return_info=True)
178 |
179 | last_state = info['state']
180 |
181 | if scoreboard is not None:
182 | scoreboard.random_resets = random_resets # noqa Checked above
183 |
184 | if evaluate:
185 | if isinstance(env._match._state_setter, DynamicGMSetter): # noqa
186 | env._match._state_setter.setter = state_setter # noqa
187 | else:
188 | env._match._state_setter = state_setter # noqa
189 | env._match._terminal_conditions = terminals # noqa
190 | env._match._reward_fn = reward # noqa
191 | return result
192 |
193 | if progress is not None:
194 | progress.close()
195 |
196 | return rollouts, result
197 |
--------------------------------------------------------------------------------
/rocket_learn/agent/pretrained_agents/nexto/nexto_v2_obs.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import numpy as np
4 | from rlgym.utils.gamestates import GameState, PlayerData
5 |
6 | BOOST_LOCATIONS = (
7 | (0.0, -4240.0, 70.0),
8 | (-1792.0, -4184.0, 70.0),
9 | (1792.0, -4184.0, 70.0),
10 | (-3072.0, -4096.0, 73.0),
11 | (3072.0, -4096.0, 73.0),
12 | (- 940.0, -3308.0, 70.0),
13 | (940.0, -3308.0, 70.0),
14 | (0.0, -2816.0, 70.0),
15 | (-3584.0, -2484.0, 70.0),
16 | (3584.0, -2484.0, 70.0),
17 | (-1788.0, -2300.0, 70.0),
18 | (1788.0, -2300.0, 70.0),
19 | (-2048.0, -1036.0, 70.0),
20 | (0.0, -1024.0, 70.0),
21 | (2048.0, -1036.0, 70.0),
22 | (-3584.0, 0.0, 73.0),
23 | (-1024.0, 0.0, 70.0),
24 | (1024.0, 0.0, 70.0),
25 | (3584.0, 0.0, 73.0),
26 | (-2048.0, 1036.0, 70.0),
27 | (0.0, 1024.0, 70.0),
28 | (2048.0, 1036.0, 70.0),
29 | (-1788.0, 2300.0, 70.0),
30 | (1788.0, 2300.0, 70.0),
31 | (-3584.0, 2484.0, 70.0),
32 | (3584.0, 2484.0, 70.0),
33 | (0.0, 2816.0, 70.0),
34 | (- 940.0, 3310.0, 70.0),
35 | (940.0, 3308.0, 70.0),
36 | (-3072.0, 4096.0, 73.0),
37 | (3072.0, 4096.0, 73.0),
38 | (-1792.0, 4184.0, 70.0),
39 | (1792.0, 4184.0, 70.0),
40 | (0.0, 4240.0, 70.0),
41 | )
42 |
43 |
44 | def rotation_to_quaternion(m: np.ndarray) -> np.ndarray:
45 | trace = np.trace(m)
46 | q = np.zeros(4)
47 |
48 | if trace > 0:
49 | s = (trace + 1) ** 0.5
50 | q[0] = s * 0.5
51 | s = 0.5 / s
52 | q[1] = (m[2, 1] - m[1, 2]) * s
53 | q[2] = (m[0, 2] - m[2, 0]) * s
54 | q[3] = (m[1, 0] - m[0, 1]) * s
55 | else:
56 | if m[0, 0] >= m[1, 1] and m[0, 0] >= m[2, 2]:
57 | s = (1 + m[0, 0] - m[1, 1] - m[2, 2]) ** 0.5
58 | inv_s = 0.5 / s
59 | q[1] = 0.5 * s
60 | q[2] = (m[1, 0] + m[0, 1]) * inv_s
61 | q[3] = (m[2, 0] + m[0, 2]) * inv_s
62 | q[0] = (m[2, 1] - m[1, 2]) * inv_s
63 | elif m[1, 1] > m[2, 2]:
64 | s = (1 + m[1, 1] - m[0, 0] - m[2, 2]) ** 0.5
65 | inv_s = 0.5 / s
66 | q[1] = (m[0, 1] + m[1, 0]) * inv_s
67 | q[2] = 0.5 * s
68 | q[3] = (m[1, 2] + m[2, 1]) * inv_s
69 | q[0] = (m[0, 2] - m[2, 0]) * inv_s
70 | else:
71 | s = (1 + m[2, 2] - m[0, 0] - m[1, 1]) ** 0.5
72 | inv_s = 0.5 / s
73 | q[1] = (m[0, 2] + m[2, 0]) * inv_s
74 | q[2] = (m[1, 2] + m[2, 1]) * inv_s
75 | q[3] = 0.5 * s
76 | q[0] = (m[1, 0] - m[0, 1]) * inv_s
77 |
78 | # q[[0, 1, 2, 3]] = q[[3, 0, 1, 2]]
79 |
80 | return -q
81 |
82 |
83 | def encode_gamestate(state: GameState):
84 | state_vals = [0, state.blue_score, state.orange_score]
85 | state_vals += state.boost_pads.tolist()
86 |
87 | for bd in (state.ball, state.inverted_ball):
88 | state_vals += bd.position.tolist()
89 | state_vals += bd.linear_velocity.tolist()
90 | state_vals += bd.angular_velocity.tolist()
91 |
92 | for p in state.players:
93 | state_vals += [p.car_id, p.team_num]
94 | for cd in (p.car_data, p.inverted_car_data):
95 | state_vals += cd.position.tolist()
96 | state_vals += rotation_to_quaternion(cd.rotation_mtx()).tolist()
97 | state_vals += cd.linear_velocity.tolist()
98 | state_vals += cd.angular_velocity.tolist()
99 | state_vals += [
100 | 0,
101 | 0,
102 | 0,
103 | 0,
104 | 0,
105 | p.is_demoed,
106 | p.on_ground,
107 | p.ball_touched,
108 | p.has_flip,
109 | p.boost_amount
110 | ]
111 | return state_vals
112 |
113 |
114 | class BatchedObsBuilder:
115 | def __init__(self):
116 | super().__init__()
117 | self.current_state = None
118 | self.current_obs = None
119 |
120 | def batched_build_obs(self, encoded_states: np.ndarray) -> Any:
121 | raise NotImplementedError
122 |
123 | def add_actions(self, obs: Any, previous_actions: np.ndarray, player_index=None):
124 | # Modify current obs to include action
125 | # player_index=None means actions for all players should be provided
126 | raise NotImplementedError
127 |
128 | def _reset(self, initial_state: GameState):
129 | raise NotImplementedError
130 |
131 | def reset(self, initial_state: GameState):
132 | self.current_state = False
133 | self.current_obs = None
134 | self._reset(initial_state)
135 |
136 | def build_obs(self, player: PlayerData, state: GameState, previous_action: np.ndarray) -> Any:
137 | # if state != self.current_state:
138 | self.current_obs = self.batched_build_obs(
139 | np.expand_dims(encode_gamestate(state), axis=0)
140 | )
141 | self.current_state = state
142 |
143 | for i, p in enumerate(state.players):
144 | if p == player:
145 | self.add_actions(self.current_obs, previous_action, i)
146 | return self.current_obs[i]
147 |
148 |
149 | IS_SELF, IS_MATE, IS_OPP, IS_BALL, IS_BOOST = range(5)
150 | POS = slice(5, 8)
151 | LIN_VEL = slice(8, 11)
152 | FW = slice(11, 14)
153 | UP = slice(14, 17)
154 | ANG_VEL = slice(17, 20)
155 | BOOST, DEMO, ON_GROUND, HAS_FLIP = range(20, 24)
156 | ACTIONS = range(24, 32)
157 |
158 | BALL_STATE_LENGTH = 18
159 | PLAYER_CAR_STATE_LENGTH = 13
160 | PLAYER_TERTIARY_INFO_LENGTH = 10
161 | PLAYER_INFO_LENGTH = 2 + 2 * PLAYER_CAR_STATE_LENGTH + PLAYER_TERTIARY_INFO_LENGTH
162 |
163 |
164 | class Nexto_V2_ObsBuilder(BatchedObsBuilder):
165 | _invert = np.array([1] * 5 + [-1, -1, 1] * 5 + [1] * 4)
166 | _norm = np.array([1.] * 5 + [2300] * 6 + [1] * 6 + [5.5] * 3 + [1] * 4)
167 |
168 | def __init__(self, field_info=None, n_players=None, tick_skip=8):
169 | super().__init__()
170 | self.n_players = n_players
171 | self.demo_timers = None
172 | self.boost_timers = None
173 | self.tick_skip = tick_skip
174 | if field_info is None:
175 | self._boost_locations = np.array(BOOST_LOCATIONS)
176 | self._boost_types = self._boost_locations[:, 2] > 72
177 | else:
178 | self._boost_locations = np.array([[bp.location.x, bp.location.y, bp.location.z]
179 | for bp in field_info.boost_pads[:field_info.num_boosts]])
180 | self._boost_types = np.array([bp.is_full_boost for bp in field_info.boost_pads[:field_info.num_boosts]])
181 |
182 | def _reset(self, initial_state: GameState):
183 | self.demo_timers = np.zeros(len(initial_state.players))
184 | self.boost_timers = np.zeros(len(initial_state.boost_pads))
185 |
186 | @staticmethod
187 | def _quats_to_rot_mtx(quats: np.ndarray) -> np.ndarray:
188 | # From rlgym.utils.math.quat_to_rot_mtx
189 | w = -quats[:, 0]
190 | x = -quats[:, 1]
191 | y = -quats[:, 2]
192 | z = -quats[:, 3]
193 |
194 | theta = np.zeros((quats.shape[0], 3, 3))
195 |
196 | norm = np.einsum("fq,fq->f", quats, quats)
197 |
198 | sel = norm != 0
199 |
200 | w = w[sel]
201 | x = x[sel]
202 | y = y[sel]
203 | z = z[sel]
204 |
205 | s = 1.0 / norm[sel]
206 |
207 | # front direction
208 | theta[sel, 0, 0] = 1.0 - 2.0 * s * (y * y + z * z)
209 | theta[sel, 1, 0] = 2.0 * s * (x * y + z * w)
210 | theta[sel, 2, 0] = 2.0 * s * (x * z - y * w)
211 |
212 | # left direction
213 | theta[sel, 0, 1] = 2.0 * s * (x * y - z * w)
214 | theta[sel, 1, 1] = 1.0 - 2.0 * s * (x * x + z * z)
215 | theta[sel, 2, 1] = 2.0 * s * (y * z + x * w)
216 |
217 | # up direction
218 | theta[sel, 0, 2] = 2.0 * s * (x * z + y * w)
219 | theta[sel, 1, 2] = 2.0 * s * (y * z - x * w)
220 | theta[sel, 2, 2] = 1.0 - 2.0 * s * (x * x + y * y)
221 |
222 | return theta
223 |
224 | @staticmethod
225 | def convert_to_relative(q, kv):
226 | # kv[..., POS.start:LIN_VEL.stop] -= q[..., POS.start:LIN_VEL.stop]
227 | kv[..., POS] -= q[..., POS]
228 | forward = q[..., FW]
229 | theta = np.arctan2(forward[..., 0], forward[..., 1])
230 | theta = np.expand_dims(theta, axis=-1)
231 | ct = np.cos(theta)
232 | st = np.sin(theta)
233 | xs = kv[..., POS.start:ANG_VEL.stop:3]
234 | ys = kv[..., POS.start + 1:ANG_VEL.stop:3]
235 | # Use temp variables to prevent modifying original array
236 | nx = ct * xs - st * ys
237 | ny = st * xs + ct * ys
238 | kv[..., POS.start:ANG_VEL.stop:3] = nx # x-components
239 | kv[..., POS.start + 1:ANG_VEL.stop:3] = ny # y-components
240 |
241 | def batched_build_obs(self, encoded_states: np.ndarray):
242 | ball_start_index = 3 + len(self._boost_locations)
243 | players_start_index = ball_start_index + BALL_STATE_LENGTH
244 | player_length = PLAYER_INFO_LENGTH
245 |
246 | n_players = (encoded_states.shape[1] - players_start_index) // player_length
247 | lim_players = n_players if self.n_players is None else self.n_players
248 | n_entities = lim_players + 1 + 34
249 |
250 | # SELECTORS
251 | sel_players = slice(0, lim_players)
252 | sel_ball = sel_players.stop
253 | sel_boosts = slice(sel_ball + 1, None)
254 |
255 | # MAIN ARRAYS
256 | q = np.zeros((n_players, encoded_states.shape[0], 1, 32))
257 | kv = np.zeros((n_players, encoded_states.shape[0], n_entities, 24)) # Keys and values are (mostly) shared
258 | m = np.zeros((n_players, encoded_states.shape[0], n_entities)) # Mask is shared
259 |
260 | # BALL
261 | kv[:, :, sel_ball, 3] = 1
262 | kv[:, :, sel_ball, np.r_[POS, LIN_VEL, ANG_VEL]] = encoded_states[:, ball_start_index: ball_start_index + 9]
263 |
264 | # BOOSTS
265 | kv[:, :, sel_boosts, IS_BOOST] = 1
266 | kv[:, :, sel_boosts, POS] = self._boost_locations
267 | kv[:, :, sel_boosts, BOOST] = 0.12 + 0.88 * (self._boost_locations[:, 2] > 72)
268 | kv[:, :, sel_boosts, DEMO] = encoded_states[:, 3:3 + 34] # FIXME boost timer
269 |
270 | # PLAYERS
271 | teams = encoded_states[0, players_start_index + 1::player_length]
272 | kv[:, :, :n_players, IS_MATE] = 1 - teams # Default team is blue
273 | kv[:, :, :n_players, IS_OPP] = teams
274 | for i in range(n_players):
275 | encoded_player = encoded_states[:,
276 | players_start_index + i * player_length: players_start_index + (i + 1) * player_length]
277 |
278 | kv[i, :, i, IS_SELF] = 1
279 | kv[:, :, i, POS] = encoded_player[:, 2: 5] # TODO constants for these indices
280 | kv[:, :, i, LIN_VEL] = encoded_player[:, 9: 12]
281 | quats = encoded_player[:, 5: 9]
282 | rot_mtx = self._quats_to_rot_mtx(quats)
283 | kv[:, :, i, FW] = rot_mtx[:, :, 0]
284 | kv[:, :, i, UP] = rot_mtx[:, :, 2]
285 | kv[:, :, i, ANG_VEL] = encoded_player[:, 12: 15]
286 | kv[:, :, i, BOOST] = encoded_player[:, 37]
287 | kv[:, :, i, DEMO] = encoded_player[:, 33] # FIXME demo timer
288 | kv[:, :, i, ON_GROUND] = encoded_player[:, 34]
289 | kv[:, :, i, HAS_FLIP] = encoded_player[:, 36]
290 |
291 | kv[teams == 1] *= self._invert
292 | kv[np.argwhere(teams == 1), ..., (IS_MATE, IS_OPP)] = kv[
293 | np.argwhere(teams == 1), ..., (IS_OPP, IS_MATE)] # Swap teams
294 |
295 | kv /= self._norm
296 |
297 | for i in range(n_players):
298 | q[i, :, 0, :kv.shape[-1]] = kv[i, :, i, :]
299 |
300 | self.convert_to_relative(q, kv)
301 | # kv[:, :, :, 5:11] -= q[:, :, :, 5:11]
302 |
303 | # MASK
304 | m[:, :, n_players: lim_players] = 1
305 |
306 | return [(q[i], kv[i], m[i]) for i in range(n_players)]
307 |
308 | def add_actions(self, obs: Any, previous_actions: np.ndarray, player_index=None):
309 | if player_index is None:
310 | for (q, kv, m), act in zip(obs, previous_actions):
311 | q[:, 0, ACTIONS] = act
312 | else:
313 | q, kv, m = obs[player_index]
314 | q[:, 0, ACTIONS] = previous_actions
315 |
--------------------------------------------------------------------------------
/rocket_learn/utils/stat_trackers/common_trackers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from rocket_learn.utils.gamestate_encoding import StateConstants
4 | from rocket_learn.utils.stat_trackers.stat_tracker import StatTracker
5 |
6 |
7 | class Speed(StatTracker):
8 | def __init__(self):
9 | super().__init__("average_speed")
10 | self.count = 0
11 | self.total_speed = 0.0
12 |
13 | def reset(self):
14 | self.count = 0
15 | self.total_speed = 0.0
16 |
17 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
18 | players = gamestates[:, StateConstants.PLAYERS]
19 | xs = players[:, StateConstants.CAR_LINEAR_VEL_X]
20 | ys = players[:, StateConstants.CAR_LINEAR_VEL_Y]
21 | zs = players[:, StateConstants.CAR_LINEAR_VEL_Z]
22 |
23 | speeds = np.sqrt(xs ** 2 + ys ** 2 + zs ** 2)
24 | self.total_speed += np.sum(speeds)
25 | self.count += speeds.size
26 |
27 | def get_stat(self):
28 | return self.total_speed / (self.count or 1)
29 |
30 |
31 | class Demos(StatTracker):
32 | def __init__(self):
33 | super().__init__("average_demos")
34 | self.count = 0
35 | self.total_demos = 0
36 |
37 | def reset(self):
38 | self.count = 0
39 | self.total_demos = 0
40 |
41 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
42 | players = gamestates[:, StateConstants.PLAYERS]
43 |
44 | demos = np.clip(players[-1, StateConstants.MATCH_DEMOLISHES] - players[0, StateConstants.MATCH_DEMOLISHES],
45 | 0, None)
46 | self.total_demos += np.sum(demos)
47 | self.count += demos.size
48 |
49 | def get_stat(self):
50 | return self.total_demos / (self.count or 1)
51 |
52 |
53 | class TimeoutRate(StatTracker):
54 | def __init__(self):
55 | super().__init__("timeout_rate")
56 | self.count = 0
57 | self.total_timeouts = 0
58 |
59 | def reset(self):
60 | self.count = 0
61 | self.total_timeouts = 0
62 |
63 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
64 | orange_diff = gamestates[-1, StateConstants.ORANGE_SCORE] - gamestates[0, StateConstants.ORANGE_SCORE]
65 | blue_diff = gamestates[-1, StateConstants.BLUE_SCORE] - gamestates[0, StateConstants.BLUE_SCORE]
66 |
67 | self.total_timeouts += ((orange_diff == 0) & (blue_diff == 0)).item()
68 | self.count += 1
69 |
70 | def get_stat(self):
71 | return self.total_timeouts / (self.count or 1)
72 |
73 |
74 | class Touch(StatTracker):
75 | def __init__(self):
76 | super().__init__("touch_rate")
77 | self.count = 0
78 | self.total_touches = 0
79 |
80 | def reset(self):
81 | self.count = 0
82 | self.total_touches = 0
83 |
84 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
85 | players = gamestates[:, StateConstants.PLAYERS]
86 | is_touch = players[:, StateConstants.BALL_TOUCHED]
87 |
88 | self.total_touches += np.sum(is_touch)
89 | self.count += is_touch.size
90 |
91 | def get_stat(self):
92 | return self.total_touches / (self.count or 1)
93 |
94 |
95 | class EpisodeLength(StatTracker):
96 | def __init__(self):
97 | super().__init__("episode_length")
98 | self.count = 0
99 | self.total_length = 0
100 |
101 | def reset(self):
102 | self.count = 0
103 | self.total_length = 0
104 |
105 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
106 | self.total_length += gamestates.shape[0]
107 | self.count += 1
108 |
109 | def get_stat(self):
110 | return self.total_length / (self.count or 1)
111 |
112 |
113 | class Boost(StatTracker):
114 | def __init__(self):
115 | super().__init__("average_boost")
116 | self.count = 0
117 | self.total_boost = 0
118 |
119 | def reset(self):
120 | self.count = 0
121 | self.total_boost = 0
122 |
123 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
124 | players = gamestates[:, StateConstants.PLAYERS]
125 | boost = players[:, StateConstants.BOOST_AMOUNT]
126 | is_limited = (0 <= boost) & (boost <= 1)
127 | boost = boost[is_limited]
128 | self.total_boost += np.sum(boost)
129 | self.count += boost.size
130 |
131 | def get_stat(self):
132 | return self.total_boost / (self.count or 1)
133 |
134 |
135 | class BehindBall(StatTracker):
136 | def __init__(self):
137 | super().__init__("behind_ball_rate")
138 | self.count = 0
139 | self.total_behind = 0
140 |
141 | def reset(self):
142 | self.count = 0
143 | self.total_behind = 0
144 |
145 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
146 | players = gamestates[:, StateConstants.PLAYERS]
147 | ball_y = gamestates[:, StateConstants.BALL_POSITION.start + 1]
148 | player_y = players[:, StateConstants.CAR_POS_Y]
149 | is_orange = players[:, StateConstants.TEAM_NUMS]
150 | behind = (2 * is_orange - 1) * (ball_y.reshape(-1, 1) - player_y) < 0
151 |
152 | self.total_behind += np.sum(behind)
153 | self.count += behind.size
154 |
155 | def get_stat(self):
156 | return self.total_behind / (self.count or 1)
157 |
158 |
159 | class TouchHeight(StatTracker):
160 | def __init__(self):
161 | super().__init__("touch_height")
162 | self.count = 0
163 | self.total_height = 0
164 |
165 | def reset(self):
166 | self.count = 0
167 | self.total_height = 0
168 |
169 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
170 | players = gamestates[:, StateConstants.PLAYERS]
171 | ball_z = gamestates[:, StateConstants.BALL_POSITION.start + 2]
172 | touch_heights = ball_z[players[:, StateConstants.BALL_TOUCHED].any(axis=1)]
173 |
174 | self.total_height += np.sum(touch_heights)
175 | self.count += touch_heights.size
176 |
177 | def get_stat(self):
178 | return self.total_height / (self.count or 1)
179 |
180 |
181 | class DistToBall(StatTracker):
182 | def __init__(self):
183 | super().__init__("distance_to_ball")
184 | self.count = 0
185 | self.total_dist = 0.0
186 |
187 | def reset(self):
188 | self.count = 0
189 | self.total_dist = 0.0
190 |
191 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
192 | players = gamestates[:, StateConstants.PLAYERS]
193 | ball = gamestates[:, StateConstants.BALL_POSITION]
194 | ball_x = ball[:, 0].reshape((-1, 1))
195 | ball_y = ball[:, 1].reshape((-1, 1))
196 | ball_z = ball[:, 2].reshape((-1, 1))
197 | xs = players[:, StateConstants.CAR_POS_X]
198 | ys = players[:, StateConstants.CAR_POS_Y]
199 | zs = players[:, StateConstants.CAR_POS_Z]
200 |
201 | dists = np.sqrt((ball_x - xs) ** 2 + (ball_y - ys) ** 2 + (ball_z - zs) ** 2)
202 | self.total_dist += np.sum(dists)
203 | self.count += dists.size
204 |
205 | def get_stat(self):
206 | return self.total_dist / (self.count or 1)
207 |
208 |
209 | class AirTouch(StatTracker):
210 | def __init__(self):
211 | super().__init__("air_touch_rate")
212 | self.count = 0
213 | self.total_touches = 0
214 |
215 | def reset(self):
216 | self.count = 0
217 | self.total_touches = 0
218 |
219 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
220 | players = gamestates[:, StateConstants.PLAYERS]
221 | is_touch = np.asarray([a * b for a, b in
222 | zip(players[:, StateConstants.BALL_TOUCHED],
223 | np.invert(players[:, StateConstants.ON_GROUND].astype(bool)))])
224 |
225 | self.total_touches += np.sum(is_touch)
226 | self.count += is_touch.size
227 |
228 | def get_stat(self):
229 | return self.total_touches / (self.count or 1)
230 |
231 |
232 | class AirTouchHeight(StatTracker):
233 | def __init__(self):
234 | super().__init__("air_touch_height")
235 | self.count = 0
236 | self.total_height = 0
237 |
238 | def reset(self):
239 | self.count = 0
240 | self.total_height = 0
241 |
242 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
243 | players = gamestates[:, StateConstants.PLAYERS]
244 | ball_z = gamestates[:, StateConstants.BALL_POSITION.start + 2]
245 | touch_heights = ball_z[players[:, StateConstants.BALL_TOUCHED].any(axis=1)]
246 | touch_heights = touch_heights[touch_heights >= 175] # remove dribble touches and below
247 |
248 | self.total_height += np.sum(touch_heights)
249 | self.count += touch_heights.size
250 |
251 | def get_stat(self):
252 | return self.total_height / (self.count or 1)
253 |
254 |
255 | class BallSpeed(StatTracker):
256 | def __init__(self):
257 | super().__init__("average_ball_speed")
258 | self.count = 0
259 | self.total_speed = 0.0
260 |
261 | def reset(self):
262 | self.count = 0
263 | self.total_speed = 0.0
264 |
265 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
266 | ball_speeds = gamestates[:, StateConstants.BALL_LINEAR_VELOCITY]
267 | xs = ball_speeds[:, 0]
268 | ys = ball_speeds[:, 1]
269 | zs = ball_speeds[:, 2]
270 | speeds = np.sqrt(xs ** 2 + ys ** 2 + zs ** 2)
271 | self.total_speed += np.sum(speeds)
272 | self.count += speeds.size
273 |
274 | def get_stat(self):
275 | return self.total_speed / (self.count or 1)
276 |
277 |
278 | class BallHeight(StatTracker):
279 | def __init__(self):
280 | super().__init__("ball_height")
281 | self.count = 0
282 | self.total_height = 0
283 |
284 | def reset(self):
285 | self.count = 0
286 | self.total_height = 0
287 |
288 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
289 | ball_z = gamestates[:, StateConstants.BALL_POSITION.start + 2]
290 |
291 | self.total_height += np.sum(ball_z)
292 | self.count += ball_z.size
293 |
294 | def get_stat(self):
295 | return self.total_height / (self.count or 1)
296 |
297 |
298 | class GoalSpeed(StatTracker):
299 | def __init__(self):
300 | super().__init__("avg_goal_speed")
301 | self.count = 0
302 | self.total_speed = 0
303 |
304 | def reset(self):
305 | self.count = 0
306 | self.total_speed = 0
307 |
308 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
309 | orange_diff = np.diff(gamestates[:, StateConstants.ORANGE_SCORE], append=np.nan)
310 | blue_diff = np.diff(gamestates[:, StateConstants.BLUE_SCORE], append=np.nan)
311 |
312 | goal_frames = (orange_diff > 0) | (blue_diff > 0)
313 |
314 | goal_speed = gamestates[goal_frames, StateConstants.BALL_LINEAR_VELOCITY]
315 | goal_speed = np.linalg.norm(goal_speed, axis=-1).sum()
316 |
317 | self.total_speed += goal_speed / 27.78 # convert to km/h
318 | self.count += goal_speed.size
319 |
320 | def get_stat(self):
321 | return self.total_speed / (self.count or 1)
322 |
323 |
324 | class MaxGoalSpeed(StatTracker):
325 | def __init__(self):
326 | super().__init__("max_goal_speed")
327 | self.max_speed = 0
328 |
329 | def reset(self):
330 | self.max_speed = 0
331 |
332 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
333 | if gamestates.ndim > 1 and len(gamestates) > 1:
334 | end = gamestates[-2]
335 | ball_speeds = end[StateConstants.BALL_LINEAR_VELOCITY]
336 | goal_speed = float(np.linalg.norm(ball_speeds)) / 27.78 # convert to km/h
337 |
338 | self.max_speed = max(float(self.max_speed), goal_speed)
339 |
340 | def get_stat(self):
341 | return self.max_speed
342 |
343 |
344 | class CarOnGround(StatTracker):
345 | def __init__(self):
346 | super().__init__("pct_car_on_ground")
347 | self.count = 0
348 | self.total_ground = 0.0
349 |
350 | def reset(self):
351 | self.count = 0
352 | self.total_ground = 0.0
353 |
354 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
355 | on_ground = gamestates[:, StateConstants.ON_GROUND]
356 |
357 | self.total_ground += np.sum(on_ground)
358 | self.count += on_ground.size
359 |
360 | def get_stat(self):
361 | return 100 * self.total_ground / (self.count or 1)
362 |
363 |
364 | class Saves(StatTracker):
365 | def __init__(self):
366 | super().__init__("average_saves")
367 | self.count = 0
368 | self.total_saves = 0
369 |
370 | def reset(self):
371 | self.count = 0
372 | self.total_saves = 0
373 |
374 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
375 | players = gamestates[:, StateConstants.PLAYERS]
376 |
377 | saves = np.clip(players[-1, StateConstants.MATCH_SAVES] - players[0, StateConstants.MATCH_SAVES],
378 | 0, None)
379 | self.total_saves += np.sum(saves)
380 | self.count += saves.size
381 |
382 | def get_stat(self):
383 | return self.total_saves / (self.count or 1)
384 |
385 |
386 | class Shots(StatTracker):
387 | def __init__(self):
388 | super().__init__("average_shots")
389 | self.count = 0
390 | self.total_shots = 0
391 |
392 | def reset(self):
393 | self.count = 0
394 | self.total_shots = 0
395 |
396 | def update(self, gamestates: np.ndarray, mask: np.ndarray):
397 | players = gamestates[:, StateConstants.PLAYERS]
398 |
399 | shots = np.clip(players[-1, StateConstants.MATCH_SHOTS] - players[0, StateConstants.MATCH_SHOTS],
400 | 0, None)
401 | self.total_shots += np.sum(shots)
402 | self.count += shots.size
403 |
404 | def get_stat(self):
405 | return self.total_shots / (self.count or 1)
406 |
--------------------------------------------------------------------------------
/rocket_learn/rollout_generator/redis/redis_rollout_generator.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from collections import Counter
3 | from typing import Iterator, Callable, Optional, List
4 |
5 | import numpy as np
6 | import plotly.graph_objs as go
7 | # import matplotlib.pyplot # noqa
8 | import wandb
9 | # from matplotlib.axes import Axes
10 | # from matplotlib.figure import Figure
11 | from redis import Redis
12 | from redis.exceptions import ResponseError
13 | from rlgym.utils import ObsBuilder, RewardFunction
14 | from rlgym.utils.action_parsers import ActionParser
15 | from trueskill import Rating, rate, SIGMA
16 |
17 | from rocket_learn.experience_buffer import ExperienceBuffer
18 | from rocket_learn.rollout_generator.base_rollout_generator import BaseRolloutGenerator
19 | from rocket_learn.rollout_generator.redis.utils import decode_buffers, _unserialize, QUALITIES, _serialize, ROLLOUTS, \
20 | VERSION_LATEST, OPPONENT_MODELS, CONTRIBUTORS, N_UPDATES, MODEL_LATEST, _serialize_model, get_rating, \
21 | _ALL, LATEST_RATING_ID, EXPERIENCE_PER_MODE
22 | from rocket_learn.utils.stat_trackers.stat_tracker import StatTracker
23 |
24 |
25 | class RedisRolloutGenerator(BaseRolloutGenerator):
26 | """
27 | Rollout generator in charge of sending commands to workers via redis
28 | """
29 |
30 | def __init__(
31 | self,
32 | name: str,
33 | redis: Redis,
34 | obs_build_factory: Callable[[], ObsBuilder],
35 | rew_func_factory: Callable[[], RewardFunction],
36 | act_parse_factory: Callable[[], ActionParser],
37 | save_every=10,
38 | model_every=100,
39 | logger=None,
40 | clear=True,
41 | max_age=0,
42 | default_sigma=SIGMA,
43 | min_sigma=1,
44 | gamemodes=("1v1", "2v2", "3v3"),
45 | stat_trackers: Optional[List[StatTracker]] = None,
46 | ):
47 | self.lastsave_ts = None
48 | self.name = name
49 | self.tot_bytes = 0
50 | self.redis = redis
51 | self.logger = logger
52 |
53 | # TODO saving/loading
54 | if clear:
55 | self.redis.delete(*(_ALL + tuple(QUALITIES.format(gm) for gm in gamemodes)))
56 | self.redis.set(N_UPDATES, 0)
57 | self.redis.hset(EXPERIENCE_PER_MODE, mapping={m: 0 for m in gamemodes})
58 | else:
59 | if self.redis.exists(ROLLOUTS) > 0:
60 | self.redis.delete(ROLLOUTS)
61 | self.redis.decr(VERSION_LATEST,
62 | max_age + 1) # In case of reload from old version, don't let current seep in
63 |
64 | # self.redis.set(SAVE_FREQ, save_every)
65 | # self.redis.set(MODEL_FREQ, model_every)
66 | self.save_freq = save_every
67 | self.model_freq = model_every
68 | self.contributors = Counter() # No need to save, clears every iteration
69 | self.obs_build_func = obs_build_factory
70 | self.rew_func_factory = rew_func_factory
71 | self.act_parse_factory = act_parse_factory
72 | self.max_age = max_age
73 | self.default_sigma = default_sigma
74 | self.min_sigma = min_sigma
75 | self.gamemodes = gamemodes
76 | self.stat_trackers = stat_trackers or []
77 | self._reset_stats()
78 |
79 | @staticmethod
80 | def _process_rollout(rollout_bytes, latest_version, obs_build_func, rew_build_func, act_build_func, max_age):
81 | rollout_data, versions, uuid, name, result, has_obs, has_states, has_rewards = _unserialize(rollout_bytes)
82 |
83 | v_check = [v for v in versions if isinstance(v, int) or v.startswith("-")]
84 |
85 | if any(version < 0 and abs(version - latest_version) > max_age for version in v_check):
86 | return
87 |
88 | if any(version < 0 for version in v_check):
89 | buffers, states = decode_buffers(rollout_data, versions, has_obs, has_states, has_rewards,
90 | obs_build_func, rew_build_func, act_build_func)
91 | else:
92 | buffers = states = [None] * len(v_check)
93 |
94 | return buffers, states, versions, uuid, name, result
95 |
96 | def _update_ratings(self, name, versions, buffers, latest_version, result):
97 | ratings = []
98 | relevant_buffers = []
99 | gamemode = f"{len(versions) // 2}v{len(versions) // 2}" # TODO: support unfair games
100 |
101 | versions = [v for v in versions if v != "na"]
102 | for version, buffer in itertools.zip_longest(versions, buffers):
103 | if isinstance(version, int) and version < 0:
104 | if abs(version - latest_version) <= self.max_age:
105 | relevant_buffers.append(buffer)
106 | self.contributors[name] += buffer.size()
107 | else:
108 | return []
109 | else:
110 | rating = get_rating(gamemode, version, self.redis)
111 | ratings.append(rating)
112 |
113 | # Only old versions, calculate MMR
114 | if len(ratings) == len(versions) and len(buffers) == 0:
115 | blue_players = sum(divmod(len(ratings), 2))
116 | blue = tuple(ratings[:blue_players]) # Tuple is important
117 | orange = tuple(ratings[blue_players:])
118 |
119 | # In ranks lowest number is best, result=-1 is orange win, 0 tie, 1 blue
120 | r1, r2 = rate((blue, orange), ranks=(0, result))
121 |
122 | # Some trickery to handle same rating appearing multiple times, we just average their new mus and sigmas
123 | ratings_versions = {}
124 | for rating, version in zip(r1 + r2, versions):
125 | ratings_versions.setdefault(version, []).append(rating)
126 |
127 | mapping = {}
128 | for version, ratings in ratings_versions.items():
129 | # In case of duplicates, average ratings together (not strictly necessary with default setup)
130 | # Also limit sigma to its lower bound
131 | avg_rating = Rating(sum(r.mu for r in ratings) / len(ratings),
132 | max(sum(r.sigma ** 2 for r in ratings) ** 0.5 / len(ratings), self.min_sigma))
133 | mapping[version] = _serialize(tuple(avg_rating))
134 | gamemode = f"{len(versions) // 2}v{len(versions) // 2}" # TODO: support unfair games
135 | self.redis.hset(QUALITIES.format(gamemode), mapping=mapping)
136 |
137 | if len(relevant_buffers) > 0:
138 | self.redis.hincrby(EXPERIENCE_PER_MODE, gamemode, len(relevant_buffers) * relevant_buffers[0].size())
139 |
140 | return relevant_buffers
141 |
142 | def _reset_stats(self):
143 | for stat_tracker in self.stat_trackers:
144 | stat_tracker.reset()
145 |
146 | def _update_stats(self, states, mask):
147 | if states is None:
148 | return
149 | for stat_tracker in self.stat_trackers:
150 | stat_tracker.update(states, mask)
151 |
152 | def _get_stats(self):
153 | stats = {}
154 | for stat_tracker in self.stat_trackers:
155 | stats[stat_tracker.name] = stat_tracker.get_stat()
156 | return stats
157 |
158 | def generate_rollouts(self) -> Iterator[ExperienceBuffer]:
159 | while True:
160 | latest_version = int(self.redis.get(VERSION_LATEST))
161 | data = self.redis.blpop(ROLLOUTS)[1]
162 | self.tot_bytes += len(data)
163 | res = self._process_rollout(
164 | data, latest_version,
165 | self.obs_build_func, self.rew_func_factory, self.act_parse_factory,
166 | self.max_age
167 | )
168 | if res is not None:
169 | buffers, states, versions, uuid, name, result = res
170 | # versions = [version for version in versions if version != 'na'] # don't track humans or hardcoded
171 |
172 | relevant_buffers = self._update_ratings(name, versions, buffers, latest_version, result)
173 | if len(relevant_buffers) > 0:
174 | self._update_stats(states, [b in relevant_buffers for b in buffers])
175 | yield from relevant_buffers
176 |
177 | def _plot_ratings(self):
178 | fig_data = []
179 | i = 0
180 | means = {}
181 | mean_key = "mean"
182 | gamemodes = list(self.gamemodes)
183 | if len(gamemodes) > 1:
184 | gamemodes.append(mean_key)
185 | for gamemode in gamemodes:
186 | if gamemode != mean_key:
187 | ratings = get_rating(gamemode, None, self.redis)
188 | if len(ratings) <= 0:
189 | return
190 | baseline = None
191 | for mode in "stochastic", "deterministic":
192 | if gamemode != "mean":
193 | x = []
194 | mus = []
195 | sigmas = []
196 | for k, r in ratings.items(): # noqa
197 | if k.endswith(mode):
198 | v = int(k.rsplit("-", 2)[1][1:])
199 | # v = int(k.split("-")[1][1:])
200 | x.append(v)
201 | mus.append(r.mu)
202 | sigmas.append(r.sigma)
203 | mean = means.setdefault(mode, {}).get(v, (0, 0))
204 | means[mode][v] = (mean[0] + r.mu, mean[1] + r.sigma ** 2)
205 | # *Smoothly* transition from red, to green, to blue depending on gamemode
206 | mid = (len(self.gamemodes) - 1) / 2
207 | # avoid divide by 0 issues if there's only one gamemode, this moves it halfway into the colors
208 | if mid == 0:
209 | mid = 0.5
210 | if i < mid:
211 | r = 1 - i / mid
212 | g = i / mid
213 | b = 0
214 | else:
215 | r = 0
216 | g = 1 - (i - mid) / mid
217 | b = (i - mid) / mid
218 | else:
219 | means_mode = means.get(mode, {})
220 | x = list(means_mode.keys())
221 | mus = [mean[0] / len(self.gamemodes) for mean in means_mode.values()]
222 | sigmas = [(mean[1] / len(self.gamemodes)) ** 0.5 for mean in means_mode.values()]
223 | r = g = b = 1 / 3
224 |
225 | indices = np.argsort(x)
226 | x = np.array(x)[indices]
227 | mus = np.array(mus)[indices]
228 | sigmas = np.array(sigmas)[indices]
229 |
230 | if baseline is None:
231 | baseline = mus[0] # Stochastic initialization is defined as the baseline (0 mu)
232 | mus = mus - baseline
233 | y = mus
234 | y_upper = mus + 2 * sigmas # 95% confidence
235 | y_lower = mus - 2 * sigmas
236 |
237 | scale = 255 if mode == "stochastic" else 128
238 | color = f"{int(r * scale)},{int(g * scale)},{int(b * scale)}"
239 |
240 | fig_data += [
241 | go.Scatter(
242 | x=x,
243 | y=y,
244 | line=dict(color=f'rgb({color})'),
245 | mode='lines',
246 | name=f"{gamemode}-{mode}",
247 | legendgroup=f"{gamemode}-{mode}",
248 | showlegend=True,
249 | visible=None if gamemode == gamemodes[-1] else "legendonly",
250 | ),
251 | go.Scatter(
252 | x=np.concatenate((x, x[::-1])), # x, then x reversed
253 | y=np.concatenate((y_upper, y_lower[::-1])), # upper, then lower reversed
254 | fill='toself',
255 | fillcolor=f'rgba({color},0.2)',
256 | line=dict(color='rgba(255,255,255,0)'),
257 | hoverinfo="skip",
258 | name="sigma",
259 | legendgroup=f"{gamemode}-{mode}",
260 | showlegend=False,
261 | visible=None if gamemode == gamemodes[-1] else "legendonly",
262 | ),
263 | ]
264 | i += 1
265 |
266 | if len(fig_data) <= 0:
267 | return
268 |
269 | fig = go.Figure(fig_data)
270 | fig.update_layout(title="Rating", xaxis_title="Iteration", yaxis_title="TrueSkill")
271 |
272 | self.logger.log({
273 | "qualities": fig,
274 | }, commit=False)
275 |
276 | def _add_opponent(self, agent):
277 | latest_id = self.redis.get(LATEST_RATING_ID)
278 | prefix = f"{self.name}-v"
279 | if latest_id is None:
280 | version = 0
281 | else:
282 | latest_id = latest_id.decode("utf-8")
283 | version = int(latest_id.replace(prefix, "")) + 1
284 | key = f"{prefix}{version}"
285 |
286 | # Add to list
287 | self.redis.hset(OPPONENT_MODELS, key, agent)
288 |
289 | # Set quality
290 | for gamemode in self.gamemodes:
291 | ratings = get_rating(gamemode, None, self.redis)
292 |
293 | for mode in "stochastic", "deterministic":
294 | if latest_id is not None:
295 | latest_key = f"{latest_id}-{mode}"
296 | quality = Rating(ratings[latest_key].mu, self.default_sigma)
297 | else:
298 | quality = Rating(0, self.min_sigma) # First (basically random) agent is initialized at 0
299 |
300 | self.redis.hset(QUALITIES.format(gamemode), f"{key}-{mode}", _serialize(tuple(quality)))
301 |
302 | # Inform that new opponent is ready
303 | self.redis.set(LATEST_RATING_ID, key)
304 |
305 | def update_parameters(self, new_params):
306 | """
307 | update redis (and thus workers) with new model data and save data as future opponent
308 | :param new_params: new model parameters
309 | """
310 | model_bytes = _serialize_model(new_params)
311 | self.redis.set(MODEL_LATEST, model_bytes)
312 | self.redis.decr(VERSION_LATEST)
313 |
314 | print("Top contributors:\n" + "\n".join(f"\t{c}: {n}" for c, n in self.contributors.most_common(5)))
315 | self.logger.log({
316 | "redis/contributors": wandb.Table(columns=["name", "steps"], data=self.contributors.most_common())},
317 | commit=False
318 | )
319 | self._plot_ratings()
320 | tot_contributors = self.redis.hgetall(CONTRIBUTORS)
321 | tot_contributors = Counter({name: int(count) for name, count in tot_contributors.items()})
322 | tot_contributors += self.contributors
323 | if tot_contributors:
324 | self.redis.hset(CONTRIBUTORS, mapping=tot_contributors)
325 | self.contributors.clear()
326 |
327 | self.logger.log({"redis/rollout_bytes": self.tot_bytes}, commit=False)
328 | self.tot_bytes = 0
329 |
330 | n_updates = self.redis.incr(N_UPDATES) - 1
331 | # save_freq = int(self.redis.get(SAVE_FREQ))
332 |
333 | if n_updates > 0:
334 | self.logger.log({f"stat/{name}": value for name, value in self._get_stats().items()}, commit=False)
335 | self._reset_stats()
336 |
337 | if n_updates % self.model_freq == 0:
338 | print("Adding model to pool...")
339 | self._add_opponent(model_bytes)
340 |
341 | if n_updates % self.save_freq == 0:
342 | # self.redis.set(MODEL_N.format(self.n_updates // self.save_every), model_bytes)
343 | print("Saving model...")
344 | if self.lastsave_ts == self.redis.lastsave():
345 | print("redis save error, previous bgsave failed")
346 | self.lastsave_ts = self.redis.lastsave()
347 | try:
348 | self.redis.bgsave()
349 | except ResponseError:
350 | print("redis bgsave failed, auto save already in progress")
351 |
--------------------------------------------------------------------------------
/rocket_learn/ppo.py:
--------------------------------------------------------------------------------
1 | import cProfile
2 | import io
3 | import os
4 | import pstats
5 | import time
6 | import sys
7 | from typing import Iterator, List, Tuple, Union
8 |
9 | import numba
10 | import numpy as np
11 | import torch
12 | import torch as th
13 | from torch.distributions import kl_divergence
14 | from torch.nn import functional as F
15 | from torch.nn.utils import clip_grad_norm_
16 |
17 | from rocket_learn.agent.actor_critic_agent import ActorCriticAgent
18 | from rocket_learn.agent.policy import Policy
19 | from rocket_learn.experience_buffer import ExperienceBuffer
20 | from rocket_learn.rollout_generator.base_rollout_generator import BaseRolloutGenerator
21 |
22 |
23 | class PPO:
24 | """
25 | Proximal Policy Optimization algorithm (PPO)
26 |
27 | :param rollout_generator: Function that will generate the rollouts
28 | :param agent: An ActorCriticAgent
29 | :param n_steps: The number of steps to run per update
30 | :param gamma: Discount factor
31 | :param batch_size: batch size to break experience data into for training
32 | :param epochs: Number of epoch when optimizing the loss
33 | :param minibatch_size: size to break batch sets into (helps combat VRAM issues)
34 | :param clip_range: PPO Clipping parameter for the value function
35 | :param ent_coef: Entropy coefficient for the loss calculation
36 | :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
37 | :param vf_coef: Value function coefficient for the loss calculation
38 | :param max_grad_norm: optional clip_grad_norm value
39 | :param logger: wandb logger to store run results
40 | :param device: torch device
41 | :param zero_grads_with_none: 0 gradient with None instead of 0
42 |
43 | Look here for info on zero_grads_with_none
44 | https://pytorch.org/docs/master/generated/torch.optim.Optimizer.zero_grad.html#torch.optim.Optimizer.zero_grad
45 | """
46 |
47 | def __init__(
48 | self,
49 | rollout_generator: BaseRolloutGenerator,
50 | agent: ActorCriticAgent,
51 | n_steps=4096,
52 | gamma=0.99,
53 | batch_size=512,
54 | epochs=10,
55 | # reuse=2,
56 | minibatch_size=None,
57 | clip_range=0.2,
58 | ent_coef=0.01,
59 | gae_lambda=0.95,
60 | vf_coef=1,
61 | max_grad_norm=0.5,
62 | logger=None,
63 | device="cuda",
64 | zero_grads_with_none=False,
65 | kl_models_weights: List[Union[Tuple[Policy, float], Tuple[Policy, float, float]]] = None
66 | ):
67 | self.rollout_generator = rollout_generator
68 |
69 | # TODO let users choose their own agent
70 | # TODO move agent to rollout generator
71 | self.agent = agent.to(device)
72 | self.device = device
73 | self.zero_grads_with_none = zero_grads_with_none
74 | self.frozen_iterations = 0
75 | self._saved_lr = None
76 |
77 | self.starting_iteration = 0
78 |
79 | # hyperparameters
80 | self.epochs = epochs
81 | self.gamma = gamma
82 | # assert n_steps % batch_size == 0
83 | # self.reuse = reuse
84 | self.n_steps = n_steps
85 | self.gae_lambda = gae_lambda
86 | self.batch_size = batch_size
87 | self.minibatch_size = minibatch_size or batch_size
88 | assert self.batch_size % self.minibatch_size == 0
89 | self.clip_range = clip_range
90 | self.ent_coef = ent_coef
91 | self.vf_coef = vf_coef
92 | self.max_grad_norm = max_grad_norm
93 |
94 | self.running_rew_mean = 0
95 | self.running_rew_var = 1
96 | self.running_rew_count = 1e-4
97 |
98 | self.total_steps = 0
99 | self.logger = logger
100 | self.logger.watch((self.agent.actor, self.agent.critic))
101 | self.timer = time.time_ns() // 1_000_000
102 | self.jit_tracer = None
103 |
104 | if kl_models_weights is not None:
105 | for i in range(len(kl_models_weights)):
106 | assert len(kl_models_weights[i]) in (2, 3)
107 | if len(kl_models_weights[i]) == 2:
108 | kl_models_weights[i] = kl_models_weights[i] + (None,)
109 | self.kl_models_weights = kl_models_weights
110 |
111 | def update_reward_norm(self, rewards: np.ndarray) -> np.ndarray:
112 | batch_mean = np.mean(rewards)
113 | batch_var = np.var(rewards)
114 | batch_count = rewards.shape[0]
115 |
116 | delta = batch_mean - self.running_rew_mean
117 | tot_count = self.running_rew_count + batch_count
118 |
119 | new_mean = self.running_rew_mean + delta * batch_count / tot_count
120 | m_a = self.running_rew_var * self.running_rew_count
121 | m_b = batch_var * batch_count
122 | m_2 = m_a + m_b + np.square(delta) * self.running_rew_count * batch_count / (
123 | self.running_rew_count + batch_count)
124 | new_var = m_2 / (self.running_rew_count + batch_count)
125 |
126 | new_count = batch_count + self.running_rew_count
127 |
128 | self.running_rew_mean = new_mean
129 | self.running_rew_var = new_var
130 | self.running_rew_count = new_count
131 |
132 | return (rewards - self.running_rew_mean) / np.sqrt(self.running_rew_var + 1e-8) # TODO normalize before update?
133 |
134 | def run(self, iterations_per_save=10, save_dir=None, save_jit=False):
135 | """
136 | Generate rollout data and train
137 | :param iterations_per_save: number of iterations between checkpoint saves
138 | :param save_dir: where to save
139 | """
140 | if save_dir:
141 | current_run_dir = os.path.join(save_dir, self.logger.project + "_" + str(time.time()))
142 | os.makedirs(current_run_dir)
143 | elif iterations_per_save and not save_dir:
144 | print("Warning: no save directory specified.")
145 | print("Checkpoints will not be saved.")
146 |
147 | iteration = self.starting_iteration
148 | rollout_gen = self.rollout_generator.generate_rollouts()
149 |
150 | self.rollout_generator.update_parameters(self.agent.actor)
151 |
152 | while True:
153 | # pr = cProfile.Profile()
154 | # pr.enable()
155 | t0 = time.time()
156 |
157 | def _iter():
158 | size = 0
159 | print(f"Collecting rollouts ({iteration})...")
160 | while size < self.n_steps:
161 | try:
162 | rollout = next(rollout_gen)
163 | if rollout.size() > 0:
164 | size += rollout.size()
165 | # progress.update(rollout.size())
166 | yield rollout
167 | except StopIteration:
168 | return
169 |
170 | self.calculate(_iter(), iteration)
171 | iteration += 1
172 |
173 | if save_dir:
174 | self.save(os.path.join(save_dir, self.logger.project + "_" + "latest"), -1, save_jit)
175 | if iteration % iterations_per_save == 0:
176 | self.save(current_run_dir, iteration, save_jit) # noqa
177 |
178 | if self.frozen_iterations > 0:
179 | if self.frozen_iterations == 1:
180 | print(" ** Unfreezing policy network **")
181 |
182 | assert self._saved_lr is not None
183 | self.agent.optimizer.param_groups[0]["lr"] = self._saved_lr
184 | self._saved_lr = None
185 |
186 | self.frozen_iterations -= 1
187 |
188 | self.rollout_generator.update_parameters(self.agent.actor)
189 |
190 | self.total_steps += self.n_steps # size
191 | t1 = time.time()
192 | self.logger.log({"ppo/steps_per_second": self.n_steps / (t1 - t0), "ppo/total_timesteps": self.total_steps})
193 |
194 | # pr.disable()
195 | # s = io.StringIO()
196 | # sortby = pstats.SortKey.CUMULATIVE
197 | # ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
198 | # ps.dump_stats(f"profile_{self.total_steps}")
199 |
200 | def set_logger(self, logger):
201 | self.logger = logger
202 |
203 | def evaluate_actions(self, observations, actions):
204 | """
205 | Calculate Log Probability and Entropy of actions
206 | """
207 | dist = self.agent.actor.get_action_distribution(observations)
208 | # indices = self.agent.get_action_indices(dists)
209 |
210 | log_prob = self.agent.actor.log_prob(dist, actions)
211 | entropy = self.agent.actor.entropy(dist, actions)
212 |
213 | entropy = -torch.mean(entropy)
214 | return log_prob, entropy, dist
215 |
216 | @staticmethod
217 | @numba.njit
218 | def _calculate_advantages_numba(rewards, values, gamma, gae_lambda, truncated):
219 | advantages = np.zeros_like(rewards)
220 | # v_targets = np.zeros_like(rewards)
221 | dones = np.zeros_like(rewards)
222 | dones[-1] = 1. if not truncated else 0.
223 | episode_starts = np.zeros_like(rewards)
224 | episode_starts[0] = 1.
225 | last_values = values[-1]
226 | last_gae_lam = 0
227 | size = len(advantages)
228 | for step in range(size - 1, -1, -1):
229 | if step == size - 1:
230 | next_non_terminal = 1.0 - dones[-1].item()
231 | next_values = last_values
232 | else:
233 | next_non_terminal = 1.0 - episode_starts[step + 1].item()
234 | next_values = values[step + 1]
235 | v_target = rewards[step] + gamma * next_values * next_non_terminal
236 | delta = v_target - values[step]
237 | last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam
238 | advantages[step] = last_gae_lam
239 | # v_targets[step] = v_target
240 | return advantages # , v_targets
241 |
242 | def calculate(self, buffers: Iterator[ExperienceBuffer], iteration):
243 | """
244 | Calculate loss and update network
245 | """
246 | obs_tensors = []
247 | act_tensors = []
248 | # value_tensors = []
249 | log_prob_tensors = []
250 | # advantage_tensors = []
251 | returns_tensors = []
252 |
253 | rewards_tensors = []
254 |
255 | ep_rewards = []
256 | ep_steps = []
257 | n = 0
258 |
259 | for buffer in buffers: # Do discounts for each ExperienceBuffer individually
260 | if isinstance(buffer.observations[0], (tuple, list)):
261 | transposed = tuple(zip(*buffer.observations))
262 | obs_tensor = tuple(torch.from_numpy(np.vstack(t)).float() for t in transposed)
263 | else:
264 | obs_tensor = th.from_numpy(np.vstack(buffer.observations)).float()
265 |
266 | with th.no_grad():
267 | if isinstance(obs_tensor, tuple):
268 | x = tuple(o.to(self.device) for o in obs_tensor)
269 | else:
270 | x = obs_tensor.to(self.device)
271 | values = self.agent.critic(x).detach().cpu().numpy().flatten() # No batching?
272 |
273 | actions = np.stack(buffer.actions)
274 | log_probs = np.stack(buffer.log_probs)
275 | rewards = np.stack(buffer.rewards)
276 | dones = np.stack(buffer.dones)
277 |
278 | size = rewards.shape[0]
279 |
280 | advantages = self._calculate_advantages_numba(rewards, values, self.gamma, self.gae_lambda, dones[-1] == 2)
281 |
282 | returns = advantages + values
283 |
284 | obs_tensors.append(obs_tensor)
285 | act_tensors.append(th.from_numpy(actions))
286 | log_prob_tensors.append(th.from_numpy(log_probs))
287 | returns_tensors.append(th.from_numpy(returns))
288 | rewards_tensors.append(th.from_numpy(rewards))
289 |
290 | ep_rewards.append(rewards.sum())
291 | ep_steps.append(size)
292 | n += 1
293 | ep_rewards = np.array(ep_rewards)
294 | ep_steps = np.array(ep_steps)
295 |
296 | self.logger.log({
297 | "ppo/ep_reward_mean": ep_rewards.mean(),
298 | "ppo/ep_reward_std": ep_rewards.std(),
299 | "ppo/ep_len_mean": ep_steps.mean(),
300 | }, step=iteration, commit=False)
301 |
302 | if isinstance(obs_tensors[0], tuple):
303 | transposed = zip(*obs_tensors)
304 | obs_tensor = tuple(th.cat(t).float() for t in transposed)
305 | else:
306 | obs_tensor = th.cat(obs_tensors).float()
307 | act_tensor = th.cat(act_tensors)
308 | log_prob_tensor = th.cat(log_prob_tensors).float()
309 | # advantages_tensor = th.cat(advantage_tensors)
310 | returns_tensor = th.cat(returns_tensors).float()
311 |
312 | tot_loss = 0
313 | tot_policy_loss = 0
314 | tot_entropy_loss = 0
315 | tot_value_loss = 0
316 | total_kl_div = 0
317 | tot_clipped = 0
318 |
319 | if self.kl_models_weights is not None:
320 | tot_kl_other_models = np.zeros(len(self.kl_models_weights))
321 | tot_kl_coeffs = np.zeros(len(self.kl_models_weights))
322 |
323 | n = 0
324 |
325 | if self.jit_tracer is None:
326 | self.jit_tracer = obs_tensor[0].to(self.device)
327 |
328 | print("Training network...")
329 |
330 | if self.frozen_iterations > 0:
331 | print("Policy network frozen, only updating value network...")
332 |
333 | precompute = torch.cat([param.view(-1) for param in self.agent.actor.parameters()])
334 | t0 = time.perf_counter_ns()
335 | self.agent.optimizer.zero_grad(set_to_none=self.zero_grads_with_none)
336 | for e in range(self.epochs):
337 | # this is mostly pulled from sb3
338 | indices = torch.randperm(returns_tensor.shape[0])[:self.batch_size]
339 | if isinstance(obs_tensor, tuple):
340 | obs_batch = tuple(o[indices] for o in obs_tensor)
341 | else:
342 | obs_batch = obs_tensor[indices]
343 | act_batch = act_tensor[indices]
344 | log_prob_batch = log_prob_tensor[indices]
345 | # advantages_batch = advantages_tensor[indices]
346 | returns_batch = returns_tensor[indices]
347 |
348 | for i in range(0, self.batch_size, self.minibatch_size):
349 | # Note: Will cut off final few samples
350 |
351 | if isinstance(obs_tensor, tuple):
352 | obs = tuple(o[i: i + self.minibatch_size].to(self.device) for o in obs_batch)
353 | else:
354 | obs = obs_batch[i: i + self.minibatch_size].to(self.device)
355 |
356 | act = act_batch[i: i + self.minibatch_size].to(self.device)
357 | # adv = advantages_batch[i:i + self.minibatch_size].to(self.device)
358 | ret = returns_batch[i: i + self.minibatch_size].to(self.device)
359 |
360 | old_log_prob = log_prob_batch[i: i + self.minibatch_size].to(self.device)
361 |
362 | # TODO optimization: use forward_actor_critic instead of separate in case shared, also use GPU
363 | try:
364 | log_prob, entropy, dist = self.evaluate_actions(obs, act) # Assuming obs and actions as input
365 | except ValueError as e:
366 | print("ValueError in evaluate_actions", e)
367 | continue
368 |
369 | ratio = torch.exp(log_prob - old_log_prob)
370 |
371 | values_pred = self.agent.critic(obs)
372 |
373 | values_pred = th.squeeze(values_pred)
374 | adv = ret - values_pred.detach()
375 | adv = (adv - th.mean(adv)) / (th.std(adv) + 1e-8)
376 |
377 | # clipped surrogate loss
378 | policy_loss_1 = adv * ratio
379 | policy_loss_2 = adv * th.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range)
380 | policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()
381 |
382 | # **If we want value clipping, add it here**
383 | value_loss = F.mse_loss(ret, values_pred)
384 |
385 | if entropy is None:
386 | # Approximate entropy when no analytical form
387 | entropy_loss = -th.mean(-log_prob)
388 | else:
389 | entropy_loss = entropy
390 |
391 | kl_loss = 0
392 | if self.kl_models_weights is not None:
393 | for k, (model, kl_coef, half_life) in enumerate(self.kl_models_weights):
394 | if half_life is not None:
395 | kl_coef *= 0.5 ** (self.total_steps / half_life)
396 | with torch.no_grad():
397 | dist_other = model.get_action_distribution(obs)
398 | div = kl_divergence(dist_other, dist).mean()
399 | tot_kl_other_models[k] += div
400 | tot_kl_coeffs[k] = kl_coef
401 | kl_loss += kl_coef * div
402 |
403 | loss = ((policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss + kl_loss)
404 | / (self.batch_size / self.minibatch_size))
405 |
406 | if not torch.isfinite(loss).all():
407 | print("Non-finite loss, skipping", n)
408 | print("\tPolicy loss:", policy_loss)
409 | print("\tEntropy loss:", entropy_loss)
410 | print("\tValue loss:", value_loss)
411 | print("\tTotal loss:", loss)
412 | print("\tRatio:", ratio)
413 | print("\tAdv:", adv)
414 | print("\tLog prob:", log_prob)
415 | print("\tOld log prob:", old_log_prob)
416 | print("\tEntropy:", entropy)
417 | print("\tActor has inf:", any(not p.isfinite().all() for p in self.agent.actor.parameters()))
418 | print("\tCritic has inf:", any(not p.isfinite().all() for p in self.agent.critic.parameters()))
419 | print("\tReward as inf:", not np.isfinite(ep_rewards).all())
420 | if isinstance(obs, tuple):
421 | for j in range(len(obs)):
422 | print(f"\tObs[{j}] has inf:", not obs[j].isfinite().all())
423 | else:
424 | print("\tObs has inf:", not obs.isfinite().all())
425 | continue
426 |
427 | loss.backward()
428 |
429 | # Unbiased low variance KL div estimator from http://joschu.net/blog/kl-approx.html
430 | total_kl_div += th.mean((ratio - 1) - (log_prob - old_log_prob)).item()
431 | tot_loss += loss.item()
432 | tot_policy_loss += policy_loss.item()
433 | tot_entropy_loss += entropy_loss.item()
434 | tot_value_loss += value_loss.item()
435 | tot_clipped += th.mean((th.abs(ratio - 1) > self.clip_range).float()).item()
436 | n += 1
437 | # pb.update(self.minibatch_size)
438 |
439 | # Clip grad norm
440 | if self.max_grad_norm is not None:
441 | clip_grad_norm_(self.agent.actor.parameters(), self.max_grad_norm)
442 |
443 | self.agent.optimizer.step()
444 | self.agent.optimizer.zero_grad(set_to_none=self.zero_grads_with_none)
445 |
446 | t1 = time.perf_counter_ns()
447 |
448 | assert n > 0
449 |
450 | postcompute = torch.cat([param.view(-1) for param in self.agent.actor.parameters()])
451 |
452 | log_dict = {
453 | "ppo/loss": tot_loss / n,
454 | "ppo/policy_loss": tot_policy_loss / n,
455 | "ppo/entropy_loss": tot_entropy_loss / n,
456 | "ppo/value_loss": tot_value_loss / n,
457 | "ppo/mean_kl": total_kl_div / n,
458 | "ppo/clip_fraction": tot_clipped / n,
459 | "ppo/epoch_time": (t1 - t0) / (1e6 * self.epochs),
460 | "ppo/update_magnitude": th.dist(precompute, postcompute, p=2),
461 | }
462 |
463 | if self.kl_models_weights is not None and len(self.kl_models_weights) > 0:
464 | log_dict.update({f"ppo/kl_div_model_{i}": tot_kl_other_models[i] / n
465 | for i in range(len(self.kl_models_weights))})
466 | log_dict.update({f"ppo/kl_coeff_model_{i}": tot_kl_coeffs[i]
467 | for i in range(len(self.kl_models_weights))})
468 |
469 | self.logger.log(log_dict, step=iteration, commit=False) # Is committed after when calculating fps
470 |
471 | def load(self, load_location, continue_iterations=True):
472 | """
473 | load the model weights, optimizer values, and metadata
474 | :param load_location: checkpoint folder to read
475 | :param continue_iterations: keep the same training steps
476 | """
477 |
478 | checkpoint = torch.load(load_location)
479 | self.agent.actor.load_state_dict(checkpoint['actor_state_dict'])
480 | self.agent.critic.load_state_dict(checkpoint['critic_state_dict'])
481 | self.agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
482 |
483 | if continue_iterations:
484 | self.starting_iteration = checkpoint['epoch']
485 | self.total_steps = checkpoint["total_steps"]
486 | print("Continuing training at iteration " + str(self.starting_iteration))
487 |
488 | def save(self, save_location, current_step, save_actor_jit=False):
489 | """
490 | Save the model weights, optimizer values, and metadata
491 | :param save_location: where to save
492 | :param current_step: the current iteration when saved. Use to later continue training
493 | :param save_actor_jit: save the policy network as a torch jit file for rlbot use
494 | """
495 |
496 | version_str = str(self.logger.project) + "_" + str(current_step)
497 | version_dir = save_location + "\\" + version_str
498 |
499 | os.makedirs(version_dir, exist_ok=current_step == -1)
500 |
501 | torch.save({
502 | 'epoch': current_step,
503 | "total_steps": self.total_steps,
504 | 'actor_state_dict': self.agent.actor.state_dict(),
505 | 'critic_state_dict': self.agent.critic.state_dict(),
506 | 'optimizer_state_dict': self.agent.optimizer.state_dict(),
507 | # TODO save/load reward normalization mean, std, count
508 | }, version_dir + "\\checkpoint.pt")
509 |
510 | if save_actor_jit:
511 | traced_actor = th.jit.trace(self.agent.actor, self.jit_tracer)
512 | torch.jit.save(traced_actor, version_dir + "\\jit_policy.jit")
513 |
514 | def freeze_policy(self, frozen_iterations=100):
515 | """
516 | Freeze policy network to allow value network to settle. Useful with pretrained policy networks.
517 |
518 | Note that network weights will not be transmitted when frozen.
519 |
520 | :param frozen_iterations: how many iterations the policy update will remain unchanged
521 | """
522 |
523 | print("-------------------------------------------------------------")
524 | print("Policy Weights frozen for " + str(frozen_iterations) + " iterations")
525 | print("-------------------------------------------------------------")
526 |
527 | self.frozen_iterations = frozen_iterations
528 |
529 | self._saved_lr = self.agent.optimizer.param_groups[0]["lr"]
530 | self.agent.optimizer.param_groups[0]["lr"] = 0
531 |
--------------------------------------------------------------------------------
/rocket_learn/rollout_generator/redis/redis_rollout_worker.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import itertools
3 | import os
4 | import time
5 | from threading import Thread
6 | from uuid import uuid4
7 |
8 | import sqlite3 as sql
9 |
10 | import numpy as np
11 | from redis import Redis
12 | from rlgym.envs import Match
13 | from rlgym.gamelaunch import LaunchPreference
14 | from rlgym.gym import Gym
15 | from tabulate import tabulate
16 |
17 | import rocket_learn.agent.policy
18 | import rocket_learn.utils.generate_episode
19 | from rocket_learn.rollout_generator.redis.utils import _unserialize_model, MODEL_LATEST, WORKER_IDS, OPPONENT_MODELS, \
20 | VERSION_LATEST, _serialize, ROLLOUTS, encode_buffers, decode_buffers, get_rating, LATEST_RATING_ID, \
21 | EXPERIENCE_PER_MODE
22 | from rocket_learn.utils.util import probability_NvsM
23 | from rocket_learn.utils.dynamic_gamemode_setter import DynamicGMSetter
24 |
25 |
26 | class RedisRolloutWorker:
27 | """
28 | Provides RedisRolloutGenerator with rollouts via a Redis server
29 |
30 | :param redis: redis object
31 | :param name: rollout worker name
32 | :param match: match object
33 | :param past_version_prob: Odds of playing against previous checkpoints
34 | :param evaluation_prob: Odds of running an evaluation match
35 | :param sigma_target: Trueskill sigma target
36 | :param dynamic_gm: Pick game mode dynamically. If True, Match.team_size should be 3
37 | :param streamer_mode: Should run in streamer mode (less data printed to screen)
38 | :param send_gamestates: Should gamestate data be sent back (increases data sent) - must send obs or gamestates
39 | :param send_obs: Should observations be send back (increases data sent) - must send obs or gamestates
40 | :param scoreboard: Scoreboard object
41 | :param pretrained_agents: Dict{} of pretrained agents and their appearance probability
42 | :param human_agent: human agent object. Sets a human match if not None
43 | :param force_paging: Should paging be forced
44 | :param auto_minimize: automatically minimize the launched rocket league instance
45 | :param local_cache_name: name of local database used for model caching. If None, caching is not used
46 | :param gamemode_weights: dict of dynamic gamemode choice weights. If None, default equal experience
47 | """
48 |
49 | def __init__(self, redis: Redis, name: str, match: Match,
50 | past_version_prob=.2, evaluation_prob=0.01, sigma_target=1,
51 | dynamic_gm=True, streamer_mode=False, send_gamestates=True,
52 | send_obs=True, scoreboard=None, pretrained_agents=None,
53 | human_agent=None, force_paging=False, auto_minimize=True,
54 | local_cache_name=None, gamemode_weights=None, full_team_evaluations=False,
55 | live_progress=True):
56 | # TODO model or config+params so workers can recreate just from redis connection?
57 | self.redis = redis
58 | self.name = name
59 |
60 | assert send_gamestates or send_obs, "Must have at least one of obs or states"
61 |
62 | self.pretrained_agents = {}
63 | self.pretrained_total_prob = 0
64 | if pretrained_agents is not None:
65 | self.pretrained_agents = pretrained_agents
66 | self.pretrained_total_prob = sum([self.pretrained_agents[key] for key in self.pretrained_agents])
67 |
68 | self.human_agent = human_agent
69 |
70 | if human_agent and pretrained_agents:
71 | print("** WARNING - Human Player and Pretrained Agents are in conflict. **")
72 | print("** Pretrained Agents will be ignored. **")
73 |
74 | self.streamer_mode = streamer_mode
75 |
76 | self.current_agent = _unserialize_model(self.redis.get(MODEL_LATEST))
77 |
78 | self.past_version_prob = past_version_prob
79 | self.evaluation_prob = evaluation_prob
80 | self.sigma_target = sigma_target
81 | self.send_gamestates = send_gamestates
82 | self.send_obs = send_obs
83 | self.dynamic_gm = dynamic_gm
84 | self.gamemode_weights = gamemode_weights
85 | if self.gamemode_weights is not None:
86 | assert np.isclose(sum(self.gamemode_weights.values()), 1), "gamemode_weights must sum to 1"
87 | self.gamemode_exp_per_episode_ema = {}
88 | self.local_cache_name = local_cache_name
89 |
90 | self.full_team_evaluations = full_team_evaluations
91 |
92 | self.live_progress = live_progress
93 |
94 | self.uuid = str(uuid4())
95 | self.redis.rpush(WORKER_IDS, self.uuid)
96 |
97 | # currently doesn't rebuild, if the old is there, reuse it.
98 | if self.local_cache_name:
99 | self.sql = sql.connect('redis-model-cache-' + local_cache_name + '.db')
100 | # if the table doesn't exist in the database, make it
101 | self.sql.execute("""
102 | CREATE TABLE if not exists MODELS (
103 | id TEXT PRIMARY KEY,
104 | parameters BLOB NOT NULL
105 | );
106 | """)
107 |
108 | if not self.streamer_mode:
109 | print("Started worker", self.uuid, "on host", self.redis.connection_pool.connection_kwargs.get("host"),
110 | "under name", name) # TODO log instead
111 | else:
112 | print("Streaming mode set. Running silent.")
113 |
114 | self.scoreboard = scoreboard
115 | state_setter = DynamicGMSetter(match._state_setter) # noqa Rangler made me do it
116 | self.set_team_size = state_setter.set_team_size
117 | match._state_setter = state_setter
118 | self.match = match
119 | self.env = Gym(match=self.match, pipe_id=os.getpid(), launch_preference=LaunchPreference.EPIC,
120 | use_injector=True, force_paging=force_paging, raise_on_crash=True, auto_minimize=auto_minimize)
121 | self.total_steps_generated = 0
122 |
123 | def _get_opponent_ids(self, n_new, n_old, pretrained_choice):
124 | # Get qualities
125 | assert (n_new + n_old) % 2 == 0
126 | per_team = (n_new + n_old) // 2
127 | gamemode = f"{per_team}v{per_team}"
128 | latest_id = self.redis.get(LATEST_RATING_ID).decode("utf-8")
129 | latest_key = f"{latest_id}-stochastic"
130 | if n_old == 0:
131 | rating = get_rating(gamemode, latest_key, self.redis)
132 | return [-1] * n_new, [rating] * n_new
133 |
134 | ratings = get_rating(gamemode, None, self.redis)
135 | latest_rating = ratings[latest_key]
136 | keys, values = zip(*ratings.items())
137 |
138 | is_eval = (n_new == 0 and len(values) >= n_old)
139 | if is_eval: # Evaluation game, try to find agents with high sigma
140 | sigmas = np.array([r.sigma for r in values])
141 | probs = np.clip(sigmas - self.sigma_target, a_min=0, a_max=None)
142 | s = probs.sum()
143 | if s == 0: # No versions with high sigma available
144 | if np.random.normal(0, self.sigma_target) > 1:
145 | # Some chance of doing a match with random versions, so they might correct themselves
146 | probs = np.ones_like(probs) / len(probs)
147 | else:
148 | return [-1] * n_old, [latest_rating] * n_old
149 | else:
150 | probs /= s
151 | versions = [np.random.choice(len(keys), p=probs)]
152 | if self.full_team_evaluations:
153 | versions = versions * per_team
154 | target_rating = values[versions[0]]
155 | elif pretrained_choice is not None: # pretrained agent chosen, just need index generation
156 | matchups = np.full((n_new + n_old), -1).tolist()
157 | for i in range(n_old):
158 | index = np.random.randint(0, n_new + n_old)
159 | matchups[index] = 'na'
160 | return matchups, ratings.values()
161 | else:
162 | if n_new == 0: # Would-be evaluation game, but not enough agents
163 | n_new = n_old
164 | n_old = 0
165 | versions = [-1] * n_new
166 | target_rating = latest_rating
167 |
168 | # Calculate 1v1 win prob against target
169 | # All the agents included should hold their own (at least approximately)
170 | # This is to prevent unrealistic scenarios,
171 | # like for instance ratings of [100, 0] vs [100, 0], which is technically fair but not useful
172 | probs = np.zeros(len(keys))
173 | if n_new == 0 and self.full_team_evaluations:
174 | for i, rating in enumerate(values):
175 | if i == versions[0]:
176 | p = 0 # Don't add more of the same agent in evaluation matches
177 | else:
178 | p = probability_NvsM([rating] * per_team, [target_rating] * per_team)
179 | probs[i] = p * (1 - p)
180 | probs /= probs.sum()
181 | opponent = np.random.choice(len(probs), p=probs)
182 | if np.random.random() < 0.5: # Randomly do blue/orange
183 | versions = versions + [opponent] * per_team
184 | else:
185 | versions = [opponent] * per_team + versions
186 | return [keys[i] for i in versions], [values[i] for i in versions]
187 | else:
188 | for i, rating in enumerate(values):
189 | if n_new == 0 and i == versions[0]:
190 | continue # Don't add more of the same agent in evaluation matches
191 | p = probability_NvsM([rating], [target_rating])
192 | probs[i] = (p * (1 - p)) ** ((n_new + n_old) // 2) # Be less lenient the more players there are
193 | probs /= probs.sum()
194 |
195 | old_versions = np.random.choice(len(probs), size=n_old - is_eval, p=probs, replace=True).tolist()
196 | versions += old_versions
197 |
198 | # Then calculate the full matchup, with just permutations of the selected versions (weighted by fairness)
199 | matchups = []
200 | qualities = []
201 | for perm in itertools.permutations(versions):
202 | it_ratings = [latest_rating if v == -1 else values[v] for v in perm]
203 | mid = len(it_ratings) // 2
204 | p = probability_NvsM(it_ratings[:mid], it_ratings[mid:])
205 | if n_new == 0 and set(perm[:mid]) == set(perm[mid:]): # Don't want team against team
206 | p = 0
207 | matchups.append(perm)
208 | qualities.append(p * (1 - p)) # From AlphaStar
209 | qualities = np.array(qualities)
210 | s = qualities.sum()
211 | if s == 0:
212 | return [-1] * (n_new + n_old), [latest_rating] * (n_new + n_old)
213 | k = np.random.choice(len(matchups), p=qualities / s)
214 | return [-1 if i == -1 else keys[i] for i in matchups[k]], \
215 | [latest_rating if i == -1 else values[i] for i in matchups[k]]
216 |
217 | @functools.lru_cache(maxsize=8)
218 | def _get_past_model(self, version):
219 | # if version in local database, query from database
220 | # if not, pull from REDIS and store in disk cache
221 |
222 | if self.local_cache_name:
223 | models = self.sql.execute("SELECT parameters FROM MODELS WHERE id == ?", (version,)).fetchall()
224 | if len(models) == 0:
225 | bytestream = self.redis.hget(OPPONENT_MODELS, version)
226 | model = _unserialize_model(bytestream)
227 |
228 | self.sql.execute('INSERT INTO MODELS (id, parameters) VALUES (?, ?)', (version, bytestream))
229 | self.sql.commit()
230 | else:
231 | # should only ever be 1 version of parameters
232 | assert len(models) <= 1
233 | # stored as tuple due to sqlite,
234 | assert len(models[0]) == 1
235 |
236 | bytestream = models[0][0]
237 | model = _unserialize_model(bytestream)
238 | else:
239 | model = _unserialize_model(self.redis.hget(OPPONENT_MODELS, version))
240 |
241 | return model
242 |
243 | def select_gamemode(self, equal_likelihood):
244 | mode_exp = {m.decode("utf-8"): int(v) for m, v in self.redis.hgetall(EXPERIENCE_PER_MODE).items()}
245 | modes = list(mode_exp.keys())
246 | if equal_likelihood:
247 | mode = np.random.choice(modes)
248 | else:
249 | dist = np.array(list(mode_exp.values())) + 1
250 | dist = dist / dist.sum()
251 | if self.gamemode_weights is None:
252 | target_dist = np.ones(len(modes))
253 | else:
254 | target_dist = np.array([self.gamemode_weights[k] for k in modes])
255 | mode_steps_per_episode = np.array(list(self.gamemode_exp_per_episode_ema.get(m, None) or 1 for m in modes))
256 |
257 | target_dist = target_dist / mode_steps_per_episode
258 | target_dist = target_dist / target_dist.sum()
259 | inv_dist = 1 - dist
260 | inv_dist = inv_dist / inv_dist.sum()
261 |
262 | dist = target_dist * inv_dist
263 | dist = dist / dist.sum()
264 |
265 | mode = np.random.choice(modes, p=dist)
266 |
267 | b, o = mode.split("v")
268 | return int(b), int(o)
269 |
270 | @staticmethod
271 | def make_table(versions, ratings, blue, orange, pretrained_choice):
272 | version_info = []
273 | for v, r in zip(versions, ratings):
274 | if pretrained_choice is not None and v == 'na': # print name but don't send it back
275 | version_info.append([str(type(pretrained_choice).__name__), "N/A"])
276 | elif v == 'na':
277 | version_info.append(['Human', "N/A"])
278 | else:
279 | if isinstance(v, int) and v < 0:
280 | v = f"Latest ({-v})"
281 | version_info.append([v, f"{r.mu:.2f}±{2 * r.sigma:.2f}"])
282 |
283 | blue_versions, blue_ratings = list(zip(*version_info[:blue]))
284 | orange_versions, orange_ratings = list(zip(*version_info[blue:]))
285 |
286 | if blue < orange:
287 | blue_versions += [""] * (orange - blue)
288 | blue_ratings += [""] * (orange - blue)
289 | elif orange < blue:
290 | orange_versions += [""] * (blue - orange)
291 | orange_ratings += [""] * (blue - orange)
292 |
293 | table_str = tabulate(list(zip(blue_versions, blue_ratings, orange_versions, orange_ratings)),
294 | headers=["Blue", "rating", "Orange", "rating"], tablefmt="rounded_outline")
295 |
296 | return table_str
297 |
298 | def run(self): # Mimics Thread
299 | """
300 | begin processing in already launched match and push to redis
301 | """
302 | n = 0
303 | latest_version = None
304 | # t = Thread()
305 | # t.start()
306 | while True:
307 | # Get the most recent version available
308 | available_version = self.redis.get(VERSION_LATEST)
309 | if available_version is None:
310 | time.sleep(1)
311 | continue # Wait for version to be published (not sure if this is necessary?)
312 | available_version = int(available_version)
313 |
314 | # Only try to download latest version when new
315 | if latest_version != available_version:
316 | model_bytes = self.redis.get(MODEL_LATEST)
317 | if model_bytes is None:
318 | time.sleep(1)
319 | continue # This is maybe not necessary? Can't hurt to leave it in.
320 | latest_version = available_version
321 | updated_agent = _unserialize_model(model_bytes)
322 | self.current_agent = updated_agent
323 |
324 | n += 1
325 | pretrained_choice = None
326 |
327 | evaluate = np.random.random() < self.evaluation_prob
328 |
329 | if self.dynamic_gm:
330 | blue, orange = self.select_gamemode(equal_likelihood=evaluate or self.streamer_mode)
331 | elif self.match._spawn_opponents is False:
332 | blue = self.match.agents
333 | orange = 0
334 | else:
335 | blue = orange = self.match.agents // 2
336 | self.set_team_size(blue, orange)
337 |
338 | if self.human_agent:
339 | n_new = blue + orange - 1
340 | versions = ['na']
341 |
342 | agents = [self.human_agent]
343 | for n in range(n_new):
344 | agents.append(self.current_agent)
345 | versions.append(-1)
346 |
347 | versions = [v if v != -1 else latest_version for v in versions]
348 | ratings = ["na"] * len(versions)
349 | else:
350 | # TODO customizable past agent selection, should team only be same agent?
351 | agents, pretrained_choice, versions, ratings = self._generate_matchup(blue + orange,
352 | latest_version,
353 | pretrained_choice,
354 | evaluate)
355 |
356 | evaluate = not any(isinstance(v, int) and v < 0 for v in versions) # Might be changed in matchup code
357 |
358 | table_str = self.make_table(versions, ratings, blue, orange, pretrained_choice)
359 |
360 | if evaluate and not self.streamer_mode and self.human_agent is None:
361 | print("EVALUATION GAME\n" + table_str)
362 | result = rocket_learn.utils.generate_episode.generate_episode(self.env, agents, evaluate=True,
363 | scoreboard=self.scoreboard,
364 | progress=self.live_progress)
365 | rollouts = []
366 | print("Evaluation finished, goal differential:", result)
367 | print()
368 | else:
369 | if not self.streamer_mode:
370 | print("ROLLOUT\n" + table_str)
371 |
372 | try:
373 | rollouts, result = rocket_learn.utils.generate_episode.generate_episode(self.env, agents,
374 | evaluate=False,
375 | scoreboard=self.scoreboard,
376 | progress=self.live_progress)
377 |
378 | if len(rollouts[0].observations) <= 1: # Happens sometimes, unknown reason
379 | print(" ** Rollout Generation Error: Restarting Generation ** ")
380 | print()
381 | continue
382 | except EnvironmentError:
383 | self.env.attempt_recovery()
384 | continue
385 |
386 | state = rollouts[0].infos[-2]["state"]
387 | goal_speed = np.linalg.norm(state.ball.linear_velocity) * 0.036 # kph
388 | str_result = ('+' if result > 0 else "") + str(result)
389 | episode_exp = len(rollouts[0].observations) * len(rollouts)
390 | self.total_steps_generated += episode_exp
391 |
392 | if self.dynamic_gm and not evaluate:
393 | mode = f"{blue}v{orange}"
394 | if mode in self.gamemode_exp_per_episode_ema:
395 | current_mean = self.gamemode_exp_per_episode_ema[mode]
396 | self.gamemode_exp_per_episode_ema[mode] = 0.98 * current_mean + 0.02 * episode_exp
397 | else:
398 | self.gamemode_exp_per_episode_ema[mode] = episode_exp
399 |
400 | post_stats = f"Rollout finished after {len(rollouts[0].observations)} steps ({self.total_steps_generated} total steps), result was {str_result}"
401 | if result != 0:
402 | post_stats += f", goal speed: {goal_speed:.2f} kph"
403 |
404 | if not self.streamer_mode:
405 | print(post_stats)
406 | print()
407 |
408 | if not self.streamer_mode:
409 | rollout_data = encode_buffers(rollouts,
410 | return_obs=self.send_obs,
411 | return_states=self.send_gamestates,
412 | return_rewards=True)
413 | # sanity_check = decode_buffers(rollout_data, versions,
414 | # has_obs=False, has_states=True, has_rewards=True,
415 | # obs_build_factory=lambda: self.match._obs_builder,
416 | # rew_func_factory=lambda: self.match._reward_fn,
417 | # act_parse_factory=lambda: self.match._action_parser)
418 | rollout_bytes = _serialize((rollout_data, versions, self.uuid, self.name, result,
419 | self.send_obs, self.send_gamestates, True))
420 |
421 | # while True:
422 | # t.join()
423 |
424 | def send():
425 | n_items = self.redis.rpush(ROLLOUTS, rollout_bytes)
426 | if n_items >= 1000:
427 | print("Had to limit rollouts. Learner may have have crashed, or is overloaded")
428 | self.redis.ltrim(ROLLOUTS, -100, -1)
429 |
430 | send()
431 | # t = Thread(target=send)
432 | # t.start()
433 | # time.sleep(0.01)
434 |
435 | def _generate_matchup(self, n_agents, latest_version, pretrained_choice, evaluate):
436 | if evaluate:
437 | n_old = n_agents
438 | else:
439 | n_old = 0
440 | rand_choice = np.random.random()
441 | if rand_choice < self.past_version_prob:
442 | n_old = np.random.randint(low=1, high=n_agents)
443 | elif rand_choice < (self.past_version_prob + self.pretrained_total_prob):
444 | wheel_prob = self.past_version_prob
445 | for agent in self.pretrained_agents:
446 | wheel_prob += self.pretrained_agents[agent]
447 | if rand_choice < wheel_prob:
448 | pretrained_choice = agent
449 | n_old = np.random.randint(low=1, high=n_agents)
450 | break
451 | n_new = n_agents - n_old
452 | versions, ratings = self._get_opponent_ids(n_new, n_old, pretrained_choice)
453 | agents = []
454 | for version in versions:
455 | if version == -1:
456 | agents.append(self.current_agent)
457 | elif pretrained_choice is not None and version == 'na':
458 | agents.append(pretrained_choice)
459 | else:
460 | selected_agent = self._get_past_model("-".join(version.split("-")[:-1]))
461 | if version.endswith("deterministic"):
462 | selected_agent.deterministic = True
463 | elif version.endswith("stochastic"):
464 | selected_agent.deterministic = False
465 | else:
466 | raise ValueError("Unknown version type")
467 | agents.append(selected_agent)
468 | if self.streamer_mode > 1:
469 | agents[-1].deterministic = True
470 | versions = [v if v != -1 else latest_version for v in versions]
471 | return agents, pretrained_choice, versions, ratings
472 |
--------------------------------------------------------------------------------