├── train ├── traces │ ├── 12mbps.trace │ ├── 0.12mbps.trace │ ├── 60mbps.trace │ ├── 108mbps.trace │ └── 10-every-200.trace ├── README.md ├── constants.py ├── common.py ├── state.py ├── models.py ├── experiments.yml └── utils.py ├── .flake8 ├── figures ├── rl_agent.png └── training_architecture.png ├── config ├── hydra │ ├── launcher │ │ ├── _submitit_local.yaml │ │ ├── _joblib.yaml │ │ └── _submitit_slurm.yaml │ └── sweeper │ │ └── _nevergrad.yaml └── config.yaml ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── third-party ├── gala │ ├── __init__.py │ ├── gala_a2c.py │ ├── utils.py │ ├── distributions.py │ ├── arguments.py │ ├── gpu_gossip_buffer.py │ ├── envs.py │ ├── model.py │ └── graph_manager.py └── CMakeLists.txt ├── .gitmodules ├── traffic_gen ├── CMakeLists.txt ├── Utils.h ├── ExampleServer.h ├── ExampleHandler.h ├── ExampleClient.h └── main.cpp ├── congestion_control ├── CongestionControlRandomEnv.h ├── CongestionControlFixedCwndEnv.h ├── RLCongestionControllerFactory.h ├── CongestionControlLocalEnv.h ├── CMakeLists.txt ├── Utils.cpp ├── CongestionControlEnvConfig.cpp ├── CongestionControlEnvFactory.h ├── Utils.h ├── CongestionControlRPCEnv.h ├── CongestionControlFixedCwndEnv.cpp ├── NetworkState.h ├── RLCongestionController.h ├── RLBandwidthSampler.h ├── CongestionControlEnvConfig.h ├── NetworkState.cpp ├── CongestionControlLocalEnv.cpp ├── RLBandwidthSampler.cpp ├── CongestionControlEnv.h ├── CongestionControlRPCEnv.cpp ├── RLCongestionController.cpp └── CongestionControlEnv.cpp ├── scripts ├── clean_pantheon_logs.sh ├── plotting │ └── util.py └── get_tperf_args.py ├── CONTRIBUTING.md ├── .gitignore ├── Dockerfile ├── CMakeLists.txt └── README.md /train/traces/12mbps.trace: -------------------------------------------------------------------------------- 1 | 1 2 | -------------------------------------------------------------------------------- /train/traces/0.12mbps.trace: -------------------------------------------------------------------------------- 1 | 100 2 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .git,*third-party* 3 | -------------------------------------------------------------------------------- /train/traces/60mbps.trace: -------------------------------------------------------------------------------- 1 | 1 2 | 1 3 | 1 4 | 1 5 | 1 6 | -------------------------------------------------------------------------------- /train/traces/108mbps.trace: -------------------------------------------------------------------------------- 1 | 1 2 | 1 3 | 1 4 | 1 5 | 1 6 | 1 7 | 1 8 | 1 9 | 1 10 | -------------------------------------------------------------------------------- /figures/rl_agent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mvfst-rl/HEAD/figures/rl_agent.png -------------------------------------------------------------------------------- /config/hydra/launcher/_submitit_local.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - submitit_local 3 | 4 | timeout_min: 180 5 | -------------------------------------------------------------------------------- /train/traces/10-every-200.trace: -------------------------------------------------------------------------------- 1 | 200 2 | 200 3 | 200 4 | 200 5 | 200 6 | 200 7 | 200 8 | 200 9 | 200 10 | 200 11 | -------------------------------------------------------------------------------- /figures/training_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/mvfst-rl/HEAD/figures/training_architecture.png -------------------------------------------------------------------------------- /config/hydra/launcher/_joblib.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - joblib 3 | 4 | #batch_size: 1 5 | n_jobs: 1 # ensure we run jobs sequentially 6 | #pre_dispatch: 0 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.8 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /third-party/gala/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .gala_a2c import GALA_A2C 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third-party/mvfst"] 2 | path = third-party/mvfst 3 | url = https://github.com/odelalleau/mvfst.git 4 | [submodule "third-party/torchbeast"] 5 | path = third-party/torchbeast 6 | url = https://github.com/facebookresearch/torchbeast.git 7 | branch = mvfst-rl 8 | -------------------------------------------------------------------------------- /config/hydra/launcher/_submitit_slurm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - submitit_slurm 3 | 4 | cpus_per_task: ${get_cpus_per_task:${mode}, ${num_actors}, ${test_job_ids}, ${test_after_train}, ${max_jobs}} 5 | gpus_per_node: 2 6 | mem_gb: 64 7 | nodes: 1 8 | partition: learnfair 9 | tasks_per_node: 1 10 | timeout_min: 900 11 | constraint: ${get_slurm_constraint:${.partition},${.gpus_per_node}} 12 | -------------------------------------------------------------------------------- /train/README.md: -------------------------------------------------------------------------------- 1 | The train directory contains all the tools needed to train RL-based congestion control using 2 | IMPALA with [Pantheon](https://github.com/StanfordSNR/pantheon) as the network emulator. 3 | 4 | Each Pantheon environment instance can be configured to run with a different emulated network setting obtained from 5 | https://github.com/StanfordSNR/observatory/blob/master/src/scripts/experiments.yml. The relevant 6 | trace files in traces/ are copied from https://github.com/StanfordSNR/observatory/tree/master/traces. 7 | -------------------------------------------------------------------------------- /traffic_gen/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | add_executable(traffic_gen main.cpp) 8 | 9 | target_compile_options( 10 | traffic_gen 11 | PRIVATE 12 | ${_QUIC_COMMON_COMPILE_OPTIONS} 13 | ) 14 | 15 | add_dependencies( 16 | traffic_gen 17 | mvfst 18 | rl_congestion_control 19 | ) 20 | 21 | target_include_directories( 22 | traffic_gen PUBLIC 23 | ${CMAKE_CURRENT_SOURCE_DIR} 24 | ) 25 | 26 | target_link_libraries( 27 | traffic_gen PUBLIC 28 | mvfst 29 | rl_congestion_control 30 | ) 31 | -------------------------------------------------------------------------------- /train/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | from os import path 9 | 10 | 11 | # Paths. 12 | SRC_DIR = path.abspath(path.join(path.dirname(__file__), os.pardir)) 13 | PANTHEON_ROOT = path.join(SRC_DIR, "_build/deps/pantheon") 14 | CONF_ROOT = path.join(SRC_DIR, "config") 15 | EXPERIMENTS_CFG = path.join(SRC_DIR, "train/experiments.yml") 16 | THIRD_PARTY_ROOT = path.join(SRC_DIR, "third-party") 17 | TORCHBEAST_ROOT = path.join(SRC_DIR, "third-party/torchbeast") 18 | GALA_ROOT = path.join(SRC_DIR, "third-party/gala") 19 | 20 | 21 | # Numbers. 22 | UDP_SEND_PACKET_LEN = 1252 # should match QuicConstants.h 23 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlRandomEnv.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include "CongestionControlEnv.h" 12 | 13 | #include 14 | 15 | namespace quic { 16 | 17 | class CongestionControlRandomEnv : public CongestionControlEnv { 18 | public: 19 | CongestionControlRandomEnv(const Config& cfg, Callback* cob, 20 | const QuicConnectionStateBase& conn) 21 | : CongestionControlEnv(cfg, cob, conn) {} 22 | 23 | private: 24 | // CongestionControlEnv impl 25 | void onObservation(Observation&& obs, float reward) override { 26 | // Random action 27 | Action action; 28 | action.cwndAction = std::rand() % cfg_.actions.size(); 29 | onAction(action); 30 | } 31 | }; 32 | 33 | } // namespace quic 34 | -------------------------------------------------------------------------------- /train/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | from dataclasses import dataclass 9 | from typing import Any, Dict 10 | 11 | from train.utils import StrEnum 12 | 13 | 14 | @dataclass 15 | class ConfigCommon(Dict[str, Any]): 16 | # Mode to run in. 17 | mode: StrEnum("Mode", "train, test, trace") = "train" 18 | # Number of parallel actors for training (ignored during testing). 19 | num_actors: int = 40 20 | # RL server address, can be : or unix:". 21 | server_address: str = "unix:/tmp/rl_server_path" 22 | # Pantheon and TorchBeast logs output directory". 23 | logdir: str = "/tmp/logs" 24 | # File to write torchscript traced model to (for training) or read from 25 | # (for local testing). 26 | traced_model: str = "traced_model.pt" 27 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_config 3 | - override hydra/launcher: _submitit_slurm 4 | 5 | # batch_size: 8 # default=8 6 | # cc_env_min_rtt_window_length_us: 10_000_000 # default=10_000_000 7 | # cc_env_reward_delay_factor: 0.2 # default=0.2 8 | # discounting: 0.99 # default=0.99 9 | # end_of_episode_bootstrap: false # default=false 10 | # entropy_cost: 0.01 # default=0.01 11 | # grad_norm_clipping: 0 # default=0 12 | # inference_batch_size: 2 # default=2 13 | # num_actors: 40 # default=40 14 | # test_after_train: true # default=true 15 | # test_job_ids: [] # default=[] 16 | # test_schemes: mvfst_rl # default="" 17 | # test_runs_per_job: 3 # default=3 18 | # train_job_ids: [] # default=[] 19 | # total_steps: 1_000_000 # default=1_000_000 20 | 21 | hydra: 22 | run: 23 | dir: /checkpoint/${oc.env:USER}/mvfst-rl/run/${now:%Y-%m-%d_%H-%M-%S} 24 | sweep: 25 | dir: /checkpoint/${oc.env:USER}/mvfst-rl/multirun/${now:%Y-%m-%d_%H-%M-%S} 26 | -------------------------------------------------------------------------------- /scripts/clean_pantheon_logs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eu 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | set -eu 11 | 12 | # This script deletes all Pantheon logs found in the pantheon/tmp folder. 13 | # ONLY USE IT WHEN NO EXPERIMENTS ARE RUNNING! 14 | 15 | SCRIPT_DIR=`dirname "$0"` 16 | PANTHEON_DIR=`realpath "$SCRIPT_DIR"/../_build/deps/pantheon` 17 | LOG_DIR="$PANTHEON_DIR"/tmp 18 | DEL_DIR="$LOG_DIR.to_delete" 19 | EMPTY_DIR="$LOG_DIR.empty" 20 | 21 | echo "Moving Pantheon logs to: $DEL_DIR" 22 | mv "$LOG_DIR" "$DEL_DIR" 23 | 24 | # Re-create the directory so that it exists for future experiments. 25 | mkdir "$LOG_DIR" 26 | 27 | echo "Deleting this folder -- this may take several hours, be patient" 28 | rm -rf "$EMPTY_DIR" 29 | mkdir "$EMPTY_DIR" 30 | # Using `rsync` instead of `rm` because it is faster. 31 | time rsync -a --delete "$EMPTY_DIR/" "$DEL_DIR" 32 | rmdir "$DEL_DIR" 33 | 34 | echo "Done!" 35 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlFixedCwndEnv.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include "CongestionControlEnv.h" 12 | 13 | namespace quic { 14 | 15 | // Basic controller aiming at reaching a specific cwnd target. 16 | // At each decision step, this controller greedily picks the action that brings 17 | // cwnd closest to its target value (NB: it may not be able to exactly reach 18 | // it). 19 | class CongestionControlFixedCwndEnv : public CongestionControlEnv { 20 | 21 | public: 22 | CongestionControlFixedCwndEnv(const Config &cfg, Callback *cob, 23 | const QuicConnectionStateBase &conn); 24 | 25 | private: 26 | // Target value for cwnd, in bytes. 27 | uint64_t cwndBytesTarget_; 28 | 29 | void onObservation(Observation &&obs, float reward) override; 30 | 31 | // Return how far we currently are from the target cwnd value. 32 | uint64_t distToTarget(uint64_t cwndBytes) const; 33 | }; 34 | 35 | } // namespace quic 36 | -------------------------------------------------------------------------------- /scripts/plotting/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import pandas as pd 9 | import matplotlib.pyplot as plt 10 | import holoviews as hv 11 | from holoviews import opts 12 | 13 | hv.extension("bokeh") 14 | 15 | 16 | def loggers2df(loggers): 17 | dfs = [pd.DataFrame(x.data) for x in loggers] 18 | return pd.concat(dfs, keys=[x.exp_id for x in loggers], names=["exp_id"]) 19 | 20 | 21 | DEF_STASH = ["case", "time_seconds"] 22 | 23 | 24 | def stash2df(stash, index=DEF_STASH): 25 | return pd.DataFrame(stash, columns=list(stash[0].keys())).set_index(index) 26 | 27 | 28 | def plot_legacy(df, x, y, z): 29 | fig, ax = plt.subplots(figsize=(8, 6)) 30 | df.groupby(z).plot(x=x, y=y, ax=ax) 31 | ax.legend(df.index.levels[0].tolist(), loc="lower right") 32 | return ax 33 | 34 | 35 | def plot(df, x, y, z, width=600, height=600): 36 | ds = hv.Dataset(df, [x, y, z]) 37 | grouped = ds.to(hv.Curve, x, y) 38 | ndoverlay = grouped.overlay(z) 39 | ndoverlay.opts(opts.Curve(width=width, height=height)) 40 | return ndoverlay 41 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to mvfst-rl 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to mvfst-rl, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /congestion_control/RLCongestionControllerFactory.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | 17 | #include "RLCongestionController.h" 18 | 19 | namespace quic { 20 | 21 | struct CongestionController; 22 | struct QuicConnectionStateBase; 23 | 24 | class RLCongestionControllerFactory : public CongestionControllerFactory { 25 | public: 26 | RLCongestionControllerFactory( 27 | std::shared_ptr envFactory) 28 | : envFactory_(envFactory) { 29 | CHECK_NOTNULL(envFactory.get()); 30 | } 31 | 32 | ~RLCongestionControllerFactory() override = default; 33 | 34 | std::unique_ptr makeCongestionController( 35 | QuicConnectionStateBase& conn, CongestionControlType type) { 36 | LOG(INFO) << "Creating RLCongestionController"; 37 | return std::make_unique(conn, envFactory_); 38 | } 39 | 40 | private: 41 | std::shared_ptr envFactory_; 42 | }; 43 | 44 | } // namespace quic 45 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlLocalEnv.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "CongestionControlEnv.h" 16 | 17 | namespace quic { 18 | 19 | class CongestionControlLocalEnv : public CongestionControlEnv { 20 | public: 21 | CongestionControlLocalEnv(const Config& cfg, Callback* cob, 22 | const QuicConnectionStateBase& conn); 23 | ~CongestionControlLocalEnv() override; 24 | 25 | private: 26 | // CongestionControlEnv impl 27 | void onObservation(Observation&& obs, float reward) override; 28 | 29 | void loop(); 30 | 31 | std::unique_ptr thread_; // Thread for inference 32 | std::atomic shutdown_{false}; // Signals termination of env loop 33 | 34 | // Tensor for holding observations 35 | torch::Tensor tensor_{torch::empty({0}, torch::kFloat32)}; 36 | float reward_; 37 | bool observationReady_{false}; 38 | 39 | torch::jit::script::Module module_; 40 | 41 | // CV and mutex for co-ordination with the inference thread. 42 | std::condition_variable cv_; 43 | std::mutex mutex_; 44 | }; 45 | 46 | } // namespace quic 47 | -------------------------------------------------------------------------------- /config/hydra/sweeper/_nevergrad.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - nevergrad 3 | 4 | optim: 5 | optimizer: RandomSearch 6 | budget: ??? 7 | num_workers: 50 8 | seed: 1234 9 | parametrization: 10 | # Learner. 11 | hidden_size: 12 | - 128 13 | - 1024 14 | unroll_length: 15 | - 8 16 | - 80 17 | - 256 18 | seed: 19 | lower: 1 20 | upper: 9999 21 | integer: true 22 | end_of_episode_bootstrap: 23 | - true 24 | - false 25 | entropy_cost: 26 | - 1e-4 27 | - 1e-3 28 | - 1e-2 29 | baseline_cost: 30 | - 0.1 31 | - 0.5 32 | - 1. 33 | - 2. 34 | discounting: 35 | - 0.95 36 | - 0.99 37 | reward_clipping: 38 | - soft_asymmetric 39 | - none 40 | # - abs_one 41 | reward_normalization_coeff: 42 | - 1e-5 43 | - 1e-4 44 | learning_rate: 45 | - 1e-4 46 | - 5e-4 47 | alpha: 48 | - 0.9 49 | - 0.99 50 | - 0.999 51 | momentum: 52 | - 0 53 | - 1e-3 54 | - 1e-1 55 | 56 | # Env. 57 | cc_env_history_size: 58 | - 1 59 | - 16 60 | - 64 61 | cc_env_reward_delay_factor: 62 | - 0.5 63 | - 0.75 64 | - 1. 65 | cc_env_reward_packet_loss_factor: 66 | - 0 67 | - 1e-1 68 | - 1 69 | -------------------------------------------------------------------------------- /congestion_control/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | add_library( 9 | rl_congestion_control STATIC 10 | CongestionControlEnv.cpp 11 | CongestionControlEnvConfig.cpp 12 | CongestionControlLocalEnv.cpp 13 | CongestionControlFixedCwndEnv.cpp 14 | NetworkState.cpp 15 | RLBandwidthSampler.cpp 16 | RLCongestionController.cpp 17 | Utils.cpp 18 | ) 19 | 20 | if(INFERENCE_ONLY) 21 | target_compile_definitions( 22 | rl_congestion_control PUBLIC 23 | MVFSTRL_INFERENCE_ONLY 24 | ) 25 | else() 26 | # Add one more cpp 27 | target_sources( 28 | rl_congestion_control PRIVATE 29 | CongestionControlRPCEnv.cpp 30 | ) 31 | add_dependencies( 32 | rl_congestion_control 33 | rpcenv_pb 34 | ) 35 | target_link_libraries( 36 | rl_congestion_control PUBLIC 37 | rpcenv_pb 38 | ) 39 | endif() 40 | 41 | target_compile_options( 42 | rl_congestion_control 43 | PRIVATE 44 | ${_QUIC_COMMON_COMPILE_OPTIONS} 45 | ) 46 | 47 | add_dependencies( 48 | rl_congestion_control 49 | mvfst 50 | ) 51 | 52 | target_include_directories( 53 | rl_congestion_control PUBLIC 54 | ${CMAKE_CURRENT_SOURCE_DIR} 55 | ) 56 | 57 | target_link_libraries( 58 | rl_congestion_control PUBLIC 59 | mvfst 60 | ${TORCH_LIBRARIES} 61 | ) 62 | 63 | target_compile_definitions( 64 | rl_congestion_control PUBLIC 65 | C10_USE_GLOG # Fix glog compilation issues related to PyTorch 66 | ) 67 | -------------------------------------------------------------------------------- /congestion_control/Utils.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | 10 | #include 11 | 12 | #include "Utils.h" 13 | 14 | namespace quic { 15 | namespace utils { 16 | 17 | int aten_to_numpy_dtype(const at::ScalarType scalar_type) { 18 | switch (scalar_type) { 19 | case at::kDouble: 20 | return NPY_DOUBLE; 21 | case at::kFloat: 22 | return NPY_FLOAT; 23 | case at::kHalf: 24 | return NPY_HALF; 25 | case at::kLong: 26 | return NPY_LONG; 27 | case at::kInt: 28 | return NPY_INT; 29 | case at::kShort: 30 | return NPY_SHORT; 31 | case at::kChar: 32 | return NPY_BYTE; 33 | case at::kByte: 34 | return NPY_UBYTE; 35 | case at::kBool: 36 | return NPY_BOOL; 37 | default: 38 | throw std::runtime_error( 39 | folly::sformat("Unsupported ScalarType: {}", toString(scalar_type))); 40 | } 41 | } 42 | 43 | at::ScalarType numpy_dtype_to_aten(int dtype) { 44 | switch (dtype) { 45 | case NPY_DOUBLE: 46 | return at::kDouble; 47 | case NPY_FLOAT: 48 | return at::kFloat; 49 | case NPY_HALF: 50 | return at::kHalf; 51 | case NPY_LONG: 52 | return at::kLong; 53 | case NPY_INT: 54 | return at::kInt; 55 | case NPY_SHORT: 56 | return at::kShort; 57 | case NPY_BYTE: 58 | return at::kChar; 59 | case NPY_UBYTE: 60 | return at::kByte; 61 | case NPY_BOOL: 62 | return at::kBool; 63 | default: 64 | throw std::runtime_error(folly::sformat("Unsupported dtype: {}", dtype)); 65 | } 66 | } 67 | } 68 | } // namespace quic::utils 69 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlEnvConfig.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #include "CongestionControlEnvConfig.h" 10 | 11 | #include 12 | #include 13 | 14 | namespace quic { 15 | 16 | void CongestionControlEnvConfig::parseActionsFromString( 17 | const std::string &actionsStr) { 18 | CHECK(!actionsStr.empty()) << "Actions cannot be empty."; 19 | 20 | quic::utils::vector v; 21 | folly::split(",", actionsStr, v); 22 | quic::utils::vector> actions_(v.size()); 23 | 24 | CHECK_EQ(v[0], "0") << "First action must be no-op (\"0\"), received " 25 | << actionsStr; 26 | actions_[0] = {ActionOp::NOOP, 0}; 27 | 28 | for (size_t i = 1; i < v.size(); ++i) { 29 | CHECK_GT(v[i].size(), 1) << "Invalid actions specified: " << actionsStr; 30 | const char op = v[i][0]; 31 | const auto &val = v[i].subpiece(1); 32 | actions_[i] = {charToActionOp(op), folly::to(val)}; 33 | } 34 | 35 | this->actions = actions_; 36 | } 37 | 38 | CongestionControlEnvConfig::ActionOp 39 | CongestionControlEnvConfig::charToActionOp(const char op) { 40 | switch (op) { 41 | case '0': 42 | return ActionOp::NOOP; 43 | case '+': 44 | return ActionOp::ADD; 45 | case '-': 46 | return ActionOp::SUB; 47 | case '*': 48 | return ActionOp::MUL; 49 | case '/': 50 | return ActionOp::DIV; 51 | default: 52 | LOG(FATAL) << "Unknown char for ActionOp: " << op; 53 | } 54 | __builtin_unreachable(); 55 | } 56 | 57 | } // namespace quic 58 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlEnvFactory.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include "CongestionControlFixedCwndEnv.h" 12 | #include "CongestionControlLocalEnv.h" 13 | #include "CongestionControlRandomEnv.h" 14 | 15 | #ifndef MVFSTRL_INFERENCE_ONLY 16 | #include "CongestionControlRPCEnv.h" 17 | #endif 18 | 19 | namespace quic { 20 | 21 | class CongestionControlEnvFactory { 22 | public: 23 | CongestionControlEnvFactory(const CongestionControlEnv::Config& cfg) 24 | : cfg_(cfg) {} 25 | 26 | std::unique_ptr make( 27 | CongestionControlEnv::Callback* cob, 28 | const QuicConnectionStateBase& conn) { 29 | switch (cfg_.mode) { 30 | case CongestionControlEnv::Config::Mode::LOCAL: 31 | return std::make_unique(cfg_, cob, conn); 32 | case CongestionControlEnv::Config::Mode::REMOTE: 33 | #ifdef MVFSTRL_INFERENCE_ONLY 34 | LOG(FATAL) << "REMOTE mode is not available as this is an inference " 35 | "only build."; 36 | return nullptr; 37 | #else 38 | return std::make_unique(cfg_, cob, conn); 39 | #endif 40 | case CongestionControlEnv::Config::Mode::RANDOM: 41 | return std::make_unique(cfg_, cob, conn); 42 | case CongestionControlEnv::Config::Mode::FIXED: 43 | return std::make_unique(cfg_, cob, conn); 44 | default: 45 | LOG(FATAL) << "Unknown mode"; 46 | return nullptr; 47 | } 48 | __builtin_unreachable(); 49 | } 50 | 51 | private: 52 | CongestionControlEnv::Config cfg_; 53 | }; 54 | 55 | } // namespace quic 56 | -------------------------------------------------------------------------------- /third-party/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | # Exports helper static libs from third-party deps 9 | 10 | if(NOT INFERENCE_ONLY) 11 | # Lib for generated torchbeast RPC protobuf files (ref torchbeast/setup.py). 12 | add_library( 13 | rpcenv_pb STATIC 14 | torchbeast/torchbeast/rpc.pb.cc 15 | torchbeast/torchbeast/rpc.grpc.pb.cc 16 | ) 17 | 18 | target_compile_options( 19 | rpcenv_pb 20 | PRIVATE 21 | ${_QUIC_COMMON_COMPILE_OPTIONS} 22 | -Wno-overloaded-virtual # gRPC is noisy 23 | ) 24 | 25 | target_link_libraries( 26 | rpcenv_pb PUBLIC 27 | grpc 28 | grpc++ 29 | gpr 30 | address_sorting 31 | protobuf 32 | ) 33 | endif() 34 | 35 | # A single interface lib that links together all mvfst libs and dependencies 36 | # for ease of use. 37 | add_library(mvfst INTERFACE) 38 | 39 | target_link_libraries( 40 | mvfst INTERFACE 41 | fizz::fizz 42 | fizz::fizz_test_support 43 | mvfst::mvfst_cc_algo 44 | mvfst::mvfst_client 45 | mvfst::mvfst_codec 46 | mvfst::mvfst_codec_decode 47 | mvfst::mvfst_codec_packet_number_cipher 48 | mvfst::mvfst_codec_pktbuilder 49 | mvfst::mvfst_codec_pktrebuilder 50 | mvfst::mvfst_codec_types 51 | mvfst::mvfst_constants 52 | mvfst::mvfst_exception 53 | mvfst::mvfst_fizz_client 54 | mvfst::mvfst_flowcontrol 55 | mvfst::mvfst_handshake 56 | mvfst::mvfst_happyeyeballs 57 | mvfst::mvfst_looper 58 | mvfst::mvfst_loss 59 | mvfst::mvfst_qlogger 60 | mvfst::mvfst_server 61 | mvfst::mvfst_state_ack_handler 62 | mvfst::mvfst_state_functions 63 | mvfst::mvfst_state_machine 64 | mvfst::mvfst_state_pacing_functions 65 | mvfst::mvfst_state_qpr_functions 66 | mvfst::mvfst_state_simple_frame_functions 67 | mvfst::mvfst_state_stream 68 | mvfst::mvfst_state_stream_functions 69 | mvfst::mvfst_transport 70 | ) 71 | -------------------------------------------------------------------------------- /traffic_gen/Utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace quic { 18 | namespace traffic_gen { 19 | 20 | class DummyCertificateVerifier : public fizz::CertificateVerifier { 21 | public: 22 | ~DummyCertificateVerifier() override = default; 23 | 24 | void verify(const std::vector>&) 25 | const override { 26 | return; 27 | } 28 | 29 | std::vector getCertificateRequestExtensions() 30 | const override { 31 | return std::vector(); 32 | } 33 | }; 34 | 35 | std::shared_ptr testCert() { 36 | auto certificate = fizz::test::getCert(fizz::test::kP256Certificate); 37 | auto privKey = fizz::test::getPrivateKey(fizz::test::kP256Key); 38 | std::vector certs; 39 | certs.emplace_back(std::move(certificate)); 40 | return std::make_shared>( 41 | std::move(privKey), std::move(certs)); 42 | } 43 | 44 | std::shared_ptr createTestServerCtx() { 45 | auto cert = testCert(); 46 | auto certManager = std::make_unique(); 47 | certManager->addCert(std::move(cert), true); 48 | auto serverCtx = std::make_shared(); 49 | serverCtx->setFactory(std::make_shared()); 50 | serverCtx->setCertManager(std::move(certManager)); 51 | serverCtx->setOmitEarlyRecordLayer(true); 52 | serverCtx->setClock(std::make_shared()); 53 | return serverCtx; 54 | } 55 | 56 | } // namespace traffic_gen 57 | } // namespace quic 58 | -------------------------------------------------------------------------------- /congestion_control/Utils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #if defined NDEBUG 12 | #include 13 | #else 14 | #include 15 | #endif 16 | 17 | #include 18 | 19 | namespace quic { 20 | namespace utils { 21 | 22 | // Define `quic::utils::vector` to be either the regular `std::vector` or its 23 | // debug version (which includes in particular bound checks), depending on 24 | // whether or not we are running in debug mode. 25 | template 26 | #if defined NDEBUG 27 | using vector = std::vector; 28 | #else 29 | using vector = __gnu_debug::vector; 30 | #endif 31 | 32 | // Redefitions of torch::aten_to_numpy_dtype and torch::numpy_dtype_to_aten 33 | // with hardcoded values for NPY_* macros which we can't use since we don't have 34 | // a Python interpreter. 35 | 36 | // Ref NPY_TYPES enum in 37 | // https://github.com/numpy/numpy/blob/464f79eb1d05bf938d16b49da1c39a4e02506fa3/numpy/core/include/numpy/ndarraytypes.h. 38 | enum NPY_TYPES { 39 | NPY_BOOL = 0, 40 | NPY_BYTE, 41 | NPY_UBYTE, 42 | NPY_SHORT, 43 | NPY_USHORT, 44 | NPY_INT, 45 | NPY_UINT, 46 | NPY_LONG, 47 | NPY_ULONG, 48 | NPY_LONGLONG, 49 | NPY_ULONGLONG, 50 | NPY_FLOAT, 51 | NPY_DOUBLE, 52 | NPY_LONGDOUBLE, 53 | NPY_CFLOAT, 54 | NPY_CDOUBLE, 55 | NPY_CLONGDOUBLE, 56 | NPY_OBJECT = 17, 57 | NPY_STRING, 58 | NPY_UNICODE, 59 | NPY_VOID, 60 | /* 61 | * New 1.6 types appended, may be integrated 62 | * into the above in 2.0. 63 | */ 64 | NPY_DATETIME, 65 | NPY_TIMEDELTA, 66 | NPY_HALF, 67 | 68 | NPY_NTYPES, 69 | NPY_NOTYPE, 70 | NPY_CHAR, 71 | NPY_USERDEF = 256, /* leave room for characters */ 72 | 73 | /* The number of types not including the new 1.6 types */ 74 | NPY_NTYPES_ABI_COMPATIBLE = 21 75 | }; 76 | 77 | int aten_to_numpy_dtype(const at::ScalarType scalar_type); 78 | 79 | at::ScalarType numpy_dtype_to_aten(int dtype); 80 | } 81 | } // namespace quic::utils 82 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlRPCEnv.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | #include 19 | 20 | #include "CongestionControlEnv.h" 21 | 22 | namespace quic { 23 | 24 | using TensorNest = nest::Nest; 25 | 26 | class CongestionControlRPCEnv : public CongestionControlEnv { 27 | public: 28 | CongestionControlRPCEnv(const Config& cfg, Callback* cob, 29 | const QuicConnectionStateBase& conn); 30 | ~CongestionControlRPCEnv() override; 31 | 32 | private: 33 | // CongestionControlEnv impl 34 | void onObservation(Observation&& obs, float reward) override; 35 | 36 | void loop(const std::string& address); 37 | 38 | static torchbeast::CallRequest makeCallRequest(int64_t actor_id, 39 | const torch::Tensor& obs, 40 | float reward, bool done); 41 | static uint32_t getActionFromCallResponse(torchbeast::CallResponse& resp); 42 | 43 | static void fillNDArrayPB(torchbeast::NDArray* ndarray, 44 | const torch::Tensor& tensor); 45 | static TensorNest arrayPBToNest(torchbeast::NDArray* ndarray); 46 | 47 | int64_t actorId_{0}; 48 | std::unique_ptr thread_; // Thread to run the gRPC client in 49 | bool connected_{false}; // Whether we are connected to gRPC server 50 | std::atomic shutdown_{false}; // Signals termination of env loop 51 | 52 | // Tensor for holding observations 53 | torch::Tensor tensor_{torch::empty({0}, torch::kFloat32)}; 54 | float reward_; 55 | bool observationReady_{false}; 56 | 57 | // CV and mutex for co-ordination with gRPC thread. 58 | std::condition_variable cv_; 59 | std::mutex mutex_; 60 | }; 61 | 62 | } // namespace quic 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Build artifacts 2 | _build/ 3 | 4 | # Log dirs 5 | logs/ 6 | train/logs/ 7 | scripts/logs/ 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | scripts/plotting/.ipynb_checkpoints/ 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | # Backup files 128 | *~ 129 | 130 | # Output files 131 | *.tsv 132 | 133 | # PyTorch checkpoint files (also GNU tar files ...) 134 | *.tar 135 | 136 | # Compiled protobuf files 137 | *.pb.h 138 | *.pb.cc 139 | 140 | # PyCharm 141 | .idea 142 | 143 | # Mac 144 | .DS_Store 145 | 146 | # Swap files 147 | *.sw* 148 | 149 | # VSCode 150 | .vscode 151 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlFixedCwndEnv.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #include "CongestionControlFixedCwndEnv.h" 10 | 11 | namespace quic { 12 | 13 | CongestionControlFixedCwndEnv::CongestionControlFixedCwndEnv( 14 | const Config &cfg, Callback *cob, const QuicConnectionStateBase &conn) 15 | : CongestionControlEnv(cfg, cob, conn) { 16 | cwndBytesTarget_ = cfg_.fixedCwnd * conn.udpSendPacketLen; 17 | } 18 | 19 | void CongestionControlFixedCwndEnv::onObservation(Observation &&obs, 20 | float reward) { 21 | // Obtain current cwnd value from observation. Compared to directly accessing 22 | // the `cwndBytes_` attribute, this is closer to how a "real" RL-based policy 23 | // would work, and avoids potential thread safety issues since `cwndBytes_` is 24 | // updated asynchronously. 25 | DCHECK(!obs.history.empty()); 26 | const float lastCwndObs = obs.history.back().cwnd; 27 | // Convert value to bytes (rounded). 28 | const uint64_t currentCwndBytes = 29 | static_cast(lastCwndObs * normBytes() + 0.5f); 30 | 31 | Action action; // the action we will take 32 | 33 | // How far are we from the target cwnd value? 34 | uint64_t currentDist = distToTarget(currentCwndBytes); 35 | if (currentDist == 0) { 36 | action.cwndAction = 0; // already at target => do nothing 37 | onAction(action); 38 | return; 39 | } 40 | 41 | // Find the action that brings `cwnd` closest to the desired target. 42 | uint32_t bestActionIdx = 0; // default = do nothing 43 | for (uint32_t actionIdx = 1; actionIdx < cfg_.actions.size(); ++actionIdx) { 44 | const uint64_t newCwndBytes = 45 | getUpdatedCwndBytes(currentCwndBytes, actionIdx); 46 | const uint64_t newDist = distToTarget(newCwndBytes); 47 | if (newDist < currentDist) { 48 | currentDist = newDist; 49 | bestActionIdx = actionIdx; 50 | } 51 | } 52 | 53 | // Apply the selected action. 54 | action.cwndAction = bestActionIdx; 55 | onAction(action); 56 | } 57 | 58 | uint64_t CongestionControlFixedCwndEnv::distToTarget(uint64_t cwndBytes) const { 59 | // Compute abs(value - target) safely (they are unsigned integers). 60 | return cwndBytes > cwndBytesTarget_ ? cwndBytes - cwndBytesTarget_ 61 | : cwndBytesTarget_ - cwndBytes; 62 | } 63 | 64 | } // namespace quic 65 | -------------------------------------------------------------------------------- /third-party/gala/gala_a2c.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ GALA-A2C agent """ 8 | 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | 12 | 13 | class GALA_A2C(): 14 | 15 | def __init__(self, actor_critic, value_loss_coef, entropy_coef, lr=None, 16 | eps=None, alpha=None, max_grad_norm=None, 17 | rank=0, gossip_buffer=None): 18 | """ GALA_A2C """ 19 | 20 | self.rank = rank 21 | self.gossip_buffer = gossip_buffer 22 | self.actor_critic = actor_critic 23 | 24 | self.value_loss_coef = value_loss_coef 25 | self.entropy_coef = entropy_coef 26 | 27 | self.max_grad_norm = max_grad_norm 28 | 29 | self.optimizer = optim.RMSprop( 30 | actor_critic.parameters(), lr, eps=eps, alpha=alpha) 31 | 32 | def update(self, rollouts): 33 | obs_shape = rollouts.obs.size()[2:] 34 | action_shape = rollouts.actions.size()[-1] 35 | num_steps, num_processes, _ = rollouts.rewards.size() 36 | 37 | values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions( 38 | rollouts.obs[:-1].view(-1, *obs_shape), 39 | rollouts.recurrent_hidden_states[0].view( 40 | -1, self.actor_critic.recurrent_hidden_state_size), 41 | rollouts.masks[:-1].view(-1, 1), 42 | rollouts.actions.view(-1, action_shape)) 43 | 44 | values = values.view(num_steps, num_processes, 1) 45 | action_log_probs = action_log_probs.view(num_steps, num_processes, 1) 46 | 47 | advantages = rollouts.returns[:-1] - values 48 | value_loss = advantages.pow(2).mean() 49 | 50 | action_loss = -(advantages.detach() * action_log_probs).mean() 51 | 52 | self.optimizer.zero_grad() 53 | (value_loss * self.value_loss_coef + action_loss - 54 | dist_entropy * self.entropy_coef).backward() 55 | 56 | nn.utils.clip_grad_norm_(self.actor_critic.parameters(), 57 | self.max_grad_norm) 58 | 59 | self.optimizer.step() 60 | 61 | # Local-Gossip 62 | if self.gossip_buffer is not None: 63 | self.gossip_buffer.write_message(self.rank, self.actor_critic) 64 | self.gossip_buffer.aggregate_message(self.rank, self.actor_critic) 65 | 66 | return value_loss.item(), action_loss.item(), dist_entropy.item() 67 | -------------------------------------------------------------------------------- /congestion_control/NetworkState.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "Utils.h" 16 | 17 | namespace quic { 18 | 19 | struct NetworkState { 20 | public: 21 | // NOTE: If fields are added, make sure to also update: 22 | // - fieldToString() in NetworkState.cpp 23 | // - the corresponding `Field` enum in state.py 24 | enum class Field : uint16_t { 25 | // RTT related 26 | RTT_MIN = 0, 27 | RTT_STANDING, 28 | LRTT, 29 | SRTT, 30 | RTT_VAR, 31 | DELAY, 32 | 33 | // Bytes related 34 | CWND, 35 | IN_FLIGHT, 36 | WRITABLE, 37 | SENT, 38 | RECEIVED, 39 | RETRANSMITTED, 40 | 41 | // LossState 42 | PTO_COUNT, 43 | TOTAL_PTO_DELTA, // Derived from LossState::totalPTOCount 44 | RTX_COUNT, 45 | TIMEOUT_BASED_RTX_COUNT, 46 | 47 | // AckEvent 48 | ACKED, 49 | THROUGHPUT, 50 | 51 | // LossEvent 52 | LOST, 53 | PERSISTENT_CONGESTION, 54 | 55 | // Total number of fields 56 | NUM_FIELDS 57 | }; 58 | 59 | static constexpr uint16_t kNumFields = 60 | static_cast(Field::NUM_FIELDS); 61 | 62 | NetworkState() : data_(kNumFields, 0.0) {} 63 | 64 | inline const float *data() const { return data_.data(); } 65 | inline constexpr uint16_t size() const { return kNumFields; } 66 | 67 | inline float operator[](int idx) const { return data_[idx]; } 68 | inline float operator[](Field field) const { 69 | return data_[static_cast(field)]; 70 | } 71 | inline float &operator[](int idx) { return data_[idx]; } 72 | inline float &operator[](Field field) { 73 | return data_[static_cast(field)]; 74 | } 75 | 76 | inline void setField(const Field field, const float &value) { 77 | data_[static_cast(field)] = value; 78 | } 79 | 80 | torch::Tensor toTensor() const; 81 | void toTensor(torch::Tensor &tensor) const; 82 | static torch::Tensor 83 | toTensor(const quic::utils::vector &states); 84 | static void toTensor(const quic::utils::vector &states, 85 | torch::Tensor &tensor); 86 | 87 | static quic::utils::vector 88 | fromTensor(const torch::Tensor &tensor); 89 | 90 | static std::string fieldToString(const uint16_t field); 91 | static std::string fieldToString(const Field field); 92 | 93 | private: 94 | quic::utils::vector data_; 95 | }; 96 | 97 | std::ostream &operator<<(std::ostream &os, const NetworkState &state); 98 | 99 | } // namespace quic 100 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ## Demonstration of running an experiment using the model inside the Pantheon environment. 2 | 3 | # This Dockerfile can be used to run any of the schemes Pantheon has available 4 | # by specifying the SCHEMENAME argument, inside a Docker container with high 5 | # privileges. We are using it to demonstrate using a trained mvfst_rl model. 6 | 7 | # In particular, this container is not used for training the model. 8 | 9 | # 1) If using Docker Desktop on a Mac, it is a good idea to increase its memory limits 10 | # because the default 2GB is too small to build mvfst. 11 | 12 | # 2) Build the docker image using 13 | # docker build --tag mvfst_rl --build-arg SCHEMENAME=mvfst_rl - < Dockerfile 14 | # where Dockerfile is this file. 15 | 16 | # 3) Run the image using 17 | # CAPS='--cap-add=NET_ADMIN --cap-add=SYS_ADMIN' 18 | # sudo docker run --name c_mvfst_rl ${CAPS:?} --rm -t -i mvfst_rl 19 | # 20 | # Inside the container, you can run any of the mvfst schemes because they all depend 21 | # on the same setup. For example you can type 22 | # sudo -u runner -H src/experiments/test.py local --schemes mvfst_bbr --flows 1 23 | # for the bbr scheme. The mvfst_rl scheme (running the trained model) can also be run with 24 | # . 1 25 | 26 | FROM ubuntu:18.04 27 | 28 | RUN echo Europe/London > /etc/timezone && \ 29 | apt-get update && \ 30 | DEBIAN_FRONTEND=noninteractive apt-get install -y \ 31 | git \ 32 | python-pip \ 33 | python-yaml \ 34 | python-matplotlib \ 35 | sudo \ 36 | ntp \ 37 | ntpdate \ 38 | mahimahi \ 39 | autogen \ 40 | debhelper autotools-dev dh-autoreconf iptables pkg-config iproute2 && \ 41 | pip install tabulate && \ 42 | useradd runner && \ 43 | mkdir -m 777 ~runner && \ 44 | chown runner: ~runner 45 | 46 | RUN sudo -u runner -H git clone https://github.com/StanfordSNR/pantheon.git ~runner/pantheon 47 | 48 | WORKDIR /home/runner/pantheon 49 | 50 | RUN sudo -u runner -H git submodule update --init --recursive 51 | 52 | RUN cd ~runner/pantheon/third_party/pantheon-tunnel && ./autogen.sh && \ 53 | ./configure && make && make install 54 | 55 | ARG SCHEMENAME 56 | 57 | RUN src/experiments/setup.py --install-deps --schemes $SCHEMENAME 58 | 59 | RUN src/experiments/setup.py --setup --schemes $SCHEMENAME 60 | 61 | RUN echo 'mkdir -p /dev/net && mknod /dev/net/tun c 10 200' > prelim.sh && \ 62 | echo 'mount -o remount rw /proc/sys' >> prelim.sh && \ 63 | echo 'chmod o+w tmp' >> prelim.sh && \ 64 | echo 'echo Please run \". 0\" or \". 1\" to run a test.' >> prelim.sh && \ 65 | echo "sudo -u runner -H src/experiments/test.py local --schemes $SCHEMENAME --flows 1" > 1 && \ 66 | echo "sudo -u runner -H src/experiments/test.py local --schemes $SCHEMENAME --flows 0" > 0 67 | 68 | CMD bash --init-file /home/runner/pantheon/prelim.sh 69 | 70 | -------------------------------------------------------------------------------- /third-party/gala/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2017 Ilya Kostrikov 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | Taken from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail 25 | and slightly modified. 26 | """ 27 | import glob 28 | import logging 29 | import math 30 | import os 31 | import sys 32 | 33 | import torch 34 | import torch.nn as nn 35 | 36 | from gala.envs import VecNormalize 37 | 38 | 39 | # Get a render function 40 | def get_render_func(venv): 41 | if hasattr(venv, 'envs'): 42 | return venv.envs[0].render 43 | elif hasattr(venv, 'venv'): 44 | return get_render_func(venv.venv) 45 | elif hasattr(venv, 'env'): 46 | return get_render_func(venv.env) 47 | 48 | return None 49 | 50 | 51 | def get_vec_normalize(venv): 52 | if isinstance(venv, VecNormalize): 53 | return venv 54 | elif hasattr(venv, 'venv'): 55 | return get_vec_normalize(venv.venv) 56 | 57 | return None 58 | 59 | 60 | # Necessary for my KFAC implementation. 61 | class AddBias(nn.Module): 62 | def __init__(self, bias): 63 | super(AddBias, self).__init__() 64 | self._bias = nn.Parameter(bias.unsqueeze(1)) 65 | 66 | def forward(self, x): 67 | if x.dim() == 2: 68 | bias = self._bias.t().view(1, -1) 69 | else: 70 | bias = self._bias.t().view(1, -1, 1, 1) 71 | 72 | return x + bias 73 | 74 | 75 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 76 | """Decreases the learning rate linearly""" 77 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 78 | for param_group in optimizer.param_groups: 79 | param_group['lr'] = lr 80 | 81 | 82 | def init(module, weight_init, bias_init, gain=1): 83 | weight_init(module.weight.data, gain=gain) 84 | bias_init(module.bias.data) 85 | return module 86 | 87 | 88 | def cleanup_log_dir(log_dir): 89 | try: 90 | os.makedirs(log_dir) 91 | except OSError: 92 | files = glob.glob(os.path.join(log_dir, '*.monitor.csv')) 93 | for f in files: 94 | os.remove(f) 95 | 96 | 97 | -------------------------------------------------------------------------------- /train/state.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Utility functions to handle the input state. 9 | """ 10 | 11 | from enum import Enum, auto 12 | 13 | import torch 14 | 15 | 16 | # The `Field` enum must be a copy of the one found in `NetworkState.h`. 17 | class Field(Enum): 18 | 19 | # Overridden to guarantee that indices remain consistent with C++ even if 20 | # the default Python implementation changes (unlikely). 21 | def _generate_next_value_(name, start, count, last_values): 22 | return last_values[-1] + 1 23 | 24 | RTT_MIN = 0 25 | RTT_STANDING = auto() 26 | 27 | LRTT = auto() 28 | SRTT = auto() 29 | RTT_VAR = auto() 30 | DELAY = auto() 31 | 32 | CWND = auto() 33 | IN_FLIGHT = auto() 34 | WRITABLE = auto() 35 | SENT = auto() 36 | RECEIVED = auto() 37 | RETRANSMITTED = auto() 38 | 39 | PTO_COUNT = auto() 40 | TOTAL_PTO_DELTA = auto() 41 | RTX_COUNT = auto() 42 | TIMEOUT_BASED_RTX_COUNT = auto() 43 | 44 | ACKED = auto() 45 | THROUGHPUT = auto() 46 | 47 | LOST = auto() 48 | PERSISTENT_CONGESTION = auto() 49 | 50 | NUM_FIELD = auto() 51 | 52 | 53 | # These offsets should match the order of aggregate statistics found 54 | # in `CongestionControlEnv::stateSummary()`. 55 | N = Field.NUM_FIELD.value 56 | OFFSET_SUM = 0 57 | OFFSET_MEAN = N 58 | OFFSET_STD = N * 2 59 | OFFSET_MIN = N * 3 60 | OFFSET_MAX = N * 4 61 | 62 | 63 | def get_from_state(state, field, offset, dim=0): 64 | """ 65 | Fetch the `state` entry found at index `offset` + `field`. 66 | 67 | :param state: Input state (a PyTorch tensor). 68 | :param field: The field to fetch (a `Field` enum). 69 | :param offset: Offset to apply to the index. 70 | :param dim: The dimension along which we should index. 71 | """ 72 | idx = offset + field.value 73 | if dim == 0: 74 | return state[idx] # straightforward indexing on first dimension 75 | else: 76 | idx_tensor = torch.tensor(idx).to(state.device) 77 | return state.index_select(dim, idx_tensor).squeeze(dim) 78 | 79 | 80 | def get_sum(state, field, dim=0): 81 | """Fetch the sum of `field` in `state`""" 82 | return get_from_state(state, field, offset=OFFSET_SUM, dim=dim) 83 | 84 | 85 | def get_mean(state, field, dim=0): 86 | """Fetch the mean of `field` in `state`""" 87 | return get_from_state(state, field, offset=OFFSET_MEAN, dim=dim) 88 | 89 | 90 | def get_std(state, field, dim=0): 91 | """Fetch the standard deviation of `field` in `state`""" 92 | return get_from_state(state, field, offset=OFFSET_STD, dim=dim) 93 | 94 | 95 | def get_min(state, field, dim=0): 96 | """Fetch the minimum of `field` in `state`""" 97 | return get_from_state(state, field, offset=OFFSET_MIN, dim=dim) 98 | 99 | 100 | def get_max(state, field, dim=0): 101 | """Fetch the maximum of `field` in `state`""" 102 | return get_from_state(state, field, offset=OFFSET_MAX, dim=dim) 103 | -------------------------------------------------------------------------------- /congestion_control/RLCongestionController.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | 19 | #include "CongestionControlEnvFactory.h" 20 | #include "RLBandwidthSampler.h" 21 | 22 | namespace quic { 23 | 24 | using namespace std::chrono_literals; 25 | 26 | class RLCongestionController : public CongestionController, 27 | public CongestionControlEnv::Callback { 28 | public: 29 | RLCongestionController( 30 | QuicConnectionStateBase &conn, 31 | std::shared_ptr envFactory); 32 | 33 | void onRemoveBytesFromInflight(uint64_t) override; 34 | void onPacketSent(const OutstandingPacket &packet) override; 35 | void onPacketAckOrLoss(folly::Optional, 36 | folly::Optional) override; 37 | 38 | uint64_t getWritableBytes() const noexcept override; 39 | uint64_t getCongestionWindow() const noexcept override; 40 | CongestionControlType type() const noexcept override; 41 | 42 | uint64_t getBytesInFlight() const noexcept; 43 | 44 | void setAppIdle(bool, TimePoint) noexcept override; 45 | void setAppLimited() override; 46 | 47 | bool isAppLimited() const noexcept override; 48 | 49 | void getStats(CongestionControllerStats& /*stats*/) const override {} 50 | 51 | private: 52 | void onPacketAcked(const AckEvent &); 53 | void onPacketLoss(const LossEvent &); 54 | 55 | // CongestionControlEnv::Callback 56 | void onUpdate(const uint64_t &cwndBytes) noexcept override; 57 | 58 | bool setNetworkState(const folly::Optional &ack, 59 | const folly::Optional &loss, 60 | NetworkState &obs); 61 | 62 | QuicConnectionStateBase &conn_; 63 | uint64_t bytesInFlight_{0}; 64 | uint64_t cwndBytes_; 65 | 66 | std::unique_ptr env_; 67 | 68 | // Copa-style RTT filters to get more accurate min and standing RTT values. 69 | WindowedFilter, uint64_t, 71 | uint64_t> 72 | minRTTFilter_; // To get min RTT over 10 seconds 73 | 74 | WindowedFilter, uint64_t, 76 | uint64_t> 77 | standingRTTFilter_; // To get min RTT over srtt/2 78 | 79 | RLBandwidthSampler bandwidthSampler_; // bandwidth estimator 80 | 81 | // Variables to track conn_.lossState values from previous ack or loss 82 | // to compute state deltas for current ack or loss 83 | uint64_t prevTotalBytesSent_{0}; 84 | uint64_t prevTotalBytesRecvd_{0}; 85 | uint64_t prevTotalBytesRetransmitted_{0}; 86 | uint32_t prevTotalPTOCount_{0}; 87 | uint32_t prevRtxCount_{0}; 88 | uint32_t prevTimeoutBasedRtxCount_{0}; 89 | }; 90 | 91 | } // namespace quic 92 | -------------------------------------------------------------------------------- /congestion_control/RLBandwidthSampler.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | 13 | #include 14 | 15 | #include "Utils.h" 16 | 17 | namespace quic { 18 | 19 | using namespace std::chrono_literals; 20 | 21 | // Bandwidth estimates must be computed over a window spanning at least this 22 | // duration (to make these estimates more stable). 23 | constexpr std::chrono::microseconds kBandwidthWindowMinDuration{100'000us}; 24 | 25 | /* 26 | Bandwidth estimator based on ACK packets. 27 | 28 | At high level the logic is as follows: 29 | - keep a rolling window of the last K ACK packets 30 | - estimate bandwdith as total #bytes acknowledged in these packets, divided 31 | by the window duration 32 | 33 | This class re-uses the same API as the bandwidth sampler from BBR so that it 34 | can be used as a "plug'n play" replacement. 35 | */ 36 | class RLBandwidthSampler : public BbrCongestionController::BandwidthSampler { 37 | public: 38 | explicit RLBandwidthSampler(QuicConnectionStateBase &conn); 39 | 40 | Bandwidth getBandwidth() const noexcept override; 41 | 42 | // NB: this class actually ignores `rttCounter` as all computations are based 43 | // on timings included in `AckEvent`. 44 | void onPacketAcked(const CongestionController::AckEvent &, 45 | uint64_t rttCounter) override; 46 | 47 | // For now we ignore app-limited mode. 48 | void onAppLimited() noexcept override {} 49 | bool isAppLimited() const noexcept override { return false; } 50 | 51 | private: 52 | // Return the index in the rolling window corresponding to the last ACK 53 | // event that was received. 54 | uint64_t getPreviousIdx() const noexcept { 55 | return ackIdx_ > 0 ? ackIdx_ - 1 : ackBytes_.size() - 1; 56 | } 57 | 58 | QuicConnectionStateBase &conn_; 59 | 60 | // Rolling windows of (1) number of acked bytes, and (2) associated 61 | // timestamps (corresponding to the time each ACK was received). 62 | quic::utils::vector ackBytes_; 63 | quic::utils::vector ackTimes_; 64 | 65 | // We enforce a minimum window duration to avoid problematic situations where 66 | // several ACK events may be processed at (almost) the same time due to 67 | // network hiccups. Without a minimum duration, this could lead to an 68 | // unexpectedly high bandwidth estimate. 69 | std::chrono::microseconds minWindowDuration_{kBandwidthWindowMinDuration}; 70 | 71 | // The minimum window duration above is translated into a minimum interval 72 | // between events stored in the rolling window (events too close to each other 73 | // are combined so as to respect this constraint). 74 | std::chrono::microseconds minIntervalBetweenAcks_; 75 | 76 | TimePoint lastEntryInitialTime_; // initial timestamp of last entry in window 77 | uint64_t totalAckBytes_{0}; // sum of acked bytes within window 78 | uint64_t ackIdx_{0}; // current index in rolling window 79 | bool gotFirstAck_{false}; // whether we have received the first ACK 80 | }; 81 | 82 | } // namespace quic -------------------------------------------------------------------------------- /traffic_gen/ExampleServer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | 20 | namespace quic { 21 | namespace traffic_gen { 22 | 23 | class ExampleServerTransportFactory : public quic::QuicServerTransportFactory { 24 | public: 25 | ~ExampleServerTransportFactory() override { 26 | while (!exampleHandlers_.empty()) { 27 | auto& handler = exampleHandlers_.back(); 28 | handler->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait( 29 | [this] { 30 | // The evb should be performing a sequential consistency atomic 31 | // operation already, so we can bank on that to make sure the writes 32 | // propagate to all threads. 33 | exampleHandlers_.pop_back(); 34 | }); 35 | } 36 | } 37 | 38 | ExampleServerTransportFactory() {} 39 | 40 | quic::QuicServerTransport::Ptr make( 41 | folly::EventBase* evb, std::unique_ptr sock, 42 | const folly::SocketAddress&, 43 | std::shared_ptr 44 | ctx) noexcept override { 45 | CHECK_EQ(evb, sock->getEventBase()); 46 | auto exampleHandler = std::make_unique(evb); 47 | auto transport = quic::QuicServerTransport::make(evb, std::move(sock), 48 | *exampleHandler, ctx); 49 | exampleHandler->setQuicSocket(transport); 50 | exampleHandlers_.push_back(std::move(exampleHandler)); 51 | return transport; 52 | } 53 | 54 | std::vector> exampleHandlers_; 55 | 56 | private: 57 | }; 58 | 59 | class ExampleServer { 60 | public: 61 | explicit ExampleServer( 62 | const std::string& host = "::1", uint16_t port = 6666, 63 | CongestionControlType cc_algo = CongestionControlType::Cubic, 64 | std::shared_ptr ccFactory = 65 | std::make_shared()) 66 | : host_(host), port_(port), server_(QuicServer::createQuicServer()) { 67 | server_->setQuicServerTransportFactory( 68 | std::make_unique()); 69 | server_->setFizzContext(createTestServerCtx()); 70 | server_->setCongestionControllerFactory(ccFactory); 71 | TransportSettings settings; 72 | settings.defaultCongestionController = cc_algo; 73 | server_->setTransportSettings(settings); 74 | } 75 | 76 | void start() { 77 | // Create a SocketAddress and the default or passed in host. 78 | folly::SocketAddress addr(host_.c_str(), port_); 79 | server_->start(addr, 0); 80 | LOG(INFO) << "ExampleServer started at: " << addr.describe(); 81 | eventbase_.loopForever(); 82 | } 83 | 84 | private: 85 | std::string host_; 86 | uint16_t port_; 87 | folly::EventBase eventbase_; 88 | std::shared_ptr server_; 89 | }; 90 | 91 | } // namespace traffic_gen 92 | } // namespace quic 93 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | cmake_minimum_required(VERSION 3.10) 8 | 9 | project( 10 | mvfst-rl 11 | ) 12 | 13 | set(PROJECT_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) 14 | set(QUIC_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/third-party/mvfst) 15 | set(THIRDPARTY_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/third-party) 16 | 17 | message(STATUS "PROJECT_ROOT: ${PROJECT_ROOT}") 18 | message(STATUS "QUIC_ROOT: ${QUIC_ROOT}") 19 | message(STATUS "THIRDPARTY_ROOT: ${THIRDPARTY_ROOT}") 20 | message(STATUS "PREFIX_PATH: ${PREFIX_PATH}") 21 | message(STATUS "BUILD_TESTS: ${BUILD_TESTS}") 22 | message(STATUS "INFERENCE_ONLY: ${INFERENCE_ONLY}") 23 | 24 | list(APPEND 25 | CMAKE_MODULE_PATH 26 | ${CMAKE_CURRENT_SOURCE_DIR}/cmake 27 | ${QUIC_ROOT}/cmake 28 | ) 29 | 30 | find_package(mvfst) 31 | 32 | set(CMAKE_CXX_STANDARD 17) 33 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 34 | set(CMAKE_CXX_EXTENSIONS OFF) 35 | 36 | list(APPEND 37 | _QUIC_BASE_COMPILE_OPTIONS 38 | -std=c++14 39 | -Wall 40 | -Wextra 41 | ) 42 | 43 | list(APPEND 44 | _QUIC_COMMON_COMPILE_OPTIONS 45 | ${_QUIC_BASE_COMPILE_OPTIONS} 46 | # more strict options 47 | -Werror=sign-compare 48 | -Werror=bool-compare 49 | -Woverloaded-virtual 50 | -Wnon-virtual-dtor 51 | # disable noisy stuff 52 | -Wno-unused-parameter 53 | -Wno-narrowing 54 | ) 55 | 56 | find_package(fmt REQUIRED) 57 | 58 | # Find GFlags 59 | SET(GFLAG_DEPENDENCIES "") 60 | find_package(gflags CONFIG QUIET) 61 | if (gflags_FOUND) 62 | message(STATUS "Found gflags from package config") 63 | if (TARGET gflags-shared) 64 | list(APPEND GFLAG_DEPENDENCIES gflags-shared) 65 | elseif (TARGET gflags) 66 | list(APPEND GFLAG_DEPENDENCIES gflags) 67 | else() 68 | message(FATAL_ERROR "Unable to determine the target name for the GFlags package.") 69 | endif() 70 | list(APPEND CMAKE_REQUIRED_LIBRARIES ${GFLAGS_LIBRARIES}) 71 | list(APPEND CMAKE_REQUIRED_INCLUDES ${GFLAGS_INCLUDE_DIR}) 72 | else() 73 | find_package(Gflags REQUIRED MODULE) 74 | list(APPEND CMAKE_REQUIRED_LIBRARIES ${LIBGFLAGS_LIBRARY}) 75 | list(APPEND CMAKE_REQUIRED_INCLUDES ${LIBGFLAGS_INCLUDE_DIR}) 76 | endif() 77 | 78 | # Find GMock and GTest. Required for linking some TestUtils. 79 | set(REQUIRED_LINK_DIRS "") 80 | if(BUILD_TESTS) 81 | enable_testing() 82 | list(APPEND CMAKE_REQUIRED_INCLUDES "${QUIC_ROOT}/_build/build/googletest/src/googletest/googlemock/include") 83 | list(APPEND CMAKE_REQUIRED_INCLUDES "${QUIC_ROOT}/_build/build/googletest/src/googletest/googletest/include") 84 | list(APPEND REQUIRED_LINK_DIRS "${QUIC_ROOT}/_build/deps/lib") 85 | list(APPEND CMAKE_REQUIRED_LIBRARIES "gmock") 86 | list(APPEND CMAKE_REQUIRED_LIBRARIES "gtest") 87 | endif() 88 | 89 | # Find PyTorch 90 | find_package(Torch REQUIRED) 91 | message(STATUS "Found PyTorch libs: ${TORCH_LIBRARIES}") 92 | 93 | include_directories( 94 | ${PROJECT_ROOT} 95 | ${CMAKE_REQUIRED_INCLUDES} 96 | ${PREFIX_PATH}/include 97 | ${THIRDPARTY_ROOT} 98 | ) 99 | 100 | link_directories( 101 | ${REQUIRED_LINK_DIRS} 102 | ${PREFIX_PATH}/lib 103 | ) 104 | 105 | link_libraries( 106 | ${CMAKE_REQUIRED_LIBRARIES} 107 | ) 108 | 109 | add_subdirectory(traffic_gen) 110 | add_subdirectory(congestion_control) 111 | add_subdirectory(third-party) 112 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlEnvConfig.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | 17 | #include "Utils.h" 18 | 19 | namespace quic { 20 | 21 | struct CongestionControlEnvConfig { 22 | /// Definitions 23 | 24 | enum class Mode : uint8_t { 25 | LOCAL = 0, // RL policy run locally 26 | REMOTE, // RL policy on a remote RL server 27 | RANDOM, // Simple env that takes random actions (for testing) 28 | FIXED, // Simple env that attempts to reach a fixed cwnd target (for 29 | // testing) 30 | }; 31 | 32 | // Type of aggregation to group state updates 33 | enum class Aggregation : uint8_t { 34 | TIME_WINDOW = 0, // Group state updates every X ms 35 | FIXED_WINDOW, // Group every Y state updates 36 | }; 37 | 38 | enum class ActionOp : uint8_t { 39 | NOOP = 0, 40 | ADD, 41 | SUB, 42 | MUL, 43 | DIV, 44 | }; 45 | 46 | /// Members 47 | 48 | Mode mode{Mode::LOCAL}; 49 | 50 | // PyTorch traced model file to load for local mode 51 | std::string modelFile{""}; 52 | 53 | // RL server address (":" or "unix:") for remote mode. 54 | std::string rpcAddress{"unix:/tmp/rl_server_path"}; 55 | 56 | // For use in training to uniquely identify an actor across episodic 57 | // connections to RL server. 58 | int64_t actorId{0}; 59 | 60 | // Index of the current job in the list of active jobs. -1 if undefined. 61 | int64_t jobId{-1}; 62 | 63 | Aggregation aggregation{Aggregation::TIME_WINDOW}; 64 | std::chrono::milliseconds windowDuration{100}; // Time window duration 65 | uint32_t windowSize{10}; // Fixed window size 66 | bool useStateSummary{true}; // Whether to use state summary instead of raw 67 | // states (auto-enabled for TIME_WINDOW). 68 | 69 | // Normalization factors for observation fields 70 | float normMs{100.0}; 71 | float normBytes{1000.0}; 72 | 73 | // Size of history (such as past actions) to include in observation 74 | uint32_t historySize{2}; 75 | 76 | // Default actions: [noop, cwnd / 2, cwnd - 10, cwnd + 10, cwnd * 2] 77 | quic::utils::vector> actions{ 78 | {ActionOp::NOOP, 0}, {ActionOp::DIV, 2}, {ActionOp::SUB, 10}, 79 | {ActionOp::ADD, 10}, {ActionOp::MUL, 2}, 80 | }; 81 | 82 | // Multipliers for reward components 83 | bool rewardLogRatio{false}; 84 | float throughputFactor{0.1}; 85 | float throughputLogOffset{1.0}; 86 | float delayFactor{0.01}; 87 | float delayLogOffset{1.0}; 88 | float packetLossFactor{0.0}; 89 | float packetLossLogOffset{1.0}; 90 | 91 | // Whether to use max delay within a window in reward (avg otherwise) 92 | bool maxDelayInReward{true}; 93 | 94 | // 'fixed' env mode only: the target cwnd value we want to reach 95 | uint32_t fixedCwnd{10}; 96 | 97 | /// RLCongestionController settings 98 | 99 | // Window duration used to compute the min RTT. 100 | std::chrono::microseconds minRTTWindowLength{kMinRTTWindowLength}; 101 | 102 | /// Helper functions 103 | 104 | /** 105 | * Actions should be specified as string of comma-separated items of the 106 | * format "". can be one of [+, -, *, /]. can be any 107 | * float. The first item should be "0", which means NOOP action. 108 | * 109 | * Example: "0,/2,-10,+10,*2". 110 | */ 111 | void parseActionsFromString(const std::string &actionsStr); 112 | static ActionOp charToActionOp(const char op); 113 | }; 114 | 115 | } // namespace quic 116 | -------------------------------------------------------------------------------- /congestion_control/NetworkState.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #include "NetworkState.h" 10 | 11 | #include 12 | 13 | namespace quic { 14 | 15 | torch::Tensor NetworkState::toTensor() const { 16 | torch::Tensor tensor = torch::empty({0}, torch::kFloat32); 17 | toTensor(tensor); 18 | return tensor; 19 | } 20 | 21 | void NetworkState::toTensor(torch::Tensor &tensor) const { 22 | toTensor({*this}, tensor); 23 | } 24 | 25 | torch::Tensor 26 | NetworkState::toTensor(const quic::utils::vector &states) { 27 | torch::Tensor tensor = torch::empty({0}, torch::kFloat32); 28 | toTensor(states, tensor); 29 | return tensor; 30 | } 31 | 32 | void NetworkState::toTensor(const quic::utils::vector &states, 33 | torch::Tensor &tensor) { 34 | if (states.empty()) { 35 | tensor.resize_({0}); 36 | return; 37 | } 38 | 39 | tensor.resize_({static_cast(states.size()), states[0].size()}); 40 | auto tensor_a = tensor.accessor(); 41 | for (int i = 0; i < tensor_a.size(0); ++i) { 42 | for (int j = 0; j < tensor_a.size(1); ++j) { 43 | tensor_a[i][j] = states[i][j]; 44 | } 45 | } 46 | } 47 | 48 | quic::utils::vector 49 | NetworkState::fromTensor(const torch::Tensor &tensor) { 50 | CHECK_EQ(tensor.dim(), 2); 51 | CHECK_EQ(tensor.sizes()[1], kNumFields); 52 | 53 | quic::utils::vector states; 54 | auto tensor_a = tensor.accessor(); 55 | for (int i = 0; i < tensor_a.size(0); ++i) { 56 | NetworkState state; 57 | for (int j = 0; j < tensor_a.size(1); ++j) { 58 | state[j] = tensor_a[i][j]; 59 | } 60 | states.push_back(std::move(state)); 61 | } 62 | return states; 63 | } 64 | 65 | std::string NetworkState::fieldToString(const uint16_t field) { 66 | return fieldToString(static_cast(field)); 67 | } 68 | 69 | std::string NetworkState::fieldToString(const Field field) { 70 | switch (field) { 71 | case Field::RTT_MIN: 72 | return "rtt_min"; 73 | case Field::RTT_STANDING: 74 | return "rtt_standing"; 75 | case Field::LRTT: 76 | return "lrtt"; 77 | case Field::SRTT: 78 | return "srtt"; 79 | case Field::RTT_VAR: 80 | return "rtt_var"; 81 | case Field::DELAY: 82 | return "delay"; 83 | case Field::CWND: 84 | return "cwnd"; 85 | case Field::IN_FLIGHT: 86 | return "in_flight"; 87 | case Field::WRITABLE: 88 | return "writable"; 89 | case Field::SENT: 90 | return "sent"; 91 | case Field::RECEIVED: 92 | return "received"; 93 | case Field::RETRANSMITTED: 94 | return "retransmitted"; 95 | case Field::PTO_COUNT: 96 | return "pto_count"; 97 | case Field::TOTAL_PTO_DELTA: 98 | return "total_pto_delta"; 99 | case Field::RTX_COUNT: 100 | return "rtx_count"; 101 | case Field::TIMEOUT_BASED_RTX_COUNT: 102 | return "timeout_based_rtx_count"; 103 | case Field::ACKED: 104 | return "acked"; 105 | case Field::THROUGHPUT: 106 | return "throughput"; 107 | case Field::LOST: 108 | return "lost"; 109 | case Field::PERSISTENT_CONGESTION: 110 | return "persistent_congestion"; 111 | case Field::NUM_FIELDS: 112 | return "num_fields"; 113 | default: 114 | LOG(FATAL) << "Unknown field"; 115 | break; 116 | } 117 | __builtin_unreachable(); 118 | } 119 | 120 | std::ostream &operator<<(std::ostream &os, const NetworkState &state) { 121 | os << "NetworkState (" << state.size() << " fields): "; 122 | for (size_t i = 0; i < state.size(); ++i) { 123 | os << NetworkState::fieldToString(i) << "=" << state[i] << " "; 124 | } 125 | return os; 126 | } 127 | 128 | } // namespace quic 129 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlLocalEnv.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #include "CongestionControlLocalEnv.h" 10 | 11 | namespace quic { 12 | 13 | namespace { 14 | // This should be train/learner.py --hidden_size + 1 15 | const int kLSTMHiddenSize = 1024 + 1; 16 | } 17 | 18 | CongestionControlLocalEnv::CongestionControlLocalEnv( 19 | const Config &cfg, Callback *cob, const QuicConnectionStateBase &conn) 20 | : CongestionControlEnv(cfg, cob, conn) { 21 | LOG(INFO) << "Loading traced model from " << cfg.modelFile; 22 | module_ = torch::jit::load(cfg.modelFile, at::kCPU); 23 | 24 | thread_ = 25 | std::make_unique(&CongestionControlLocalEnv::loop, this); 26 | } 27 | 28 | CongestionControlLocalEnv::~CongestionControlLocalEnv() { 29 | shutdown_ = true; 30 | cv_.notify_all(); 31 | thread_->join(); 32 | } 33 | 34 | void CongestionControlLocalEnv::onObservation(Observation &&obs, float reward) { 35 | std::unique_lock lock(mutex_, std::try_to_lock); 36 | if (!lock) { 37 | LOG(WARNING) << __func__ << ": Still waiting for an update from model, " 38 | "skipping observation"; 39 | return; 40 | } 41 | obs.toTensor(tensor_); 42 | reward_ = reward; 43 | observationReady_ = true; 44 | lock.unlock(); 45 | cv_.notify_one(); 46 | } 47 | 48 | void CongestionControlLocalEnv::loop() { 49 | Action action; 50 | bool done = true; 51 | uint32_t episode_step = 0; 52 | float episode_return = 0.0; 53 | std::unique_lock lock(mutex_); 54 | 55 | // Initialize LSTM core state with zeros 56 | auto core_state = at::ivalue::Tuple::create( 57 | {torch::zeros({1, kLSTMHiddenSize}, at::kFloat), 58 | torch::zeros({1, kLSTMHiddenSize}, at::kFloat)}); 59 | 60 | while (!shutdown_) { 61 | cv_.wait(lock, [&]() -> bool { return (observationReady_ || shutdown_); }); 62 | if (shutdown_) { 63 | break; 64 | } 65 | 66 | done = (episode_step == 0); 67 | episode_return += reward_; 68 | VLOG(2) << "Episode step = " << episode_step 69 | << ", total return = " << episode_return; 70 | 71 | // env_inputs: (obs, reward, done) 72 | auto reward_tensor = torch::from_blob(&reward_, {1}, at::kFloat); 73 | auto done_tensor = torch::from_blob(&done, {1}, at::kBool); 74 | auto env_inputs = at::ivalue::Tuple::create({tensor_.reshape({1, -1}), 75 | std::move(reward_tensor), 76 | std::move(done_tensor)}); 77 | 78 | // inputs: (last_action, (obs, reward, done), core_state) 79 | auto last_action_tensor = 80 | torch::from_blob(&action.cwndAction, {1}, at::kLong); 81 | quic::utils::vector inputs{std::move(last_action_tensor), 82 | std::move(env_inputs), 83 | std::move(core_state)}; 84 | const auto &outputs = module_.forward(inputs).toTuple(); 85 | 86 | // output: (action, core_state) 87 | const auto &action_tensor = outputs->elements()[0].toTensor(); 88 | core_state = outputs->elements()[1].toTuple(); 89 | 90 | action.cwndAction = *action_tensor.data_ptr(); 91 | 92 | // If there is an ongoing shutdown, it is important not to trigger the action 93 | // because `onAction()` calls `runImmediatelyOrRunInEventBaseThreadAndWait()` 94 | // and this method will hang forever during shutdown, preventing the thread from 95 | // exiting cleanly. 96 | if (!shutdown_) { 97 | onAction(action); 98 | } else { 99 | LOG(INFO) << "Skipping action due to shutdown in progress"; 100 | } 101 | 102 | episode_step++; 103 | observationReady_ = false; // Back to waiting 104 | } 105 | 106 | LOG(INFO) << "Inference loop terminating after " << episode_step 107 | << " steps, total return = " << episode_return; 108 | } 109 | 110 | } // namespace quic 111 | -------------------------------------------------------------------------------- /congestion_control/RLBandwidthSampler.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | 10 | #include "RLBandwidthSampler.h" 11 | 12 | namespace quic { 13 | 14 | using namespace std::chrono; 15 | 16 | RLBandwidthSampler::RLBandwidthSampler(QuicConnectionStateBase &conn) 17 | : conn_(conn), ackBytes_(kBandwidthWindowLength, 0), 18 | ackTimes_(kBandwidthWindowLength, Clock::now()), 19 | lastEntryInitialTime_(Clock::now()) { 20 | // Compute the min interval required to respect the min window duration. 21 | const uint64_t minWin = minWindowDuration_.count(); 22 | const uint64_t winSize = ackBytes_.size(); 23 | // Ceiling of integer division, see e.g. 24 | // https://stackoverflow.com/questions/2745074/fast-ceiling-of-an-integer-division-in-c-c 25 | minIntervalBetweenAcks_ = 26 | std::chrono::microseconds(minWin / winSize + (minWin % winSize != 0)); 27 | VLOG(10) << __func__ << ": minWindowDuration = " << minWindowDuration_.count() 28 | << ", window size = " << ackBytes_.size() 29 | << ", minIntervalBetweenAcks = " << minIntervalBetweenAcks_.count(); 30 | } 31 | 32 | Bandwidth RLBandwidthSampler::getBandwidth() const noexcept { 33 | // Compute current window duration, lower bounded by its min allowed value. 34 | const uint64_t previousIdx = getPreviousIdx(); 35 | std::chrono::microseconds windowDuration = 36 | duration_cast(ackTimes_[previousIdx] - ackTimes_[ackIdx_]); 37 | DCHECK(windowDuration.count() >= 0); 38 | windowDuration = std::max(windowDuration, minWindowDuration_); 39 | 40 | uint64_t ackBytes = totalAckBytes_; 41 | 42 | if (gotFirstAck_) { 43 | // Check if we have not received any ACK packet for a while (= for a 44 | // duration greater than the current window duration). If that is the case 45 | // then we linearly decrease the bandwidth to zero over a period equal to 46 | // the current window duration. 47 | const std::chrono::microseconds timeSinceLastAck = 48 | duration_cast(Clock::now() - 49 | ackTimes_[previousIdx]); 50 | 51 | if (timeSinceLastAck > windowDuration) { 52 | // Linearly decrease the bandwidth to zero over `windowDuration`. 53 | const float scale = (timeSinceLastAck - windowDuration).count() / 54 | static_cast(windowDuration.count()); 55 | const uint64_t bytesToRemove = static_cast(ackBytes * scale); 56 | ackBytes = ackBytes > bytesToRemove ? ackBytes - bytesToRemove : 0; 57 | } 58 | } 59 | 60 | VLOG(10) << __func__ << "Computing bandwidth based on " << ackBytes 61 | << " acknowledged bytes over " << (windowDuration.count() / 1000) 62 | << " ms"; 63 | 64 | return Bandwidth(ackBytes, windowDuration); 65 | } 66 | 67 | void RLBandwidthSampler::onPacketAcked( 68 | const CongestionController::AckEvent &ackEvent, uint64_t rttCounter) { 69 | 70 | if (!gotFirstAck_) { 71 | // First ACK: we use it to initialize timestamps but ignore acked bytes, as 72 | // it is difficult to obtain a meaningful bandwidth estimate from one ACK. 73 | std::fill(ackTimes_.begin(), ackTimes_.end(), ackEvent.ackTime); 74 | lastEntryInitialTime_ = ackEvent.ackTime; 75 | gotFirstAck_ = true; 76 | return; 77 | } 78 | 79 | // Update number of acked bytes. 80 | totalAckBytes_ += ackEvent.ackedBytes; 81 | 82 | // We will update rolling window based on how close we are to the previous 83 | // entry. 84 | const std::chrono::microseconds lastInterval = 85 | duration_cast(ackEvent.ackTime - 86 | lastEntryInitialTime_); 87 | DCHECK(lastInterval.count() >= 0); 88 | 89 | if (lastInterval >= minIntervalBetweenAcks_) { 90 | // Large enough interval between events: create a new entry. 91 | 92 | // Remove oldest acked bytes. 93 | totalAckBytes_ -= ackBytes_[ackIdx_]; 94 | // Update initial timestamp for this new entry. 95 | lastEntryInitialTime_ = ackEvent.ackTime; 96 | // Store new entry. 97 | ackBytes_[ackIdx_] = ackEvent.ackedBytes; 98 | ackTimes_[ackIdx_] = ackEvent.ackTime; 99 | // Update rolling window index. 100 | ackIdx_ = (ackIdx_ + 1) % ackBytes_.size(); 101 | } else { 102 | // New event is too close: combine it with the previous entry. 103 | const uint64_t previousIdx = getPreviousIdx(); 104 | ackBytes_[previousIdx] += ackEvent.ackedBytes; 105 | ackTimes_[previousIdx] = ackEvent.ackTime; 106 | } 107 | } 108 | 109 | } // namespace quic -------------------------------------------------------------------------------- /traffic_gen/ExampleHandler.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | 17 | DECLARE_int32(chunk_size); 18 | 19 | namespace quic { 20 | namespace traffic_gen { 21 | 22 | class ExampleHandler : public quic::QuicSocket::ConnectionCallback, 23 | public quic::QuicSocket::ReadCallback, 24 | public quic::QuicSocket::WriteCallback { 25 | public: 26 | using StreamData = std::pair; 27 | 28 | explicit ExampleHandler(folly::EventBase* evbIn) : evb(evbIn) { 29 | // Create dummy data to send 30 | std::string data(FLAGS_chunk_size, 'x'); 31 | respBuf_ = folly::IOBuf::copyBuffer(data); 32 | } 33 | 34 | void setQuicSocket(std::shared_ptr socket) { 35 | sock = socket; 36 | } 37 | 38 | void onNewBidirectionalStream(quic::StreamId id) noexcept override { 39 | LOG(INFO) << "Got bidirectional stream id=" << id; 40 | sock->setReadCallback(id, this); 41 | } 42 | 43 | void onNewUnidirectionalStream(quic::StreamId id) noexcept override { 44 | LOG(INFO) << "Got unidirectional stream id=" << id; 45 | sock->setReadCallback(id, this); 46 | } 47 | 48 | void onStopSending(quic::StreamId id, 49 | quic::ApplicationErrorCode error) noexcept override { 50 | LOG(INFO) << "Got StopSending stream id=" << id << " error=" << error; 51 | } 52 | 53 | void onConnectionEnd() noexcept override { LOG(INFO) << "Socket closed"; } 54 | 55 | void onConnectionError( 56 | std::pair error) noexcept override { 57 | LOG(ERROR) << "Socket error=" << toString(error.first); 58 | } 59 | 60 | void readAvailable(quic::StreamId id) noexcept override { 61 | LOG(INFO) << "read available for stream id=" << id; 62 | 63 | auto res = sock->read(id, 0); 64 | if (res.hasError()) { 65 | LOG(ERROR) << "Got error=" << toString(res.error()); 66 | return; 67 | } 68 | if (input_.find(id) == input_.end()) { 69 | input_.emplace( 70 | id, 71 | std::make_pair( 72 | folly::IOBufQueue(folly::IOBufQueue::cacheChainLength()), false)); 73 | } 74 | quic::Buf data = std::move(res.value().first); 75 | bool eof = res.value().second; 76 | auto dataLen = (data ? data->computeChainDataLength() : 0); 77 | LOG(INFO) << "Got len=" << dataLen << " eof=" << uint32_t(eof) 78 | << " total=" << input_[id].first.chainLength() + dataLen 79 | << " data=" << data->clone()->moveToFbString().toStdString(); 80 | input_[id].first.append(std::move(data)); 81 | input_[id].second = eof; 82 | if (eof) { 83 | response(id, input_[id]); 84 | } 85 | } 86 | 87 | void readError( 88 | quic::StreamId id, 89 | std::pair> 90 | error) noexcept override { 91 | LOG(ERROR) << "Got read error on stream=" << id 92 | << " error=" << toString(error); 93 | // A read error only terminates the ingress portion of the stream state. 94 | // Your application should probably terminate the egress portion via 95 | // resetStream 96 | } 97 | 98 | void response(quic::StreamId id, StreamData& data) { 99 | auto responseData = respBuf_->clone(); 100 | bool eof = false; 101 | auto res = 102 | sock->writeChain(id, std::move(responseData), eof, false, nullptr); 103 | if (res.hasError()) { 104 | LOG(ERROR) << "write error=" << toString(res.error()); 105 | } else { 106 | sock->notifyPendingWriteOnStream(id, this); 107 | } 108 | } 109 | 110 | void onStreamWriteReady(quic::StreamId id, 111 | uint64_t maxToSend) noexcept override { 112 | VLOG(2) << "Socket is write ready with maxToSend=" << maxToSend; 113 | response(id, input_[id]); 114 | } 115 | 116 | void onStreamWriteError( 117 | quic::StreamId id, 118 | std::pair> 119 | error) noexcept override { 120 | LOG(ERROR) << "write error with stream=" << id 121 | << " error=" << toString(error); 122 | } 123 | 124 | folly::EventBase* getEventBase() { return evb; } 125 | 126 | folly::EventBase* evb; 127 | std::shared_ptr sock; 128 | 129 | private: 130 | std::map input_; 131 | std::unique_ptr respBuf_; 132 | }; 133 | 134 | } // namespace traffic_gen 135 | } // namespace quic 136 | -------------------------------------------------------------------------------- /third-party/gala/distributions.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2017 Ilya Kostrikov 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | Taken from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail 25 | and slightly modified. 26 | """ 27 | import torch 28 | import torch.nn as nn 29 | 30 | 31 | # Necessary for my KFAC implementation. 32 | class AddBias(nn.Module): 33 | def __init__(self, bias): 34 | super(AddBias, self).__init__() 35 | self._bias = nn.Parameter(bias.unsqueeze(1)) 36 | 37 | def forward(self, x): 38 | if x.dim() == 2: 39 | bias = self._bias.t().view(1, -1) 40 | else: 41 | bias = self._bias.t().view(1, -1, 1, 1) 42 | 43 | return x + bias 44 | 45 | 46 | def init(module, weight_init, bias_init, gain=1): 47 | weight_init(module.weight.data, gain=gain) 48 | bias_init(module.bias.data) 49 | return module 50 | 51 | 52 | """ 53 | Modify standard PyTorch distributions so they are compatible with this code. 54 | """ 55 | 56 | # 57 | # Standardize distribution interfaces 58 | # 59 | 60 | # Categorical 61 | FixedCategorical = torch.distributions.Categorical 62 | 63 | old_sample = FixedCategorical.sample 64 | FixedCategorical.sample = lambda self: old_sample(self).unsqueeze(-1) 65 | 66 | log_prob_cat = FixedCategorical.log_prob 67 | FixedCategorical.log_probs = lambda self, actions: log_prob_cat( 68 | self, actions.squeeze(-1)).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 69 | 70 | FixedCategorical.mode = lambda self: self.probs.argmax(dim=-1, keepdim=True) 71 | 72 | # Normal 73 | FixedNormal = torch.distributions.Normal 74 | 75 | log_prob_normal = FixedNormal.log_prob 76 | FixedNormal.log_probs = lambda self, actions: log_prob_normal( 77 | self, actions).sum( 78 | -1, keepdim=True) 79 | 80 | normal_entropy = FixedNormal.entropy 81 | FixedNormal.entropy = lambda self: normal_entropy(self).sum(-1) 82 | 83 | FixedNormal.mode = lambda self: self.mean 84 | 85 | # Bernoulli 86 | FixedBernoulli = torch.distributions.Bernoulli 87 | 88 | log_prob_bernoulli = FixedBernoulli.log_prob 89 | FixedBernoulli.log_probs = lambda self, actions: log_prob_bernoulli( 90 | self, actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 91 | 92 | bernoulli_entropy = FixedBernoulli.entropy 93 | FixedBernoulli.entropy = lambda self: bernoulli_entropy(self).sum(-1) 94 | FixedBernoulli.mode = lambda self: torch.gt(self.probs, 0.5).float() 95 | 96 | 97 | class Categorical(nn.Module): 98 | def __init__(self, num_inputs, num_outputs): 99 | super(Categorical, self).__init__() 100 | 101 | init_ = lambda m: init( 102 | m, 103 | nn.init.orthogonal_, 104 | lambda x: nn.init.constant_(x, 0), 105 | gain=0.01) 106 | 107 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 108 | 109 | def forward(self, x): 110 | x = self.linear(x) 111 | return FixedCategorical(logits=x) 112 | 113 | 114 | class DiagGaussian(nn.Module): 115 | def __init__(self, num_inputs, num_outputs): 116 | super(DiagGaussian, self).__init__() 117 | 118 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 119 | constant_(x, 0)) 120 | 121 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 122 | self.logstd = AddBias(torch.zeros(num_outputs)) 123 | 124 | def forward(self, x): 125 | action_mean = self.fc_mean(x) 126 | 127 | # An ugly hack for my KFAC implementation. 128 | zeros = torch.zeros(action_mean.size()) 129 | if x.is_cuda: 130 | zeros = zeros.cuda() 131 | 132 | action_logstd = self.logstd(zeros) 133 | return FixedNormal(action_mean, action_logstd.exp()) 134 | 135 | 136 | class Bernoulli(nn.Module): 137 | def __init__(self, num_inputs, num_outputs): 138 | super(Bernoulli, self).__init__() 139 | 140 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 141 | constant_(x, 0)) 142 | 143 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 144 | 145 | def forward(self, x): 146 | x = self.linear(x) 147 | return FixedBernoulli(logits=x) 148 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlEnv.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | #include 20 | 21 | #include "CongestionControlEnvConfig.h" 22 | #include "NetworkState.h" 23 | #include "Utils.h" 24 | 25 | namespace quic { 26 | 27 | class CongestionControlEnv { 28 | public: 29 | using Config = CongestionControlEnvConfig; 30 | 31 | struct Callback { 32 | virtual ~Callback() = default; 33 | virtual void onUpdate(const uint64_t &cwndBytes) noexcept = 0; 34 | }; 35 | 36 | struct Action { 37 | // This assumes that the policy has a no-op action at index 0 38 | uint32_t cwndAction{0}; 39 | }; 40 | 41 | struct History { 42 | Action action; // Past action taken 43 | float cwnd; // Normalized cwnd after applying the action 44 | 45 | History(const Action &a, const float c) : action(a), cwnd(c) {} 46 | }; 47 | 48 | struct Observation { 49 | public: 50 | Observation(const Config &cfg) : cfg_(cfg) {} 51 | 52 | torch::Tensor toTensor() const; 53 | void toTensor(torch::Tensor &tensor) const; 54 | 55 | quic::utils::vector states; 56 | quic::utils::vector history; 57 | 58 | private: 59 | const Config &cfg_; 60 | }; 61 | 62 | CongestionControlEnv(const Config &cfg, Callback *cob, 63 | const QuicConnectionStateBase &conn); 64 | virtual ~CongestionControlEnv() = default; 65 | 66 | /** 67 | * To be invoked by whoever owns CongestionControlEnv (such as 68 | * RLCongestionController) to share network state updates after every 69 | * Ack/Loss event. 70 | */ 71 | void onNetworkState(NetworkState &&state); 72 | 73 | inline const Config &config() const { return cfg_; } 74 | inline float normMs() const { return cfg_.normMs; } 75 | inline float normBytes() const { return cfg_.normBytes; } 76 | 77 | protected: 78 | /** 79 | * onObservation() will be triggered when there are enough state updates to 80 | * run the policy and predict an action. Subclasses should implement this 81 | * and return the action via onAction() callback, either synchronously or 82 | * asynchronously. 83 | */ 84 | virtual void onObservation(Observation &&obs, float reward) = 0; 85 | 86 | /** 87 | * Callback to be invoked by subclasses when there is an update 88 | * following onObservation(). 89 | */ 90 | void onAction(const Action &action); 91 | 92 | /** 93 | * Return the updated value of cwnd after applying a specific action. 94 | */ 95 | uint64_t getUpdatedCwndBytes(uint64_t currentCwndBytes, 96 | uint32_t actionIdx) const; 97 | 98 | const Config &cfg_; 99 | 100 | private: 101 | class ObservationTimeout : public folly::HHWheelTimer::Callback { 102 | public: 103 | explicit ObservationTimeout(CongestionControlEnv *env, 104 | folly::EventBase *evb) 105 | : env_(CHECK_NOTNULL(env)), evb_(CHECK_NOTNULL(evb)) {} 106 | ~ObservationTimeout() override = default; 107 | 108 | void schedule(const std::chrono::milliseconds &timeoutMs) noexcept { 109 | evb_->timer().scheduleTimeout(this, timeoutMs); 110 | } 111 | 112 | void timeoutExpired() noexcept override { 113 | env_->observationTimeoutExpired(); 114 | } 115 | 116 | void callbackCanceled() noexcept override { return; } 117 | 118 | private: 119 | CongestionControlEnv *env_; 120 | folly::EventBase *evb_; 121 | }; 122 | 123 | void observationTimeoutExpired() noexcept; 124 | void handleStates(); 125 | float computeReward(const quic::utils::vector &states) const; 126 | void updateCwnd(const uint32_t actionIdx); 127 | 128 | inline bool useStateSummary() const { 129 | return cfg_.useStateSummary || 130 | (cfg_.aggregation == Config::Aggregation::TIME_WINDOW); 131 | } 132 | 133 | /** 134 | * Compute sum, mean, std, min, max for each field. 135 | */ 136 | quic::utils::vector 137 | stateSummary(const quic::utils::vector &states); 138 | 139 | Callback *cob_{nullptr}; 140 | const QuicConnectionStateBase &conn_; 141 | folly::EventBase *evb_{nullptr}; 142 | ObservationTimeout observationTimeout_; 143 | 144 | uint64_t cwndBytes_; 145 | quic::utils::vector states_; 146 | std::deque history_; 147 | 148 | // Keep track of running statistics on rewards. 149 | uint64_t rewardCount_; 150 | float rewardSum_; 151 | 152 | // Intermediate tensor to compute state summary 153 | torch::Tensor summaryTensor_{torch::empty({0}, torch::kFloat32)}; 154 | 155 | std::chrono::time_point lastObservationTime_; 156 | }; 157 | 158 | std::ostream &operator<<(std::ostream &os, 159 | const CongestionControlEnv::Observation &observation); 160 | std::ostream &operator<<(std::ostream &os, 161 | const CongestionControlEnv::History &history); 162 | 163 | } // namespace quic 164 | -------------------------------------------------------------------------------- /third-party/gala/arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import argparse 8 | 9 | import torch 10 | 11 | 12 | def get_args(arg_dict=None): 13 | parser = argparse.ArgumentParser(description='RL') 14 | 15 | parser.add_argument( 16 | '--sync-freq', 17 | type=int, 18 | default=0, 19 | help='max amount of message staleness for local gossip') 20 | parser.add_argument( 21 | '--num-learners', 22 | type=int, 23 | default=1, 24 | help='number of learners to stack on device') 25 | parser.add_argument( 26 | '--num-peers', 27 | type=int, 28 | default=1, 29 | help='number of peers to communicate with in each iteration') 30 | parser.add_argument( 31 | '--lr', 32 | type=float, 33 | default=7e-4, 34 | help='learning rate (default: 7e-4)') 35 | parser.add_argument( 36 | '--eps', 37 | type=float, 38 | default=1e-5, 39 | help='RMSprop optimizer epsilon (default: 1e-5)') 40 | parser.add_argument( 41 | '--alpha', 42 | type=float, 43 | default=0.99, 44 | help='RMSprop optimizer apha (default: 0.99)') 45 | parser.add_argument( 46 | '--gamma', 47 | type=float, 48 | default=0.99, 49 | help='discount factor for rewards (default: 0.99)') 50 | parser.add_argument( 51 | '--use-gae', 52 | action='store_true', 53 | default=False, 54 | help='use generalized advantage estimation') 55 | parser.add_argument( 56 | '--gae-lambda', 57 | type=float, 58 | default=0.95, 59 | help='gae lambda parameter (default: 0.95)') 60 | parser.add_argument( 61 | '--entropy-coef', 62 | type=float, 63 | default=0.01, 64 | help='entropy term coefficient (default: 0.01)') 65 | parser.add_argument( 66 | '--value-loss-coef', 67 | type=float, 68 | default=0.5, 69 | help='value loss coefficient (default: 0.5)') 70 | parser.add_argument( 71 | '--max-grad-norm', 72 | type=float, 73 | default=0.5, 74 | help='max norm of gradients (default: 0.5)') 75 | parser.add_argument( 76 | '--seed', 77 | type=int, 78 | default=1, 79 | help='random seed (default: 1)') 80 | parser.add_argument( 81 | '--cuda-deterministic', 82 | action='store_true', 83 | default=False, 84 | help="sets flags for determinism when using CUDA (potentially slow!)") 85 | parser.add_argument( 86 | '--num-procs-per-learner', 87 | type=int, 88 | default=16, 89 | help='num simulators per learner (default: 16)') 90 | parser.add_argument( 91 | '--max-steps', 92 | type=int, 93 | default=int(10e3), 94 | help='max episode length (default: 10,000)') 95 | parser.add_argument( 96 | '--num-steps-per-update', 97 | type=int, 98 | default=5, 99 | help='number of forward steps in A2C (default: 5)') 100 | parser.add_argument( 101 | '--clip-param', 102 | type=float, 103 | default=0.2, 104 | help='ppo clip parameter (default: 0.2)') 105 | parser.add_argument( 106 | '--log-interval', 107 | type=int, 108 | default=10, 109 | help='log interval, measured in environment steps (default: 10)') 110 | parser.add_argument( 111 | '--save-interval', 112 | type=int, 113 | default=100, 114 | help='save interval, measured in environment steps (default: 100)') 115 | parser.add_argument( 116 | '--num-env-steps', 117 | type=int, 118 | default=10e6, 119 | help='number of total environment steps to train (default: 10e6)') 120 | parser.add_argument( 121 | '--env-name', 122 | default='PongNoFrameskip-v4', 123 | help='environment to train on (default: PongNoFrameskip-v4)') 124 | parser.add_argument( 125 | '--eval-log-dir', 126 | default='/tmp/gym/eval/', 127 | help='directory to save agent eval-logs (default: /tmp/gym/eval/)') 128 | parser.add_argument( 129 | '--log-dir', 130 | default='/tmp/gym/', 131 | help='directory to save agent logs (default: /tmp/gym)') 132 | parser.add_argument( 133 | '--save-dir', 134 | default='./trained_models/', 135 | help='directory to save agent logs (default: ./trained_models/)') 136 | parser.add_argument( 137 | '--cuda-device', 138 | type=int, 139 | default=0, 140 | help='index of cuda device to use') 141 | parser.add_argument( 142 | '--no-cuda', 143 | action='store_true', 144 | default=False, 145 | help='disables CUDA training') 146 | parser.add_argument( 147 | '--use-proper-time-limits', 148 | action='store_true', 149 | default=False, 150 | help='compute returns taking into account time limits') 151 | parser.add_argument( 152 | '--recurrent-policy', 153 | action='store_true', 154 | default=False, 155 | help='use a recurrent policy') 156 | parser.add_argument( 157 | '--use-linear-lr-decay', 158 | action='store_true', 159 | default=False, 160 | help='use a linear schedule on the learning rate') 161 | 162 | args = parser.parse_args(arg_dict) 163 | args.cuda = not args.no_cuda and torch.cuda.is_available() 164 | 165 | return args 166 | -------------------------------------------------------------------------------- /scripts/get_tperf_args.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | """ 10 | Utility script to obtain the arguments to be provided to tperf to test a trained model. 11 | 12 | Usage: 13 | get_tperf_args.py 14 | where `exp_folder` is the base experiment folder. 15 | """ 16 | 17 | import argparse 18 | import json 19 | import logging 20 | import os 21 | import subprocess 22 | import sys 23 | 24 | from pathlib import Path 25 | from typing import Any, Dict, List, Optional 26 | 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | TPERF_SETTINGS = [ 31 | "cc_env_mode", 32 | "cc_env_model_file", 33 | "cc_env_job_id", 34 | ] 35 | 36 | # These arguments are to be read from the experiment training configuration, to be 37 | # forwarded to the `tperf` executable. 38 | TPERF_FLAGS = [ 39 | "cc_env_agg", 40 | "cc_env_time_window_ms", 41 | "cc_env_fixed_window_size", 42 | "cc_env_use_state_summary", 43 | "cc_env_history_size", 44 | "cc_env_norm_ms", 45 | "cc_env_norm_bytes", 46 | "cc_env_actions", 47 | "cc_env_reward_log_ratio", 48 | "cc_env_reward_throughput_factor", 49 | "cc_env_reward_throughput_log_offset", 50 | "cc_env_reward_delay_factor", 51 | "cc_env_reward_delay_log_offset", 52 | "cc_env_reward_packet_loss_factor", 53 | "cc_env_reward_packet_loss_log_offset", 54 | "cc_env_reward_max_delay", 55 | "cc_env_fixed_cwnd", 56 | "cc_env_min_rtt_window_length_us", 57 | ] 58 | 59 | 60 | def get_cc_env_actions(actions: List[str]) -> str: 61 | """ 62 | Obtain list of actions in the format suitable to tperf. 63 | 64 | In 'meta.json', actions are given as a list of strings, but somehow the "+" 65 | sign in front of a positive integer is lost. We restore it here. 66 | """ 67 | tokens = [] 68 | for action in actions: 69 | if action == "0" or any(action.startswith(c) for c in ["+", "-", "*", "/"]): 70 | tokens.append(action) 71 | else: 72 | # Verify that this is indeed a positive integer. 73 | try: 74 | int(action) 75 | except ValueError: 76 | raise NotImplementedError(f"Unsupported action: {action}") 77 | tokens.append(f"+{action}") 78 | return ",".join(tokens) 79 | 80 | 81 | def init_logger() -> None: 82 | """Initialize logger""" 83 | logger.addHandler(logging.StreamHandler(stream=sys.stdout)) 84 | logger.setLevel(logging.INFO) 85 | 86 | 87 | def parse_args() -> Any: 88 | """Parse commnad line arguments""" 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("path") 91 | return parser.parse_args() 92 | 93 | 94 | def to_cmd_line( 95 | flags: Dict[str, Any], exclude: Optional[List[str]] = None 96 | ) -> List[str]: 97 | """ 98 | Convert experiment flags into command-line arguments. 99 | """ 100 | exclude = set() if exclude is None else set(exclude) 101 | args: List[str] = [f"{k}={to_str(v)}" for k, v in flags.items() if k not in exclude] 102 | return args 103 | 104 | 105 | def to_str(val: Any) -> str: 106 | """ 107 | Convert a value to its string representation to be used on the command line. 108 | """ 109 | if val is None: 110 | return "null" 111 | elif isinstance(val, str): 112 | return val 113 | elif isinstance(val, (int, float)): 114 | return str(val) 115 | elif isinstance(val, list): 116 | return f"[{','.join(to_str(v) for v in val)}]" 117 | else: 118 | raise NotImplementedError( 119 | f"Unsupported type '{type(val).__name__}' with value: `{val}`" 120 | ) 121 | 122 | 123 | def trace(path: Path, flags: Dict[str, Any]) -> None: 124 | """ 125 | Trace the trained model. 126 | """ 127 | cwd = Path.cwd() 128 | script_dir = Path(__file__).absolute().parent 129 | os.chdir(script_dir / "..") 130 | try: 131 | cmd_line = to_cmd_line(flags, exclude=["mode"]) 132 | cmd = ["python", "-m", "train.train", "mode=trace"] + cmd_line 133 | logger.info("Tracing model, command:\n " + " ".join(cmd)) 134 | subprocess.check_call(cmd) 135 | finally: 136 | os.chdir(cwd) 137 | 138 | 139 | def main() -> int: 140 | """ 141 | Script entry point. 142 | """ 143 | init_logger() 144 | args = parse_args() 145 | 146 | # Load settings. 147 | path = Path(args.path) 148 | meta_path = path / "train" / "meta.json" 149 | assert meta_path.is_file(), "meta.json not found!" 150 | with meta_path.open() as f: 151 | meta = json.load(f) 152 | flags = meta["flags"] 153 | 154 | # Check that the config is valid. 155 | assert not flags[ 156 | "use_job_id_in_actor" 157 | ], "providing the job ID is not currently supported" 158 | 159 | # Trace model if needed. 160 | traced_path = path / "traced_model.pt" 161 | if traced_path.exists(): 162 | logger.info("Traced model already exists") 163 | else: 164 | trace(path, flags) 165 | 166 | # Extract relevant tperf settings. 167 | tperf_args = {f: flags[f] for f in TPERF_FLAGS} 168 | 169 | # Post-process / add arguments. 170 | tperf_args.update( 171 | { 172 | "congestion": "rl", 173 | "cc_env_job_id": -1, 174 | "cc_env_mode": "local", 175 | "cc_env_model_file": str(traced_path), 176 | "cc_env_actions": get_cc_env_actions(flags["cc_env_actions"]), 177 | } 178 | ) 179 | 180 | # Output tperf args. 181 | tperf_cmd = " ".join(f"-{k}='{to_str(v)}'" for k, v in tperf_args.items()) 182 | logger.info("tperf command line arguments:\n%s", tperf_cmd) 183 | 184 | return 0 185 | 186 | 187 | if __name__ == "__main__": 188 | sys.exit(main()) 189 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mvfst-rl 2 | 3 | `mvfst-rl` is a framework for network congestion control in the QUIC transport protocol 4 | that leverages state-of-the-art in asynchronous Reinforcement Learning training with 5 | off-policy correction. It is built upon the following components: 6 | 7 | 1. [mvfst](https://github.com/facebookincubator/mvfst), an implementation of the IETF QUIC transport protocol. 8 | 2. [torchbeast](https://github.com/facebookresearch/torchbeast), a PyTorch implementation of asynchronous distributed deep RL. 9 | 3. [Pantheon](https://github.com/StanfordSNR/pantheon), a set of calibrated network emulators. 10 | 11 | ### MTEnv API 12 | 13 | If your objective is to experiment with new RL algorithms on congestion control tasks, you are encouraged 14 | to switch to the [mtenv](https://github.com/facebookresearch/mvfst-rl/tree/mtenv) branch. 15 | 16 | That branch implements in particular an [MTEnv](https://github.com/facebookresearch/mtenv)-compatible API 17 | that makes it easy to define a multi-task environment and interact with it, independently of the more complex 18 | IMPALA-based learning framework this project is based on. 19 | 20 | ### Asynchronous RL Agent 21 | 22 | ![alt text](figures/rl_agent.png "RL Agent") 23 | 24 | 25 | ### Training Architecture 26 | 27 | ![alt text](figures/training_architecture.png "Training Architecture") 28 | 29 | 30 | For more details, please refer to our [paper](https://arxiv.org/abs/1910.04054). 31 | 32 | ## Building mvfst-rl 33 | 34 | ### Ubuntu 20+ 35 | 36 | Pantheon requires Python 2 while `mvfst-rl` training requires Python 3.8+. The recommended setup is to explicitly use python2/python3 commands. 37 | 38 | For building with training support, it is recommended to have a conda environment first: 39 | ```shell 40 | conda create -n mvfst-rl python=3.8 -y && conda activate mvfst-rl 41 | ./setup.sh 42 | ``` 43 | 44 | If you have a previous installation and need to re-install from scratch after updating 45 | the code, run the following commands: 46 | ```shell 47 | conda activate base && conda env remove -n mvfst-rl 48 | conda create -n mvfst-rl python=3.8 -y && conda activate mvfst-rl 49 | ./setup.sh --clean 50 | ``` 51 | 52 | For building `mvfst-rl` in test-only or deployment mode, run the following script. 53 | This allows you to run a trained model exported via TorchScript purely in C++. 54 | ``` 55 | ./setup.sh --inference 56 | ``` 57 | 58 | ## Training 59 | 60 | Training can be run locally as follows: 61 | ```shell 62 | python3 -m train.train \ 63 | mode=train \ 64 | total_steps=1_000_000 \ 65 | num_actors=40 \ 66 | hydra.run.dir=/tmp/logs 67 | ``` 68 | 69 | The above starts 40 Pantheon instances in parallel that communicate with the torchbeast actors via RPC. 70 | To see the full list of training parameters, run `python3 -m train.train --help`. 71 | 72 | ## Hyper-parameter sweeps with Hydra 73 | 74 | `mvfst-rl` uses [Hydra](https://hydra.cc/), which in particular makes it easy to run 75 | hyper-parameter sweeps. Here is an example showing how to run three experiments with 76 | different learning rates on a [Slurm](https://slurm.schedmd.com/overview.html) cluster: 77 | ```shell 78 | python3 -m train.train \ 79 | mode=train \ 80 | test_after_train=false \ 81 | total_steps=1_000_000 \ 82 | num_actors=40 \ 83 | learning_rate=1e-5,1e-4,1e-3 \ 84 | hydra.sweep.dir='${oc.env:HOME}/tmp/logs_${now:%Y-%m-%d_%H-%M-%S}' \ 85 | hydra/launcher=_submitit_slurm -m 86 | ``` 87 | 88 | Note the following settings in the above example: 89 | * `test_after_train=false` skips running the test mode after training. This can be useful 90 | for instance when the machines on the cluster have not been setup with all the libraries 91 | required in test mode. 92 | * `learning_rate=1e-5,1e-4,1e-3`: this is the basic syntax to perform a parameter sweep. 93 | * `hydra.sweep.dir='${oc.env:HOME}/tmp/logs_${now:%Y-%m-%d_%H-%M-%S}'`: the base location for all logs 94 | (look into the `.submitit` subfolder inside that directory to access the jobs' stdout/stderr). 95 | * `hydra/launcher=_submitit_slurm`: the launcher used to run on Slurm. Hydra supports more 96 | launchers, see its [documentation](https://hydra.cc/docs/intro) for details (by default, 97 | the [joblib](https://hydra.cc/docs/plugins/joblib_launcher) launcher is also installed 98 | by `setup.sh` -- it allows running multiple jobs locally instead of on a cluster). 99 | Note that the launcher name must be prefixed with an underscore to match the config files 100 | under `config/hydra/launcher` (which you may edit to tweak launcher settings). 101 | * `-m`: to run Hydra in [multi-run mode](https://hydra.cc/docs/next/tutorials/basic/running_your_app/multi-run/). 102 | 103 | ## Monitoring training behavior 104 | 105 | The script `scripts/plotting/plot_sweep.py` can be used to plot training curves. 106 | Refer to comments in the script's header for instructions on how to execute it. 107 | 108 | It is also possible to use [TensorBoard](https://www.tensorflow.org/tensorboard): 109 | the data can be found in the `train/tensorboard` subfolder of an experiment's logs directory. 110 | 111 | 112 | ## Evaluation 113 | 114 | To test a trained model on all emulated Pantheon environments, run with `mode=test` as follows: 115 | ``` 116 | python3 -m train.train \ 117 | mode=test \ 118 | base_logdir=/tmp/logs 119 | ``` 120 | 121 | The above takes the `checkpoint.tar` file in `/tmp/logs`, traces the model via TorchScript, 122 | and runs inference in C++ (without RPC). 123 | 124 | ## Pantheon logs cleanup 125 | 126 | Pantheon generates temporary logs (in `_build/deps/pantheon/tmp`) that may take up a lot of space. 127 | It is advised to regularly run `scripts/clean_pantheon_logs.sh` to delete them (when no experiment is running). 128 | Note that when running jobs on a SLURM cluster, where a temporary local folder is made available to 129 | each job in `/scratch/slurm_tmpdir/$SLURM_JOB_ID`, this folder is used instead to store the logs 130 | (thus alleviating the need for manual cleanup). 131 | 132 | ## Contributing 133 | We would love to have you contribute to `mvfst-rl` or use it for your research. 134 | See the [CONTRIBUTING](CONTRIBUTING.md) file for how to help out. 135 | 136 | ## License 137 | mvfst-rl is licensed under the CC-BY-NC 4.0 license, as found in the LICENSE file. 138 | 139 | ## BibTeX 140 | 141 | ``` 142 | @article{mvfstrl2019, 143 | title={MVFST-RL: An Asynchronous RL Framework for Congestion Control with Delayed Actions}, 144 | author={Viswanath Sivakumar and Olivier Delalleau and Tim Rockt\"{a}schel and Alexander H. Miller and Heinrich K\"{u}ttler and Nantas Nardelli and Mike Rabbat and Joelle Pineau and Sebastian Riedel}, 145 | year={2019}, 146 | eprint={1910.04054}, 147 | archivePrefix={arXiv}, 148 | primaryClass={cs.LG}, 149 | url={https://arxiv.org/abs/1910.04054}, 150 | journal={NeurIPS Workshop on Machine Learning for Systems}, 151 | } 152 | ``` 153 | -------------------------------------------------------------------------------- /third-party/gala/gpu_gossip_buffer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ 8 | Gossip Buffer 9 | 10 | :author: Mido Assran 11 | :description: Class defines a shared-memory Gossip-Buffer, which allows 12 | multi-processed asynchronous agents on the same machine to communicate 13 | tensors to on one-another 14 | """ 15 | 16 | import copy 17 | import torch 18 | import logging 19 | 20 | logging.basicConfig(level=logging.INFO) 21 | 22 | 23 | class GossipBuffer: 24 | def __init__( 25 | self, 26 | topology, 27 | model, 28 | buffer_locks, 29 | read_events, 30 | write_events, 31 | sync_list, 32 | sync_freq=0, 33 | ): 34 | """ GossipBuffer """ 35 | 36 | self.topology = topology 37 | self.num_learners = len(topology) 38 | self.sync_list = sync_list 39 | self.sync_freq = sync_freq 40 | 41 | # Initialize message buffer (4-item object): 42 | # [0] -> Msg-Tensor 43 | # [1] -> Events recording peers that have read the message 44 | # [2] -> Events recording peer that has written the message 45 | # [3] -> Lock for safe access of Msg-Tensor 46 | self.msg_buffer = [] 47 | for rank in range(self.num_learners): 48 | msg = copy.deepcopy(model) 49 | msg.share_memory() 50 | r_events = read_events[rank] 51 | w_events = write_events[rank] 52 | lock = buffer_locks[rank] 53 | self.msg_buffer.append([msg, r_events, w_events, lock]) 54 | 55 | # Initialize each Read-Buffer as 'read' 56 | for msg_buffer in self.msg_buffer: 57 | read_event_list = msg_buffer[1] 58 | for event in read_event_list: 59 | event.set() 60 | 61 | def write_message(self, rank, model, rotate=False): 62 | """ 63 | Write agent 'rank's 'model' to a local 'boradcast buffer' that will be 64 | read by the out-neighbours defined in 'self.topology'. 65 | 66 | :param rank: Agent's rank in multi-agent graph topology 67 | :param model: Agent's torch neural network model 68 | :param rotate: Whether to alternate peers in graph topology 69 | 70 | Agents should only write to their own broadcast buffer: 71 | i.e., ensure 'model' belongs to agent 'rank' 72 | WARNING: setting rotate=True with sync_freq > 1 not currently supported 73 | """ 74 | with torch.no_grad(): 75 | 76 | # Get local broadcast-buffer 77 | msg_buffer = self.msg_buffer[rank] 78 | broadcast_buffer = msg_buffer[0] 79 | read_event_list = msg_buffer[1] 80 | write_event_list = msg_buffer[2] 81 | lock = msg_buffer[3] 82 | 83 | # Check if out-peers finished reading our last message 84 | out_peers, _ = self.topology[rank].get_peers() 85 | read_complete = True 86 | for peer in out_peers: 87 | if not read_event_list[peer].is_set(): 88 | read_complete = False 89 | break 90 | 91 | # If peers done reading our last message, wait and clear events 92 | if read_complete: 93 | for peer in out_peers: 94 | read_event = read_event_list[peer] 95 | read_event.wait() 96 | read_event.clear() 97 | # If not done reading, cannot write another message right now 98 | else: 99 | return 100 | 101 | # Update broadcast-buffer with new message 102 | # -- flatten params and multiply by mixing-weight 103 | num_peers = self.topology[rank].peers_per_itr 104 | with lock: 105 | for bp, p in zip(broadcast_buffer.parameters(), model.parameters()): 106 | bp.data.copy_(p) 107 | bp.data.div_(num_peers + 1) 108 | # -- mark message as 'written' 109 | out_peers, _ = self.topology[rank].get_peers(rotate) 110 | torch.cuda.current_stream().synchronize() 111 | for peer in out_peers: 112 | write_event_list[peer].set() 113 | 114 | def aggregate_message(self, rank, model): 115 | """ 116 | Average messages with local model: 117 | Average all in-neighbours' (defined in 'self.topology') parameters with 118 | agent 'rank's 'model' and copy the result into 'model'. 119 | 120 | Agents should only aggregate messages into their own model: 121 | i.e., ensure 'model belongs to agent 'rank' 122 | """ 123 | with torch.no_grad(): 124 | 125 | # Check if in-peers finished writing messages to broadcast buffers 126 | _, in_peers = self.topology[rank].get_peers() 127 | write_complete = True 128 | for peer in in_peers: 129 | peer_buffer = self.msg_buffer[peer] 130 | write_event = peer_buffer[2][rank] 131 | if not write_event.is_set(): 132 | write_complete = False 133 | break 134 | 135 | # Check if any messages are excessively stale 136 | stale_assert = self.sync_list[rank] >= self.sync_freq 137 | 138 | # If peers done writing or message too stale, wait and clear events 139 | if write_complete or stale_assert: 140 | for peer in in_peers: 141 | peer_buffer = self.msg_buffer[peer] 142 | write_event = peer_buffer[2][rank] 143 | write_event.wait() 144 | write_event.clear() 145 | self.sync_list[rank] = 0 146 | # Not done writing, but staleness is still tolerable 147 | else: 148 | self.sync_list[rank] += 1 149 | logging.info( 150 | "GALA agent %s: staleness %s" % (rank, self.sync_list[rank]) 151 | ) 152 | return 153 | 154 | # Lazy-mixing of local params 155 | num_peers = self.topology[rank].peers_per_itr 156 | for p in model.parameters(): 157 | p.data.div_(num_peers + 1) 158 | 159 | # Aggregate received messages 160 | for peer in in_peers: 161 | peer_buffer = self.msg_buffer[peer] 162 | lock = peer_buffer[3] 163 | with lock: 164 | # Read message and update 'params' 165 | peer_msg = peer_buffer[0] 166 | for p, bp in zip(model.parameters(), peer_msg.parameters()): 167 | p.data.add_(bp.to(p.device, non_blocking=True)) 168 | torch.cuda.current_stream().synchronize() 169 | # Mark message as 'read' 170 | peer_buffer[1][rank].set() 171 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlRPCEnv.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #include "CongestionControlRPCEnv.h" 10 | 11 | #include 12 | 13 | #include "Utils.h" 14 | 15 | using namespace grpc; 16 | using namespace torchbeast; 17 | 18 | namespace quic { 19 | 20 | namespace { 21 | constexpr std::chrono::seconds kConnectTimeout{5}; 22 | } 23 | 24 | CongestionControlRPCEnv::CongestionControlRPCEnv( 25 | const Config &cfg, Callback *cob, const QuicConnectionStateBase &conn) 26 | : CongestionControlEnv(cfg, cob, conn), actorId_(cfg.actorId) { 27 | tensor_ = torch::empty({0}, torch::kFloat32); 28 | thread_ = std::make_unique(&CongestionControlRPCEnv::loop, this, 29 | cfg.rpcAddress); 30 | 31 | // Wait until connected to gRPC server 32 | std::unique_lock lock(mutex_); 33 | cv_.wait(lock, [&]() -> bool { return connected_; }); 34 | } 35 | 36 | CongestionControlRPCEnv::~CongestionControlRPCEnv() { 37 | shutdown_ = true; 38 | thread_->join(); 39 | } 40 | 41 | void CongestionControlRPCEnv::onObservation(Observation &&obs, float reward) { 42 | std::unique_lock lock(mutex_, std::try_to_lock); 43 | if (!lock) { 44 | // If we can't acquire the mutex, then we haven't received the action 45 | // back for the previous observation. Although this should almost never 46 | // happen as model runtimes are sufficiently fast, we handle this safely 47 | // here by skipping this observation. 48 | LOG(WARNING) << __func__ << ": Still waiting for an update from " 49 | "ActorPoolServer, skipping observation."; 50 | return; 51 | } 52 | obs.toTensor(tensor_); 53 | reward_ = reward; 54 | observationReady_ = true; 55 | lock.unlock(); 56 | cv_.notify_one(); 57 | } 58 | 59 | void CongestionControlRPCEnv::loop(const std::string &address) { 60 | std::shared_ptr channel = 61 | grpc::CreateChannel(address, grpc::InsecureChannelCredentials()); 62 | auto stub = RPC::NewStub(channel); 63 | 64 | LOG(INFO) << "Connecting to ActorPoolServer at " << address << " ..."; 65 | const auto &deadline = std::chrono::system_clock::now() + kConnectTimeout; 66 | if (!channel->WaitForConnected(deadline)) { 67 | LOG(FATAL) << "Timed out connecting to ActorPoolServer: " << address; 68 | } 69 | 70 | // Notify that we are connected 71 | { 72 | std::lock_guard g(mutex_); 73 | connected_ = true; 74 | } 75 | cv_.notify_one(); 76 | LOG(INFO) << "Connected to ActorPoolServer: " << address; 77 | 78 | grpc::ClientContext context; 79 | std::shared_ptr> stream( 80 | stub->Call(&context)); 81 | 82 | Action action; 83 | bool done = true; 84 | uint32_t episode_step = 0; 85 | float episode_return = 0.0; 86 | CallResponse resp; 87 | std::unique_lock lock(mutex_); 88 | 89 | while (!shutdown_) { 90 | cv_.wait(lock, [&]() -> bool { return (observationReady_ || shutdown_); }); 91 | if (shutdown_) { 92 | LOG(INFO) << "RPC env loop terminating"; 93 | const auto &status = stream->Finish(); 94 | if (!status.ok()) { 95 | LOG(ERROR) << "RPC env loop failed on finish."; 96 | } 97 | return; 98 | } 99 | 100 | // The lifetime of a connection is seen as a single episode, so 101 | // done is set to true only at the beginning of the episode (to mark 102 | // the end of the previous episode. Episodic training should be 103 | // implemented via resetting the entire connection. 104 | done = (episode_step == 0); 105 | episode_return += reward_; 106 | VLOG(2) << "Episode step = " << episode_step 107 | << ", total return = " << episode_return; 108 | 109 | const auto &req = makeCallRequest(actorId_, tensor_, reward_, done); 110 | observationReady_ = false; // Back to waiting 111 | 112 | stream->Write(req); 113 | if (!stream->Read(&resp)) { 114 | LOG(FATAL) << "Read failed from gRPC server."; 115 | } 116 | if (resp.has_error()) { 117 | LOG(FATAL) << "Error in response from RL server: " 118 | << resp.error().message(); 119 | } 120 | action.cwndAction = getActionFromCallResponse(resp); 121 | onAction(action); 122 | 123 | episode_step++; 124 | } 125 | } 126 | 127 | CallRequest CongestionControlRPCEnv::makeCallRequest(int64_t actorId, 128 | const torch::Tensor &obs, 129 | float reward, bool done) { 130 | // We need the same run Id across episodes per actor to ensure reconnects to 131 | // RL server at the beginning of each episode fills the rollout buffer 132 | // correctly. 133 | int64_t runId = 0; 134 | 135 | TensorNest actorIdNest(torch::from_blob(&actorId, {}, at::kLong)); 136 | TensorNest runIdNest(torch::from_blob(&runId, {}, at::kLong)); 137 | TensorNest obsNest(obs); 138 | TensorNest rewardNest(torch::from_blob(&reward, {}, at::kFloat)); 139 | TensorNest doneNest(torch::from_blob(&done, {}, at::kBool)); 140 | 141 | // Input is an ArrayNest of (actor_id, run_id, (observation, reward, done)). 142 | TensorNest stepNest( 143 | quic::utils::vector({obsNest, rewardNest, doneNest})); 144 | TensorNest inputs( 145 | quic::utils::vector({actorIdNest, runIdNest, stepNest})); 146 | 147 | CallRequest req; 148 | req.set_function("inference"); 149 | fill_nest_pb(req.mutable_inputs(), std::move(inputs), fillNDArrayPB); 150 | return req; 151 | } 152 | 153 | uint32_t CongestionControlRPCEnv::getActionFromCallResponse( 154 | torchbeast::CallResponse &resp) { 155 | TensorNest output = nest_pb_to_nest(resp.mutable_outputs(), arrayPBToNest); 156 | 157 | // Output should be a single tensor containing the action 158 | CHECK(output.is_leaf()); 159 | const torch::Tensor &actionTensor = output.front(); 160 | CHECK_EQ(actionTensor.numel(), 1); 161 | return *actionTensor.data_ptr(); 162 | } 163 | 164 | void CongestionControlRPCEnv::fillNDArrayPB(NDArray *ndarray, 165 | const torch::Tensor &tensor) { 166 | for (const auto &dim : tensor.sizes()) { 167 | ndarray->add_shape(dim); 168 | } 169 | ndarray->set_dtype(quic::utils::aten_to_numpy_dtype(tensor.scalar_type())); 170 | ndarray->set_data(tensor.contiguous().data_ptr(), tensor.nbytes()); 171 | } 172 | 173 | TensorNest 174 | CongestionControlRPCEnv::arrayPBToNest(torchbeast::NDArray *ndarray) { 175 | quic::utils::vector shape; 176 | for (int i = 0, length = ndarray->shape_size(); i < length; ++i) { 177 | shape.push_back(ndarray->shape(i)); 178 | } 179 | 180 | std::string *data = ndarray->release_data(); 181 | at::ScalarType dtype = quic::utils::numpy_dtype_to_aten(ndarray->dtype()); 182 | 183 | return TensorNest(torch::from_blob( 184 | data->data(), shape, [data](void *ptr) { delete data; }, dtype)); 185 | } 186 | 187 | } // namespace quic 188 | -------------------------------------------------------------------------------- /train/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import functools 9 | import logging 10 | import operator 11 | 12 | from typing import Optional 13 | 14 | import torch 15 | from torch import nn 16 | from torch.nn import functional as F 17 | 18 | import nest 19 | 20 | 21 | class SimpleNet(nn.Module): 22 | def __init__( 23 | self, 24 | input_size: int, 25 | hidden_size: int, 26 | num_actions: int, 27 | use_lstm: bool = False, 28 | n_train_jobs: Optional[int] = None, 29 | use_job_id_in_actor: bool = False, 30 | use_job_id_in_critic: bool = False, 31 | ): 32 | super(SimpleNet, self).__init__() 33 | self.num_actions = num_actions 34 | self.use_lstm = use_lstm 35 | # If provided, `n_train_jobs` is the total number of jobs (= different environments) 36 | # this model is trained on. It is required to provide a one-hot vector of the 37 | # corresponding size as input to the actor and/or critic (depending on the flags 38 | # `use_job_id_in_{actor,critic}`). 39 | # Note that if there is only one training job, it is ignored since there is no 40 | # need to differentiate between multiple environments. 41 | self.use_job_id_in_network_input = False 42 | self.use_job_id_in_actor_head = False 43 | self.use_job_id_in_critic_head = False 44 | if use_job_id_in_actor or use_job_id_in_critic: 45 | assert n_train_jobs is not None and n_train_jobs >= 1, n_train_jobs 46 | self.n_train_jobs = 0 if n_train_jobs == 1 else n_train_jobs 47 | if self.n_train_jobs > 0: 48 | # If the job ID is provided as input to both actor and critic, then 49 | # its one-hot representation is given as input to the first layer. 50 | # Otherwise, it will be provided only to the actor or critic head. 51 | if use_job_id_in_actor and use_job_id_in_critic: 52 | self.use_job_id_in_network_input = True 53 | else: 54 | self.use_job_id_in_actor_head = use_job_id_in_actor 55 | self.use_job_id_in_critic_head = use_job_id_in_critic 56 | else: 57 | self.n_train_jobs = 0 # not used 58 | 59 | # Feature extraction. 60 | # The first layer's input size is decreased by 1 because we remove the integer 61 | # representation of the job ID from the input. 62 | self.fc1 = nn.Linear( 63 | input_size - 1 + self.n_train_jobs * self.use_job_id_in_network_input, 64 | hidden_size, 65 | ) 66 | self.fc2 = nn.Linear(hidden_size, hidden_size) 67 | 68 | # FC output size + last reward. 69 | core_output_size = self.fc2.out_features + 1 70 | 71 | if use_lstm: 72 | self.core = nn.LSTMCell(core_output_size, core_output_size) 73 | 74 | self.policy = nn.Linear( 75 | core_output_size + self.n_train_jobs * self.use_job_id_in_actor_head, 76 | self.num_actions, 77 | ) 78 | self.baseline = nn.Linear( 79 | core_output_size + self.n_train_jobs * self.use_job_id_in_critic_head, 80 | 1, 81 | ) 82 | 83 | def concat_to_job_id(self, tensor, job_id): 84 | """Concatenate a 2D tensor with the one-hot encoding associated to `job_id`""" 85 | one_hot = ( 86 | F.one_hot(job_id.flatten().long(), self.n_train_jobs) 87 | .type(tensor.dtype) 88 | .to(tensor.device) 89 | ) 90 | return torch.cat((tensor, one_hot), dim=1) 91 | 92 | def initial_state(self, batch_size=1): 93 | # Always return a tuple of two tensors so torch script type-checking 94 | # passes. It's sufficient for core state to be 95 | # Tuple[Tensor, Tensor] - the shapes don't matter. 96 | if self.use_lstm: 97 | core_hidden_size = self.core.hidden_size 98 | else: 99 | core_hidden_size = 0 100 | 101 | return tuple(torch.zeros(batch_size, core_hidden_size) for _ in range(2)) 102 | 103 | def forward(self, last_actions, env_outputs, core_state, unroll=False): 104 | if not unroll: 105 | # [T=1, B, ...]. 106 | env_outputs = nest.map(lambda t: t.unsqueeze(0), env_outputs) 107 | 108 | observation, reward, done = env_outputs 109 | 110 | T, B, *_ = observation.shape 111 | x = torch.flatten(observation, 0, 1) # Merge time and batch. 112 | x = x.view(T * B, -1) 113 | 114 | # Separate the job ID from the rest of the observation. 115 | job_id = x[:, -1] 116 | x = x[:, 0:-1] 117 | 118 | if self.use_job_id_in_network_input: 119 | x = self.concat_to_job_id(x, job_id) 120 | 121 | x = F.relu(self.fc1(x)) 122 | x = F.relu(self.fc2(x)) 123 | 124 | # reward = torch.clamp(reward, -1, 1).view(T * B, 1).float() 125 | reward = reward.view(T * B, 1).float() 126 | core_input = torch.cat([x, reward], dim=1) 127 | 128 | if self.use_lstm: 129 | core_input = core_input.view(T, B, -1) 130 | core_output_list = [] 131 | notdone = (~done).float() 132 | notdone.unsqueeze_(-1) # [T, B, H=1] for broadcasting. 133 | 134 | for input_t, notdone_t in zip(core_input.unbind(), notdone.unbind()): 135 | # When `done` is True it means this is the first step in a new 136 | # episode => reset the internal state to zero. 137 | core_state = nest.map(notdone_t.mul, core_state) 138 | output_t, core_state = self.core(input_t, core_state) 139 | core_state = (output_t, core_state) # nn.LSTMCell is a bit weird. 140 | core_output_list.append(output_t) # [[B, H], [B, H], ...]. 141 | core_output = torch.cat(core_output_list) # [T * B, H]. 142 | else: 143 | core_output = core_input 144 | 145 | actor_input = ( 146 | self.concat_to_job_id(core_output, job_id) 147 | if self.use_job_id_in_actor_head 148 | else core_output 149 | ) 150 | policy_logits = self.policy(actor_input) 151 | 152 | if self.training: 153 | action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) 154 | 155 | critic_input = ( 156 | self.concat_to_job_id(core_output, job_id) 157 | if self.use_job_id_in_critic_head 158 | else core_output 159 | ) 160 | baseline = self.baseline(critic_input) 161 | 162 | baseline = baseline.view(T, B) 163 | 164 | else: 165 | # Don't sample when testing. 166 | action = torch.argmax(policy_logits, dim=1) 167 | 168 | policy_logits = policy_logits.view(T, B, self.num_actions) 169 | action = action.view(T, B) 170 | 171 | if self.training: 172 | outputs = dict( 173 | action=action, policy_logits=policy_logits, baseline=baseline 174 | ) 175 | if not unroll: 176 | outputs = nest.map(lambda t: t.squeeze(0), outputs) 177 | return outputs, core_state 178 | else: 179 | # In eval mode, we just return (action, core_state). PyTorch doesn't 180 | # support jit tracing output dicts. 181 | return action, core_state 182 | -------------------------------------------------------------------------------- /traffic_gen/ExampleClient.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include 23 | 24 | namespace quic { 25 | namespace traffic_gen { 26 | 27 | class ExampleClient : public quic::QuicSocket::ConnectionCallback, 28 | public quic::QuicSocket::ReadCallback, 29 | public quic::QuicSocket::WriteCallback, 30 | public quic::QuicSocket::DataExpiredCallback { 31 | public: 32 | ExampleClient(const std::string& host, uint16_t port, 33 | CongestionControlType cc_algo = CongestionControlType::Cubic, 34 | std::shared_ptr ccFactory = 35 | std::make_shared()) 36 | : addr_(host.c_str(), port), cc_algo_(cc_algo), ccFactory_(ccFactory) {} 37 | 38 | void readAvailable(quic::StreamId streamId) noexcept override { 39 | auto readData = quicClient_->read(streamId, 0); 40 | if (readData.hasError()) { 41 | LOG(ERROR) << "ExampleClient failed read from stream=" << streamId 42 | << ", error=" << (uint32_t)readData.error(); 43 | } 44 | auto copy = readData->first->clone(); 45 | if (recvOffsets_.find(streamId) == recvOffsets_.end()) { 46 | recvOffsets_[streamId] = copy->length(); 47 | } else { 48 | recvOffsets_[streamId] += copy->length(); 49 | } 50 | VLOG_EVERY_N(2, 1000) << "Client received data=" 51 | << copy->computeChainDataLength() 52 | << " bytes on stream=" << streamId; 53 | } 54 | 55 | void readError( 56 | quic::StreamId streamId, 57 | std::pair> 58 | error) noexcept override { 59 | LOG(ERROR) << "ExampleClient failed read from stream=" << streamId 60 | << ", error=" << toString(error); 61 | // A read error only terminates the ingress portion of the stream state. 62 | // Your application should probably terminate the egress portion via 63 | // resetStream 64 | } 65 | 66 | void onNewBidirectionalStream(quic::StreamId id) noexcept override { 67 | LOG(INFO) << "ExampleClient: new bidirectional stream=" << id; 68 | quicClient_->setReadCallback(id, this); 69 | } 70 | 71 | void onNewUnidirectionalStream(quic::StreamId id) noexcept override { 72 | LOG(INFO) << "ExampleClient: new unidirectional stream=" << id; 73 | quicClient_->setReadCallback(id, this); 74 | } 75 | 76 | void onStopSending(quic::StreamId id, 77 | quic::ApplicationErrorCode /*error*/) noexcept override { 78 | VLOG(10) << "ExampleClient got StopSending stream id=" << id; 79 | } 80 | 81 | void onTransportReady() noexcept override { 82 | LOG(INFO) << "ExampleClient connected to " << addr_.describe(); 83 | auto streamId = quicClient_->createBidirectionalStream().value(); 84 | quicClient_->setReadCallback(streamId, this); 85 | pendingOutput_[streamId].append(folly::IOBuf::copyBuffer("hello")); 86 | sendMessage(streamId, pendingOutput_[streamId]); 87 | } 88 | 89 | void onConnectionEnd() noexcept override { 90 | LOG(INFO) << "ExampleClient connection end"; 91 | } 92 | 93 | void onConnectionError( 94 | std::pair error) noexcept override { 95 | LOG_EVERY_N(ERROR, 100) << "ExampleClient error connecting to " 96 | << addr_.describe() << " - " 97 | << toString(error.first) << ". Trying again..."; 98 | connect(); 99 | } 100 | 101 | void onStreamWriteReady(quic::StreamId id, 102 | uint64_t maxToSend) noexcept override { 103 | LOG(INFO) << "ExampleClient socket is write ready with maxToSend=" 104 | << maxToSend; 105 | sendMessage(id, pendingOutput_[id]); 106 | } 107 | 108 | void onStreamWriteError( 109 | quic::StreamId id, 110 | std::pair> 111 | error) noexcept override { 112 | LOG(ERROR) << "ExampleClient write error with stream=" << id 113 | << " error=" << toString(error); 114 | } 115 | 116 | void onDataExpired(StreamId streamId, uint64_t newOffset) noexcept override { 117 | LOG(INFO) << "Client received skipData; " 118 | << newOffset - recvOffsets_[streamId] 119 | << " bytes skipped on stream=" << streamId; 120 | } 121 | 122 | void connect() { 123 | auto sock = std::make_unique(evb_); 124 | auto fizzClientContext = 125 | FizzClientQuicHandshakeContext::Builder() 126 | .setCertificateVerifier( 127 | std::make_unique()) 128 | .build(); 129 | quicClient_ = std::make_shared( 130 | evb_, std::move(sock), std::move(fizzClientContext)); 131 | quicClient_->setHostname("example.org"); 132 | quicClient_->addNewPeerAddress(addr_); 133 | quicClient_->setCongestionControllerFactory(ccFactory_); 134 | 135 | TransportSettings settings; 136 | settings.defaultCongestionController = cc_algo_; 137 | 138 | // Often times flow control becomes the bottleneck and prevents accurate 139 | // analysis of congestion control. Effectively disable it by setting a large 140 | // window size. 141 | settings.advertisedInitialConnectionWindowSize = 1e8; 142 | settings.advertisedInitialBidiLocalStreamWindowSize = 1e8; 143 | settings.advertisedInitialBidiRemoteStreamWindowSize = 1e8; 144 | settings.advertisedInitialUniStreamWindowSize = 1e8; 145 | 146 | quicClient_->setTransportSettings(settings); 147 | quicClient_->start(this); 148 | } 149 | 150 | void start() { 151 | folly::ScopedEventBaseThread networkThread("ExampleClientThread"); 152 | evb_ = networkThread.getEventBase(); 153 | 154 | evb_->runInEventBaseThreadAndWait([&] { 155 | LOG(INFO) << "ExampleClient connecting to " << addr_.describe(); 156 | connect(); 157 | }); 158 | 159 | // Loop forever 160 | while (true) { 161 | } 162 | 163 | LOG(INFO) << "ExampleClient stopping"; 164 | } 165 | 166 | ~ExampleClient() override = default; 167 | 168 | private: 169 | void sendMessage(quic::StreamId id, folly::IOBufQueue& data) { 170 | auto message = data.move(); 171 | auto res = quicClient_->writeChain(id, message->clone(), true, false); 172 | if (res.hasError()) { 173 | LOG(ERROR) << "ExampleClient writeChain error=" << uint32_t(res.error()); 174 | } else { 175 | auto str = message->moveToFbString().toStdString(); 176 | LOG(INFO) << "ExampleClient wrote \"" << str << "\"" 177 | << ", len=" << str.size() << " on stream=" << id; 178 | // sent whole message 179 | pendingOutput_.erase(id); 180 | } 181 | } 182 | 183 | folly::SocketAddress addr_; 184 | CongestionControlType cc_algo_; 185 | std::shared_ptr ccFactory_; 186 | 187 | std::shared_ptr quicClient_; 188 | std::map pendingOutput_; 189 | std::map recvOffsets_; 190 | folly::EventBase* evb_{nullptr}; 191 | }; 192 | 193 | } // namespace traffic_gen 194 | } // namespace quic 195 | -------------------------------------------------------------------------------- /traffic_gen/main.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "congestion_control/RLCongestionControllerFactory.h" 19 | #include "traffic_gen/ExampleClient.h" 20 | #include "traffic_gen/ExampleServer.h" 21 | 22 | DEFINE_string(host, "::1", "Server hostname/IP"); 23 | DEFINE_int32(port, 6666, "Server port"); 24 | DEFINE_string(mode, "server", "Mode to run in: 'client' or 'server'"); 25 | DEFINE_int32(chunk_size, 1024 * 1024, "Chunk size to send at once"); 26 | DEFINE_string(cc_algo, "cubic", "Congestion Control algorithm to use"); 27 | DEFINE_string( 28 | cc_env_mode, "local", 29 | "CongestionControlEnv mode for RL cc_algo - [local|remote|random|fixed]"); 30 | DEFINE_int64(cc_env_actor_id, 0, 31 | "For use in training to uniquely identify an actor across " 32 | "episodic connections to RL server."); 33 | DEFINE_int64(cc_env_job_id, -1, 34 | "Index of the current job in the list of active jobs. -1 if undefined."); 35 | DEFINE_string(cc_env_model_file, "traced_model.pt", 36 | "PyTorch traced model file for local mode"); 37 | DEFINE_string( 38 | cc_env_rpc_address, "unix:/tmp/rl_server_path", 39 | "CongestionControlRPCEnv RL server address for training. Could " 40 | "be either : or Unix domain socket path unix:."); 41 | DEFINE_string(cc_env_agg, "time", "State aggregation type for RL cc_algo"); 42 | DEFINE_int32(cc_env_time_window_ms, 100, 43 | "Window duration (ms) for TIME_WINDOW aggregation"); 44 | DEFINE_int32(cc_env_fixed_window_size, 10, 45 | "Window size for FIXED_WINDOW aggregation"); 46 | DEFINE_bool(cc_env_use_state_summary, true, 47 | "Whether to use state summary instead of raw states in " 48 | "observation (auto-enabled for TIME_WINDOW)"); 49 | DEFINE_int32( 50 | cc_env_history_size, 2, 51 | "Length of history (such as past actions) to include in observation"); 52 | DEFINE_double( 53 | cc_env_norm_ms, 100.0, 54 | "Normalization factor for temporal (in ms) fields in observation"); 55 | DEFINE_double(cc_env_norm_bytes, 1000.0, 56 | "Normalization factor for byte fields in observation"); 57 | DEFINE_string(cc_env_actions, "0,/2,-10,+10,*2", 58 | "List of actions specifying how cwnd should be updated. The " 59 | "first action is required to be 0 (no-op action)."); 60 | DEFINE_bool( 61 | cc_env_reward_log_ratio, false, 62 | "If true, then instead of " 63 | " a * throughput - b * delay - c * loss " 64 | "we use as reward " 65 | " a * log(a' + throughput) - b * log(b' + delay) - c * log(c' + loss)"); 66 | DEFINE_double(cc_env_reward_throughput_factor, 0.1, 67 | "Throughput multiplier in reward (a)"); 68 | DEFINE_double(cc_env_reward_throughput_log_offset, 1.0, 69 | "Offset to add to throughput in log version (a')"); 70 | DEFINE_double(cc_env_reward_delay_factor, 0.01, 71 | "Delay multiplier in reward (b)"); 72 | DEFINE_double(cc_env_reward_delay_log_offset, 1.0, 73 | "Offset to add to delay in log version (b')"); 74 | DEFINE_double(cc_env_reward_packet_loss_factor, 0.0, 75 | "Packet loss multiplier in reward (c)"); 76 | DEFINE_double(cc_env_reward_packet_loss_log_offset, 1.0, 77 | "Offset to add to packet loss in log version (c')"); 78 | DEFINE_bool(cc_env_reward_max_delay, true, 79 | "Whether to take max delay over observations in reward." 80 | "Otherwise, avg delay is used."); 81 | DEFINE_uint32(cc_env_fixed_cwnd, 10, 82 | "Target fixed cwnd value (only used in 'fixed' env mode)"); 83 | DEFINE_uint64(cc_env_min_rtt_window_length_us, 84 | quic::kMinRTTWindowLength.count(), 85 | "Window length (in us) of min RTT filter used to estimate delay"); 86 | 87 | using namespace quic::traffic_gen; 88 | using Config = quic::CongestionControlEnv::Config; 89 | 90 | std::shared_ptr 91 | makeRLCongestionControllerFactory() { 92 | Config cfg; 93 | 94 | if (FLAGS_cc_env_mode == "local") { 95 | cfg.mode = Config::Mode::LOCAL; 96 | } else if (FLAGS_cc_env_mode == "remote") { 97 | cfg.mode = Config::Mode::REMOTE; 98 | } else if (FLAGS_cc_env_mode == "random") { 99 | cfg.mode = Config::Mode::RANDOM; 100 | } else if (FLAGS_cc_env_mode == "fixed") { 101 | cfg.mode = Config::Mode::FIXED; 102 | } else { 103 | LOG(FATAL) << "Unknown cc_env_mode: " << FLAGS_cc_env_mode; 104 | } 105 | 106 | cfg.modelFile = FLAGS_cc_env_model_file; 107 | cfg.rpcAddress = FLAGS_cc_env_rpc_address; 108 | cfg.actorId = FLAGS_cc_env_actor_id; 109 | cfg.jobId = FLAGS_cc_env_job_id; 110 | 111 | if (FLAGS_cc_env_agg == "time") { 112 | cfg.aggregation = Config::Aggregation::TIME_WINDOW; 113 | } else if (FLAGS_cc_env_agg == "fixed") { 114 | cfg.aggregation = Config::Aggregation::FIXED_WINDOW; 115 | } else { 116 | LOG(FATAL) << "Unknown cc_env_agg: " << FLAGS_cc_env_agg; 117 | } 118 | cfg.windowDuration = std::chrono::milliseconds(FLAGS_cc_env_time_window_ms); 119 | cfg.windowSize = FLAGS_cc_env_fixed_window_size; 120 | cfg.useStateSummary = FLAGS_cc_env_use_state_summary; 121 | 122 | cfg.historySize = FLAGS_cc_env_history_size; 123 | 124 | cfg.normMs = FLAGS_cc_env_norm_ms; 125 | cfg.normBytes = FLAGS_cc_env_norm_bytes; 126 | 127 | cfg.parseActionsFromString(FLAGS_cc_env_actions); 128 | 129 | cfg.rewardLogRatio = FLAGS_cc_env_reward_log_ratio; 130 | cfg.throughputFactor = FLAGS_cc_env_reward_throughput_factor; 131 | cfg.throughputLogOffset = FLAGS_cc_env_reward_throughput_log_offset; 132 | cfg.delayFactor = FLAGS_cc_env_reward_delay_factor; 133 | cfg.delayLogOffset = FLAGS_cc_env_reward_delay_log_offset; 134 | cfg.packetLossFactor = FLAGS_cc_env_reward_packet_loss_factor; 135 | cfg.packetLossLogOffset = FLAGS_cc_env_reward_packet_loss_log_offset; 136 | cfg.maxDelayInReward = FLAGS_cc_env_reward_max_delay; 137 | cfg.fixedCwnd = FLAGS_cc_env_fixed_cwnd; 138 | cfg.minRTTWindowLength = 139 | std::chrono::microseconds(FLAGS_cc_env_min_rtt_window_length_us); 140 | 141 | auto envFactory = std::make_shared(cfg); 142 | return std::make_shared(envFactory); 143 | } 144 | 145 | int main(int argc, char *argv[]) { 146 | #if FOLLY_HAVE_LIBGFLAGS 147 | // Enable glog logging to stderr by default. 148 | gflags::SetCommandLineOptionWithMode("logtostderr", "1", 149 | gflags::SET_FLAGS_DEFAULT); 150 | #endif 151 | gflags::ParseCommandLineFlags(&argc, &argv, false); 152 | folly::Init init(&argc, &argv); 153 | fizz::CryptoUtils::init(); 154 | 155 | quic::CongestionControlType cc_algo; 156 | std::shared_ptr ccFactory = 157 | std::make_shared(); 158 | if (FLAGS_cc_algo == "cubic") { 159 | cc_algo = quic::CongestionControlType::Cubic; 160 | } else if (FLAGS_cc_algo == "newreno") { 161 | cc_algo = quic::CongestionControlType::NewReno; 162 | } else if (FLAGS_cc_algo == "copa") { 163 | cc_algo = quic::CongestionControlType::Copa; 164 | } else if (FLAGS_cc_algo == "bbr") { 165 | cc_algo = quic::CongestionControlType::BBR; 166 | } else if (FLAGS_cc_algo == "rl") { 167 | cc_algo = quic::CongestionControlType::None; 168 | ccFactory = makeRLCongestionControllerFactory(); 169 | } else if (FLAGS_cc_algo == "none") { 170 | cc_algo = quic::CongestionControlType::None; 171 | } else { 172 | LOG(ERROR) << "Unknown cc_algo " << FLAGS_cc_algo; 173 | return -1; 174 | } 175 | 176 | if (FLAGS_mode == "server") { 177 | ExampleServer server(FLAGS_host, FLAGS_port, cc_algo, ccFactory); 178 | server.start(); 179 | } else if (FLAGS_mode == "client") { 180 | if (FLAGS_host.empty() || FLAGS_port == 0) { 181 | LOG(ERROR) << "ExampleClient expected --host and --port"; 182 | return -2; 183 | } 184 | ExampleClient client(FLAGS_host, FLAGS_port, cc_algo, ccFactory); 185 | client.start(); 186 | } else { 187 | LOG(ERROR) << "Unknown mode specified: " << FLAGS_mode; 188 | return -1; 189 | } 190 | return 0; 191 | } 192 | -------------------------------------------------------------------------------- /congestion_control/RLCongestionController.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #include 10 | #include 11 | #include 12 | 13 | #include "NetworkState.h" 14 | #include "RLCongestionController.h" 15 | 16 | namespace quic { 17 | 18 | using namespace std::chrono; 19 | 20 | using Field = NetworkState::Field; 21 | 22 | RLCongestionController::RLCongestionController( 23 | QuicConnectionStateBase &conn, 24 | std::shared_ptr envFactory) 25 | : conn_(conn), 26 | cwndBytes_(conn.transportSettings.initCwndInMss * conn.udpSendPacketLen), 27 | env_(envFactory->make(this, conn)), 28 | minRTTFilter_(kMinRTTWindowLength.count(), 0us, 0), // length reset below 29 | standingRTTFilter_(100000, 0us, 0), // 100ms 30 | bandwidthSampler_(conn) { 31 | DCHECK(env_); 32 | const CongestionControlEnv::Config &cfg = env_->config(); 33 | minRTTFilter_.SetWindowLength(cfg.minRTTWindowLength.count()); 34 | 35 | VLOG(10) << __func__ << " writable=" << getWritableBytes() 36 | << " cwnd=" << cwndBytes_ << " inflight=" << bytesInFlight_ << " " 37 | << conn_; 38 | } 39 | 40 | void RLCongestionController::onRemoveBytesFromInflight(uint64_t bytes) { 41 | subtractAndCheckUnderflow(bytesInFlight_, bytes); 42 | VLOG(10) << __func__ << " writable=" << getWritableBytes() 43 | << " cwnd=" << cwndBytes_ << " inflight=" << bytesInFlight_ << " " 44 | << conn_; 45 | } 46 | 47 | void RLCongestionController::onPacketSent(const OutstandingPacket &packet) { 48 | addAndCheckOverflow(bytesInFlight_, packet.metadata.encodedSize); 49 | 50 | VLOG(10) << __func__ << " writable=" << getWritableBytes() 51 | << " cwnd=" << cwndBytes_ << " inflight=" << bytesInFlight_ 52 | << " bytesBuffered=" << conn_.flowControlState.sumCurStreamBufferLen 53 | << " packetNum=" << packet.packet.header.getPacketSequenceNum() 54 | << " " << conn_; 55 | } 56 | 57 | void RLCongestionController::onPacketAckOrLoss( 58 | folly::Optional ack, folly::Optional loss) { 59 | if (loss) { 60 | onPacketLoss(*loss); 61 | } 62 | if (ack && ack->largestAckedPacket.hasValue()) { 63 | onPacketAcked(*ack); 64 | } 65 | 66 | // State update to the env 67 | NetworkState obs; 68 | if (setNetworkState(ack, loss, obs)) { 69 | env_->onNetworkState(std::move(obs)); 70 | } 71 | } 72 | 73 | void RLCongestionController::onPacketAcked(const AckEvent &ack) { 74 | DCHECK(ack.largestAckedPacket.hasValue()); 75 | subtractAndCheckUnderflow(bytesInFlight_, ack.ackedBytes); 76 | minRTTFilter_.Update( 77 | conn_.lossState.lrtt, 78 | std::chrono::duration_cast(ack.ackTime.time_since_epoch()) 79 | .count()); 80 | standingRTTFilter_.SetWindowLength(conn_.lossState.srtt.count() / 2); 81 | standingRTTFilter_.Update( 82 | conn_.lossState.lrtt, 83 | std::chrono::duration_cast(ack.ackTime.time_since_epoch()) 84 | .count()); 85 | 86 | // The `rttCounter` argument is set to 0 because it is ignored in 87 | // `RLBandwidthSampler`. If one wanted to use a different bandwidth estimator 88 | // (e.g. `BbrBandwidthSampler`) then a proper counter should be implemeneted. 89 | bandwidthSampler_.onPacketAcked(ack, 0); 90 | 91 | VLOG(10) << __func__ << "ack size=" << ack.ackedBytes 92 | << " num packets acked=" << ack.ackedBytes / conn_.udpSendPacketLen 93 | << " writable=" << getWritableBytes() << " cwnd=" << cwndBytes_ 94 | << " inflight=" << bytesInFlight_ 95 | << " sRTT=" << conn_.lossState.srtt.count() 96 | << " lRTT=" << conn_.lossState.lrtt.count() 97 | << " mRTT=" << conn_.lossState.mrtt.count() 98 | << " rttvar=" << conn_.lossState.rttvar.count() 99 | << " packetsBufferred=" 100 | << conn_.flowControlState.sumCurStreamBufferLen 101 | << " packetsRetransmitted=" << conn_.lossState.rtxCount << " " 102 | << conn_; 103 | } 104 | 105 | void RLCongestionController::onPacketLoss(const LossEvent &loss) { 106 | VLOG(10) << __func__ << " lostBytes=" << loss.lostBytes 107 | << " lostPackets=" << loss.lostPackets << " cwnd=" << cwndBytes_ 108 | << " inflight=" << bytesInFlight_ << " " << conn_; 109 | DCHECK(loss.largestLostPacketNum.hasValue()); 110 | subtractAndCheckUnderflow(bytesInFlight_, loss.lostBytes); 111 | if (loss.persistentCongestion) { 112 | VLOG(10) << __func__ << " writable=" << getWritableBytes() 113 | << " cwnd=" << cwndBytes_ << " inflight=" << bytesInFlight_ << " " 114 | << conn_; 115 | } 116 | } 117 | 118 | void RLCongestionController::onUpdate(const uint64_t &cwndBytes) noexcept { 119 | cwndBytes_ = cwndBytes; 120 | } 121 | 122 | bool RLCongestionController::setNetworkState( 123 | const folly::Optional &ack, 124 | const folly::Optional &loss, NetworkState &obs) { 125 | const auto &state = conn_.lossState; 126 | 127 | const auto &rttMin = minRTTFilter_.GetBest(); 128 | const auto &rttStanding = standingRTTFilter_.GetBest().count(); 129 | const auto &delay = 130 | duration_cast(conn_.lossState.lrtt - rttMin).count(); 131 | if (rttStanding == 0 || delay < 0) { 132 | LOG(ERROR) 133 | << "Invalid rttStanding or delay, skipping network state update: " 134 | << "rttStanding = " << rttStanding << ", delay = " << delay << " " 135 | << conn_; 136 | return false; 137 | } 138 | 139 | const float normMs = env_->normMs(); 140 | const float normBytes = env_->normBytes(); 141 | 142 | obs[Field::RTT_MIN] = rttMin.count() / 1000.0 / normMs; 143 | obs[Field::RTT_STANDING] = rttStanding / 1000.0 / normMs; 144 | obs[Field::LRTT] = state.lrtt.count() / 1000.0 / normMs; 145 | obs[Field::SRTT] = state.srtt.count() / 1000.0 / normMs; 146 | obs[Field::RTT_VAR] = state.rttvar.count() / 1000.0 / normMs; 147 | obs[Field::DELAY] = delay / 1000.0 / normMs; 148 | 149 | obs[Field::CWND] = cwndBytes_ / normBytes; 150 | obs[Field::IN_FLIGHT] = bytesInFlight_ / normBytes; 151 | obs[Field::WRITABLE] = getWritableBytes() / normBytes; 152 | obs[Field::SENT] = (state.totalBytesSent - prevTotalBytesSent_) / normBytes; 153 | obs[Field::RECEIVED] = 154 | (state.totalBytesRecvd - prevTotalBytesRecvd_) / normBytes; 155 | obs[Field::RETRANSMITTED] = 156 | (state.totalBytesRetransmitted - prevTotalBytesRetransmitted_) / 157 | normBytes; 158 | 159 | // The throughput is in bytes / s => we normalize it with `normBytes`. 160 | DCHECK(bandwidthSampler_.getBandwidth().unitType == 161 | Bandwidth::UnitType::BYTES); 162 | obs[Field::THROUGHPUT] = 163 | bandwidthSampler_.getBandwidth().normalize() / normBytes; 164 | 165 | obs[Field::PTO_COUNT] = state.ptoCount; 166 | obs[Field::TOTAL_PTO_DELTA] = state.totalPTOCount - prevTotalPTOCount_; 167 | obs[Field::RTX_COUNT] = state.rtxCount - prevRtxCount_; 168 | obs[Field::TIMEOUT_BASED_RTX_COUNT] = 169 | state.timeoutBasedRtxCount - prevTimeoutBasedRtxCount_; 170 | 171 | if (ack && ack->largestAckedPacket.hasValue()) { 172 | obs[Field::ACKED] = ack->ackedBytes / normBytes; 173 | } 174 | 175 | if (loss) { 176 | obs[Field::LOST] = loss->lostBytes / normBytes; 177 | obs[Field::PERSISTENT_CONGESTION] = loss->persistentCongestion; 178 | } 179 | 180 | // Update prev state values 181 | prevTotalBytesSent_ = state.totalBytesSent; 182 | prevTotalBytesRecvd_ = state.totalBytesRecvd; 183 | prevTotalBytesRetransmitted_ = state.totalBytesRetransmitted; 184 | prevTotalPTOCount_ = state.totalPTOCount; 185 | prevRtxCount_ = state.rtxCount; 186 | prevTimeoutBasedRtxCount_ = state.timeoutBasedRtxCount; 187 | 188 | return true; 189 | } 190 | 191 | uint64_t RLCongestionController::getWritableBytes() const noexcept { 192 | if (bytesInFlight_ > cwndBytes_) { 193 | return 0; 194 | } else { 195 | return cwndBytes_ - bytesInFlight_; 196 | } 197 | } 198 | 199 | uint64_t RLCongestionController::getCongestionWindow() const noexcept { 200 | return cwndBytes_; 201 | } 202 | 203 | CongestionControlType RLCongestionController::type() const noexcept { 204 | return CongestionControlType::None; 205 | } 206 | 207 | uint64_t RLCongestionController::getBytesInFlight() const noexcept { 208 | return bytesInFlight_; 209 | } 210 | 211 | void RLCongestionController::setAppIdle(bool, 212 | TimePoint) noexcept { /* unsupported */ 213 | } 214 | 215 | void RLCongestionController::setAppLimited() { /* unsupported */ 216 | } 217 | 218 | bool RLCongestionController::isAppLimited() const noexcept { 219 | return false; // not supported 220 | } 221 | 222 | } // namespace quic 223 | -------------------------------------------------------------------------------- /train/experiments.yml: -------------------------------------------------------------------------------- 1 | # Basically copied from 2 | # https://github.com/StanfordSNR/observatory/blob/master/src/scripts/experiments.yml. 3 | # Modified a bit for our purposes. 4 | 5 | meta: 6 | branch: main 7 | base_dir: {src_dir} 8 | test_path: {src_dir}/_build/deps/pantheon/src/experiments/test.py 9 | data_base_dir: {src_dir}/data 10 | tmp_dir: {src_dir}/tmp 11 | install_deps_path: {pantheon_root}/tools/install_deps.sh 12 | pkill_path: {pantheon_root}/tools/pkill.py 13 | setup_system_path: {pantheon_root}/src/experiments/setup_system.py 14 | setup_path: {pantheon_root}/src/experiments/setup.py 15 | analyze_path: {pantheon_root}/src/analysis/analyze.py 16 | traces_dir: {src_dir}/train/traces 17 | 18 | emu: 19 | matrix: 20 | flow_scenario: 21 | - -f 1 22 | macros: 23 | common_param_set: >- 24 | local --data-dir {data_dir} --pkill-cleanup 25 | jobs: 26 | # 1, 'Calibrated emulator (Nepal to AWS India)' 27 | - scenario: 1 28 | desc: >- 29 | Calibrated to the real path from Nepal to AWS India 30 | (https://pantheon.stanford.edu/result/188/) 31 | command: >- 32 | {test_path} {common_param_set} 33 | --uplink-trace {traces_dir}/0.57mbps-poisson.trace 34 | --downlink-trace {traces_dir}/0.57mbps-poisson.trace 35 | --prepend-mm-cmds "mm-delay 28 mm-loss uplink 0.0477" 36 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=14" 37 | # 2, 'Calibrated emulator (Mexico cellular to AWS California)' 38 | - scenario: 2 39 | desc: >- 40 | Calibrated to the real path from Mexico cellular to AWS California 41 | (https://pantheon.stanford.edu/result/196/) 42 | command: >- 43 | {test_path} {common_param_set} 44 | --uplink-trace {traces_dir}/2.64mbps-poisson.trace 45 | --downlink-trace {traces_dir}/2.64mbps-poisson.trace 46 | --prepend-mm-cmds "mm-delay 88" 47 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=130" 48 | # 3, 'Calibrated emulator (AWS Brazil to Colombia cellular)' 49 | - scenario: 3 50 | desc: >- 51 | Calibrated to the real path from AWS Brazil to Colombia cellular 52 | (https://pantheon.stanford.edu/result/339/) 53 | command: >- 54 | {test_path} {common_param_set} 55 | --uplink-trace {traces_dir}/3.04mbps-poisson.trace 56 | --downlink-trace {traces_dir}/3.04mbps-poisson.trace 57 | --prepend-mm-cmds "mm-delay 130" 58 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=426" 59 | # 4, 'Calibrated emulator (India to AWS India)' 60 | - scenario: 4 61 | desc: >- 62 | Calibrated to the real path from India to AWS India 63 | (https://pantheon.stanford.edu/result/251/) 64 | command: >- 65 | {test_path} {common_param_set} 66 | --uplink-trace {traces_dir}/100.42mbps.trace 67 | --downlink-trace {traces_dir}/100.42mbps.trace 68 | --prepend-mm-cmds "mm-delay 27" 69 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=173" 70 | # 5, 'Calibrated emulator (AWS Korea to China)' 71 | - scenario: 5 72 | desc: >- 73 | Calibrated to the real path from AWS Korea to China 74 | (https://pantheon.stanford.edu/result/361/) 75 | command: >- 76 | {test_path} {common_param_set} 77 | --uplink-trace {traces_dir}/77.72mbps.trace 78 | --downlink-trace {traces_dir}/77.72mbps.trace 79 | --prepend-mm-cmds "mm-delay 51 mm-loss uplink 0.0006" 80 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=94" 81 | # 6, 'Calibrated emulator (AWS California to Mexico)' 82 | - scenario: 6 83 | desc: >- 84 | Calibrated to the real path from AWS California to Mexico 85 | (https://pantheon.stanford.edu/result/353/) 86 | command: >- 87 | {test_path} {common_param_set} 88 | --uplink-trace {traces_dir}/114.68mbps.trace 89 | --downlink-trace {traces_dir}/114.68mbps.trace 90 | --prepend-mm-cmds "mm-delay 45" 91 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=450" 92 | # 7, 'Token-bucket based policer (bandwidth 12mbps, RTT 20ms)' 93 | - scenario: 7 94 | desc: Token-bucket based policer (bandwidth 12mbps, RTT 20ms) 95 | command: >- 96 | {test_path} {common_param_set} 97 | --uplink-trace {traces_dir}/12mbps.trace 98 | --downlink-trace {traces_dir}/12mbps.trace 99 | --prepend-mm-cmds "mm-delay 10" 100 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=1 --downlink-queue=droptail --downlink-queue-args=packets=1" 101 | # 8, 'Token-bucket based policer (bandwidth 60mbps, RTT 20ms)' 102 | - scenario: 8 103 | desc: Token-bucket based policer (bandwidth 60mbps, RTT 20ms) 104 | command: >- 105 | {test_path} {common_param_set} 106 | --uplink-trace {traces_dir}/60mbps.trace 107 | --downlink-trace {traces_dir}/60mbps.trace 108 | --prepend-mm-cmds "mm-delay 10" 109 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=1 --downlink-queue=droptail --downlink-queue-args=packets=1" 110 | # 9, 'Token-bucket based policer (bandwidth 108mbps, RTT 20ms)' 111 | - scenario: 9 112 | desc: Token-bucket based policer (bandwidth 108mbps, RTT 20ms) 113 | command: >- 114 | {test_path} {common_param_set} 115 | --uplink-trace {traces_dir}/108mbps.trace 116 | --downlink-trace {traces_dir}/108mbps.trace 117 | --prepend-mm-cmds "mm-delay 10" 118 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=1 --downlink-queue=droptail --downlink-queue-args=packets=1" 119 | # 10, 'Token-bucket based policer (bandwidth 12mbps, RTT 100ms)' 120 | - scenario: 10 121 | desc: Token-bucket based policer (bandwidth 12mbps, RTT 100ms) 122 | command: >- 123 | {test_path} {common_param_set} 124 | --uplink-trace {traces_dir}/12mbps.trace 125 | --downlink-trace {traces_dir}/12mbps.trace 126 | --prepend-mm-cmds "mm-delay 50" 127 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=1 --downlink-queue=droptail --downlink-queue-args=packets=1" 128 | # 11, 'Token-bucket based policer (bandwidth 60mbps, RTT 100ms)' 129 | - scenario: 11 130 | desc: Token-bucket based policer (bandwidth 60mbps, RTT 100ms) 131 | command: >- 132 | {test_path} {common_param_set} 133 | --uplink-trace {traces_dir}/60mbps.trace 134 | --downlink-trace {traces_dir}/60mbps.trace 135 | --prepend-mm-cmds "mm-delay 50" 136 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=1 --downlink-queue=droptail --downlink-queue-args=packets=1" 137 | # 12, 'Token-bucket based policer (bandwidth 108mbps, RTT 100ms)' 138 | - scenario: 12 139 | desc: Token-bucket based policer (bandwidth 108mbps, RTT 100ms) 140 | command: >- 141 | {test_path} {common_param_set} 142 | --uplink-trace {traces_dir}/108mbps.trace 143 | --downlink-trace {traces_dir}/108mbps.trace 144 | --prepend-mm-cmds "mm-delay 50" 145 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=packets=1 --downlink-queue=droptail --downlink-queue-args=packets=1" 146 | # 13, 'Severe ACK aggregation (1 ACK every 100ms)' 147 | - scenario: 13 148 | desc: 'Severe ACK aggregation (1 ACK every 100ms)' 149 | command: >- 150 | {test_path} {common_param_set} 151 | --uplink-trace {traces_dir}/12mbps.trace 152 | --downlink-trace {traces_dir}/0.12mbps.trace 153 | --prepend-mm-cmds "mm-delay 10" 154 | # 14, 'Severe ACK aggregation (10 ACKs every 200ms)' 155 | - scenario: 14 156 | desc: 'Severe ACK aggregation (10 ACKs every 200ms)' 157 | command: >- 158 | {test_path} {common_param_set} 159 | --uplink-trace {traces_dir}/12mbps.trace 160 | --downlink-trace {traces_dir}/10-every-200.trace 161 | --prepend-mm-cmds "mm-delay 10" 162 | # 15, 'Bottleneck buffer = BDP/10' 163 | - scenario: 15 164 | desc: 'Bottleneck buffer = BDP/10' 165 | command: >- 166 | {test_path} {common_param_set} 167 | --uplink-trace {traces_dir}/12mbps.trace 168 | --downlink-trace {traces_dir}/12mbps.trace 169 | --prepend-mm-cmds "mm-delay 30" 170 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=bytes=9000" 171 | # 16, 'Bottleneck buffer = BDP/3' 172 | - scenario: 16 173 | desc: 'Bottleneck buffer = BDP/3' 174 | command: >- 175 | {test_path} {common_param_set} 176 | --uplink-trace {traces_dir}/12mbps.trace 177 | --downlink-trace {traces_dir}/12mbps.trace 178 | --prepend-mm-cmds "mm-delay 30" 179 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=bytes=30000" 180 | # 17, 'Bottleneck buffer = BDP/2' 181 | - scenario: 17 182 | desc: 'Bottleneck buffer = BDP/2' 183 | command: >- 184 | {test_path} {common_param_set} 185 | --uplink-trace {traces_dir}/12mbps.trace 186 | --downlink-trace {traces_dir}/12mbps.trace 187 | --prepend-mm-cmds "mm-delay 30" 188 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=bytes=45000" 189 | # 18, 'Bottleneck buffer = BDP' 190 | - scenario: 18 191 | desc: 'Bottleneck buffer = BDP' 192 | command: >- 193 | {test_path} {common_param_set} 194 | --uplink-trace {traces_dir}/12mbps.trace 195 | --downlink-trace {traces_dir}/12mbps.trace 196 | --prepend-mm-cmds "mm-delay 30" 197 | --extra-mm-link-args "--uplink-queue=droptail --uplink-queue-args=bytes=90000" 198 | -------------------------------------------------------------------------------- /third-party/gala/envs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ 8 | MIT License 9 | 10 | Copyright (c) 2017 Ilya Kostrikov 11 | 12 | Permission is hereby granted, free of charge, to any person obtaining a copy 13 | of this software and associated documentation files (the "Software"), to deal 14 | in the Software without restriction, including without limitation the rights 15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | copies of the Software, and to permit persons to whom the Software is 17 | furnished to do so, subject to the following conditions: 18 | 19 | The above copyright notice and this permission notice shall be included in all 20 | copies or substantial portions of the Software. 21 | 22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 28 | SOFTWARE. 29 | 30 | Modified from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail 31 | """ 32 | 33 | import os 34 | 35 | import gym 36 | import numpy as np 37 | import torch 38 | from gym.spaces.box import Box 39 | 40 | from baselines import bench 41 | from baselines.common.atari_wrappers import make_atari, wrap_deepmind 42 | from baselines.common.vec_env import VecEnvWrapper 43 | from baselines.common.vec_env.dummy_vec_env import DummyVecEnv 44 | from baselines.common.vec_env.shmem_vec_env import ShmemVecEnv 45 | from baselines.common.vec_env.vec_normalize import \ 46 | VecNormalize as VecNormalize_ 47 | 48 | try: 49 | import dm_control2gym 50 | except ImportError: 51 | pass 52 | 53 | def make_env(env_id, seed, rank, log_dir, allow_early_resets, signature='', 54 | max_steps=None): 55 | def _thunk(): 56 | if env_id.startswith("dm"): 57 | _, domain, task = env_id.split('.') 58 | env = dm_control2gym.make(domain_name=domain, task_name=task) 59 | else: 60 | env = gym.make(env_id) 61 | 62 | is_atari = hasattr(gym.envs, 'atari') and isinstance( 63 | env.unwrapped, gym.envs.atari.atari_env.AtariEnv) 64 | if is_atari: 65 | env = make_atari(env_id, max_steps) 66 | 67 | env.seed(seed + rank) 68 | 69 | obs_shape = env.observation_space.shape 70 | 71 | if str(env.__class__.__name__).find('TimeLimit') >= 0: 72 | env = TimeLimitMask(env) 73 | 74 | if log_dir is not None: 75 | env = bench.Monitor( 76 | env, 77 | os.path.join(log_dir, str(rank) + signature), 78 | allow_early_resets=allow_early_resets) 79 | 80 | if is_atari: 81 | if len(env.observation_space.shape) == 3: 82 | env = wrap_deepmind(env) 83 | elif len(env.observation_space.shape) == 3: 84 | raise NotImplementedError( 85 | "CNN models work only for atari,\n" 86 | "please use a custom wrapper for a custom pixel input env.\n" 87 | "See wrap_deepmind for an example.") 88 | 89 | # If the input has shape (W,H,3), wrap for PyTorch convolutions 90 | obs_shape = env.observation_space.shape 91 | if len(obs_shape) == 3 and obs_shape[2] in [1, 3]: 92 | env = TransposeImage(env, op=[2, 0, 1]) 93 | 94 | return env 95 | 96 | return _thunk 97 | 98 | 99 | def make_vec_envs(env_name, seed, num_processes, gamma, log_dir, device, 100 | allow_early_resets, num_frame_stack=None, rank=0, 101 | signature='', max_steps=None): 102 | print('log-dir', log_dir) 103 | envs = [ 104 | make_env(env_name, seed, (rank * num_processes) + i, log_dir, 105 | allow_early_resets, signature, max_steps) 106 | for i in range(num_processes) 107 | ] 108 | 109 | if len(envs) > 1: 110 | envs = ShmemVecEnv(envs) 111 | else: 112 | envs = DummyVecEnv(envs) 113 | 114 | if len(envs.observation_space.shape) == 1: 115 | if gamma is None: 116 | envs = VecNormalize(envs, ret=False) 117 | else: 118 | envs = VecNormalize(envs, gamma=gamma) 119 | 120 | envs = VecPyTorch(envs, device) 121 | 122 | if num_frame_stack is not None: 123 | envs = VecPyTorchFrameStack(envs, num_frame_stack, device) 124 | elif len(envs.observation_space.shape) == 3: 125 | envs = VecPyTorchFrameStack(envs, 4, device) 126 | 127 | return envs 128 | 129 | 130 | # Checks whether done was caused my timit limits or not 131 | class TimeLimitMask(gym.Wrapper): 132 | def step(self, action): 133 | obs, rew, done, info = self.env.step(action) 134 | if done and self.env._max_episode_steps == self.env._elapsed_steps: 135 | info['bad_transition'] = True 136 | 137 | return obs, rew, done, info 138 | 139 | def reset(self, **kwargs): 140 | return self.env.reset(**kwargs) 141 | 142 | 143 | # Can be used to test recurrent policies for Reacher-v2 144 | class MaskGoal(gym.ObservationWrapper): 145 | def observation(self, observation): 146 | if self.env._elapsed_steps > 0: 147 | observation[-2:0] = 0 148 | return observation 149 | 150 | 151 | class TransposeObs(gym.ObservationWrapper): 152 | def __init__(self, env=None): 153 | """ 154 | Transpose observation space (base class) 155 | """ 156 | super(TransposeObs, self).__init__(env) 157 | 158 | 159 | class TransposeImage(TransposeObs): 160 | def __init__(self, env=None, op=[2, 0, 1]): 161 | """ 162 | Transpose observation space for images 163 | """ 164 | super(TransposeImage, self).__init__(env) 165 | assert len(op) == 3, f"Error: Operation, {str(op)}, must be dim3" 166 | self.op = op 167 | obs_shape = self.observation_space.shape 168 | self.observation_space = Box( 169 | self.observation_space.low[0, 0, 0], 170 | self.observation_space.high[0, 0, 0], [ 171 | obs_shape[self.op[0]], obs_shape[self.op[1]], 172 | obs_shape[self.op[2]] 173 | ], 174 | dtype=self.observation_space.dtype) 175 | 176 | def observation(self, ob): 177 | return ob.transpose(self.op[0], self.op[1], self.op[2]) 178 | 179 | 180 | class VecPyTorch(VecEnvWrapper): 181 | def __init__(self, venv, device): 182 | """Return only every `skip`-th frame""" 183 | super(VecPyTorch, self).__init__(venv) 184 | self.device = device 185 | # TODO: Fix data types 186 | 187 | def reset(self): 188 | obs = self.venv.reset() 189 | obs = torch.from_numpy(obs).float().to(self.device) 190 | return obs 191 | 192 | def step_async(self, actions): 193 | if isinstance(actions, torch.LongTensor): 194 | # Squeeze the dimension for discrete actions 195 | actions = actions.squeeze(1) 196 | actions = actions.cpu().numpy() 197 | self.venv.step_async(actions) 198 | 199 | def step_wait(self): 200 | obs, reward, done, info = self.venv.step_wait() 201 | obs = torch.from_numpy(obs).float().to(self.device) 202 | reward = torch.from_numpy(reward).unsqueeze(dim=1).float() 203 | return obs, reward, done, info 204 | 205 | 206 | class VecNormalize(VecNormalize_): 207 | def __init__(self, *args, **kwargs): 208 | super(VecNormalize, self).__init__(*args, **kwargs) 209 | self.training = True 210 | 211 | def _obfilt(self, obs, update=True): 212 | if self.ob_rms: 213 | if self.training and update: 214 | self.ob_rms.update(obs) 215 | obs = np.clip((obs - self.ob_rms.mean) / 216 | np.sqrt(self.ob_rms.var + self.epsilon), 217 | -self.clipob, self.clipob) 218 | return obs 219 | else: 220 | return obs 221 | 222 | def train(self): 223 | self.training = True 224 | 225 | def eval(self): 226 | self.training = False 227 | 228 | 229 | # Derived from 230 | # https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_frame_stack.py 231 | class VecPyTorchFrameStack(VecEnvWrapper): 232 | def __init__(self, venv, nstack, device=None): 233 | self.venv = venv 234 | self.nstack = nstack 235 | 236 | wos = venv.observation_space # wrapped ob space 237 | self.shape_dim0 = wos.shape[0] 238 | 239 | low = np.repeat(wos.low, self.nstack, axis=0) 240 | high = np.repeat(wos.high, self.nstack, axis=0) 241 | 242 | if device is None: 243 | device = torch.device('cpu') 244 | self.stacked_obs = torch.zeros((venv.num_envs, ) + 245 | low.shape).to(device) 246 | 247 | observation_space = gym.spaces.Box( 248 | low=low, high=high, dtype=venv.observation_space.dtype) 249 | VecEnvWrapper.__init__(self, venv, observation_space=observation_space) 250 | 251 | def step_wait(self): 252 | obs, rews, news, infos = self.venv.step_wait() 253 | self.stacked_obs[:, :-self.shape_dim0] = \ 254 | self.stacked_obs[:, self.shape_dim0:] 255 | for (i, new) in enumerate(news): 256 | if new: 257 | self.stacked_obs[i] = 0 258 | self.stacked_obs[:, -self.shape_dim0:] = obs 259 | return self.stacked_obs, rews, news, infos 260 | 261 | def reset(self): 262 | obs = self.venv.reset() 263 | if torch.backends.cudnn.deterministic: 264 | self.stacked_obs = torch.zeros(self.stacked_obs.shape) 265 | else: 266 | self.stacked_obs.zero_() 267 | self.stacked_obs[:, -self.shape_dim0:] = obs 268 | return self.stacked_obs 269 | 270 | def close(self): 271 | self.venv.close() 272 | -------------------------------------------------------------------------------- /third-party/gala/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2017 Ilya Kostrikov 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | Taken from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail 25 | and slightly modified. 26 | """ 27 | 28 | import numpy as np 29 | import torch 30 | import torch.nn as nn 31 | 32 | from gala.distributions import Bernoulli, Categorical, DiagGaussian 33 | 34 | 35 | def init(module, weight_init, bias_init, gain=1): 36 | weight_init(module.weight.data, gain=gain) 37 | bias_init(module.bias.data) 38 | return module 39 | 40 | 41 | class Flatten(nn.Module): 42 | def forward(self, x): 43 | return x.view(x.size(0), -1) 44 | 45 | 46 | class Policy(nn.Module): 47 | def __init__(self, obs_shape, action_space=None, base=None, base_kwargs=None, env_name=None): 48 | super(Policy, self).__init__() 49 | if base_kwargs is None: 50 | base_kwargs = {} 51 | if base is None: 52 | if len(obs_shape) == 3: 53 | base = CNNBase 54 | elif len(obs_shape) == 1: 55 | base = MLPBase 56 | else: 57 | raise NotImplementedError 58 | 59 | self.base = base(obs_shape[0], **base_kwargs) 60 | 61 | if action_space is None: 62 | game = env_name[:env_name.find('NoFrameskip')] 63 | num_actions = { 64 | 'BeamRider': 9, 65 | 'Breakout': 4, 66 | 'Pong': 6, 67 | 'Qbert': 6, 68 | 'Seaquest': 18, 69 | 'SpaceInvaders': 6, 70 | } 71 | num_outputs = num_actions[game] 72 | self.dist = Categorical(self.base.output_size, num_outputs) 73 | elif action_space.__class__.__name__ == "Discrete": 74 | num_outputs = action_space.n 75 | self.dist = Categorical(self.base.output_size, num_outputs) 76 | elif action_space.__class__.__name__ == "Box": 77 | num_outputs = action_space.shape[0] 78 | self.dist = DiagGaussian(self.base.output_size, num_outputs) 79 | elif action_space.__class__.__name__ == "MultiBinary": 80 | num_outputs = action_space.shape[0] 81 | self.dist = Bernoulli(self.base.output_size, num_outputs) 82 | else: 83 | raise NotImplementedError 84 | 85 | @property 86 | def is_recurrent(self): 87 | return self.base.is_recurrent 88 | 89 | @property 90 | def recurrent_hidden_state_size(self): 91 | """Size of rnn_hx.""" 92 | return self.base.recurrent_hidden_state_size 93 | 94 | def forward(self, inputs, rnn_hxs, masks): 95 | raise NotImplementedError 96 | 97 | def act(self, inputs, rnn_hxs, masks, deterministic=False): 98 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 99 | dist = self.dist(actor_features) 100 | 101 | if deterministic: 102 | action = dist.mode() 103 | else: 104 | action = dist.sample() 105 | 106 | action_log_probs = dist.log_probs(action) 107 | dist_entropy = dist.entropy().mean() 108 | 109 | return value, action, action_log_probs, rnn_hxs 110 | 111 | def get_value(self, inputs, rnn_hxs, masks): 112 | value, _, _ = self.base(inputs, rnn_hxs, masks) 113 | return value 114 | 115 | def evaluate_actions(self, inputs, rnn_hxs, masks, action): 116 | value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) 117 | dist = self.dist(actor_features) 118 | 119 | action_log_probs = dist.log_probs(action) 120 | dist_entropy = dist.entropy().mean() 121 | 122 | return value, action_log_probs, dist_entropy, rnn_hxs 123 | 124 | 125 | class NNBase(nn.Module): 126 | def __init__(self, recurrent, recurrent_input_size, hidden_size): 127 | super(NNBase, self).__init__() 128 | 129 | self._hidden_size = hidden_size 130 | self._recurrent = recurrent 131 | 132 | if recurrent: 133 | self.gru = nn.GRU(recurrent_input_size, hidden_size) 134 | for name, param in self.gru.named_parameters(): 135 | if 'bias' in name: 136 | nn.init.constant_(param, 0) 137 | elif 'weight' in name: 138 | nn.init.orthogonal_(param) 139 | 140 | @property 141 | def is_recurrent(self): 142 | return self._recurrent 143 | 144 | @property 145 | def recurrent_hidden_state_size(self): 146 | if self._recurrent: 147 | return self._hidden_size 148 | return 1 149 | 150 | @property 151 | def output_size(self): 152 | return self._hidden_size 153 | 154 | def _forward_gru(self, x, hxs, masks): 155 | if x.size(0) == hxs.size(0): 156 | x, hxs = self.gru(x.unsqueeze(0), (hxs * masks).unsqueeze(0)) 157 | x = x.squeeze(0) 158 | hxs = hxs.squeeze(0) 159 | else: 160 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 161 | N = hxs.size(0) 162 | T = int(x.size(0) / N) 163 | 164 | # unflatten 165 | x = x.view(T, N, x.size(1)) 166 | 167 | # Same deal with masks 168 | masks = masks.view(T, N) 169 | 170 | # Let's figure out which steps in the sequence have a zero for any agent 171 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 172 | has_zeros = ((masks[1:] == 0.0) \ 173 | .any(dim=-1) 174 | .nonzero() 175 | .squeeze() 176 | .cpu()) 177 | 178 | # +1 to correct the masks[1:] 179 | if has_zeros.dim() == 0: 180 | # Deal with scalar 181 | has_zeros = [has_zeros.item() + 1] 182 | else: 183 | has_zeros = (has_zeros + 1).numpy().tolist() 184 | 185 | # add t=0 and t=T to the list 186 | has_zeros = [0] + has_zeros + [T] 187 | 188 | hxs = hxs.unsqueeze(0) 189 | outputs = [] 190 | for i in range(len(has_zeros) - 1): 191 | # We can now process steps that don't have any zeros in masks together! 192 | # This is much faster 193 | start_idx = has_zeros[i] 194 | end_idx = has_zeros[i + 1] 195 | 196 | rnn_scores, hxs = self.gru( 197 | x[start_idx:end_idx], 198 | hxs * masks[start_idx].view(1, -1, 1)) 199 | 200 | outputs.append(rnn_scores) 201 | 202 | # assert len(outputs) == T 203 | # x is a (T, N, -1) tensor 204 | x = torch.cat(outputs, dim=0) 205 | # flatten 206 | x = x.view(T * N, -1) 207 | hxs = hxs.squeeze(0) 208 | 209 | return x, hxs 210 | 211 | 212 | class CNNBase(NNBase): 213 | def __init__(self, num_inputs, recurrent=False, hidden_size=512): 214 | super(CNNBase, self).__init__(recurrent, hidden_size, hidden_size) 215 | 216 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 217 | constant_(x, 0), nn.init.calculate_gain('relu')) 218 | 219 | self.main = nn.Sequential( 220 | init_(nn.Conv2d(num_inputs, 32, 8, stride=4)), nn.ReLU(), 221 | init_(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(), 222 | init_(nn.Conv2d(64, 32, 3, stride=1)), nn.ReLU(), Flatten(), 223 | init_(nn.Linear(32 * 7 * 7, hidden_size)), nn.ReLU()) 224 | 225 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 226 | constant_(x, 0)) 227 | 228 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 229 | 230 | self.train() 231 | 232 | def forward(self, inputs, rnn_hxs, masks): 233 | x = self.main(inputs / 255.0) 234 | 235 | if self.is_recurrent: 236 | x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) 237 | 238 | return self.critic_linear(x), x, rnn_hxs 239 | 240 | 241 | class MLPBase(NNBase): 242 | def __init__(self, num_inputs, recurrent=False, hidden_size=64): 243 | super(MLPBase, self).__init__(recurrent, num_inputs, hidden_size) 244 | 245 | if recurrent: 246 | num_inputs = hidden_size 247 | 248 | init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. 249 | constant_(x, 0), np.sqrt(2)) 250 | 251 | self.actor = nn.Sequential( 252 | init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(), 253 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 254 | 255 | self.critic = nn.Sequential( 256 | init_(nn.Linear(num_inputs, hidden_size)), nn.Tanh(), 257 | init_(nn.Linear(hidden_size, hidden_size)), nn.Tanh()) 258 | 259 | self.critic_linear = init_(nn.Linear(hidden_size, 1)) 260 | 261 | self.train() 262 | 263 | def forward(self, inputs, rnn_hxs, masks): 264 | x = inputs 265 | 266 | if self.is_recurrent: 267 | x, rnn_hxs = self._forward_gru(x, rnn_hxs, masks) 268 | 269 | hidden_critic = self.critic(x) 270 | hidden_actor = self.actor(x) 271 | 272 | return self.critic_linear(hidden_critic), hidden_actor, rnn_hxs 273 | -------------------------------------------------------------------------------- /third-party/gala/graph_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | """ 8 | Graph Manager Class 9 | 10 | :author: Mido Assran 11 | :description: Class provides an API for loading different peer-to-peer 12 | communication topologies, and cycling through peers. 13 | """ 14 | 15 | from math import log as mlog 16 | 17 | 18 | class GraphManager(object): 19 | 20 | def __init__(self, rank, world_size, peers_per_itr=1): 21 | assert int(peers_per_itr) >= 1 22 | self.rank = rank 23 | self.world_size = world_size 24 | self.phone_book = self._make_graph() 25 | self._peers_per_itr = peers_per_itr 26 | self._group_indices = [i for i in range(peers_per_itr)] 27 | 28 | @property 29 | def peers_per_itr(self): 30 | return self._peers_per_itr 31 | 32 | @peers_per_itr.setter 33 | def peers_per_itr(self, v): 34 | self._peers_per_itr = v 35 | # set group-indices attr. --- point to out-peers in phone-book 36 | self._group_indices = [i for i in range(v)] 37 | 38 | def _make_graph(self): 39 | """ 40 | Returns a nested list of peers; the outer-list is indexed by rank, 41 | the inner list denotes the set of peers that 'rank' can send 42 | messages to at any point in time 43 | """ 44 | raise NotImplementedError 45 | 46 | def is_regular_graph(self): 47 | """ Whether each node has the same number of in-peers as out-peers """ 48 | raise NotImplementedError 49 | 50 | def is_bipartite_graph(self): 51 | """ Whether graph is bipartite or not """ 52 | raise NotImplementedError 53 | 54 | def is_passive(self, rank=None): 55 | """ Whether 'rank' is a passive node or not """ 56 | raise NotImplementedError 57 | 58 | def is_dynamic_graph(self, graph_type=None): 59 | """ Whether the graph-type is dynamic (as opposed to static) """ 60 | raise NotImplementedError 61 | 62 | def get_peers(self, rotate=False): 63 | """ Returns the out and in-peers corresponding to 'self.rank' """ 64 | # cycle through in- and out-peers by updating group-index 65 | if rotate: 66 | self._rotate_group_indices() 67 | 68 | # get out- and in-peers using new group-indices 69 | out_peers, in_peers = [], [] 70 | for group_index in self._group_indices: 71 | out_peers.append(self.phone_book[self.rank][group_index]) 72 | for rank, peers in enumerate(self.phone_book): 73 | if rank == self.rank: 74 | continue 75 | if self.rank == peers[group_index]: 76 | in_peers.append(rank) 77 | return out_peers, in_peers 78 | 79 | def _rotate_group_indices(self): 80 | """ Incerement group indices to point to the next out-peer """ 81 | increment = self.peers_per_itr 82 | for i, group_index in enumerate(self._group_indices): 83 | self._group_indices[i] = int((group_index + increment) 84 | % len(self.phone_book[self.rank])) 85 | 86 | def _rotate_forward(self, r, p): 87 | """ Helper function returns peer that is p hops ahead of r """ 88 | return (r + p) % self.world_size 89 | 90 | def _rotate_backward(self, r, p): 91 | """ Helper function returns peer that is p hops behind r """ 92 | temp = r 93 | for _ in range(p): 94 | temp -= 1 95 | if temp < 0: 96 | temp = self.world_size - 1 97 | return temp 98 | 99 | 100 | class DynamicDirectedExponentialGraph(GraphManager): 101 | 102 | def _make_graph(self): 103 | phone_book = [[] for _ in range(self.world_size)] 104 | for rank in range(self.world_size): 105 | group = phone_book[rank] 106 | for i in range(0, int(mlog(self.world_size - 1, 2)) + 1): 107 | f_peer = self._rotate_forward(rank, 2 ** i) 108 | if f_peer not in group: 109 | group.append(f_peer) 110 | b_peer = self._rotate_backward(rank, 2 ** i) 111 | if b_peer not in group: 112 | group.append(b_peer) 113 | return phone_book 114 | 115 | def is_regular_graph(self): return True 116 | 117 | def is_bipartite_graph(self): return False 118 | 119 | def is_passive(self, rank=None): return False 120 | 121 | def is_dynamic_graph(self, graph_type=None): return True 122 | 123 | 124 | class DynamicBipartiteExponentialGraph(GraphManager): 125 | 126 | def _make_graph(self): 127 | phone_book = [[] for _ in range(self.world_size)] 128 | for rank in range(self.world_size): 129 | group = phone_book[rank] 130 | for i in range(0, int(mlog(self.world_size - 1, 2)) + 1): 131 | if i == 0: 132 | f_peer = self._rotate_forward(rank, 1) 133 | b_peer = self._rotate_backward(rank, 1) 134 | else: 135 | f_peer = self._rotate_forward(rank, 1 + 2 ** i) 136 | b_peer = self._rotate_backward(rank, 1 + 2 ** i) 137 | # create directory for non-passive peers 138 | if not self.is_passive(rank) and ( 139 | self.is_passive(f_peer) and self.is_passive(b_peer)): 140 | if f_peer not in group: 141 | group.append(f_peer) # forward peer... 142 | if b_peer not in group: 143 | group.append(b_peer) # then backward peer 144 | # create directory for passive peers 145 | elif self.is_passive(rank) and ( 146 | not (self.is_passive(f_peer) or self.is_passive(b_peer))): 147 | if b_peer not in group: 148 | group.append(b_peer) # backward peer... 149 | if f_peer not in group: 150 | group.append(f_peer) # then forward peer 151 | return phone_book 152 | 153 | def is_regular_graph(self): return True 154 | 155 | def is_bipartite_graph(self): return True 156 | 157 | def is_passive(self, rank=None): 158 | rank = self.rank if rank is None else rank 159 | return (rank % 2) == 0 160 | 161 | def is_dynamic_graph(self, graph_type=None): return True 162 | 163 | 164 | class DynamicDirectedLinearGraph(GraphManager): 165 | 166 | def _make_graph(self): 167 | phone_book = [[] for _ in range(self.world_size)] 168 | for rank in range(self.world_size): 169 | group = phone_book[rank] 170 | for i in range(1, self.world_size): 171 | if i % 2 == 0: 172 | continue 173 | f_peer = self._rotate_forward(rank, i) 174 | if f_peer not in group: 175 | group.append(f_peer) 176 | b_peer = self._rotate_backward(rank, i) 177 | if b_peer not in group: 178 | group.append(b_peer) 179 | return phone_book 180 | 181 | def is_regular_graph(self): return True 182 | 183 | def is_bipartite_graph(self): return False 184 | 185 | def is_passive(self, rank=None): return False 186 | 187 | def is_dynamic_graph(self, graph_type=None): return True 188 | 189 | 190 | class DynamicBipartiteLinearGraph(GraphManager): 191 | 192 | def _make_graph(self): 193 | phone_book = [[] for _ in range(self.world_size)] 194 | for rank in range(self.world_size): 195 | group = phone_book[rank] 196 | for i in range(1, self.world_size): 197 | f_peer = self._rotate_forward(rank, i) 198 | b_peer = self._rotate_backward(rank, i) 199 | # create directory for non-passive peers 200 | if not self.is_passive(rank) and ( 201 | self.is_passive(f_peer) and self.is_passive(b_peer)): 202 | if f_peer not in group: 203 | group.append(f_peer) # forward peer... 204 | if b_peer not in group: 205 | group.append(b_peer) # then backward peer 206 | # create directory for passive peers 207 | elif self.is_passive(rank) and ( 208 | not (self.is_passive(f_peer) or self.is_passive(b_peer))): 209 | if b_peer not in group: 210 | group.append(b_peer) # backward peer... 211 | if f_peer not in group: 212 | group.append(f_peer) # then forward peer 213 | return phone_book 214 | 215 | def is_regular_graph(self): return True 216 | 217 | def is_bipartite_graph(self): return True 218 | 219 | def is_passive(self, rank=None): 220 | rank = self.rank if rank is None else rank 221 | return (rank % 2) == 0 222 | 223 | def is_dynamic_graph(self, graph_type=None): return True 224 | 225 | 226 | class StaticDirectedLinearGraph(GraphManager): 227 | 228 | def _make_graph(self): 229 | phone_book = [[] for _ in range(self.world_size)] 230 | for rank in range(self.world_size): 231 | group = phone_book[rank] 232 | f_peer = self._rotate_forward(rank, 1) 233 | if f_peer not in group: 234 | group.append(f_peer) 235 | return phone_book 236 | 237 | def is_regular_graph(self): return True 238 | 239 | def is_bipartite_graph(self): return False 240 | 241 | def is_passive(self, rank=None): return False 242 | 243 | def is_dynamic_graph(self, graph_type=None): return False 244 | 245 | 246 | class FullyConnectedGraph(GraphManager): 247 | 248 | def _make_graph(self): 249 | phone_book = [[] for _ in range(self.world_size)] 250 | for rank in range(self.world_size): 251 | group = phone_book[rank] 252 | for i in range(1, self.world_size): 253 | f_peer = self._rotate_forward(rank, i) 254 | if f_peer not in group and f_peer != rank: 255 | group.append(f_peer) 256 | return phone_book 257 | 258 | def is_regular_graph(self): return True 259 | 260 | def is_bipartite_graph(self): return False 261 | 262 | def is_passive(self, rank=None): return False 263 | 264 | def is_dynamic_graph(self, graph_type=None): return False 265 | -------------------------------------------------------------------------------- /congestion_control/CongestionControlEnv.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | * 8 | */ 9 | #include "CongestionControlEnv.h" 10 | 11 | #include 12 | #include 13 | 14 | namespace quic { 15 | 16 | using Field = NetworkState::Field; 17 | 18 | static const float kBytesToMB = 1e-6; 19 | 20 | /// CongestionControlEnv impl 21 | 22 | CongestionControlEnv::CongestionControlEnv(const Config &cfg, Callback *cob, 23 | const QuicConnectionStateBase &conn) 24 | : cfg_(cfg), cob_(CHECK_NOTNULL(cob)), conn_(conn), 25 | evb_(folly::EventBaseManager::get()->getEventBase()), 26 | observationTimeout_(this, evb_), 27 | cwndBytes_(conn.transportSettings.initCwndInMss * conn.udpSendPacketLen), 28 | rewardCount_(0), rewardSum_(0.f) { 29 | // Initialize history with no-op past actions 30 | History noopHistory(Action{0}, cwndBytes_ / normBytes()); 31 | history_.resize(cfg.historySize, noopHistory); 32 | 33 | if (cfg.aggregation == Config::Aggregation::TIME_WINDOW) { 34 | CHECK_GT(cfg.windowDuration.count(), 0); 35 | observationTimeout_.schedule(cfg.windowDuration); 36 | } 37 | } 38 | 39 | void CongestionControlEnv::onAction(const Action &action) { 40 | evb_->runImmediatelyOrRunInEventBaseThreadAndWait([this, action] { 41 | updateCwnd(action.cwndAction); 42 | cob_->onUpdate(cwndBytes_); 43 | 44 | // Update history 45 | history_.pop_front(); 46 | history_.emplace_back(action, cwndBytes_ / normBytes()); 47 | 48 | const auto &elapsed = std::chrono::duration( 49 | std::chrono::steady_clock::now() - lastObservationTime_); 50 | VLOG(1) << "Action updated (cwndAction=" << action.cwndAction 51 | << ", cwnd=" << cwndBytes_ / conn_.udpSendPacketLen 52 | << "), policy elapsed time = " << elapsed.count() << " ms"; 53 | }); 54 | } 55 | 56 | uint64_t CongestionControlEnv::getUpdatedCwndBytes(uint64_t currentCwndBytes, 57 | uint32_t actionIdx) const { 58 | DCHECK_LT(actionIdx, cfg_.actions.size()); 59 | const auto &op = cfg_.actions[actionIdx].first; 60 | const auto &val = cfg_.actions[actionIdx].second; 61 | const auto &valBytes = val * conn_.udpSendPacketLen; 62 | 63 | switch (op) { 64 | case Config::ActionOp::NOOP: 65 | break; 66 | case Config::ActionOp::ADD: 67 | currentCwndBytes += valBytes; 68 | break; 69 | case Config::ActionOp::SUB: 70 | currentCwndBytes = 71 | (currentCwndBytes >= valBytes) ? (currentCwndBytes - valBytes) : 0; 72 | break; 73 | case Config::ActionOp::MUL: 74 | currentCwndBytes = std::round(currentCwndBytes * val); 75 | break; 76 | case Config::ActionOp::DIV: 77 | currentCwndBytes = std::round(currentCwndBytes * 1.0 / val); 78 | break; 79 | default: 80 | LOG(FATAL) << "Unknown ActionOp"; 81 | break; 82 | } 83 | 84 | return boundedCwnd(currentCwndBytes, conn_.udpSendPacketLen, 85 | conn_.transportSettings.maxCwndInMss, 86 | conn_.transportSettings.minCwndInMss); 87 | } 88 | 89 | void CongestionControlEnv::onNetworkState(NetworkState &&state) { 90 | VLOG(3) << __func__ << ": " << state; 91 | 92 | states_.push_back(std::move(state)); 93 | 94 | switch (cfg_.aggregation) { 95 | case Config::Aggregation::TIME_WINDOW: 96 | DCHECK(observationTimeout_.isScheduled()); 97 | break; 98 | case Config::Aggregation::FIXED_WINDOW: 99 | if (states_.size() == cfg_.windowSize) { 100 | handleStates(); 101 | } 102 | break; 103 | default: 104 | LOG(FATAL) << "Unknown aggregation type"; 105 | break; 106 | } 107 | } 108 | 109 | void CongestionControlEnv::observationTimeoutExpired() noexcept { 110 | handleStates(); 111 | observationTimeout_.schedule(cfg_.windowDuration); 112 | } 113 | 114 | void CongestionControlEnv::handleStates() { 115 | if (states_.empty()) { 116 | return; 117 | } 118 | 119 | // Compute reward based on original states 120 | const float reward = computeReward(states_); 121 | 122 | ++rewardCount_; 123 | rewardSum_ += reward; 124 | if (rewardCount_ % 10 == 0) { 125 | VLOG(1) << __func__ << ": for jobId= " << cfg_.jobId 126 | << ", after " << rewardCount_ 127 | << " steps, avg reward = " << (rewardSum_ / rewardCount_); 128 | } 129 | 130 | Observation obs(cfg_); 131 | obs.states = useStateSummary() ? stateSummary(states_) : std::move(states_); 132 | states_.clear(); 133 | std::copy(history_.begin(), history_.end(), std::back_inserter(obs.history)); 134 | 135 | VLOG(2) << __func__ << ' ' << obs; 136 | 137 | lastObservationTime_ = std::chrono::steady_clock::now(); 138 | onObservation(std::move(obs), reward); 139 | } 140 | 141 | quic::utils::vector CongestionControlEnv::stateSummary( 142 | const quic::utils::vector &states) { 143 | int dim = 0; 144 | bool keepdim = true; 145 | // Bassel's correction on stddev only when defined to avoid NaNs. 146 | bool unbiased = (states.size() > 1); 147 | 148 | NetworkState::toTensor(states, summaryTensor_); 149 | const auto &sum = torch::sum(summaryTensor_, dim, keepdim); 150 | const auto &std_mean = 151 | torch::std_mean(summaryTensor_, dim, unbiased, keepdim); 152 | const auto &min = torch::amin(summaryTensor_, dim, keepdim); 153 | const auto &max = torch::amax(summaryTensor_, dim, keepdim); 154 | // If these statistics are modified / re-ordered, make sure to also update 155 | // the corresponding `OFFSET_*` constants in state.py. 156 | const auto &summary = torch::cat( 157 | {sum, std::get<1>(std_mean), std::get<0>(std_mean), min, max}, dim); 158 | auto summaryStates = NetworkState::fromTensor(summary); 159 | 160 | // Certain stats for some fields don't make sense such as sum over 161 | // RTT from ACKs. Zero-out them. 162 | static const quic::utils::vector invalidSumFields = { 163 | Field::RTT_MIN, Field::RTT_STANDING, Field::LRTT, 164 | Field::SRTT, Field::RTT_VAR, Field::DELAY, 165 | Field::CWND, Field::IN_FLIGHT, Field::WRITABLE, 166 | }; 167 | for (const Field field : invalidSumFields) { 168 | summaryStates[0][field] = 0.0; 169 | } 170 | 171 | static const quic::utils::vector keys = { 172 | "Sum", "Mean", "Std", "Min", "Max", 173 | }; 174 | VLOG(2) << "State summary: "; 175 | for (size_t i = 0; i < summaryStates.size(); ++i) { 176 | VLOG(2) << keys[i] << ": " << summaryStates[i]; 177 | } 178 | 179 | return summaryStates; 180 | } 181 | 182 | float CongestionControlEnv::computeReward( 183 | const quic::utils::vector &states) const { 184 | // Reward function is a combinaton of throughput, delay and lost bytes. 185 | // For throughput and delay, it makes sense to take the average, whereas 186 | // for loss, we compute the total bytes lost over these states. 187 | float avgThroughput = 0.0; 188 | float avgDelay = 0.0; 189 | float maxDelay = 0.0; 190 | float totalLost = 0.0; 191 | for (const auto &state : states) { 192 | avgThroughput += state[Field::THROUGHPUT]; 193 | avgDelay += state[Field::DELAY]; 194 | maxDelay = std::max(maxDelay, state[Field::DELAY]); 195 | totalLost += state[Field::LOST]; 196 | } 197 | avgThroughput /= states.size(); 198 | avgDelay /= states.size(); 199 | 200 | // Undo normalization and convert to MB/sec for throughput and ms for 201 | // delay. 202 | float throughputMBps = avgThroughput * normBytes() * kBytesToMB; 203 | float avgDelayMs = avgDelay * normMs(); 204 | float maxDelayMs = maxDelay * normMs(); 205 | float delayMs = (cfg_.maxDelayInReward ? maxDelayMs : avgDelayMs); 206 | float lostMbits = totalLost * normBytes() * kBytesToMB; 207 | 208 | float reward = 0.f; 209 | if (cfg_.rewardLogRatio) { 210 | reward = 211 | cfg_.throughputFactor * log(cfg_.throughputLogOffset + throughputMBps) - 212 | cfg_.delayFactor * log(cfg_.delayLogOffset + delayMs) - 213 | cfg_.packetLossFactor * log(cfg_.packetLossLogOffset + lostMbits); 214 | } else { 215 | reward = cfg_.throughputFactor * throughputMBps - 216 | cfg_.delayFactor * delayMs - cfg_.packetLossFactor * lostMbits; 217 | } 218 | VLOG(1) << "Num states = " << states.size() 219 | << " avg throughput = " << throughputMBps 220 | << " MB/sec, avg delay = " << avgDelayMs 221 | << " ms, max delay = " << maxDelayMs 222 | << " ms, total Mb lost = " << lostMbits << ", reward = " << reward; 223 | return reward; 224 | } 225 | 226 | void CongestionControlEnv::updateCwnd(const uint32_t actionIdx) { 227 | cwndBytes_ = getUpdatedCwndBytes(cwndBytes_, actionIdx); 228 | } 229 | 230 | /// CongestionControlEnv::Observation impl 231 | 232 | torch::Tensor CongestionControlEnv::Observation::toTensor() const { 233 | torch::Tensor tensor = torch::empty({0}, torch::kFloat32); 234 | toTensor(tensor); 235 | return tensor; 236 | } 237 | 238 | void CongestionControlEnv::Observation::toTensor(torch::Tensor &tensor) const { 239 | if (states.empty()) { 240 | tensor.resize_({0}); 241 | return; 242 | } 243 | 244 | CHECK_EQ(history.size(), cfg_.historySize); 245 | 246 | // Dim per history = len(one-hot actions) + 1 (cwnd). 247 | // Total dim = flattened state dim + history dim + 1 (job ID) 248 | uint32_t historyDim = cfg_.actions.size() + 1; 249 | uint32_t dim = states.size() * states[0].size() + history.size() * historyDim + 1; 250 | 251 | tensor.resize_({dim}); 252 | auto tensor_a = tensor.accessor(); 253 | int x = 0; 254 | 255 | // Serialize states 256 | for (const auto &state : states) { 257 | for (size_t i = 0; i < state.size(); ++i) { 258 | tensor_a[x++] = state[i]; 259 | } 260 | } 261 | 262 | // Serialize history 263 | for (const auto &h : history) { 264 | for (size_t i = 0; i < cfg_.actions.size(); ++i) { 265 | tensor_a[x++] = (h.action.cwndAction == i); 266 | } 267 | tensor_a[x++] = h.cwnd; 268 | } 269 | 270 | // Append the job ID (IMPORTANT: it must remain at the end of the tensor) 271 | tensor_a[x++] = cfg_.jobId; 272 | 273 | CHECK_EQ(x, dim); 274 | } 275 | 276 | std::ostream &operator<<(std::ostream &os, 277 | const CongestionControlEnv::Observation &obs) { 278 | os << "Observation (" << obs.states.size() << " states, " 279 | << obs.history.size() << " history):" << std::endl; 280 | for (const auto &state : obs.states) { 281 | os << state << std::endl; 282 | } 283 | for (const auto &history : obs.history) { 284 | os << history << std::endl; 285 | } 286 | return os; 287 | } 288 | 289 | std::ostream &operator<<(std::ostream &os, 290 | const CongestionControlEnv::History &history) { 291 | os << "History: action=" << history.action.cwndAction 292 | << " cwnd=" << history.cwnd; 293 | return os; 294 | } 295 | 296 | } // namespace quic 297 | -------------------------------------------------------------------------------- /train/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import argparse 8 | import inspect 9 | import itertools 10 | import logging 11 | import os 12 | import shutil 13 | import signal 14 | import socket 15 | import string 16 | import sys 17 | import time 18 | import yaml 19 | 20 | from dataclasses import field 21 | from datetime import datetime 22 | from enum import Enum 23 | from typing import Optional 24 | 25 | import psutil 26 | 27 | from train.constants import SRC_DIR, PANTHEON_ROOT, EXPERIMENTS_CFG 28 | 29 | 30 | class _StrEnum(str, Enum): 31 | """ 32 | "String" enum, such that one can use `if my_enum == "some_string"`. 33 | 34 | Do not instantiate this class directly: instead use the `StrEnum()` function. 35 | """ 36 | 37 | def _generate_next_value_(name, start, count, last_values): 38 | # This makes it so that the enum's *value* is equal to its *name*. 39 | # As a result, string comparisons work as expected. 40 | return name 41 | 42 | 43 | def StrEnum(name, values): 44 | """ 45 | Helper function to create a "String" Enum. 46 | 47 | :param name: Name of the Enum being created. The module calling this function 48 | should not have any other member with the same name (because this function 49 | will add the created Enum as a member under that name). 50 | :param values: Values the Enum can take, either as a list of strings, or a 51 | single string with values separated by whitespaces or commas. 52 | :return: The corresponding Enum class. 53 | 54 | The motivation for using this helper function vs instantiating `_StrEnum` 55 | directly is that this function also registers the created Enum within the 56 | calling module. This makes it possible to pickle / unpickle this Enum as 57 | long as it is created at import time. The typical use case is within 58 | structured configs, as follows: 59 | 60 | @dataclass 61 | class MyConfig: 62 | color: StrEnum("Color", "red, green, blue") = "blue" 63 | 64 | Such a config (after being processed through Hydra / OmegaConf) may be used 65 | with string comparisons: 66 | 67 | if cfg.color == "red": 68 | ... 69 | elif cfg.color == "green": 70 | ... 71 | ... 72 | 73 | NB: a function is used instead of adding this logic to `_StrEnum.__init__()` 74 | because it turns out to be tricky to override an Enum constructor. 75 | """ 76 | enum = _StrEnum(name, values) 77 | # Obtain the module this function was called from. 78 | stack = inspect.stack()[1] 79 | src_module = inspect.getmodule(stack[0]) 80 | # Add the generated Enum as member of this module. 81 | assert not hasattr( 82 | src_module, name 83 | ), f"module {src_module} already has a member '{name}'" 84 | setattr(src_module, name, enum) 85 | # We also need to set the Enum's module accordingly. As a result the unpickling 86 | # sequence of an instance of this Enum will be as follows: 87 | # 1. Identify this Enum's module as `src_module` 88 | # 2. Import `src_module` 89 | # 3. During import, this function (= `StrEnum()`) is called 90 | # 4. During this call, the Enum class is created and registered into `src_module` 91 | # 5. Once `src_module` is imported, the Enum class is obtained from it and the 92 | # Enum instance can be created 93 | enum.__module__ = src_module.__name__ 94 | return enum 95 | 96 | 97 | def add_to_path(path): 98 | """ 99 | Add a path to `sys.path` / PYTHONPATH env variable. 100 | 101 | We also add it to PYTHONPATH so that when unpickling an object (e.g. with 102 | submitit) Python can find the required packages even if the module 103 | modifying `sys.path` has not been imported yet. 104 | """ 105 | if path not in sys.path: 106 | sys.path.append(path) 107 | try: 108 | python_path = os.environ["PYTHONPATH"] 109 | except KeyError: 110 | os.environ["PYTHONPATH"] = path 111 | else: 112 | tokens = python_path.split(os.pathsep) 113 | if path not in tokens: 114 | os.environ["PYTHONPATH"] += os.pathsep + path 115 | 116 | 117 | def default_empty_list(): 118 | """ 119 | Helper function to declare a dataclass field whose default value is an empty list. 120 | """ 121 | return field(default_factory=list) 122 | 123 | 124 | def default_list(lst): 125 | """ 126 | Helper function to declare a dataclass field whose default value is the list `lst`. 127 | """ 128 | # We make a copy of `lst` to be sure it is not accidentally shared. 129 | return field(default_factory=lambda: list(lst)) 130 | 131 | 132 | def get_actions(num_actions): 133 | ACTIONS = { 134 | 5: ["0", "/2", "-10", "+10", "*2"], 135 | 7: ["0", "/2", "/1.5", "-10", "+10", "*1.5", "*2"], 136 | 9: ["0", "/2", "/1.5", "/1.25", "-10", "+10", "*1.25", "*1.5", "*2"], 137 | 11: [ 138 | "0", 139 | "/5", 140 | "/2", 141 | "/1.5", 142 | "/1.25", 143 | "-10", 144 | "+10", 145 | "*1.25", 146 | "*1.5", 147 | "*2", 148 | "*5", 149 | ], 150 | } 151 | assert num_actions in ACTIONS, "Unsupported num_actions" 152 | return ACTIONS[num_actions] 153 | 154 | 155 | def get_cpus_per_task(mode, num_actors, test_job_ids, test_after_train, max_jobs): 156 | """Return number of CPUs to reserve given the current settings""" 157 | from train import pantheon_env # lazy import to avoid circular dependencies 158 | 159 | # Reserve 2 CPUs per Pantheon "thread" (at least 4 total). During training, 160 | # the number of threads is equal to the number of actors, while during 161 | # testing it is equal to the number of jobs. 162 | n_cpus_min = 4 163 | n_threads_train = n_threads_test = 0 164 | 165 | assert mode in ["train", "test"] 166 | if mode == "train": 167 | n_threads_train = num_actors 168 | if mode == "test" or test_after_train: 169 | jobs = pantheon_env.get_jobs_to_perform(test_job_ids, max_jobs) 170 | n_threads_test = len(jobs) 171 | 172 | return max(n_cpus_min, n_threads_train * 2, n_threads_test * 2) 173 | 174 | 175 | def get_jobs(flags, mode=None): 176 | from train import pantheon_env # lazy import to avoid circular dependencies 177 | 178 | mode = flags.mode if mode is None else mode 179 | if mode == "train": 180 | job_ids = flags.train_job_ids 181 | elif mode == "test": 182 | job_ids = flags.test_job_ids 183 | else: 184 | raise ValueError(mode) 185 | 186 | return pantheon_env.get_jobs_to_perform(job_ids, flags.max_jobs) 187 | 188 | 189 | def get_n_jobs(flags, mode=None): 190 | return len(get_jobs(flags, mode=mode)) 191 | 192 | 193 | def get_observation_length(history_size, num_actions): 194 | # The observation contains: 195 | # - state summary stats (5 * 20) (5 because sum / mean / std / min /max) 196 | # - history_size * (one-hot actions + cwnd) 197 | # - job ID 198 | return 100 + history_size * (num_actions + 1) + 1 199 | 200 | 201 | def get_slurm_constraint(partition: str, gpus_per_node: int) -> Optional[str]: 202 | """Return the constraint to be used by the `submitit_slurm` launcher""" 203 | if partition in ["priority", "learnfair"] and gpus_per_node <= 2: 204 | # If we are on the right environment, use constraint "gpu2". 205 | host = socket.gethostname() 206 | if host.startswith("devfair") and len(host) == 11: # H2? 207 | return "gpu2" 208 | return None 209 | 210 | 211 | def str2bool(v): 212 | if v.lower() in ("yes", "true", "t", "y", "1"): 213 | return True 214 | elif v.lower() in ("no", "false", "f", "n", "0"): 215 | return False 216 | else: 217 | raise argparse.ArgumentTypeError("Boolean value expected.") 218 | 219 | 220 | class SafeDict(dict): 221 | def __missing__(self, key): 222 | return "{" + key + "}" 223 | 224 | 225 | # format 'format_string' but ignore keys that do not exist in 'key_dict' 226 | def safe_format(format_string, key_dict): 227 | return string.Formatter().vformat(format_string, (), SafeDict(key_dict)) 228 | 229 | 230 | def parse_experiments(): 231 | with open(EXPERIMENTS_CFG) as cfg: 232 | return yaml.full_load( 233 | safe_format( 234 | cfg.read(), {"src_dir": SRC_DIR, "pantheon_root": PANTHEON_ROOT} 235 | ) 236 | ) 237 | 238 | 239 | def delete_dir(dir_path, max_tries=1, sleep_time=1): 240 | """Delete a directory (with potential retry mechanism)""" 241 | if not os.path.exists(dir_path): 242 | return 243 | 244 | for i in range(max_tries): 245 | try: 246 | shutil.rmtree(dir_path) 247 | except Exception: 248 | if i == max_tries - 1: 249 | logging.warning("Failed to delete dir (giving up): %s", dir_path) 250 | break 251 | else: 252 | logging.info("Failed to delete dir (will try again): %s", dir_path) 253 | time.sleep(sleep_time) 254 | else: 255 | logging.info("Deleted dir: %s", dir_path) 256 | break 257 | 258 | 259 | expt_cfg = parse_experiments() 260 | meta = expt_cfg["meta"] 261 | 262 | 263 | def expand_matrix(matrix_cfg): 264 | input_list = [] 265 | for variable, value_list in matrix_cfg.items(): 266 | input_list.append([{variable: value} for value in value_list]) 267 | 268 | ret = [] 269 | for element in itertools.product(*input_list): 270 | tmp = {} 271 | for kv in element: 272 | tmp.update(kv) 273 | ret.append(tmp) 274 | 275 | return ret 276 | 277 | 278 | # Mostly copied from https://psutil.readthedocs.io/en/latest/#kill-process-tree 279 | def kill_proc_tree( 280 | pid, sig=signal.SIGKILL, include_parent=True, timeout=None, on_terminate=None 281 | ): 282 | """Kill a process tree (including grandchildren) with signal 283 | "sig" and return a (gone, still_alive) tuple. 284 | "on_terminate", if specified, is a callabck function which is 285 | called as soon as a child terminates. 286 | """ 287 | assert pid != os.getpid(), "won't kill myself" 288 | parent = psutil.Process(pid) 289 | children = parent.children(recursive=True) 290 | if include_parent: 291 | children.append(parent) 292 | for p in children: 293 | try: 294 | p.send_signal(sig) 295 | except psutil.NoSuchProcess: 296 | # It seems possible (in rare cases) for the process to have 297 | # terminated already, triggering this exception. 298 | continue 299 | 300 | gone, alive = psutil.wait_procs(children, timeout=timeout, callback=on_terminate) 301 | return (gone, alive) 302 | 303 | 304 | def utc_date(): 305 | return datetime.utcnow().strftime("%Y-%m-%dT%H-%M") 306 | --------------------------------------------------------------------------------