├── doc ├── .gitignore └── model.png ├── .vscode ├── .gitignore ├── settings.json └── extensions.json ├── .gitignore ├── .flake8 ├── .pre-commit-config.yaml ├── .github └── workflows │ ├── lint.yml │ └── train-example.yml ├── pyproject.toml ├── LICENSE ├── CITATION.cff ├── src ├── layernormlstm.py ├── buffer.py ├── util.py ├── env │ ├── environment.py │ ├── wrapper.py │ ├── simple_environment.py │ ├── network.py │ ├── constants.py │ └── routing.py ├── policy.py ├── replaybuffer.py ├── eval.py ├── sl.py └── model.py ├── scripts ├── start_sl_runs.sh ├── start_routing_netmon_runs.sh └── start_routing_runs.sh └── README.md /doc/.gitignore: -------------------------------------------------------------------------------- 1 | *.bkp -------------------------------------------------------------------------------- /.vscode/.gitignore: -------------------------------------------------------------------------------- 1 | launch.json -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | runs 3 | mirror 4 | *.egg-info -------------------------------------------------------------------------------- /doc/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jw3il/graph-marl/HEAD/doc/model.png -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # flake8 example config based on 3 | # https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html 4 | max-line-length = 90 5 | select = C,E,F,W,B,B950 6 | extend-ignore = E203, E501, W503 7 | exclude = ./.git,./archive -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: "22.12.0" 4 | hooks: 5 | - id: black 6 | args: [./src] 7 | - repo: https://github.com/PyCQA/flake8 8 | rev: "6.0.0" 9 | hooks: 10 | - id: flake8 11 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "autoDocstring.docstringFormat": "sphinx-notypes", 3 | "editor.detectIndentation": true, 4 | "editor.formatOnSaveMode": "file", 5 | "editor.formatOnSave": true, 6 | "[python]": { 7 | "editor.formatOnSave": true, 8 | }, 9 | "python.analysis.typeCheckingMode": "off", 10 | "editor.rulers": [ 11 | 90 12 | ], 13 | "editor.defaultFormatter": "ms-python.black-formatter" 14 | } -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: 3 | workflow_dispatch: 4 | push: 5 | paths: 6 | - '**.py' 7 | pull_request: 8 | paths: 9 | - '**.py' 10 | 11 | jobs: 12 | lint: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.9' 19 | cache: 'pip' 20 | - name: Install flake8 and black 21 | run: pip install flake8==6.0.0 black==22.12.0 flake8-black 22 | - run: flake8 . 23 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=827846 to learn about workspace recommendations. 3 | // Extension identifier format: ${publisher}.${name}. Example: vscode.csharp 4 | 5 | // List of extensions which should be recommended for users of this workspace. 6 | "recommendations": [ 7 | "ms-python.python", 8 | "ms-python.black-formatter", 9 | "ms-python.flake8", 10 | "njpwerner.autodocstring", 11 | "streetsidesoftware.code-spell-checker" 12 | ], 13 | // List of extensions recommended by VS Code that should not be recommended for users of this workspace. 14 | "unwantedRecommendations": [ 15 | 16 | ] 17 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "graph-marl" 3 | description = "Multi-Agent Reinforcement Learning in Graphs" 4 | version = "0.1.0" 5 | authors = [ 6 | {name = "Jannis Weil"}, 7 | {name = "Zhenghua Bao"}, 8 | {name = "Amirkasra Amini"} 9 | ] 10 | maintainers = [ 11 | {name = "Jannis Weil"} 12 | ] 13 | dependencies = [ 14 | "torch>=2.0", 15 | "torchvision", 16 | "torchaudio", 17 | "gym", 18 | "numpy", 19 | "pettingzoo", 20 | "matplotlib!=3.7.2", 21 | "tqdm", 22 | "networkx>=3.1", 23 | "tensorboard", 24 | "tables", 25 | "gymnasium", 26 | "torch_geometric", 27 | "tabulate" 28 | ] 29 | 30 | [project.optional-dependencies] 31 | dev = [ 32 | "simple-gpu-scheduler", 33 | "flake8==6.0.0", 34 | "black==22.12.0", 35 | "pre-commit" 36 | ] 37 | temporal = ["torch-geometric-temporal"] 38 | 39 | [build-system] 40 | requires = ["setuptools>=61.0"] 41 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jannis Weil, Zhenghua Bao, Amirkasra Amini 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/train-example.yml: -------------------------------------------------------------------------------- 1 | name: Train Example 2 | on: 3 | workflow_dispatch: 4 | push: 5 | paths: 6 | - '**.py' 7 | pull_request: 8 | paths: 9 | - '**.py' 10 | 11 | jobs: 12 | train-example: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.9' 19 | cache: 'pip' 20 | - name: Install dependencies 21 | run: | 22 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu 23 | pip install -e . 24 | - name: Train agent and check result 25 | run: | 26 | python src/main.py --model=dqn --hidden-dim=8 --random-topology=1 --mini-batch-size=32 --device=cpu --episode-steps=1 --eval-episode-steps=1 --lr=0.001 --tau=0.01 --netmon --netmon-encoder-dim=4 --hidden-dim=4 --netmon-dim=2 --netmon-iterations=1 --sequence-length=1 --step-before-train=1_000 --capacity=10_000 --eval-episodes=100 --total-steps=5_000 --env-type=simple --epsilon=0.1 --epsilon-decay=1.0 --seed=0 --disable-progress | tee -a train.txt 27 | cat train.txt | grep "\"reward_mean\": 1.0" 28 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: Autonomous Communication in Cooperative Computer Networks 6 | message: >- 7 | If you use this software, please cite it using the 8 | metadata from this file. If you use it in your research, 9 | please cite the listed conference paper. 10 | type: software 11 | authors: 12 | - given-names: Jannis 13 | family-names: Weil 14 | - given-names: Zhenghua 15 | family-names: Bao 16 | - given-names: Amirkasra 17 | family-names: Amini 18 | license: MIT 19 | preferred-citation: 20 | type: conference-paper 21 | title: >- 22 | Towards Generalizability of Multi-Agent Reinforcement 23 | Learning in Graphs with Recurrent Message Passing 24 | authors: 25 | - family-names: Weil 26 | given-names: Jannis 27 | - family-names: Bao 28 | given-names: Zhenghua 29 | - family-names: Abboud 30 | given-names: Osama 31 | - family-names: Meuser 32 | given-names: Tobias 33 | collection-title: >- 34 | Proceedings of the 23rd International Conference on 35 | Autonomous Agents and Multiagent Systems 36 | notes: accepted, to appear 37 | year: 2024 38 | -------------------------------------------------------------------------------- /src/layernormlstm.py: -------------------------------------------------------------------------------- 1 | from torch.nn import RNNCellBase 2 | from torch.nn import LayerNorm 3 | from torch import Tensor 4 | import torch 5 | from typing import Tuple 6 | 7 | 8 | class LayerNormLSTMCell(RNNCellBase): 9 | """ 10 | LayerNorm LSTM Cell from https://arxiv.org/abs/1607.06450 11 | based on implementation from pytorch fastrnn benchmark 12 | https://github.com/pytorch/pytorch/blob/cbcb2b5ad767622cf5ec04263018609bde3c974a/benchmarks/fastrnns/custom_lstms.py#L149 13 | """ 14 | 15 | def __init__(self, input_size, hidden_size, bias=True): 16 | super().__init__(input_size, hidden_size, bias, num_chunks=4) 17 | # we only use a single bias, double bias from RNNCellBase would only be necessary for 18 | # cuDNN compatibility, see https://pytorch.org/docs/2.0/_modules/torch/nn/modules/rnn.html#RNNBase 19 | del self.bias_hh 20 | self.ln_input = LayerNorm(4 * hidden_size) 21 | self.ln_hidden = LayerNorm(4 * hidden_size) 22 | self.ln_cell = LayerNorm(hidden_size) 23 | 24 | def forward( 25 | self, input: Tensor, state: Tuple[Tensor, Tensor] 26 | ) -> Tuple[Tensor, Tensor]: 27 | hx, cx = state 28 | i_gates = self.ln_input(torch.mm(input, self.weight_ih.t())) 29 | h_gates = self.ln_hidden(torch.mm(hx, self.weight_hh.t())) 30 | # add bias after layer norm 31 | gates = i_gates + h_gates + self.bias_ih 32 | in_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) 33 | 34 | in_gate = torch.sigmoid(in_gate) 35 | forget_gate = torch.sigmoid(forget_gate) 36 | cell_gate = torch.tanh(cell_gate) 37 | out_gate = torch.sigmoid(out_gate) 38 | 39 | cy = self.ln_cell((forget_gate * cx) + (in_gate * cell_gate)) 40 | hy = out_gate * torch.tanh(cy) 41 | 42 | return hy, cy 43 | -------------------------------------------------------------------------------- /src/buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Buffer: 5 | """ 6 | Simple buffer that stores data in a numpy array and expands automatically. 7 | """ 8 | 9 | def __init__(self, capacity: int, shape, dtype) -> None: 10 | """ 11 | Create a buffer with initial capacity. 12 | 13 | :param capacity: initial capacity of the buffer 14 | :param shape: shape of individual elements 15 | :param dtype: data type 16 | """ 17 | self._data = np.empty((capacity, *shape), dtype=dtype) 18 | self._capacity = capacity 19 | self._count = 0 20 | 21 | def insert(self, a): 22 | """ 23 | Insert an element into the buffer and automatically expand the buffer 24 | if the max capacity has been reached. 25 | 26 | :param a: the element that is inserted 27 | """ 28 | if isinstance(a, list): 29 | for e in a: 30 | self._insert_element(e) 31 | else: 32 | self._insert_element(a) 33 | 34 | def _insert_element(self, elem): 35 | if self._count == self._capacity: 36 | self._data = np.concatenate((self._data, self._data), axis=0) 37 | self._capacity = self._data.shape[0] 38 | 39 | self._data[self._count] = elem 40 | self._count += 1 41 | 42 | def get(self) -> np.ndarray: 43 | """ 44 | Get the content of the buffer. 45 | 46 | :return: current content 47 | """ 48 | return self._data[: self._count] 49 | 50 | def clear(self): 51 | """ 52 | Clears the buffer. 53 | """ 54 | self._count = 0 55 | 56 | def mean(self, default=0): 57 | """ 58 | Get the mean of all values. 59 | 60 | :param default: returned if there are no elements, defaults to 0 61 | :return: mean of all values or default value 62 | """ 63 | return self.get().mean() if self._count > 0 else default 64 | -------------------------------------------------------------------------------- /scripts/start_sl_runs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Allows using multiple GPUs (e.g. "0 1 2 3") and also assigning multiple 4 | # jobs to the same GPU (e.g. "0 0"). 5 | GPUS="0" 6 | 7 | OFF="echo" 8 | 9 | DATE=$(date +%Y%m%d_%H%M%S) 10 | DIR_NAME="${DATE}_logs_netmon_sl" 11 | 12 | mkdir -p $DIR_NAME 13 | 14 | # create backup of the code 15 | git log -n 1 >> $DIR_NAME/git.txt 16 | git status >> $DIR_NAME/git.txt 17 | git diff >> $DIR_NAME/git.txt 18 | cp -r src $DIR_NAME/src 19 | # copy this script 20 | cp $0 $DIR_NAME/$(basename "$0") 21 | 22 | run() { 23 | RUN_NAME="$1" 24 | shift 25 | RUN_ARGS="$@" 26 | echo "(set -x; time python -u src/sl.py $RUN_ARGS --filename=${DIR_NAME}/${RUN_NAME}.h5) > ${DIR_NAME}/${RUN_NAME}.log 2>&1" 27 | } 28 | 29 | REST_ARGS="--iterations=50_000 --num-samples-train=99_000 --validate-after=500 --disable-progressbar" 30 | time ( 31 | for i in 0 1 2 32 | do 33 | run "netmon-1it-8seq-$i" --seed=$i --netmon-iterations=1 --sequence-length=8 $REST_ARGS 34 | run "netmon-1it-16seq-$i" --seed=$i --netmon-iterations=1 --sequence-length=16 $REST_ARGS 35 | run "netmon-2it-8seq-$i" --seed=$i --netmon-iterations=2 --sequence-length=8 $REST_ARGS 36 | run "netmon-4it-8seq-$i" --seed=$i --netmon-iterations=4 --sequence-length=8 $REST_ARGS 37 | 38 | run "gconvlstm-1it-8seq-$i" --seed=$i --netmon-iterations=1 --sequence-length=8 --netmon-agg-type=gconvlstm $REST_ARGS 39 | run "gconvlstm-1it-16seq-$i" --seed=$i --netmon-iterations=1 --sequence-length=16 --netmon-agg-type=gconvlstm $REST_ARGS 40 | run "gconvlstm-2it-8seq-$i" --seed=$i --netmon-iterations=2 --sequence-length=8 --netmon-agg-type=gconvlstm $REST_ARGS 41 | run "gconvlstm-4it-8seq-$i" --seed=$i --netmon-iterations=4 --sequence-length=8 --netmon-agg-type=gconvlstm $REST_ARGS 42 | 43 | run "graphsage-8it-1seq-$i" --seed=$i --netmon-iterations=8 --sequence-length=1 --netmon-agg-type=graphsage --netmon-rnn-type=none $REST_ARGS 44 | run "graphsage-16it-1seq-$i" --seed=$i --netmon-iterations=16 --sequence-length=1 --netmon-agg-type=graphsage --netmon-rnn-type=none $REST_ARGS 45 | 46 | run "antisymgcn-8it-1seq-$i" --seed=$i --netmon-iterations=8 --sequence-length=1 --netmon-agg-type=antisymgcn --netmon-rnn-type=none $REST_ARGS 47 | run "antisymgcn-16it-1seq-$i" --seed=$i --netmon-iterations=16 --sequence-length=1 --netmon-agg-type=antisymgcn --netmon-rnn-type=none $REST_ARGS 48 | done 49 | ) | simple_gpu_scheduler --gpus $GPUS 50 | -------------------------------------------------------------------------------- /scripts/start_routing_netmon_runs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Allows using multiple GPUs (e.g. "0 1 2 3") and also assigning multiple 4 | # jobs to the same GPU (e.g. "0 0"). 5 | GPUS="0" 6 | 7 | BASE_PARAMS="--step-between-train=10 --total-steps=2_500_000 --netmon --model=dqn --random-topology=1 --gamma=0.9 --epsilon=1.0 --epsilon-decay=0.999 --hidden-dim=512,256 --netmon-encoder-dim=512,256 --netmon-dim=128 --mini-batch-size=32 --device=cuda --lr=0.001 --tau=0.01 --step-before-train=100_000 --capacity=200_000 --eval-episodes=1000 --eval-episode-steps=300 --episode-steps=50 --disable-progressbar" 8 | 9 | OFF="echo" 10 | 11 | DATE=$(date +%Y%m%d_%H%M%S) 12 | DIR_NAME="${DATE}_logs_dynamic" 13 | 14 | mkdir -p $DIR_NAME 15 | 16 | # create backup of the code 17 | git log -n 1 >> $DIR_NAME/git.txt 18 | git status >> $DIR_NAME/git.txt 19 | git diff >> $DIR_NAME/git.txt 20 | cp -r src $DIR_NAME/src 21 | # copy this script 22 | cp $0 $DIR_NAME/$(basename "$0") 23 | 24 | run() { 25 | RUN_NAME="$1" 26 | shift 27 | RUN_ARGS="$@" 28 | echo "(set -x; time python -u src/main.py $RUN_ARGS --comment=${RUN_NAME}) > ${DIR_NAME}/${RUN_NAME}.log 2>&1" 29 | } 30 | 31 | time ( 32 | for i in 0 1 2 33 | do 34 | for nocong in 0 1 35 | do 36 | if [ "$nocong" -eq "0" ]; then 37 | NOCONG_RUNNAME="" 38 | NOCONG_ARG="" 39 | else 40 | NOCONG_RUNNAME="-nocong" 41 | NOCONG_ARG="--no-congestion" 42 | fi 43 | run "dynamic${NOCONG_RUNNAME}-shortest-paths-eval-$i" --policy=heuristic --eval --random-topology=1 --disable-progressbar --eval-output-dir=${DIR_NAME}/dynamic${NOCONG_RUNNAME}-shortest-paths-eval-$i/eval --seed=$i $NOCONG_ARG 44 | run "dynamic${NOCONG_RUNNAME}-netmon-1it-8seq-$i" --netmon-agg-type=sum --netmon-rnn-type=lstm --netmon-iterations=1 --sequence-length=8 --seed=$i $NOCONG_ARG $BASE_PARAMS 45 | run "dynamic${NOCONG_RUNNAME}-netmon-gconvlstm-1it-8seq-$i" --netmon-agg-type=gconvlstm --netmon-rnn-type=gconvlstm --netmon-iterations=1 --sequence-length=8 --seed=$i $NOCONG_ARG $BASE_PARAMS 46 | run "dynamic${NOCONG_RUNNAME}-netmon-graphsage-8it-1seq-$i" --netmon-agg-type=graphsage --netmon-rnn-type=none --netmon-iterations=8 --sequence-length=1 --seed=$i $NOCONG_ARG $BASE_PARAMS 47 | run "dynamic${NOCONG_RUNNAME}-netmon-antisymgcn-8it-1seq-$i" --netmon-agg-type=antisymgcn --netmon-rnn-type=none --netmon-iterations=8 --sequence-length=1 --seed=$i $NOCONG_ARG $BASE_PARAMS 48 | done 49 | done 50 | ) | simple_gpu_scheduler --gpus $GPUS 51 | -------------------------------------------------------------------------------- /scripts/start_routing_runs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Allows using multiple GPUs (e.g. "0 1 2 3") and also assigning multiple 4 | # jobs to the same GPU (e.g. "0 0"). 5 | GPUS="0" 6 | 7 | BASE_PARAMS="--gamma=0.9 --epsilon=1.0 --hidden-dim=512,256 --mini-batch-size=32 --device=cuda --lr=0.001 --tau=0.01 --step-before-train=10_000 --capacity=200_000 --eval-episodes=1000 --eval-episode-steps=300 --disable-progressbar" 8 | LIMITS="--step-between-train=10 --total-steps=250_000" 9 | RECURRENT="--sequence-length=8" 10 | 11 | OFF="echo" 12 | 13 | DATE=$(date +%Y%m%d_%H%M%S) 14 | DIR_NAME="${DATE}_logs_fixed_baseline" 15 | 16 | mkdir -p $DIR_NAME 17 | 18 | # create backup of the code 19 | git log -n 1 >> $DIR_NAME/git.txt 20 | git status >> $DIR_NAME/git.txt 21 | git diff >> $DIR_NAME/git.txt 22 | cp -r src $DIR_NAME/src 23 | # copy this script 24 | cp $0 $DIR_NAME/$(basename "$0") 25 | 26 | run() { 27 | RUN_NAME="$1" 28 | shift 29 | RUN_ARGS="$@" 30 | echo "(set -x; time python -u src/main.py $RUN_ARGS --comment=${RUN_NAME}) > ${DIR_NAME}/${RUN_NAME}.log 2>&1" 31 | } 32 | 33 | time ( 34 | # Graphs G_A, G_B, G_C, selected 35 | for seed in 971182936 923430603 1704443687 324821133 36 | do 37 | for i in 0 1 2 38 | do 39 | run "fixed-nocong-shortest-paths-eval-t$seed-$i" --policy=heuristic --eval --seed=$i --no-congestion --topology-init-seed=$seed --train-topology-allow-eval-seed --episode-steps=300 --random-topology=0 --disable-progressbar --eval-output-dir=${DIR_NAME}/fixed-nocong-shortest-paths-eval-t$seed-$i/eval 40 | run "fixed-nocong-dqn-t$seed-$i" --seed=$i --no-congestion --topology-init-seed=$seed --train-topology-allow-eval-seed --episode-steps=300 --model=dqn --random-topology=0 $BASE_PARAMS $LIMITS 41 | run "fixed-nocong-dqnr-t$seed-$i" --seed=$i --no-congestion --topology-init-seed=$seed --train-topology-allow-eval-seed --episode-steps=300 --model=dqnr --random-topology=0 $BASE_PARAMS $RECURRENT $LIMITS 42 | run "fixed-nocong-commnet-t$seed-$i" --seed=$i --no-congestion --topology-init-seed=$seed --train-topology-allow-eval-seed --episode-steps=300 --model=commnet --random-topology=0 $BASE_PARAMS $RECURRENT $LIMITS 43 | run "fixed-nocong-dgn-t$seed-$i" --seed=$i --no-congestion --topology-init-seed=$seed --train-topology-allow-eval-seed --episode-steps=300 --model=dgn --random-topology=0 $BASE_PARAMS $LIMITS 44 | 45 | run "fixed-shortest-paths-eval-t$seed-$i" --policy=heuristic --eval --seed=$i --topology-init-seed=$seed --train-topology-allow-eval-seed --episode-steps=300 --random-topology=0 --disable-progressbar --eval-output-dir=${DIR_NAME}/fixed-shortest-paths-eval-t$seed-$i/eval 46 | run "fixed-dqn-t$seed-$i" --seed=$i --topology-init-seed=$seed --train-topology-allow-eval-seed --episode-steps=300 --model=dqn --random-topology=0 $BASE_PARAMS $LIMITS 47 | run "fixed-dqnr-t$seed-$i" --seed=$i --topology-init-seed=$seed --train-topology-allow-eval-seed --episode-steps=300 --model=dqnr --random-topology=0 $BASE_PARAMS $RECURRENT $LIMITS 48 | run "fixed-commnet-t$seed-$i" --seed=$i --topology-init-seed=$seed --train-topology-allow-eval-seed --episode-steps=300 --model=commnet --random-topology=0 $BASE_PARAMS $RECURRENT $LIMITS 49 | run "fixed-dgn-t$seed-$i" --seed=$i --topology-init-seed=$seed --train-topology-allow-eval-seed --episode-steps=300 --model=dgn --random-topology=0 $BASE_PARAMS $LIMITS 50 | done 51 | done 52 | ) | simple_gpu_scheduler --gpus $GPUS 53 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def interpolate_model(a: nn.Module, b: nn.Module, a_weight: float, target: nn.Module): 9 | """ 10 | Interpolates model parameters from a and b, saves a_weight * a + (1-a_weight) * b in target. 11 | 12 | :param a: First input module 13 | :param b: Second input module 14 | :param a_weight: Weight of the first input module 15 | :param target: The output module 16 | """ 17 | a_dict = a.state_dict() 18 | b_dict = b.state_dict() 19 | for key in a_dict: 20 | # store interpolation results in a_dict 21 | a_dict[key] = a_weight * a_dict[key] + (1 - a_weight) * b_dict[key] 22 | 23 | target.load_state_dict(a_dict) 24 | 25 | 26 | def get_state_dict(model, netmon, args): 27 | state_dict = { 28 | "type": type(model).__name__, 29 | "state_dict": model.state_dict(), 30 | "args": args, 31 | } 32 | if netmon is not None: 33 | state_dict["netmon_state_dict"] = netmon.state_dict() 34 | 35 | return state_dict 36 | 37 | 38 | def load_state_dict(state_dict, model, netmon): 39 | if state_dict["type"] != type(model).__name__: 40 | print( 41 | f"Warning: Loader expected {type(model).__name__} " 42 | f"but found {state_dict['type']}" 43 | ) 44 | if "netmon_state_dict" in state_dict: 45 | if netmon is None: 46 | raise ValueError("Model uses NetMon which has not been initialized.") 47 | else: 48 | netmon.load_state_dict(state_dict["netmon_state_dict"]) 49 | elif netmon is not None: 50 | raise ValueError("NetMon state could not be found.") 51 | 52 | model.load_state_dict(state_dict["state_dict"]) 53 | 54 | 55 | def set_attributes(obj, key_value_dict, verbose=False): 56 | changes_str = "" 57 | for key in key_value_dict: 58 | if verbose: 59 | if hasattr(obj, key) and getattr(obj, key) != key_value_dict[key]: 60 | changes_str += f"> Updated: {key} = {key_value_dict[key]}" + os.linesep 61 | if not hasattr(obj, key): 62 | changes_str += f"> Added: {key} = {key_value_dict[key]}" + os.linesep 63 | 64 | setattr(obj, key, key_value_dict[key]) 65 | 66 | if verbose: 67 | print(changes_str, end="") 68 | 69 | 70 | def filter_dict(dict, keys): 71 | return {key: dict[key] for key in keys} 72 | 73 | 74 | def one_hot_list(i, max_indices): 75 | a = [0] * max_indices 76 | if i >= 0: 77 | a[i] = 1 78 | return a 79 | 80 | 81 | def set_seed(seed): 82 | """ 83 | Sets seeds for better reproducibility. 84 | 85 | Disclaimer: 86 | Note that this method alone does NOT guarantee deterministic 87 | execution. When using CUDA, there are multiple potential sources of randomness, 88 | including the execution of RNNs. 89 | Also see https://pytorch.org/docs/2.0/notes/randomness.html 90 | and https://pytorch.org/docs/2.0/generated/torch.nn.LSTM.html#torch.nn.LSTM. 91 | 92 | We tested our implementation with 93 | torch.backends.cudnn.deterministic = True 94 | and the environment variables 95 | export CUBLAS_WORKSPACE_CONFIG=:4096:2 96 | export CUDA_LAUNCH_BLOCKING=1 97 | but found that some models (e.g. NetMon with GConvLSTM) still show nondeterministic 98 | behavior, at the cost of almost doubled training time. Because of this, we chose not 99 | include these settings here. 100 | 101 | :param seed: the seed 102 | """ 103 | torch.manual_seed(seed) 104 | random.seed(seed) 105 | np.random.seed(seed) 106 | 107 | 108 | def dim_str_to_list(dims: str): 109 | if len(dims) == 0: 110 | return [] 111 | return [int(item) for item in dims.split(",")] 112 | -------------------------------------------------------------------------------- /src/env/environment.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from enum import Enum 3 | from typing import Any, Dict 4 | import numpy as np 5 | 6 | 7 | class EnvironmentVariant(Enum): 8 | # Without neighbor info in obs 9 | INDEPENDENT = 1 10 | # With neighbor info of k neighbors in obs 11 | WITH_K_NEIGHBORS = 2 12 | # With (global) network topology and all node observations in obs 13 | GLOBAL = 3 14 | 15 | 16 | def reset_and_get_sizes(env) -> int: 17 | """ 18 | Resets the environment to dynamically get environment sizes: 19 | * number of agents 20 | * agent observation dimensions 21 | * number of nodes 22 | * node observation dimensions 23 | 24 | :return: tuple (n_agents, obs_dim, n_nodes, node_obs_dim) 25 | """ 26 | agent_observation, _ = env.reset() 27 | node_obs = env.get_node_observation() 28 | return ( 29 | agent_observation.shape[0], 30 | agent_observation.shape[1], 31 | node_obs.shape[0], 32 | node_obs.shape[1], 33 | ) 34 | 35 | 36 | class NetworkEnv(abc.ABC): 37 | """ 38 | Abstract graph/network environment. 39 | """ 40 | 41 | @abc.abstractmethod 42 | def reset(self): 43 | """ 44 | Resets the environment. 45 | 46 | :return: tuple (agent observation, agent adjacency matrix) 47 | """ 48 | ... 49 | 50 | @abc.abstractmethod 51 | def step(self, act): 52 | """ 53 | Step function that advances the environment. 54 | 55 | :param act: action of all agents 56 | :returns: tuple (agent obs, agent adjacency, agent reward, agent done, info) 57 | """ 58 | ... 59 | 60 | def get(self): 61 | """ 62 | Get underlying environment without wrappers. 63 | 64 | :return: the environment 65 | """ 66 | return self 67 | 68 | def get_final_info(self, info: Dict[str, Any]): 69 | """ 70 | Get additional info at the end of an episode that's not included in the step 71 | info. 72 | 73 | :param info: current info dict that will be extended in-place 74 | :returns: updated info dict 75 | """ 76 | return info 77 | 78 | def get_node_aux(self): 79 | """ 80 | Optional auxiliary targets for each node in the network. 81 | 82 | :return: None (default) or auxiliary targets of shape (num_nodes, node_aux_target_size) 83 | """ 84 | return None 85 | 86 | @abc.abstractmethod 87 | def get_node_agent_matrix(self) -> np.ndarray: 88 | """ 89 | Get a matrix that indicates where agents are located, 90 | matrix[n, a] = 1 iff agent a is on node n and 0 otherwise. 91 | 92 | :return: the node agent matrix of shape (n_nodes, n_agents) 93 | """ 94 | ... 95 | 96 | @abc.abstractmethod 97 | def get_nodes_adjacency(self) -> np.ndarray: 98 | """ 99 | Get a matrix of shape (n_nodes, n_nodes) that indicates node adjacency 100 | 101 | :return: node adjacency matrix 102 | """ 103 | ... 104 | 105 | @abc.abstractmethod 106 | def get_node_observation(self) -> np.ndarray: 107 | """ 108 | Get node observations of shape (n_nodes, node_obs_dim) with dynamic but 109 | consistent node_obs_dim. 110 | 111 | :return: node observation for all nodes 112 | """ 113 | ... 114 | 115 | @abc.abstractmethod 116 | def get_num_agents(self): 117 | """ 118 | Get number of agents in the environment. 119 | 120 | :return: number of agents 121 | """ 122 | ... 123 | 124 | @abc.abstractmethod 125 | def get_num_nodes(self): 126 | """ 127 | Get number of nodes in the environment. 128 | 129 | :return: number of nodes 130 | """ 131 | ... 132 | -------------------------------------------------------------------------------- /src/env/wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from model import NetMon 5 | 6 | 7 | class NetMonWrapper: 8 | """ 9 | Wraps a given environment with a netmon instance. Creates new observations by 10 | concatenating the agent's observations and the respective graph observations. 11 | """ 12 | 13 | def __init__(self, env, netmon, startup_iterations) -> None: 14 | self.env = env 15 | self.netmon = netmon 16 | self.device = next(netmon.parameters()).device 17 | 18 | self.node_obs = None 19 | self.node_adj = None 20 | self.node_agent_matrix = None 21 | self.last_netmon_state = None 22 | self.current_netmon_state = None 23 | assert startup_iterations >= 1, "Number of startup iterations must be >= 1" 24 | self.startup_iterations = startup_iterations 25 | self.frozen = False 26 | 27 | def __getattr__(self, name): 28 | # allow to access attributes of the environment 29 | return getattr(self.env, name) 30 | 31 | def __str__(self) -> str: 32 | return ( 33 | self.env.__str__() 34 | + os.linesep 35 | + "▲ environment is wrapped with NetMon (graph obs)" 36 | ) 37 | 38 | def reset(self): 39 | self.frozen = False 40 | self.current_netmon_state = None 41 | self.last_netmon_state = None 42 | obs, adj = self.env.reset() 43 | for _ in range(self.startup_iterations): 44 | network_obs = self._netmon_step() 45 | return np.concatenate((obs, network_obs), axis=-1), adj 46 | 47 | def step(self, actions): 48 | next_obs, next_adj, reward, done, info = self.env.step(actions) 49 | next_network_obs = self._netmon_step() 50 | next_joint_obs = np.concatenate((next_obs, next_network_obs), axis=-1) 51 | return next_joint_obs, next_adj, reward, done, info 52 | 53 | def freeze(self): 54 | """ 55 | Disable message-passing for the rest of this episode. Agents still receive 56 | graph observations depending on their position in the graph. 57 | """ 58 | self.frozen = True 59 | 60 | def get_netmon_info(self): 61 | return (self.node_obs, self.node_adj, self.node_agent_matrix) 62 | 63 | def get(self): 64 | return self.env 65 | 66 | def _netmon_step(self): 67 | if self.frozen: 68 | self.node_agent_matrix = self.env.get_node_agent_matrix() 69 | node_agent_matrix_in = torch.tensor( 70 | self.node_agent_matrix, dtype=torch.float32 71 | ).unsqueeze(0) 72 | network_obs = torch.bmm( 73 | self.netmon_out.transpose(1, 2).cpu().detach(), node_agent_matrix_in 74 | ).transpose(1, 2) 75 | return network_obs.squeeze(0).numpy() 76 | 77 | self.node_obs = self.env.get_node_observation() 78 | self.node_adj = self.env.get_nodes_adjacency() 79 | self.node_agent_matrix = self.env.get_node_agent_matrix() 80 | with torch.no_grad(): 81 | # prepare inputs 82 | node_obs_in = ( 83 | torch.tensor(self.node_obs, dtype=torch.float32) 84 | .unsqueeze(0) 85 | .to(self.device, non_blocking=True) 86 | ) 87 | node_adj_in = ( 88 | torch.tensor(self.node_adj, dtype=torch.float32) 89 | .unsqueeze(0) 90 | .to(self.device, non_blocking=True) 91 | ) 92 | node_agent_matrix_in = ( 93 | torch.tensor(self.node_agent_matrix, dtype=torch.float32) 94 | .unsqueeze(0) 95 | .to(self.device, non_blocking=True) 96 | ) 97 | 98 | # perform netmon step with correct state 99 | self.last_netmon_state = self.current_netmon_state 100 | self.netmon.state = self.current_netmon_state 101 | self.netmon_out = self.netmon( 102 | node_obs_in, node_adj_in, node_agent_matrix_in, no_agent_mapping=True 103 | ) 104 | self.current_netmon_state = self.netmon.state 105 | 106 | network_obs = NetMon.output_to_network_obs( 107 | self.netmon_out, node_agent_matrix_in 108 | ) 109 | return network_obs.squeeze(0).cpu().detach().numpy() 110 | -------------------------------------------------------------------------------- /src/policy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from gymnasium.spaces import Discrete 4 | from env.routing import Routing 5 | 6 | 7 | class EpsilonGreedy: 8 | def __init__(self, env, model, action_space, args) -> None: 9 | self._env = env 10 | self._enable_action_mask = ( 11 | hasattr(self._env, "enable_action_mask") and self._env.enable_action_mask 12 | ) 13 | self._model = model 14 | self._action_space = action_space 15 | self._args = args 16 | self._epsilon = args.epsilon 17 | self._step = 0 18 | self._epsilon_tmp = None 19 | 20 | def __call__(self, obs, adj): 21 | self._step += 1 22 | # first dimension is number of agents 23 | actions = np.zeros(obs.shape[0], dtype=np.int32) 24 | 25 | with torch.no_grad(): 26 | device = next(self._model.parameters()).device 27 | obs = ( 28 | torch.tensor(obs, dtype=torch.float32) 29 | .unsqueeze(0) 30 | .to(device, non_blocking=True) 31 | ) 32 | adj = ( 33 | torch.tensor(adj, dtype=torch.float32) 34 | .unsqueeze(0) 35 | .to(device, non_blocking=True) 36 | ) 37 | # run our model 38 | q_values = self._model(obs, adj) 39 | # squeeze batch dimension, our batch size is 1 40 | q_values = q_values.cpu().squeeze(0).detach().numpy() 41 | 42 | if self._enable_action_mask: 43 | q_values[self._env.action_mask.nonzero()] = float("-inf") 44 | 45 | # epsilon-greedy action selection 46 | random_actions = np.random.randint(self._action_space, size=actions.shape[0]) 47 | random_filter = np.random.rand(actions.shape[0]) < self._epsilon 48 | actions = ( 49 | np.argmax(q_values, axis=-1) * ~random_filter 50 | + random_filter * random_actions 51 | ) 52 | 53 | # instead of having step as parameter we should log the steps to track 54 | # decay from https://github.com/PKU-RL/DGN/blob/master/Routing/routers_regularization.py 55 | if ( 56 | self._epsilon > 0 57 | and self._step > self._args.step_before_train 58 | and self._step % self._args.epsilon_update_freq == 0 59 | ): 60 | self._epsilon *= self._args.epsilon_decay 61 | if self._epsilon < 0.01: 62 | self._epsilon = 0.01 63 | 64 | return actions 65 | 66 | def eval(self): 67 | # remember epsilon and switch to greedy policy 68 | self._eps_tmp = self._epsilon 69 | self._epsilon = 0 70 | 71 | def reset(self, agents_to_reset): 72 | """ 73 | Resets agents according to the given boolean tensor. 74 | 75 | :param agents_to_reset: agents to reset of shape (batch_size, n_agents). A value 76 | of 1 indicates that the agent's state should be reset. 77 | """ 78 | if hasattr(self._model, "state") and self._model.state is not None: 79 | self._model.state = self._model.state * ~torch.tensor( 80 | agents_to_reset, dtype=bool, device=self._model.state.device 81 | ).unsqueeze(-1) 82 | 83 | def train(self): 84 | # switch back to old epsilon 85 | if self._epsilon_tmp is not None: 86 | self._epsilon_tmp = None 87 | self._epsilon = self._epsilon_tmp 88 | 89 | 90 | class ShortestPath: 91 | def __init__(self, env, model, action_space, args) -> None: 92 | self._env = env 93 | assert isinstance(env.get(), Routing) 94 | self._n_agents = env.get_num_agents() 95 | self._model = model 96 | self._action_space = action_space 97 | self._args = args 98 | self.static_shortest_paths = False 99 | self.network = None 100 | 101 | def reset_episode(self): 102 | self.network = None 103 | 104 | def __call__(self, obs, adj): 105 | act = np.zeros(self._env.n_data, dtype=np.int32) 106 | 107 | if self.static_shortest_paths: 108 | # create shortest paths at the very beginning, then use them 109 | if self.network is None: 110 | import copy 111 | 112 | self.network = copy.deepcopy(self._env.network) 113 | 114 | network = self.network 115 | else: 116 | # always use latest shortest paths 117 | network = self._env.network 118 | 119 | for i in range(self._env.n_data): 120 | packet = self._env.data[i] 121 | current_node = packet.now 122 | target_node = packet.target 123 | 124 | if current_node == target_node: 125 | act[i] = 0 126 | continue 127 | 128 | # first index is the source and last index is the target 129 | next_node = network.shortest_paths[current_node][target_node][1] 130 | 131 | for index, j in enumerate(network.nodes[current_node].edges): 132 | current_edge = network.edges[j] 133 | # if current edge is the desired one we choose this edge 134 | # case 1: edge.start is the current node itself and edge.end is the target node 135 | if current_edge.get_other_node(packet.now) == next_node: 136 | act[i] = index + 1 137 | break 138 | 139 | return act 140 | 141 | 142 | class RandomPolicy: 143 | def __init__(self, env, model, action_space, args) -> None: 144 | self._env = env 145 | self._model = model 146 | self._action_space = action_space 147 | self._args = args 148 | self._step = 0 149 | 150 | def __call__(self, obs, adj): 151 | self._step += 1 152 | self.action_space = Discrete(self._env.action_space.n, start=0) 153 | act = np.zeros(self._args.n_data, dtype=np.int32) 154 | # random action selection 155 | for i in range(len(act)): 156 | act[i] = self.action_space.sample() 157 | return act 158 | 159 | 160 | class SimplePolicy: 161 | def __init__(self, env, model, action_space, args) -> None: 162 | self._env = env 163 | self._model = model 164 | self._action_space = action_space 165 | self._args = args 166 | self._step = 0 167 | 168 | def __call__(self, obs, adj): 169 | self._step += 1 170 | act = np.zeros(self._args.n_data, dtype=np.int32) 171 | edges = self._env.edges 172 | router = self._env.router 173 | 174 | for i in range(self._args.n_data): 175 | packet = self._env.data[i] 176 | current_node = packet.now 177 | target_node = router[0] 178 | 179 | if current_node == target_node: 180 | act[i] = 0 181 | continue 182 | else: 183 | for index, j in enumerate(router[current_node].edge): 184 | current_edge = edges[j] 185 | if current_edge.get_other_node(packet.now) == target_node: 186 | act[i] = index + 1 187 | break 188 | 189 | return act 190 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Agent Reinforcement Learning in Graphs 2 | [![Lint](https://github.com/jw3il/graph-marl/actions/workflows/lint.yml/badge.svg?branch=main)](https://github.com/jw3il/graph-marl/actions/workflows/lint.yml) [![Train Example](https://github.com/jw3il/graph-marl/actions/workflows/train-example.yml/badge.svg?branch=main)](https://github.com/jw3il/graph-marl/actions/workflows/train-example.yml) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![arXiv](https://img.shields.io/badge/arXiv-2402.05027-b31b1b.svg)](https://arxiv.org/abs/2402.05027) 3 | 4 | 5 | This repository provides prototypical implementations of reinforcement learning algorithms and graph-based (multi-agent) environments. 6 | 7 | We introduce a new environment definition in which not only the agents receive (partial) *observations*, but also all nodes in the graph receive (partial) *node observations*. 8 | The core idea is that this allows to *decouple* learning graph representations and control tasks within the graph. Our approach leverages these local node observations and message passing to learn *graph observations*. These graph observations are then used by local agents to solve tasks that may require a broader view of the graph. 9 | 10 | 11 | 12 | ## Citation 13 | 14 | This is the official implementation used in the paper *Towards Generalizability of Multi-Agent Reinforcement Learning in Graphs with Recurrent Message Passing* ([arXiv](https://arxiv.org/abs/2402.05027)), which has been accepted for publication at AAMAS 2024. If you use parts of this repository, please consider citing 15 | 16 | ``` 17 | @InProceedings{weil2024graphMARL, 18 | author = {Weil, Jannis and Bao, Zhenghua and Abboud, Osama and Meuser, Tobias}, 19 | title = {Towards Generalizability of Multi-Agent Reinforcement Learning in Graphs with Recurrent Message Passing}, 20 | booktitle = {Proceedings of the 23rd International Conference on Autonomous Agents and Multiagent Systems}, 21 | year = {2024}, 22 | note = {accepted, to appear} 23 | } 24 | ``` 25 | 26 | ## Getting started 27 | 28 | We use [conda](https://docs.conda.io/en/latest/miniconda.html) for this guide. 29 | 30 | 1. Create a python environment and activate it 31 | ``` 32 | $ conda create -n graph-marl python=3.9 33 | $ conda activate graph-marl 34 | ``` 35 | 36 | 2. Optional: Install pytorch with GPU support (see https://pytorch.org/get-started/locally/). Use the CUDA version that is compatible with your GPU, e.g. for CUDA 11.8: 37 | ``` 38 | (graph-marl) $ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 39 | ``` 40 | 3. Install this package with remaining requirements (installs pytorch without GPU support if you skipped step 2). The suffix `[dev,temporal]` is optional, `dev` includes development dependencies and `temporal` includes `torch-geometric-temporal` for temporal GNN baselines. Installing the dependencies of `temporal` may take a while. 41 | ``` 42 | (graph-marl) $ pip install -e .[dev,temporal] 43 | ``` 44 | 4. In a git repository, you can install the pre-commit hook to check formatting automatically 45 | ``` 46 | (graph-marl) $ pre-commit install 47 | ``` 48 | 5. You can start developing and training! 49 | 50 | ### Troubleshooting 51 | 52 | #### Windows: TLS/SSL and pre-commit 53 | 54 | On windows, you might encounter issues with `pre-commit` not being able to set up the environments due to a missing ssl Python module. 55 | 56 | Solution: Install `pre-commit` in the `base` environment using the `conda-forge` channel (`conda install -c conda-forge pre-commit`). Then upgrade your packages in `base` via `conda upgrade --all -c conda-forge`. 57 | 58 | ## Training 59 | 60 | The main entrypoint of this project is the file `src/main.py`. 61 | You can get an overview of the available arguments with 62 | 63 | ``` 64 | (graph-marl) $ python src/main.py --help 65 | ``` 66 | 67 | Note that not all combinations of the available arguments have been thoroughly tested. 68 | 69 | ### Minimal Example Environment 70 | 71 | We provide a minimal example environment (`--env-type=simple`) that converges very quickly and allows to test the implementation. 72 | 73 | You can run regular DQN on the environment like 74 | ``` 75 | (graph-marl) $ python src/main.py --device=cuda --step-before-train=1000 --total-steps=2000 --eval-episodes=10 --env-type=simple --model=dqn 76 | ``` 77 | 78 | and enable learned graph observations by adding `--netmon`, resulting in 79 | ``` 80 | (graph-marl) $ python src/main.py --device=cuda --step-before-train=1000 --total-steps=2000 --eval-episodes=10 --env-type=simple --model=dqn --netmon 81 | ``` 82 | 83 | The latter will take longer to train, but the evaluation after training should yield the optimal mean reward of 84 | ``` 85 | { 86 | "reward_mean": 1.0 87 | } 88 | ``` 89 | 90 | If your system does not support cuda, you can replace `--device=cuda` with `--device=cpu`. 91 | 92 | ### Routing in Single Graphs 93 | 94 | All routing experiments are initiated by calling `src/main.py`. 95 | An exemplary configuration to train a DQN agent for routing in single graphs with seed 923430603 (from the test graphs) is 96 | 97 | ``` 98 | (graph-marl) $ python src/main.py --seed=0 --topology-init-seed=923430603 --train-topology-allow-eval-seed --episode-steps=300 --model=dqn --random-topology=0 --gamma=0.9 --epsilon=1.0 --hidden-dim=512,256 --mini-batch-size=32 --device=cuda --lr=0.001 --tau=0.01 --step-before-train=10_000 --capacity=200_000 --eval-episodes=1000 --eval-episode-steps=300 --step-between-train=10 --total-steps=250_000 --comment=fixed-dqn-t923430603 99 | ``` 100 | 101 | The environment variant without bandwidth limitations can be enabled with `--no-congestion `. 102 | 103 | The experiments from our paper are provided with `scripts/start_routing_runs.sh`. 104 | 105 | ### Supervised Learning on Shortest Path Lengths 106 | 107 | The script `src/sl.py` allows to train different graph observation architectures on a supervised routing task. 108 | In the first run, the script will generate a dataset. 109 | The dataset is saved locally, subsequent runs load it automatically. 110 | An exemplary configuration with our architecture is 111 | 112 | ``` 113 | (graph-marl) $ python src/sl.py --seed=0 --netmon-iterations=1 --sequence-length=8 --iterations=50_000 --num-samples-train=99_000 --validate-after=500 --filename=netmon-1it-8seq-0.h5 114 | ``` 115 | 116 | Note that generating data for 100_000 topologies will take some time. Dataset generation is not parallelized at the moment and will take around 5-15 minutes depending on the CPU. 117 | 118 | The argument `--netmon-iterations` stands for the number of message passing iterations. 119 | The argument `--sequence-length` determines the unroll length during training. 120 | To use other architectures, one can set the `--netmon-agg-type` argument, e.g. to `gconvlstm` for GCRN-LSTM. 121 | Optionally, the argument `--clear-cache` can be used to force the generation of new datasets. 122 | 123 | The experiments from our paper are provided with `scripts/start_sl_runs.sh` (note that the dataset has to be generated first). If you want to perform parallel training runs, make sure that the dataset is generated 124 | *before* launching the runs. 125 | 126 | ### Generalized Routing 127 | 128 | An exemplary configuration to train our approach is 129 | 130 | ``` 131 | (graph-marl) $ python src/main.py --netmon-agg-type=sum --netmon-rnn-type=lstm --netmon-iterations=1 --sequence-length=8 --seed=0 --step-between-train=10 --total-steps=2_500_000 --netmon --model=dqn --random-topology=1 --gamma=0.9 --epsilon=1.0 --epsilon-decay=0.999 --hidden-dim=512,256 --netmon-encoder-dim=512,256 --netmon-dim=128 --mini-batch-size=32 --device=cuda --lr=0.001 --tau=0.01 --step-before-train=100_000 --capacity=200_000 --eval-episodes=1000 --eval-episode-steps=300 --episode-steps=50 --comment=dynamic-netmon-1it-8seq 132 | ``` 133 | 134 | The argument `--netmon` enables graph observations. 135 | Note that a replay memory with capacity `200_000` requires around 30 GB of RAM. 136 | To reduce the memory requirements, you can reduce the precision of the replay memory (not of the model) with `--replay-half-precision`. 137 | The evaluation after training is performed over the 1000 test graphs. 138 | 139 | The experiments from our paper are provided with `scripts/start_routing_netmon_runs.sh`. 140 | 141 | ## Development with Visual Studio Code 142 | 143 | This repository contains settings and recommended extensions for [Visual Studio Code](https://code.visualstudio.com/) in `.vscode/`. 144 | When opening the project with vscode, you should get a prompt to install the recommended packages. The settings should be applied automatically. 145 | 146 | ### Select Python Interpreter 147 | 148 | To get syntax highlighting to work properly, you have to select the correct Python interpreter. 149 | This can be done by opening the vscode command palette (usually Control+Shift+P or F1) and typing `Python: Select Interpreter`. Select the previously created `graph-marl` conda environment and you are done. 150 | 151 | ### Viewing Tensorboard Logs 152 | 153 | We automatically save tensorboard logs and models in new subdirectories inside `runs/`. 154 | 155 | You can view them by running `tensorboard --logdir=runs`. 156 | 157 | Alternatively, you can run the command `Python: Launch Tensorboard` in vscode. 158 | 159 | ## Acknowledgement 160 | 161 | This work has received funding from the Federal Ministry of Education and Research of Germany ([BMBF](https://www.bmbf.de/)) through Software Campus Grant 01IS17050 ([AC3Net](https://softwarecampus.de/en/projekt/ac3net-autonomous-communication-in-cooperative-computer-networks/)). 162 | 163 | AC3Net project logo by Alisha Hirsch. 164 | -------------------------------------------------------------------------------- /src/replaybuffer.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import Iterator 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | import torch 7 | 8 | # named tuple for easier access of batches 9 | TransitionBatch = namedtuple( 10 | "TransitionBatch", 11 | [ 12 | "idx", 13 | "obs", 14 | "action", 15 | "reward", 16 | "next_obs", 17 | "adj", 18 | "next_adj", 19 | "done", 20 | "episode_done", 21 | "agent_state", 22 | "node_obs", 23 | "node_adj", 24 | "node_state", 25 | "node_aux", 26 | "node_agent_matrix", 27 | "next_node_obs", 28 | "next_node_adj", 29 | "next_node_agent_matrix", 30 | ], 31 | ) 32 | 33 | 34 | class ReplayBuffer(object): 35 | def __init__( 36 | self, 37 | seed, 38 | buffer_size, 39 | n_agents, 40 | observation_size, 41 | agent_state_size, 42 | n_nodes=0, 43 | node_observation_size=0, 44 | node_state_size=0, 45 | node_aux_size=0, 46 | half_precision=False, 47 | ): 48 | self.buffer_size = buffer_size 49 | self.count = 0 50 | self.index = 0 51 | 52 | # create buffer 53 | 54 | def create(shape, dtype, default=0): 55 | act = np.empty(shape, dtype=dtype) 56 | act.fill(default) 57 | return act 58 | 59 | if half_precision: 60 | float_type = np.float16 61 | else: 62 | float_type = np.float32 63 | 64 | self.obs = create((buffer_size, n_agents, observation_size), dtype=float_type) 65 | self.action = create((buffer_size, n_agents), dtype=np.int8) 66 | self.reward = create((buffer_size, n_agents), dtype=float_type) 67 | self.next_obs = create( 68 | (buffer_size, n_agents, observation_size), dtype=float_type 69 | ) 70 | self.adj = create((buffer_size, n_agents, n_agents), dtype=np.bool_) 71 | self.next_adj = create((buffer_size, n_agents, n_agents), dtype=np.bool_) 72 | self.done = create((buffer_size, n_agents), dtype=np.bool_) 73 | self.episode_done = create(buffer_size, dtype=np.bool_) 74 | self.agent_state = create( 75 | (buffer_size, n_agents, agent_state_size), dtype=float_type 76 | ) 77 | 78 | # optional node information 79 | if n_nodes > 0: 80 | assert node_observation_size >= 0 and node_state_size >= 0 81 | 82 | self.node_state = create( 83 | (buffer_size, n_nodes, node_state_size), dtype=float_type 84 | ) 85 | self.node_aux = create((buffer_size, n_nodes, node_aux_size), dtype=float_type) 86 | self.node_obs = create( 87 | (buffer_size, n_nodes, node_observation_size), dtype=float_type 88 | ) 89 | self.next_node_obs = create( 90 | (buffer_size, n_nodes, node_observation_size), dtype=float_type 91 | ) 92 | self.node_adj = create((buffer_size, n_nodes, n_nodes), dtype=np.bool_) 93 | self.next_node_adj = create((buffer_size, n_nodes, n_nodes), dtype=np.bool_) 94 | self.node_agent_matrix = create( 95 | (buffer_size, n_nodes, n_agents), dtype=np.bool_ 96 | ) 97 | self.next_node_agent_matrix = create( 98 | (buffer_size, n_nodes, n_agents), dtype=np.bool_ 99 | ) 100 | 101 | self._random_generator = np.random.default_rng(seed) 102 | 103 | def get_batch( 104 | self, batch_size, device, sequence_length=0 105 | ) -> Iterator[TransitionBatch]: 106 | # simple case: just get random indices 107 | if sequence_length <= 1: 108 | indices = self._random_generator.choice( 109 | self.count, batch_size, replace=True, p=None 110 | ) 111 | yield self._get_transition_batch(indices, device) 112 | return 113 | 114 | # we sample a sequence with length > 1 115 | 116 | # first get the beginning of the buffer (oldest element) 117 | buffer_start = self.index % self.count 118 | batch_sequence_start = self._random_generator.choice( 119 | self.count - sequence_length, 120 | batch_size, 121 | replace=True, 122 | p=None, 123 | ) 124 | 125 | # add buffer start offset and wrap around 126 | batch_sequence_start = (buffer_start + batch_sequence_start) % self.count 127 | 128 | for offset in range(sequence_length): 129 | indices = (batch_sequence_start + offset) % self.count 130 | yield self._get_transition_batch(indices, device) 131 | 132 | def _get_transition_batch(self, indices, device) -> TransitionBatch: 133 | # convert to tensor and push to training device 134 | return TransitionBatch( 135 | indices, 136 | torch.tensor(self.obs[indices], dtype=torch.float32).to( 137 | device, non_blocking=True 138 | ), 139 | torch.tensor(self.action[indices], dtype=torch.int64).to( 140 | device, non_blocking=True 141 | ), 142 | torch.tensor(self.reward[indices], dtype=torch.float32).to( 143 | device, non_blocking=True 144 | ), 145 | torch.tensor(self.next_obs[indices], dtype=torch.float32).to( 146 | device, non_blocking=True 147 | ), 148 | torch.tensor(self.adj[indices], dtype=torch.float32).to( 149 | device, non_blocking=True 150 | ), 151 | torch.tensor(self.next_adj[indices], dtype=torch.float32).to( 152 | device, non_blocking=True 153 | ), 154 | torch.tensor(self.done[indices], dtype=torch.bool).to( 155 | device, non_blocking=True 156 | ), 157 | torch.tensor(self.episode_done[indices], dtype=torch.bool).to( 158 | device, non_blocking=True 159 | ), 160 | torch.tensor(self.agent_state[indices], dtype=torch.float32).to( 161 | device, non_blocking=True 162 | ), 163 | torch.tensor(self.node_obs[indices], dtype=torch.float32).to( 164 | device, non_blocking=True 165 | ), 166 | torch.tensor(self.node_adj[indices], dtype=torch.float32).to( 167 | device, non_blocking=True 168 | ), 169 | torch.tensor(self.node_state[indices], dtype=torch.float32).to( 170 | device, non_blocking=True 171 | ), 172 | torch.tensor(self.node_aux[indices], dtype=torch.float32).to( 173 | device, non_blocking=True 174 | ), 175 | torch.tensor(self.node_agent_matrix[indices], dtype=torch.float32).to( 176 | device, non_blocking=True 177 | ), 178 | torch.tensor(self.next_node_obs[indices], dtype=torch.float32).to( 179 | device, non_blocking=True 180 | ), 181 | torch.tensor(self.next_node_adj[indices], dtype=torch.float32).to( 182 | device, non_blocking=True 183 | ), 184 | torch.tensor(self.next_node_agent_matrix[indices], dtype=torch.float32).to( 185 | device, non_blocking=True 186 | ), 187 | ) 188 | 189 | def get_recent_indices(self, last_n): 190 | if last_n is None: 191 | return np.arange(self.count), np.arange(self.count) 192 | 193 | m = min(self.count, last_n) 194 | x = self.index - m + np.arange(m) 195 | return x, x % self.count 196 | 197 | def plot_episode_vlines(self, idx, ymin, ymax): 198 | episode_done = self.episode_done[idx].nonzero()[0] 199 | plt.vlines( 200 | idx[0] + episode_done, 201 | colors="k", 202 | linestyles="dotted", 203 | ymin=ymin, 204 | ymax=ymax, 205 | ) 206 | 207 | def save_node_state_diff_plot(self, filename, last_n): 208 | # only for debugging 209 | x, idx = self.get_recent_indices(last_n) 210 | node_state_diff = (self.node_state[idx[1:]] - self.node_state[idx[:-1]]).mean( 211 | axis=(-1, -2) 212 | ) 213 | self.plot_episode_vlines(idx[1:], node_state_diff.min(), node_state_diff.max()) 214 | plt.plot(x[1:], node_state_diff) 215 | plt.xlabel("Buffer steps") 216 | plt.ylabel("Node state difference") 217 | plt.tight_layout() 218 | plt.savefig( 219 | filename, 220 | bbox_inches="tight", 221 | ) 222 | plt.clf() 223 | 224 | def save_node_state_std_plot(self, filename, last_n): 225 | # std over node dim => std per feature 226 | x, idx = self.get_recent_indices(last_n) 227 | node_states_feature_std = self.node_state[idx].std(axis=-2) 228 | mean = node_states_feature_std.mean(axis=-1) 229 | std = node_states_feature_std.std(axis=-1) 230 | self.plot_episode_vlines(x, (mean - std).min(), (mean + std).max()) 231 | plt.fill_between(x, mean - std, mean + std, alpha=0.2) 232 | plt.plot(x, mean) 233 | plt.xlabel("Buffer steps") 234 | plt.ylabel("Node feature std") 235 | 236 | plt.tight_layout() 237 | plt.savefig( 238 | filename, 239 | bbox_inches="tight", 240 | ) 241 | plt.clf() 242 | 243 | def add( 244 | self, 245 | obs, 246 | action, 247 | reward, 248 | next_obs, 249 | adj, 250 | next_adj, 251 | done, 252 | episode_done, 253 | agent_state, 254 | node_state, 255 | node_aux, 256 | node_obs, 257 | node_adj, 258 | node_agent_matrix, 259 | next_node_obs, 260 | next_node_adj, 261 | next_node_agent_matrix, 262 | ): 263 | # put values into buffer 264 | self.obs[self.index] = obs 265 | self.action[self.index] = action 266 | self.reward[self.index] = reward 267 | self.next_obs[self.index] = next_obs 268 | self.adj[self.index] = adj 269 | self.next_adj[self.index] = next_adj 270 | self.done[self.index] = done 271 | self.episode_done[self.index] = episode_done 272 | if isinstance(agent_state, np.ndarray): 273 | agent_state = np.squeeze(agent_state, axis=0) 274 | self.agent_state[self.index] = agent_state 275 | self.node_state[self.index] = node_state 276 | self.node_aux[self.index] = node_aux 277 | self.node_obs[self.index] = node_obs 278 | self.node_adj[self.index] = node_adj 279 | self.node_agent_matrix[self.index] = node_agent_matrix 280 | self.next_node_obs[self.index] = next_node_obs 281 | self.next_node_adj[self.index] = next_node_adj 282 | self.next_node_agent_matrix[self.index] = next_node_agent_matrix 283 | 284 | # increase counters 285 | if self.count < self.buffer_size: 286 | self.count += 1 287 | self.index = (self.index + 1) % self.buffer_size 288 | -------------------------------------------------------------------------------- /src/env/simple_environment.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | from typing import List 3 | 4 | import matplotlib.pyplot as plt 5 | import networkx as nx 6 | import numpy as np 7 | 8 | from env.environment import EnvironmentVariant, NetworkEnv 9 | from gymnasium.spaces import Discrete 10 | 11 | 12 | class Router(object): 13 | def __init__(self, x, y, s): 14 | self.x = x 15 | self.y = y 16 | self.neighbor = [] 17 | self.edge = [] 18 | self.score = s 19 | 20 | 21 | class Edge(object): 22 | def __init__(self, x, y, length): 23 | self.start = x 24 | self.end = y 25 | self.len = int(int(length * 10) / 2 + 1) 26 | self.load = 0 27 | 28 | def get_other_node(self, node): 29 | if self.start == node: 30 | return self.end 31 | elif self.end == node: 32 | return self.start 33 | else: 34 | raise ValueError( 35 | f"Is neither start nor end of edge {self.start}-{self.end}: {node}" 36 | ) 37 | 38 | 39 | class Data(object): 40 | def __init__(self, x, size): 41 | self.now = x 42 | self.size = size 43 | 44 | 45 | class SimpleEnvironment(NetworkEnv): 46 | """ 47 | Simple test environment with 3 nodes connected in a line topology 48 | and one agent located at the nodes in the middle. 49 | 50 | o -- o -- o 51 | ^ 52 | agent 53 | 54 | The agent makes a binary decision to either move to the left or right neighbor. 55 | These neighbor nodes have scores in {{-1, 1}} randomly chosen without replacement 56 | and the agent receives the score that has been assigned to the node. 57 | 58 | The goal is to select the action that leads to the node with score 1. 59 | """ 60 | 61 | def __init__(self, env_var: EnvironmentVariant, random_topology): 62 | """ 63 | Initializes the environment. 64 | 65 | :param env_var: Independent means that the agent only 66 | observes its own position, with_k_neighbors and global allow for a global view. 67 | :param random_topology: Whether to randomize the node ids and edge order. 68 | """ 69 | self.router: List[Router] = [] 70 | self.edges: List[Edge] = [] 71 | self.G = nx.Graph() 72 | self.adj_matrix = None 73 | self.n_router = 3 74 | self.n_data = 1 75 | self.record_distance_map = False 76 | self.random_topology = random_topology 77 | self.sort_edges = True 78 | self.env_var = EnvironmentVariant(env_var) 79 | self.start_node = -1 80 | 81 | # using gym action space 82 | self.action_space = Discrete(2, start=0) # {0, 1} 83 | 84 | @staticmethod 85 | def one_hot_list(i, max_indices): 86 | a = [0] * max_indices 87 | if i >= 0: 88 | a[i] = 1 89 | return a 90 | 91 | def get_num_agents(self): 92 | return self.n_data 93 | 94 | def get_num_nodes(self): 95 | return self.n_router 96 | 97 | def __str__(self) -> str: 98 | return textwrap.dedent( 99 | f"""\ 100 | SimpleEnvironment with parameters 101 | > Environment variant: {self.env_var} 102 | > Random topology: {self.random_topology}\ 103 | """ 104 | ) 105 | 106 | def _build_network(self): 107 | self.G = nx.Graph() 108 | self.router = [] 109 | self.edges = [] 110 | self.data = [] 111 | 112 | border_scores = np.array([-1, 1]) 113 | # border scores are always shuffled 114 | np.random.shuffle(border_scores) 115 | 116 | scores = np.array([border_scores[0], 0, border_scores[1]]) 117 | if self.random_topology: 118 | # shuffle all scores 119 | np.random.shuffle(scores) 120 | 121 | # identify node ids, n0 is the start node with score 0 122 | n0 = np.where(scores == 0)[0][0] 123 | n1 = (n0 + 1) % 3 124 | n2 = (n1 + 1) % 3 125 | self.start_node = n0 126 | 127 | for i in range(3): 128 | # add routers at random locations 129 | new_router = Router(np.random.random(), np.random.random(), scores[i]) 130 | self.router.append(new_router) 131 | self.G.add_node(i, pos=(new_router.x, new_router.y)) 132 | 133 | self.router[n0].neighbor.append(n1) 134 | self.router[n0].neighbor.append(n2) 135 | self.router[n1].neighbor.append(n0) 136 | self.router[n2].neighbor.append(n0) 137 | 138 | edge_destinations = [n1, n2] 139 | if self.random_topology: 140 | np.random.shuffle(edge_destinations) 141 | 142 | edge_nodes = [n0, edge_destinations[0]] 143 | if self.random_topology: 144 | np.random.shuffle(edge_nodes) 145 | 146 | new_edge_0 = Edge(edge_nodes[0], edge_nodes[1], 1) 147 | self.edges.append(new_edge_0) 148 | 149 | self.G.add_edge( 150 | new_edge_0.start, 151 | new_edge_0.end, 152 | color="lightblue", 153 | weight=new_edge_0.len, 154 | ) 155 | 156 | edge_nodes = [n0, edge_destinations[1]] 157 | if self.random_topology: 158 | np.random.shuffle(edge_nodes) 159 | 160 | new_edge_2 = Edge(edge_nodes[0], edge_nodes[1], 1) 161 | self.edges.append(new_edge_2) 162 | 163 | self.G.add_edge( 164 | new_edge_2.start, 165 | new_edge_2.end, 166 | color="lightblue", 167 | weight=new_edge_2.len, 168 | ) 169 | 170 | edge_order = [0, 1] 171 | if self.random_topology: 172 | if self.sort_edges: 173 | edge_order = np.argsort(edge_destinations) 174 | else: 175 | np.random.shuffle(edge_order) 176 | 177 | self.router[n0].edge.append(edge_order[0]) 178 | self.router[n0].edge.append(edge_order[1]) 179 | 180 | self.router[edge_destinations[0]].edge.append(0) 181 | self.router[edge_destinations[1]].edge.append(1) 182 | 183 | # generate data packet 184 | self.data = [] 185 | self.data.append(Data(self.start_node, 1)) 186 | 187 | self._update_nodes_adjacency() 188 | 189 | def reset(self): 190 | self._build_network() 191 | # self.render() 192 | # self._network_exists = True 193 | return self._get_observation(), self._get_data_adjacency() 194 | 195 | def render(self): 196 | labels = {} 197 | for r in range(3): 198 | labels[r] = f"{r}:{self.router[r].score}" 199 | nx.draw_networkx(self.G, labels=labels, node_color="pink") 200 | plt.show() 201 | 202 | # adj matrix of routers(nodes) # 203 | def get_nodes_adjacency(self): 204 | """ 205 | Get the adjacency matrix for all routers (nodes) in the network. 206 | 207 | return: adjacency matrix of size (n_router, n_router) 208 | """ 209 | return self.adj_matrix 210 | 211 | def _update_nodes_adjacency(self): 212 | self.adj_matrix = np.eye(self.n_router, self.n_router, dtype=np.int8) 213 | for i in range(self.n_router): 214 | for neighbor in self.router[i].neighbor: 215 | self.adj_matrix[i][neighbor] = 1 216 | 217 | def get_node_observation(self): 218 | """ 219 | Get the monitoring information for each router in the network. 220 | 221 | :return: monitoring info for each router 222 | """ 223 | obs = [] 224 | for j in range(self.n_router): 225 | ob = [] 226 | # necessary: node score 227 | ob.append(self.router[j].score) 228 | 229 | # optional: node index 230 | # ob.append(j) 231 | 232 | # optional: edge indices 233 | # for edge_idx in range(2): 234 | # if edge_idx >= len(self.router[j].edge): 235 | # ob.append(-1) 236 | # else: 237 | # ob.append(self.router[j].edge[edge_idx]) 238 | 239 | obs.append(ob) 240 | 241 | return np.array(obs, dtype=np.float32) 242 | 243 | def get_node_agent_matrix(self): 244 | """ 245 | Gets a matrix that indicates where agents are located, 246 | matrix[n, a] = 1 iff agent a is on node n and 0 otherwise. 247 | 248 | :return: the node agent matrix of shape (n_nodes, n_agents) 249 | """ 250 | node_agent = np.zeros((self.n_router, self.n_data), dtype=np.int8) 251 | for a in range(self.n_data): 252 | node_agent[self.data[a].now, a] = 1 253 | 254 | return node_agent 255 | 256 | def _get_observation(self): 257 | obs = [] 258 | nodes_adjacency = self.get_nodes_adjacency().flatten() 259 | node_observation = self.get_node_observation().flatten() 260 | global_obs = np.concatenate((nodes_adjacency, node_observation)) 261 | 262 | for i in range(self.n_data): 263 | ob = [] 264 | # packet information 265 | ob.append(self.data[i].now) 266 | 267 | # other data 268 | self.data[i].neigh = [] 269 | self.data[i].neigh.append(i) 270 | for j in range(self.n_data): 271 | if j == i: 272 | continue 273 | if (self.data[j].now in self.router[self.data[i].now].neighbor) | ( 274 | self.data[j].now == self.data[i].now 275 | ): 276 | self.data[i].neigh.append(j) 277 | 278 | ob_numpy = np.array(ob) 279 | 280 | # add global information 281 | if self.env_var != EnvironmentVariant.INDEPENDENT: 282 | # add global node observations 283 | ob_numpy = np.concatenate((ob_numpy, global_obs)) 284 | 285 | obs.append(ob_numpy) 286 | 287 | return np.array(obs, dtype=np.float32) 288 | 289 | def step(self, action): 290 | act = action[0] 291 | reward = [-1] 292 | done = [False] 293 | 294 | for i in range(self.n_data): 295 | # agent i controls data packet i 296 | packet = self.data[i] 297 | t = self.router[packet.now].edge[act] 298 | 299 | if self.edges[t].start == packet.now: 300 | packet.now = self.edges[t].end 301 | else: 302 | packet.now = self.edges[t].start 303 | 304 | reward[0] = self.router[packet.now].score 305 | done[0] = True 306 | 307 | # reset packet (middle router) 308 | packet.now = self.start_node 309 | 310 | obs = self._get_observation() 311 | adj = self._get_data_adjacency() 312 | info = {} 313 | 314 | # print(action, reward) 315 | return obs, adj, np.array(reward), done, info 316 | 317 | def _get_data_adjacency(self): 318 | """ 319 | Get an adjacency matrix for data packets (agents) of shape (n_agents, n_agents) 320 | where the second dimension contains the neighbors of the agents in the first 321 | dimension, i.e. the matrix is of form (agent, neighbors). 322 | 323 | :param data: current data list 324 | :param n_data: number of data packets 325 | :return: adjacency matrix 326 | """ 327 | # eye because self is also part of the neighborhood 328 | adj = np.eye(self.n_data, self.n_data, dtype=np.int8) 329 | for i in range(self.n_data): 330 | for n in self.data[i].neigh: 331 | if n != -1: 332 | # n is (currently) a neighbor of i 333 | adj[i, n] = 1 334 | return adj 335 | -------------------------------------------------------------------------------- /src/env/network.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import networkx as nx 5 | from networkx.classes.function import path_weight 6 | from collections import defaultdict 7 | 8 | 9 | class Node: 10 | """ 11 | A node in a network. 12 | """ 13 | 14 | def __init__(self, x, y): 15 | self.x = x 16 | self.y = y 17 | self.neighbors = [] 18 | self.edges = [] 19 | 20 | 21 | class Edge: 22 | """ 23 | An edge in a network. 24 | """ 25 | 26 | def __init__(self, start, end, length): 27 | self.start = start 28 | self.end = end 29 | self.length = length 30 | 31 | def get_other_node(self, node): 32 | if self.start == node: 33 | return self.end 34 | elif self.end == node: 35 | return self.start 36 | else: 37 | raise ValueError( 38 | f"Is neither start nor end of edge {self.start}-{self.end}: {node}" 39 | ) 40 | 41 | 42 | class Network: 43 | """ 44 | Network class that manages the creation of graphs. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | n_nodes=20, 50 | random_topology=False, 51 | n_random_seeds=None, 52 | sequential_topology_seeds=False, 53 | topology_init_seed=476, 54 | excluded_seeds: Optional[List[int]] = None, 55 | provided_seeds: Optional[List[int]] = None, 56 | ): 57 | """ 58 | Initializes the network and optionally creates list of valid random topology seeds. 59 | 60 | :param n_nodes: Number of nodes in the network, defaults to 20 61 | :param random_topology: Create a random topology on each reset, defaults to False 62 | :param n_random_seeds: Number of random topologies, defaults to None 63 | :param sequential_topology_seeds: Sample topologies sequentially, defaults to False 64 | :param topology_init_seed: Seed for topology creation, defaults to 476 65 | :param excluded_seeds: Seeds that must not be used for topology creation, defaults to None 66 | :param provided_seeds: Use provided seeds and generate no new topologies, defaults to None 67 | """ 68 | self.n_nodes = n_nodes 69 | 70 | self.nodes = [] 71 | self.edges = [] 72 | 73 | self.G = nx.Graph() 74 | self.G_weight_key = "weight" # None for hops or "weight" for lengths 75 | self.shortest_paths = None 76 | self.shortest_paths_weights = None 77 | self.adj_matrix = None 78 | 79 | # only needed for evaluate the efficiency of creating a correct random topology 80 | self.repetitions = 0 81 | 82 | self.random_topology = random_topology 83 | self.current_topology_seed = None 84 | self.sampled_topology_seeds = [] 85 | self.sequential_topology_seeds = sequential_topology_seeds 86 | self.sequential_topology_seeds_frozen = False 87 | self.sequential_topology_index = 0 88 | self.topology_init_seed = topology_init_seed 89 | self.exclude_seeds = None if excluded_seeds is None else set(excluded_seeds) 90 | self.provide_seeds = provided_seeds 91 | if provided_seeds is not None and len(provided_seeds) > 0: 92 | self.seeds = provided_seeds 93 | else: 94 | self.seeds = self.build_seed_list( 95 | random_topology, n_random_seeds, self.exclude_seeds 96 | ) 97 | if self.exclude_seeds is not None: 98 | assert all([s not in self.exclude_seeds for s in self.seeds]) 99 | 100 | def build_seed_list(self, random_topology, n_random_seeds, exclude_seeds=None): 101 | if not random_topology: 102 | return [self.topology_init_seed] 103 | 104 | if n_random_seeds is None or n_random_seeds <= 0: 105 | return [] 106 | 107 | old_rand_state = np.random.get_state() 108 | np.random.seed(self.topology_init_seed) 109 | 110 | seed_list = [] 111 | # build list of unique num_random_seeds > 0 seeds 112 | while len(seed_list) < n_random_seeds: 113 | # build one random network and add topology seed 114 | new_seed = self._create_valid_network(seeds_exclude=exclude_seeds) 115 | if new_seed not in seed_list: 116 | seed_list.append(new_seed) 117 | 118 | np.random.set_state(old_rand_state) 119 | 120 | return seed_list 121 | 122 | def _create_random_topology(self): 123 | """ 124 | Creates a new random topology. This is based on the implementation 125 | by Jiang et al. https://github.com/PKU-RL/DGN/blob/master/Routing/routers.py 126 | used for their DGN paper https://arxiv.org/abs/1810.09202. 127 | """ 128 | # otherwise create a new topology 129 | self.G = nx.Graph() 130 | self.nodes = [] 131 | self.edges = [] 132 | t_edge = 0 133 | 134 | for i in range(self.n_nodes): 135 | # add routers at random locations 136 | new_router = Node(np.random.random(), np.random.random()) 137 | self.nodes.append(new_router) 138 | self.G.add_node(i, pos=(new_router.x, new_router.y)) 139 | 140 | for i in range(self.n_nodes): 141 | # calculate (squared) distances to all other routers 142 | self.dis = [] 143 | for j in range(self.n_nodes): 144 | self.dis.append( 145 | [ 146 | (self.nodes[j].x - self.nodes[i].x) ** 2 147 | + (self.nodes[j].y - self.nodes[i].y) ** 2, 148 | j, 149 | ] 150 | ) 151 | 152 | # sort by distance 153 | self.dis.sort(key=lambda x: x[0], reverse=False) 154 | 155 | # find new neighbors 156 | # exclude index 0 as we always have distance 0 to ourselves 157 | for j in range(1, self.n_nodes): 158 | # we have found enough neighbors => break 159 | if len(self.nodes[i].neighbors) == 3: 160 | break 161 | 162 | # check for neighbor candidates 163 | candidate_sq_dist, candidate_idx = self.dis[j] 164 | if ( 165 | len(self.nodes[candidate_idx].neighbors) < 3 166 | and i not in self.nodes[candidate_idx].neighbors 167 | ): 168 | # append new neighbor 169 | self.nodes[i].neighbors.append(candidate_idx) 170 | self.nodes[candidate_idx].neighbors.append(i) 171 | 172 | # create edges, always sorted by index 173 | edge_distance = int(int(np.sqrt(candidate_sq_dist) * 10) / 2 + 1) 174 | if i < candidate_idx: 175 | new_edge = Edge(i, candidate_idx, edge_distance) 176 | else: 177 | new_edge = Edge(candidate_idx, i, edge_distance) 178 | 179 | self.edges.append(new_edge) 180 | self.nodes[candidate_idx].edges.append(t_edge) 181 | self.nodes[i].edges.append(t_edge) 182 | self.G.add_edge( 183 | new_edge.start, 184 | new_edge.end, 185 | weight=new_edge.length, 186 | ) 187 | 188 | t_edge += 1 189 | 190 | # order router edges by neighbor node id to remove symmetries 191 | for i in range(self.n_nodes): 192 | self.nodes[i].edges = sorted( 193 | self.nodes[i].edges, 194 | key=lambda edge_index: self.edges[edge_index].get_other_node(i), 195 | ) 196 | 197 | def _check_topology_constraints(self): 198 | """ 199 | Check if the current network topology fulfills the constraints, meaning it is 200 | connected and all nodes have three neighbors. 201 | 202 | :return: whether the topology is valid. 203 | """ 204 | # for the case that there is no isolated island but nodes with less than k edges 205 | for i in range(self.n_nodes): 206 | if len(self.nodes[i].neighbors) < 3: 207 | return False 208 | 209 | # this means not every nodes is reachable, we have got isolated islands 210 | if not nx.is_connected(self.G): 211 | return False 212 | 213 | return True 214 | 215 | def _create_valid_network( 216 | self, seed_list=None, seed_index=None, seeds_exclude=None 217 | ): 218 | """ 219 | Generates a network based on a list of seeds. 220 | 221 | :param seed_list: List of seeds, can be None to create new valid topology 222 | :param seed_index: Index in seed list, chooses random index if None 223 | :param seeds_exclude: List of seeds (for random generation) are excluded 224 | :returns: the seed used to generate the network topology 225 | """ 226 | 227 | # set seed for topology generation 228 | no_seed_provided = seed_list is None or len(seed_list) == 0 229 | if no_seed_provided: 230 | topology_seed = np.random.randint(2**31 - 1) 231 | while seeds_exclude is not None and topology_seed in seeds_exclude: 232 | topology_seed = np.random.randint(2**31 - 1) 233 | 234 | elif seed_index is not None: 235 | topology_seed = seed_list[seed_index] 236 | else: 237 | # choose one of the seeds from the list 238 | topology_seed = np.random.choice(seed_list) 239 | 240 | # remember random state for packet generation 241 | old_rand_state = np.random.get_state() 242 | np.random.seed(topology_seed) 243 | 244 | self.repetitions = 0 245 | while True: 246 | self._create_random_topology() 247 | self.repetitions += 1 248 | if self._check_topology_constraints(): 249 | break 250 | 251 | assert no_seed_provided, f"Provided seed {topology_seed} is invalid." 252 | topology_seed = np.random.randint(2**31 - 1) 253 | while seeds_exclude is not None and topology_seed in seeds_exclude: 254 | topology_seed = np.random.randint(2**31 - 1) 255 | np.random.seed(topology_seed) 256 | 257 | # restore old random state 258 | np.random.set_state(old_rand_state) 259 | 260 | # calculate shortest paths with corresponding distances/weights 261 | self._update_shortest_paths() 262 | 263 | # e_lens = np.array([e.len for e in self.edges]) 264 | # print( 265 | # f"Max: {e_lens.max()}, min: {e_lens.min()}, mean {e_lens.mean()}, std {e_lens.std()}" 266 | # ) 267 | 268 | self._update_nodes_adjacency() 269 | self.current_topology_seed = topology_seed 270 | 271 | # return the seed that was used to create this topology 272 | return topology_seed 273 | 274 | def _update_shortest_paths(self): 275 | """ 276 | Calculates shortest paths and stores them in self.shortest_paths. The 277 | corresponding weights (distances) are stored in self.shortest_paths_weights 278 | """ 279 | self.shortest_paths = dict(nx.shortest_path(self.G, weight=self.G_weight_key)) 280 | self.shortest_paths_weights = defaultdict(dict) 281 | for start in self.shortest_paths: 282 | for end in self.shortest_paths[start]: 283 | if self.G_weight_key is None: 284 | self.shortest_paths_weights[start][end] = ( 285 | len(self.shortest_paths[start][end]) - 1 286 | ) 287 | else: 288 | self.shortest_paths_weights[start][end] = path_weight( 289 | self.G, self.shortest_paths[start][end], self.G_weight_key 290 | ) 291 | 292 | def randomize_edge_weights(self, mode: str, **kwargs): 293 | """ 294 | Randomizes edge weights in the graph (at runtime). 295 | 296 | :param mode: `shuffle` to shuffle existing weights, `randint` with additional 297 | kwargs `low` and `high` to create new random weights 298 | :returns: tuple of (proportion of changed first hops on shortest paths, proportion 299 | of changed shortest paths, proportion of changed shortest path lengths) 300 | """ 301 | if mode == "shuffle": 302 | edge_lengths = np.array([e.length for e in self.edges]) 303 | np.random.shuffle(edge_lengths) 304 | for i, e in enumerate(self.edges): 305 | e.length = edge_lengths[i] 306 | elif mode == "randint": 307 | for e in self.edges: 308 | e.length = np.random.randint(kwargs["low"], kwargs["high"]) 309 | elif mode == "bottleneck-971182936": 310 | edge_update_list = [ 311 | (2, 7), 312 | ] 313 | for e in self.edges: 314 | for (start, end) in edge_update_list: 315 | if e.start == start and e.end == end: 316 | e.length = 10 317 | break 318 | if self.current_topology_seed != 971182936: 319 | print("Warning: mode only meant to be used in graph 971182936.") 320 | else: 321 | raise ValueError(f"Unknown mode {mode}") 322 | 323 | old_shortest_paths = self.shortest_paths.copy() 324 | old_shortest_path_weights = self.shortest_paths_weights.copy() 325 | 326 | for e in self.edges: 327 | self.G[e.start][e.end]["weight"] = e.length 328 | 329 | self._update_shortest_paths() 330 | 331 | # check how much has changed 332 | n_paths = self.n_nodes * (self.n_nodes - 1) 333 | n_paths_changed_first_hop = 0 334 | n_paths_changed = 0 335 | n_path_weights_changed = 0 336 | for a in range(self.n_nodes): 337 | for b in range(self.n_nodes): 338 | if a == b: 339 | continue 340 | if self.shortest_paths[a][b][1] != old_shortest_paths[a][b][1]: 341 | n_paths_changed_first_hop += 1 342 | if self.shortest_paths[a][b] != old_shortest_paths[a][b]: 343 | n_paths_changed += 1 344 | if self.shortest_paths_weights[a][b] != old_shortest_path_weights[a][b]: 345 | n_path_weights_changed += 1 346 | 347 | return ( 348 | n_paths_changed_first_hop / n_paths, 349 | n_paths_changed / n_paths, 350 | n_path_weights_changed / n_paths, 351 | ) 352 | 353 | def freeze_sequential_topology_seeds(self): 354 | self.sequential_topology_seeds_frozen = True 355 | 356 | def next_topology_seed_index(self, advance_index=True): 357 | seed_index = ( 358 | self.sequential_topology_index 359 | if len(self.seeds) > 1 and self.sequential_topology_seeds 360 | else None 361 | ) 362 | if seed_index is not None and advance_index: 363 | self.sequential_topology_index = (seed_index + 1) % len(self.seeds) 364 | return seed_index 365 | 366 | def reset(self): 367 | seed_index = self.next_topology_seed_index( 368 | advance_index=not self.sequential_topology_seeds_frozen 369 | ) 370 | self._create_valid_network(self.seeds, seed_index, self.exclude_seeds) 371 | self.sampled_topology_seeds.append(self.current_topology_seed) 372 | 373 | def render(self): 374 | nx.draw_networkx(self.G, with_labels=True, node_color="pink") 375 | plt.show() 376 | 377 | def get_nodes_adjacency(self): 378 | """ 379 | Get the adjacency matrix for all routers (nodes) in the network. 380 | 381 | return: adjacency matrix of size (n_router, n_router) 382 | """ 383 | return self.adj_matrix 384 | 385 | def _update_nodes_adjacency(self): 386 | self.adj_matrix = np.eye(self.n_nodes, self.n_nodes, dtype=np.int8) 387 | for i in range(self.n_nodes): 388 | for neighbor in self.nodes[i].neighbors: 389 | self.adj_matrix[i][neighbor] = 1 390 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import json 3 | import lzma 4 | from pathlib import Path 5 | import pickle 6 | from typing import Any, Dict, List, NamedTuple, Optional, Union 7 | from matplotlib import pyplot as plt 8 | import networkx as nx 9 | 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | from env.routing import Routing 14 | 15 | 16 | class StepStats(NamedTuple): 17 | n: int 18 | obs: np.ndarray 19 | adj: np.ndarray 20 | act: np.ndarray 21 | reward: np.ndarray 22 | done: np.ndarray 23 | info: dict 24 | node_state: Optional[np.ndarray] 25 | node_aux: Optional[np.ndarray] 26 | 27 | 28 | class EpisodeStats(NamedTuple): 29 | steps: List[StepStats] 30 | aux: Dict[str, Any] 31 | 32 | 33 | def evaluate( 34 | env, 35 | policy, 36 | episodes, 37 | steps_per_episode, 38 | disable_progressbar=False, 39 | output_dir: Optional[Union[Path, str]] = None, 40 | output_detailed=False, 41 | output_node_state_aux=False, 42 | ): 43 | if output_dir is not None: 44 | if isinstance(output_dir, str): 45 | output_dir = Path(output_dir) 46 | output_dir.mkdir(exist_ok=True, parents=True) 47 | 48 | episode_stats = [] 49 | 50 | if hasattr(policy, "eval"): 51 | policy.eval() 52 | 53 | if hasattr(env, "set_eval_info"): 54 | env.set_eval_info(True) 55 | 56 | # perform evaluation 57 | # print("Performing Evaluation") 58 | for ep in tqdm(range(episodes), disable=disable_progressbar): 59 | step_stats = [] 60 | aux_stats = {} 61 | obs, adj = env.reset() 62 | 63 | # print(f"Episode {ep} seed was {env.current_topology_seed}") 64 | 65 | # reset all agents 66 | if hasattr(policy, "reset"): 67 | policy.reset(1) 68 | 69 | if hasattr(policy, "reset_episode"): 70 | policy.reset_episode() 71 | 72 | for step in range(steps_per_episode): 73 | if hasattr(env, "netmon"): 74 | node_state = env.netmon.state.detach().cpu().squeeze(0).numpy() 75 | if node_state.size == 0: 76 | node_state = np.zeros((obs.shape[0], 1)) 77 | else: 78 | node_state = np.zeros((obs.shape[0], 1)) 79 | 80 | if output_node_state_aux: 81 | node_aux = env.get_node_aux() 82 | else: 83 | node_aux = None 84 | 85 | actions = policy(obs, adj) 86 | next_obs, next_adj, reward, done, info = env.step(actions) 87 | 88 | if step + 1 == steps_per_episode: 89 | # also add delays of agents that did not arrive 90 | info = env.get_final_info(info) 91 | 92 | # manual eval experiments 93 | 94 | # experiment from the paper with selected bottleneck link 95 | # if step == 50: 96 | # print(env.network.randomize_edge_weights("bottleneck-971182936")) 97 | 98 | # stop adapting netmon after some steps 99 | # if step == 25: 100 | # env.freeze() 101 | 102 | # if (step + 1) % 100 == 0: 103 | # env.current_netmon_state = None 104 | 105 | step_stats.append( 106 | StepStats( 107 | step, obs, adj, actions, reward, done, info, node_state, node_aux 108 | ) 109 | ) 110 | 111 | # reset done agents 112 | if hasattr(policy, "reset"): 113 | policy.reset(done) 114 | 115 | obs = next_obs 116 | adj = next_adj 117 | 118 | if isinstance(env.get(), Routing): 119 | aux_stats["distance_map"] = env.distance_map.copy() 120 | aux_stats["sum_packets_per_node"] = env.sum_packets_per_node 121 | aux_stats["sum_packets_per_edge"] = env.sum_packets_per_edge 122 | aux_stats["G"] = env.network.G.copy() 123 | env.distance_map.clear() 124 | 125 | episode_stats.append(EpisodeStats(step_stats, aux_stats)) 126 | 127 | if hasattr(env, "set_eval_info"): 128 | env.set_eval_info(False) 129 | 130 | eval_metrics = get_eval_metrics(episode_stats) 131 | 132 | if output_dir is not None: 133 | output_dir.mkdir(exist_ok=True, parents=True) 134 | with open(output_dir / "metrics.json", "w+") as f: 135 | json.dump(eval_metrics, f, indent=4, sort_keys=True, default=str) 136 | 137 | if isinstance(env.get(), Routing): 138 | create_routing_plots( 139 | episode_stats, output_dir, output_detailed, output_node_state_aux 140 | ) 141 | 142 | return eval_metrics 143 | 144 | 145 | def get_eval_metrics(episode_stats: List[EpisodeStats]): 146 | stats_lists = defaultdict(list) 147 | # join stats for each step in each episode 148 | for episode in episode_stats: 149 | for step in episode.steps: 150 | for k, v in step.info.items(): 151 | if isinstance(v, list): 152 | stats_lists[k] += v 153 | else: 154 | stats_lists[k].append(v) 155 | 156 | for r in step.reward: 157 | stats_lists["reward"].append(r) 158 | 159 | # calculate mean 160 | metrics = dict() 161 | for k, v in stats_lists.items(): 162 | if len(stats_lists[k]) > 0: 163 | v_arr = np.array(v) 164 | metrics[k + "_mean"] = v_arr.mean() 165 | else: 166 | metrics[k + "_mean"] = float("inf") 167 | 168 | return metrics 169 | 170 | 171 | def save_distance_map_plot(distance_map, filename): 172 | if len(distance_map) == 0: 173 | return 174 | 175 | X = np.sort(list(distance_map.keys())) 176 | Y = np.zeros_like(X, dtype=float) 177 | Y_err = np.zeros_like(X) 178 | for i, x in enumerate(X): 179 | Y_arr = np.array(distance_map[x]) 180 | Y[i] = Y_arr.mean() 181 | Y_err[i] = Y_arr.std() 182 | 183 | plt.clf() 184 | plt.plot(X, Y, label="Agent") 185 | plt.plot(X, X, label="Lower Bound") 186 | plt.xlabel("Shortest path [steps]") 187 | plt.ylabel("Agent path [steps]") 188 | plt.legend() 189 | plt.tight_layout() 190 | plt.savefig(filename, bbox_inches="tight") 191 | 192 | 193 | def save_packet_location_graph( 194 | G, 195 | sum_packets_per_node, 196 | sum_packets_per_edge, 197 | num_steps, 198 | filename, 199 | ): 200 | plt.clf() 201 | pos = nx.drawing.spring_layout(G, seed=1337) 202 | # pos = nx.get_node_attributes(G, "pos") 203 | edge_weight = np.array([data["weight"] for n1, n2, data in G.edges(data=True)]) 204 | nx_edges = nx.draw_networkx_edges( 205 | G, 206 | pos=pos, 207 | width=4, 208 | edge_color=sum_packets_per_edge / (np.sum(sum_packets_per_edge) * edge_weight), 209 | edge_cmap=plt.get_cmap("viridis"), 210 | ) 211 | plt.colorbar(nx_edges, label="Normalized edge utilization") 212 | nx_nodes = nx.draw_networkx_nodes( 213 | G, 214 | pos=pos, 215 | node_color=sum_packets_per_node / np.sum(sum_packets_per_node), 216 | cmap=plt.get_cmap("viridis"), 217 | ) 218 | nx.draw_networkx_labels( 219 | G, 220 | pos, 221 | labels=dict([(i, i) for i in range(G.order())]), 222 | ) 223 | plt.colorbar(nx_nodes, label="Normalized node utilization") 224 | nx.draw_networkx_edge_labels( 225 | G, 226 | pos, 227 | edge_labels=nx.get_edge_attributes(G, "weight"), 228 | bbox=dict(boxstyle="round", facecolor="white", alpha=0.0, edgecolor="white"), 229 | ) 230 | # remove border around network 231 | plt.gca().axis("off") 232 | plt.tight_layout() 233 | plt.savefig(filename, bbox_inches="tight") 234 | 235 | 236 | def ddlist(): 237 | """ 238 | Default dict of lists, required for pickle. 239 | 240 | :return: a defaultdict of lists. 241 | """ 242 | return defaultdict(list) 243 | 244 | 245 | def create_routing_plots( 246 | episode_stats: List[EpisodeStats], 247 | output_dir: Path, 248 | output_detailed, 249 | output_node_state_aux, 250 | ): 251 | d = defaultdict(ddlist) 252 | 253 | # join stats for each step in each episode over all nodes 254 | for episode_i, episode in enumerate(episode_stats): 255 | last_node_state = 0 256 | last_step = None 257 | 258 | if output_detailed: 259 | save_distance_map_plot( 260 | episode.aux["distance_map"], 261 | output_dir / f"ep_{episode_i}_distance_map.png", 262 | ) 263 | save_packet_location_graph( 264 | episode.aux["G"], 265 | episode.aux["sum_packets_per_node"], 266 | episode.aux["sum_packets_per_edge"], 267 | len(episode.steps), 268 | output_dir / f"ep_{episode_i}_packet_heatmap.png", 269 | ) 270 | 271 | for k, v in episode.aux["distance_map"].items(): 272 | d["distance_map"][k] += v 273 | 274 | if output_detailed: 275 | d["episode_throughput"][episode_i] = defaultdict(list) 276 | for step in episode.steps: 277 | d["episode_done_ids"][episode_i] += step.done.nonzero()[0].tolist() 278 | if output_detailed: 279 | # episode-wise stats 280 | d["episode_throughput"][episode_i][step.n].append( 281 | step.info["throughput"] 282 | ) 283 | 284 | # step wise stats 285 | d["feature_diffs"][step.n].append( 286 | np.mean(np.abs(step.node_state - last_node_state), axis=0) 287 | ) 288 | d["feature_diffs_mean"][step.n].append( 289 | np.mean(np.abs(step.node_state - last_node_state)) 290 | ) 291 | d["feature_diffs_max"][step.n].append( 292 | np.max(np.abs(step.node_state - last_node_state)) 293 | ) 294 | d["reward"][step.n].append(np.mean(step.reward)) 295 | d["feature_std"][step.n].append(np.std(step.node_state, axis=0)) 296 | d["feature_mean"][step.n].append(np.mean(step.node_state, axis=0)) 297 | d["dropped"][step.n].append(step.info["dropped"]) 298 | d["throughput"][step.n].append(step.info["throughput"]) 299 | d["blocked"][step.n].append(step.info["blocked"]) 300 | d["total_edge_load"][step.n].append(step.info["total_edge_load"]) 301 | 302 | if len(step.info["delays_arrived"]) > 0: 303 | d["delays_arrived_mean"][step.n].append( 304 | np.mean(step.info["delays_arrived"]) 305 | ) 306 | 307 | if len(step.info["spr"]) > 0: 308 | d["spr_mean"][step.n].append(np.mean(step.info["spr"])) 309 | d["spr_min"][step.n].append(np.min(step.info["spr"])) 310 | 311 | d["looped"][step.n].append(step.info["looped"]) 312 | 313 | if last_step is not None: 314 | for a in range(step.done.shape[0]): 315 | if not step.done[a]: 316 | diff = np.abs( 317 | step.info["packet_distances"][a] 318 | - last_step.info["packet_distances"][a] 319 | ) 320 | d["packet_distance_delta"][step.n].append(diff) 321 | 322 | d["occupied_edges"][step.n].append(step.info["occupied_edges"]) 323 | d["packets_on_edges"][step.n].append(step.info["packets_on_edges"]) 324 | d["total_packet_size"][step.n].append(step.info["total_packet_size"]) 325 | d["packet_distance_mean"][step.n].append( 326 | np.mean(step.info["packet_distances"]) 327 | ) 328 | d["packet_distance_max"][step.n].append( 329 | np.max(step.info["packet_distances"]) 330 | ) 331 | for a in range(step.act.shape[0]): 332 | packet_size = step.info["packet_sizes"][a] 333 | action = step.act[a] 334 | d["action_to_packet_size"][action].append(packet_size) 335 | 336 | last_node_state = step.node_state 337 | last_step = step 338 | 339 | # aggregate done stats 340 | n_agents = episode_stats[0].steps[0].act.shape[0] 341 | episode_done_agents = np.zeros((len(episode_stats), n_agents)) 342 | for i in range(len(episode_stats)): 343 | done_agents = list(np.unique(d["episode_done_ids"][i])) 344 | episode_done_agents[i, done_agents] = 1 345 | 346 | d["episode_done_agents"] = episode_done_agents 347 | 348 | def plot_attribute(step_value_dict, ylabel, filename, start=0, end=None): 349 | # automatically try to resolve dict if it is a key 350 | if isinstance(step_value_dict, str): 351 | step_value_dict = d[step_value_dict] 352 | 353 | # plot diffs 354 | x = list(step_value_dict.keys()) 355 | x = np.array(sorted(x)) 356 | mean = np.zeros(len(x)) 357 | std = np.zeros(len(x)) 358 | 359 | for i, (step_i, val) in enumerate(step_value_dict.items()): 360 | val_np = np.array(val) 361 | mean[i] = val_np.mean() 362 | std[i] = val_np.std() 363 | 364 | plt.fill_between( 365 | x[start:end], (mean - std)[start:end], (mean + std)[start:end], alpha=0.2 366 | ) 367 | plt.plot(x[start:end], mean[start:end]) 368 | plt.xlabel("Steps") 369 | plt.ylabel(ylabel) 370 | plt.savefig(filename, bbox_inches="tight") 371 | plt.clf() 372 | 373 | plot_attribute("reward", "Mean reward", output_dir / "reward.png") 374 | plot_attribute( 375 | "feature_diffs", 376 | "Mean node state difference", 377 | output_dir / "node_diff.png", 378 | start=1, 379 | ) 380 | plot_attribute( 381 | "feature_diffs_mean", 382 | "Mean node state difference", 383 | output_dir / "node_diff_mean.png", 384 | start=1, 385 | ) 386 | plot_attribute( 387 | "feature_diffs_max", 388 | "Mean node state difference", 389 | output_dir / "node_diff_max.png", 390 | start=1, 391 | ) 392 | plot_attribute("throughput", "Throughput", output_dir / "throughput.png") 393 | plot_attribute("blocked", "Mean blocked", output_dir / "blocked.png") 394 | plot_attribute("total_edge_load", "Total edge load", output_dir / "edge_load.png") 395 | plot_attribute( 396 | "occupied_edges", "Occupied edges", output_dir / "occupied_edges.png" 397 | ) 398 | plot_attribute( 399 | "packets_on_edges", "Packets on edges", output_dir / "packets_on_edges.png" 400 | ) 401 | plot_attribute( 402 | "total_packet_size", "Total packet size", output_dir / "packet_size.png" 403 | ) 404 | plot_attribute( 405 | "delays_arrived_mean", 406 | "Delays of arrived packets", 407 | output_dir / "delays_arrived.png", 408 | ) 409 | plot_attribute("spr_mean", "Mean spr", output_dir / "spr_mean.png") 410 | plot_attribute("spr_min", "Min spr", output_dir / "spr_min.png") 411 | plot_attribute( 412 | "packet_distance_mean", "Mean distance", output_dir / "distance_mean.png" 413 | ) 414 | plot_attribute( 415 | "packet_distance_max", "Max distance", output_dir / "distance_max.png" 416 | ) 417 | plot_attribute( 418 | "packet_distance_delta", "Distance delta", output_dir / "distance_delta.png" 419 | ) 420 | plot_attribute("looped", "Looped", output_dir / "looped.png") 421 | plot_attribute("dropped", "Dropped", output_dir / "dropped.png") 422 | 423 | # plot packet size to action..plot 424 | for action in d["action_to_packet_size"]: 425 | plt.hist(d["action_to_packet_size"][action], bins=100) 426 | plt.xlabel(f"Packet sizes for action {action}") 427 | plt.ylabel("Counts") 428 | plt.savefig(output_dir / f"packet_size_act_{action}.png", bbox_inches="tight") 429 | plt.clf() 430 | 431 | def plot_done_hists(x, episode_id): 432 | plt.hist(x, bins=20) 433 | plt.xlabel("Packet id") 434 | plt.ylabel("Done count") 435 | plt.savefig(output_dir / f"ep_{episode_id}_done_hist.png", bbox_inches="tight") 436 | plt.clf() 437 | 438 | if output_detailed: 439 | for i in range(len(episode_stats)): 440 | plot_done_hists(d["episode_done_ids"][i], i) 441 | plot_attribute( 442 | d["episode_throughput"][i], 443 | "Throughput", 444 | output_dir / f"ep_{i}_throughput.png", 445 | ) 446 | 447 | save_distance_map_plot(d["distance_map"], output_dir / "distance_map.png") 448 | 449 | # save selected stats for combined plotting (paper) 450 | with lzma.open(output_dir / "lzma_d.pk", "wb") as f: 451 | pickle.dump( 452 | { 453 | k: d[k] 454 | for k in [ 455 | "feature_diffs_mean", 456 | "feature_diffs_max", 457 | "throughput", 458 | "delays_arrived_mean", 459 | "looped", 460 | "reward", 461 | "dropped", 462 | "episode_done_agents", 463 | ] 464 | }, 465 | f, 466 | ) 467 | 468 | if output_node_state_aux: 469 | all_node_states = np.stack( 470 | [ 471 | np.stack([step.node_state for step in episode.steps]) 472 | for episode in episode_stats 473 | ] 474 | ) 475 | all_node_aux = np.stack( 476 | [ 477 | np.stack([step.node_aux for step in episode.steps]) 478 | for episode in episode_stats 479 | ] 480 | ) 481 | 482 | np.savez_compressed( 483 | output_dir / "node_state_aux", 484 | node_state=all_node_states, 485 | node_aux=all_node_aux, 486 | ) 487 | 488 | # plot node state content 489 | # max_episode_steps = len(episode_step_stats[0]) 490 | # node_std_img = np.zeros((max_episode_steps, len(feature_std[0][0]))) 491 | # node_mean_img = np.zeros((max_episode_steps, len(feature_std[0][0]))) 492 | 493 | # for step in feature_std.keys(): 494 | # node_std_img[step] = np.array(feature_std[step]).mean(axis=0) 495 | # node_mean_img[step] = np.array(feature_mean[step]).mean(axis=0) 496 | 497 | # fig, axs = plt.subplots(2, 1, sharex=True) 498 | # axs[0].imshow(node_std_img) 499 | # axs[1].imshow(node_mean_img) 500 | # plt.show() 501 | -------------------------------------------------------------------------------- /src/env/constants.py: -------------------------------------------------------------------------------- 1 | # first 1000 valid topologies generated with default init seed --topology-init-seed=476 2 | # hard-coded to use use the same topologies for testing, irrespective of the train seed 3 | EVAL_SEEDS = [ 4 | 2138774885, 5 | 1802057910, 6 | 1462139390, 7 | 131710925, 8 | 237865846, 9 | 1510037652, 10 | 1595811251, 11 | 1681391277, 12 | 628748993, 13 | 1477609498, 14 | 1538604031, 15 | 718354550, 16 | 1379440459, 17 | 1885302144, 18 | 1604851542, 19 | 971512681, 20 | 1618767442, 21 | 1293700628, 22 | 821234700, 23 | 791456496, 24 | 2108660450, 25 | 1851776804, 26 | 1232514822, 27 | 920431730, 28 | 1492372024, 29 | 1662026582, 30 | 1882845857, 31 | 1744297538, 32 | 611417999, 33 | 334419018, 34 | 1074031085, 35 | 249468174, 36 | 11602613, 37 | 1189619244, 38 | 1136528045, 39 | 1368622316, 40 | 412461771, 41 | 1648823123, 42 | 427360781, 43 | 723790196, 44 | 1862329693, 45 | 1448058228, 46 | 1904112242, 47 | 385714911, 48 | 1234292532, 49 | 648971536, 50 | 648649592, 51 | 947597750, 52 | 1791116744, 53 | 861798368, 54 | 1485520444, 55 | 892667998, 56 | 1806121462, 57 | 57154220, 58 | 1462323865, 59 | 209658998, 60 | 1138985440, 61 | 940158033, 62 | 1904761510, 63 | 405011735, 64 | 2059638010, 65 | 1994823182, 66 | 1576394177, 67 | 93520195, 68 | 583511077, 69 | 1818584892, 70 | 1246972224, 71 | 292610438, 72 | 2135753609, 73 | 466839424, 74 | 46190961, 75 | 1211008616, 76 | 659380028, 77 | 1966292323, 78 | 1724122341, 79 | 476988447, 80 | 506251940, 81 | 1336831022, 82 | 188216578, 83 | 727018480, 84 | 1209638696, 85 | 1936149392, 86 | 375136146, 87 | 836751739, 88 | 1215106835, 89 | 1564873376, 90 | 565658269, 91 | 1022265636, 92 | 694748128, 93 | 86252569, 94 | 791124196, 95 | 63934777, 96 | 1542991501, 97 | 1869290369, 98 | 816409505, 99 | 1393131267, 100 | 94259669, 101 | 1633257115, 102 | 1456794106, 103 | 19399640, 104 | 718852713, 105 | 213009216, 106 | 1142123422, 107 | 48847165, 108 | 1039613902, 109 | 136298308, 110 | 500636818, 111 | 752007200, 112 | 1695791055, 113 | 1624646434, 114 | 105138361, 115 | 1875810224, 116 | 1314663030, 117 | 1138146363, 118 | 301732049, 119 | 369866175, 120 | 326826959, 121 | 423448415, 122 | 1981043339, 123 | 354475974, 124 | 1716016325, 125 | 439283282, 126 | 1938159865, 127 | 2071235559, 128 | 1462437244, 129 | 1334542488, 130 | 581992610, 131 | 92670783, 132 | 889016083, 133 | 653502062, 134 | 398165042, 135 | 1867641698, 136 | 1301530639, 137 | 1841017512, 138 | 412534520, 139 | 1319263700, 140 | 1014471347, 141 | 141021217, 142 | 872667392, 143 | 207702942, 144 | 1216137449, 145 | 2099051087, 146 | 662573662, 147 | 1560922990, 148 | 78415505, 149 | 1383599352, 150 | 1086719028, 151 | 1723110152, 152 | 1504708566, 153 | 854091357, 154 | 1074092318, 155 | 427413928, 156 | 1011358926, 157 | 968368716, 158 | 1387107320, 159 | 2117428220, 160 | 181637955, 161 | 608422009, 162 | 1093464142, 163 | 167446689, 164 | 1119937975, 165 | 1695128826, 166 | 1388812304, 167 | 586930137, 168 | 382430086, 169 | 324821133, 170 | 1385664729, 171 | 166633377, 172 | 161357245, 173 | 1695790171, 174 | 909165233, 175 | 1470316681, 176 | 763149588, 177 | 19973759, 178 | 1006277766, 179 | 1269937330, 180 | 1199272002, 181 | 676908685, 182 | 1712403175, 183 | 1668335234, 184 | 975779314, 185 | 1578863569, 186 | 1799118892, 187 | 494596527, 188 | 1384271815, 189 | 1998529640, 190 | 255450474, 191 | 1180730322, 192 | 1167098504, 193 | 1686111784, 194 | 1988885304, 195 | 2000211921, 196 | 648897271, 197 | 568860464, 198 | 163721982, 199 | 1931058049, 200 | 1410064767, 201 | 2104761057, 202 | 892186498, 203 | 1811082772, 204 | 1515679945, 205 | 1719061896, 206 | 511563846, 207 | 2066618332, 208 | 1033716940, 209 | 1885255953, 210 | 1198574103, 211 | 292761050, 212 | 489960213, 213 | 503275656, 214 | 1636046062, 215 | 823631881, 216 | 1540375005, 217 | 1350854024, 218 | 125978836, 219 | 221985855, 220 | 96457523, 221 | 463963718, 222 | 1418829212, 223 | 1894702209, 224 | 1845256749, 225 | 2052240772, 226 | 1740546692, 227 | 698342137, 228 | 1526212900, 229 | 529507658, 230 | 849719367, 231 | 1452481881, 232 | 130094862, 233 | 177684417, 234 | 1262489981, 235 | 1605113823, 236 | 581803023, 237 | 1897418457, 238 | 573975554, 239 | 1692639319, 240 | 443915493, 241 | 885540673, 242 | 1163559751, 243 | 956143257, 244 | 828381560, 245 | 901483007, 246 | 1192253857, 247 | 1449959669, 248 | 1897822574, 249 | 1336067652, 250 | 2122393396, 251 | 179918628, 252 | 1902924948, 253 | 391115389, 254 | 220724228, 255 | 1445828275, 256 | 620167002, 257 | 1618995961, 258 | 705560869, 259 | 429416190, 260 | 926147294, 261 | 753909044, 262 | 1602738457, 263 | 1338919001, 264 | 1403848543, 265 | 1598963603, 266 | 634823219, 267 | 526813974, 268 | 1556746757, 269 | 1834164754, 270 | 1345420349, 271 | 1749308964, 272 | 2022322261, 273 | 2086958891, 274 | 1215978306, 275 | 233713781, 276 | 1877251282, 277 | 1686318743, 278 | 552445767, 279 | 214110606, 280 | 77123318, 281 | 1561531498, 282 | 101328923, 283 | 1352730451, 284 | 1367602223, 285 | 913251015, 286 | 1220723061, 287 | 900985363, 288 | 1835507367, 289 | 1853516888, 290 | 1438632345, 291 | 781252138, 292 | 527389814, 293 | 1087768793, 294 | 1080153425, 295 | 828755564, 296 | 571792340, 297 | 1144521817, 298 | 1207042762, 299 | 399166300, 300 | 1042532860, 301 | 552229682, 302 | 1412782342, 303 | 403827280, 304 | 662128546, 305 | 1853684256, 306 | 1294002298, 307 | 2041566694, 308 | 1659238606, 309 | 1277654792, 310 | 1730744469, 311 | 1331016792, 312 | 293575933, 313 | 2135948498, 314 | 1434468563, 315 | 459768125, 316 | 903596263, 317 | 1715889147, 318 | 69254672, 319 | 360869478, 320 | 1056900256, 321 | 3744665, 322 | 1046953748, 323 | 1897842360, 324 | 1415319170, 325 | 1983091349, 326 | 1930485231, 327 | 822902683, 328 | 1610019577, 329 | 1923270533, 330 | 1188753409, 331 | 284076257, 332 | 663833375, 333 | 947754233, 334 | 1217864999, 335 | 1675970665, 336 | 106432033, 337 | 769515486, 338 | 1513148572, 339 | 1226543221, 340 | 2010814751, 341 | 1125284535, 342 | 1959005469, 343 | 1349729838, 344 | 714740639, 345 | 775153467, 346 | 1457501836, 347 | 527400640, 348 | 637444002, 349 | 1338703567, 350 | 135353153, 351 | 204603111, 352 | 2097120693, 353 | 1874731722, 354 | 923430603, 355 | 876779933, 356 | 285657543, 357 | 761716517, 358 | 224903117, 359 | 2146831603, 360 | 183237257, 361 | 1477070528, 362 | 1089079416, 363 | 42262289, 364 | 2139773943, 365 | 1652861119, 366 | 1194380101, 367 | 1008947702, 368 | 1150666119, 369 | 1446603103, 370 | 384404884, 371 | 1400678718, 372 | 910189803, 373 | 2096114186, 374 | 16963923, 375 | 105356225, 376 | 823341795, 377 | 1846530174, 378 | 1433019784, 379 | 1651185611, 380 | 1244351578, 381 | 2008145494, 382 | 149013009, 383 | 1143473050, 384 | 1125452617, 385 | 1840801401, 386 | 2137760171, 387 | 1780561802, 388 | 906565827, 389 | 74271094, 390 | 353265263, 391 | 99987523, 392 | 592475239, 393 | 1358104198, 394 | 1186073325, 395 | 1151209574, 396 | 2051337283, 397 | 297251918, 398 | 108288610, 399 | 944328881, 400 | 199828117, 401 | 1980810309, 402 | 1828361517, 403 | 174451486, 404 | 140594438, 405 | 729406367, 406 | 1989023760, 407 | 221573786, 408 | 64818540, 409 | 1108150877, 410 | 916916781, 411 | 1252292516, 412 | 1850849570, 413 | 740790881, 414 | 1450434155, 415 | 32669285, 416 | 1431969111, 417 | 1795490043, 418 | 736731214, 419 | 444629195, 420 | 1623489826, 421 | 1178378640, 422 | 1043230578, 423 | 813941745, 424 | 1449026504, 425 | 11681283, 426 | 9726029, 427 | 1278841639, 428 | 377718175, 429 | 1129419052, 430 | 1612166837, 431 | 1759018360, 432 | 1158068648, 433 | 1743279196, 434 | 1733614998, 435 | 578961716, 436 | 1511907808, 437 | 1632916070, 438 | 683832574, 439 | 1015836548, 440 | 786227169, 441 | 189768713, 442 | 814554905, 443 | 1183700103, 444 | 954300546, 445 | 1039475414, 446 | 1656837350, 447 | 395668095, 448 | 1886321972, 449 | 638146272, 450 | 1488340776, 451 | 1251279135, 452 | 520739682, 453 | 1571641922, 454 | 82222799, 455 | 1777905110, 456 | 1271626760, 457 | 1859747931, 458 | 105784940, 459 | 398172294, 460 | 1557263386, 461 | 1868462267, 462 | 846010559, 463 | 1420178983, 464 | 600910112, 465 | 865749537, 466 | 670525289, 467 | 1190307107, 468 | 930227587, 469 | 1731233449, 470 | 2051900839, 471 | 251683625, 472 | 555596954, 473 | 701702490, 474 | 1811195221, 475 | 394999752, 476 | 1954485675, 477 | 563793661, 478 | 245291935, 479 | 1118327033, 480 | 271487261, 481 | 1026049176, 482 | 846408041, 483 | 2011105496, 484 | 394938489, 485 | 1126324030, 486 | 971182936, 487 | 127888765, 488 | 388695932, 489 | 192985197, 490 | 389611354, 491 | 1023976525, 492 | 958091795, 493 | 1743792798, 494 | 411086329, 495 | 2129027809, 496 | 1130043654, 497 | 89954270, 498 | 1068566443, 499 | 2091938526, 500 | 1683299080, 501 | 1107532284, 502 | 2081728602, 503 | 1364523218, 504 | 1728193210, 505 | 730137799, 506 | 1746476844, 507 | 288791655, 508 | 594236773, 509 | 1169295767, 510 | 1568792599, 511 | 1305568797, 512 | 1177350554, 513 | 335491771, 514 | 221040126, 515 | 1472704927, 516 | 279588594, 517 | 1681709554, 518 | 1791610268, 519 | 1204905813, 520 | 110281290, 521 | 986157179, 522 | 1470977629, 523 | 1191759237, 524 | 743243638, 525 | 105757200, 526 | 1228065424, 527 | 305720634, 528 | 615955509, 529 | 1835185255, 530 | 1202205611, 531 | 426863263, 532 | 1120623605, 533 | 862934033, 534 | 7170846, 535 | 452080836, 536 | 1931901662, 537 | 425240001, 538 | 615815805, 539 | 549494669, 540 | 1366461951, 541 | 286355409, 542 | 1706633260, 543 | 1559919332, 544 | 1014841846, 545 | 2041520273, 546 | 1833201034, 547 | 1403672434, 548 | 1765687069, 549 | 681569355, 550 | 1176639776, 551 | 350186779, 552 | 264252844, 553 | 1437553485, 554 | 1398715485, 555 | 888342058, 556 | 557249375, 557 | 1043860618, 558 | 877069403, 559 | 1490753370, 560 | 647989259, 561 | 134427484, 562 | 1631185040, 563 | 974425967, 564 | 1611800148, 565 | 1120106428, 566 | 1411978113, 567 | 1048162263, 568 | 1843367844, 569 | 1713959785, 570 | 576915890, 571 | 1727285072, 572 | 1303384685, 573 | 2024873530, 574 | 2034684012, 575 | 957902988, 576 | 457427344, 577 | 309587609, 578 | 1312033428, 579 | 566422202, 580 | 98473376, 581 | 754317156, 582 | 608432207, 583 | 1195417297, 584 | 738730116, 585 | 1317510567, 586 | 1171974841, 587 | 2058556219, 588 | 1654032124, 589 | 771131839, 590 | 1902772945, 591 | 1434911098, 592 | 1418180630, 593 | 149325386, 594 | 681024359, 595 | 1653976383, 596 | 329347505, 597 | 38943316, 598 | 825616265, 599 | 117368963, 600 | 273254898, 601 | 999395139, 602 | 358850125, 603 | 15674333, 604 | 193964982, 605 | 1845080990, 606 | 1201118350, 607 | 139608511, 608 | 1616940231, 609 | 2022768652, 610 | 650502891, 611 | 1563359102, 612 | 959212916, 613 | 1089867813, 614 | 1083313017, 615 | 8061318, 616 | 1103489726, 617 | 417283080, 618 | 891223922, 619 | 1409818467, 620 | 1705778279, 621 | 1726638967, 622 | 976102401, 623 | 1233118449, 624 | 1168769893, 625 | 917263497, 626 | 714149218, 627 | 606974545, 628 | 1145391863, 629 | 1251606536, 630 | 373519147, 631 | 1970532272, 632 | 1811528956, 633 | 587494270, 634 | 1484505001, 635 | 955751667, 636 | 2027414786, 637 | 925174548, 638 | 544213819, 639 | 451468197, 640 | 2106757469, 641 | 871331914, 642 | 744179519, 643 | 51707587, 644 | 1098124866, 645 | 1012157917, 646 | 192580, 647 | 14283732, 648 | 2141314747, 649 | 835950990, 650 | 590919807, 651 | 2027358987, 652 | 1051916, 653 | 1385430063, 654 | 1107937578, 655 | 1641625182, 656 | 1226466379, 657 | 676969206, 658 | 578327269, 659 | 568837333, 660 | 441641058, 661 | 415185069, 662 | 974957341, 663 | 604900356, 664 | 1433672383, 665 | 1090597699, 666 | 1364824196, 667 | 1540496691, 668 | 1456988889, 669 | 660340140, 670 | 1853826053, 671 | 1661911535, 672 | 1049064198, 673 | 1692975715, 674 | 1403813457, 675 | 1468642930, 676 | 1937989989, 677 | 1735370913, 678 | 983859063, 679 | 1751385295, 680 | 1429573017, 681 | 1604777388, 682 | 1663717210, 683 | 1402965810, 684 | 436162440, 685 | 1010127374, 686 | 1202077659, 687 | 1862294692, 688 | 820660045, 689 | 1627864947, 690 | 756357701, 691 | 1322277675, 692 | 1211800114, 693 | 476088948, 694 | 1388491636, 695 | 1950830918, 696 | 1294484815, 697 | 1351319481, 698 | 458840024, 699 | 1640507928, 700 | 32483790, 701 | 448182511, 702 | 1630429721, 703 | 488092091, 704 | 406859765, 705 | 835470228, 706 | 348227548, 707 | 1661043492, 708 | 53098947, 709 | 1541011976, 710 | 2054285665, 711 | 1228645928, 712 | 861881046, 713 | 1360147907, 714 | 1704443687, 715 | 1195435328, 716 | 1620809389, 717 | 310746026, 718 | 1741158115, 719 | 12208007, 720 | 585920303, 721 | 1945780841, 722 | 1589433480, 723 | 1155078380, 724 | 363163445, 725 | 2058715823, 726 | 1728882296, 727 | 2055377873, 728 | 283989607, 729 | 1854288957, 730 | 165696449, 731 | 851506015, 732 | 1856749041, 733 | 2144772409, 734 | 907542948, 735 | 1762927083, 736 | 1136691439, 737 | 1214922984, 738 | 779852122, 739 | 1347122462, 740 | 1801097519, 741 | 1905994686, 742 | 1369078864, 743 | 1210512976, 744 | 1354916752, 745 | 806332442, 746 | 1308534954, 747 | 1022616967, 748 | 869606923, 749 | 100888543, 750 | 895625910, 751 | 1706266011, 752 | 1585519011, 753 | 1366947010, 754 | 648950566, 755 | 809644768, 756 | 485956962, 757 | 950349108, 758 | 1316500400, 759 | 1470150067, 760 | 402644912, 761 | 1623005697, 762 | 1547390765, 763 | 1085941515, 764 | 1129941196, 765 | 1099670549, 766 | 1987059390, 767 | 401794238, 768 | 1916912981, 769 | 1416029460, 770 | 913402330, 771 | 1734524721, 772 | 993641815, 773 | 1928856243, 774 | 1406765862, 775 | 2081039609, 776 | 1327275431, 777 | 64362140, 778 | 1867847352, 779 | 131697234, 780 | 1395517972, 781 | 736952332, 782 | 1025700916, 783 | 1015908867, 784 | 1439564941, 785 | 1316636059, 786 | 1611952240, 787 | 1142690317, 788 | 1108464280, 789 | 271940206, 790 | 1588444509, 791 | 925042325, 792 | 1599902262, 793 | 1281050110, 794 | 683390094, 795 | 885103195, 796 | 1677537770, 797 | 2052829077, 798 | 1929106273, 799 | 379127053, 800 | 367598791, 801 | 665498540, 802 | 214491634, 803 | 1986928393, 804 | 555116323, 805 | 1023167041, 806 | 677848000, 807 | 1247511220, 808 | 1611127783, 809 | 19264764, 810 | 1407905500, 811 | 764334122, 812 | 1184123539, 813 | 1926232975, 814 | 2088017510, 815 | 578654454, 816 | 1781791319, 817 | 206395498, 818 | 63044953, 819 | 771559150, 820 | 253209699, 821 | 893311664, 822 | 1626818173, 823 | 626122574, 824 | 784581472, 825 | 977669669, 826 | 1970407004, 827 | 961694139, 828 | 1520493529, 829 | 625158633, 830 | 1705021615, 831 | 1758255878, 832 | 2114111775, 833 | 1360616949, 834 | 1561852852, 835 | 527985858, 836 | 1091123955, 837 | 1752590106, 838 | 1862765330, 839 | 834940058, 840 | 1418099500, 841 | 1950064467, 842 | 533508030, 843 | 291045681, 844 | 902825468, 845 | 904536212, 846 | 1134332670, 847 | 246266101, 848 | 895510008, 849 | 185414082, 850 | 762381885, 851 | 1051886067, 852 | 1757330813, 853 | 854286245, 854 | 629525812, 855 | 757571707, 856 | 362160319, 857 | 710263281, 858 | 1891685362, 859 | 1395229907, 860 | 774376908, 861 | 1734247772, 862 | 1281691351, 863 | 185075766, 864 | 2131885097, 865 | 2141190380, 866 | 603283737, 867 | 771362311, 868 | 92707063, 869 | 519732350, 870 | 1399159911, 871 | 2013357495, 872 | 1053352668, 873 | 1517263587, 874 | 1541870372, 875 | 1321648140, 876 | 1297995383, 877 | 1470710588, 878 | 2112731097, 879 | 1001597594, 880 | 2104305443, 881 | 912912438, 882 | 754366614, 883 | 799003120, 884 | 2009117290, 885 | 1279091208, 886 | 504076768, 887 | 1190536177, 888 | 179229856, 889 | 547065360, 890 | 1524604687, 891 | 314142659, 892 | 222124109, 893 | 1683150190, 894 | 231652147, 895 | 480036226, 896 | 350967455, 897 | 1642927560, 898 | 831776473, 899 | 1217156517, 900 | 43335056, 901 | 1469900474, 902 | 345323111, 903 | 1377724525, 904 | 597286999, 905 | 49196902, 906 | 119098777, 907 | 2087953957, 908 | 1962315979, 909 | 829552729, 910 | 1480162203, 911 | 710805557, 912 | 1378510623, 913 | 695284844, 914 | 2107262039, 915 | 2071687286, 916 | 832604210, 917 | 1470066089, 918 | 323461140, 919 | 476542370, 920 | 717098465, 921 | 679161854, 922 | 633965999, 923 | 600195364, 924 | 1251834271, 925 | 1760600377, 926 | 1110465555, 927 | 695448001, 928 | 1812204416, 929 | 1801071242, 930 | 1721240357, 931 | 609746187, 932 | 1796517282, 933 | 554528868, 934 | 1312733158, 935 | 1328412336, 936 | 1355507490, 937 | 175287066, 938 | 24274813, 939 | 160295734, 940 | 234749775, 941 | 2059489899, 942 | 1120236907, 943 | 16460123, 944 | 720658877, 945 | 1827949995, 946 | 814332939, 947 | 1702307144, 948 | 376952204, 949 | 1768463330, 950 | 216310342, 951 | 2129081932, 952 | 604799341, 953 | 1275225950, 954 | 1628301099, 955 | 89795510, 956 | 1182015419, 957 | 1268390512, 958 | 111199042, 959 | 1159878970, 960 | 1103075838, 961 | 415415220, 962 | 1149618041, 963 | 292847407, 964 | 820730331, 965 | 410797894, 966 | 424741631, 967 | 1646881796, 968 | 959334364, 969 | 209385518, 970 | 1813831521, 971 | 741258471, 972 | 476660187, 973 | 562398066, 974 | 2126364141, 975 | 820854751, 976 | 398390169, 977 | 1065454627, 978 | 1149092077, 979 | 2024255214, 980 | 1829703143, 981 | 528142739, 982 | 1650632136, 983 | 531412320, 984 | 1792057330, 985 | 104206149, 986 | 1904754626, 987 | 189835977, 988 | 965348213, 989 | 1613434394, 990 | 337780178, 991 | 1588814767, 992 | 2115572448, 993 | 161027100, 994 | 275648867, 995 | 1686961877, 996 | 474346010, 997 | 193991935, 998 | 88089057, 999 | 1458036039, 1000 | 1225692874, 1001 | 1081079058, 1002 | 1596647484, 1003 | 1928239280, 1004 | ] 1005 | -------------------------------------------------------------------------------- /src/env/routing.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | import numpy as np 3 | from collections import defaultdict 4 | 5 | from env.environment import EnvironmentVariant, NetworkEnv 6 | from gymnasium.spaces import Discrete 7 | 8 | from env.network import Network 9 | from util import one_hot_list 10 | 11 | 12 | class Data: 13 | """ 14 | A data packet. 15 | """ 16 | 17 | def __init__(self, id): 18 | self.id = id 19 | self.now = None 20 | self.target = None 21 | self.size = None 22 | self.start = None 23 | self.time = 0 24 | self.edge = -1 25 | self.neigh = None 26 | self.ttl = None 27 | self.shortest_path_weight = None 28 | self.visited_nodes = None 29 | 30 | def reset(self, start, target, size, ttl, shortest_path_weight): 31 | self.now = start 32 | self.target = target 33 | self.size = size 34 | self.start = start 35 | self.time = 0 36 | self.edge = -1 37 | self.neigh = [self.id] 38 | self.ttl = ttl 39 | self.shortest_path_weight = shortest_path_weight 40 | self.visited_nodes = set([start]) 41 | 42 | 43 | class Routing(NetworkEnv): 44 | """ " 45 | Routing environment based on the environment by 46 | Jiang et al. https://github.com/PKU-RL/DGN/blob/master/Routing/routers.py 47 | used for their DGN paper https://arxiv.org/abs/1810.09202. 48 | 49 | The task is to route packets from random source to random destination nodes in a 50 | given network. Each agent controls a single packet. When a packet reaches its 51 | destination, a new packet is instantly created at a random location with a new 52 | random target. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | network: Network, 58 | n_data, 59 | env_var: EnvironmentVariant, 60 | k=3, 61 | enable_congestion=True, 62 | enable_action_mask=False, 63 | ttl=0, 64 | ): 65 | """ 66 | Initialize the environment. 67 | 68 | :param network: a network 69 | :param n_data: the number of data packets 70 | :param env_var: the environment variant 71 | :param k: include k neighbors in local observation (only for environment variant WITH_K_NEIGHBORS), defaults to 3 72 | :param enable_congestion: whether to respect link capacities, defaults to True 73 | :param enable_action_mask: whether to generate an action mask for agents that does not allow visiting nodes twice, defaults to False 74 | :param ttl: time to live before packets are discarded, defaults to 0 75 | """ 76 | super(Routing, self).__init__() 77 | 78 | self.network = network 79 | assert isinstance(self.network, Network) 80 | 81 | self.n_data = n_data 82 | self.data = [] 83 | 84 | # make sure env_var is casted 85 | self.env_var = EnvironmentVariant(env_var) 86 | 87 | # optionally include k neighbors in local observation 88 | self.k = k 89 | 90 | # log information 91 | self.agent_steps = np.zeros(self.n_data) 92 | 93 | # whether to use random targets or target == 0 for all packets 94 | self.num_random_targets = self.network.n_nodes 95 | assert self.num_random_targets >= 0 96 | 97 | # map from shortest path to actual agent steps 98 | self.distance_map = defaultdict(list) 99 | self.enable_ttl = ttl > 0 100 | self.enable_congestion = enable_congestion 101 | self.ttl = ttl 102 | self.sum_packets_per_node = None 103 | self.sum_packets_per_edge = None 104 | 105 | self.enable_action_mask = enable_action_mask 106 | self.action_mask = np.zeros((n_data, 4), dtype=bool) 107 | 108 | self.action_space = Discrete(4, start=0) # {0, 1, 2, 3} using gym action space 109 | self.eval_info_enabled = False 110 | 111 | def set_eval_info(self, val): 112 | """ 113 | Whether the step function should return additional info for evaluation. 114 | 115 | :param val: the step function returns additional info if true 116 | """ 117 | self.eval_info_enabled = val 118 | 119 | def reset_packet(self, packet: Data): 120 | """ 121 | Resets the given data packet using the settings of this environment. 122 | 123 | :param packet: a data packet that will be reset *in-place* 124 | """ 125 | # free resources on used edge 126 | if packet.edge != -1: 127 | self.network.edges[packet.edge].load -= packet.size 128 | 129 | # reset packet in place 130 | start = np.random.randint(self.network.n_nodes) 131 | target = np.random.randint(self.num_random_targets) 132 | packet.reset( 133 | start=start, 134 | target=target, 135 | size=np.random.random(), 136 | ttl=self.ttl, 137 | shortest_path_weight=self.network.shortest_paths_weights[start][target], 138 | ) 139 | 140 | if self.enable_action_mask: 141 | # all links are allowed 142 | self.action_mask[packet.id] = 0 143 | # idling is allowed if a packet spawns at the destination 144 | self.action_mask[packet.id, 0] = packet.now != packet.target 145 | 146 | def __str__(self) -> str: 147 | return textwrap.dedent( 148 | f"""\ 149 | Routing environment with parameters 150 | > Network: {self.network.n_nodes} nodes 151 | > Number of packets: {self.n_data} 152 | > Environment variant: {self.env_var.name} 153 | > Number of considered neighbors (k): {self.k if self.env_var == EnvironmentVariant.WITH_K_NEIGHBORS else "disabled"} 154 | > Congestion: {self.enable_congestion} 155 | > Action mask: {self.enable_action_mask} 156 | > TTL: {self.ttl if self.enable_ttl else "disabled"}\ 157 | """ 158 | ) 159 | 160 | def reset(self): 161 | self.agent_steps = np.zeros(self.n_data) 162 | self.network.reset() 163 | for edge in self.network.edges: 164 | # add new load attribute to edges 165 | edge.load = 0 166 | 167 | if self.eval_info_enabled: 168 | self.sum_packets_per_node = np.zeros(self.network.n_nodes) 169 | self.sum_packets_per_edge = np.zeros(len(self.network.edges)) 170 | 171 | # generate random data packets 172 | self.data = [] 173 | for i in range(self.n_data): 174 | new_data = Data(i) 175 | self.reset_packet(new_data) 176 | self.data.append(new_data) 177 | 178 | return self._get_observation(), self._get_data_adjacency() 179 | 180 | def render(self): 181 | # TODO: also render packets 182 | self.network.render() 183 | 184 | def get_nodes_adjacency(self): 185 | return self.network.adj_matrix 186 | 187 | def get_node_observation(self): 188 | """ 189 | Get the node observation for each node in the network. 190 | 191 | :return: node observations of shape (num_nodes, node_observation_size) 192 | """ 193 | obs = [] 194 | for j in range(self.network.n_nodes): 195 | ob = [] 196 | 197 | # router info 198 | # ob.append(j) 199 | ob += one_hot_list(j, self.network.n_nodes) 200 | num_packets = 0 201 | total_load = 0 202 | for i in range(self.n_data): 203 | if self.data[i].now == j and self.data[i].edge == -1: 204 | num_packets += 1 205 | total_load += self.data[i].size 206 | 207 | # for dest in range(self.n_router): 208 | # ob.append(self.shortest_paths_weights[j][dest]) 209 | 210 | ob.append(num_packets) 211 | ob.append(total_load) 212 | 213 | # #position obs 214 | # ob.append(self.router[j].y) 215 | # ob.append(self.router[j].x) 216 | 217 | # my_path_to_zero = self.shortest_paths[j][0] 218 | # next_node = my_path_to_zero[1] if len(my_path_to_zero) > 1 else -1 219 | 220 | # edge info 221 | for k in self.network.nodes[j].edges: 222 | other_node = self.network.edges[k].get_other_node(j) 223 | # ob.append(other_node) 224 | ob += one_hot_list(other_node, self.network.n_nodes) 225 | ob.append(self.network.edges[k].length) 226 | ob.append(self.network.edges[k].load) 227 | 228 | # cheating: add observation that tells the node how to get to 0 229 | # if self.edges[k].get_other_node(j) == next_node: 230 | # ob.append(1) 231 | # else: 232 | # ob.append(0) 233 | 234 | obs.append(ob) 235 | return np.array(obs, dtype=np.float32) 236 | 237 | def get_node_aux(self): 238 | """ 239 | Auxiliary targets for each node in the network. 240 | 241 | :return: Auxiliary targets of shape (num_nodes, node_aux_target_size) 242 | """ 243 | aux = [] 244 | for j in range(self.network.n_nodes): 245 | aux_j = [] 246 | 247 | # for routing, it is essential for a node to estimate the distance to 248 | # other nodes -> auxiliary target is length of shortest paths to all nodes 249 | for k in range(self.network.n_nodes): 250 | aux_j.append(self.network.shortest_paths_weights[j][k]) 251 | 252 | aux.append(aux_j) 253 | 254 | return np.array(aux, dtype=np.float32) 255 | 256 | def get_node_agent_matrix(self): 257 | """ 258 | Gets a matrix that indicates where agents are located, 259 | matrix[n, a] = 1 iff agent a is on node n and 0 otherwise. 260 | 261 | :return: the node agent matrix of shape (n_nodes, n_agents) 262 | """ 263 | node_agent = np.zeros((self.network.n_nodes, self.n_data), dtype=np.int8) 264 | for a in range(self.n_data): 265 | node_agent[self.data[a].now, a] = 1 266 | 267 | return node_agent 268 | 269 | def _get_observation(self): 270 | obs = [] 271 | if self.env_var == EnvironmentVariant.GLOBAL: 272 | # for the global observation 273 | nodes_adjacency = self.get_nodes_adjacency().flatten() 274 | node_observation = self.get_node_observation().flatten() 275 | global_obs = np.concatenate((nodes_adjacency, node_observation)) 276 | 277 | for i in range(self.n_data): 278 | ob = [] 279 | # packet information 280 | # ob.append(self.data[i].now) 281 | ob += one_hot_list(self.data[i].now, self.network.n_nodes) 282 | # ob.append(self.data[i].target) 283 | ob += one_hot_list(self.data[i].target, self.network.n_nodes) 284 | 285 | # packets should know where they are coming from when traveling on an edge 286 | ob.append(int(self.data[i].edge != -1)) 287 | if self.data[i].edge != -1: 288 | other_node = self.network.edges[self.data[i].edge].get_other_node( 289 | self.data[i].now 290 | ) 291 | else: 292 | other_node = -1 293 | ob += one_hot_list(other_node, self.network.n_nodes) 294 | 295 | ob.append(self.data[i].time) 296 | ob.append(self.data[i].size) 297 | ob.append(self.data[i].id) 298 | 299 | # edge information 300 | for j in self.network.nodes[self.data[i].now].edges: 301 | other_node = self.network.edges[j].get_other_node(self.data[i].now) 302 | # ob.append(other_node) 303 | ob += one_hot_list(other_node, self.network.n_nodes) 304 | ob.append(self.network.edges[j].length) 305 | ob.append(self.network.edges[j].load) 306 | 307 | # ob.append(self.shortest_paths_weights[other_node][self.data[i].target]) 308 | # for dest in range(self.n_router): 309 | # ob.append(dest == self.data[i].target) 310 | # ob.append( 311 | # 1 312 | # * (dest == self.data[i].target) 313 | # * self.shortest_paths_weights[other_node][dest] 314 | # ) 315 | 316 | # other data 317 | count = 0 318 | self.data[i].neigh = [] 319 | self.data[i].neigh.append(i) 320 | for j in range(self.n_data): 321 | if j == i: 322 | continue 323 | if ( 324 | self.data[j].now in self.network.nodes[self.data[i].now].neighbors 325 | ) | (self.data[j].now == self.data[i].now): 326 | self.data[i].neigh.append(j) 327 | 328 | # with neighbor information in observation (until k neighbors) 329 | if ( 330 | self.env_var == EnvironmentVariant.WITH_K_NEIGHBORS 331 | and count < self.k 332 | ): 333 | count += 1 334 | ob.append(self.data[j].now) 335 | ob.append(self.data[j].target) 336 | ob.append(self.data[j].edge) 337 | ob.append(self.data[j].size) 338 | ob.append(self.data[i].id) 339 | 340 | if self.env_var == EnvironmentVariant.WITH_K_NEIGHBORS: 341 | for j in range(self.k - count): 342 | for _ in range(5): 343 | ob.append(-1) # invalid placeholder 344 | 345 | # for j in range(self.n_router): 346 | # # cooridnates info 347 | # ob.append(self.router[j].y) 348 | # ob.append(self.router[j].x) 349 | 350 | ob_numpy = np.array(ob) 351 | 352 | # add global information 353 | if self.env_var == EnvironmentVariant.GLOBAL: 354 | ob_numpy = np.concatenate((ob_numpy, global_obs)) 355 | 356 | obs.append(ob_numpy) 357 | 358 | return np.array(obs, dtype=np.float32) 359 | 360 | def step(self, act): 361 | reward = np.zeros(self.n_data, dtype=np.float32) 362 | looped = np.zeros(self.n_data, dtype=np.float32) 363 | done = np.zeros(self.n_data, dtype=bool) 364 | drop_packet = np.zeros(self.n_data, dtype=bool) 365 | success = np.zeros(self.n_data, dtype=bool) 366 | blocked = 0 367 | 368 | delays = [] 369 | delays_arrived = [] 370 | spr = [] 371 | self.agent_steps += 1 372 | 373 | # optionally shuffle packet order so that lower packet ids 374 | # are not prioritized anymore 375 | # random_packet_order = np.arange(self.n_data) 376 | # np.random.shuffle(random_packet_order) 377 | 378 | # handle actions 379 | # for i in random_packet_order: 380 | for i in range(self.n_data): 381 | # agent i controls data packet i 382 | packet = self.data[i] 383 | 384 | if self.eval_info_enabled: 385 | if packet.edge == -1: 386 | self.sum_packets_per_node[packet.now] += 1 387 | 388 | # select outgoing edge (act == 0 is idle) 389 | if packet.edge == -1 and act[i] != 0: 390 | t = self.network.nodes[packet.now].edges[act[i] - 1] 391 | # note that packets that are handled earlier in this loop 392 | # (i.e. with lower ids) are prioritized here. 393 | if ( 394 | self.enable_congestion 395 | and self.network.edges[t].load + packet.size > 1 396 | ): 397 | # not possible to take this edge => collision 398 | reward[i] -= 0.2 399 | blocked += 1 400 | else: 401 | # take this edge 402 | packet.edge = t 403 | packet.time = self.network.edges[t].length 404 | # assign load to the selected edge 405 | self.network.edges[t].load += packet.size 406 | 407 | # already set the next position 408 | packet.now = self.network.edges[t].get_other_node(packet.now) 409 | if packet.now in packet.visited_nodes: 410 | looped[i] = 1 411 | else: 412 | packet.visited_nodes.add(packet.now) 413 | 414 | if self.eval_info_enabled: 415 | total_edge_load = 0 416 | occupied_edges = 0 417 | packets_on_edges = 0 418 | total_packet_size = 0 419 | packet_sizes = [] 420 | 421 | for edge in self.network.edges: 422 | total_edge_load += edge.load 423 | if edge.load > 0: 424 | occupied_edges += 1 425 | 426 | for i in range(self.n_data): 427 | packet = self.data[i] 428 | if packet.edge != -1: 429 | self.sum_packets_per_edge[packet.edge] += 1 430 | 431 | total_packet_size += packet.size 432 | packet_sizes.append(self.data[i].size) 433 | if packet.edge != -1: 434 | packets_on_edges += 1 435 | 436 | packet_distances = list( 437 | map( 438 | lambda p: self.network.shortest_paths_weights[p.now][p.target], 439 | self.data, 440 | ) 441 | ) 442 | 443 | # then simulate in-flight packets (=> effect of actions) 444 | for i in range(self.n_data): 445 | packet = self.data[i] 446 | packet.ttl -= 1 447 | 448 | if packet.edge != -1: 449 | packet.time -= 1 450 | # the packet arrived at the destination, reduce load from edge 451 | if packet.time <= 0: 452 | self.network.edges[packet.edge].load -= packet.size 453 | packet.edge = -1 454 | 455 | drop_packet[i] = drop_packet[i] or (self.enable_ttl and packet.ttl <= 0) 456 | if self.enable_action_mask: 457 | if packet.edge != -1: 458 | self.action_mask[i] = 0 459 | else: 460 | self.action_mask[i, 0] = 1 461 | for edge_i, e in enumerate(self.network.nodes[packet.now].edges): 462 | self.action_mask[i, 1 + edge_i] = ( 463 | self.network.edges[e].get_other_node(packet.now) 464 | in packet.visited_nodes 465 | ) 466 | 467 | # packets that can't do anything are dropped 468 | if self.action_mask[i].sum() == 4: 469 | drop_packet[i] = True 470 | 471 | # the packet has reached the target 472 | has_reached_target = packet.edge == -1 and packet.now == packet.target 473 | if has_reached_target or drop_packet[i]: 474 | reward[i] += 10 if has_reached_target else -10 475 | done[i] = True 476 | success[i] = has_reached_target 477 | 478 | # we need at least 1 step (idle) if we spawn at the target 479 | opt_distance = max(packet.shortest_path_weight, 1) 480 | 481 | # insert delays before resetting packets 482 | if success[i]: 483 | delays_arrived.append(self.agent_steps[i]) 484 | spr.append(self.agent_steps[i] / opt_distance) 485 | if self.eval_info_enabled: 486 | self.distance_map[opt_distance].append(self.agent_steps[i]) 487 | 488 | delays.append(self.agent_steps[i]) 489 | 490 | self.agent_steps[i] = 0 491 | self.reset_packet(packet) 492 | # else: 493 | # # negative reward for distance in hops 494 | # distance = len(self.shortest_paths[packet.now][packet.target]) 495 | # reward[i] -= distance * 0.01 496 | 497 | obs = self._get_observation() 498 | adj = self._get_data_adjacency() 499 | info = { 500 | "delays": delays, 501 | "delays_arrived": delays_arrived, 502 | # shortest path ratio in [1, inf) where 1 is optimal 503 | "spr": spr, 504 | "looped": looped.sum(), 505 | "throughput": success.sum(), 506 | "dropped": (done & ~success).sum(), 507 | "blocked": blocked, 508 | } 509 | if self.eval_info_enabled: 510 | info.update( 511 | { 512 | "total_edge_load": total_edge_load, 513 | "occupied_edges": occupied_edges, 514 | "packets_on_edges": packets_on_edges, 515 | "total_packet_size": total_packet_size, 516 | "packet_sizes": packet_sizes, 517 | "packet_distances": packet_distances, 518 | } 519 | ) 520 | return obs, adj, reward, done, info 521 | 522 | def _get_data_adjacency(self): 523 | """ 524 | Get an adjacency matrix for data packets (agents) of shape (n_agents, n_agents) 525 | where the second dimension contains the neighbors of the agents in the first 526 | dimension, i.e. the matrix is of form (agent, neighbors). 527 | 528 | :param data: current data list 529 | :param n_data: number of data packets 530 | :return: adjacency matrix 531 | """ 532 | # eye because self is also part of the neighborhood 533 | adj = np.eye(self.n_data, self.n_data, dtype=np.int8) 534 | for i in range(self.n_data): 535 | for n in self.data[i].neigh: 536 | if n != -1: 537 | # n is (currently) a neighbor of i 538 | adj[i, n] = 1 539 | return adj 540 | 541 | def get_final_info(self, info: dict): 542 | agent_steps = self.agent_steps 543 | for agent_step in agent_steps: 544 | if agent_step != 0: 545 | info["delays"].append(agent_step) 546 | return info 547 | 548 | def get_num_agents(self): 549 | return self.n_data 550 | 551 | def get_num_nodes(self): 552 | return self.network.n_nodes 553 | -------------------------------------------------------------------------------- /src/sl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from pathlib import Path 4 | import pickle 5 | from tqdm import tqdm 6 | from env.constants import EVAL_SEEDS 7 | from env.network import Network 8 | from model import NetMon 9 | from env.environment import EnvironmentVariant 10 | from env.routing import Routing 11 | import torch.nn as nn 12 | import torch 13 | import numpy as np 14 | import networkx as nx 15 | import pandas as pd 16 | import torch.nn.functional as F 17 | 18 | from util import dim_str_to_list, set_seed 19 | 20 | parser = argparse.ArgumentParser( 21 | description="Train and test graph observation models on a supervised routing task." 22 | ) 23 | 24 | parser.add_argument( 25 | "--num-targets", 26 | type=int, 27 | help="Number of targets included in the loss for regression with all destinations", 28 | default=None, 29 | ) 30 | parser.add_argument( 31 | "--num-samples-train", 32 | type=int, 33 | help="Number of generated training graphs (ignored when loading a dataset)", 34 | default=10_000, 35 | ) 36 | parser.add_argument("--seed", type=int, help="Seed for the experiment", default=42) 37 | parser.add_argument( 38 | "--iterations", type=int, help="Number of training iterations", default=4_000 39 | ) 40 | parser.add_argument( 41 | "--validate-after", 42 | type=int, 43 | help="Validate model after the given number of training steps", 44 | default=1_000, 45 | ) 46 | parser.add_argument( 47 | "--sequence-length", 48 | type=int, 49 | help="Unroll depth of the model for each sample", 50 | default=4, 51 | ) 52 | parser.add_argument( 53 | "--filename", type=str, help="Where to save the results", default=None 54 | ) 55 | parser.add_argument( 56 | "--test-sequence-lengths", 57 | type=str, 58 | help="Sequence lengths used during testing", 59 | default="1,2,4,8,16,32,64,128,256", 60 | ) 61 | parser.add_argument( 62 | "--netmon-dim", type=int, help="Size of NetMon state and observations", default=128 63 | ) 64 | parser.add_argument( 65 | "--netmon-encoder-dim", 66 | type=str, 67 | help="NetMon encoder dimensions. Examples: '128', '512,128'..", 68 | default="512,256", 69 | ) 70 | parser.add_argument( 71 | "--netmon-iterations", 72 | type=int, 73 | help="Number of NetMon iterations between environment steps", 74 | default=3, 75 | ) 76 | parser.add_argument( 77 | "--netmon-rnn-type", 78 | type=str, 79 | help="NetMon RNN type", 80 | default="lstm", 81 | ) 82 | parser.add_argument( 83 | "--netmon-rnn-carryover", 84 | type=int, 85 | help="Carry over RNN state between RNN modules", 86 | # 0: False, 1: True 87 | choices=[False, True], 88 | default=True, 89 | ) 90 | parser.add_argument( 91 | "--netmon-agg-type", 92 | type=str, 93 | help="NetMon aggregation function", 94 | default="sum", 95 | ) 96 | parser.add_argument( 97 | "--netmon-global", 98 | help="Enables global pooling of graph observations (only allowed in centralized case)", 99 | dest="netmon_global", 100 | action="store_true", 101 | ) 102 | parser.set_defaults(netmon_global=False) 103 | parser.add_argument( 104 | "--netmon-last-neighbors", 105 | type=int, 106 | help="Append last node state received by neighbors to graph observation", 107 | choices=[False, True], 108 | default=True, 109 | ) 110 | 111 | parser.add_argument( 112 | "--disable-progressbar", 113 | help="Disables the progress bar and iteration-wise status prints", 114 | dest="disable_progressbar", 115 | action="store_true", 116 | ) 117 | parser.set_defaults(disable_progressbar=False) 118 | 119 | parser.add_argument( 120 | "--clear-cache", 121 | help="Forces generation of new datasets (clears cache if it exists)", 122 | dest="clear_cache", 123 | action="store_true", 124 | ) 125 | parser.set_defaults(clear_cache=False) 126 | 127 | args = parser.parse_args() 128 | args.test_sequence_lengths = dim_str_to_list(args.test_sequence_lengths) 129 | set_seed(args.seed) 130 | 131 | 132 | class NetMonSL(nn.Module): 133 | def __init__(self, node_obs_dim, nb_classes, nb_nodes) -> None: 134 | super().__init__() 135 | # rnn_type specifies if we are using LSTM or GRU 136 | self.netmon = NetMon( 137 | node_obs_dim, 138 | args.netmon_dim, 139 | dim_str_to_list(args.netmon_encoder_dim), 140 | iterations=args.netmon_iterations, 141 | activation_fn=F.leaky_relu, 142 | rnn_type=args.netmon_rnn_type, 143 | rnn_carryover=args.netmon_rnn_carryover, 144 | agg_type=args.netmon_agg_type, 145 | output_neighbor_hidden=args.netmon_last_neighbors, 146 | output_global_hidden=args.netmon_global, 147 | ) 148 | self.linear = nn.Linear(self.netmon.get_out_features(), nb_classes) 149 | self.linear_reg = nn.Linear(self.netmon.get_out_features(), 1) 150 | self.linear_reg_all = nn.Linear(self.netmon.get_out_features(), nb_nodes) 151 | self.class_logits = None 152 | 153 | def forward(self, node_obs, node_adj): 154 | batches, nodes, features = node_obs.shape 155 | eye = torch.eye(nodes).repeat(batches, 1, 1) 156 | 157 | node_features = self.netmon(node_obs, node_adj, eye) 158 | class_logits = self.linear(node_features) 159 | pred = self.linear_reg(node_features) 160 | pred_all = self.linear_reg_all(node_features) 161 | self.class_logits = class_logits.detach() 162 | return class_logits, pred, pred_all 163 | 164 | def get_class_probabilities(self): 165 | return torch.softmax(self.class_logits, dim=-1) 166 | 167 | def get_prediction(self): 168 | return torch.argmax(self.get_class_probabilities(), axis=-1) 169 | 170 | 171 | NUM_CLASSES = 4 172 | 173 | 174 | def get_sl_sample(env: Routing): 175 | single_node_obs = env.get_node_observation() 176 | single_node_adj = env.get_nodes_adjacency() 177 | single_node_labels = np.zeros(single_node_obs.shape[0]) 178 | single_node_targets = np.zeros(single_node_obs.shape[0]) 179 | single_node_targets_all = np.zeros( 180 | (single_node_obs.shape[0], single_node_obs.shape[0]) 181 | ) 182 | 183 | # get labels from path to zero 184 | for n in range(env.get_num_nodes()): 185 | # distance to node 0 186 | single_node_targets[n] = env.network.shortest_paths_weights[n][0] 187 | for n_other in range(env.get_num_nodes()): 188 | single_node_targets_all[n, n_other] = env.network.shortest_paths_weights[n][ 189 | n_other 190 | ] 191 | 192 | # which link corresponds to the shortest path to node 0? 193 | n_to_zero = env.network.shortest_paths[n][0] 194 | if len(n_to_zero) == 1: 195 | # we are already there 196 | single_node_labels[n] = 0 197 | # print(f"{n}: is the target") 198 | else: 199 | # look for the link we need to take 200 | next_node = n_to_zero[1] 201 | found_edge = False 202 | for e_idx, e in enumerate(env.network.nodes[n].edges): 203 | if env.network.edges[e].get_other_node(n) == next_node: 204 | single_node_labels[n] = e_idx + 1 205 | # print(f"{n}: found node {next_node} at edge idx {e_idx + 1}") 206 | found_edge = True 207 | break 208 | 209 | assert found_edge 210 | 211 | assert (single_node_adj.sum(axis=-1) == 4).all() 212 | return ( 213 | single_node_obs, 214 | single_node_adj, 215 | single_node_labels, 216 | single_node_targets, 217 | single_node_targets_all, 218 | ) 219 | 220 | 221 | # we are getting the average steps from one node to node 0 222 | def get_mean_num_shortest_paths_to_zero(env): 223 | num_paths = [] 224 | for i in range(env.get_num_nodes()): 225 | num_paths.append( 226 | len( 227 | list( 228 | nx.all_shortest_paths( 229 | env.network.G, i, 0, weight=env.network.G_weight_key 230 | ) 231 | ) 232 | ) 233 | ) 234 | return np.mean(num_paths) 235 | 236 | 237 | def build_dataset(env: Routing, num_samples): 238 | assert num_samples >= 1 239 | ( 240 | init_node_obs, 241 | init_node_adj, 242 | init_node_labels, 243 | init_node_targets, 244 | init_node_targets_all, 245 | ) = get_sl_sample(env) 246 | 247 | node_obs = np.zeros((num_samples, *init_node_obs.shape)) 248 | node_adj = np.zeros((num_samples, *init_node_adj.shape)) 249 | node_labels = np.zeros((num_samples, *init_node_labels.shape)) 250 | node_targets = np.zeros((num_samples, *init_node_targets.shape)) 251 | node_targets_all = np.zeros((num_samples, *init_node_targets_all.shape)) 252 | 253 | node_obs[0] = init_node_obs 254 | node_adj[0] = init_node_adj 255 | node_labels[0] = init_node_labels 256 | node_targets[0] = init_node_targets 257 | node_targets_all[0] = init_node_targets_all 258 | 259 | mean_paths = np.zeros((num_samples)) 260 | mean_paths[0] = get_mean_num_shortest_paths_to_zero(env) 261 | 262 | for s in tqdm( 263 | range(1, num_samples), 264 | initial=1, 265 | total=num_samples, 266 | disable=args.disable_progressbar, 267 | ): 268 | env.reset() 269 | ( 270 | node_obs[s], 271 | node_adj[s], 272 | node_labels[s], 273 | node_targets[s], 274 | node_targets_all[s], 275 | ) = get_sl_sample(env) 276 | mean_paths[s] = get_mean_num_shortest_paths_to_zero(env) 277 | 278 | print( 279 | "Network stats: \n" 280 | f"Mean neighbors: {node_adj.sum(axis=-1).mean()} \n" 281 | f"Mean distance to 0: {node_targets.mean()} \n" 282 | f"Max distance to 0: {node_targets.max()} \n" 283 | f"Mean shortest paths to 0: {mean_paths.mean()}" 284 | ) 285 | 286 | return node_obs, node_adj, node_labels, node_targets, node_targets_all 287 | 288 | 289 | def build_or_load_dataset(env: Routing, num_samples, filename, clear_cache, cache): 290 | path = Path(filename) 291 | 292 | if path.exists(): 293 | if clear_cache: 294 | path.unlink() 295 | elif cache: 296 | with open(path, "rb") as f: 297 | dataset = pickle.load(f) 298 | return dataset 299 | 300 | print("Creating new dataset..") 301 | dataset = build_dataset(env, num_samples) 302 | 303 | if cache: 304 | with open(path, "wb") as f: 305 | pickle.dump(dataset, f) 306 | print(f"Saved dataset as {path}") 307 | 308 | return dataset 309 | 310 | 311 | class RandomLossReduction: 312 | def __init__(self, loss) -> None: 313 | self.loss = loss 314 | 315 | def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 316 | loss_out = self.loss(input, target) 317 | loss_weights = torch.ones_like(loss_out) 318 | loss_weights *= torch.rand(size=loss_out.shape) >= 0.5 319 | return torch.sum(loss_out * loss_weights) / torch.sum(loss_weights) 320 | 321 | 322 | def train( 323 | model: nn.Module, 324 | train_node_obs, 325 | train_node_adj, 326 | train_node_labels, 327 | train_node_targets, 328 | train_node_targets_all, 329 | iterations, 330 | batch_size, 331 | with_classification, 332 | with_regression, 333 | with_regression_all, 334 | sequence_length, 335 | random_loss_weights=False, 336 | validation_callback=None, 337 | ): 338 | assert len(train_node_obs) == len(train_node_adj) == len(train_node_labels) 339 | num_samples = len(train_node_obs) 340 | optim = torch.optim.AdamW(model.parameters()) 341 | 342 | reduction = "none" if random_loss_weights else "mean" 343 | cross_entropy = torch.nn.CrossEntropyLoss(reduction=reduction) 344 | mse = torch.nn.MSELoss(reduction=reduction) 345 | 346 | if random_loss_weights: 347 | cross_entropy = RandomLossReduction(cross_entropy) 348 | mse = RandomLossReduction(mse) 349 | 350 | model.train() 351 | class_loss_list = [] 352 | reg_loss_list = [] 353 | reg_all_loss_list = [] 354 | total_loss_list = [] 355 | validation_list = [] 356 | 357 | if validation_callback is not None: 358 | tqdm.write("Validation..") 359 | netmon.eval() 360 | validation_list.append((0, *validation_callback())) 361 | netmon.train() 362 | 363 | for it in tqdm( 364 | range(iterations), total=iterations, disable=args.disable_progressbar 365 | ): 366 | sequence_loss_list = [] 367 | model.netmon.state = None 368 | batch_idx = np.random.choice(num_samples, batch_size, replace=True) 369 | batch_node_obs = torch.Tensor(train_node_obs[batch_idx]) 370 | batch_node_adj = torch.Tensor(train_node_adj[batch_idx]) 371 | batch_node_labels = torch.Tensor(train_node_labels[batch_idx]).long() 372 | batch_node_targets = torch.Tensor(train_node_targets[batch_idx]) 373 | batch_node_targets_all = torch.Tensor(train_node_targets_all[batch_idx]) 374 | 375 | for _ in range(0, max(sequence_length, 1)): 376 | log_probs, pred, pred_all = model(batch_node_obs, batch_node_adj) 377 | 378 | loss = 0 379 | if with_classification: 380 | # combine batch and node dimensions for loss calculation 381 | log_probs_loss = log_probs.reshape(-1, NUM_CLASSES) 382 | batch_node_labels_loss = batch_node_labels.reshape(-1, 1) 383 | 384 | class_loss = cross_entropy( 385 | log_probs_loss, batch_node_labels_loss.squeeze(-1) 386 | ) 387 | loss += class_loss 388 | 389 | if with_regression: 390 | pred = pred.reshape(-1, 1) 391 | batch_node_targets = batch_node_targets.reshape(-1, 1) 392 | regression_loss = mse(pred, batch_node_targets) 393 | loss += regression_loss 394 | 395 | if with_regression_all: 396 | regression_loss_all = mse( 397 | pred_all[..., : args.num_targets], 398 | batch_node_targets_all[..., : args.num_targets], 399 | ) 400 | loss += regression_loss_all 401 | 402 | # remember total loss for each element in the sequence 403 | sequence_loss_list.append(loss) 404 | 405 | # log loss at end of sequence 406 | iteration_str = f"Iteration {it}" 407 | if with_classification: 408 | class_loss_list.append(class_loss.detach().item()) 409 | iteration_str += f" | class loss = {class_loss.detach().item():.2f}" 410 | if with_regression: 411 | reg_loss_list.append(regression_loss.detach().item()) 412 | iteration_str += f" | reg loss = {regression_loss.detach().item():.2f}" 413 | if with_regression_all: 414 | reg_all_loss_list.append(regression_loss_all.detach().item()) 415 | iteration_str += ( 416 | f" | reg_all loss = {regression_loss_all.detach().item():.2f}" 417 | ) 418 | 419 | # and total mean loss for the sequence 420 | total_loss = torch.mean(torch.stack(sequence_loss_list)) 421 | total_loss_list.append(total_loss.detach().item()) 422 | iteration_str += f" | total = {total_loss.detach().item():.2f}" 423 | if sequence_length > 1: 424 | iteration_str += f" (seq_len={sequence_length})" 425 | 426 | optim.zero_grad() 427 | total_loss.backward() 428 | optim.step() 429 | 430 | if not args.disable_progressbar: 431 | tqdm.write(iteration_str) 432 | 433 | if validation_callback is not None and (it + 1) % args.validate_after == 0: 434 | tqdm.write(f"Iteration {it + 1}: validation") 435 | netmon.eval() 436 | validation_list.append((it + 1, *validation_callback())) 437 | netmon.train() 438 | 439 | return ( 440 | total_loss_list, 441 | class_loss_list, 442 | reg_loss_list, 443 | reg_all_loss_list, 444 | validation_list, 445 | ) 446 | 447 | 448 | def test( 449 | model: nn.Module, 450 | test_node_obs, 451 | test_node_adj, 452 | test_node_labels, 453 | test_node_targets, 454 | test_node_targets_all, 455 | batch_size, 456 | with_classification, 457 | with_regression, 458 | with_regression_all, 459 | sequence_length, 460 | disable_progress_force=False, 461 | ): 462 | assert len(test_node_obs) == len(test_node_adj) == len(test_node_labels) 463 | num_samples = len(test_node_obs) 464 | 465 | cross_entropy = torch.nn.CrossEntropyLoss(reduction="sum") 466 | mse = torch.nn.MSELoss(reduction="sum") 467 | 468 | total_class_loss = 0 469 | total_reg_loss = 0 470 | total_reg_all_loss = 0 471 | total_correct = 0 472 | total_count = np.prod(test_node_obs.shape[0:2]) 473 | 474 | model.eval() 475 | idx = 0 476 | pbar = tqdm( 477 | total=num_samples, disable=args.disable_progressbar or disable_progress_force 478 | ) 479 | with torch.no_grad(): 480 | while idx < num_samples: 481 | model.netmon.state = None 482 | next_idx = min(num_samples, idx + batch_size) 483 | batch_idx = np.arange(idx, next_idx) 484 | batch_node_obs = torch.Tensor(test_node_obs[batch_idx]) 485 | batch_node_adj = torch.Tensor(test_node_adj[batch_idx]) 486 | batch_node_labels = torch.Tensor(test_node_labels[batch_idx]).long() 487 | batch_node_targets = torch.Tensor(test_node_targets[batch_idx]) 488 | batch_node_targets_all = torch.Tensor(test_node_targets_all[batch_idx]) 489 | 490 | for _ in range(0, max(sequence_length, 1)): 491 | log_probs, pred, pred_all = model(batch_node_obs, batch_node_adj) 492 | 493 | if with_classification: 494 | # combine batch and node dimensions for loss calculation 495 | log_probs_loss = log_probs.reshape(-1, NUM_CLASSES) 496 | batch_node_labels_loss = batch_node_labels.reshape(-1, 1) 497 | 498 | total_class_loss += cross_entropy( 499 | log_probs_loss, batch_node_labels_loss.squeeze(-1) 500 | ).item() 501 | total_correct += ( 502 | (model.get_prediction() == batch_node_labels).sum().item() 503 | ) 504 | 505 | if with_regression: 506 | pred = pred.reshape(-1, 1) 507 | batch_node_targets = batch_node_targets.reshape(-1, 1) 508 | total_reg_loss += mse(pred, batch_node_targets).item() 509 | 510 | if with_regression_all: 511 | total_reg_all_loss += mse( 512 | pred_all[..., : args.num_targets], 513 | batch_node_targets_all[..., : args.num_targets], 514 | ).item() 515 | 516 | pbar.update(next_idx - idx) 517 | idx = next_idx 518 | 519 | pbar.close() 520 | 521 | if with_classification: 522 | print( 523 | f"{total_correct / total_count:.2f} acc, {total_class_loss / total_count} loss" 524 | ) 525 | if with_regression: 526 | print(f"Pred loss {total_reg_loss / total_count}") 527 | 528 | if with_regression_all: 529 | print(f"Pred_all loss {total_reg_all_loss / (total_count * args.num_targets)}") 530 | 531 | return ( 532 | total_correct / total_count, 533 | total_class_loss / total_count, 534 | total_reg_loss / total_count, 535 | total_reg_all_loss / (total_count * args.num_targets), 536 | ) 537 | 538 | 539 | if torch.cuda.is_available(): 540 | torch.set_default_tensor_type("torch.cuda.FloatTensor") 541 | print("Training with GPU enabled") 542 | 543 | batch_size = 32 544 | clear_cache = args.clear_cache 545 | # whether to cache/save the generated dataset 546 | cache = True 547 | 548 | # 10_000 1_000 4_000 549 | num_samples_train = args.num_samples_train 550 | num_samples_test = 1_000 551 | assert num_samples_test <= len(EVAL_SEEDS) 552 | train_iterations = args.iterations 553 | with_classification = False 554 | with_regression = False 555 | with_regression_all = True 556 | 557 | sequence_length = args.sequence_length 558 | 559 | save_results_filename = args.filename 560 | if save_results_filename is not None and Path(save_results_filename).exists(): 561 | Path(save_results_filename).unlink() 562 | 563 | num_nodes = 20 564 | num_packets = 20 565 | if args.num_targets is None: 566 | args.num_targets = num_nodes 567 | 568 | assert 1 <= args.num_targets <= num_nodes 569 | 570 | networks_train = Network( 571 | num_nodes, 572 | random_topology=True, 573 | excluded_seeds=EVAL_SEEDS, 574 | ) 575 | networks_val = Network( 576 | num_nodes, 577 | random_topology=True, 578 | n_random_seeds=num_samples_test, 579 | sequential_topology_seeds=True, 580 | excluded_seeds=EVAL_SEEDS, 581 | ) 582 | networks_test = Network( 583 | num_nodes, 584 | random_topology=True, 585 | sequential_topology_seeds=True, 586 | provided_seeds=EVAL_SEEDS, 587 | ) 588 | env = Routing(networks_train, num_packets, EnvironmentVariant.INDEPENDENT) 589 | env_val = Routing(networks_val, num_packets, EnvironmentVariant.INDEPENDENT) 590 | env_test = Routing(networks_test, num_packets, EnvironmentVariant.INDEPENDENT) 591 | env.network.G_weight_key = ( 592 | # number of hops 593 | # None 594 | # actual edge lengths (needs more training iterations!) 595 | "weight" 596 | ) 597 | env_val.network.G_weight_key = env.network.G_weight_key 598 | env_test.network.G_weight_key = env.network.G_weight_key 599 | env.reset() 600 | env_val.reset() 601 | env_test.reset() 602 | node_observation_dim = len(env.get_node_observation()[0]) 603 | 604 | netmon = NetMonSL(node_observation_dim, NUM_CLASSES, num_nodes) 605 | summary_node_obs = torch.tensor( 606 | env.get_node_observation(), dtype=torch.float32 607 | ).unsqueeze(0) 608 | summary_node_adj = torch.tensor( 609 | env.get_nodes_adjacency(), dtype=torch.float32 610 | ).unsqueeze(0) 611 | summary_node_agent = torch.tensor( 612 | env.get_node_agent_matrix(), dtype=torch.float32 613 | ).unsqueeze(0) 614 | print(netmon.netmon.summarize(summary_node_obs, summary_node_adj, summary_node_agent)) 615 | 616 | print("Loading train dataset..") 617 | train_data = build_or_load_dataset( 618 | env, num_samples_train, "train.pk", clear_cache, cache 619 | ) 620 | print(f"loaded {train_data[0].shape[0]} samples") 621 | 622 | print("Loading validation dataset..") 623 | val_data = build_or_load_dataset( 624 | env_val, num_samples_test, "val.pk", clear_cache, cache 625 | ) 626 | print(f"loaded {val_data[0].shape[0]} samples") 627 | 628 | print("Loading test dataset..") 629 | test_data = build_or_load_dataset( 630 | env_test, num_samples_test, "test.pk", clear_cache, cache 631 | ) 632 | print(f"loaded {test_data[0].shape[0]} samples") 633 | 634 | 635 | def validation_callback(): 636 | return test( 637 | netmon, 638 | *val_data, 639 | batch_size, 640 | with_classification, 641 | with_regression, 642 | with_regression_all, 643 | sequence_length, 644 | disable_progress_force=True, 645 | ) 646 | 647 | 648 | # training 649 | total_loss, class_loss, reg_loss, reg_loss_all, validation_results = train( 650 | netmon, 651 | *train_data, 652 | train_iterations, 653 | batch_size, 654 | with_classification, 655 | with_regression, 656 | with_regression_all, 657 | sequence_length, 658 | validation_callback=validation_callback, 659 | ) 660 | 661 | # save losses 662 | if save_results_filename is not None: 663 | class_loss = class_loss if with_classification else np.zeros(train_iterations) 664 | reg_loss = reg_loss if with_regression else np.zeros(train_iterations) 665 | reg_loss_all = reg_loss_all if with_regression_all else np.zeros(train_iterations) 666 | df_loss = pd.DataFrame( 667 | data=list( 668 | zip( 669 | np.arange(train_iterations), 670 | class_loss, 671 | reg_loss, 672 | reg_loss_all, 673 | total_loss, 674 | ) 675 | ), 676 | columns=[ 677 | "Iteration", 678 | "Classification Loss", 679 | "Regression Loss", 680 | "Regression Loss All", 681 | "Total Loss", 682 | ], 683 | ) 684 | df_loss.to_hdf(save_results_filename, "loss", mode="a") 685 | 686 | def validation_results_idx(idx): 687 | return list(map(lambda x: x[idx], validation_results)) 688 | 689 | df_validation = pd.DataFrame( 690 | data=list( 691 | zip( 692 | validation_results_idx(0), 693 | validation_results_idx(1), 694 | validation_results_idx(2), 695 | validation_results_idx(3), 696 | validation_results_idx(4), 697 | ) 698 | ), 699 | columns=[ 700 | "Iteration", 701 | "Accuracy", 702 | "Classification Loss", 703 | "Regression Loss", 704 | "Regression Loss All", 705 | ], 706 | ) 707 | df_validation.to_hdf(save_results_filename, "validation", mode="a") 708 | 709 | 710 | print("Train data eval: ") 711 | train_acc, train_class_loss, train_reg_loss, train_reg_loss_all = test( 712 | netmon, 713 | *train_data, 714 | batch_size, 715 | with_classification, 716 | with_regression, 717 | with_regression_all, 718 | sequence_length, 719 | ) 720 | 721 | # testing 722 | 723 | print(f"Test data eval: (seq_len={sequence_length})") 724 | test_acc, test_class_loss, test_reg_loss, test_reg_loss_all = test( 725 | netmon, 726 | *test_data, 727 | batch_size, 728 | with_classification, 729 | with_regression, 730 | with_regression_all, 731 | sequence_length, 732 | ) 733 | 734 | # for sequence length > 1, also start eval out of train sequence length 735 | # removed this condition 736 | test_sequence_results = [] 737 | 738 | for seq_len in args.test_sequence_lengths: 739 | print(f"Extended test data eval (seq_len={seq_len})") 740 | test_sequence_results.append( 741 | # append new tuple with sequence length and test results 742 | ( 743 | seq_len, 744 | *test( 745 | netmon, 746 | *test_data, 747 | batch_size, 748 | with_classification, 749 | with_regression, 750 | with_regression_all, 751 | seq_len, 752 | ), 753 | ) 754 | ) 755 | 756 | # save results 757 | if save_results_filename is not None: 758 | n = len(test_sequence_results) 759 | 760 | def test_sequence_results_idx(idx): 761 | return list(map(lambda x: x[idx], test_sequence_results)) 762 | 763 | df_loss = pd.DataFrame( 764 | data=list( 765 | zip( 766 | ["train"] + ["test"] * (1 + n), 767 | [train_data[0].shape[0]] + [test_data[0].shape[0]] * (1 + n), 768 | [sequence_length, sequence_length] + test_sequence_results_idx(0), 769 | [args.netmon_iterations] * (2 + n), 770 | [train_acc, test_acc] + test_sequence_results_idx(1), 771 | [train_class_loss, test_class_loss] + test_sequence_results_idx(2), 772 | [train_reg_loss, test_reg_loss] + test_sequence_results_idx(3), 773 | [train_reg_loss_all, test_reg_loss_all] + test_sequence_results_idx(4), 774 | ) 775 | ), 776 | columns=[ 777 | "Type", 778 | "Samples", 779 | "Sequence Length", 780 | "Netmon Iterations", 781 | "Accuracy", 782 | "Classification Loss", 783 | "Regression Loss", 784 | "Regression Loss All", 785 | ], 786 | ) 787 | df_loss.to_hdf(save_results_filename, "results", mode="a") 788 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv, SAGEConv, AntiSymmetricConv, GraphSAGE 5 | from torch_geometric.utils import dense_to_sparse 6 | from torch_geometric.nn.summary import summary 7 | 8 | # from torch_geometric_temporal import DyGrEncoder 9 | 10 | from layernormlstm import LayerNormLSTMCell 11 | 12 | 13 | class MLP(nn.Module): 14 | def __init__( 15 | self, in_features, mlp_units, activation_fn, activation_on_output=True 16 | ): 17 | super(MLP, self).__init__() 18 | self.activation_fn = activation_fn 19 | 20 | self.linear_layers = nn.ModuleList() 21 | previous_units = in_features 22 | if isinstance(mlp_units, int): 23 | mlp_units = [mlp_units] 24 | 25 | for units in mlp_units: 26 | self.linear_layers.append(nn.Linear(previous_units, units)) 27 | previous_units = units 28 | 29 | self.out_features = previous_units 30 | self.activation_on_output = activation_on_output 31 | 32 | def forward(self, x): 33 | # intermediate layers 34 | for module in self.linear_layers[:-1]: 35 | x = self.activation_fn(module(x)) 36 | 37 | # last layer 38 | x = self.linear_layers[-1](x) 39 | if self.activation_on_output: 40 | x = self.activation_fn(x) 41 | 42 | return x 43 | 44 | 45 | class AttModel(nn.Module): 46 | """ 47 | Multi-headed attention model based on.. 48 | 49 | a) the following implementations of the paper "Graph Convolutional Reinforcement Learning" 50 | (https://arxiv.org/abs/1810.09202) 51 | a.1) ..in TensorFlow: https://github.com/PKU-RL/DGN 52 | a.2) ..in PyTorch: https://github.com/jiechuanjiang/pytorch_DGN/ 53 | 54 | b) a PyTorch implementation of the Transformer model in "Attention is All You Need" (https://arxiv.org/abs/1706.03762) 55 | https://github.com/jadore801120/attention-is-all-you-need-pytorch 56 | """ 57 | 58 | def __init__( 59 | self, 60 | in_features, 61 | k_features, 62 | v_features, 63 | out_features, 64 | num_heads, 65 | activation_fn, 66 | vkq_activation_fn, 67 | ): 68 | super(AttModel, self).__init__() 69 | self.k_features = k_features 70 | self.v_features = v_features 71 | self.num_heads = num_heads 72 | self.fc_v = nn.Linear(in_features, v_features * num_heads) 73 | self.fc_k = nn.Linear(in_features, k_features * num_heads) 74 | self.fc_q = nn.Linear(in_features, k_features * num_heads) 75 | self.fc_out = nn.Linear(v_features * num_heads, out_features) 76 | self.activation_fn = activation_fn 77 | self.vkq_activation_fn = vkq_activation_fn 78 | 79 | # attention scaling factor 1 / sqrt(d_k) from "Attention is All You Need" 80 | self.attention_scale = 1 / (k_features**0.5) 81 | 82 | def forward(self, x, mask): 83 | batch_size, num_agents = x.shape[0], x.shape[1] 84 | 85 | # get values, queries and keys and view according to heads 86 | # difference to DQN: we use a linear mapping here, as in the Transformer paper 87 | v = self.fc_v(x).view(batch_size, num_agents, self.num_heads, self.v_features) 88 | q = self.fc_q(x).view(batch_size, num_agents, self.num_heads, self.k_features) 89 | k = self.fc_k(x).view(batch_size, num_agents, self.num_heads, self.k_features) 90 | 91 | if self.vkq_activation_fn is not None: 92 | v = self.vkq_activation_fn(v) 93 | q = self.vkq_activation_fn(q) 94 | k = self.vkq_activation_fn(k) 95 | 96 | # permute for batch multiplication over batch size and heads 97 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 98 | 99 | # add head axis (mask is the same for all heads) 100 | mask = mask.unsqueeze(1) 101 | 102 | # calculate attention as dot product of all queries with all keys 103 | # according to mask and softmax over last dimension 104 | att_weights = torch.matmul(q, k.transpose(2, 3)) * self.attention_scale 105 | att = att_weights.masked_fill(mask == 0, -1e9) 106 | att = F.softmax(att, dim=-1) 107 | 108 | # combine values according to attention 109 | out = torch.matmul(att, v) 110 | # skip connection 111 | out = torch.add(out, v) 112 | # undo transpose and concatenate all heads 113 | out = out.transpose(1, 2).contiguous().view(batch_size, num_agents, -1) 114 | 115 | out = self.activation_fn(self.fc_out(out)) 116 | return out, att_weights 117 | 118 | 119 | class Q_Net(nn.Module): 120 | def __init__(self, in_features, actions): 121 | super(Q_Net, self).__init__() 122 | self.fc = nn.Linear(in_features, actions) 123 | 124 | def forward(self, x): 125 | return self.fc(x) 126 | 127 | 128 | class DGN(nn.Module): 129 | def __init__( 130 | self, 131 | in_features, 132 | mlp_units, 133 | num_actions, 134 | num_heads, 135 | num_attention_layers, 136 | activation_fn, 137 | ): 138 | """ 139 | Implementation of "Graph Convolutional Reinforcement Learning" 140 | (https://arxiv.org/abs/1810.09202) based on https://github.com/PKU-RL/DGN. 141 | 142 | :param in_features: Number of input features 143 | :param mlp_units: MLP units (either int or list/tuple) 144 | :param num_actions: Number of actions 145 | :param num_heads: Number of attention heads 146 | :param num_attention_layers: Number of attention layers 147 | :param activation_fn: activation function, defaults to F.relu 148 | """ 149 | super(DGN, self).__init__() 150 | self.encoder = MLP(in_features, mlp_units, activation_fn) 151 | self.att_layers = nn.ModuleList() 152 | hidden_features = self.encoder.out_features 153 | for _ in range(num_attention_layers): 154 | self.att_layers.append( 155 | AttModel( 156 | hidden_features, 157 | # dv=16 from official implementation 158 | # https://github.com/PKU-RL/DGN/blob/92b926888e82880afa68fcd967c6e6527f7773fa/Routing/routers.py#L196 159 | 16, 160 | 16, 161 | hidden_features, 162 | num_heads, 163 | activation_fn, 164 | # official implementation uses act function for key/query/value 165 | activation_fn, 166 | ) 167 | ) 168 | 169 | self.q_net = Q_Net(hidden_features * (num_attention_layers + 1), num_actions) 170 | self.att_weights = [] 171 | 172 | def forward(self, x, mask): 173 | h = self.encoder(x) 174 | 175 | q_input = h 176 | self.att_weights.clear() 177 | for attention_layer in self.att_layers: 178 | h, att_weights = attention_layer(h, mask) 179 | self.att_weights.append(att_weights) 180 | # concatenate outputs like described in the paper & official implementation 181 | q_input = torch.cat((q_input, h), dim=-1) 182 | 183 | q = self.q_net(q_input) 184 | return q 185 | 186 | 187 | class DQN(nn.Module): 188 | """ 189 | Minimal implementation of a DQN model (see https://www.nature.com/articles/nature14236) 190 | with vector-based input. 191 | """ 192 | 193 | def __init__(self, in_features, mlp_units, num_actions, activation_fn): 194 | super(DQN, self).__init__() 195 | self.encoder = MLP(in_features, mlp_units, activation_fn) 196 | self.q_net = Q_Net(self.encoder.out_features, num_actions) 197 | self.activation_fn = activation_fn 198 | 199 | def forward(self, x, mask): 200 | batch, agent, features = x.shape 201 | h = self.encoder(x) 202 | q = self.q_net(h) 203 | return q 204 | 205 | 206 | class SimpleAggregation(nn.Module): 207 | def __init__(self, agg: str, mask_eye: bool) -> None: 208 | super().__init__() 209 | self.agg = agg 210 | assert self.agg == "mean" or self.agg == "sum" 211 | self.mask_eye = mask_eye 212 | 213 | def forward(self, node_features, node_adjacency): 214 | if self.mask_eye: 215 | node_adjacency = node_adjacency * ~( 216 | torch.eye( 217 | node_adjacency.shape[1], 218 | node_adjacency.shape[1], 219 | device=node_adjacency.device, 220 | ) 221 | .repeat(node_adjacency.shape[0], 1, 1) 222 | .bool() 223 | ) 224 | feature_sum = torch.bmm(node_adjacency, node_features) 225 | if self.agg == "sum": 226 | return feature_sum 227 | if self.agg == "mean": 228 | num_neighbors = torch.clamp(node_adjacency.sum(dim=-1), min=1).unsqueeze(-1) 229 | return feature_sum / num_neighbors 230 | 231 | 232 | class JumpingKnowledgeADGN(nn.Module): 233 | """ 234 | Stacks multiple iterations of AntiSymmetricConv (with weight sharing) 235 | and uses the provided jumping knowledge function to aggregate intermediate 236 | node states. 237 | """ 238 | 239 | def __init__(self, hidden_features, num_iters, jk) -> None: 240 | super().__init__() 241 | self.aggregate = AntiSymmetricConv(hidden_features, num_iters=1) 242 | self.num_iters = num_iters 243 | assert self.num_iters >= 1 244 | self.jk = jk 245 | assert self.jk is not None 246 | 247 | def forward(self, x, mask_sparse): 248 | xs = [] 249 | for _ in range(self.num_iters): 250 | x = self.aggregate(x, mask_sparse) 251 | xs.append(x) 252 | 253 | return self.jk(xs) 254 | 255 | 256 | class NetMon(nn.Module): 257 | def __init__( 258 | self, 259 | in_features, 260 | hidden_features: int, 261 | encoder_units, 262 | iterations, 263 | activation_fn, 264 | rnn_type="lstm", 265 | rnn_carryover=True, 266 | agg_type="sum", 267 | output_neighbor_hidden=False, 268 | output_global_hidden=False, 269 | ) -> None: 270 | super().__init__() 271 | assert isinstance(hidden_features, int) 272 | self.encode = MLP( 273 | in_features, 274 | (*encoder_units, hidden_features), 275 | activation_fn, 276 | ) 277 | self.state = None 278 | self.iterations = iterations 279 | self.output_neighbor_hidden = output_neighbor_hidden 280 | self.output_global_hidden = output_global_hidden 281 | self.rnn_carryover = rnn_carryover 282 | 283 | # 0 = dense input 284 | # 1 = sparse input 285 | # 2 = gconvlstm (sparse input) 286 | # 3 = GraphSAGE (sparse input, directly outputs neighbor info) 287 | self.aggregation_def_type = None 288 | 289 | # aggregation 290 | self.agg_type_str = agg_type 291 | # first resolve jumping knowledge functions for GraphSAGE and A-DGN 292 | self.jk = None 293 | if "jk-cat" in agg_type: 294 | self.jk_out = nn.Linear(hidden_features * iterations, hidden_features) 295 | self.jk_neighbors = nn.Linear( 296 | hidden_features * (iterations - 1), hidden_features 297 | ) 298 | 299 | def jk_cat(xs): 300 | return ( 301 | self.jk_out(torch.cat(xs, dim=-1)), 302 | self.jk_neighbors(torch.cat(xs[:-1], dim=-1)), 303 | ) 304 | 305 | self.jk = jk_cat 306 | elif "jk-max" in agg_type: 307 | 308 | def jk_max(xs): 309 | return ( 310 | torch.max(torch.stack(xs), dim=0)[0], 311 | torch.max(torch.stack(xs[:-1]), dim=0)[0], 312 | ) 313 | 314 | self.jk = jk_max 315 | elif agg_type == "graphsage" or agg_type == "adgn": 316 | 317 | def jk(xs): 318 | return (xs[-1], xs[-2]) 319 | 320 | self.jk = jk 321 | 322 | # now resolve the actual aggregation 323 | if agg_type == "sum" or agg_type == "mean": 324 | self.aggregate = SimpleAggregation(agg=agg_type, mask_eye=False) 325 | self.aggregation_def_type = 0 326 | elif agg_type == "gcn": 327 | self.aggregate = GCNConv(hidden_features, hidden_features, improved=True) 328 | self.aggregation_def_type = 1 329 | elif agg_type == "sage": 330 | self.aggregate = SAGEConv(hidden_features, hidden_features) 331 | self.aggregation_def_type = 1 332 | elif "graphsage" in agg_type: 333 | self.aggregate = GraphSAGE( 334 | hidden_features, hidden_features, num_layers=iterations 335 | ) 336 | self.agg_type_str = agg_type + f" ({iterations} layer)" 337 | assert self.jk is not None 338 | self.aggregate.jk = self.jk 339 | self.aggregate.jk_mode = "custom" 340 | self.aggregation_def_type = 3 341 | self.iterations = 1 342 | if rnn_type != "none": 343 | print(f"WARNING: Overwritten given rnn type {rnn_type} with 'none'") 344 | rnn_type = "none" 345 | elif "adgn" in agg_type: 346 | self.aggregate = JumpingKnowledgeADGN( 347 | hidden_features, num_iters=iterations, jk=self.jk 348 | ) 349 | self.agg_type_str = agg_type + f" ({iterations} layer)" 350 | self.aggregation_def_type = 3 351 | self.iterations = 1 352 | if rnn_type != "none": 353 | print(f"WARNING: Overwritten given rnn type {rnn_type} with 'none'") 354 | rnn_type = "none" 355 | elif agg_type == "antisymgcn": 356 | # use single iteration so that we still get last hidden node states 357 | self.aggregate = AntiSymmetricConv(hidden_features, num_iters=1) 358 | self.aggregation_def_type = 1 359 | elif agg_type == "gconvlstm": 360 | # filter size 1 => only from neighbors 361 | from torch_geometric_temporal.nn.recurrent.gconv_lstm import GConvLSTM 362 | 363 | self.agg_type_str = agg_type + f" (filter size {iterations + 1})" 364 | self.aggregate = GConvLSTM( 365 | hidden_features, hidden_features, K=(self.iterations + 1) 366 | ) 367 | self.iterations = 1 368 | self.aggregation_def_type = 2 369 | if rnn_type != "gconvlstm": 370 | print( 371 | f"WARNING: Overwritten given rnn type {rnn_type} with 'gconvlstm'" 372 | ) 373 | rnn_type = "gconvlstm" 374 | else: 375 | raise ValueError(f"Unknown aggregation type {agg_type}") 376 | 377 | # update and observation encoding 378 | self.rnn_type = rnn_type 379 | if self.rnn_type == "lstm": 380 | self.rnn_obs = nn.LSTMCell(hidden_features, hidden_features) 381 | self.rnn_update = nn.LSTMCell(hidden_features, hidden_features) 382 | self.num_states = 2 if rnn_carryover else 4 383 | elif self.rnn_type == "lnlstm": 384 | self.rnn_obs = LayerNormLSTMCell(hidden_features, hidden_features) 385 | self.rnn_update = LayerNormLSTMCell(hidden_features, hidden_features) 386 | self.num_states = 2 if rnn_carryover else 4 387 | elif self.rnn_type == "gru": 388 | self.rnn_obs = nn.GRUCell(hidden_features, hidden_features) 389 | self.rnn_update = nn.GRUCell(hidden_features, hidden_features) 390 | self.num_states = 1 if rnn_carryover else 2 391 | elif self.rnn_type == "gconvlstm": 392 | # rnn is part of aggregate function 393 | self.num_states = 2 394 | elif self.rnn_type == "none": 395 | # empty state / stateless => simply store h for debugging 396 | self.num_states = 1 397 | else: 398 | raise ValueError(f"Unknown rnn type {self.rnn_type}") 399 | 400 | self.hidden_features = hidden_features 401 | self.state_size = hidden_features * self.num_states 402 | 403 | def get_out_features(self): 404 | out_features = self.hidden_features 405 | 406 | if self.output_neighbor_hidden: 407 | out_features += self.hidden_features * 3 408 | 409 | if self.output_global_hidden: 410 | out_features += self.hidden_features 411 | 412 | return out_features 413 | 414 | def get_state_size(self): 415 | return self.state_size 416 | 417 | def _state_reshape_in(self, batch_size, n_agents): 418 | """ 419 | Reshapes the state of shape 420 | (batch_size, n_agents, self.get_state_len()) 421 | to shape 422 | (2, batch_size * n_agents, hidden_size) 423 | 424 | :param batch_size: the batch size 425 | :param n_agents: the number of agents 426 | """ 427 | if self.state.numel() == 0: 428 | return 429 | 430 | self.state = self.state.reshape( 431 | batch_size * n_agents, 432 | self.num_states, 433 | -1, 434 | ).transpose(0, 1) 435 | 436 | def _state_reshape_out(self, batch_size, n_agents): 437 | """ 438 | Reshapes the state of shape 439 | (2, batch_size * n_agents, hidden_size) 440 | to shape 441 | (batch_size, n_agents, self.get_state_len()). 442 | 443 | :param batch_size: the batch size 444 | :param n_agents: the number of agents 445 | """ 446 | if self.state.numel() == 0: 447 | return 448 | 449 | self.state = self.state.transpose(0, 1).reshape(batch_size, n_agents, -1) 450 | 451 | def forward( 452 | self, x, mask, node_agent_matrix, max_degree=None, no_agent_mapping=False 453 | ): 454 | # steps (1), (2) and (3) 455 | h, last_neighbor_h = self._update_node_states(x, mask) 456 | 457 | # step (4) 458 | if self.output_neighbor_hidden or self.output_global_hidden: 459 | extended_h = [h] 460 | 461 | if self.output_global_hidden: 462 | extended_h.append(self._get_global_h(h)) 463 | 464 | if self.output_neighbor_hidden: 465 | extended_h.append( 466 | self._get_neighbor_h(last_neighbor_h, mask, max_degree) 467 | ) 468 | 469 | h = torch.cat(extended_h, dim=-1) 470 | 471 | if no_agent_mapping: 472 | return h 473 | 474 | return NetMon.output_to_network_obs(h, node_agent_matrix) 475 | 476 | def _update_node_states(self, x, mask): 477 | batch_size, n_nodes, feature_dim = x.shape 478 | x = x.reshape(batch_size * n_nodes, -1) 479 | 480 | if self.state is None: 481 | # initialize state 482 | self.state = torch.zeros( 483 | (batch_size, n_nodes, self.state_size), device=x.device 484 | ) 485 | 486 | self._state_reshape_in(batch_size, n_nodes) 487 | 488 | # step (1): encode observation to get h^0_v and combine with state 489 | h = self.encode(x) 490 | if self.rnn_type == "lstm" or self.rnn_type == "lnlstm": 491 | h0, cx0 = self.rnn_obs(h, (self.state[0], self.state[1])) 492 | h, cx = h0, cx0 493 | elif self.rnn_type == "gru": 494 | h0 = self.rnn_obs(h, self.state[0]) 495 | h = h0 496 | 497 | # message passing iterations 498 | if self.iterations <= 0 and self.output_neighbor_hidden: 499 | last_neighbor_h = torch.zeros_like(h, device=h.device) 500 | else: 501 | last_neighbor_h = None 502 | 503 | if self.aggregation_def_type != 0: 504 | mask_sparse, mask_weights = dense_to_sparse(mask) 505 | 506 | if self.aggregation_def_type == 2: 507 | H, C = self.state[0], self.state[1] 508 | 509 | for it in range(self.iterations): 510 | if self.output_neighbor_hidden and it == self.iterations - 1: 511 | if self.aggregation_def_type == 2: 512 | # we know that the aggregation step will exchange the hidden states 513 | # (and much more..) so we can just use them for the skip connection 514 | # instead of the other nodes' input. 515 | # This is only relevant for a single iteration per step. 516 | last_neighbor_h = H 517 | else: 518 | # use the last received hidden state 519 | last_neighbor_h = h 520 | 521 | # step (2): aggregate 522 | if self.aggregation_def_type == 0: 523 | M = self.aggregate(h.view(batch_size, n_nodes, -1), mask).view( 524 | batch_size * n_nodes, -1 525 | ) 526 | elif self.aggregation_def_type == 1: 527 | M = self.aggregate(h, mask_sparse) 528 | elif self.aggregation_def_type == 2: 529 | H, C = self.aggregate(h, mask_sparse, H=H, C=C) 530 | M = H 531 | elif self.aggregation_def_type == 3: 532 | # overwrite last_neighbor_h with jumping knowledge output 533 | M, last_neighbor_h = self.aggregate(h, mask_sparse) 534 | 535 | # step (3): update 536 | # 23.03.23 GRU significantly better at regression task than simple update 537 | if self.rnn_type == "lstm" or self.rnn_type == "lnlstm": 538 | if not self.rnn_carryover and it == 0: 539 | rnn_input = (self.state[2], self.state[3]) 540 | else: 541 | rnn_input = (h, cx) 542 | 543 | h1, cx1 = self.rnn_update(M, rnn_input) 544 | h, cx = h1, cx1 545 | elif self.rnn_type == "gru": 546 | if not self.rnn_carryover and it == 0: 547 | rnn_input = self.state[1] 548 | else: 549 | rnn_input = h 550 | 551 | h1 = self.rnn_update(M, rnn_input) 552 | h = h1 553 | else: 554 | h = M 555 | 556 | # reshape 557 | if last_neighbor_h is not None: 558 | last_neighbor_h = last_neighbor_h.reshape(batch_size, n_nodes, -1) 559 | h = h.reshape(batch_size, n_nodes, -1) 560 | 561 | # update internal state 562 | if self.rnn_type == "lstm" or self.rnn_type == "lnlstm": 563 | if self.rnn_carryover: 564 | self.state = torch.stack((h1, cx1)) 565 | else: 566 | self.state = torch.stack((h0, cx0, h1, cx1)) 567 | elif self.rnn_type == "gru": 568 | if self.rnn_carryover: 569 | self.state = h1.unsqueeze(0) 570 | else: 571 | self.state = torch.stack((h0.unsqueeze(0), h1.unsqueeze(0))) 572 | elif self.rnn_type == "gconvlstm": 573 | self.state = torch.stack((H, C)) 574 | elif self.rnn_type == "none": 575 | # store last node state for debugging and aux loss 576 | self.state = h.unsqueeze(0) 577 | 578 | self._state_reshape_out(batch_size, n_nodes) 579 | 580 | return h, last_neighbor_h 581 | 582 | def _get_neighbor_h(self, neighbor_h, mask, max_degree): 583 | batch_size, n_nodes, _ = neighbor_h.shape 584 | # return own hidden state + last received neighbor hidden states ordered 585 | # by node ids 586 | 587 | # get max node id for dense observation tensor (excluding self) 588 | if max_degree is None: 589 | max_degree = torch.sum(mask, dim=-1).max().long().item() - 1 590 | 591 | # placeholder for observations for each neighbor 592 | h_neighbors = torch.zeros( 593 | (batch_size, n_nodes, max_degree, neighbor_h.shape[-1]), 594 | device=neighbor_h.device, 595 | ) 596 | 597 | # get mask without self (only containing neighbors) 598 | neighbor_mask = mask * ~( 599 | torch.eye(n_nodes, n_nodes, device=mask.device) 600 | .unsqueeze(0) 601 | .repeat(mask.shape[0], 1, 1) 602 | .bool() 603 | ) 604 | 605 | # we want to collect features from neighbors and put them into h_neighbors 606 | # 1) get the neighbor node indices (batch, node, neighbor) 607 | h_index = neighbor_mask.nonzero() 608 | 609 | # 2) get the relative neighbor id for the insertion in h_neighbors 610 | # first neighbor (with lowest node id) is neighbor 0, then the ids increase 611 | cumulative_neighbor_index = neighbor_mask.cumsum(dim=-1).long() - 1 612 | h_neighbors_index = cumulative_neighbor_index[ 613 | h_index[:, 0], h_index[:, 1], h_index[:, 2] 614 | ] 615 | 616 | # 3) copy the last hidden states of all neighbors into the h_neighbors tensor 617 | h_neighbors[h_index[:, 0], h_index[:, 1], h_neighbors_index] = neighbor_h[ 618 | h_index[:, 0], h_index[:, 2] 619 | ] 620 | 621 | # concatenate info for each node 622 | return h_neighbors.reshape(batch_size, n_nodes, -1) 623 | 624 | def _get_global_h(self, h): 625 | _, n_nodes, _ = h.shape 626 | global_h = h.mean(dim=1).repeat((n_nodes, 1, 1)).transpose(0, 1) 627 | return global_h 628 | 629 | @staticmethod 630 | def output_to_network_obs(netmon_out, node_agent_matrix): 631 | return torch.bmm(netmon_out.transpose(1, 2), node_agent_matrix).transpose(1, 2) 632 | 633 | def summarize(self, *args): 634 | str_out = [] 635 | str_out.append("NetMon Module") 636 | str_out.append(summary(self, *args, max_depth=10)) 637 | self.state = None 638 | str_out.append(f"> Aggregation Type: {self.agg_type_str}") 639 | str_out.append(f"> RNN Type: {self.rnn_type}") 640 | str_out.append(f"> Carryover: {self.rnn_carryover}") 641 | str_out.append(f"> Iterations: {self.iterations}") 642 | readout_str = "> Readout: local" 643 | if self.output_neighbor_hidden: 644 | readout_str += " + last neighbors" 645 | if self.output_global_hidden: 646 | readout_str += " + global agg" 647 | str_out.append(readout_str) 648 | import os 649 | 650 | return os.linesep.join(str_out) 651 | 652 | 653 | class DQNR(nn.Module): 654 | """ 655 | Recurrent DQN with an lstm cell. 656 | """ 657 | 658 | def __init__(self, in_features, mlp_units, num_actions, activation_fn): 659 | super(DQNR, self).__init__() 660 | self.encoder = MLP(in_features, mlp_units, activation_fn) 661 | self.lstm = nn.LSTMCell( 662 | input_size=self.encoder.out_features, hidden_size=self.encoder.out_features 663 | ) 664 | self.state = None 665 | self.q_net = Q_Net(self.encoder.out_features, num_actions) 666 | 667 | def get_state_len(self): 668 | return 2 * self.lstm.hidden_size 669 | 670 | def _state_reshape_in(self, batch_size, n_agents): 671 | """ 672 | Reshapes the state of shape 673 | (batch_size, n_agents, self.get_state_len()) 674 | to shape 675 | (2, batch_size * n_agents, hidden_size). 676 | 677 | :param batch_size: the batch size 678 | :param n_agents: the number of agents 679 | """ 680 | self.state = ( 681 | self.state.reshape( 682 | batch_size * n_agents, 683 | 2, 684 | self.lstm.hidden_size, 685 | ) 686 | .transpose(0, 1) 687 | .contiguous() 688 | ) 689 | 690 | def _state_reshape_out(self, batch_size, n_agents): 691 | """ 692 | Reshapes the state of shape 693 | (2, batch_size * n_agents, hidden_size) 694 | to shape 695 | (batch_size, n_agents, self.get_state_len()). 696 | 697 | :param batch_size: the batch size 698 | :param n_agents: the number of agents 699 | """ 700 | self.state = self.state.transpose(0, 1).reshape(batch_size, n_agents, -1) 701 | 702 | def _lstm_forward(self, x, reshape_state=True): 703 | """ 704 | A single lstm forward pass 705 | 706 | :param x: Cell input 707 | :param reshape_state: reshape the state to and from (batch_size, n_agents, -1) 708 | """ 709 | batch_size, n_agents, feature_dim = x.shape 710 | # combine agent and batch dimension 711 | x = x.view(batch_size * n_agents, -1) 712 | 713 | if self.state is None: 714 | lstm_hidden_state, lstm_cell_state = self.lstm(x) 715 | else: 716 | if reshape_state: 717 | self._state_reshape_in(batch_size, n_agents) 718 | lstm_hidden_state, lstm_cell_state = self.lstm( 719 | x, (self.state[0], self.state[1]) 720 | ) 721 | 722 | self.state = torch.stack((lstm_hidden_state, lstm_cell_state)) 723 | x = lstm_hidden_state 724 | 725 | # undo combine 726 | x = x.view(batch_size, n_agents, -1) 727 | if reshape_state: 728 | self._state_reshape_out(batch_size, n_agents) 729 | 730 | return x 731 | 732 | def forward(self, x, mask): 733 | h = self.encoder(x) 734 | h = self._lstm_forward(h) 735 | return self.q_net(h) 736 | 737 | 738 | class CommNet(DQNR): 739 | """ 740 | Implementation of CommNet https://arxiv.org/abs/1605.07736 with masked communication 741 | between agents. 742 | 743 | While the hidden state is aggregated over the neighbors during communication, the 744 | individual cell states stay the same. This is how IC3Net implemented CommNet 745 | https://github.com/IC3Net/IC3Net. The CommNet paper does not elaborate on if and how 746 | the cell states are combined. 747 | """ 748 | 749 | def __init__( 750 | self, 751 | in_features, 752 | mlp_units, 753 | num_actions, 754 | comm_rounds, 755 | activation_fn, 756 | ): 757 | super().__init__(in_features, mlp_units, num_actions, activation_fn) 758 | assert comm_rounds >= 0 759 | self.comm_rounds = comm_rounds 760 | 761 | def forward(self, x, mask): 762 | batch_size, n_agents, feature_dim = x.shape 763 | h = self.encoder(x) 764 | 765 | # manually reshape state 766 | if self.state is not None: 767 | self._state_reshape_in(batch_size, n_agents) 768 | 769 | h = self._lstm_forward(h, reshape_state=False) 770 | 771 | # explicitly exclude self-communication from mask 772 | mask = mask * ~torch.eye(n_agents, dtype=bool, device=x.device).unsqueeze(0) 773 | 774 | for _ in range(self.comm_rounds): 775 | # combine hidden state h according to mask 776 | # first add up hidden states according to mask 777 | # h has dimensions (batch, agents, features) 778 | # and mask has dimensions (batch, agents, neighbors) 779 | # => we have to transpose the mask to aggregate over all neighbors 780 | c = torch.bmm(h.transpose(1, 2), mask.transpose(1, 2)).transpose(1, 2) 781 | # then normalize according to number of neighbors per agent 782 | c = c / torch.clamp(mask.sum(dim=-1).unsqueeze(-1), min=1) 783 | 784 | # skip connection for hidden state and communication 785 | h = h + c 786 | # use new hidden state 787 | self.state[0] = h.view(batch_size * n_agents, -1) 788 | 789 | # pass through forward module 790 | h = self._lstm_forward(h, reshape_state=False) 791 | 792 | # manually reshape state in the end 793 | self._state_reshape_out(batch_size, n_agents) 794 | return self.q_net(h) 795 | --------------------------------------------------------------------------------