├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── init.sh ├── release_message.sh ├── rename_project.sh └── workflows │ ├── main.yml │ ├── release.yml │ └── rename_project.yml ├── .gitignore ├── CONTRIBUTING.md ├── Containerfile ├── HISTORY.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── docs └── index.md ├── experiment ├── __init__.py ├── batch_run.sh ├── bologna_clean │ ├── clean.net.xml │ ├── e1_output.xml │ ├── joined.rou.xml │ ├── joined_detectors.add.xml │ ├── joined_tls.add.xml │ ├── joined_vtypes.add.xml │ ├── run.sumocfg │ ├── sumo_log.txt │ ├── tripinfo.xml │ └── tripinfos.xml ├── gen_data.sh ├── sumo_env.py ├── sumo_exp.py └── traci_tls │ ├── cross.con.xml │ ├── cross.det.xml │ ├── cross.edg.xml │ ├── cross.net.xml │ ├── cross.nod.xml │ ├── cross.out │ ├── cross.rou.xml │ ├── cross.rou1.xml │ ├── cross.rou2.xml │ ├── cross.rou3.xml │ ├── cross.rou4.xml │ ├── cross.rou5.xml │ ├── run.sumocfg │ ├── run.sumocfg1 │ ├── run.sumocfg2 │ ├── run.sumocfg3 │ ├── run.sumocfg4 │ ├── run.sumocfg5 │ └── tripinfo.xml ├── mkdocs.yml ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── conftest.py ├── test_base.py ├── test_graph_process.py ├── test_sumo_env.py └── test_tsim_rule.py └── transworld ├── VERSION ├── __init__.py ├── base.py ├── finetune.py ├── game ├── VERSION ├── __init__.py ├── __main__.py ├── base.py ├── cli.py ├── core │ ├── __init__.py │ ├── config.py │ ├── controller.py │ ├── init.py │ ├── node.py │ └── typing.py ├── data │ ├── __init__.py │ ├── dataloader.py │ └── dataset.py ├── graph │ ├── __init__.py │ └── graph.py ├── model │ ├── HGT.py │ ├── __init__.py │ ├── generator.py │ ├── linear.py │ └── loss.py └── operator │ ├── __init__.py │ └── transform.py ├── graph ├── __init__.py ├── load.py └── process.py ├── rules ├── __init__.py ├── post_process.py ├── post_process_old.py ├── pre_process.py └── pre_process_old.py ├── run.sh └── transworld_exp.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [rochacbruno] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug, help wanted 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Version [e.g. 22] 29 | 30 | **Additional context** 31 | Add any other context about the problem here. 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement, question 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### Summary :memo: 2 | _Write an overview about it._ 3 | 4 | ### Details 5 | _Describe more what you did on changes._ 6 | 1. (...) 7 | 2. (...) 8 | 9 | ### Bugfixes :bug: (delete if dind't have any) 10 | - 11 | 12 | ### Checks 13 | - [ ] Closed #798 14 | - [ ] Tested Changes 15 | - [ ] Stakeholder Approval 16 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" -------------------------------------------------------------------------------- /.github/init.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | overwrite_template_dir=0 3 | 4 | while getopts t:o flag 5 | do 6 | case "${flag}" in 7 | t) template=${OPTARG};; 8 | o) overwrite_template_dir=1;; 9 | esac 10 | done 11 | 12 | if [ -z "${template}" ]; then 13 | echo "Available templates: flask" 14 | read -p "Enter template name: " template 15 | fi 16 | 17 | repo_urlname=$(basename -s .git `git config --get remote.origin.url`) 18 | repo_name=$(basename -s .git `git config --get remote.origin.url` | tr '-' '_' | tr '[:upper:]' '[:lower:]') 19 | repo_owner=$(git config --get remote.origin.url | awk -F ':' '{print $2}' | awk -F '/' '{print $1}') 20 | echo "Repo name: ${repo_name}" 21 | echo "Repo owner: ${repo_owner}" 22 | echo "Repo urlname: ${repo_urlname}" 23 | 24 | if [ -f ".github/workflows/rename_project.yml" ]; then 25 | .github/rename_project.sh -a "${repo_owner}" -n "${repo_name}" -u "${repo_urlname}" -d "Awesome ${repo_name} created by ${repo_owner}" 26 | fi 27 | 28 | function download_template { 29 | rm -rf "${template_dir}" 30 | mkdir -p .github/templates 31 | git clone "${template_url}" "${template_dir}" 32 | } 33 | 34 | echo "Using template:${template}" 35 | template_url="https://github.com/rochacbruno/${template}-project-template" 36 | template_dir=".github/templates/${template}" 37 | if [ -d "${template_dir}" ]; then 38 | # Template directory already exists 39 | if [ "${overwrite_template_dir}" -eq 1 ]; then 40 | # user passed -o flag, delete and re-download 41 | echo "Overwriting ${template_dir}" 42 | download_template 43 | else 44 | # Ask user if they want to overwrite 45 | echo "Directory ${template_dir} already exists." 46 | read -p "Do you want to overwrite it? [y/N] " -n 1 -r 47 | echo 48 | if [[ $REPLY =~ ^[Yy]$ ]]; then 49 | echo "Overwriting ${template_dir}" 50 | download_template 51 | else 52 | # User decided not to overwrite 53 | echo "Using existing ${template_dir}" 54 | fi 55 | fi 56 | else 57 | # Template directory does not exist, download it 58 | echo "Downloading ${template_url}" 59 | download_template 60 | fi 61 | 62 | echo "Applying ${template} template to this project"} 63 | ./.github/templates/${template}/apply.sh -a "${repo_owner}" -n "${repo_name}" -u "${repo_urlname}" -d "Awesome ${repo_name} created by ${repo_owner}" 64 | 65 | # echo "Removing temporary template files" 66 | # rm -rf .github/templates/${template} 67 | 68 | echo "Done! review, commit and push the changes" 69 | -------------------------------------------------------------------------------- /.github/release_message.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | previous_tag=$(git tag --sort=-creatordate | sed -n 2p) 3 | git shortlog "${previous_tag}.." | sed 's/^./ &/' 4 | -------------------------------------------------------------------------------- /.github/rename_project.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | while getopts a:n:u:d: flag 3 | do 4 | case "${flag}" in 5 | a) author=${OPTARG};; 6 | n) name=${OPTARG};; 7 | u) urlname=${OPTARG};; 8 | d) description=${OPTARG};; 9 | esac 10 | done 11 | 12 | echo "Author: $author"; 13 | echo "Project Name: $name"; 14 | echo "Project URL name: $urlname"; 15 | echo "Description: $description"; 16 | 17 | echo "Renaming project..." 18 | 19 | original_author="PJSAC" 20 | original_name="transworld" 21 | original_urlname="TransWorld" 22 | original_description="Awesome transworld created by PJSAC" 23 | # for filename in $(find . -name "*.*") 24 | for filename in $(git ls-files) 25 | do 26 | sed -i "s/$original_author/$author/g" $filename 27 | sed -i "s/$original_name/$name/g" $filename 28 | sed -i "s/$original_urlname/$urlname/g" $filename 29 | sed -i "s/$original_description/$description/g" $filename 30 | echo "Renamed $filename" 31 | done 32 | 33 | mv transworld $name 34 | 35 | # This command runs only once on GHA! 36 | rm -rf .github/template.yml 37 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: CI 4 | 5 | # Controls when the workflow will run 6 | on: 7 | # Triggers the workflow on push or pull request events but only for the main branch 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | 13 | # Allows you to run this workflow manually from the Actions tab 14 | workflow_dispatch: 15 | 16 | jobs: 17 | linter: 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | python-version: [3.9] 22 | os: [ubuntu-latest] 23 | runs-on: ${{ matrix.os }} 24 | steps: 25 | - uses: actions/checkout@v3 26 | - uses: actions/setup-python@v4 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - name: Install project 30 | run: make install 31 | - name: Run linter 32 | run: make lint 33 | 34 | tests_linux: 35 | needs: linter 36 | strategy: 37 | fail-fast: false 38 | matrix: 39 | python-version: [3.9] 40 | os: [ubuntu-latest] 41 | runs-on: ${{ matrix.os }} 42 | steps: 43 | - uses: actions/checkout@v3 44 | - uses: actions/setup-python@v4 45 | with: 46 | python-version: ${{ matrix.python-version }} 47 | - name: Install project 48 | run: make install 49 | - name: Run tests 50 | run: make test 51 | - name: "Upload coverage to Codecov" 52 | uses: codecov/codecov-action@v3 53 | # with: 54 | # fail_ci_if_error: true 55 | 56 | tests_mac: 57 | needs: linter 58 | strategy: 59 | fail-fast: false 60 | matrix: 61 | python-version: [3.9] 62 | os: [macos-latest] 63 | runs-on: ${{ matrix.os }} 64 | steps: 65 | - uses: actions/checkout@v3 66 | - uses: actions/setup-python@v4 67 | with: 68 | python-version: ${{ matrix.python-version }} 69 | - name: Install project 70 | run: make install 71 | - name: Run tests 72 | run: make test 73 | 74 | tests_win: 75 | needs: linter 76 | strategy: 77 | fail-fast: false 78 | matrix: 79 | python-version: [3.9] 80 | os: [windows-latest] 81 | runs-on: ${{ matrix.os }} 82 | steps: 83 | - uses: actions/checkout@v3 84 | - uses: actions/setup-python@v4 85 | with: 86 | python-version: ${{ matrix.python-version }} 87 | - name: Install Pip 88 | run: pip install --user --upgrade pip 89 | - name: Install project 90 | run: pip install -e .[test] 91 | - name: run tests 92 | run: pytest -s -vvvv -l --tb=long tests 93 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | push: 5 | # Sequence of patterns matched against refs/tags 6 | tags: 7 | - '*' # Push events to matching v*, i.e. v1.0, v20.15.10 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | jobs: 13 | release: 14 | name: Create Release 15 | runs-on: ubuntu-latest 16 | permissions: 17 | contents: write 18 | steps: 19 | - uses: actions/checkout@v3 20 | with: 21 | # by default, it uses a depth of 1 22 | # this fetches all history so that we can read each commit 23 | fetch-depth: 0 24 | - name: Generate Changelog 25 | run: .github/release_message.sh > release_message.md 26 | - name: Release 27 | uses: softprops/action-gh-release@v1 28 | with: 29 | body_path: release_message.md 30 | 31 | deploy: 32 | needs: release 33 | runs-on: ubuntu-latest 34 | steps: 35 | - uses: actions/checkout@v3 36 | - name: Set up Python 37 | uses: actions/setup-python@v4 38 | with: 39 | python-version: '3.x' 40 | - name: Install dependencies 41 | run: | 42 | python -m pip install --upgrade pip 43 | pip install setuptools wheel twine 44 | - name: Build and publish 45 | env: 46 | TWINE_USERNAME: __token__ 47 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 48 | run: | 49 | python setup.py sdist bdist_wheel 50 | twine upload dist/* 51 | -------------------------------------------------------------------------------- /.github/workflows/rename_project.yml: -------------------------------------------------------------------------------- 1 | name: Rename the project from template 2 | 3 | on: [push] 4 | 5 | permissions: write-all 6 | 7 | jobs: 8 | rename-project: 9 | if: ${{ !contains (github.repository, '/python-project-template') }} 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | with: 14 | # by default, it uses a depth of 1 15 | # this fetches all history so that we can read each commit 16 | fetch-depth: 0 17 | ref: ${{ github.head_ref }} 18 | 19 | - run: echo "REPOSITORY_NAME=$(echo '${{ github.repository }}' | awk -F '/' '{print $2}' | tr '-' '_' | tr '[:upper:]' '[:lower:]')" >> $GITHUB_ENV 20 | shell: bash 21 | 22 | - run: echo "REPOSITORY_URLNAME=$(echo '${{ github.repository }}' | awk -F '/' '{print $2}')" >> $GITHUB_ENV 23 | shell: bash 24 | 25 | - run: echo "REPOSITORY_OWNER=$(echo '${{ github.repository }}' | awk -F '/' '{print $1}')" >> $GITHUB_ENV 26 | shell: bash 27 | 28 | - name: Is this still a template 29 | id: is_template 30 | run: echo "::set-output name=is_template::$(ls .github/template.yml &> /dev/null && echo true || echo false)" 31 | 32 | - name: Rename the project 33 | if: steps.is_template.outputs.is_template == 'true' 34 | run: | 35 | echo "Renaming the project with -a(author) ${{ env.REPOSITORY_OWNER }} -n(name) ${{ env.REPOSITORY_NAME }} -u(urlname) ${{ env.REPOSITORY_URLNAME }}" 36 | .github/rename_project.sh -a ${{ env.REPOSITORY_OWNER }} -n ${{ env.REPOSITORY_NAME }} -u ${{ env.REPOSITORY_URLNAME }} -d "Awesome ${{ env.REPOSITORY_NAME }} created by ${{ env.REPOSITORY_OWNER }}" 37 | 38 | - uses: stefanzweifel/git-auto-commit-action@v4 39 | with: 40 | commit_message: "✅ Ready to clone and code." 41 | # commit_options: '--amend --no-edit' 42 | push_options: --force 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # templates 132 | .github/templates/* 133 | 134 | /experiment/traci_tls/data 135 | /experiment/bologna_clean/data -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to develop on this project 2 | 3 | transworld welcomes contributions from the community. 4 | 5 | **You need PYTHON3!** 6 | 7 | This instructions are for linux base systems. (Linux, MacOS, BSD, etc.) 8 | ## Setting up your own fork of this repo. 9 | 10 | - On github interface click on `Fork` button. 11 | - Clone your fork of this repo. `git clone git@github.com:YOUR_GIT_USERNAME/TransWorld.git` 12 | - Enter the directory `cd TransWorld` 13 | - Add upstream repo `git remote add upstream https://github.com/PJSAC/TransWorld` 14 | 15 | ## Setting up your own virtual environment 16 | 17 | Run `make virtualenv` to create a virtual environment. 18 | then activate it with `source .venv/bin/activate`. 19 | 20 | ## Install the project in develop mode 21 | 22 | Run `make install` to install the project in develop mode. 23 | 24 | ## Run the tests to ensure everything is working 25 | 26 | Run `make test` to run the tests. 27 | 28 | ## Create a new branch to work on your contribution 29 | 30 | Run `git checkout -b my_contribution` 31 | 32 | ## Make your changes 33 | 34 | Edit the files using your preferred editor. (we recommend VIM or VSCode) 35 | 36 | ## Format the code 37 | 38 | Run `make fmt` to format the code. 39 | 40 | ## Run the linter 41 | 42 | Run `make lint` to run the linter. 43 | 44 | ## Test your changes 45 | 46 | Run `make test` to run the tests. 47 | 48 | Ensure code coverage report shows `100%` coverage, add tests to your PR. 49 | 50 | ## Build the docs locally 51 | 52 | Run `make docs` to build the docs. 53 | 54 | Ensure your new changes are documented. 55 | 56 | ## Commit your changes 57 | 58 | This project uses [conventional git commit messages](https://www.conventionalcommits.org/en/v1.0.0/). 59 | 60 | Example: `fix(package): update setup.py arguments 🎉` (emojis are fine too) 61 | 62 | ## Push your changes to your fork 63 | 64 | Run `git push origin my_contribution` 65 | 66 | ## Submit a pull request 67 | 68 | On github interface, click on `Pull Request` button. 69 | 70 | Wait CI to run and one of the developers will review your PR. 71 | ## Makefile utilities 72 | 73 | This project comes with a `Makefile` that contains a number of useful utility. 74 | 75 | ```bash 76 | ❯ make 77 | Usage: make 78 | 79 | Targets: 80 | help: ## Show the help. 81 | install: ## Install the project in dev mode. 82 | fmt: ## Format code using black & isort. 83 | lint: ## Run pep8, black, mypy linters. 84 | test: lint ## Run tests and generate coverage report. 85 | watch: ## Run tests on every change. 86 | clean: ## Clean unused files. 87 | virtualenv: ## Create a virtual environment. 88 | release: ## Create a new tag for release. 89 | docs: ## Build the documentation. 90 | switch-to-poetry: ## Switch to poetry package manager. 91 | init: ## Initialize the project based on an application template. 92 | ``` 93 | 94 | ## Making a new release 95 | 96 | This project uses [semantic versioning](https://semver.org/) and tags releases with `X.Y.Z` 97 | Every time a new tag is created and pushed to the remote repo, github actions will 98 | automatically create a new release on github and trigger a release on PyPI. 99 | 100 | For this to work you need to setup a secret called `PIPY_API_TOKEN` on the project settings>secrets, 101 | this token can be generated on [pypi.org](https://pypi.org/account/). 102 | 103 | To trigger a new release all you need to do is. 104 | 105 | 1. If you have changes to add to the repo 106 | * Make your changes following the steps described above. 107 | * Commit your changes following the [conventional git commit messages](https://www.conventionalcommits.org/en/v1.0.0/). 108 | 2. Run the tests to ensure everything is working. 109 | 4. Run `make release` to create a new tag and push it to the remote repo. 110 | 111 | the `make release` will ask you the version number to create the tag, ex: type `0.1.1` when you are asked. 112 | 113 | > **CAUTION**: The make release will change local changelog files and commit all the unstaged changes you have. 114 | -------------------------------------------------------------------------------- /Containerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.7-alpine 2 | COPY . /app 3 | WORKDIR /app 4 | RUN pip install . 5 | CMD ["transworld"] 6 | -------------------------------------------------------------------------------- /HISTORY.md: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | 5 | 0.1.2 (2021-08-14) 6 | ------------------ 7 | - Fix release, README and windows CI. [Bruno Rocha] 8 | - Release: version 0.1.0. [Bruno Rocha] 9 | 10 | 11 | 0.1.0 (2021-08-14) 12 | ------------------ 13 | - Add release command. [Bruno Rocha] 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include HISTORY.md 3 | include Containerfile 4 | graft tests 5 | graft transworld 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .ONESHELL: 2 | ENV_PREFIX=$(shell python -c "if __import__('pathlib').Path('.venv/bin/pip').exists(): print('.venv/bin/')") 3 | USING_POETRY=$(shell grep "tool.poetry" pyproject.toml && echo "yes") 4 | 5 | .PHONY: help 6 | help: ## Show the help. 7 | @echo "Usage: make " 8 | @echo "" 9 | @echo "Targets:" 10 | @fgrep "##" Makefile | fgrep -v fgrep 11 | 12 | 13 | .PHONY: show 14 | show: ## Show the current environment. 15 | @echo "Current environment:" 16 | @if [ "$(USING_POETRY)" ]; then poetry env info && exit; fi 17 | @echo "Running using $(ENV_PREFIX)" 18 | @$(ENV_PREFIX)python -V 19 | @$(ENV_PREFIX)python -m site 20 | 21 | .PHONY: install 22 | install: ## Install the project in dev mode. 23 | @if [ "$(USING_POETRY)" ]; then poetry install && exit; fi 24 | @echo "Don't forget to run 'make virtualenv' if you got errors." 25 | $(ENV_PREFIX)pip install -e .[test] 26 | 27 | .PHONY: fmt 28 | fmt: ## Format code using black & isort. 29 | $(ENV_PREFIX)isort transworld/ 30 | $(ENV_PREFIX)black -l 79 transworld/ 31 | $(ENV_PREFIX)black -l 79 tests/ 32 | 33 | .PHONY: lint 34 | lint: ## Run pep8, black, mypy linters. 35 | $(ENV_PREFIX)flake8 transworld/ 36 | $(ENV_PREFIX)black -l 79 --check transworld/ 37 | $(ENV_PREFIX)black -l 79 --check tests/ 38 | $(ENV_PREFIX)mypy --ignore-missing-imports transworld/ 39 | 40 | .PHONY: test 41 | test: lint ## Run tests and generate coverage report. 42 | $(ENV_PREFIX)pytest -v --cov-config .coveragerc --cov=transworld -l --tb=short --maxfail=1 tests/ 43 | $(ENV_PREFIX)coverage xml 44 | $(ENV_PREFIX)coverage html 45 | 46 | .PHONY: watch 47 | watch: ## Run tests on every change. 48 | ls **/**.py | entr $(ENV_PREFIX)pytest -s -vvv -l --tb=long --maxfail=1 tests/ 49 | 50 | .PHONY: clean 51 | clean: ## Clean unused files. 52 | @find ./ -name '*.pyc' -exec rm -f {} \; 53 | @find ./ -name '__pycache__' -exec rm -rf {} \; 54 | @find ./ -name 'Thumbs.db' -exec rm -f {} \; 55 | @find ./ -name '*~' -exec rm -f {} \; 56 | @rm -rf .cache 57 | @rm -rf .pytest_cache 58 | @rm -rf .mypy_cache 59 | @rm -rf build 60 | @rm -rf dist 61 | @rm -rf *.egg-info 62 | @rm -rf htmlcov 63 | @rm -rf .tox/ 64 | @rm -rf docs/_build 65 | 66 | .PHONY: virtualenv 67 | virtualenv: ## Create a virtual environment. 68 | @if [ "$(USING_POETRY)" ]; then poetry install && exit; fi 69 | @echo "creating virtualenv ..." 70 | @rm -rf .venv 71 | @python3 -m venv .venv 72 | @./.venv/bin/pip install -U pip 73 | @./.venv/bin/pip install -e .[test] 74 | @echo 75 | @echo "!!! Please run 'source .venv/bin/activate' to enable the environment !!!" 76 | 77 | .PHONY: release 78 | release: ## Create a new tag for release. 79 | @echo "WARNING: This operation will create s version tag and push to github" 80 | @read -p "Version? (provide the next x.y.z semver) : " TAG 81 | @echo "$${TAG}" > transworld/VERSION 82 | @$(ENV_PREFIX)gitchangelog > HISTORY.md 83 | @git add transworld/VERSION HISTORY.md 84 | @git commit -m "release: version $${TAG} 🚀" 85 | @echo "creating git tag : $${TAG}" 86 | @git tag $${TAG} 87 | @git push -u origin HEAD --tags 88 | @echo "Github Actions will detect the new tag and release the new version." 89 | 90 | .PHONY: docs 91 | docs: ## Build the documentation. 92 | @echo "building documentation ..." 93 | @$(ENV_PREFIX)mkdocs build 94 | URL="site/index.html"; xdg-open $$URL || sensible-browser $$URL || x-www-browser $$URL || gnome-open $$URL 95 | 96 | .PHONY: switch-to-poetry 97 | switch-to-poetry: ## Switch to poetry package manager. 98 | @echo "Switching to poetry ..." 99 | @if ! poetry --version > /dev/null; then echo 'poetry is required, install from https://python-poetry.org/'; exit 1; fi 100 | @rm -rf .venv 101 | @poetry init --no-interaction --name=a_flask_test --author=rochacbruno 102 | @echo "" >> pyproject.toml 103 | @echo "[tool.poetry.scripts]" >> pyproject.toml 104 | @echo "transworld = 'transworld.__main__:main'" >> pyproject.toml 105 | @cat requirements.txt | while read in; do poetry add --no-interaction "$${in}"; done 106 | @cat requirements-test.txt | while read in; do poetry add --no-interaction "$${in}" --dev; done 107 | @poetry install --no-interaction 108 | @mkdir -p .github/backup 109 | @mv requirements* .github/backup 110 | @mv setup.py .github/backup 111 | @echo "You have switched to https://python-poetry.org/ package manager." 112 | @echo "Please run 'poetry shell' or 'poetry run transworld'" 113 | 114 | .PHONY: init 115 | init: ## Initialize the project based on an application template. 116 | @./.github/init.sh 117 | 118 | 119 | # This project has been generated from rochacbruno/python-project-template 120 | # __author__ = 'rochacbruno' 121 | # __repo__ = https://github.com/rochacbruno/python-project-template 122 | # __sponsor__ = https://github.com/sponsors/rochacbruno/ 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # TransWorldNG 3 | 4 | TransWorldNG: Empower Traffic Simulation via Foundation Model 5 | 6 | TransWorldNG is a cutting-edge traffic simulation model that utilizes a foundation model to accurately model complex traffic behavior and relationships from real-world data. 7 | 8 | Key features include: 9 | 10 | 1. A unified and adaptable data structure for modeling diverse agents and relationships within complex traffic systems. 11 | 2. A heterogeneous graph learning framework that automatically generates behavior models by learning from complex traffic data. 12 | 13 | 14 | # Citing TransWorldNG 15 | 16 | If you use TransWorldNG in your research, please cite the paper. 17 | 18 | Wang, D., Wang, X., Chen, L., Yao, S., Jing, M., Li, H., Li, L., Bao, S., Wang, F.Y. and Lin, Y., 2023. TransWorldNG: Traffic Simulation via Foundation Model. arXiv preprint arXiv:2305.15743. 19 | 20 | In BibTeX format: 21 | 22 | ``` 23 | @article{wang2023transworldng, 24 | title={TransWorldNG: Traffic Simulation via Foundation Model}, 25 | author={Wang, Ding and Wang, Xuhong and Chen, Liang and Yao, Shengyue and Jing, Ming and Li, Honghai and Li, Li and Bao, Shiqiang and Wang, Fei-Yue and Lin, Yilun}, 26 | journal={arXiv preprint arXiv:2305.15743}, 27 | year={2023} 28 | } 29 | ``` 30 | 31 | # Environment 32 | Make sure you have all the necessary dependencies installed before running the above commands. You can install the dependencies by running `pip install -r requirements.txt`. 33 | 34 | 35 | # Getting Started 36 | 37 | ## Generating example data with SUMO 38 | 1. Navigate to transworldNG/experiment/gen_data.sh. 39 | 2. Modify the gen_data.sh file to specify scenarios and parameters. 40 | 3. Run gen_data.sh to generate data: 41 | 42 | ``` 43 | ./gen_data.sh 44 | ``` 45 | 46 | ## Running TransWorldNG simulation 47 | 1. Navigate to transWorldNG/transworld/run.sh 48 | 2. Modify the run.sh file to specify scenarios and parameters. 49 | 3. Run run.sh file. The results will be saved in the original data folder. Alternatively, you can modify the settings in transworld_exp.py. 50 | ``` 51 | ./run.sh 52 | ``` 53 | 54 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to MkDocs 2 | 3 | For full documentation visit [mkdocs.org](https://www.mkdocs.org). 4 | 5 | ## Commands 6 | 7 | * `mkdocs new [dir-name]` - Create a new project. 8 | * `mkdocs serve` - Start the live-reloading docs server. 9 | * `mkdocs build` - Build the documentation site. 10 | * `mkdocs -h` - Print help message and exit. 11 | 12 | ## Project layout 13 | 14 | mkdocs.yml # The configuration file. 15 | docs/ 16 | index.md # The documentation homepage. 17 | ... # Other markdown pages, images and other files. 18 | -------------------------------------------------------------------------------- /experiment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SACLabs/TransWorldNG/906c9283ed8e8121b0650869a4cdef4ab36c3848/experiment/__init__.py -------------------------------------------------------------------------------- /experiment/batch_run.sh: -------------------------------------------------------------------------------- 1 | nohup python -m sumo_exp --scenario_name 'traci_tls' --target_step '100' --rou_file 'cross.rou1.xml' --run_file 'run.sumocfg1' & 2 | nohup python -m sumo_exp --scenario_name 'traci_tls' --target_step '100' --rou_file 'cross.rou2.xml' --run_file 'run.sumocfg2' & 3 | nohup python -m sumo_exp --scenario_name 'traci_tls' --target_step '100' --rou_file 'cross.rou3.xml' --run_file 'run.sumocfg3' & 4 | nohup python -m sumo_exp --scenario_name 'traci_tls' --target_step '100' --rou_file 'cross.rou4.xml' --run_file 'run.sumocfg4' & 5 | nohup python -m sumo_exp --scenario_name 'traci_tls' --target_step '100' --rou_file 'cross.rou5.xml' --run_file 'run.sumocfg5' & 6 | 7 | 8 | 9 | nohup python -m sumo_exp --scenario_name 'traci_tls' --target_step '100' --rou_file 'cross.rou5.xml' --run_file 'run.sumocfg5' & 10 | nohup python -m sumo_exp --scenario_name 'highd' --target_step '500' --rou_file 'route.xml' --run_file 'freeway.sumo.cfg' & -------------------------------------------------------------------------------- /experiment/bologna_clean/joined_detectors.add.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | --> 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /experiment/bologna_clean/joined_vtypes.add.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /experiment/bologna_clean/run.sumocfg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /experiment/bologna_clean/sumo_log.txt: -------------------------------------------------------------------------------- 1 | ***Starting server on port 33223 *** 2 | Loading net-file from '/mnt/workspace/wangding/Desktop/tsim/experiment/bologna_clean/clean.net.xml' ... done (12ms). 3 | Loading additional-files from '/mnt/workspace/wangding/Desktop/tsim/experiment/bologna_clean/joined_detectors.add.xml' ... done (7ms). 4 | Loading additional-files from '/mnt/workspace/wangding/Desktop/tsim/experiment/bologna_clean/joined_vtypes.add.xml' ... done (1ms). 5 | Loading additional-files from '/mnt/workspace/wangding/Desktop/tsim/experiment/bologna_clean/joined_tls.add.xml' ... done (1ms). 6 | Loading done. 7 | Simulation version 1.17.0 started with time: 0.00. 8 | Warning: Vehicle 'Costa_200_548' performs emergency braking on lane 'a77cd_0' with decel=9.00, wished=4.50, severity=1.00, time=1480.00. 9 | Simulation ended at time: 4633.00 10 | Reason: TraCI requested termination. 11 | Performance: 12 | Duration: 3578.15s 13 | TraCI-Duration: 3145.59s 14 | Real time factor: 1.2948 15 | UPS: 907.840033 16 | Vehicles: 17 | Inserted: 11000 18 | Running: 0 19 | Waiting: 0 20 | 21 | -------------------------------------------------------------------------------- /experiment/bologna_clean/tripinfos.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /experiment/traci_tls/cross.nod.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /experiment/traci_tls/run.sumocfg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /experiment/traci_tls/run.sumocfg1: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /experiment/traci_tls/run.sumocfg2: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /experiment/traci_tls/run.sumocfg3: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /experiment/traci_tls/run.sumocfg4: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /experiment/traci_tls/run.sumocfg5: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: transworld 2 | theme: readthedocs 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dgl==1.0.1+cu116 2 | matplotlib==3.7.1 3 | numpy==1.23.5 4 | pandas==1.5.3 5 | pytest==7.3.1 6 | scikit_learn==1.2.0 7 | setuptools==65.6.3 8 | sumolib==1.17.0 9 | torch==1.12.1 10 | tqdm==4.64.1 11 | traci==1.17.0 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Python setup.py for transworld package""" 2 | import io 3 | import os 4 | from setuptools import find_packages, setup 5 | 6 | 7 | def read(*paths, **kwargs): 8 | """Read the contents of a text file safely. 9 | >>> read("transworld", "VERSION") 10 | '0.1.0' 11 | >>> read("README.md") 12 | ... 13 | """ 14 | 15 | content = "" 16 | with io.open( 17 | os.path.join(os.path.dirname(__file__), *paths), 18 | encoding=kwargs.get("encoding", "utf8"), 19 | ) as open_file: 20 | content = open_file.read().strip() 21 | return content 22 | 23 | 24 | def read_requirements(path): 25 | return [ 26 | line.strip() 27 | for line in read(path).split("\n") 28 | if not line.startswith(('"', "#", "-", "git+")) 29 | ] 30 | 31 | 32 | setup( 33 | name="transworld", 34 | version=read("transworld", "VERSION"), 35 | description="Awesome transworld created by PJSAC", 36 | url="https://github.com/PJSAC/TransWorld/", 37 | long_description=read("README.md"), 38 | long_description_content_type="text/markdown", 39 | author="PJSAC", 40 | packages=find_packages(exclude=["tests", ".github"]), 41 | install_requires=read_requirements("requirements.txt"), 42 | entry_points={ 43 | "console_scripts": ["transworld = transworld.__main__:main"] 44 | }, 45 | extras_require={"test": read_requirements("requirements-test.txt")}, 46 | ) 47 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SACLabs/TransWorldNG/906c9283ed8e8121b0650869a4cdef4ab36c3848/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pytest 3 | 4 | 5 | # each test runs on cwd to its temp dir 6 | @pytest.fixture(autouse=True) 7 | def go_to_tmpdir(request): 8 | # Get the fixture dynamically by its name. 9 | tmpdir = request.getfixturevalue("tmpdir") 10 | # ensure local test created packages can be imported 11 | sys.path.insert(0, str(tmpdir)) 12 | # Chdir only for the duration of the test. 13 | with tmpdir.as_cwd(): 14 | yield 15 | -------------------------------------------------------------------------------- /tests/test_base.py: -------------------------------------------------------------------------------- 1 | from transworld.base import NAME 2 | 3 | 4 | def test_base(): 5 | assert NAME == "tsim" 6 | -------------------------------------------------------------------------------- /tests/test_graph_process.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import pandas as pd 4 | 5 | from transworld.graph.process import ( 6 | generate_unique_node_id, 7 | generate_graph_dict, 8 | generate_feat_dict, 9 | ) 10 | 11 | 12 | class TestGraphProcess(unittest.TestCase): 13 | def setUp(self) -> None: 14 | return super().setUp() 15 | 16 | def test_generate_unique_node_id(self): 17 | node_data = pd.DataFrame({"step": [0], "name": ["node1"], "type": ["test"]}) 18 | target_node_id = {"node1": 0} 19 | node_id = generate_unique_node_id(node_data) 20 | self.assertEqual(node_id, target_node_id) 21 | 22 | def test_generate_graph_dict(self): 23 | edge_data = pd.DataFrame( 24 | { 25 | "step": [0, 1], 26 | "from_id": [0, 1], # veh0, veh1 27 | "to_id": [1, 2], # veh1, lane2 28 | "relation": ["veh_follow_veh", "veh_on_lane"], 29 | } 30 | ) 31 | target_graph_dict = { 32 | ("veh", "follow", "veh"): ( 33 | torch.tensor([0]), 34 | torch.tensor([1]), 35 | torch.tensor([0]), 36 | ), 37 | ("veh", "on", "lane"): ( 38 | torch.tensor([1]), 39 | torch.tensor([2]), 40 | torch.tensor([1]), 41 | ), 42 | } 43 | g_dict = generate_graph_dict(edge_data) 44 | self.assertEqual(g_dict, target_graph_dict) 45 | 46 | def test_generate_feat_dict(self): 47 | feat_data = pd.DataFrame( 48 | { 49 | "step": [0, 1], 50 | "name": ["veh0", "veh1"], 51 | "node_id": [0, 1], 52 | "feat_length": [5, 6], 53 | } 54 | ) 55 | type = "veh" 56 | target_feat_dict = g_feat_dict = { 57 | "veh": { 58 | 0: {"feat_length": torch.tensor([5])}, 59 | 1: {"feat_length": torch.tensor([6])}, 60 | } 61 | } 62 | g_feat_dict = generate_feat_dict(type, feat_data) 63 | self.assertEqual(g_feat_dict, target_feat_dict) 64 | 65 | def tearDown(self) -> None: 66 | return super().tearDown() 67 | -------------------------------------------------------------------------------- /tests/test_sumo_env.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from experiment.sumo_env import get_node_data, get_edge_data 4 | 5 | 6 | class TestSumoEnv(unittest.TestCase): 7 | def setUp(self) -> None: 8 | return super().setUp() 9 | 10 | def test_get_node_data(self): 11 | step = 0 12 | type = "test" 13 | source_list = ["node1", "node2"] 14 | 15 | operation_list = [str] 16 | # step, node_name, type 17 | target_node_list = [] 18 | target_node_list.append([0, "node1", "test"]) 19 | target_node_list.append([0, "node2", "test"]) 20 | # step,node_name,operation_results 21 | target_feat_list = [] 22 | target_feat_list.append([0, "node1", "node1"]) 23 | target_feat_list.append([0, "node2", "node2"]) 24 | node, feat = get_node_data(step, type, source_list, operation_list) 25 | 26 | self.assertEqual(node, target_node_list) 27 | self.assertEqual(feat, target_feat_list) 28 | 29 | def test_graph_edge_data(self): 30 | step = 0 31 | relation = "veh_follow_veh" 32 | source_list = ["node1", "node2"] 33 | operation_list = [lambda node_id: f"to_{node_id}"] 34 | # step, from, to, relation 35 | target_edge_list = [] 36 | target_edge_list.append([0, "node1", "to_node1", "veh_follow_veh"]) 37 | target_edge_list.append([0, "node2", "to_node2", "veh_follow_veh"]) 38 | edge = get_edge_data(step, relation, source_list, operation_list) 39 | self.assertEqual(edge, target_edge_list) 40 | 41 | def tearDown(self) -> None: 42 | return super().tearDown() 43 | -------------------------------------------------------------------------------- /tests/test_tsim_rule.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import pandas as pd 4 | 5 | from transworld.graph.process import ( 6 | generate_unique_node_id, 7 | generate_graph_dict, 8 | generate_feat_dict, 9 | ) 10 | 11 | 12 | class TestGraphProcess(unittest.TestCase): 13 | def setUp(self) -> None: 14 | return super().setUp() 15 | 16 | def test_generate_unique_node_id(self): 17 | node_data = pd.DataFrame({"step": [0], "name": ["node1"], "type": ["test"]}) 18 | target_node_id = {"node1": 0} 19 | node_id = generate_unique_node_id(node_data) 20 | self.assertEqual(node_id, target_node_id) 21 | 22 | def tearDown(self) -> None: 23 | return super().tearDown() 24 | -------------------------------------------------------------------------------- /transworld/VERSION: -------------------------------------------------------------------------------- 1 | 0.1.0 2 | -------------------------------------------------------------------------------- /transworld/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SACLabs/TransWorldNG/906c9283ed8e8121b0650869a4cdef4ab36c3848/transworld/__init__.py -------------------------------------------------------------------------------- /transworld/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | tsim base module. 3 | 4 | This is the principal module of the tsim project. 5 | here you put your main classes and objects. 6 | 7 | Be creative! do whatever you want! 8 | 9 | If you want to replace this with a Flask application run: 10 | 11 | $ make init 12 | 13 | and then choose `flask` as template. 14 | """ 15 | 16 | # example constant variable 17 | NAME = "tsim" 18 | -------------------------------------------------------------------------------- /transworld/finetune.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | #from graph.load import load_graph 3 | from graph.load_hd import load_graph 4 | from graph.process import generate_unique_node_id 5 | from rules.pre_process_old import load_veh_depart, pre_actions 6 | from rules.post_process_old import load_veh_route, post_actions 7 | import random 8 | import torch.nn as nn 9 | from game.model import HGT, RuleBasedGenerator, GraphStateLoss 10 | from game.data import DataLoader, Dataset 11 | from game.graph import Graph 12 | import matplotlib.pyplot as plt 13 | import torch 14 | from game.operator.transform import dgl_graph_to_graph_dict 15 | from tqdm import tqdm 16 | from datetime import datetime 17 | #from eval.eval import struc_dict_to_eval, raw_edge_to_eval, compare_feat, compare_struc 18 | import pickle 19 | import os 20 | import csv 21 | import argparse 22 | import logging 23 | import shutil 24 | import sys 25 | import pandas as pd 26 | from collections import defaultdict 27 | random.seed(3407) 28 | 29 | def print_loss(epoch, i, loss): 30 | with open('loss.csv','a') as loss_log: 31 | train_writer = csv.writer(loss_log) 32 | train_writer.writerow([str(epoch), str(i), str(round(loss,4))]) 33 | 34 | def train(timestamps, graph, batch_size, num_workers, encoder, generator, veh_route, loss_fcn, optimizer, logger, device): 35 | logger.info("========= start generate dataset =======") 36 | train_dataset = Dataset(timestamps, device, train_mode=True) 37 | train_loader = DataLoader(train_dataset, graph.operate, batch_size=batch_size, num_workers=num_workers, drop_last =True) 38 | logger.info("========== finish generate dataset =======") 39 | # graph_dicts = {} 40 | logger.info("========= start training =======") 41 | loss_list = [] 42 | for i, (cur_graphs, next_graphs) in enumerate(train_loader): # 这里for的是时间戳 43 | # 下面这个for循环后面会改成并发操作 44 | if i == 12: 45 | print() 46 | #logger.info(f"========= step{i} =======") 47 | loss = 0. 48 | for ((_, cur_graph), (seed_node_n_time, next_graph)) in zip(cur_graphs.items(), next_graphs.items()): 49 | cur_graph, next_graph = cur_graph.to(device), next_graph.to(device) 50 | assert (_.split("@")[0]) == seed_node_n_time.split("@")[0], ValueError("Dataloader Error! node_name not equal") 51 | node_type = seed_node_n_time.split("/")[0] 52 | # if node_type == 'veh': 53 | # continue 54 | time = float(seed_node_n_time.split("@")[1]) 55 | node_repr = encoder(cur_graph) 56 | actions, pred_graph = generator([seed_node_n_time], cur_graph, node_repr, veh_route) 57 | loss = loss + loss_fcn(pred_graph, next_graph) 58 | loss_list.append((loss.item()) / batch_size) 59 | #print_loss(i, loss.item()) 60 | optimizer.zero_grad() 61 | loss.backward() 62 | optimizer.step() 63 | logger.info(f"------------ loss is {sum(loss_list) / len(loss_list)} ---------") 64 | logger.info("========== finished training ==========") 65 | torch.cuda.empty_cache() 66 | return loss_list 67 | 68 | @torch.no_grad() 69 | def eval(graph, batch_size, num_workers, encoder, generator, veh_depart, veh_route, changable_feature_names, hetero_feat_dim, logger, device, training_step, perd_step): 70 | val_timestamp = [float(i) for i in range(int(training_step),int(training_step)+perd_step+1)] 71 | #print(val_timestamp) 72 | val_dataset = Dataset(val_timestamp, device, train_mode=False) 73 | val_loader = DataLoader(val_dataset, graph.operate, batch_size=batch_size, num_workers=num_workers, drop_last =False) 74 | logger.info(f"========= start eval ======= batch_size{batch_size}=======") 75 | for i, cur_graphs in tqdm(enumerate(val_loader)): 76 | val_result = {} 77 | for seed_node_n_time, cur_graph in cur_graphs.items(): 78 | cur_graph = cur_graph.to(device) 79 | node_type = seed_node_n_time.split("/")[0] 80 | if node_type == 'veh': 81 | continue 82 | time = float(seed_node_n_time.split("@")[1]) 83 | actions_pre = pre_actions(veh_depart, time, cur_graph) 84 | node_repr = encoder(cur_graph) 85 | actions, pred_graph = generator( 86 | [seed_node_n_time], cur_graph, node_repr, veh_route 87 | ) 88 | graph.states_to_feature_event(time, changable_feature_names, cur_graph, pred_graph) 89 | graph.actions_to_game_operations(actions_pre) 90 | if actions != {}: 91 | graph.actions_to_game_operations(actions) 92 | print(actions) 93 | 94 | total_graphs = val_loader.collate_tool([max(val_timestamp)+1.], max_step=perd_step) 95 | for seed_node_n_time, total_graph in total_graphs[0].items(): 96 | struc_dict, feat_dict = dgl_graph_to_graph_dict(total_graph, hetero_feat_dim) 97 | val_result[seed_node_n_time] = (struc_dict, feat_dict) 98 | logger.info("========== finished eval ==========") 99 | return val_result 100 | 101 | 102 | def create_folder(folder_path, delete_origin=False): 103 | # 这个函数的作用是,当文件夹存在就删除,然后重新创建一个新的 104 | if not os.path.exists(folder_path): 105 | # shutil.rmtree(folder_path) 106 | os.makedirs(folder_path, exist_ok=True) 107 | else: 108 | if delete_origin: 109 | shutil.rmtree(folder_path) 110 | os.makedirs(folder_path, exist_ok=True) 111 | 112 | def setup_logger(name, log_folder_path, level=logging.DEBUG): 113 | create_folder(log_folder_path) 114 | log_file = log_folder_path /f"{name}_log" 115 | handler = logging.FileHandler(log_file,encoding="utf-8",mode="a") 116 | formatter = logging.Formatter("%(asctime)s,%(msecs)d,%(levelname)s,%(name)s::%(message)s") 117 | handler.setFormatter(formatter) 118 | logger = logging.getLogger(name) 119 | logger.setLevel(level) 120 | logger.addHandler(handler) 121 | stream_handler = logging.StreamHandler(sys.stdout) 122 | stream_handler.setFormatter(formatter) 123 | logger.addHandler(stream_handler) 124 | return logger 125 | 126 | def run(scenario, test_data, training_step, pred_step, hid_dim, n_heads, n_layer, device): 127 | time_diff = [] 128 | #for test in [1,2,3,4,5]: 129 | exp_dir = Path(__file__).parent.parent / "HighD" 130 | 131 | exp_setting = exp_dir / "highway02" 132 | preTrainModel = exp_setting / 'preTrainModel' 133 | FineTuneModel = exp_setting / 'FineTuneModel' 134 | 135 | #exp_setting = exp_dir / "bologna_clean" 136 | #exp_setting = exp_dir / scenario 137 | data_dir = exp_setting / "data" 138 | train_data_dir = data_dir 139 | #test_data_dir = data_dir / "test_data" 140 | out_dir = preTrainModel / f"out_dim_{hid_dim}_n_heads_{n_heads}_n_layer_{n_layer}_pred_step_{pred_step}" 141 | 142 | 143 | 144 | name = f"scenario_{scenario}test_data_{test_data}_dim_{hid_dim}_n_heads_{n_heads}_n_layer_{n_layer}" 145 | log_folder_path = out_dir / "Log" 146 | logger = setup_logger(name, log_folder_path) 147 | logger.info(f"========== process {scenario}_{test_data}_{hid_dim}_{n_heads}_{n_layer}_pred_step_{pred_step} is running! ===========" ) 148 | isExist = os.path.exists(out_dir) 149 | if not isExist: 150 | os.makedirs(out_dir) 151 | 152 | node_all = pd.read_csv(train_data_dir / "node_all.csv") 153 | node_id_dict = generate_unique_node_id(node_all) 154 | 155 | veh_depart = load_veh_depart("veh_depart", train_data_dir, training_step) 156 | veh_route = load_veh_route("veh_route", train_data_dir) 157 | logger.info(f"========== finish load route and depart ========") 158 | # init struc_dict, feat_dict, node_id_dict 159 | 160 | struc_dict, feat_dict, node_id_dict, scalers = load_graph(train_data_dir, 0, training_step-1, node_id_dict) 161 | #test_struc, test_feat, node_id_dict, scalers = load_graph(test_data_dir) 162 | logger.info(f"========= finish load graph =========") 163 | #model parameters 164 | n_epochs = 100 #200 165 | batch_size = 50 #100 166 | num_workers = 1 #10 167 | batch_size = max(1, batch_size * num_workers) 168 | lr = 5e-4 169 | hid_dim = hid_dim 170 | n_heads = n_heads 171 | changable_feature_names = ['xAcceleration','yAcceleration'] 172 | graph = Graph(struc_dict, feat_dict) 173 | hetero_feat_dim = graph.hetero_feat_dim 174 | timestamps = graph.timestamps.float().tolist() 175 | 176 | logger.info(f"========= {n_epochs}_{batch_size}_{num_workers} =========") 177 | 178 | encoder = HGT( 179 | in_dim={ 180 | ntype: int(sum(hetero_feat_dim[ntype].values())) 181 | for ntype in hetero_feat_dim.keys() 182 | }, 183 | n_ntypes=graph.num_ntypes, 184 | n_etypes=graph.num_etypes, 185 | hid_dim=hid_dim, 186 | n_layers=n_layer, 187 | n_heads=n_heads, 188 | activation = nn.ReLU() 189 | ).to(device) 190 | 191 | 192 | generator = RuleBasedGenerator( 193 | hetero_feat_dim, 194 | n_heads * hid_dim, 195 | { 196 | ntype: int(sum(hetero_feat_dim[ntype].values())) 197 | for ntype in hetero_feat_dim.keys() 198 | }, 199 | activation = nn.ReLU(), 200 | scalers= scalers, 201 | output_activation = nn.Sigmoid() 202 | ).to(device) 203 | 204 | logger.info("========== finish generate generator rule ==========") 205 | generator.register_rule(post_actions) 206 | 207 | 208 | # Load pre-trained model 209 | encoder_path = preTrainModel / 'encorder.pth' 210 | generator_path = preTrainModel / 'generator.pth' 211 | 212 | encoder.load_state_dict(torch.load(encoder_path)) 213 | generator.load_state_dict(torch.load(generator_path)) 214 | 215 | print("========== Finish load preTrain model ==========") 216 | 217 | # loss_fcn = GraphStateLoss().to(device) 218 | # optimizer = torch.optim.Adam(list(encoder.parameters())+list(generator.parameters()), lr=lr) 219 | 220 | 221 | # # Fine-tune the pre-trained model 222 | # loss_avg = [] 223 | # for ep in tqdm(range(n_epochs)): 224 | # logger.info(f"--------- current ep is {ep} --------") 225 | # loss_lst = train(timestamps, graph, batch_size, num_workers, encoder, generator, veh_route, loss_fcn, optimizer, logger, device) 226 | # #loss_dict[f'train_{ep}_loss'] = loss_lst 227 | # #loss_avg.append(sum(loss_lst) / len(loss_lst)) 228 | 229 | 230 | # #loss_df = pd.DataFrame(loss_avg) 231 | # #loss_df.to_csv(out_dir / 'loss.csv', index=False) 232 | 233 | 234 | # # Save the fine-tuned model 235 | # torch.save(encoder.state_dict(), out_dir / 'encorder_cali.pth') 236 | # torch.save(generator.state_dict(), out_dir / 'generator_cali.pth') 237 | 238 | 239 | 240 | for i in range(10): 241 | logger.info(f"--------- current is {0+pred_step*(i+1), training_step+pred_step*(i+1)} --------") 242 | sim_graph = eval(graph, batch_size//num_workers, num_workers, encoder, generator, veh_depart, veh_route, changable_feature_names, hetero_feat_dim, logger, device, training_step+pred_step*(i+1), pred_step) 243 | #print(sim_graph['veh/0@198.0'][0][('veh','phy/to','lane')][2]) 244 | with open(out_dir / f"predicted_graph_{scenario}_{test_data}_{n_layer}_{n_heads}_{hid_dim}_{i}.p", "wb") as f: 245 | pickle.dump(sim_graph, f) 246 | graph.reset() 247 | #struc_dict, feat_dict, node_id_dict, scalers = load_graph(train_data_dir, 0+pred_step*(i+1), training_step+pred_step*(i+1), node_id_dict) 248 | struc_dict, feat_dict, node_id_dict, scalers = load_graph(train_data_dir, 0, training_step+pred_step*(i+1), node_id_dict) 249 | #veh_depart = load_veh_depart("veh_depart", train_data_dir, training_step+pred_step*(i+1)) 250 | #logger.info(f"--------- current is {0+pred_step*(i+1), training_step+pred_step*(i+1)} --------") 251 | graph = Graph(struc_dict, feat_dict) 252 | #print(0+pred_step*(i+1), training_step+pred_step*(i+1)) 253 | 254 | # with open(out_dir / f"node_id_dict_{n_layer}_{n_heads}_{hid_dim}.p", "wb") as f: 255 | # pickle.dump(node_id_dict, f) 256 | 257 | #after = datetime.now() 258 | #time_diff.append((after - before).total_seconds()) 259 | 260 | #logger.info(f"========== time_diff is : {(after - before).total_seconds()} ==========") 261 | logger.info("========== Exp has finished! ==========") 262 | 263 | 264 | if __name__ =="__main__": 265 | parser = argparse.ArgumentParser() 266 | parser.add_argument("--scenario", type=str, default='traci_tls') 267 | parser.add_argument("--test_data", type=str, default='test100') 268 | parser.add_argument("--training_step", type=int, default=1100) 269 | parser.add_argument("--pred_step", type=int, default=10) 270 | parser.add_argument("--hid_dim", type=int, default=100) 271 | parser.add_argument("--n_head", type=int, default=4) 272 | parser.add_argument("--n_layer", type=int, default=4) 273 | parser.add_argument("--gpu", type=int, default=0) 274 | args = parser.parse_args() 275 | if (not torch.cuda.is_available()) or (args.gpu == -1): 276 | device = torch.device("cpu") 277 | else: 278 | device = torch.device(f"cuda:{args.gpu}") 279 | run(args.scenario,args.test_data, args.training_step, args.pred_step, args.hid_dim, args.n_head, args.n_layer, device) -------------------------------------------------------------------------------- /transworld/game/VERSION: -------------------------------------------------------------------------------- 1 | 0.1.0 -------------------------------------------------------------------------------- /transworld/game/__init__.py: -------------------------------------------------------------------------------- 1 | from .core.init import init 2 | 3 | 4 | __all__ = ["init"] 5 | -------------------------------------------------------------------------------- /transworld/game/__main__.py: -------------------------------------------------------------------------------- 1 | """Entry point for game.""" 2 | 3 | from .cli import main # pragma: no cover 4 | 5 | if __name__ == "__main__": # pragma: no cover 6 | main() 7 | -------------------------------------------------------------------------------- /transworld/game/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | game base module. 3 | 4 | This is the principal module of the opas project. 5 | here you put your main classes and objects. 6 | 7 | Be creative! do whatever you want! 8 | 9 | If you want to replace this with a Flask application run: 10 | 11 | $ make init 12 | 13 | and then choose `flask` as template. 14 | """ 15 | 16 | # example constant variable 17 | NAME = "game" 18 | -------------------------------------------------------------------------------- /transworld/game/cli.py: -------------------------------------------------------------------------------- 1 | """CLI interface for game project. 2 | 3 | Be creative! do whatever you want! 4 | 5 | - Install click or typer and create a CLI app 6 | - Use builtin argparse 7 | - Start a web application 8 | - Import things from your .base module 9 | """ 10 | 11 | 12 | def main(): # pragma: no cover 13 | """ 14 | The main function executes on commands: 15 | `python -m game` and `$ game `. 16 | 17 | This is your program's entry point. 18 | 19 | You can change this function to do whatever you want. 20 | Examples: 21 | * Run a test suite 22 | * Run a server 23 | * Do some other stuff 24 | * Run a command line application (Click, Typer, ArgParse) 25 | * List all available tasks 26 | * Run an application (Flask, FastAPI, Django, etc.) 27 | """ 28 | print("This will do something") 29 | -------------------------------------------------------------------------------- /transworld/game/core/__init__.py: -------------------------------------------------------------------------------- 1 | from game.core.controller import Controller 2 | from game.core.typing import ( 3 | FeatureCreateFormat, 4 | FeatureRetrieveFormat, 5 | StructureCreateFormat, 6 | StructureRetrieveFormat, 7 | StructureDeleteFormat, 8 | ) 9 | 10 | __all__ = [ 11 | "Controller", 12 | "FeatureCreateFormat", 13 | "FeatureRetrieveFormat", 14 | "StructureCreateFormat", 15 | "StructureRetrieveFormat", 16 | "StructureDeleteFormat", 17 | ] 18 | -------------------------------------------------------------------------------- /transworld/game/core/config.py: -------------------------------------------------------------------------------- 1 | # protected file! 2 | # Actor Related 3 | TIMELENGTH = 1000000 4 | -------------------------------------------------------------------------------- /transworld/game/core/controller.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import numpy as np 3 | from typing import Any, Dict, List, Union, Optional, Tuple 4 | from functools import singledispatchmethod 5 | 6 | from game.core.node import Node 7 | from game.core.typing import * 8 | 9 | 10 | class Controller: 11 | instance = None 12 | first_init = False 13 | 14 | def __new__(cls, *args, **kwargs): 15 | if not cls.instance: 16 | cls.instance = super().__new__(cls) 17 | return cls.instance 18 | 19 | def __init__(self): 20 | if self.first_init: 21 | pass 22 | else: 23 | # real initialization 24 | self.first_init = True 25 | self.process_table: Dict[str, "Node"] = dict() 26 | self.dead_process_table: Dict[str, Tuple] = dict() 27 | 28 | def create_single_node(self, process_name: str): 29 | if process_name not in self.process_table: 30 | self.process_table[process_name] = Node(process_name) 31 | 32 | def distributed_storage(self, process_name_list: List) -> None: 33 | for process_name in process_name_list: 34 | self.create_single_node(process_name) 35 | 36 | def create_node(self, process_name: str): 37 | self.create_single_node(process_name) 38 | 39 | def retrieve_node(self, process_name: str, timestamp: float) -> "Node": 40 | if process_name in self.process_table: 41 | return self.process_table[process_name] 42 | elif process_name in self.dead_process_table: 43 | if self.dead_process_table[process_name][-1] > timestamp: 44 | raise ValueError(f"node is died under currnt retrieve time") 45 | else: 46 | return self.dead_process_table[process_name][0] 47 | else: 48 | raise KeyError(f"process {process_name} has not beed created!") 49 | 50 | def retrieve_node_structure( 51 | self, retrieve_message: StructureRetrieveFormat 52 | ) -> Dict: 53 | check_structure_retrieve_format(retrieve_message) 54 | process_ = self.retrieve_node(retrieve_message.name, retrieve_message.timestamp) 55 | return process_.retrieve_structure( 56 | timestamp=retrieve_message.timestamp, 57 | rank_order=retrieve_message.rank_order, 58 | ) 59 | 60 | @singledispatchmethod 61 | def check_process_exist(self, process_: Union[Tuple, List[Tuple]]) -> bool: 62 | raise NotImplementedError 63 | 64 | @check_process_exist.register 65 | def _(self, process_: tuple) -> bool: 66 | relation_, process_name = process_ 67 | if process_name in self.process_table: 68 | return True 69 | else: 70 | return False 71 | 72 | @check_process_exist.register(list) 73 | def _(self, process_list: list) -> bool: 74 | not_created_process_list = [] 75 | for process_ in process_list: 76 | if not self.check_process_exist(process_): 77 | not_created_process_list.append(process_) 78 | if len(not_created_process_list) == 0: 79 | return True 80 | else: 81 | Warning(f"process {not_created_process_list} has not been created!") 82 | return False 83 | 84 | def set_first_appear_time(self, data: Union[Tuple, List], timestamp: float): 85 | if isinstance(data, List): 86 | for tuple_ in data: 87 | _, connected_process = tuple_ 88 | self.process_table[connected_process].set_node_appear_time(timestamp) 89 | else: 90 | _, connected_process = data 91 | self.process_table[connected_process].set_node_appear_time(timestamp) 92 | 93 | def update_node(self, update_message: StructureCreateFormat): 94 | check_structure_create_format(update_message) 95 | process_ = self.retrieve_node(update_message.name, update_message.timestamp) 96 | if not self.check_process_exist(update_message.data): 97 | raise NotImplementedError(f"you must created the process first!") 98 | 99 | if update_message.operator == "add": 100 | self.set_first_appear_time(update_message.data, update_message.timestamp) 101 | process_.create_address( 102 | timestamp=update_message.timestamp, address_=update_message.data 103 | ) 104 | elif update_message.operator == "delete": 105 | process_.delete_address( 106 | timestamp=update_message.timestamp, address_=update_message.data 107 | ) 108 | else: 109 | raise KeyError(f"{update_message.operator} has not be implemented!") 110 | 111 | def delete_node(self, delete_message: StructureDeleteFormat): 112 | check_structure_delete_format(delete_message) 113 | self.dead_process_table[delete_message.name] = ( 114 | self.process_table.pop(delete_message.name), 115 | delete_message.timestamp, 116 | ) 117 | for process_name in self.process_table.keys(): 118 | self.process_table[process_name].delete_node( 119 | timestamp=delete_message.timestamp, node_name=delete_message.name 120 | ) 121 | 122 | def create_feature(self, create_feature_message: FeatureCreateFormat) -> None: 123 | check_feature_create_format(create_feature_message) 124 | process_ = self.retrieve_node( 125 | create_feature_message.name, create_feature_message.timestamp 126 | ) 127 | process_.create_feature( 128 | timestamp=create_feature_message.timestamp, 129 | feature=create_feature_message.data, 130 | ) 131 | 132 | def retrieve_feature( 133 | self, retrieve_feature_message: FeatureRetrieveFormat 134 | ) -> Dict[str, List[np.ndarray]]: 135 | process_ = self.retrieve_node( 136 | retrieve_feature_message.name, retrieve_feature_message.timestamp 137 | ) 138 | return { 139 | retrieve_feature_message.name: process_.retrieve_feature( 140 | timestamp=retrieve_feature_message.timestamp, 141 | retrieve_name=retrieve_feature_message.retrieve_name, 142 | look_back_step=retrieve_feature_message.look_back_step, 143 | ) 144 | } 145 | 146 | def crud_structure(self, operator: str, operation_list: List): 147 | return_list = [] 148 | for data in operation_list: 149 | if operator == "create": 150 | self.create_node(data) 151 | elif operator == "retrieve": 152 | return_list.append(self.retrieve_node_structure(data)) 153 | elif operator == "update": 154 | self.update_node(data) 155 | elif operator == "delete": 156 | self.delete_node(data) 157 | else: 158 | raise NotImplementedError 159 | 160 | if operator == "retrieve": 161 | return return_list 162 | 163 | def crud_feature(self, operator: str, operation_list: List): 164 | return_list = [] 165 | for data in operation_list: 166 | if operator == "create": 167 | self.create_feature(data) 168 | elif operator == "retrieve": 169 | return_list.append(self.retrieve_feature(data)) 170 | else: 171 | raise NotImplementedError 172 | return return_list 173 | 174 | @singledispatchmethod 175 | def retrieve_alive_node(self, timestamp: Union[float, List[float]]) -> List: 176 | raise NotImplementedError 177 | 178 | @retrieve_alive_node.register 179 | def _(self, timestamp: float) -> List: 180 | assert type(timestamp) == float, "timestamp must by float type" 181 | node_list = [] 182 | for alive_node_name, node in self.process_table.items(): 183 | if node.start_appear_time <= timestamp: 184 | node_list.append(alive_node_name) 185 | 186 | for death_node_name, tuple_ in self.dead_process_table.items(): 187 | # when appear time is latter than retrieve time and death time is before than retrieve time 188 | if tuple_[0].start_appear_time <= timestamp and timestamp <= tuple_[1]: 189 | node_list.append(death_node_name) 190 | 191 | return node_list 192 | 193 | @retrieve_alive_node.register 194 | def _(self, timestamp_list: list) -> List: 195 | return [self.retrieve_alive_node(timestamp) for timestamp in timestamp_list] 196 | 197 | def run(self): 198 | return super().run() 199 | 200 | @classmethod 201 | def reset(cls): 202 | cls.instance = None 203 | cls.first_init = False 204 | -------------------------------------------------------------------------------- /transworld/game/core/init.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable 2 | 3 | from game.core.controller import Controller 4 | 5 | 6 | def init(): 7 | pass 8 | -------------------------------------------------------------------------------- /transworld/game/core/node.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Set, List, Tuple, Any, Union, Callable, Optional 2 | import copy 3 | import numpy as np 4 | import bisect 5 | 6 | 7 | from functools import singledispatchmethod 8 | from collections import defaultdict 9 | from queue import Queue 10 | 11 | 12 | from game.core.config import TIMELENGTH 13 | 14 | 15 | class Node: 16 | instantces: Dict[str, "Node"] = dict() 17 | first_init: Set[str] = set() 18 | 19 | def __new__(cls, name: str, *args, **kwargs): 20 | if name not in cls.instantces: 21 | cls.instantces[name] = super().__new__(cls) 22 | return cls.instantces[name] 23 | 24 | def __init__(self, name: str): 25 | if name not in self.first_init: 26 | self.first_init.add(name) 27 | ###### real initialization ##### 28 | self.name = name 29 | self.address_timestamp_queue: Queue[float] = Queue(maxsize=TIMELENGTH + 1) 30 | self.feature_timestamp_queue: Queue[float] = Queue(maxsize=TIMELENGTH + 1) 31 | self.feature_with_timestamp: Dict[float, Dict[str, np.ndarray]] = dict() 32 | self.address_book_with_timestamp: Dict[float, Dict] = dict() 33 | self.start_appear_time = -1 34 | self.address_book: Dict[str, List] = defaultdict(list) 35 | self.last_structure_modified_timestamp = -1.0 36 | self.last_feature_modified_timestamp = -1.0 37 | 38 | def chenck_and_copy_address(self, operation_timestamp): 39 | if operation_timestamp != self.last_structure_modified_timestamp: 40 | 41 | if self.address_timestamp_queue.full(): 42 | destroy_key = self.address_timestamp_queue.get() 43 | self.address_book_with_timestamp.pop(destroy_key) 44 | 45 | self.address_timestamp_queue.put(operation_timestamp) 46 | self.address_book_with_timestamp[operation_timestamp] = copy.deepcopy( 47 | self.address_book 48 | ) 49 | self.last_structure_modified_timestamp = operation_timestamp 50 | 51 | else: 52 | self.address_book_with_timestamp[operation_timestamp] = copy.deepcopy( 53 | self.address_book 54 | ) 55 | 56 | def set_node_appear_time(self, timestamp): 57 | if self.start_appear_time < 0 or self.start_appear_time >= timestamp: 58 | self.start_appear_time = timestamp 59 | 60 | def create_address(self, timestamp, address_: Union[List, Tuple]): 61 | assert ( 62 | timestamp >= self.last_structure_modified_timestamp 63 | ), "timestamp must bigger than last structure modifed time" 64 | if self.start_appear_time < 0: 65 | self.start_appear_time = timestamp 66 | Node.add_remove_operation( 67 | address_, timestamp, self.address_book, Node.add_address 68 | ) 69 | self.chenck_and_copy_address(timestamp) 70 | 71 | def retrieve_structure( 72 | self, timestamp: float, rank_order: int = 1, node_name: Optional[str] = None 73 | ) -> Dict: 74 | # 获取end_time位置的graph结构, rank_order表示的是几阶子图 75 | structure_dict: Dict[str, Dict] = dict() 76 | search_node_name = self.name if node_name is None else node_name 77 | structure_history_timestamp_list = list( 78 | Node.instantces[search_node_name].address_book_with_timestamp.keys() 79 | ) 80 | if len(structure_history_timestamp_list) == 0: 81 | return {search_node_name: {}} 82 | 83 | sorted_timestamp_list, index = Node.search_nearest_key( 84 | structure_history_timestamp_list, timestamp 85 | ) 86 | 87 | if index == 0 and timestamp < sorted_timestamp_list[0]: 88 | Warning("Current Node is an asolate node!") 89 | structure_dict.update({search_node_name: {}}) 90 | return structure_dict 91 | 92 | nearest_timestamp = structure_history_timestamp_list[max(0, index - 1)] 93 | connected_address_book = Node.instantces[ 94 | search_node_name 95 | ].address_book_with_timestamp[nearest_timestamp] 96 | if rank_order == 1: 97 | structure_dict.update({search_node_name: connected_address_book}) 98 | return structure_dict 99 | else: 100 | for connected_node_name in connected_address_book.keys(): 101 | structure_dict.update( 102 | self.retrieve_structure( 103 | timestamp, rank_order - 1, connected_node_name 104 | ) 105 | ) 106 | structure_dict.update({search_node_name: connected_address_book}) 107 | return structure_dict 108 | 109 | def delete_address(self, timestamp: float, address_: Union[List, Tuple]) -> None: 110 | # remove算子没有这么简单,需要在一个列表中移除掉一个address 111 | Node.add_remove_operation( 112 | address_, 113 | timestamp, 114 | self.address_book, 115 | Node.remove_address, 116 | ) 117 | self.chenck_and_copy_address(timestamp) 118 | 119 | def delete_node(self, timestamp: float, node_name: str) -> None: 120 | # 这个是移除当前码本中,所有与这个node相链接的边 121 | if node_name in self.address_book: 122 | self.address_book.pop(node_name) 123 | self.chenck_and_copy_address(timestamp) 124 | 125 | def bak_feature(self, timestamp: float) -> None: 126 | if self.feature_timestamp_queue.full(): 127 | destroy_key = self.feature_timestamp_queue.get() 128 | self.feature_with_timestamp.pop(destroy_key) 129 | 130 | self.feature_timestamp_queue.put(timestamp) 131 | self.feature_with_timestamp[timestamp] = copy.deepcopy(self.feature) 132 | self.last_structure_modified_timestamp = timestamp 133 | 134 | def create_feature(self, timestamp: float, feature: Dict[str, np.ndarray]) -> None: 135 | # 先更新,后设置副本,这样就不会把None放进去了 136 | assert ( 137 | timestamp >= self.last_feature_modified_timestamp 138 | ), "timestamp must bigger than last feature modified timestamp!" 139 | self.feature = feature 140 | self.bak_feature(timestamp) 141 | 142 | def retrieve_feature( 143 | self, 144 | timestamp: float, 145 | retrieve_name: Optional[List[str]], 146 | look_back_step: Optional[int] = None, 147 | ) -> List: 148 | feature_history_timestamp_list = list(self.feature_with_timestamp.keys()) 149 | if len(feature_history_timestamp_list) == 0: 150 | return [] 151 | else: 152 | sorted_list, start_index = Node.search_nearest_key( 153 | feature_history_timestamp_list, timestamp 154 | ) 155 | if start_index == 0: 156 | if sorted_list[0] == timestamp: 157 | return [(timestamp, self.feature_with_timestamp[timestamp])] 158 | else: 159 | return [] 160 | else: 161 | if look_back_step is not None: 162 | legal_timestamp_list = sorted_list[ 163 | max(0, start_index - look_back_step) : start_index 164 | ] 165 | else: 166 | legal_timestamp_list = sorted_list[0:start_index] 167 | 168 | if retrieve_name is None: 169 | return [ 170 | (_timestamp, self.feature_with_timestamp[_timestamp]) 171 | for _timestamp in legal_timestamp_list 172 | ] 173 | else: 174 | return [ 175 | ( 176 | _timestamp, 177 | { 178 | feat_name: self.feature_with_timestamp[_timestamp][ 179 | feat_name 180 | ] 181 | for feat_name in retrieve_name 182 | }, 183 | ) 184 | for _timestamp in legal_timestamp_list 185 | ] 186 | 187 | @staticmethod 188 | def search_nearest_key(src_list, target_value): 189 | # 对src_list 进行排序,然后采用二分查找,O(nlogn + logn) 190 | sorted_list = sorted(src_list) 191 | find_index = bisect.bisect_left(src_list, target_value) 192 | return sorted_list, find_index 193 | 194 | @staticmethod 195 | def remove_address(src_list, remove_tuple): 196 | relation, _ = remove_tuple 197 | rm_relations = [tuple_ for tuple_ in src_list if relation in tuple_] 198 | if len(rm_relations) > 1: 199 | raise ValueError("Removed adddress should less then one.") 200 | elif len(rm_relations) == 0: 201 | Warning(f"{remove_tuple} not exist") 202 | else: 203 | src_list.remove(rm_relations[0]) 204 | 205 | @staticmethod 206 | def add_address(src_list, add_tuple): 207 | if add_tuple not in src_list: 208 | list.append(src_list, add_tuple) 209 | 210 | @singledispatchmethod 211 | @staticmethod 212 | def add_remove_operation( 213 | address_: Union[List, Tuple], 214 | timestamp, 215 | source_object: Dict, 216 | Operator: Callable, 217 | ) -> None: 218 | raise NotImplementedError 219 | 220 | @add_remove_operation.register 221 | @staticmethod 222 | def _(address_: list, timestamp, source_object: Dict, Operator: Callable): 223 | for tuple_ in address_: 224 | Node.add_remove_operation(tuple_, timestamp, source_object, Operator) 225 | 226 | @add_remove_operation.register 227 | @staticmethod 228 | def _(address_: tuple, timestamp, source_object: Dict, Operator: Callable): 229 | relation, connected_process = address_ 230 | Operator(source_object[connected_process], (relation, timestamp)) 231 | if len(source_object[connected_process]) == 0: 232 | source_object.pop(connected_process) 233 | 234 | @classmethod 235 | def reset(cls): 236 | cls.first_init = set() 237 | cls.instantces = dict() 238 | -------------------------------------------------------------------------------- /transworld/game/core/typing.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Dict, Optional 2 | from dataclasses import dataclass 3 | import numpy as np 4 | 5 | 6 | @dataclass 7 | class FeatureCreateFormat: 8 | name: str 9 | timestamp: float 10 | data: Dict[str, np.ndarray] 11 | 12 | 13 | def check_feature_create_format(message: FeatureCreateFormat): 14 | assert ( 15 | type(message.name) == str 16 | ), "the type of FeatureCreateFormat.name must be string" 17 | assert ( 18 | type(message.timestamp) == float 19 | ), "the type of FeatureCreateFormat.timestamp must be float" 20 | # for key, value in message.data.items(): 21 | # assert ( 22 | # type(key) == str 23 | # ), "the type of FeatureCreateFormat.data.keys() must be string" 24 | # assert ( 25 | # type(value) == np.ndarray 26 | # ), "the type of FeatureCreateFormat.data.values() must be np.ndarray" 27 | 28 | 29 | @dataclass 30 | class FeatureRetrieveFormat: 31 | name: str 32 | timestamp: float 33 | retrieve_name: Optional[List[str]] = None 34 | look_back_step: Optional[int] = None 35 | 36 | 37 | def check_feature_retrieve_format(message: FeatureRetrieveFormat): 38 | assert ( 39 | type(message.name) == str 40 | ), "the type of FeatureRetrieveFormat.name must be string" 41 | assert ( 42 | type(message.timestamp) == float 43 | ), "the type of FeatureRetrieveFormat.timestamp must be float" 44 | assert ( 45 | type(message.retrieve_name) == list or message.retrieve_name is None 46 | ), "the type of FeatureRetrieveFormat.retrieve_name must be Optional[List[str]]" 47 | assert ( 48 | type(message.look_back_step) == int or message.look_back_step is None 49 | ), "the type of FeatureRetrieveFormat.look_back_step must be Optional[int]" 50 | 51 | 52 | @dataclass 53 | class StructureCreateFormat: 54 | name: str 55 | timestamp: float 56 | operator: str 57 | data: Union[Tuple, List[Tuple]] 58 | 59 | 60 | def check_structure_create_format(message: StructureCreateFormat): 61 | assert ( 62 | type(message.name) == str 63 | ), "the type of StructureCreateFormat.name must be string" 64 | assert ( 65 | type(message.timestamp) == float 66 | ), "the type of StructureCreateFormat.name must be float" 67 | assert ( 68 | type(message.operator) == str 69 | ), "the type of StructureCreateFormat.operation must be string" 70 | assert ( 71 | type(message.data) == tuple or type(message.data) == list 72 | ), "the type of StructureCreateFormat.data must be Union[Tuple, List[Tuple]]" 73 | 74 | 75 | @dataclass 76 | class StructureRetrieveFormat: 77 | name: str 78 | timestamp: float 79 | rank_order: int = 1 80 | 81 | 82 | def check_structure_retrieve_format(message: StructureRetrieveFormat): 83 | assert ( 84 | type(message.name) == str 85 | ), "the type of StructureRetrieveFormat.name must be string" 86 | assert ( 87 | type(message.timestamp) == float 88 | ), "the type of StructureRetrieveFormat.name must be float" 89 | assert ( 90 | type(message.rank_order) == int 91 | ), "the type of StructureCreateFormat.rank_order must int" 92 | 93 | 94 | @dataclass 95 | class StructureDeleteFormat: 96 | name: str 97 | timestamp: float 98 | 99 | 100 | def check_structure_delete_format(message: StructureDeleteFormat): 101 | assert ( 102 | type(message.name) == str 103 | ), "the type of StructureDeleteFormat.name must be string" 104 | assert ( 105 | type(message.timestamp) == float 106 | ), "the type of StructureDeleteFormat.name must be float" 107 | -------------------------------------------------------------------------------- /transworld/game/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import DataLoader 2 | from .dataset import Dataset 3 | 4 | __all__ = ["DataLoader", "Dataset"] 5 | -------------------------------------------------------------------------------- /transworld/game/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from collections import defaultdict, ChainMap 4 | from copy import deepcopy 5 | from torch import Tensor 6 | from typing import ( 7 | Dict, 8 | List, 9 | Tuple, 10 | ) 11 | from game.operator.transform import game_to_dgl 12 | from game.core.typing import ( 13 | FeatureRetrieveFormat, 14 | StructureRetrieveFormat, 15 | ) 16 | 17 | 18 | class DataLoader(torch.utils.data.DataLoader): 19 | def __init__( 20 | self, 21 | dataset, 22 | operation_fn, 23 | full_graph: bool = True, 24 | need_negative_sampling: bool = False, 25 | **kwargs 26 | ) -> None: 27 | self.operation_fn = operation_fn 28 | self.dataset = dataset 29 | self.device = dataset.device 30 | self.full_graph = full_graph 31 | self.train_mode = dataset.train_mode 32 | self.need_negative_sampling = need_negative_sampling 33 | super().__init__(self.dataset, collate_fn=self.collate, **kwargs) 34 | 35 | def collate(self, batched_timestamp: List): 36 | worker_info = torch.utils.data.get_worker_info() 37 | if worker_info is not None: 38 | per_worker = int( 39 | math.ceil(len(batched_timestamp) / float(worker_info.num_workers)) 40 | ) 41 | worker_id = worker_info.id 42 | iter_start = worker_id * per_worker 43 | iter_end = min(iter_start + per_worker, len(batched_timestamp)) 44 | batched_timestamp = batched_timestamp[iter_start:iter_end] 45 | if self.train_mode is True: 46 | batched_cur_graph, batched_next_graph = [], [] 47 | for pair_timestamp in batched_timestamp: 48 | cur_graph, next_graph = self.collate_tool(list(pair_timestamp)) 49 | batched_cur_graph.append(cur_graph) 50 | batched_next_graph.append(next_graph) 51 | return ChainMap(*batched_cur_graph), ChainMap(*batched_next_graph) 52 | else: 53 | batched_cur_graph = self.collate_tool(batched_timestamp) 54 | return ChainMap(*batched_cur_graph) 55 | 56 | def collate_tool(self, batched_timestamp: List[float], max_step: int = 1): 57 | # TODO 后期加入测试 58 | nodes_pool_at_time = self.operation_fn( 59 | operation_type="alive_node", 60 | operator="retrieve", 61 | operation_list=batched_timestamp, 62 | ) 63 | if self.train_mode is True: 64 | union_nodes = set.intersection(*map(set, nodes_pool_at_time)) 65 | nodes_pool_at_time = [union_nodes for _ in range(len(batched_timestamp))] 66 | batched_structures_ = [ 67 | self.operation_fn( 68 | operation_type="structure", 69 | operator="retrieve", 70 | operation_list=[ 71 | StructureRetrieveFormat(name=node, timestamp=ts, rank_order=1) 72 | for node in nodes_pool_at_time[i] 73 | ], 74 | ) 75 | for i, ts in enumerate(batched_timestamp) 76 | ] 77 | batched_structures = [delete_empty_structure(bs) for bs in batched_structures_] 78 | 79 | batched_features_ = [ 80 | self.operation_fn( 81 | operation_type="feature", 82 | operator="retrieve", 83 | operation_list=[ 84 | FeatureRetrieveFormat( 85 | name=node, timestamp=ts, look_back_step=max_step 86 | ) 87 | for node in nodes_pool_at_time[i] 88 | ], 89 | ) 90 | for i, ts in enumerate(batched_timestamp) 91 | ] 92 | batched_features = [dictify_feature_list(bf) for bf in batched_features_] 93 | 94 | batched_structures, batched_features = self.preprocess( 95 | batched_structures, batched_features 96 | ) 97 | # if self.need_negative_sampling is True: 98 | # TODO define negative_sampling func 99 | 100 | batched_dgl_graph = [ 101 | game_to_dgl( 102 | batched_structure, 103 | batched_feature, 104 | str(batched_timestamp[i]), 105 | self.full_graph, 106 | self.device, 107 | ) 108 | for i, (batched_structure, batched_feature) in enumerate( 109 | zip(batched_structures, batched_features) 110 | ) 111 | ] 112 | return batched_dgl_graph 113 | 114 | # def worker_init_fn_(self, worker_id): 115 | # TODO: Xuhong 实现多进程加载数据,正在考虑是使用pytroch自带的多进程还是MPI实现 116 | # worker_info = torch.utils.data.get_worker_info() 117 | # dataset = worker_info.dataset # the dataset copy in this worker process 118 | # # configure the dataset to only process the split workload 119 | # per_worker = int(len(dataset.dataset) // float(worker_info.num_workers)) 120 | # worker_id = worker_info.id 121 | # dataset.dataset = dataset.dataset[worker_id*per_worker:(worker_id+1)*per_worker] 122 | # print('dataset',dataset.dataset) 123 | 124 | def preprocess( 125 | self, batched_structure: List, batched_feature: List 126 | ) -> Tuple[List, List]: 127 | # 这个函数可以让用户放入一些前处理函数,比如check前处理规则等等 128 | return batched_structure, batched_feature 129 | 130 | 131 | def dictify_feature_list(batched_features_list: List) -> Dict[str, List]: 132 | # TODO 后期加入测试 133 | # 将feature不同时刻查询的列表合并为字典 134 | batched_features: Dict = defaultdict(list) 135 | for features in deepcopy(batched_features_list): 136 | for node_name, feats in features.items(): 137 | for feat in feats: 138 | if (batched_features[node_name] == []) or ( 139 | feat[0] not in list(zip(*batched_features[node_name]))[0] 140 | ): 141 | feat = ( 142 | feat[0], 143 | {key: tensor_unsqueeze(val) for key, val in feat[1].items()}, 144 | ) 145 | batched_features[node_name].append(feat) 146 | return batched_features 147 | 148 | 149 | def tensor_unsqueeze(tensor: Tensor, target_shape: int = 3): 150 | while len(tensor.shape) < target_shape: 151 | tensor.unsqueeze_(0) 152 | assert len(tensor.shape) == 3, ValueError( 153 | "Feature shape must equals to 3. But got {}".format(tensor.shape) 154 | ) 155 | return tensor 156 | 157 | 158 | def delete_empty_structure(batched_structures: List[Dict]) -> List[Dict]: 159 | batched_structures = [ 160 | structure 161 | for structure in batched_structures 162 | if list(structure.values())[0] != {} 163 | ] 164 | return batched_structures 165 | -------------------------------------------------------------------------------- /transworld/game/data/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import List 4 | import random 5 | 6 | 7 | class Dataset(torch.utils.data.IterableDataset): 8 | def __init__(self, timestamps: List, device: torch.device, train_mode: bool): 9 | self.train_mode = train_mode 10 | self.device = device 11 | if self.train_mode is True: 12 | curr_timestamps = timestamps[:-1] 13 | next_timestamps = timestamps[1:] 14 | self.dataset = [ 15 | (cur_t, next_t) 16 | for (cur_t, next_t) in zip(curr_timestamps, next_timestamps) 17 | ] 18 | self.shuffle_() 19 | else: 20 | self.dataset = timestamps 21 | 22 | def shuffle_(self): 23 | """当需要进行数据集shuffle时,必须手动调用这个方法,pytorch DataLoader自带的shuffle参数失效""" 24 | if self.train_mode is not True: 25 | NotImplementedError( 26 | "Please do not shuffle dataset when tesing, otherwise the game.core module will not perform as expect" 27 | ) 28 | random.shuffle(self.dataset) 29 | 30 | def __iter__(self): 31 | for ts in self.dataset: 32 | yield ts 33 | -------------------------------------------------------------------------------- /transworld/game/graph/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph import Graph 2 | 3 | __all__ = [ 4 | "Graph", 5 | ] 6 | -------------------------------------------------------------------------------- /transworld/game/graph/graph.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, namedtuple, OrderedDict 2 | from functools import partial 3 | from itertools import chain 4 | from typing import Any, Dict, List, Optional, Tuple, Union 5 | 6 | import re 7 | import dgl 8 | import torch 9 | from dgl import DGLGraph 10 | from torch import Tensor 11 | 12 | from game.core.controller import Controller 13 | from game.core.typing import ( 14 | FeatureCreateFormat, 15 | StructureCreateFormat, 16 | StructureDeleteFormat, 17 | ) 18 | from game.operator.transform import ( 19 | dgl_graph_to_graph_dict, 20 | extract_feat, 21 | convert_type_and_id_to_name, 22 | ) 23 | 24 | from game.core.node import Node 25 | SubGraph = namedtuple("SubGraph", ["seed_node_id", "seed_node_type", "dgl_graph"]) 26 | 27 | 28 | class Graph(object): 29 | """_summary_ 30 | 31 | Args: 32 | object (_type_): game.Graph Object 33 | """ 34 | 35 | def __init__( 36 | self, 37 | edge_dict: Dict, 38 | feat_dict: Dict, 39 | build_directed: bool = True, 40 | ): 41 | 42 | """Internal constructor for creating a game.Graph. 43 | 44 | Args: 45 | edge_dict (Dict): Edge dictionary object. 46 | feat_dict (Dict): Node Feature dictionary object. 47 | build_directed (bool, optional): If Ture, build a directed graph; if False, build a undirected graph. Defaults to True. 48 | 49 | Examples: 50 | TODO:后续将输入数据修改为dataclass的形式 51 | 边字典:第一个tensor为源节点列表u,第二个为目标节点v,第三个为时间戳t,第四个tensor(可选)为边上的特征Fe 52 | edge_dict = { 53 | ('Car', 'follows', 'Car'): (torch.tensor([0, 1]), torch.tensor([1, 2]), torch.tensor([2321.0, 2323.1]), torch.tensor([0.06, 1.1])), 54 | ('Car', 'locates', 'Lane'): (torch.tensor([1, 1]), torch.tensor([1, 2]), torch.tensor([231.0, 233.2]), torch.tensor([1.77, 2.1])), 55 | ('Lane', 'connects', 'Lane'): (torch.tensor([0, 3]), torch.tensor([3, 4]), torch.tensor([1321.0, 2333.5]), torch.tensor([1.88, -0.16])) 56 | } 57 | 节点特征字典: 58 | feat_dict = { 59 | 'Car':{ 60 | 0:{'time':[torch.tensor([0.0]), torch.tensor([1.1]), torch.tensor([2.7]), torch.tensor([5.9])], 61 | 'feature':[torch.tensor([0, 3]), torch.tensor([0, 3]), torch.tensor([0, 3]), torch.tensor([0, 3])] 62 | }, 63 | 1:{'time':[torch.tensor([0.0]), torch.tensor([1.1]), torch.tensor([2.7]), torch.tensor([5.9])], 64 | 'feature':[torch.tensor([0, 3]), torch.tensor([0, 3]), torch.tensor([0, 3]), torch.tensor([0, 3])] 65 | }, 66 | }, 67 | 'Lane':{ 68 | 0:{'time':[torch.tensor([0.0]), torch.tensor([1.1]), torch.tensor([2.7]), torch.tensor([5.9])], 69 | 'feature':[torch.tensor([0, 3]), torch.tensor([0, 3]), torch.tensor([0, 3]), torch.tensor([0, 3])] 70 | }, 71 | 1:{'time':[torch.tensor([0.0]), torch.tensor([1.1]), torch.tensor([2.7]), torch.tensor([5.9])], 72 | 'feature':[torch.tensor([0, 3]), torch.tensor([0, 3]), torch.tensor([0, 3]), torch.tensor([0, 3])] 73 | }, 74 | } 75 | } 76 | graph = Graph(graph_dict, feat_dict, build_directed=True) 77 | """ 78 | if (edge_dict is not None) and (feat_dict is not None): 79 | self.is_directed = build_directed 80 | self.is_attr_graph = False if feat_dict is None else True 81 | self.dgl_graph = edge_dict_to_dgl(edge_dict) 82 | self.controller = Controller() 83 | self.node_list = self.store_data(self.dgl_graph, feat_dict) 84 | self.timestamps = self.collect_timestamps(self.dgl_graph) 85 | self.hetero_feat_dim = heterogenous_feature_parsing(feat_dict) 86 | 87 | @property 88 | def is_attr(self) -> bool: 89 | return self.is_attr_graph 90 | 91 | @property 92 | def num_ntypes(self) -> int: 93 | return len(self.dgl_graph.ntypes) 94 | 95 | @property 96 | def num_etypes(self) -> int: 97 | return len(self.dgl_graph.etypes) 98 | 99 | @property 100 | def num_timestamps(self) -> int: 101 | return len(self.timestamps) 102 | 103 | def operate(self, operation_type: str, operator: str, operation_list: List) -> List: 104 | results = [] 105 | if operation_type == "feature": 106 | results = self.controller.crud_feature(operator, operation_list) 107 | elif operation_type == "structure": 108 | results = self.controller.crud_structure(operator, operation_list) 109 | elif operation_type == "alive_node": 110 | results = self.controller.retrieve_alive_node(operation_list) 111 | else: 112 | ValueError("Please use correct operation_type name") 113 | return results 114 | 115 | def store_data(self, dgl_graph_: DGLGraph, feat_dict: Dict): 116 | # call controller: create isolated nodes 117 | node_list = [] 118 | for ntype in dgl_graph_.ntypes: 119 | node_list.extend( 120 | [f"{ntype}/{node_id}" for node_id in dgl_graph_.nodes(ntype).tolist()] 121 | ) 122 | 123 | self.controller.crud_structure(operator="create", operation_list=node_list) 124 | 125 | # call controller: add subgraph to each node 126 | _update_edge_events = partial(self._update_edge_events, dgl_graph_) 127 | 128 | edge_events = list( 129 | map( 130 | _update_edge_events, 131 | [ 132 | (int(node_id), ntype) 133 | for ntype in dgl_graph_.ntypes 134 | for node_id in dgl_graph_.nodes(ntype) 135 | ], 136 | ) 137 | ) 138 | self.controller.crud_structure( 139 | operator="update", operation_list=list(chain.from_iterable(edge_events)) 140 | ) 141 | 142 | # call controller: add feature to each node 143 | feature_events = list( 144 | map( 145 | _create_feature_events, 146 | [ 147 | (int(node_id), node_type, node_feats) 148 | for node_type in feat_dict.keys() 149 | for node_id, node_feats in feat_dict[node_type].items() 150 | ], 151 | ) 152 | ) 153 | self.controller.crud_feature( 154 | operator="create", operation_list=list(chain.from_iterable(feature_events)) 155 | ) 156 | 157 | return node_list 158 | 159 | def collect_timestamps(self, dgl_graph): 160 | unique_timestamps = torch.unique( 161 | torch.cat(list(dgl_graph.edata["time"].values())) 162 | ) 163 | unique_timestamps, _ = torch.sort(unique_timestamps) 164 | timestamps = unique_timestamps[1:] # 去除最早的第一个timestamp,因为没有意义 165 | return timestamps 166 | 167 | def _update_edge_events(self, dgl_graph_, args: Tuple): 168 | node_id, ntype = args 169 | dgl_SG = extract_subgraph(dgl_graph_, node_id, ntype) 170 | edge_events = dgl_to_structure_event(dgl_SG) 171 | edge_events = sorted(edge_events, key=lambda x: x.timestamp) 172 | return edge_events 173 | 174 | def actions_to_game_operations(self, actions: Dict[str, List[str]]) -> None: 175 | for node_name_n_time, action_list in actions.items(): 176 | for action in action_list: 177 | if re.search(".*edge", action) is not None: 178 | structure_operation = self.action_to_structure_event( 179 | node_name_n_time, action 180 | ) 181 | self.controller.crud_structure( 182 | operator="update", operation_list=[structure_operation] 183 | ) 184 | elif re.search(".*node", action) is not None: 185 | node_operation = self.action_to_node_event(node_name_n_time, action) 186 | self.controller.crud_structure( 187 | operator="delete", operation_list=[node_operation] 188 | ) 189 | else: 190 | NotImplementedError 191 | 192 | def action_to_structure_event(self, node_name_n_time: str, action: str): 193 | edge_str = re.findall(r"[(](.*?)[)]", action)[0] 194 | src_name, etype, dst_name = edge_str.split(",") 195 | operation = StructureCreateFormat( 196 | name=src_name, 197 | timestamp=round(float(node_name_n_time.split("@")[-1]), 6), 198 | operator=action.split("_")[0], 199 | data=(etype, dst_name), 200 | ) 201 | return operation 202 | 203 | def action_to_node_event(self, node_name_n_time: str, action: str): 204 | node_name = re.findall(r"[(](.*?)[)]", action)[0] 205 | operation = StructureDeleteFormat( 206 | name=node_name, 207 | timestamp=round(float(node_name_n_time.split("@")[-1]), 6), 208 | ) 209 | return operation 210 | 211 | def states_to_feature_event( 212 | self, 213 | time: float, 214 | changable_feature_names: List[str], 215 | cur_graph: DGLGraph, 216 | pred_graph: DGLGraph, 217 | ): 218 | cur_feat = extract_feat(cur_graph.cpu(), self.hetero_feat_dim) 219 | pred_feat = extract_feat(pred_graph.cpu(), self.hetero_feat_dim) 220 | 221 | for ntype in pred_feat.keys(): 222 | for node_id in pred_feat[ntype].keys(): 223 | feat_data = { 224 | feat_name: pred_feat[ntype][node_id][feat_name].squeeze() 225 | if feat_name in changable_feature_names 226 | else cur_feat[ntype][node_id][feat_name].squeeze() 227 | for feat_name in pred_feat[ntype][node_id].keys() 228 | } 229 | feat_event = FeatureCreateFormat( 230 | name=convert_type_and_id_to_name(ntype, node_id), 231 | timestamp=time, 232 | data=feat_data, 233 | ) 234 | self.operate("feature", "create", [feat_event]) 235 | 236 | 237 | @staticmethod 238 | def reset(): 239 | Controller.reset() 240 | Node.reset() 241 | 242 | def heterogenous_feature_parsing(feat_dict): 243 | feat_parsing_dict: Dict[str, Dict] = defaultdict(dict) 244 | for ntype, feat in feat_dict.items(): 245 | for name, tensor in list(feat.values())[0].items(): 246 | dim = tensor.shape[-1] if len(tensor.shape) != 1 else 1 247 | feat_parsing_dict[ntype][name] = dim 248 | return feat_parsing_dict 249 | 250 | 251 | def _create_feature_events(node_item: Tuple): 252 | feature_events = [] 253 | node_id, node_type, node_feats = node_item 254 | feat_name_list = list(node_feats.keys()) 255 | assert "time" in feat_name_list, "Feature Dictionary shall include time" 256 | last_timestamp = 0. #round(float(node_feats['time'][0].item() - 1.), 6) 257 | for feats in zip(*node_feats.values()): 258 | cur_timestamp = round(float(feats[feat_name_list.index("time")].item()), 6) 259 | data = { 260 | feat_name: feats[i] 261 | if feat_name != "time" 262 | else feats[i].float() - last_timestamp 263 | for i, feat_name in enumerate(feat_name_list) 264 | } 265 | feature_events.append( 266 | FeatureCreateFormat( 267 | name=f"{node_type}/{node_id}", 268 | timestamp=cur_timestamp, 269 | data=data, 270 | ) 271 | ) 272 | last_timestamp = cur_timestamp 273 | feature_events = sorted(feature_events, key=lambda x: x.timestamp) 274 | return feature_events 275 | 276 | 277 | def dgl_to_structure_event(sub_graph: SubGraph) -> List[StructureCreateFormat]: 278 | # 这个方法只支持转换子图 279 | def _id_tensors_to_edge_events(edge_item: Tuple): 280 | src_id_, dst_id_, e_id = edge_item 281 | src_id = dgl_graph_.ndata["ID"][src_type][src_id_] 282 | dst_id = dgl_graph_.ndata["ID"][dst_type][dst_id_] 283 | event = None 284 | if src_id == seed_node_id: 285 | t = dgl_graph_.edata["time"][(src_type, etype, dst_type)][e_id].item() 286 | event = StructureCreateFormat( 287 | name=f"{src_type}/{src_id}", 288 | timestamp=round(float(t), 6), 289 | operator="add", 290 | data=(etype, f"{dst_type}/{dst_id}"), 291 | ) 292 | return event 293 | 294 | edge_events = [] 295 | dgl_graph_ = sub_graph.dgl_graph 296 | seed_node_type = sub_graph.seed_node_type 297 | seed_node_id = sub_graph.seed_node_id 298 | for (src_type, etype, dst_type) in dgl_graph_.canonical_etypes: 299 | if src_type == seed_node_type: 300 | src_id_tensor, dst_id_tensor, e_id_tensor = dgl_graph_.edges( 301 | form="all", etype=(src_type, etype, dst_type) 302 | ) 303 | edge_events.extend( 304 | list( 305 | map( 306 | _id_tensors_to_edge_events, 307 | [ 308 | (src_id.item(), dst_id.item(), e_id) 309 | for (src_id, dst_id, e_id) in zip( 310 | src_id_tensor, dst_id_tensor, e_id_tensor 311 | ) 312 | ], 313 | ) 314 | ) 315 | ) 316 | edge_events = list(filter(None, edge_events)) 317 | return edge_events 318 | 319 | 320 | def extract_subgraph(dgl_graph_: DGLGraph, node_id: int, node_type: str): 321 | 322 | hop_nodes = defaultdict(list) 323 | hop_nodes[node_type].extend([node_id]) 324 | for canonical_etype in dgl_graph_.canonical_etypes: 325 | src_type, _, dst_type = canonical_etype 326 | if src_type == node_type: 327 | hop_nodes[dst_type].extend( 328 | dgl_graph_.successors(node_id, etype=canonical_etype).tolist() 329 | ) 330 | for k, v in hop_nodes.items(): 331 | hop_nodes[k] = list(set(v)) 332 | dgl_subgraph = dgl.node_subgraph(dgl_graph_, hop_nodes) 333 | dgl_subgraph.ndata["ID"] = dgl_subgraph.ndata[dgl.NID] 334 | 335 | sub_graph = SubGraph._make([node_id, node_type, dgl_subgraph]) 336 | return sub_graph 337 | 338 | 339 | def edge_dict_to_dgl(edge_dict: Dict): 340 | # 仅在用户输入数据的时候使用一次 341 | # 由于DGLGraph不认识时间tensor和特征tensor,先将时间、拓扑逻辑、边上的特征分离,在组合到DGLGraph中 342 | dgl_dict = {} 343 | time_dict = {} 344 | edge_feat_dict = {} 345 | for metapath, uvtf in edge_dict.items(): 346 | uv = (uvtf[0], uvtf[1]) 347 | t = uvtf[2] 348 | if len(uvtf) > 3: 349 | f = uvtf[3] 350 | edge_feat_dict[metapath] = f 351 | dgl_dict[metapath] = uv 352 | time_dict[metapath] = t 353 | dgl_graph_ = dgl.heterograph(dgl_dict) 354 | dgl_graph_.edata["time"] = time_dict 355 | if len(edge_feat_dict) > 0: 356 | dgl_graph_.edata["feat"] = edge_feat_dict 357 | return dgl_graph_ 358 | 359 | -------------------------------------------------------------------------------- /transworld/game/model/HGT.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | import torch.nn as nn 4 | from dgl import DGLGraph 5 | import dgl.nn.pytorch as dglnn 6 | from .linear import HeteroToHomoLinear 7 | from typing import Any, Dict, List, Union, Optional, Tuple 8 | from torch import Tensor 9 | 10 | 11 | class HGT(nn.Module): 12 | def __init__( 13 | self, 14 | in_dim: Dict[str, int], 15 | n_ntypes: int, 16 | n_etypes: int, 17 | hid_dim: int, 18 | n_layers: int, 19 | n_heads: int, 20 | dropout: float = 0.2, 21 | use_norm=True, 22 | activation: Optional[nn.Module] = None, 23 | ): 24 | """A simplest heterogenous graph neural network model. 25 | 26 | Args: 27 | in_dim (Dict[str, int]): The dimensions of different node features 28 | out_dim (Dict[str, int]): The dimensions of different node represemtation learned by HGT 29 | n_ntypes (int): Num of node types 30 | n_etypes (int): Num of edge types 31 | hid_dim (int): Dimension of universal feature sapce 32 | n_layers (int): Num of GNN layers 33 | n_heads (int): Num of attention heads 34 | dropout (float, optional): dropout. Defaults to 0.2. 35 | use_norm (bool, optional): normalization. Defaults to True. 36 | """ 37 | super(HGT, self).__init__() 38 | self.gnns = nn.ModuleList() 39 | self.linears = nn.ModuleList() 40 | self.bns = nn.ModuleList() 41 | self.hid_dim = hid_dim # 定义统一特征空间维数,将不同类型的特征映射到统一特征空间 42 | self.n_layers = n_layers 43 | self.hetero_input_projector = HeteroToHomoLinear( 44 | in_dim, hid_dim, activation=activation 45 | ) 46 | in_size_list = [hid_dim] + [hid_dim * n_heads for _ in range(n_layers - 1)] 47 | for in_size in in_size_list: 48 | self.gnns.append( 49 | dglnn.HGTConv( 50 | in_size=in_size, 51 | head_size=hid_dim, 52 | num_heads=n_heads, 53 | num_ntypes=n_ntypes, 54 | num_etypes=n_etypes, 55 | dropout=dropout, 56 | use_norm=use_norm, 57 | ) 58 | ) 59 | self.linears.append( 60 | nn.Linear(hid_dim * n_heads, hid_dim * n_heads, bias=True) 61 | ) 62 | self.bns.append(nn.BatchNorm1d(hid_dim * n_heads)) 63 | self.activation = activation 64 | 65 | def forward( 66 | self, subgraph: DGLGraph, *, hetero_output: bool = True 67 | ) -> Dict[str, Tensor]: 68 | """_summary_ 69 | 70 | Args: 71 | subgraph (DGLGraph): a suggraph that can be seen by a agent 72 | feat_name (str): user-defined feature name 73 | hetero_output (bool, optional): user can use heterogenous output or homogenous one. Defaults to True. Note that this parameter is not ready in this version. 74 | 75 | Returns: 76 | _type_: _description_ 77 | """ 78 | with subgraph.local_scope(): 79 | # 先将不同类型的节点特征映射到同一空间 80 | assert len(subgraph.ntypes) > 1, ValueError( 81 | "HGT only support heterogenous graph, but the input graph is homogenoous one." 82 | ) 83 | subgraph.ndata["repr"] = self.hetero_input_projector( 84 | subgraph.ndata["state"] 85 | ) 86 | # 然后转成同质图 87 | g = dgl.to_homogeneous(subgraph, ndata=["repr"]) 88 | ntype_indicator = g.ndata[dgl.NTYPE] 89 | etype_indicator = g.edata[dgl.ETYPE] 90 | h = g.ndata["repr"].squeeze() 91 | for i in range(self.n_layers): 92 | h = self.gnns[i](g, h, ntype_indicator, etype_indicator) 93 | if self.activation is not None: 94 | h = self.activation(h) 95 | h = self.bns[i](h) 96 | h = self.linears[i](h) 97 | if hetero_output is True: 98 | h = { 99 | ntype: h[ntype_indicator == i] 100 | for i, ntype in enumerate(subgraph.ntypes) 101 | } 102 | return h 103 | -------------------------------------------------------------------------------- /transworld/game/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .HGT import HGT 2 | from .linear import HomoToHeteroLinear, HeteroToHomoLinear 3 | from .generator import OneShotGenerator, SequentialGenerator, RuleBasedGenerator 4 | from .loss import GraphStateLoss 5 | -------------------------------------------------------------------------------- /transworld/game/model/generator.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | import torch.nn as nn 4 | from dgl import DGLGraph 5 | import dgl.nn.pytorch as dglnn 6 | from .linear import HomoToHeteroLinear 7 | from sklearn.preprocessing import StandardScaler 8 | from typing import Any, Dict, List, Union, Optional, Tuple, Callable, Generic 9 | from torch import Tensor 10 | from abc import ABC, abstractmethod 11 | from game.operator.transform import dgl_graph_to_graph_dict 12 | 13 | 14 | class OneShotGenerator(nn.Module): 15 | """Directly generate new graph from a graph 16 | 注意这种模型无法使用图补全来预测,而必须使用图生成的方法,因为图补全无法产生delete操作 17 | 18 | Args: 19 | nn (_type_): _description_ 20 | """ 21 | 22 | def __init__(self): 23 | super().__init__() 24 | 25 | def forward( 26 | self, 27 | node_name_at_time: List[str], 28 | subgraph: DGLGraph, 29 | node_repr: Dict[str, Tensor], 30 | ) -> Tuple[Dict[str, Tensor], Dict[str, List[str]], DGLGraph]: 31 | new_graph = self.generate_graph(node_name_at_time, subgraph, node_repr) 32 | node_states = self.generate_states(node_name_at_time, subgraph, node_repr) 33 | actions = self.graph_to_actions(new_graph) 34 | return node_states, actions, new_graph 35 | 36 | @abstractmethod 37 | def generate_graph( 38 | self, 39 | node_name_at_time: List[str], 40 | subgraph: DGLGraph, 41 | node_repr: Dict[str, Tensor], 42 | ) -> DGLGraph: 43 | # NN based graph generator 44 | pass 45 | 46 | @abstractmethod 47 | def generate_states( 48 | self, 49 | node_name_at_time: List[str], 50 | subgraph: DGLGraph, 51 | node_repr: Dict[str, Tensor], 52 | ) -> Dict[str, Tensor]: 53 | # NN based graph generator 54 | pass 55 | 56 | def graph_to_actions(self, subgraph: DGLGraph): 57 | # TODO 58 | # new graph to actions 59 | NotImplementedError 60 | 61 | 62 | class SequentialGenerator(nn.Module): 63 | """Generate actions step by step from a graph 64 | 65 | Args: 66 | nn (_type_): _description_ 67 | """ 68 | 69 | def __init__(self): 70 | super().__init__() 71 | 72 | def forward( 73 | self, 74 | node_name_at_time: List[str], 75 | subgraph: DGLGraph, 76 | node_repr: Dict[str, Tensor], 77 | ) -> Tuple[Dict[str, Tensor], Dict[str, List[str]], DGLGraph]: 78 | actions, node_states = self.generate_actions_and_state( 79 | node_name_at_time, subgraph, node_repr 80 | ) 81 | delta_graph = self.actions_to_graph(actions) 82 | return node_states, actions, delta_graph 83 | 84 | @abstractmethod 85 | def generate_actions_and_state( 86 | self, 87 | node_name_at_time: List[str], 88 | subgraph: DGLGraph, 89 | node_repr: Dict[str, Tensor], 90 | ) -> Tuple[Dict[str, List[str]], Dict[str, Tensor]]: 91 | # TODO NN based actions generator 92 | pass 93 | 94 | def actions_to_graph(self, actions: Dict[str, List[str]]): 95 | # TODO 96 | # actions to graph 97 | NotImplementedError 98 | 99 | 100 | class RuleBasedGenerator(nn.Module): 101 | def __init__( 102 | self, 103 | hetero_feat_dim: Dict, 104 | repr_dim: int, 105 | pred_dim: Dict[str, int], 106 | scalers: Optional[Dict] = None, 107 | activation: Optional[nn.Module] = None, 108 | output_activation: Optional[nn.Module] = None, 109 | ): 110 | super().__init__() 111 | self.hetero_feat_dim = hetero_feat_dim 112 | self.state_projector = HomoToHeteroLinear( 113 | repr_dim, 114 | pred_dim, 115 | activation=activation, 116 | output_activation=output_activation, 117 | ) 118 | self.scalers = scalers 119 | 120 | def register_rule(self, rule_func) -> None: 121 | self.rule_func = rule_func 122 | 123 | def generate_actions_and_state( 124 | self, 125 | operate_function, 126 | node_name_at_time: List[str], 127 | subgraph: DGLGraph, 128 | node_repr: Dict[str, Tensor], 129 | *args, 130 | **kwargs, 131 | ) -> Tuple[Dict[str, Tensor], Dict[str, List[str]]]: 132 | node_states, node_state_scaled = self.generate_state(node_repr) 133 | actions = operate_function( 134 | self.rule_func, 135 | node_name_at_time, 136 | subgraph, 137 | node_state_scaled, 138 | *args, 139 | **kwargs, 140 | ) 141 | return node_states, actions 142 | 143 | def generate_state( 144 | self, node_repr: Dict[str, Tensor] 145 | ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: 146 | node_state = self.state_projector(node_repr) 147 | if self.scalers is not None: 148 | node_state_scaled = { 149 | ntype: sklearn_scaling(self.scalers[ntype], nstate) 150 | for ntype, nstate in node_state.items() 151 | } 152 | return node_state, node_state_scaled 153 | 154 | def generate_actions( 155 | self, 156 | hook_fn, 157 | node_name_at_time: List[str], 158 | subgraph: DGLGraph, 159 | node_states: Dict[str, Tensor], 160 | *args, 161 | **kwargs, 162 | ) -> Dict[str, List[str]]: 163 | subgraph.ndata["state"] = node_states 164 | struc_dict, feat_dict = dgl_graph_to_graph_dict(subgraph, self.hetero_feat_dim) 165 | actions = hook_fn(node_name_at_time, struc_dict, feat_dict, *args, **kwargs) 166 | return actions 167 | 168 | def forward( 169 | self, 170 | node_name_at_time: List[str], 171 | subgraph: DGLGraph, 172 | node_repr: Dict[str, Tensor], 173 | *args, 174 | **kwargs, 175 | ) -> Tuple[Dict[str, List[str]], DGLGraph]: 176 | outgraph = deepcopy(subgraph) 177 | node_states, actions = self.generate_actions_and_state( 178 | self.generate_actions, 179 | node_name_at_time, 180 | outgraph, 181 | node_repr, 182 | *args, 183 | **kwargs, 184 | ) 185 | outgraph.ndata["state"] = node_states 186 | return actions, outgraph 187 | 188 | def output_graph_dict( 189 | self, 190 | subgraph: DGLGraph, 191 | node_repr: Dict[str, Tensor], 192 | *args, 193 | **kwargs, 194 | ) -> Tuple[Dict, Dict]: 195 | node_states = self.generate_state(node_repr) 196 | struc_dict, feat_dict = dgl_graph_to_graph_dict(subgraph, self.hetero_feat_dim) 197 | return struc_dict, feat_dict 198 | 199 | 200 | def sklearn_scaling(scaler, X: Tensor) -> Tensor: 201 | X_need_scale = X[:, :-1] 202 | X_no_scale = X[:, [-1]] 203 | X_scaled = ( 204 | X_need_scale - torch.tensor(scaler.min_).float().to(X_need_scale.device) 205 | ) / torch.tensor(scaler.scale_).float().to(X_need_scale.device) 206 | X = torch.cat([X_scaled, X_no_scale], dim=1) 207 | return X 208 | -------------------------------------------------------------------------------- /transworld/game/model/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Any, Dict, List, Union, Optional, Tuple 4 | 5 | 6 | class HomoToHeteroLinear(nn.Module): 7 | def __init__( 8 | self, 9 | in_size: int, 10 | out_size: Dict[str, int], 11 | activation: Optional[nn.Module] = None, 12 | output_activation: Optional[nn.Module] = None, 13 | bias: bool = True, 14 | ): 15 | """Apply linear transformations from homogeneous inputs to heterogeneous inputs. 16 | 17 | Args: 18 | in_size (int): Input feature size. 19 | out_size (Dict[str, int]): Output feature size for heterogeneous inputs. A key can be a string or a tuple of strings. 20 | activation (torch.nn.Module, optional): activative function. Defaults to None. 21 | bias (bool, optional): bias of network parameters. Defaults to True. 22 | """ 23 | super(HomoToHeteroLinear, self).__init__() 24 | self.activation = activation 25 | self.linears = nn.ModuleDict() 26 | self.projector = nn.Linear(in_size, in_size // 2, bias=bias) 27 | assert isinstance(in_size, int), "input size should be int" 28 | assert isinstance(out_size, dict), "output size should be dict" 29 | self.output_activation = output_activation 30 | for typ, dim in out_size.items(): 31 | self.linears[str(typ)] = nn.Linear(in_size // 2, dim, bias=bias) 32 | 33 | def forward(self, feat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 34 | """Forward Function 35 | 36 | Args: 37 | feat (Dict[str, torch.Tensor]): Heterogeneous input features. It maps keys to features. 38 | 39 | Returns: 40 | Dict[str, torch.Tensor]: Transformed features. 41 | """ 42 | out_feat = dict() 43 | 44 | for typ, typ_feat in feat.items(): 45 | out = typ_feat 46 | out = self.projector(out) 47 | if self.activation: 48 | out = self.activation(out) 49 | out = self.linears[str(typ)](out) 50 | if self.output_activation: 51 | out = self.output_activation(out) 52 | out_feat[typ] = out 53 | return out_feat 54 | 55 | 56 | class HeteroToHomoLinear(nn.Module): 57 | def __init__( 58 | self, 59 | in_size: Dict[str, int], 60 | out_size: int, 61 | activation: Optional[nn.Module] = None, 62 | bias: bool = True, 63 | ): 64 | """Apply linear transformations on heterogeneous inputs. 65 | 66 | Args: 67 | in_size (Dict[str, int]): Input feature size for heterogeneous inputs. A key can be a string or a tuple of strings. 68 | out_size (int): Output feature size. 69 | activation (nn.Module, optional): activative function.. Defaults to None. 70 | bias (bool, optional): bias of network parameters. Defaults to True. 71 | 72 | Examples: 73 | >>> import torch 74 | >>> from dgl.nn import HeteroLinear 75 | 76 | >>> layer = HeteroLinear({'user': 1, ('user', 'follows', 'user'): 2}, 3) 77 | >>> in_feats = {'user': torch.randn(2, 1), ('user', 'follows', 'user'): torch.randn(3, 2)} 78 | >>> out_feats = layer(in_feats) 79 | >>> print(out_feats['user'].shape) 80 | torch.Size([2, 3]) 81 | >>> print(out_feats[('user', 'follows', 'user')].shape) 82 | torch.Size([3, 3]) 83 | """ 84 | super(HeteroToHomoLinear, self).__init__() 85 | self.activation = activation 86 | self.linears = nn.ModuleDict() 87 | for typ, typ_in_size in in_size.items(): 88 | self.linears[str(typ)] = nn.Linear(typ_in_size, out_size // 2, bias=bias) 89 | self.projector = nn.Linear(out_size // 2, out_size, bias=bias) 90 | 91 | def forward(self, feat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 92 | """Forward function 93 | Args: 94 | feat (Dict[str, torch.Tensor]): Heterogeneous input features. It maps keys to features. 95 | 96 | Returns: 97 | Dict[str, torch.Tensor]: Transformed features. 98 | """ 99 | out_feat = dict() 100 | for typ, typ_feat in feat.items(): 101 | if len(typ_feat.shape) > 2: 102 | typ_feat = typ_feat[ 103 | :, -1, : 104 | ] # 由于这个模块不支持时序建模,如果输入的数据是多步时间步的,那么直接简单地取最后的一步 105 | out = self.linears[str(typ)](typ_feat) 106 | if self.activation is not None: 107 | out = self.activation(out) 108 | out = self.projector(out) 109 | out_feat[typ] = out 110 | return out_feat 111 | -------------------------------------------------------------------------------- /transworld/game/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dgl import DGLGraph 4 | from typing import List 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | 9 | class GraphStateLoss(nn.Module): 10 | def __init__(self, weight=None, size_average=True): 11 | super(GraphStateLoss, self).__init__() 12 | 13 | def forward(self, pred_graph: DGLGraph, tar_graph: DGLGraph): 14 | loss = torch.Tensor([0.0]).to(pred_graph.device) 15 | total_node_type = get_intersection(pred_graph.ntypes, tar_graph.ntypes) 16 | for node_type in total_node_type: 17 | if tensor_is_equal( 18 | pred_graph.ndata["ID"][node_type], tar_graph.ndata["ID"][node_type] 19 | ): 20 | pred_state = pred_graph.ndata["state"][node_type] 21 | tar_state = ( 22 | tar_graph.ndata["state"][node_type][:, -1, :] 23 | if len(tar_graph.ndata["state"][node_type].shape) > 2 24 | else tar_graph.ndata["state"][node_type] 25 | ) 26 | else: 27 | node_id_of_pred = pred_graph.ndata["ID"][node_type].tolist() 28 | node_id_of_tar = tar_graph.ndata["ID"][node_type].tolist() 29 | comm_node_id = get_intersection(node_id_of_pred, node_id_of_tar) 30 | pred_state = torch.cat( 31 | [ 32 | pred_graph.ndata["state"][node_type][idx].unsqueeze(0) 33 | for idx, nid in enumerate(node_id_of_pred) 34 | if nid in comm_node_id 35 | ] 36 | ) 37 | tar_state = torch.cat( 38 | [ 39 | tar_graph.ndata["state"][node_type][idx, -1, :].unsqueeze(0) 40 | for idx, nid in enumerate(node_id_of_tar) 41 | if nid in comm_node_id 42 | ] 43 | ) 44 | loss_ = F.mse_loss(pred_state, tar_state) # 只计算最后一个step的state,历史数据不需要计算 45 | loss = loss + loss_ 46 | return loss 47 | 48 | 49 | def get_intersection(ListA: List, ListB: List) -> List: 50 | return list(set(ListA).intersection(set(ListB))) 51 | 52 | 53 | def tensor_is_equal(TensorA: Tensor, TensorB: Tensor) -> bool: 54 | if len(TensorA) == len(TensorB) and (TensorA == TensorB).all().item(): 55 | return True 56 | else: 57 | return False 58 | -------------------------------------------------------------------------------- /transworld/game/operator/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /transworld/game/operator/transform.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | Any, 3 | Dict, 4 | Callable, 5 | Iterable, 6 | TypeVar, 7 | Generic, 8 | Sequence, 9 | List, 10 | Optional, 11 | Union, 12 | Optional, 13 | Tuple, 14 | ) 15 | from functools import partial, reduce 16 | from copy import deepcopy 17 | from collections import defaultdict, OrderedDict, ChainMap 18 | import dgl 19 | from dgl import DGLGraph 20 | import torch 21 | from torch.nn import functional as F 22 | from torch import Tensor 23 | 24 | 25 | def game_to_dgl( 26 | batched_structures: List, 27 | batched_features: Dict, 28 | timestamp: str, 29 | full_graph: bool = True, 30 | device: torch.device = torch.device("cpu"), 31 | ) -> Dict[str, List[DGLGraph]]: 32 | batched_dgl_graph = OrderedDict() 33 | if not full_graph: 34 | for structure in batched_structures: 35 | if list(structure.values())[0] == {}: 36 | # 如果找不到这个节点,就不返回 37 | continue 38 | seed_node_name = list(structure.keys())[0] 39 | dgl_graph = structure_to_dgl_graph(structure) 40 | dgl_graph = dgl.compact_graphs(dgl_graph) 41 | dgl_graph.ndata["ID"] = dgl_graph.ndata["_ID"] 42 | del dgl_graph.ndata["_ID"] 43 | node_feat_dict = feature_to_dgl_feature(dgl_graph, batched_features) 44 | 45 | dgl_graph = dgl_graph_join_feature(dgl_graph, node_feat_dict) 46 | dgl_graph = attach_node_state(dgl_graph) 47 | batched_dgl_graph[seed_node_name + "@" + str(timestamp)] = dgl_graph 48 | else: 49 | structure = ChainMap(*batched_structures) 50 | seed_node_name = "full/0" 51 | dgl_graph = structure_to_dgl_graph(structure) 52 | dgl_graph = dgl.compact_graphs(dgl_graph) 53 | dgl_graph.ndata["ID"] = dgl_graph.ndata["_ID"] 54 | del dgl_graph.ndata["_ID"] 55 | node_feat_dict = feature_to_dgl_feature(dgl_graph, batched_features) 56 | 57 | dgl_graph = dgl_graph_join_feature(dgl_graph, node_feat_dict) 58 | dgl_graph = attach_node_state(dgl_graph) 59 | batched_dgl_graph[seed_node_name + "@" + str(timestamp)] = dgl_graph 60 | return batched_dgl_graph 61 | 62 | 63 | def attach_node_state(subgraph: DGLGraph): 64 | for ntype in subgraph.ntypes: 65 | feat_names = subgraph.nodes[ntype].data.keys() 66 | subgraph.nodes[ntype].data["state"] = torch.cat( 67 | [ 68 | subgraph.nodes[ntype].data[feat_name] 69 | for feat_name in feat_names 70 | if "ID" not in feat_name 71 | ], 72 | dim=2, 73 | ) 74 | return subgraph 75 | 76 | 77 | def update_edge_and_time_tensor_( 78 | dst_list: List, 79 | dgl_structure_dict: Dict, 80 | dgl_edge_time_dict: Dict, 81 | src_type: str, 82 | src_id: int, 83 | ): 84 | """push edge information to dgl_structure_dict and dgl_edge_time_dict, from dst_list 85 | 86 | Args: 87 | dgl_structure_dict (_type_): _description_ 88 | dgl_edge_time_dict (_type_): _description_ 89 | src_type (_type_): _description_ 90 | src_id (_type_): _description_ 91 | dst_list (_type_): _description_ 92 | """ 93 | dst_name, edges = dst_list 94 | dst_type, dst_id = convert_name_to_type_and_id(dst_name) 95 | dst_id = torch.tensor(int(dst_id)) 96 | for etype, edge_time in edges: 97 | dgl_structure_dict[(src_type, etype, dst_type)] = torch.cat( 98 | [ 99 | dgl_structure_dict[(src_type, etype, dst_type)], 100 | torch.tensor((src_id, dst_id)).unsqueeze(1), 101 | ], 102 | dim=1, 103 | ).long() 104 | dgl_edge_time_dict[(src_type, etype, dst_type)] = torch.cat( 105 | [ 106 | dgl_edge_time_dict[(src_type, etype, dst_type)], 107 | torch.tensor((edge_time,)), 108 | ], 109 | dim=0, 110 | ) 111 | 112 | 113 | def structure_to_dgl_graph(structure: Dict) -> DGLGraph: 114 | dgl_structure_dict: Dict = defaultdict(Tensor) 115 | dgl_edge_time_dict: Dict = defaultdict(Tensor) 116 | for seed_node in list(structure.keys()): 117 | src_type, src_id = convert_name_to_type_and_id(seed_node) 118 | src_id = torch.tensor(int(src_id)) 119 | update_dicts_ = partial( 120 | update_edge_and_time_tensor_, 121 | dgl_structure_dict=dgl_structure_dict, 122 | dgl_edge_time_dict=dgl_edge_time_dict, 123 | src_type=src_type, 124 | src_id=src_id, 125 | ) 126 | list(map(update_dicts_, list(structure[seed_node].items()))) 127 | for key, values in dgl_structure_dict.items(): 128 | dgl_structure_dict[key] = tuple(values) 129 | dgl_graph = dgl.heterograph(dgl_structure_dict) 130 | dgl_graph.edata["time"] = ( 131 | dgl_edge_time_dict 132 | if len(dgl_graph.etypes) > 1 133 | else list(dgl_edge_time_dict.values())[0] 134 | ) 135 | return dgl_graph 136 | 137 | 138 | def feature_to_dgl_feature(dgl_graph: DGLGraph, batched_features: Dict) -> Dict: 139 | node_feat_dict: Dict = defaultdict(lambda: defaultdict(list)) 140 | for node_type in dgl_graph.ntypes: 141 | selected_feature = [ 142 | batched_features[convert_type_and_id_to_name(node_type, node_id)] 143 | for node_id in dgl_graph.nodes[node_type].data["ID"] 144 | ] 145 | 146 | time_merged_feature = [] 147 | for node_feat in deepcopy(selected_feature): 148 | time_merged_feature.append(time_merged_tensor(node_feat, dim=1)) 149 | node_merged_node_feat = reduce( 150 | lambda a, b: node_merged_tensor(a, b, 0), time_merged_feature 151 | ) 152 | if dgl_graph.num_nodes(node_type) == 1: # 当某一个类型的节点数量为1时,要追加一步特殊处理 153 | node_merged_node_feat = { 154 | feat_name: feat_tensor 155 | for feat_name, feat_tensor in time_merged_feature[0].items() 156 | } 157 | node_feat_dict[node_type] = node_merged_node_feat 158 | return node_feat_dict 159 | 160 | 161 | def time_merged_tensor(merged_list: List, dim: int = 1): 162 | keys = list(merged_list[0][-1].keys()) 163 | return { 164 | key: torch.cat( 165 | [merged_list[index][-1][key] for index in range(len(merged_list))], dim 166 | ) 167 | for key in keys 168 | } 169 | 170 | 171 | def dgl_graph_join_feature(sg: DGLGraph, subfeature: Dict) -> DGLGraph: 172 | for ntype, feat_dict in subfeature.items(): 173 | for feat_name, feat in feat_dict.items(): 174 | sg.nodes[ntype].data[feat_name] = feat.to(sg.device) 175 | return sg 176 | 177 | 178 | def convert_name_to_type_and_id(node_name: str): 179 | return node_name.split("/") 180 | 181 | 182 | def convert_type_and_id_to_name(node_type: str, node_id): 183 | return node_type + "/" + str(int(node_id)) 184 | 185 | 186 | def node_merged_tensor(a, b, dim): 187 | # TODO python的Reduce性能很差,下个版本应该改掉这个函数 188 | for key in a: 189 | # 对于时间步长不足的节点,补零处理 190 | if a[key].shape[1] > b[key].shape[1]: 191 | b[key] = F.pad( 192 | b[key], (0, 0, 0, a[key].shape[1] - b[key].shape[1]), "constant", 0 193 | ) 194 | elif a[key].shape[1] < b[key].shape[1]: 195 | a[key] = F.pad( 196 | a[key], (0, 0, 0, b[key].shape[1] - a[key].shape[1]), "constant", 0 197 | ) 198 | a[key] = torch.cat([a[key], b[key]], dim=dim) 199 | return a 200 | 201 | 202 | def dgl_graph_to_graph_dict( 203 | dgl_graph: DGLGraph, hetero_feat_dim: Dict 204 | ) -> Tuple[Dict, Dict]: 205 | struc_dict = extract_struc(dgl_graph) 206 | feat_dict = extract_feat(dgl_graph, hetero_feat_dim) 207 | return struc_dict, feat_dict 208 | 209 | 210 | def extract_struc(dgl_graph: DGLGraph) -> Dict: 211 | struc_dict = {} 212 | for (src_type, e_type, dst_type) in dgl_graph.canonical_etypes: 213 | src_id_, dst_id_ = dgl_graph.edges(etype=(src_type, e_type, dst_type)) 214 | src_id = dgl_graph.ndata["ID"][src_type][src_id_].long() 215 | dst_id = dgl_graph.ndata["ID"][dst_type][dst_id_].long() 216 | edge_time = ( 217 | dgl_graph.edata["time"][(src_type, e_type, dst_type)] 218 | if len(dgl_graph.canonical_etypes) > 1 219 | else dgl_graph.edata["time"] 220 | ) 221 | struc_dict[(src_type, e_type, dst_type)] = (src_id, dst_id, edge_time) 222 | return struc_dict 223 | 224 | 225 | def extract_feat(dgl_graph: DGLGraph, hetero_feat_dim: Dict) -> Dict: 226 | feat_dict = {} 227 | dgl_feat = dgl_graph.ndata["state"] 228 | for ntype in dgl_feat.keys(): 229 | feat_dim_dict = hetero_feat_dim[ntype] 230 | feats = extract_feat_tool(ntype, dgl_graph, dgl_feat, feat_dim_dict) 231 | feat_dict[ntype] = feats 232 | return feat_dict 233 | 234 | 235 | # TODO 这个函数要单独加入测试 236 | def extract_feat_tool( 237 | ntype, dgl_graph: DGLGraph, dgl_feat: Dict, feat_dim_dict: Dict 238 | ) -> Dict: 239 | feats: Dict = defaultdict(dict) 240 | for i, node_id in enumerate(dgl_graph.ndata["ID"][ntype]): 241 | end_idx = 0 242 | node_id = int(node_id) 243 | for feat_name, feat_dim in feat_dim_dict.items(): 244 | feats[node_id][feat_name] = ( 245 | dgl_feat[ntype][i, :, end_idx : end_idx + feat_dim] 246 | if len(dgl_feat[ntype].shape) == 3 247 | else dgl_feat[ntype][i, end_idx : end_idx + feat_dim] 248 | ) 249 | end_idx += feat_dim 250 | return feats 251 | -------------------------------------------------------------------------------- /transworld/graph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SACLabs/TransWorldNG/906c9283ed8e8121b0650869a4cdef4ab36c3848/transworld/graph/__init__.py -------------------------------------------------------------------------------- /transworld/graph/load.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pathlib import Path 3 | from graph.process import ( 4 | generate_feat_dict, 5 | generate_graph_dict, 6 | generate_unique_node_id, 7 | ) 8 | from typing import Dict, Tuple 9 | from collections import defaultdict 10 | from sklearn.preprocessing import MinMaxScaler 11 | 12 | def check_file_exist(file_path: Path): 13 | if not Path(file_path).exists(): 14 | raise FileNotFoundError("File path does not exist") 15 | 16 | def unique_id(data_path: Path): 17 | check_file_exist(data_path) 18 | node_all = pd.read_csv(data_path / "node_all.csv") 19 | node_id_dict = generate_unique_node_id(node_all) 20 | return node_id_dict 21 | 22 | def load_graph(data_path: Path, start_step, end_step,node_id_dict) -> Tuple[Dict, Dict]: 23 | """ 24 | Load graph data from csv files. 25 | - read node,edge,feature files, assign unique node id to each node 26 | - generate graph dict, feature dict 27 | return: dicts containing the graph data, feature data 28 | """ 29 | 30 | check_file_exist(data_path) 31 | all_files = list(data_path.glob("*.csv")) 32 | #node_all = pd.read_csv(data_path / "node_all.csv") 33 | #node_id_dict = unique_id(node_all) 34 | 35 | 36 | """Read files and get unique node id""" 37 | dfs: Dict = defaultdict(dict) 38 | file_names = ["node", "edge", "feat_veh", "feat_lane", "feat_road", "feat_tlc"] 39 | feat_file_names = ["feat_veh", "feat_lane", "feat_road", "feat_tlc"] 40 | #except_columns = ['name', 'coordinate',"shape"] 41 | scalers = {'veh': MinMaxScaler(), 'lane': MinMaxScaler(), 'road': MinMaxScaler(), 'tlc': MinMaxScaler()} 42 | 43 | for data_file in all_files: 44 | df_name = data_file.stem 45 | df = pd.read_csv(data_file) 46 | 47 | if df_name in feat_file_names: 48 | #df = df.loc[(df['step'] >= start_step) & (df['step'] <= end_step)] 49 | if df_name == 'feat_veh': 50 | df['coor_x'] = [float(coor.replace('(','').replace(")", '').split(',')[0]) for coor in df['coordinate']] 51 | df['coor_y'] = [float(coor.replace('(','').replace(")", '').split(',')[1]) for coor in df['coordinate']] 52 | select_columns=df.columns.difference(['step', 'name','coordinate']) 53 | scaled = pd.DataFrame(scalers['veh'].fit_transform(df[select_columns]), columns=select_columns, index=df.index) 54 | for col in select_columns: 55 | df[col] = scaled[col] 56 | df = df.drop(['coordinate'], axis=1) 57 | elif df_name == 'feat_lane': 58 | df['shape_a'] = [float(coor.replace('(','').replace(")", '').split(',')[0]) for coor in df['shape']] 59 | df['shape_b'] = [float(coor.replace('(','').replace(")", '').split(',')[1]) for coor in df['shape']] 60 | df['shape_c'] = [float(coor.replace('(','').replace(")", '').split(',')[2]) for coor in df['shape']] 61 | df['shape_d'] = [float(coor.replace('(','').replace(")", '').split(',')[3]) for coor in df['shape']] 62 | select_columns=df.columns.difference(['step', 'name','shape']) 63 | scaled = pd.DataFrame(scalers['lane'].fit_transform(df[select_columns]), columns=select_columns, index=df.index) 64 | for col in select_columns: 65 | df[col] = scaled[col] 66 | df = df.drop(['shape'], axis=1) 67 | elif df_name == 'feat_road': 68 | select_columns=df.columns.difference(['step', 'name']) 69 | scaled = pd.DataFrame(scalers['road'].fit_transform(df[select_columns]), columns=select_columns, index=df.index) 70 | for col in select_columns: 71 | df[col] = scaled[col] 72 | elif df_name == 'feat_tlc': 73 | select_columns=df.columns.difference(['step', 'name']) 74 | scaled = pd.DataFrame(scalers['tlc'].fit_transform(df[select_columns]), columns=select_columns, index=df.index) 75 | for col in select_columns: 76 | df[col] = scaled[col] 77 | else: 78 | pass 79 | 80 | dfs[df_name] = df 81 | if df_name == "edge": 82 | dfs[df_name]["to_id"] = [node_id_dict[i] for i in dfs[df_name]["to"]] 83 | dfs[df_name]["from_id"] = [node_id_dict[i] for i in dfs[df_name]["from"]] 84 | elif df_name in file_names: 85 | dfs[df_name]["node_id"] = [ 86 | node_id_dict[str(i)] for i in dfs[df_name]["name"] 87 | ] 88 | else: 89 | pass 90 | 91 | """Generate graph dict""" 92 | graph_dict = generate_graph_dict(dfs["edge"], start_step, end_step) 93 | #print(graph_dict) 94 | """Generate feature dict""" 95 | feat_dict = generate_feat_dict("lane", dfs["feat_lane"], start_step, end_step) 96 | feat_dict.update(generate_feat_dict("veh", dfs["feat_veh"], start_step, end_step)) 97 | feat_dict.update(generate_feat_dict("road", dfs["feat_road"], start_step, end_step)) 98 | feat_dict.update(generate_feat_dict("tlc", dfs["feat_tlc"], start_step, end_step)) 99 | #print(feat_dict) 100 | return graph_dict, feat_dict, node_id_dict, scalers 101 | -------------------------------------------------------------------------------- /transworld/graph/process.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from collections import defaultdict 3 | import torch 4 | from typing import Dict 5 | import ast 6 | 7 | def generate_unique_node_id(node_data: pd.DataFrame) -> Dict[str, int]: 8 | node_id_dict = {} 9 | # for node_type in set(node_data["type"]): 10 | for node_type in node_data["type"].unique(): 11 | # unique_names = list(set(node_data[node_data["type"] == node_type].name)) 12 | unique_names = list((node_data[node_data["type"] == node_type].name).unique()) 13 | node_id_dict_ = {name: unique_names.index(name) for name in unique_names} 14 | node_id_dict.update(node_id_dict_) 15 | return node_id_dict 16 | 17 | 18 | def generate_graph_dict(edge_Data: pd.DataFrame, start_step, end_step) -> Dict: 19 | g_dict: Dict = defaultdict(list) 20 | edge_Data = edge_Data.loc[(edge_Data['step'] >= start_step) & (edge_Data['step'] <= end_step)] 21 | for rel in edge_Data.relation.unique(): 22 | sub_edge = edge_Data[edge_Data["relation"] == rel] 23 | orig_type, rel_type, dest_type = rel.split("_") 24 | g_dict[(orig_type, rel_type, dest_type)] = [ 25 | torch.tensor( 26 | [ 27 | sub_edge.iloc[i]["from_id"], 28 | sub_edge.iloc[i]["to_id"], 29 | sub_edge.iloc[i]["step"], 30 | ] 31 | ) 32 | for i in range(sub_edge.shape[0]) 33 | ] 34 | 35 | for k, v in g_dict.items(): 36 | tensor_ = torch.cat([t.unsqueeze(1) for t in v], dim=1) 37 | g_dict[k] = (tensor_[0].long(), tensor_[1].long(), tensor_[2].long()) 38 | return g_dict 39 | 40 | 41 | def generate_feat_dict(node_type: str, feat_data: pd.DataFrame, start_step, end_step) -> Dict: 42 | feat_data = feat_data.loc[(feat_data['step'] >= start_step) & (feat_data['step'] <= end_step)] 43 | g_feat_dict: Dict = defaultdict(dict) 44 | node_id_list = feat_data.node_id.unique() 45 | sub_feat_data = feat_data.drop(["name", "node_id"], axis=1) 46 | for node_i in node_id_list: 47 | sub_node = feat_data[feat_data.node_id == node_i] 48 | 49 | feat_dict = { 50 | feat_i: torch.tensor(list([ast.literal_eval(str(x)) for x in sub_node[feat_i]])) 51 | for feat_i in sub_feat_data.columns 52 | } 53 | # feat_dict["step_keep"] = feat_dict["step"] 54 | feat_dict["time"] = feat_dict.pop("step") 55 | g_feat_dict[node_i].update(feat_dict) 56 | 57 | return {node_type: g_feat_dict} 58 | -------------------------------------------------------------------------------- /transworld/rules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SACLabs/TransWorldNG/906c9283ed8e8121b0650869a4cdef4ab36c3848/transworld/rules/__init__.py -------------------------------------------------------------------------------- /transworld/rules/post_process.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, List 2 | from collections import defaultdict 3 | from graph.process import generate_unique_node_id 4 | import pandas as pd 5 | from pathlib import Path 6 | 7 | 8 | def load_veh_route(filename, data_path: Path) -> Dict: 9 | node_all = pd.read_csv(data_path / "node_all.csv") 10 | node_id_dict = generate_unique_node_id(node_all) 11 | data_file = pd.read_csv(data_path / (filename + ".csv")) 12 | data_file["veh_id"] = [node_id_dict[str(i)] for i in data_file["name"]] 13 | rou_id = [] 14 | for rou in data_file["route"]: 15 | rou = [x for x in rou if x not in ["[", "'", "]", " "]] 16 | rou = "".join(rou).split(",") 17 | rou_id.append([node_id_dict[str(i)] for i in rou]) 18 | data_file["route"] = rou_id 19 | veh_route = data_file.set_index("veh_id").T.to_dict() 20 | return veh_route 21 | 22 | 23 | def get_veh_current_lane(struc_dict: Dict) -> int: 24 | # if ("veh", "phy/to", "lane") in struc_dict.keys(): 25 | lane_id = struc_dict[("veh", "phy/to", "lane")][1] 26 | return int(lane_id[-1]) 27 | 28 | 29 | def get_veh_next_lane( 30 | veh_id: int, veh_route: dict, cur_lane_id: int 31 | ) -> Union[str, None]: 32 | route_lst = veh_route[veh_id]["route"] 33 | cur_lane_idx = route_lst.index(cur_lane_id) 34 | if cur_lane_idx < len(route_lst) - 1: 35 | next_lane = route_lst[cur_lane_idx + 1] 36 | return next_lane 37 | else: 38 | return None 39 | 40 | def post_actions( 41 | node_names: List[str], struc_dict: Dict, feat_dict: Dict, veh_route: Dict 42 | ) -> Dict: # return action: ["node_name","add_edge(veh1,on,lane1)"] 43 | struc_actions = defaultdict(list) 44 | 45 | veh_list = struc_dict[("veh", "phy/to", "lane")][0] 46 | lane_list = struc_dict[("veh", "phy/to", "lane")][1] 47 | time_lst = struc_dict[("veh", "phy/to", "lane")][2] 48 | action_time = node_names[0].split('@')[-1] # TODO this code only support full graph inference 49 | action = [] 50 | min_dis = 0.95 51 | aggr_edge_list = defaultdict(float) 52 | 53 | for veh_id, lane_id, time in zip(veh_list, lane_list, time_lst): 54 | veh_id, lane_id, time = int(veh_id), int(lane_id), int(time) 55 | if aggr_edge_list.get((veh_id), None) is None: 56 | aggr_edge_list[veh_id] = {'time':time, 'lane_node': lane_id} 57 | else: 58 | new_value_dict = aggr_edge_list[veh_id] if aggr_edge_list[veh_id]['time'] > time else {'time':time, 'lane_node': lane_id} 59 | aggr_edge_list[veh_id] = new_value_dict 60 | 61 | 62 | for veh_id in list(aggr_edge_list.keys()): 63 | lane_id = aggr_edge_list[veh_id]['lane_node'] 64 | current_lane_len = abs(feat_dict["lane"][lane_id]["length"]) 65 | pos_on_lane = abs(feat_dict["veh"][veh_id]["pos_on_lane"]) 66 | 67 | if pos_on_lane / current_lane_len > min_dis: 68 | next_lane = get_veh_next_lane(veh_id, veh_route, lane_id) 69 | tlc_state = feat_dict["veh"][veh_id]["tlc_state"] 70 | if next_lane is None: # This vehicle has reached the destination 71 | action.append("delete_node(veh/" + str(veh_id)+ ")") 72 | # elif (next_lane is not None) and ( 73 | # tlc_state >= 0 74 | # ): # This vehicle will move to it's next route if it's upcoming tlc state is either green(1) or yellow(0) 75 | elif next_lane is not None: 76 | action.append( 77 | "add_edge(veh/" 78 | + str(veh_id) 79 | + ",phy/to," 80 | + "lane/" 81 | + str(next_lane) 82 | + ")" 83 | ) 84 | # action.append( 85 | # "delete_edge(veh/" 86 | # + str(veh_id) 87 | # + ",phy/to," 88 | # + "lane/" 89 | # + str(lane_id) 90 | # + ")" 91 | #) 92 | if action != []: 93 | struc_actions.update({'veh/'+str(veh_id)+'@'+str(action_time): action}) 94 | return struc_actions 95 | 96 | 97 | 98 | # for node_name in node_names: 99 | # action = [] 100 | # min_dis = 10 101 | # node_type, id_step = node_name.split("/") 102 | # node_id = int(id_step.split("@")[0]) 103 | # if node_type == "veh" : 104 | # """ 105 | # Change lane action when approaching the end of lane. 106 | # return: move to next lane if availiable, wait if tlc is red, remove node if reached destination 107 | # """ 108 | # min_dis = 10 # minimum distance for decision when approaching the end of lane 109 | 110 | 111 | 112 | # current_lane = get_veh_current_lane(struc_dict) 113 | # current_lane_len = abs(feat_dict["lane"][current_lane]["length"]) 114 | # pos_on_lane = abs(feat_dict["veh"][node_id]["pos_on_lane"]) 115 | # if current_lane_len - pos_on_lane < min_dis: 116 | # next_lane = get_veh_next_lane(node_id, veh_route, current_lane) 117 | # tlc_state = feat_dict["veh"][node_id]["tlc_state"] 118 | # if next_lane is None: # This vehicle has reached the destination 119 | # action.append("delete_node(veh/" + str(node_id)+ ")") 120 | # elif (next_lane is not None) and ( 121 | # tlc_state >= 0 122 | # ): # This vehicle will move to it's next route if it's upcoming tlc state is either green(1) or yellow(0) 123 | # action.append( 124 | # "add_edge(veh/" 125 | # + str(node_id) 126 | # + ",phy/to," 127 | # + "lane/" 128 | # + str(next_lane) 129 | # + ")" 130 | # ) 131 | # action.append( 132 | # "delete_edge(veh/" 133 | # + str(node_id) 134 | # + ",phy/to," 135 | # + "lane/" 136 | # + str(current_lane) 137 | # + ")" 138 | # ) 139 | # if action != []: 140 | # struc_actions.update({node_name: action}) 141 | # return struc_actions 142 | 143 | 144 | def get_feat_actions( 145 | node_names: List[str], struc_dict: Dict, feat_dict: Dict, veh_od: Dict 146 | ) -> Dict: 147 | pass 148 | -------------------------------------------------------------------------------- /transworld/rules/post_process_old.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, List 2 | from collections import defaultdict 3 | from graph.process import generate_unique_node_id 4 | import pandas as pd 5 | from pathlib import Path 6 | 7 | 8 | def load_veh_route(filename, data_path: Path) -> Dict: 9 | pass 10 | 11 | 12 | def get_veh_current_lane(struc_dict: Dict) -> int: 13 | # if ("veh", "phy/to", "lane") in struc_dict.keys(): 14 | lane_id = struc_dict[("veh", "phy/to", "lane")][1] 15 | return int(lane_id[-1]) 16 | 17 | 18 | def get_veh_next_lane( 19 | veh_id: int, veh_route: dict, cur_lane_id: int 20 | ) -> Union[str, None]: 21 | route_lst = veh_route[veh_id]["route"] 22 | cur_lane_idx = route_lst.index(cur_lane_id) 23 | if cur_lane_idx < len(route_lst) - 1: 24 | next_lane = route_lst[cur_lane_idx + 1] 25 | return next_lane 26 | else: 27 | return None 28 | 29 | 30 | def post_actions( 31 | node_names: List[str], struc_dict: Dict, feat_dict: Dict, veh_route: Dict 32 | ) -> Dict: # return action: ["node_name","add_edge(veh1,on,lane1)"] 33 | struc_actions = defaultdict(list) 34 | return struc_actions 35 | 36 | 37 | 38 | # for node_name in node_names: 39 | # action = [] 40 | # min_dis = 10 41 | # node_type, id_step = node_name.split("/") 42 | # node_id = int(id_step.split("@")[0]) 43 | # if node_type == "veh" : 44 | # """ 45 | # Change lane action when approaching the end of lane. 46 | # return: move to next lane if availiable, wait if tlc is red, remove node if reached destination 47 | # """ 48 | # min_dis = 10 # minimum distance for decision when approaching the end of lane 49 | 50 | 51 | 52 | # current_lane = get_veh_current_lane(struc_dict) 53 | # current_lane_len = abs(feat_dict["lane"][current_lane]["length"]) 54 | # pos_on_lane = abs(feat_dict["veh"][node_id]["pos_on_lane"]) 55 | # if current_lane_len - pos_on_lane < min_dis: 56 | # next_lane = get_veh_next_lane(node_id, veh_route, current_lane) 57 | # tlc_state = feat_dict["veh"][node_id]["tlc_state"] 58 | # if next_lane is None: # This vehicle has reached the destination 59 | # action.append("delete_node(veh/" + str(node_id)+ ")") 60 | # elif (next_lane is not None) and ( 61 | # tlc_state >= 0 62 | # ): # This vehicle will move to it's next route if it's upcoming tlc state is either green(1) or yellow(0) 63 | # action.append( 64 | # "add_edge(veh/" 65 | # + str(node_id) 66 | # + ",phy/to," 67 | # + "lane/" 68 | # + str(next_lane) 69 | # + ")" 70 | # ) 71 | # action.append( 72 | # "delete_edge(veh/" 73 | # + str(node_id) 74 | # + ",phy/to," 75 | # + "lane/" 76 | # + str(current_lane) 77 | # + ")" 78 | # ) 79 | # if action != []: 80 | # struc_actions.update({node_name: action}) 81 | # return struc_actions 82 | 83 | 84 | def get_feat_actions( 85 | node_names: List[str], struc_dict: Dict, feat_dict: Dict, veh_od: Dict 86 | ) -> Dict: 87 | pass 88 | -------------------------------------------------------------------------------- /transworld/rules/pre_process.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, List 2 | from collections import defaultdict 3 | from graph.process import generate_unique_node_id 4 | import pandas as pd 5 | from pathlib import Path 6 | 7 | 8 | def load_veh_depart(filename, data_path: Path, training_step: int) -> Dict: 9 | node_all = pd.read_csv(data_path / "node_all.csv") 10 | node_id_dict = generate_unique_node_id(node_all) 11 | data_file = pd.read_csv(data_path / (filename + ".csv")) 12 | #print(training_step) 13 | data_file = data_file[data_file['depart'] Dict: 9 | pass 10 | 11 | 12 | def pre_actions(veh_depart, sys_time, subgraph): 13 | """ 14 | Add new vehicles to the system if it is ready to departure. 15 | return: add node/edge action, for example "add_node(v-h/1)" and "add_edge(veh/1, phy/to, lane/1)" 16 | """ 17 | actions = {} 18 | return actions 19 | -------------------------------------------------------------------------------- /transworld/run.sh: -------------------------------------------------------------------------------- 1 | 2 | python -m transworld_exp --scenario 'traci_tls' --train_data "run1" --training_step 50 --pred_step 10 --hid_dim 50 --n_head 4 --n_layer 4 3 | #python -m transworld_exp --scenario 'bologna_clean' --train_data "run1" --training_step 400 --pred_step 10 --hid_dim 50 --n_head 4 --n_layer 4 4 | -------------------------------------------------------------------------------- /transworld/transworld_exp.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from graph.load import load_graph 3 | from graph.process import generate_unique_node_id 4 | from rules.pre_process import load_veh_depart, pre_actions 5 | from rules.post_process import load_veh_route, post_actions 6 | import random 7 | import torch.nn as nn 8 | from game.model import HGT, RuleBasedGenerator, GraphStateLoss 9 | from game.data import DataLoader, Dataset 10 | from game.graph import Graph 11 | import matplotlib.pyplot as plt 12 | import torch 13 | from game.operator.transform import dgl_graph_to_graph_dict 14 | from tqdm import tqdm 15 | from datetime import datetime 16 | import pickle 17 | import os 18 | import csv 19 | import argparse 20 | import logging 21 | import shutil 22 | import sys 23 | import pandas as pd 24 | from collections import defaultdict 25 | random.seed(3407) 26 | 27 | def print_loss(epoch, i, loss): 28 | with open('loss.csv','a') as loss_log: 29 | train_writer = csv.writer(loss_log) 30 | train_writer.writerow([str(epoch), str(i), str(round(loss,4))]) 31 | 32 | def train(timestamps, graph, batch_size, num_workers, encoder, generator, veh_route, loss_fcn, optimizer, logger, device): 33 | logger.info("========= start generate dataset =======") 34 | train_dataset = Dataset(timestamps, device, train_mode=True) 35 | train_loader = DataLoader(train_dataset, graph.operate, batch_size=batch_size, num_workers=num_workers, drop_last =True) 36 | logger.info("========== finish generate dataset =======") 37 | # graph_dicts = {} 38 | logger.info("========= start training =======") 39 | loss_list = [] 40 | for i, (cur_graphs, next_graphs) in enumerate(train_loader): 41 | loss = 0. 42 | for ((_, cur_graph), (seed_node_n_time, next_graph)) in zip(cur_graphs.items(), next_graphs.items()): 43 | cur_graph, next_graph = cur_graph.to(device), next_graph.to(device) 44 | assert (_.split("@")[0]) == seed_node_n_time.split("@")[0], ValueError("Dataloader Error! node_name not equal") 45 | node_type = seed_node_n_time.split("/")[0] 46 | # if node_type == 'veh': 47 | # continue 48 | time = float(seed_node_n_time.split("@")[1]) 49 | node_repr = encoder(cur_graph) 50 | actions, pred_graph = generator([seed_node_n_time], cur_graph, node_repr, veh_route) 51 | loss = loss + loss_fcn(pred_graph, next_graph) 52 | loss_list.append((loss.item()) / batch_size) 53 | #print_loss(i, loss.item()) 54 | optimizer.zero_grad() 55 | loss.backward() 56 | optimizer.step() 57 | logger.info(f"------------ loss is {sum(loss_list) / len(loss_list)} ---------") 58 | logger.info("========== finished training ==========") 59 | torch.cuda.empty_cache() 60 | return loss_list 61 | 62 | @torch.no_grad() 63 | def eval(graph, batch_size, num_workers, encoder, generator, veh_depart, veh_route, changable_feature_names, hetero_feat_dim, logger, device, training_step, perd_step): 64 | val_timestamp = [float(i) for i in range(int(training_step),int(training_step)+perd_step+1)] 65 | #print(val_timestamp) 66 | val_dataset = Dataset(val_timestamp, device, train_mode=False) 67 | val_loader = DataLoader(val_dataset, graph.operate, batch_size=batch_size, num_workers=num_workers, drop_last =False) 68 | logger.info(f"========= start eval ======= batch_size{batch_size}=======") 69 | for i, cur_graphs in tqdm(enumerate(val_loader)): 70 | val_result = {} 71 | for seed_node_n_time, cur_graph in cur_graphs.items(): 72 | cur_graph = cur_graph.to(device) 73 | node_type = seed_node_n_time.split("/")[0] 74 | if node_type == 'veh': 75 | continue 76 | time = float(seed_node_n_time.split("@")[1]) 77 | actions_pre = pre_actions(veh_depart, time, cur_graph) 78 | node_repr = encoder(cur_graph) 79 | actions, pred_graph = generator( 80 | [seed_node_n_time], cur_graph, node_repr, veh_route 81 | ) 82 | graph.states_to_feature_event(time, changable_feature_names, cur_graph, pred_graph) 83 | graph.actions_to_game_operations(actions_pre) 84 | if actions != {}: 85 | graph.actions_to_game_operations(actions) 86 | print(actions) 87 | 88 | total_graphs = val_loader.collate_tool([max(val_timestamp)+1.], max_step=perd_step) 89 | for seed_node_n_time, total_graph in total_graphs[0].items(): 90 | struc_dict, feat_dict = dgl_graph_to_graph_dict(total_graph, hetero_feat_dim) 91 | val_result[seed_node_n_time] = (struc_dict, feat_dict) 92 | logger.info("========== finished eval ==========") 93 | return val_result 94 | 95 | 96 | def create_folder(folder_path, delete_origin=False): 97 | if not os.path.exists(folder_path): 98 | # shutil.rmtree(folder_path) 99 | os.makedirs(folder_path, exist_ok=True) 100 | else: 101 | if delete_origin: 102 | shutil.rmtree(folder_path) 103 | os.makedirs(folder_path, exist_ok=True) 104 | 105 | def setup_logger(name, log_folder_path, level=logging.DEBUG): 106 | create_folder(log_folder_path) 107 | log_file = log_folder_path /f"{name}_log" 108 | handler = logging.FileHandler(log_file,encoding="utf-8",mode="a") 109 | formatter = logging.Formatter("%(asctime)s,%(msecs)d,%(levelname)s,%(name)s::%(message)s") 110 | handler.setFormatter(formatter) 111 | logger = logging.getLogger(name) 112 | logger.setLevel(level) 113 | logger.addHandler(handler) 114 | stream_handler = logging.StreamHandler(sys.stdout) 115 | stream_handler.setFormatter(formatter) 116 | logger.addHandler(stream_handler) 117 | return logger 118 | 119 | def run(scenario, test_data, training_step, pred_step, hid_dim, n_heads, n_layer, device): 120 | time_diff = [] 121 | 122 | exp_dir = Path(__file__).parent.parent / "experiment" 123 | 124 | exp_setting = exp_dir / scenario 125 | data_dir = exp_setting / "data" / test_data 126 | train_data_dir = data_dir / "train_data" 127 | test_data_dir = data_dir / "test_data" 128 | out_dir = data_dir / f"out_dim_{hid_dim}_n_heads_{n_heads}_n_layer_{n_layer}_pred_step_{pred_step}" 129 | name = f"scenario_{scenario}test_data_{test_data}_dim_{hid_dim}_n_heads_{n_heads}_n_layer_{n_layer}" 130 | log_folder_path = out_dir / "Log" 131 | logger = setup_logger(name, log_folder_path) 132 | logger.info(f"========== process {scenario}_{test_data}_{hid_dim}_{n_heads}_{n_layer}_pred_step_{pred_step} is running! ===========" ) 133 | isExist = os.path.exists(out_dir) 134 | if not isExist: 135 | os.makedirs(out_dir) 136 | 137 | node_all = pd.read_csv(train_data_dir / "node_all.csv") 138 | node_id_dict = generate_unique_node_id(node_all) 139 | 140 | veh_depart = load_veh_depart("veh_depart", train_data_dir, training_step) 141 | veh_route = load_veh_route("veh_route", train_data_dir) 142 | logger.info(f"========== finish load route and depart ========") 143 | # init struc_dict, feat_dict, node_id_dict 144 | 145 | struc_dict, feat_dict, node_id_dict, scalers = load_graph(train_data_dir, 0, training_step-1, node_id_dict) 146 | #test_struc, test_feat, node_id_dict, scalers = load_graph(test_data_dir) 147 | logger.info(f"========= finish load graph =========") 148 | #model parameters 149 | n_epochs = 1 #200 150 | batch_size = max(4,training_step//10 - 10) #100 151 | num_workers = 1 152 | batch_size = max(1, batch_size * num_workers) 153 | lr = 5e-4 154 | hid_dim = hid_dim 155 | n_heads = n_heads 156 | changable_feature_names = ['speed','pos_on_lane','occupancy','acceleration'] 157 | graph = Graph(struc_dict, feat_dict) 158 | hetero_feat_dim = graph.hetero_feat_dim 159 | timestamps = graph.timestamps.float().tolist() 160 | 161 | logger.info(f"========= {n_epochs}_{batch_size}_{num_workers} =========") 162 | 163 | encoder = HGT( 164 | in_dim={ 165 | ntype: int(sum(hetero_feat_dim[ntype].values())) 166 | for ntype in hetero_feat_dim.keys() 167 | }, 168 | n_ntypes=graph.num_ntypes, 169 | n_etypes=graph.num_etypes, 170 | hid_dim=hid_dim, 171 | n_layers=n_layer, 172 | n_heads=n_heads, 173 | activation = nn.ReLU() 174 | ).to(device) 175 | 176 | 177 | generator = RuleBasedGenerator( 178 | hetero_feat_dim, 179 | n_heads * hid_dim, 180 | { 181 | ntype: int(sum(hetero_feat_dim[ntype].values())) 182 | for ntype in hetero_feat_dim.keys() 183 | }, 184 | activation = nn.ReLU(), 185 | scalers= scalers, 186 | output_activation = nn.Sigmoid() 187 | ).to(device) 188 | 189 | logger.info("========== finish generate generator rule ==========") 190 | generator.register_rule(post_actions) 191 | 192 | loss_fcn = GraphStateLoss().to(device) 193 | 194 | optimizer = torch.optim.Adam(list(encoder.parameters())+list(generator.parameters()), lr=lr) 195 | 196 | 197 | 198 | loss_avg = [] 199 | for ep in tqdm(range(n_epochs)): 200 | logger.info(f"--------- current ep is {ep} --------") 201 | loss_lst = train(timestamps, graph, batch_size, num_workers, encoder, generator, veh_route, loss_fcn, optimizer, logger, device) 202 | #loss_dict[f'train_{ep}_loss'] = loss_lst 203 | loss_avg.append(sum(loss_lst) / len(loss_lst)) 204 | 205 | 206 | loss_df = pd.DataFrame(loss_avg) 207 | # loss_df = pd.DataFrame.from_dict(dict(loss_dict)) 208 | loss_df.to_csv(out_dir / 'loss.csv', index=False) 209 | 210 | torch.save(encoder.state_dict(), out_dir / 'encorder.pth') 211 | torch.save(generator.state_dict(), out_dir / 'generator.pth') 212 | 213 | before = datetime.now() 214 | 215 | for i in range(10): 216 | logger.info(f"--------- current is {0+pred_step*(i+1), training_step+pred_step*(i+1)} --------") 217 | sim_graph = eval(graph, batch_size//num_workers, num_workers, encoder, generator, veh_depart, veh_route, changable_feature_names, hetero_feat_dim, logger, device, training_step+pred_step*(i+1), pred_step) 218 | #print(sim_graph['veh/0@198.0'][0][('veh','phy/to','lane')][2]) 219 | with open(out_dir / f"predicted_graph_{scenario}_{test_data}_{n_layer}_{n_heads}_{hid_dim}_{i}.p", "wb") as f: 220 | pickle.dump(sim_graph, f) 221 | graph.reset() 222 | #struc_dict, feat_dict, node_id_dict, scalers = load_graph(train_data_dir, 0+pred_step*(i+1), training_step+pred_step*(i+1), node_id_dict) 223 | struc_dict, feat_dict, node_id_dict, scalers = load_graph(train_data_dir, training_step, training_step+pred_step*(i+1), node_id_dict) 224 | veh_depart = load_veh_depart("veh_depart", train_data_dir, training_step+pred_step*(i+1)) 225 | #logger.info(f"--------- current is {0+pred_step*(i+1), training_step+pred_step*(i+1)} --------") 226 | graph = Graph(struc_dict, feat_dict) 227 | #print(0+pred_step*(i+1), training_step+pred_step*(i+1)) 228 | 229 | 230 | after = datetime.now() 231 | time_diff.append((after - before).total_seconds()) 232 | 233 | logger.info(f"========== Eval time_diff is : {(after - before).total_seconds()} ==========") 234 | logger.info("========== Exp has finished! ==========") 235 | 236 | 237 | if __name__ =="__main__": 238 | parser = argparse.ArgumentParser() 239 | parser.add_argument("--scenario", type=str, default='traci_tls') 240 | parser.add_argument("--train_data", type=str, default='run1') 241 | parser.add_argument("--training_step", type=int, default=80) 242 | parser.add_argument("--pred_step", type=int, default=10) 243 | parser.add_argument("--hid_dim", type=int, default=100) 244 | parser.add_argument("--n_head", type=int, default=4) 245 | parser.add_argument("--n_layer", type=int, default=4) 246 | parser.add_argument("--gpu", type=int, default=0) 247 | args = parser.parse_args() 248 | if (not torch.cuda.is_available()) or (args.gpu == -1): 249 | device = torch.device("cpu") 250 | else: 251 | device = torch.device("cuda",args.gpu) 252 | run(args.scenario,args.train_data, args.training_step, args.pred_step, args.hid_dim, args.n_head, args.n_layer, device) --------------------------------------------------------------------------------