├── .github ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── ubuntu-ci.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── NOTICE ├── README.md ├── apps ├── README.md └── model_tune │ └── model_tune.py ├── codecov.yml ├── configs └── samples │ ├── gcv_models.yaml │ ├── mutate.yaml │ ├── rpc_client_mobile.yaml │ ├── rpc_client_server.yaml │ ├── toy_workloads.yaml │ ├── tune_batch.yaml │ ├── tune_local.yaml │ └── tune_rpc_master.yaml ├── docker ├── Dockerfile.base ├── Dockerfile.ci ├── Dockerfile.lorien ├── Dockerfile.tvm ├── bash.sh ├── build_ci.sh ├── build_lorien.sh ├── entrypoint ├── publish.sh ├── publish_docker_hub.sh ├── req_dev.txt └── requirements.txt ├── docs ├── .gitignore ├── Makefile ├── README.md └── source │ ├── _static │ └── img │ │ └── README │ ├── api │ ├── configs.rst │ ├── database │ │ ├── index.rst │ │ ├── table.rst │ │ └── util.rst │ ├── dialect │ │ ├── index.rst │ │ └── tvm_dial │ │ │ ├── auto_scheduler_dial │ │ │ ├── extract.rst │ │ │ ├── index.rst │ │ │ ├── job.rst │ │ │ ├── result.rst │ │ │ └── workload.rst │ │ │ ├── autotvm_dial │ │ │ ├── extract_from_model.rst │ │ │ ├── extract_from_record.rst │ │ │ ├── index.rst │ │ │ ├── job.rst │ │ │ ├── result.rst │ │ │ ├── util.rst │ │ │ └── workload.rst │ │ │ ├── frontend_parser.rst │ │ │ ├── index.rst │ │ │ ├── job.rst │ │ │ ├── result.rst │ │ │ └── util.rst │ ├── generate.rst │ ├── index.rst │ ├── logger.rst │ ├── main.rst │ ├── tune │ │ ├── index.rst │ │ ├── job.rst │ │ ├── manager.rst │ │ ├── master.rst │ │ ├── result.rst │ │ └── rpc │ │ │ ├── client.rst │ │ │ ├── index.rst │ │ │ ├── launch.rst │ │ │ └── server.rst │ ├── util.rst │ └── workload.rst │ ├── conf.py │ ├── contribute │ ├── index.rst │ └── pull_request.rst │ ├── genindex.rst │ ├── index.rst │ ├── setup │ ├── index.rst │ └── on_docker.rst │ └── tutorials │ ├── dialects.rst │ ├── extract_feature.rst │ ├── index.rst │ ├── train_model.rst │ ├── tune_on_aws_batch.rst │ ├── tune_on_local.rst │ └── tune_on_rpc.rst ├── lorien ├── __init__.py ├── __main__.py ├── configs.py ├── database │ ├── __init__.py │ ├── table.py │ └── util.py ├── dialect │ ├── __init__.py │ └── tvm_dial │ │ ├── __init__.py │ │ ├── auto_scheduler_dial │ │ ├── __init__.py │ │ ├── extract.py │ │ ├── job.py │ │ ├── result.py │ │ └── workload.py │ │ ├── autotvm_dial │ │ ├── __init__.py │ │ ├── extract_from_model.py │ │ ├── extract_from_record.py │ │ ├── job.py │ │ ├── result.py │ │ ├── util.py │ │ └── workload.py │ │ ├── frontend_parser.py │ │ ├── job.py │ │ ├── result.py │ │ └── util.py ├── generate.py ├── logger.py ├── main.py ├── tune │ ├── __init__.py │ ├── job.py │ ├── job_manager.py │ ├── master.py │ ├── result.py │ └── rpc │ │ ├── __init__.py │ │ ├── client.py │ │ ├── launch.py │ │ └── server.py ├── util.py └── workload.py ├── requirements.txt ├── scripts ├── aws │ ├── aws_batch_env.json │ ├── create_launch_template.sh │ └── gen_launch_template.py └── python │ ├── download_db.py │ ├── merge_workload.py │ └── sort_workloads.py ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── common.py ├── dialect ├── __init__.py └── tvm_dial │ ├── __init__.py │ ├── auto_scheduler_dial │ ├── __init__.py │ └── test_dialect.py │ ├── autotvm_dial │ ├── __init__.py │ └── test_dialect.py │ ├── test_parser.py │ └── test_util.py ├── lint ├── coveragerc ├── pylintrc └── yapf_style.cfg ├── test_config.py ├── test_database.py ├── test_generate.py ├── test_job_manager.py ├── test_result.py ├── test_tune_master.py └── test_util.py /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Thanks for contributing to Lorien! 2 | By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. 3 | 4 | ### Change Description 5 | 6 | 7 | -------------------------------------------------------------------------------- /.github/workflows/ubuntu-ci.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-18.04 15 | container: comaniac0422/lorien:ubuntu-18.04-v0.06 16 | steps: 17 | - name: Checkout PR 18 | uses: actions/checkout@v1 19 | - name: Check format with black 20 | run: make check_format 21 | - name: Lint with pylint 22 | run: make lint 23 | - name: Check type with mypy 24 | run: make type 25 | - name: Unit test and coverage report with pytest 26 | run: python3 -m pytest tests --cov-config=tests/lint/coveragerc --cov=lorien --cov-report "xml:cov.xml" 27 | - name: Upload coverage report to codecov 28 | uses: codecov/codecov-action@v1 29 | with: 30 | token: ${{ secrets.CODECOV_TOKEN }} 31 | file: ./cov.xml 32 | fail_ci_if_error: false 33 | - name: Build docs 34 | run: | 35 | sudo DEBIAN_FRONTEND=noninteractive apt-get update 36 | sudo DEBIAN_FRONTEND=noninteractive apt-get install -y python3-sphinx rsync 37 | make doc 38 | - name: Deploy docs 39 | uses: JamesIves/github-pages-deploy-action@releases/v3 40 | if: github.event_name == 'push' 41 | with: 42 | ACCESS_TOKEN: ${{ secrets.DEPLOY_ACCESS_TOKEN }} 43 | BRANCH: gh-pages 44 | FOLDER: docs/build/html 45 | 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | id_rsa 27 | cache_tvm_timestamp 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | .mypy* 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | .pytest* 47 | nosetests.xml 48 | cov*.xml 49 | *,cover 50 | .hypothesis/ 51 | *.db 52 | local_db 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # IPython Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # IDE stuff 98 | .vscode 99 | 100 | # All customized configs 101 | configs/* 102 | !configs/samples 103 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PRJ_NAME=lorien 2 | PORT=18871 3 | 4 | env: 5 | virtualenv -p python3 venv --system-site-packages 6 | venv/bin/pip3 install -r requirements.txt 7 | 8 | lint: 9 | python3 -m pylint ${PRJ_NAME} --rcfile=tests/lint/pylintrc 10 | 11 | type: 12 | # intall-types is the new feature since mypy 0.900 that installs missing stubs. 13 | python3 -m mypy ${PRJ_NAME} --ignore-missing-imports --install-types --non-interactive 14 | 15 | format: 16 | black -l 100 `git diff --name-only --diff-filter=ACMRTUX origin/master -- "*.py" "*.pyi"` 17 | 18 | check_format: 19 | black -l 100 --check `git diff --name-only --diff-filter=ACMRTUX origin/master -- "*.py" "*.pyi"` 20 | 21 | local_db: 22 | mkdir $@ 23 | cd $@; curl -O https://s3-us-west-2.amazonaws.com/dynamodb-local/dynamodb_local_latest.zip 24 | cd $@; unzip dynamodb_local_latest.zip; rm dynamodb_local_latest.zip 25 | 26 | launch_local_db: local_db 27 | java -Djava.library.path=./local_db/DynamoDBLocal_lib \ 28 | -jar ./local_db/DynamoDBLocal.jar \ 29 | -sharedDb -port 10020 30 | 31 | launch_rpc_server: 32 | # OBJC_DISABLE_INITIALIZE_FORK_SAFETY is a workaround to a MacOS 10.13 (High Sierra) issue with Python. 33 | # See http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html. 34 | OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES python3 -m lorien rpc-server --port ${PORT} 35 | 36 | unit_test: 37 | python3 -m pytest --lf 38 | 39 | cov: 40 | python3 -m pytest tests --cov-config=tests/lint/coveragerc --cov=${PRJ_NAME} --cov-report term 41 | 42 | doc: 43 | make -C docs html 44 | 45 | clean: 46 | rm -rf .coverage* *.xml *.log *.pyc *.egg-info tests/temp* test_* tests/*.pdf curr *.db 47 | find . -name "__pycache__" -type d -exec rm -r {} + 48 | find . -name ".pytest_cache" -type d -exec rm -r {} + 49 | find . -name ".pkl_memoize_py3" -type d -exec rm -r {} + 50 | 51 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Lorien: Efficient deep learning workload delivery. 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Lorien: A Unified Infrastructure for Efficient Deep Learning Workloads Delivery 2 | =============================================================================== 3 | [![Build Status](https://github.com/awslabs/lorien/actions/workflows/ubuntu-ci.yml/badge.svg?branch=main)](https://github.com/awslabs/lorien/actions/workflows/ubuntu-ci.yml) 4 | [![codecov.io](https://codecov.io/gh/awslabs/lorien/branch/main/graph/badge.svg?token=78Q29GBRHW)](https://codecov.io/gh/awslabs/lorien) 5 | 6 | Lorien is an infrastructure to massively explore/benchmark the best schedules of given deep learning models. 7 | Lorien is deep learning compiler (DLC) agnostic, so one can easily implement a Lorien dialect to support 8 | a new DLC. 9 | 10 | ## Motivation 11 | 12 | Although auto-tuning frameworks for deep learning compilers (e.g., TVM, Halide) are capable of 13 | delivering high-performance operators that match or even beat vendor kernel libraries, auto-tuning 14 | a deep learning model could take days or even weeks, especially for the model with many workloads 15 | like ResNet-152 or Inception V3. 16 | 17 | With such a long tuning time, one key question To maintain the best user experience during deep model 18 | developments and deployments is *How to promptly deliver schedules with reasonably good performance upon user requests?* 19 | Accordingly, we design and implement Lorien to remove the following obstacles: 20 | 21 | 1. *Tuning Process Scalability and Stability.* Long tuning time affects not only the time-to-market but the stability. 22 | To the best of our knowledge, none of existing auto-tuning frameworks is designed for tuning on multiple machines, 23 | and none of them consider fault tolerance. The tuning process, hence, has to be manually started over if it was 24 | accidentally interrupted. This is crucial especially on edge devices, which are less reliable than cloud instances 25 | and may fail frequently due to overheat or other factors. 26 | 27 | 2. *Tuning Result Management.* Although almost all auto-tuning frameworks provide mechanisms to serialize tuning 28 | results for future applications, all of them use file-based mechanism and have different formats. As a result, 29 | engineers have additional work to orchestrate the data for efficient usage. 30 | 31 | 3. *Time to Deliver an Efficient Schedule.* Even a database is constructed to serve most user requests, 32 | it is still possible that certain workloads are missing. However, modern auto-tuning frameworks usually 33 | leverage iterative search algorithms with on-device measurements, which usually take hours, 34 | to find an efficient schedule for an unseen workload. The unfavorably expensive querying/tuning overhead 35 | makes production deployment impractical. 36 | 37 | Lorien is a unified and extensible infrastructure for delivering efficient deep learning workloads upon requests. 38 | Lorien allows auto-tuning deep learning frameworks to be easily plugged in as dialects, and supports large scale 39 | tuning on both cloud and edge platforms. The tuning results are managed in a NoSQL database with a unified data model 40 | that fits all auto-tuning frameworks. 41 | While the best schedules managed in the database can be used to compile deep learning models to achieve high performance, 42 | the tuning logs managed in a file system can also 1) enable more comprehensive performance analysis on different platforms, 43 | and 2) help train a performance cost model with an AutoML solution. 44 | 45 | Please visit the [official documentations](https://awslabs.github.io/lorien) for setup guideline and tutorials. 46 | 47 | ## System Requirements 48 | 49 | * Python 3.6+ 50 | 51 | * **Amazon DynamoDB (local or aws)**: DynamoDB is used for storing and maintain the tuned schedules. 52 | You can choose to either of the following: 53 | 54 | 1. Launch a [local version](https://s3-us-west-2.amazonaws.com/dynamodb-local/dynamodb_local_latest.zip) using JVM on your machine, and specify endpoint URL (e.g. `--db "endpoint_url: http://:8000"`) when invoking a tuning procses. 55 | 56 | 2. Configure AWS credential on your machine to directly use AWS DynamoDB service. In this case, you do not have to specify any argument in tuning configurations. 57 | 58 | * **AWS S3 (optional)**: S3 is used to store the full tuning logs (JSON files generated by AutoTVM). If you specify `--commit-log-to bucket_name` and configure an AWS credential on your machine, then all complete tuning logs will be uploaded to the S3 bucket for debugging or research prupose. Note that this is an optional requirement, so you can ignore the `--commit-log-to` argument if you do not want to keep full tuning logs. 59 | 60 | * **AWS Batch (AWS ECR)**: You have to set up AWS batch computation environments, job queues, and job definitions in advance to use Lorien AWS batch worker for tuning. See [this blog post](https://fredhutch.github.io/aws-batch-at-hutch-docs/) for reference. You may also need to build an upload Lorien docker images to AWS ECR as the AWS batch job running container. 61 | 62 | ## Docker Images 63 | 64 | You can directly make use of pre-built Lorien docker images on [Docker Hub](https://hub.docker.com/repository/docker/comaniac0422/lorien/tags), which includes two typs of images for CPU and CPU+CUDA platforms. The docker images have TVM deployed so you can launch a tuning process in the container after cloning Lorien. The docker image is also used for Lorien CI purpose. 65 | 66 | ## Documentation 67 | 68 | [https://awslabs.github.io/lorien/](https://awslabs.github.io/lorien/) 69 | 70 | ### Citing Lorien 71 | 72 | If you use Lorien in a scientific publication, please cite the following paper: 73 | 74 | Cody Hao Yu, Xingjian Shi, Haichen Shen, Zhi Chen, Mu Li, Yida Wang, "Lorien: Efficient Deep Learning Workloads Delivery", Proceedings of the 12th ACM Symposium on Cloud Computing. 2021. 75 | 76 | ``` 77 | @inproceedings{yu2021lorien, 78 | title={Lorien: Efficient Deep Learning Workloads Delivery}, 79 | author={Yu, Cody Hao and Shi, Xingjian and Shen, Haichen and Chen, Zhi and Li, Mu and Wang, Yida}, 80 | booktitle={Proceedings of the Seventh ACM Symposium on Cloud Computing}, 81 | year={2021} 82 | } 83 | ``` 84 | 85 | -------------------------------------------------------------------------------- /apps/README.md: -------------------------------------------------------------------------------- 1 | # Applications 2 | This folder contains the applications of using the database. -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | 4 | coverage: 5 | precision: 2 6 | round: down 7 | range: "50...100" 8 | status: 9 | project: 10 | default: 11 | threshold: 0.5% 12 | patch: off 13 | 14 | parsers: 15 | gcov: 16 | branch_detection: 17 | conditional: yes 18 | loop: yes 19 | method: no 20 | macro: no 21 | 22 | comment: 23 | layout: "reach,diff,flags,tree" 24 | behavior: default 25 | require_changes: no 26 | -------------------------------------------------------------------------------- /configs/samples/gcv_models.yaml: -------------------------------------------------------------------------------- 1 | # Extract workloads from Gluon CV modelzoo. 2 | gcv: 3 | - alexnet 4 | - yolo3_darknet53_voc: 5 | data: [1, 3, 320, 320] 6 | output: 7 | - gcv_workloads.yaml 8 | -------------------------------------------------------------------------------- /configs/samples/mutate.yaml: -------------------------------------------------------------------------------- 1 | # Workload mutation rules 2 | rules: 3 | - task: 4 | - conv2d_NCHWc.x86 5 | - depthwise_conv2d_NCHWc.x86 6 | - conv2d_nchw.cuda 7 | - depthwise_conv2d_nchw.cuda 8 | - conv2d_nchw_winograd.cuda 9 | - dense_nopack.x86 10 | - dense_pack.x86 11 | desc: 12 | "[0, 1, 0]": "[1, 3, 4, 7, 8, 12, 16]" # Batch size 13 | "[0, 1, 1]": "[v, v * 2, v * 4]" # Channel size 14 | -------------------------------------------------------------------------------- /configs/samples/rpc_client_mobile.yaml: -------------------------------------------------------------------------------- 1 | server: 'localhost:18871' 2 | target: 'llvm -mcpu=core-avx2' 3 | device: 'my-device' 4 | runner-port: '9190' 5 | 6 | -------------------------------------------------------------------------------- /configs/samples/rpc_client_server.yaml: -------------------------------------------------------------------------------- 1 | server: 'localhost:18871' 2 | target: 'llvm -mcpu=skylake-avx512' 3 | 4 | -------------------------------------------------------------------------------- /configs/samples/toy_workloads.yaml: -------------------------------------------------------------------------------- 1 | workload: 2 | - '!!python/object:lorien.workload.Workload {args: [[TENSOR, [1, 64, 73, 73], float32], [TENSOR, [80, 64, 1, 1], float32], [1, 1], [0, 0, 0, 0], [1, 1], NCHW, NCHW, float32], lib: topi, target: llvm -keys=cpu, task_name: conv2d_NCHWc.x86}' 3 | - '!!python/object:lorien.workload.Workload {args: [[TENSOR, [1, 128, 56, 56], float32], [TENSOR, [128, 128, 1, 1], float32], [1, 1], [0, 0, 0, 0], [1, 1], NCHW, NCHW, float32], lib: topi, target: llvm -keys=cpu, task_name: conv2d_NCHWc.x86}' 4 | - '!!python/object:lorien.workload.Workload {args: [[TENSOR, [1, 256, 35, 35], float32], [TENSOR, [48, 256, 1, 1], float32], [1, 1], [0, 0, 0, 0], [1, 1], NCHW, NCHW, float32], lib: topi, target: llvm -keys=cpu, task_name: conv2d_NCHWc.x86}' 5 | -------------------------------------------------------------------------------- /configs/samples/tune_batch.yaml: -------------------------------------------------------------------------------- 1 | # Tuning options. 2 | batch: 3 | target: llvm -mcpu=skylake-avx512 -libs=cblas 4 | job_queue: lorien-c5-job-queue 5 | job_def: lorien-job-cpu:1 6 | job_bucket: saved-tuning-logs/job-queue 7 | tuner: random 8 | ntrial: 16 9 | 10 | # We enable clflush for x86 targets so we can have fewer tests. 11 | test: 1 12 | repeat: 10 13 | min: 1 14 | 15 | # Result committing options. 16 | commit-nbest: 20 17 | commit-table-name: lorien 18 | # Uncomment this line if you have configured AWS CLI and S3 bucket. 19 | # commit-log-to: saved-tuning-logs/lorien-x86-skylake 20 | -------------------------------------------------------------------------------- /configs/samples/tune_local.yaml: -------------------------------------------------------------------------------- 1 | # Tuning options. 2 | local: llvm -mcpu=core-avx2 3 | db: 4 | - endpoint_url: http://localhost:10020 5 | tuner: random 6 | ntrial: 15 7 | 8 | # We enable clflush for x86 targets so we can have fewer tests. 9 | test: 1 10 | repeat: 10 11 | min: 1 12 | 13 | # Result committing options. 14 | commit-nbest: 20 15 | commit-table-name: lorien 16 | # Uncomment this line if you have configured AWS CLI and S3 bucket. 17 | #commit-log-to: saved-tuning-logs 18 | -------------------------------------------------------------------------------- /configs/samples/tune_rpc_master.yaml: -------------------------------------------------------------------------------- 1 | # Tuning options. 2 | rpc: 3 | target: llvm -mcpu=skylake-avx512 4 | port: 18871 5 | db: 6 | - endpoint_url: http://localhost:10020 7 | tuner: random 8 | ntrial: 16 9 | 10 | # We enable clflush for x86 targets so we can have fewer tests. 11 | test: 1 12 | repeat: 10 13 | min: 1 14 | 15 | # Result committing options. 16 | commit-nbest: 20 17 | commit-table-name: lorien 18 | # Uncomment this line if you have configured AWS CLI and S3 bucket. 19 | #commit-log-to: saved-tuning-logs 20 | -------------------------------------------------------------------------------- /docker/Dockerfile.base: -------------------------------------------------------------------------------- 1 | ARG platform=x86 2 | 3 | # Base env setup for ARM/Intel CPUs 4 | FROM ubuntu:18.04 as cpu-base 5 | RUN echo "Creating environment for CPU" 6 | RUN apt-get update \ 7 | && apt-get install -y python3-pip python3-dev git curl sudo libxml2-dev libxslt-dev gfortran libopenblas-dev liblapack-dev wget \ 8 | && cd /usr/local/bin \ 9 | && ln -s /usr/bin/python3 python 10 | 11 | RUN python3 -m pip install --upgrade pip 12 | 13 | # Python dependencies 14 | COPY requirements.txt /requirements.txt 15 | RUN python3 -m pip install --ignore-installed -r /requirements.txt 16 | 17 | FROM cpu-base as x86-base 18 | 19 | 20 | # Base env setup for NVIDIA GPU 21 | FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 as gpu-base 22 | RUN echo "Creating environment for GPU" 23 | RUN apt-get update \ 24 | && apt-get install -y python3-pip python3-dev git curl sudo \ 25 | && cd /usr/local/bin \ 26 | && ln -s /usr/bin/python3 python 27 | 28 | RUN python3 -m pip install --upgrade pip 29 | 30 | # Python dependencies 31 | COPY requirements.txt /requirements.txt 32 | RUN python3 -m pip install --ignore-installed -r /requirements.txt 33 | 34 | FROM ${platform}-base as final 35 | 36 | -------------------------------------------------------------------------------- /docker/Dockerfile.ci: -------------------------------------------------------------------------------- 1 | ARG platform=x86 2 | FROM lorien:${platform}-tvm-latest 3 | 4 | COPY req_dev.txt /req_dev.txt 5 | RUN python3 -m pip install --ignore-installed -r /req_dev.txt 6 | 7 | RUN python3 -m pip install xgboost>=1.1.0 8 | RUN python3 -m pip install codecov 9 | RUN python3 -m pip install mxnet==1.5.1 gluoncv==0.6.0 10 | RUN python3 -m pip install tensorflow-cpu==2.2.0 keras onnx 11 | RUN python3 -m pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 12 | RUN python3 -m pip install flatbuffers 13 | RUN python3 -m pip install https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py3-none-any.whl 14 | 15 | 16 | COPY ./entrypoint /usr/local/bin 17 | RUN chmod +x /usr/local/bin/entrypoint 18 | ENTRYPOINT ["entrypoint"] 19 | 20 | -------------------------------------------------------------------------------- /docker/Dockerfile.lorien: -------------------------------------------------------------------------------- 1 | # Docker env with Lorien deployed 2 | ARG base=lorien:x86-tvm-latest 3 | FROM $base 4 | 5 | # Check repo head commit and clone Lorien 6 | ADD "https://www.random.org/cgi-bin/randbyte?nbytes=10&format=h" skipcache 7 | COPY id_rsa . 8 | RUN chmod 400 id_rsa && eval $(ssh-agent) && ssh-add ./id_rsa && \ 9 | ssh-keyscan -H github.com >> /etc/ssh/ssh_known_hosts && \ 10 | git clone git@github.com:comaniac/lorien.git 11 | ENV PYTHONPATH /lorien:${PYTHONPATH} 12 | 13 | COPY ./entrypoint /usr/local/bin 14 | RUN chmod +x /usr/local/bin/entrypoint 15 | ENTRYPOINT ["entrypoint"] 16 | 17 | -------------------------------------------------------------------------------- /docker/Dockerfile.tvm: -------------------------------------------------------------------------------- 1 | ARG platform=x86 2 | 3 | FROM lorien:x86-base-latest as x86-tvm-latest 4 | RUN echo "Installing TVM for x86 CPU" 5 | RUN python3 -m pip install tlcpack_nightly -f https://tlcpack.ai/wheels 6 | 7 | FROM lorien:gpu-bast-latest as gpu-tvm-latest 8 | RUN echo "Installing TVM for CUDA" 9 | RUN python3 -m pip install tlcpack_nightly_cu102 -f https://tlcpack.ai/wheels 10 | 11 | FROM ${platform}-tvm-latest as final 12 | -------------------------------------------------------------------------------- /docker/bash.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Start a bash for debugging docker image. 4 | # 5 | # Usage: docker/bash.sh 6 | # Starts an interactive session 7 | # 8 | # Usage2: docker/bash.sh [COMMAND] 9 | # Execute command in the docker image, non-interactive 10 | # 11 | if [ "$#" -lt 1 ]; then 12 | echo "Usage: docker/bash.sh [COMMAND]" 13 | exit 0 14 | fi 15 | 16 | # Setup AWS credentials. 17 | # The default AWS access key can only be used to access 18 | # local DynamoDB but not any other services (e.g., S3). 19 | if [ ! ${AWS_REGION} ]; then 20 | AWS_REGION="us-west-2" 21 | fi 22 | if [ ! ${AWS_ACCESS_KEY_ID} ]; then 23 | AWS_ACCESS_KEY_ID="aaa" 24 | fi 25 | if [ ! ${AWS_SECRET_ACCESS_KEY} ]; then 26 | AWS_SECRET_ACCESS_KEY="bbb" 27 | fi 28 | 29 | 30 | DOCKER_IMAGE_NAME=("$1") 31 | 32 | if [ "$#" -eq 1 ]; then 33 | COMMAND="bash" 34 | if [[ $(uname) == "Darwin" ]]; then 35 | # Docker's host networking driver isn't supported on macOS. 36 | # Use default bridge network and expose port for jupyter notebook. 37 | DOCKER_EXTRA_PARAMS=("-it -p 8888:8888") 38 | else 39 | DOCKER_EXTRA_PARAMS=("-it --net=host") 40 | fi 41 | else 42 | shift 1 43 | COMMAND=("$@") 44 | fi 45 | 46 | # Use nvidia-docker if the container is GPU. 47 | if [[ ! -z $CUDA_VISIBLE_DEVICES ]]; then 48 | CUDA_ENV="-e CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" 49 | else 50 | CUDA_ENV="" 51 | fi 52 | 53 | if [[ "${DOCKER_IMAGE_NAME}" == *"gpu"* ]]; then 54 | if ! type "nvidia-docker" 1> /dev/null 2> /dev/null 55 | then 56 | DOCKER_BINARY="docker" 57 | CUDA_ENV=" --gpus all "${CUDA_ENV} 58 | else 59 | DOCKER_BINARY="nvidia-docker" 60 | fi 61 | else 62 | DOCKER_BINARY="docker" 63 | fi 64 | 65 | # Print arguments. 66 | echo "Running '${COMMAND[@]}' inside ${DOCKER_IMAGE_NAME}..." 67 | 68 | # By default we cleanup - remove the container once it finish running (--rm) 69 | # and share the PID namespace (--pid=host) so the process inside does not have 70 | # pid 1 and SIGKILL is propagated to the process inside. 71 | ${DOCKER_BINARY} run --rm --pid=host\ 72 | -e "AWS_REGION=${AWS_REGION}" \ 73 | -e "AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}" \ 74 | -e "AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}" \ 75 | ${CUDA_ENV}\ 76 | ${DOCKER_EXTRA_PARAMS[@]} \ 77 | ${DOCKER_IMAGE_NAME}\ 78 | ${COMMAND[@]} 79 | 80 | -------------------------------------------------------------------------------- /docker/build_ci.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | if [ "$#" -lt 1 ]; then 4 | echo "Usage ./build.sh " 5 | exit 0 6 | fi 7 | 8 | PLATFORM=${1} 9 | 10 | # Build base 11 | DOCKER_BUILDKIT=1 docker build -f Dockerfile.base -t lorien:${PLATFORM}-base-latest \ 12 | --build-arg platform=${PLATFORM} . 13 | 14 | # Build TVM on base 15 | DOCKER_BUILDKIT=1 docker build -f Dockerfile.tvm -t lorien:${PLATFORM}-tvm-latest \ 16 | --build-arg platform=${PLATFORM} . 17 | 18 | # Build CI on TVM 19 | DOCKER_BUILDKIT=1 docker build -f Dockerfile.ci -t lorien:${PLATFORM}-ci-latest \ 20 | --build-arg platform=${PLATFORM} . 21 | 22 | -------------------------------------------------------------------------------- /docker/build_lorien.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | if [ "$#" -lt 2 ]; then 4 | echo "Usage ./build.sh " 5 | exit 0 6 | fi 7 | 8 | PLATFORM=${1} 9 | BUILD_ON_CI=${2} 10 | 11 | # Build base 12 | DOCKER_BUILDKIT=1 docker build -f Dockerfile.base -t lorien:${PLATFORM}-base-latest \ 13 | --build-arg platform=${PLATFORM} . 14 | 15 | # Build TVM on base 16 | DOCKER_BUILDKIT=1 docker build -f Dockerfile.tvm -t lorien:${PLATFORM}-tvm-latest \ 17 | --build-arg platform=${PLATFORM} . 18 | 19 | # Determine the base image for Lorien to deploy on 20 | if [ $BUILD_ON_CI -eq 1 ]; then 21 | # Build CI on TVM 22 | DOCKER_BUILDKIT=1 docker build -f Dockerfile.ci -t lorien:${PLATFORM}-ci-latest \ 23 | --build-arg platform=${PLATFORM} . 24 | 25 | LORIEN_BASE=lorien:${PLATFORM}-ci-latest 26 | else 27 | LORIEN_BASE=lorien:${PLATFORM}-tvm-latest 28 | fi 29 | 30 | # Deploy Lorien 31 | DOCKER_BUILDKIT=1 docker build -f Dockerfile.lorien -t lorien:${PLATFORM}-latest \ 32 | --build-arg base=${LORIEN_BASE} . 33 | 34 | -------------------------------------------------------------------------------- /docker/entrypoint: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Configure AWS credentials. 4 | aws configure set region ${AWS_REGION} 5 | aws configure set aws_access_key_id ${AWS_ACCESS_KEY_ID} 6 | aws configure set aws_secret_access_key ${AWS_SECRET_ACCESS_KEY} 7 | 8 | echo "Running command $@" 9 | exec "$@" 10 | 11 | -------------------------------------------------------------------------------- /docker/publish.sh: -------------------------------------------------------------------------------- 1 | 2 | # Repo path could be either AWS ECR repo or Docker hub repo. 3 | # Note that you have to `docker login` first to authorize the write permission. 4 | # Example: ./publish.sh lorien:cpu-latest .dkr.ecr..amazonaws.com//cpu-latest 5 | 6 | if [ "$#" -lt 2 ]; then 7 | echo "Usage ./build.sh " 8 | exit 0 9 | fi 10 | 11 | IMAGE_TAG=${1} 12 | TARGET_REPO=${2} 13 | 14 | echo "Uploading ${IMAGE_TAG} to ${TARGET_REPO}" 15 | docker tag ${IMAGE_TAG} ${TARGET_REPO} 16 | docker push ${TARGET_REPO} 17 | 18 | -------------------------------------------------------------------------------- /docker/publish_docker_hub.sh: -------------------------------------------------------------------------------- 1 | # Note that you have to `docker login` first to authorize the write permission. 2 | 3 | if [ "$#" -lt 1 ]; then 4 | echo "Usage ./build.sh " 5 | exit 0 6 | fi 7 | 8 | IMAGE_VERSION="v0.06" 9 | PLATFORM=${1} 10 | 11 | if [ $PLATFORM = "x86" ]; then 12 | TARGET_REPO=comaniac0422/lorien:ubuntu-18.04-${IMAGE_VERSION} 13 | else 14 | TARGET_REPO=comaniac0422/lorien:ubuntu-18.04-cuda-${IMAGE_VERSION} 15 | fi 16 | 17 | echo "Uploading image to ${TARGET_REPO}" 18 | docker tag lorien:${PLATFORM}-ci-latest ${TARGET_REPO} 19 | docker push ${TARGET_REPO} 20 | 21 | -------------------------------------------------------------------------------- /docker/req_dev.txt: -------------------------------------------------------------------------------- 1 | # test 2 | xgboost >= 1.1.0 3 | mypy >= 0.900 4 | mock 5 | pylint == 2.4.4 6 | pytest 7 | pytest-cov 8 | pytest-mock 9 | pytest-remotedata == 0.3.2 10 | testfixtures 11 | black 12 | moto == 1.3.15.dev886 13 | 14 | # docs 15 | Sphinx 16 | sphinx_rtd_theme 17 | -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | future 3 | tqdm > 4.40 4 | argparse 5 | rpyc 6 | boto3 7 | filelock 8 | ruamel.yaml >= 0.16.12 9 | awscli >= 1.18.140 10 | 11 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | _build 3 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | 3 | # You can set these variables from the command line. 4 | SPHINXOPTS = 5 | SPHINXBUILD = sphinx-build 6 | SOURCEDIR = source 7 | BUILDDIR = build 8 | 9 | # Put it first so that "make" without argument is like "make help". 10 | help: 11 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 12 | 13 | .PHONY: help Makefile 14 | 15 | # Catch-all target: route all unknown targets to Sphinx using the new 16 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 17 | %: Makefile 18 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -W 19 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | Lorien Documentation 2 | ==================== 3 | 4 | This folder contains the source of Lorien documentations. 5 | 6 | * Python package requirements 7 | * sphinx >= 1.5.5, sphinx-gallery, sphinx_rtd_theme 8 | 9 | -------------------------------------------------------------------------------- /docs/source/_static/img/README: -------------------------------------------------------------------------------- 1 | Static images should be here. 2 | -------------------------------------------------------------------------------- /docs/source/api/configs.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ####### 19 | Configs 20 | ####### 21 | 22 | .. automodule:: lorien.configs 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/database/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ######## 19 | Database 20 | ######## 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | 25 | table 26 | util 27 | 28 | -------------------------------------------------------------------------------- /docs/source/api/database/table.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ##### 19 | Table 20 | ##### 21 | 22 | .. automodule:: lorien.database.table 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/database/util.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | #### 19 | Util 20 | #### 21 | 22 | .. automodule:: lorien.database.util 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/dialect/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ####### 19 | Dialect 20 | ####### 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | 25 | tvm_dial/index 26 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/auto_scheduler_dial/extract.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ####### 19 | Extract 20 | ####### 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.auto_scheduler_dial.extract 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/auto_scheduler_dial/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ##################### 19 | AutoScheduler Dialect 20 | ##################### 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | 25 | extract 26 | job 27 | result 28 | workload 29 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/auto_scheduler_dial/job.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ### 19 | Job 20 | ### 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.auto_scheduler_dial.job 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/auto_scheduler_dial/result.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ###### 19 | Result 20 | ###### 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.auto_scheduler_dial.result 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/auto_scheduler_dial/workload.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ######## 19 | Workload 20 | ######## 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.auto_scheduler_dial.workload 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/autotvm_dial/extract_from_model.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ################## 19 | Extract From Model 20 | ################## 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.autotvm_dial.extract_from_model 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/autotvm_dial/extract_from_record.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ################### 19 | Extract From Record 20 | ################### 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.autotvm_dial.extract_from_record 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/autotvm_dial/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ############### 19 | AutoTVM Dialect 20 | ############### 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | 25 | extract_from_model 26 | extract_from_record 27 | job 28 | result 29 | util 30 | workload 31 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/autotvm_dial/job.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ### 19 | Job 20 | ### 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.autotvm_dial.job 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/autotvm_dial/result.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ###### 19 | Result 20 | ###### 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.autotvm_dial.result 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/autotvm_dial/util.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | #### 19 | Util 20 | #### 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.autotvm_dial.util 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/autotvm_dial/workload.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ######## 19 | Workload 20 | ######## 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.autotvm_dial.workload 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/frontend_parser.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ############### 19 | Frontend Parser 20 | ############### 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.frontend_parser 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ########### 19 | TVM Dialect 20 | ########### 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | 25 | frontend_parser 26 | job 27 | result 28 | util 29 | auto_scheduler_dial/index 30 | autotvm_dial/index 31 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/job.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ### 19 | Job 20 | ### 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.job 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/result.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ###### 19 | Result 20 | ###### 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.result 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/dialect/tvm_dial/util.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | #### 19 | Util 20 | #### 21 | 22 | .. automodule:: lorien.dialect.tvm_dial.util 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/generate.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ######## 19 | Generate 20 | ######## 21 | 22 | .. automodule:: lorien.generate 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ########### 19 | Python APIs 20 | ########### 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | 25 | main 26 | configs 27 | workload 28 | util 29 | logger 30 | generate 31 | tune/index 32 | database/index 33 | dialect/index 34 | 35 | -------------------------------------------------------------------------------- /docs/source/api/logger.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ###### 19 | Logger 20 | ###### 21 | 22 | .. automodule:: lorien.logger 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/main.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | #### 19 | Main 20 | #### 21 | 22 | .. automodule:: lorien.main 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/tune/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | #### 19 | Tune 20 | #### 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | 25 | master 26 | manager 27 | result 28 | job 29 | rpc/index 30 | 31 | -------------------------------------------------------------------------------- /docs/source/api/tune/job.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ### 19 | Job 20 | ### 21 | 22 | .. automodule:: lorien.tune.job 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/tune/manager.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ########### 19 | Job Manager 20 | ########### 21 | 22 | .. automodule:: lorien.tune.job_manager 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/tune/master.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ###### 19 | Master 20 | ###### 21 | 22 | .. automodule:: lorien.tune.master 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/tune/result.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ###### 19 | Result 20 | ###### 21 | 22 | .. automodule:: lorien.tune.result 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/tune/rpc/client.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ###### 19 | Client 20 | ###### 21 | 22 | .. automodule:: lorien.tune.rpc.client 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/tune/rpc/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ### 19 | RPC 20 | ### 21 | 22 | .. toctree:: 23 | :maxdepth: 2 24 | 25 | server 26 | client 27 | launch 28 | 29 | -------------------------------------------------------------------------------- /docs/source/api/tune/rpc/launch.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ###### 19 | Launch 20 | ###### 21 | 22 | .. automodule:: lorien.tune.rpc.launch 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/tune/rpc/server.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ###### 19 | Server 20 | ###### 21 | 22 | .. automodule:: lorien.tune.rpc.server 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/util.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | #### 19 | Util 20 | #### 21 | 22 | .. automodule:: lorien.util 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/workload.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ######## 19 | Workload 20 | ######## 21 | 22 | .. automodule:: lorien.workload 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | 18 | sys.path.insert(0, os.path.abspath("../../")) 19 | 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "lorien" 24 | copyright = "" 25 | author = "" 26 | 27 | # The short X.Y.Z version 28 | version = "0.0.1" 29 | # The full version, including alpha/beta/rc tags 30 | release = "alpha" 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # If your documentation needs a minimal Sphinx version, state it here. 36 | # 37 | # needs_sphinx = '1.0' 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = [ 43 | "sphinx.ext.autodoc", 44 | "sphinx.ext.autosummary", 45 | "sphinx.ext.intersphinx", 46 | "sphinx.ext.napoleon", 47 | "sphinx.ext.mathjax", 48 | ] 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ["_templates"] 52 | 53 | # The suffix(es) of source filenames. 54 | # You can specify multiple suffix as a list of string: 55 | # 56 | # source_suffix = ['.rst', '.md'] 57 | source_suffix = ".rst" 58 | 59 | # The master toctree document. 60 | master_doc = "index" 61 | 62 | # The language for content autogenerated by Sphinx. Refer to documentation 63 | # for a list of supported languages. 64 | # 65 | # This is also used if you do content translation via gettext catalogs. 66 | # Usually you set "language" from the command line for these cases. 67 | language = "python3" 68 | 69 | # List of patterns, relative to source directory, that match files and 70 | # directories to ignore when looking for source files. 71 | # This pattern also affects html_static_path and html_extra_path. 72 | exclude_patterns = [] 73 | 74 | # The name of the Pygments (syntax highlighting) style to use. 75 | pygments_style = None 76 | 77 | # generate autosummary even if no references 78 | # autosummary_generate = True 79 | 80 | 81 | # -- Options for HTML output ------------------------------------------------- 82 | 83 | # The theme to use for HTML and HTML Help pages. See the documentation for 84 | # a list of builtin themes. 85 | # 86 | html_theme = "sphinx_rtd_theme" 87 | 88 | # Theme options are theme-specific and customize the look and feel of a theme 89 | # further. For a list of options available for each theme, see the 90 | # documentation. 91 | # 92 | # html_theme_options = {} 93 | 94 | # Add any paths that contain custom static files (such as style sheets) here, 95 | # relative to this directory. They are copied after the builtin static files, 96 | # so a file named "default.css" will overwrite the builtin "default.css". 97 | html_static_path = ["_static"] 98 | 99 | # Custom sidebar templates, must be a dictionary that maps document names 100 | # to template names. 101 | # 102 | # The default sidebars (for documents that don't match any pattern) are 103 | # defined by theme itself. Builtin themes are using these templates by 104 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 105 | # 'searchbox.html']``. 106 | # 107 | # html_sidebars = {} 108 | 109 | 110 | # -- Options for HTMLHelp output --------------------------------------------- 111 | 112 | # Output file base name for HTML help builder. 113 | htmlhelp_basename = "loriendoc" 114 | 115 | # -- Extension configuration ------------------------------------------------- 116 | 117 | # -- Options for todo extension ---------------------------------------------- 118 | 119 | # If true, `todo` and `todoList` produce output, else they produce nothing. 120 | todo_include_todos = True 121 | -------------------------------------------------------------------------------- /docs/source/contribute/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | .. _contribute: 19 | 20 | #################### 21 | Contribute to Lorien 22 | #################### 23 | 24 | Lorien is a large project and cannot be done by a single person. As a result, we welcome everyone to contribute. We value all forms of contributions, including but not limited to: 25 | 26 | - Code reviewing of the existing patches. 27 | - Documentation and usage examples. 28 | - Code readability and developer guide. 29 | 30 | - We welcome contributions that add code comments 31 | to improve readability 32 | - We also welcome contributions to docs to explain the 33 | design choices of the internal. 34 | 35 | - Test cases to make the codebase more robust. 36 | - Tutorials, blog posts, talks that promote the project. 37 | 38 | Here are some guidelines for contributing to various aspect of the project (TBA): 39 | 40 | .. toctree:: 41 | :maxdepth: 2 42 | 43 | pull_request 44 | 45 | -------------------------------------------------------------------------------- /docs/source/contribute/pull_request.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ##################### 19 | Submit a Pull Request 20 | ##################### 21 | 22 | Assuming you have forked Lorien and made some nice features. Here is a guide to submit your changes as a pull request (PR). 23 | 24 | - Add test-cases to cover the new features or bugfix the patch introduces. 25 | - Document the code you wrote. 26 | - Run unit tests locally and make sure your changes can pass all of them: 27 | 28 | .. code:: bash 29 | 30 | make unit_test 31 | 32 | If you only need to re-run a single test file: 33 | 34 | .. code:: bash 35 | 36 | python3 -m pytest tests/ 37 | 38 | - Run code coverage test and make sure your changes will not decrease the code coverage rate: 39 | 40 | .. code:: bash 41 | 42 | make cov 43 | 44 | Note that since code coverage test has to run unit test as well, you can actually skip the previous step if you plan to get a coverage report. 45 | 46 | 47 | - Rebase your branch on the head of master. Note that you may need to resolve conflicts during rebasing. Follow the prompt up git instructions to resolve them. 48 | 49 | .. code:: bash 50 | 51 | git remote add upstream git@github.com:comaniac/lorien.git 52 | git fetch upstream 53 | git rebase upstream/master your_branch 54 | git push origin -f 55 | 56 | - Check the code format: 57 | 58 | .. code:: bash 59 | 60 | make check_format 61 | 62 | If you got any errors when checking the format, use the following command to auto-format the code: 63 | 64 | .. code:: bash 65 | 66 | make format 67 | 68 | - Check lint. We use ``pylint`` to check if the coding style aligns to PEP8 standards. Run the following command for linting. 69 | 70 | .. code:: bash 71 | 72 | make lint 73 | 74 | You have to get a perfect score in order to pass the CI. If you believe that the ERROR/WARNING you got from linting does not make sense, you are also welcome to open an issue for discussion. Do NOT simply add ``pylint: disable`` to workaround it without reasons. 75 | 76 | :: 77 | 78 | Your code has been rated at 10.00/10 (previous run: 10.00/10, +0.00) 79 | 80 | - Run the following command to check type. We use ``mypy`` to check if the code has potential type errors. 81 | 82 | .. code:: bash 83 | 84 | make type 85 | 86 | You should see the following message (the number of source files may change over time). 87 | 88 | :: 89 | 90 | Success: no issues found in 29 source files 91 | 92 | - Commit all changes you made during the above steps. 93 | - Send the pull request and request code reviews from other contributors. 94 | 95 | - To get your code reviewed quickly, we encourage you to help review others' code so they can do the favor in return. 96 | - Code review is a shepherding process that helps to improve contributor's code quality. 97 | We should treat it proactively, to improve the code as much as possible before the review. 98 | We highly value patches that can get in without extensive reviews. 99 | - The detailed guidelines and summarizes useful lessons. 100 | 101 | - The pull request can be merged after passing the CI and approved by at least one reviewer. 102 | 103 | 104 | ************** 105 | CI Environment 106 | ************** 107 | We use Github Action with the prebuilt docker image to run CI. You can find the prebuilt docker images at `Docker Hub `_ . 108 | 109 | Since updating docker images may cause CI problems and need fixes to accommodate the new environment, here is the protocol to update the docker image for CI: 110 | 111 | - Send PR to update build script in the repo and ask one of the code owners to perform the following steps. 112 | - Build the new docker image: ``./build cpu``. 113 | - Tag and publish the image with a new version: ``./publish.sh cpu comaniac0422/lorien:ubuntu-18.04-v``. 114 | - Update the version (most of the time increase the minor version) in ``./github/workflows/ubuntu-ci.yaml``, send a PR. 115 | - The PR should now use the updated environment to run CI. Fix any issues wrt to the new image versions. 116 | - Merge the PR and now we are in new version. 117 | 118 | ************************* 119 | Code Coverage Enforcement 120 | ************************* 121 | The CI will upload the code coverage report to codecov.io to keep track of code coverage changes. You can find a badge in the README showing the current code coverage of master branch. We enforce the code coverage to be higher than 90%. 122 | 123 | After your PR has passed the CI, you should see a Codecov bot posts a comment in your PR for a code coverage change report. Your PR should guarantee that the code coverage does not drop before getting merged. 124 | 125 | -------------------------------------------------------------------------------- /docs/source/genindex.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ##### 19 | Index 20 | ##### 21 | 22 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | #################### 19 | Lorien Documentation 20 | #################### 21 | 22 | *********** 23 | Get Started 24 | *********** 25 | .. toctree:: 26 | :maxdepth: 1 27 | 28 | setup/index 29 | tutorials/index 30 | contribute/index 31 | 32 | ************* 33 | API Reference 34 | ************* 35 | .. toctree:: 36 | :maxdepth: 2 37 | 38 | api/index 39 | 40 | ***** 41 | Index 42 | ***** 43 | .. toctree:: 44 | :maxdepth: 1 45 | 46 | genindex 47 | 48 | -------------------------------------------------------------------------------- /docs/source/setup/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | .. _setup: 19 | 20 | ############ 21 | Setup Lorien 22 | ############ 23 | Lorien is a Python package, so you can simply install it by running the following command. 24 | 25 | .. code-block:: bash 26 | 27 | python3 setup.py install 28 | 29 | Note that Lorien installation does not include the dialect-specific packages such as TVM in TVM dialects, because every dialect is optional. When launching Lorien, it detects the system environment it is running and loads the supported dialects. In other words, if you want to use the TVM dialects in Lorien, for example, you have to manually install TVM before running Lorien. You can either build TVM from source, or install the nightly build by referring to `this website `_: 30 | 31 | .. code-block:: bash 32 | 33 | pip3 install tlcpack-nightly-cu102 -f https://tlcpack.ai/wheels 34 | 35 | If you prefer not to set up an environment on your host machine, you can also consider to run it inside docker with Lorien prebuilt docker images. See :ref:`on-docker` for detail steps. 36 | 37 | 38 | .. toctree:: 39 | :maxdepth: 1 40 | 41 | on_docker 42 | -------------------------------------------------------------------------------- /docs/source/setup/on_docker.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | .. _on-docker: 19 | 20 | ############# 21 | Docker Images 22 | ############# 23 | 24 | We provide prebuilt docker images to let you 1) quickly try out Lorien, 2) refer the environment set up, and 3) use it for AWS batch containers. You can find them on `Docker Hub `_. We currently only provide prebuilt images for Ubuntu 18.04 on CPU and GPU platforms. The docker images include dependent packages and nightly built TVM. If you prefer other platforms or OS and is willing to make one, refer to :ref:`contribute` to file a pull request. 25 | 26 | First of all, you need to set up `docker `_ (or `nvidia-docker `_ if you want to use cuda). In the rest of this guide, we will focus on the CPU platform, but you should be able to get the GPU platform working with exactly the same steps with ``nvidia-docker``. 27 | 28 | Let's first pull the docker image from Docker Hub. Note that you may need ``sudo`` for all docker comamnds in this guide if you do not configure your docker permission. 29 | 30 | .. code-block:: bash 31 | 32 | docker pull comaniac0422/lorien:ubuntu-18.04- 33 | 34 | Now the docker image is available on your machine. We then create a container and log in it using the provided script (run it in the Lorien root directory): 35 | 36 | .. code-block:: bash 37 | 38 | cd docker; ./bash.sh comaniac0422/lorien:ubuntu-18.04- 39 | 40 | You can get the script by cloning Lorien, or copy it from `bash.sh `_. The ``bash.sh`` script creates a container with the given docker image and execute a command. If no command is provided as shown in the above example, then it launches a bash shell for you to interact with. Note that your AWS credential on the host machine will not be available in a docker container, so you have to specify them in environment variables if you want Lorien to access AWS services (e.g., S3, DynamoDB, batch). ``bash.sh`` will set up an AWS credential in container according to the following environment variables. 41 | 42 | .. code-block:: bash 43 | 44 | export AWS_REGION="us-west-1" 45 | export AWS_ACCESS_KEY="OOOOXXX" 46 | export AWS_SECRET_ACCESS_KEY="XXXYYY" 47 | cd docker; ./bash.sh comaniac0422/lorien:ubuntu-18.04-v0.01 48 | 49 | 50 | After you login the container, you can clone and install Lorien (the prebuilt docker images do not have Lorien deployed as it is used for CI). 51 | 52 | .. code-block:: bash 53 | 54 | git clone https://github.com/comaniac/lorien.git 55 | python3 setup.py install 56 | 57 | 58 | That's all. You can now use Lorien! Try out some commands in Lorien README and have fun. 59 | 60 | ############ 61 | Docker Files 62 | ############ 63 | Check out `docker files `_ for Lorien docker files if you want to build your own docker images for your platform or newer dependencies. 64 | -------------------------------------------------------------------------------- /docs/source/tutorials/dialects.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | .. _dialects: 19 | 20 | ################# 21 | Add a New Dialect 22 | ################# 23 | 24 | TBA 25 | -------------------------------------------------------------------------------- /docs/source/tutorials/extract_feature.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | .. _extract-feature: 19 | 20 | ##################################### 21 | Feature Extraction for Model Training 22 | ##################################### 23 | 24 | TBA 25 | -------------------------------------------------------------------------------- /docs/source/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | ######### 19 | Tutorials 20 | ######### 21 | 22 | In this gallery, we make some useful tutorials for you to quickly get familiar with Lorien. 23 | 24 | 25 | ****** 26 | Tuning 27 | ****** 28 | 29 | .. toctree:: 30 | :maxdepth: 1 31 | 32 | tune_on_local 33 | tune_on_aws_batch 34 | tune_on_rpc 35 | 36 | ********************** 37 | Performance Cost Model 38 | ********************** 39 | 40 | .. toctree:: 41 | :maxdepth: 1 42 | 43 | extract_feature 44 | train_model 45 | 46 | ************ 47 | New Dialects 48 | ************ 49 | 50 | .. toctree:: 51 | :maxdepth: 1 52 | 53 | dialects 54 | -------------------------------------------------------------------------------- /docs/source/tutorials/train_model.rst: -------------------------------------------------------------------------------- 1 | .. Licensed to the Apache Software Foundation (ASF) under one 2 | or more contributor license agreements. See the NOTICE file 3 | distributed with this work for additional information 4 | regarding copyright ownership. The ASF licenses this file 5 | to you under the Apache License, Version 2.0 (the 6 | "License"); you may not use this file except in compliance 7 | with the License. You may obtain a copy of the License at 8 | 9 | .. http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | .. Unless required by applicable law or agreed to in writing, 12 | software distributed under the License is distributed on an 13 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 | KIND, either express or implied. See the License for the 15 | specific language governing permissions and limitations 16 | under the License. 17 | 18 | .. _train-model: 19 | 20 | ########################## 21 | Performance Model Training 22 | ########################## 23 | 24 | TBA 25 | -------------------------------------------------------------------------------- /lorien/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # pylint: disable=redefined-builtin, wildcard-import 15 | """The top module""" 16 | 17 | from . import dialect, generate, tune 18 | 19 | # Current version. 20 | __version__ = "0.1.dev0" 21 | -------------------------------------------------------------------------------- /lorien/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The package and console entry.""" 15 | 16 | from lorien.main import Main 17 | 18 | if __name__ == "__main__": 19 | Main() 20 | -------------------------------------------------------------------------------- /lorien/database/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Database Module.""" 15 | 16 | from .table import create_table, delete_table, list_tables 17 | -------------------------------------------------------------------------------- /lorien/database/table.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The module to interact with DynamoDB. 16 | """ 17 | from typing import Any, Dict, Generator, Optional, List 18 | 19 | import boto3 20 | 21 | from ..logger import get_logger 22 | 23 | log = get_logger("Database") 24 | 25 | 26 | def create_table(table_name: str, **db_kwargs) -> str: 27 | """Create an empty table in the DynamoDB if the table does not exist. 28 | 29 | Parameters 30 | ---------- 31 | table_name: str 32 | The table name. 33 | 34 | **db_kwargs 35 | The kwargs of boto3 client. Commonly used: "region_name='us-west-1'" 36 | or "endpoint_url=http://localhost:8000". 37 | 38 | Returns 39 | ------- 40 | arn: str 41 | The table ARN (Amazon Resource Name). 42 | """ 43 | 44 | # Check if the table exists. 45 | try: 46 | client = boto3.client("dynamodb", **db_kwargs) 47 | resp = client.describe_table(TableName=table_name) 48 | log.info("Table %s exists", table_name) 49 | return resp["Table"]["TableArn"] 50 | except Exception as err: # pylint:disable=broad-except 51 | pass 52 | 53 | # Key attributes in the table. 54 | attrs = [ 55 | {"AttributeName": "Target", "AttributeType": "S"}, 56 | {"AttributeName": "PrimaryRangeKey", "AttributeType": "S"}, 57 | {"AttributeName": "TargetIDKeys", "AttributeType": "S"}, 58 | ] 59 | 60 | key_schema = [ 61 | {"AttributeName": "Target", "KeyType": "HASH"}, 62 | {"AttributeName": "PrimaryRangeKey", "KeyType": "RANGE"}, 63 | ] 64 | 65 | global_secondary_indexes = [ 66 | { 67 | "IndexName": "TargetIDKeysIndex", 68 | "KeySchema": [ 69 | {"AttributeName": "TargetIDKeys", "KeyType": "HASH"}, 70 | {"AttributeName": "PrimaryRangeKey", "KeyType": "RANGE"}, 71 | ], 72 | "Projection": {"ProjectionType": "ALL"}, 73 | "ProvisionedThroughput": {"ReadCapacityUnits": 100, "WriteCapacityUnits": 10}, 74 | } 75 | ] 76 | 77 | try: 78 | client = boto3.client("dynamodb", **db_kwargs) 79 | resp = client.create_table( 80 | TableName=table_name, 81 | AttributeDefinitions=attrs, 82 | KeySchema=key_schema, 83 | GlobalSecondaryIndexes=global_secondary_indexes, 84 | ProvisionedThroughput={"ReadCapacityUnits": 100, "WriteCapacityUnits": 10}, 85 | ) 86 | log.info("Table %s created successfully", table_name) 87 | return resp["TableDescription"]["TableArn"] 88 | except Exception as err: # pylint:disable=broad-except 89 | raise RuntimeError("Error creating table %s: %s" % (table_name, str(err))) 90 | 91 | 92 | def delete_table(table_name: str, **db_kwargs) -> None: 93 | """Delete the given table in the database. 94 | 95 | Parameters 96 | ---------- 97 | table_name: str 98 | The table name in string. 99 | 100 | **db_kwargs 101 | The kwargs of boto3 client. Commonly used: "region_name='us-west-1'" 102 | or "endpoint_url=http://localhost:8000". 103 | """ 104 | client = boto3.client("dynamodb", **db_kwargs) 105 | client.delete_table(TableName=table_name) 106 | log.info("Table %s has been deleted", table_name) 107 | 108 | 109 | def check_table(table_name: str, table_arn: str, **db_kwargs) -> bool: 110 | """Check if the DynamoDB table name and ARN match the one this worker can access to. 111 | 112 | Parameters 113 | ---------- 114 | table_name: str 115 | Table name. 116 | 117 | table_arn: str 118 | Table Amazon Resource Name. 119 | 120 | **db_kwargs 121 | The kwargs of boto3 client. Commonly used: "region_name='us-west-1'" 122 | or "endpoint_url=http://localhost:8000". 123 | 124 | Returns 125 | ------- 126 | success: bool 127 | False if the table and ARN does not exist in the DynamoDB. 128 | """ 129 | try: 130 | client = boto3.client("dynamodb", **db_kwargs) 131 | resp = client.describe_table(TableName=table_name) 132 | return table_arn == resp["Table"]["TableArn"] 133 | except Exception: # pylint:disable=broad-except 134 | return False 135 | 136 | 137 | def list_tables(**db_kwargs) -> List[str]: 138 | """List all table names in the database. 139 | 140 | Parameters 141 | ---------- 142 | **db_kwargs 143 | The kwargs of boto3 client. Commonly used: "region_name='us-west-1'" 144 | or "endpoint_url=http://localhost:8000". 145 | 146 | Returns 147 | ------- 148 | tables: List[str] 149 | A list of sorted table names. 150 | """ 151 | try: 152 | client = boto3.client("dynamodb", **db_kwargs) 153 | return sorted(client.list_tables()["TableNames"], reverse=True) 154 | except Exception as err: # pylint:disable=broad-except 155 | raise RuntimeError("Failed to fetch the table list: %s" % str(err)) 156 | 157 | 158 | def scan_table(table_name: str, limit: Optional[int] = None, **db_kwargs) -> Generator: 159 | """Scan a DynamoDB table for all items. Note that DynamoDB only transfers 160 | at most 1 MB data per query, so you may need to invoke this generator several times 161 | to get the entire table. 162 | 163 | Parameters 164 | ---------- 165 | table_name: str 166 | The target table name to be scanned. 167 | 168 | **db_kwargs 169 | The kwargs of boto3 client. For example, use "endpoint_url=http://localhost:8000" 170 | for local DynamoDB. 171 | 172 | Returns 173 | ------- 174 | gen: Generator 175 | A generator that yields a scan query response (at most 1 MB). 176 | """ 177 | scan_kwargs: Dict[str, Any] = {"TableName": table_name} 178 | if limit is not None: 179 | scan_kwargs["Limit"] = limit 180 | 181 | try: 182 | client = boto3.client("dynamodb", **db_kwargs) 183 | resp = client.scan(**scan_kwargs) 184 | yield resp 185 | while "LastEvaluatedKey" in resp: 186 | resp = client.scan(ExclusiveStartKey=resp["LastEvaluatedKey"], **scan_kwargs) 187 | yield resp 188 | except Exception as err: # pylint:disable=board-except 189 | raise RuntimeError("Failed to scan table: %s" % str(err)) 190 | -------------------------------------------------------------------------------- /lorien/database/util.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The utilities of database manipulations. 16 | """ 17 | from typing import Any, Dict, List, Tuple, Union 18 | 19 | from ..logger import get_logger 20 | 21 | log = get_logger("DB-Util") 22 | 23 | 24 | def convert_to_db_list(orig_list: Union[Tuple[Any, ...], List[Any]]) -> Dict[str, Any]: 25 | """Convert a list to the DynamoDB list type. 26 | Note: There is no tuple type in DynamoDB so we will also convert tuples to "L" type, 27 | which is also a list. 28 | 29 | Parameters 30 | ---------- 31 | orig_list: List[Any] 32 | The native list. 33 | 34 | Returns 35 | ------- 36 | new_list: Dict[str, Any] 37 | The DynamoDB list: {'L': []}. 38 | """ 39 | new_list: List[Any] = [] 40 | for elt in orig_list: 41 | if isinstance(elt, str): 42 | new_list.append({"S": elt}) 43 | elif isinstance(elt, (int, float)): 44 | new_list.append({"N": str(elt)}) 45 | elif isinstance(elt, (list, tuple)): 46 | new_list.append(convert_to_db_list(elt)) 47 | elif isinstance(elt, dict): 48 | new_list.append(convert_to_db_dict(elt)) 49 | elif elt is None: 50 | new_list.append({"S": "None"}) 51 | else: 52 | raise RuntimeError("Cannot convert %s (%s)" % (str(elt), type(elt))) 53 | 54 | return {"L": new_list} 55 | 56 | 57 | def convert_to_db_dict(orig_dict: Dict[str, Any]) -> Dict[str, Any]: 58 | """Convert a dict to the DynamoDB dict form. 59 | 60 | Parameters 61 | ---------- 62 | orig_dict: Dict[str, Any] 63 | The native dict. 64 | 65 | Returns 66 | ------- 67 | new_dict: Dict[str, Any] 68 | The DynamoDB dict: {'M': {}}. 69 | """ 70 | new_dict: Dict[str, Any] = {} 71 | for key, val in orig_dict.items(): 72 | if isinstance(val, str): 73 | new_dict[key] = {"S": val} 74 | elif isinstance(val, (int, float)): 75 | new_dict[key] = {"N": str(val)} 76 | elif isinstance(val, (list, tuple)): 77 | new_dict[key] = convert_to_db_list(val) 78 | elif isinstance(val, dict): 79 | new_dict[key] = convert_to_db_dict(val) 80 | elif val is None: 81 | new_dict[key] = {"S": "None"} 82 | else: 83 | raise RuntimeError("Cnanot convert %s (%s)" % (str(val), type(val))) 84 | 85 | return {"M": new_dict} 86 | 87 | 88 | def convert_to_list(db_list: Dict[str, Any]) -> List[Any]: 89 | """Convert a DynamoDB list to a native list. 90 | 91 | Parameters 92 | ---------- 93 | db_list: Dict[str, Any] 94 | A DynamoDB list: {'L': []}. 95 | 96 | Returns 97 | ------- 98 | new_list: List[Any] 99 | A native list. 100 | """ 101 | if "L" not in db_list: 102 | raise RuntimeError("Not a DynamoDB list: %s" % (str(db_list))) 103 | 104 | new_list: List[Any] = [] 105 | for elt in db_list["L"]: 106 | assert len(elt) == 1 107 | dtype = list(elt.keys())[0] 108 | 109 | if dtype == "S": 110 | new_list.append(str(elt[dtype]) if elt[dtype] != "None" else None) 111 | elif dtype == "N": 112 | new_list.append(float(elt[dtype])) 113 | elif dtype == "L": 114 | new_list.append(convert_to_list(elt)) 115 | elif dtype == "M": 116 | new_list.append(convert_to_dict(elt)) 117 | else: 118 | raise RuntimeError("Cannot convert %s (%s)" % (str(elt), dtype)) 119 | 120 | return new_list 121 | 122 | 123 | def convert_to_dict(db_dict: Dict[str, Any]) -> Dict[str, Any]: 124 | """Convert a DynamoDB dict to a native dict. 125 | 126 | Parameters 127 | ---------- 128 | db_dict: Dict[str, Any] 129 | A DynamoDB dict: {'M': {}}. 130 | 131 | Returns 132 | ------- 133 | new_dict: Dict[str, Any] 134 | A native dict. 135 | """ 136 | if "M" not in db_dict: 137 | raise RuntimeError("Not a DynamoDB dict: %s" % str(db_dict)) 138 | 139 | new_dict: Dict[str, Any] = {} 140 | for key, elt in db_dict["M"].items(): 141 | dtype = list(elt.keys())[0] 142 | 143 | if dtype == "S": 144 | new_dict[key] = str(elt[dtype]) if elt[dtype] != "None" else None 145 | elif dtype == "N": 146 | new_dict[key] = float(elt[dtype]) 147 | elif dtype == "L": 148 | new_dict[key] = convert_to_list(elt) 149 | elif dtype == "M": 150 | new_dict[key] = convert_to_dict(elt) 151 | else: 152 | raise RuntimeError("Cannot convert %s (%s)" % (str(elt), dtype)) 153 | 154 | return new_dict 155 | -------------------------------------------------------------------------------- /lorien/dialect/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Load Dialect Modules 16 | """ 17 | from typing import Set 18 | import importlib 19 | 20 | AVAILABLE_DIALECTS: Set[str] = set() 21 | 22 | if importlib.util.find_spec("tvm") is not None: 23 | import tvm 24 | 25 | from . import tvm_dial 26 | from .tvm_dial.util import check_tvm_version 27 | 28 | # Minimum supported version of TVM. 29 | TVM_MIN_VERSION = "0.8.dev0" 30 | 31 | # Check if the TVM version is supported. 32 | TVM_VERSION = tvm.__version__ 33 | if not check_tvm_version(TVM_VERSION, TVM_MIN_VERSION): 34 | raise RuntimeError("Unsatisfied TVM version (>= %s): %s" % (TVM_MIN_VERSION, TVM_VERSION)) 35 | 36 | AVAILABLE_DIALECTS.add("tvm") 37 | AVAILABLE_DIALECTS.add("autotvm") 38 | AVAILABLE_DIALECTS.add("auto_scheduler") 39 | -------------------------------------------------------------------------------- /lorien/dialect/tvm_dial/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | TVM dialect Modules 16 | """ 17 | 18 | from . import auto_scheduler_dial, autotvm_dial 19 | -------------------------------------------------------------------------------- /lorien/dialect/tvm_dial/auto_scheduler_dial/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | TVM auto_scheduler dialects. 16 | """ 17 | from . import extract 18 | from .job import AutoSchedulerJob 19 | from .result import AutoSchedulerTuneResult 20 | from .workload import AutoSchedulerWorkload 21 | -------------------------------------------------------------------------------- /lorien/dialect/tvm_dial/auto_scheduler_dial/extract.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The module of TVM auto_scheduler workload extraction from DNN models. 16 | """ 17 | import argparse 18 | from typing import Set 19 | 20 | import tqdm 21 | from tvm import auto_scheduler 22 | 23 | from ....configs import create_config_parser, register_config_parser 24 | from ....generate import gen 25 | from ....logger import get_logger 26 | from ..frontend_parser import EXTRACTOR_FUNC_TABLE 27 | from .workload import AutoSchedulerWorkload 28 | 29 | log = get_logger("Extract") 30 | 31 | 32 | def extract_from_models(configs: argparse.Namespace): 33 | """Extract op workloads from a given model. 34 | 35 | Parameters 36 | ---------- 37 | configs: argparse.Namespace 38 | The system configure of generate.extract-from-model. 39 | 40 | Returns 41 | ------- 42 | workloads: List[Workload] 43 | A list of collected workloads. 44 | """ 45 | 46 | # Extract operator worklaods from models. 47 | workloads: Set[AutoSchedulerWorkload] = set() 48 | for framework in ["gcv", "keras", "onnx", "torch", "tf", "tflite", "mxnet"]: 49 | if not getattr(configs, framework): 50 | continue 51 | mod_n_params = EXTRACTOR_FUNC_TABLE[framework](configs) 52 | 53 | # Extract workloads from models. 54 | progress = tqdm.tqdm( 55 | total=len(mod_n_params), desc="", bar_format="{desc} {percentage:3.0f}%|{bar:50}{r_bar}" 56 | ) 57 | for name, (mod, params) in mod_n_params: 58 | for target in configs.target: 59 | progress.set_description_str(name, refresh=True) 60 | tasks, _ = auto_scheduler.extract_tasks( 61 | mod, 62 | target=target, 63 | params=params, 64 | include_simple_tasks=configs.include_simple_tasks, 65 | ) 66 | 67 | # Task to workload 68 | for task in tasks: 69 | try: 70 | workloads.add(AutoSchedulerWorkload.from_task(task)) 71 | except RuntimeError as err: 72 | log.warning( 73 | "Failed to create workload from task %s: %s", str(task), str(err) 74 | ) 75 | continue 76 | progress.update(1) 77 | 78 | log.info("%d operator workloads have been generated", len(workloads)) 79 | return list(workloads) 80 | 81 | 82 | @register_config_parser("top.generate.auto_scheduler") 83 | def define_config_extract() -> argparse.ArgumentParser: 84 | """Define the command line interface for workload generation by model extraction. 85 | 86 | Returns 87 | ------- 88 | parser: argparse.ArgumentParser 89 | The defined argument parser. 90 | """ 91 | parser = create_config_parser("AutoSchedulerWorkload Generation by Model Extraction") 92 | 93 | common_desc = ( 94 | "A {0} with input shape in YAML format: " 95 | '": {{: []}}". When shape is ignored, ' 96 | 'the default input {{"{1}": ({2})}} will be applied' 97 | ) 98 | 99 | parser.add_argument( 100 | "--gcv", 101 | action="append", 102 | default=[], 103 | required=False, 104 | help=common_desc.format("Gluon CV model name", "data", "1, 3, 224, 224"), 105 | ) 106 | parser.add_argument( 107 | "--keras", 108 | action="append", 109 | default=[], 110 | required=False, 111 | help=common_desc.format("Keras model file path", "input_1", "1, 3, 224, 224"), 112 | ) 113 | parser.add_argument( 114 | "--onnx", 115 | action="append", 116 | default=[], 117 | required=False, 118 | help=common_desc.format("ONNX model file path", "input", "1, 3, 224, 224"), 119 | ) 120 | parser.add_argument( 121 | "--torch", 122 | action="append", 123 | default=[], 124 | required=False, 125 | help=common_desc.format("PyTorch model file path", "input", "1, 3, 224, 224"), 126 | ) 127 | parser.add_argument( 128 | "--tf", 129 | action="append", 130 | default=[], 131 | required=False, 132 | help=common_desc.format("TensorFlow model file path", "Placeholder", "1, 224, 224, 3"), 133 | ) 134 | parser.add_argument( 135 | "--tflite", 136 | action="append", 137 | default=[], 138 | required=False, 139 | help=common_desc.format("TFLite model file path", "Placeholder", "1, 224, 224, 3"), 140 | ) 141 | parser.add_argument( 142 | "--mxnet", 143 | action="append", 144 | default=[], 145 | required=False, 146 | help=common_desc.format("MXNet model file path", "data", "1, 3, 224, 224"), 147 | ) 148 | parser.add_argument( 149 | "--target", 150 | action="append", 151 | default=[], 152 | required=True, 153 | help="A TVM target (e.g., llvm, cuda, etc). " 154 | "Note that the device tag (e.g., -model=v100) is not required.", 155 | ) 156 | parser.add_argument( 157 | "-o", "--output", default="auto_scheduler_workloads.yaml", help="The output file path" 158 | ) 159 | parser.add_argument( 160 | "--include-simple-tasks", 161 | default=False, 162 | action="store_true", 163 | help="Whether to extract simple tasks (without complex operators)", 164 | ) 165 | parser.set_defaults(entry=gen(extract_from_models), validate_task=False) 166 | return parser 167 | -------------------------------------------------------------------------------- /lorien/dialect/tvm_dial/auto_scheduler_dial/result.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Tuning result of auto_scheduler dialect. 16 | """ 17 | import heapq 18 | from typing import Any, Dict, List, Optional, Sequence, Tuple 19 | 20 | import numpy as np 21 | from tvm.auto_scheduler.measure import MeasureErrorNo, MeasureInput, MeasureResult 22 | from tvm.auto_scheduler.measure_record import dump_record_to_string, load_record_from_string 23 | 24 | from ....logger import get_logger 25 | from ....tune.result import TuneResult 26 | from ....workload import Workload 27 | from ..result import TVMRecords 28 | from ..util import gen_target_id_keys, get_canonical_tvm_target_str 29 | from .workload import AutoSchedulerWorkload 30 | 31 | log = get_logger("AutoScheduler-Result") 32 | 33 | 34 | class AutoSchedulerRecords(TVMRecords[Tuple[MeasureInput, MeasureResult]]): 35 | """The container to maintain the records of an auto_scheduler task.""" 36 | 37 | def __init__(self, target_key: str, workload_key: Optional[str] = None): 38 | """Initialize records. 39 | 40 | Parameters 41 | ---------- 42 | target_key: str 43 | The target key (partition key) of the workload. 44 | 45 | workload_key: Optional[str] 46 | The workload key (sort key) of the workload. If not presneted, 47 | then all records with the same target_key will be fetched when querying. 48 | """ 49 | target_key = get_canonical_tvm_target_str(target_key, remove_libs=True) 50 | alter_key = gen_target_id_keys(target_key) 51 | super(AutoSchedulerRecords, self).__init__(target_key, alter_key, workload_key) 52 | self._data: List[Tuple[float, float, Tuple[MeasureInput, MeasureResult]]] = [] 53 | 54 | @staticmethod 55 | def encode(record: Tuple[MeasureInput, MeasureResult]) -> str: 56 | """Encode a record to a string.""" 57 | return dump_record_to_string(*record) 58 | 59 | @staticmethod 60 | def decode(record_str: str) -> Tuple[MeasureInput, MeasureResult]: 61 | """Decode a string to a record.""" 62 | return load_record_from_string(record_str) 63 | 64 | def gen_task_item(self) -> Dict[str, Any]: 65 | """No additional attribute is required for auto_scheduler.""" 66 | return {} 67 | 68 | @staticmethod 69 | def gen_record_item(record: Tuple[MeasureInput, MeasureResult]): 70 | """Generate an item for a record that can be appended to the task item.""" 71 | return {"latency": np.mean([v.value for v in record[1].costs])} 72 | 73 | def push(self, record: Tuple[MeasureInput, MeasureResult]): 74 | """Push a new record. 75 | 76 | Parameters 77 | ---------- 78 | record: Any 79 | The record to be pushed. 80 | """ 81 | # Push with -cost as heapq is min-heap as we want the worst record on the top. 82 | heapq.heappush( 83 | self._data, (-np.mean([v.value for v in record[1].costs]), record[1].timestamp, record) 84 | ) 85 | 86 | def pop(self) -> Tuple[MeasureInput, MeasureResult]: 87 | """Pop the worst record in the container and remove the cost.""" 88 | return heapq.heappop(self._data)[2] 89 | 90 | def peak(self) -> Tuple[MeasureInput, MeasureResult]: 91 | """Peak the first record.""" 92 | assert self._data 93 | return self._data[0][2] 94 | 95 | def to_list(self, nbest: int = -1) -> List[Tuple[MeasureInput, MeasureResult]]: 96 | """Sort the record (of any layout) to be a list and return the best N. 97 | 98 | Parameters 99 | ---------- 100 | nbest: int 101 | The best N records to be returned. Default to return all. 102 | 103 | Returns 104 | ------- 105 | records: List[Tuple[MeasureInput, MeasureResult]] 106 | The sorted list of records. 107 | """ 108 | nbest = nbest if nbest != -1 else len(self._data) 109 | return [item[2] for item in heapq.nsmallest(nbest, self._data)] 110 | 111 | def __len__(self) -> int: 112 | return len(self._data) 113 | 114 | 115 | class AutoSchedulerTuneResult(TuneResult): 116 | """The result of a tuning job.""" 117 | 118 | @staticmethod 119 | def create_records_by_workloads( 120 | log_file_path: str, nbest: int, workload: Optional[Workload] = None 121 | ) -> Sequence[AutoSchedulerRecords]: 122 | """Parse records from the tuning log and group them by workloads. 123 | 124 | Parameters 125 | ---------- 126 | log_file_path: str 127 | The log file path. 128 | 129 | nbest: int 130 | The maximum number of best records to be kept. 131 | 132 | workload: Optional[Workload] 133 | The target workload. If presented, the returnede map will only have one entry. 134 | 135 | Returns 136 | ------- 137 | records: Sequence[AutoSchedulerRecords] 138 | A list of created records. 139 | """ 140 | target_workload_key = workload.get_workload_key() if workload is not None else None 141 | 142 | best_records: Dict[str, AutoSchedulerRecords] = {} 143 | with open(log_file_path, "r") as filep: 144 | for line in filep: 145 | if line[0] == "#" or line[0] == " " or line[0] == "\n": 146 | continue 147 | 148 | inp, res = load_record_from_string(line) 149 | workload_key = AutoSchedulerWorkload.from_task(inp.task).get_workload_key() 150 | if target_workload_key is not None and workload_key != target_workload_key: 151 | continue 152 | 153 | if workload_key not in best_records: 154 | best_records[workload_key] = AutoSchedulerRecords(inp.task.target, workload_key) 155 | curr_records = best_records[workload_key] 156 | 157 | if res.error_no != MeasureErrorNo.NO_ERROR: 158 | continue 159 | 160 | curr_records.push((inp, res)) 161 | if len(curr_records) > nbest: 162 | curr_records.pop() 163 | return list(best_records.values()) 164 | 165 | @staticmethod 166 | def gen_features(log_file_path: str, out_path: str): 167 | """Featurize tuning logs to be input features of the performance cost model. 168 | 169 | Parameters 170 | ---------- 171 | log_file_path: str 172 | The log file path. It can be a file path for a single file, or a directory of 173 | several log files. 174 | 175 | out_path: str 176 | The path to write generated features and artifacts. 177 | """ 178 | raise RuntimeError("Feature extraction is not supported yet in AutoScheduler dialect") 179 | -------------------------------------------------------------------------------- /lorien/dialect/tvm_dial/auto_scheduler_dial/workload.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | auto_scheduler Workload Definition. 16 | """ 17 | import pickle 18 | from typing import Any, Sequence 19 | 20 | from ruamel.yaml import YAML, yaml_object 21 | from tvm.auto_scheduler.search_task import SearchTask 22 | 23 | from ....tune.job import Job 24 | from ....workload import Workload 25 | from ..util import get_canonical_tvm_target_str 26 | 27 | 28 | @yaml_object(YAML()) 29 | class AutoSchedulerWorkload(Workload): 30 | """The workload for an op. 31 | A workload can be used to created an AutoScheduler task for tuning. 32 | """ 33 | 34 | def __init__(self): 35 | super(AutoSchedulerWorkload, self).__init__() 36 | self.workload_key: str = "" 37 | self.task_pickle: bytes = b"" 38 | self.dag_repr: str = "" 39 | 40 | @classmethod 41 | def from_task(cls, task: SearchTask) -> "AutoSchedulerWorkload": 42 | """Create a workload from an AutoScheduler task. 43 | 44 | Parameters 45 | ---------- 46 | task: SearchTask 47 | The AutoScheduler task for the workload. 48 | 49 | Returns 50 | ------- 51 | workload: Workload 52 | The initialized workload. 53 | """ 54 | 55 | workload = cls() 56 | 57 | assert task.target is not None 58 | workload.workload_key = str(task.workload_key) 59 | workload.target = get_canonical_tvm_target_str(task.target, task) 60 | workload.task_pickle = pickle.dumps(task) 61 | workload.dag_repr = repr(task.compute_dag) 62 | return workload 63 | 64 | def to_task(self) -> SearchTask: 65 | """Create an AutoScheduler task from this workload. 66 | 67 | Returns 68 | ------- 69 | task: SearchTask 70 | Return the created task, or raise RuntimeError if failed. 71 | """ 72 | # Try to create task. 73 | try: 74 | task = pickle.loads(self.task_pickle) 75 | except Exception as err: # pylint: disable=broad-except 76 | raise RuntimeError( 77 | "Failed to create the task for workload {0}: {1}".format(str(self), str(err)) 78 | ) 79 | 80 | return task 81 | 82 | def to_job(self) -> Job: 83 | """Create a job to tune this workload. 84 | 85 | Returns 86 | ------- 87 | job: Job 88 | The created job. 89 | """ 90 | from .job import AutoSchedulerJob # Avoid circular import dependency. 91 | 92 | return AutoSchedulerJob(self) 93 | 94 | def get_workload_key(self) -> str: 95 | """Get the primary key of this workload in DB. 96 | 97 | Returns 98 | ------- 99 | key: str 100 | The primary key of this workload to index the records in DB. 101 | """ 102 | return self.workload_key 103 | 104 | def mutate(self, rules: Any) -> Sequence["Workload"]: 105 | """auto_scheduler task mutation is not supported yet. 106 | 107 | Parameters 108 | ---------- 109 | workload: Workload 110 | The workload to be mutated. 111 | 112 | rules: Any 113 | The mutation rules that can be customized. 114 | 115 | Returns 116 | ------- 117 | workloads: Sequence[Workload] 118 | The mutated workloads. 119 | """ 120 | raise NotImplementedError 121 | 122 | def __lt__(self, other: object) -> bool: 123 | assert isinstance(other, AutoSchedulerWorkload) 124 | for key in ["workload_key", "target"]: 125 | if getattr(self, key) != getattr(other, key): 126 | return getattr(self, key) < getattr(other, key) 127 | return False 128 | 129 | def __str__(self) -> str: 130 | return repr(self) 131 | 132 | def __repr__(self): 133 | return "%s(workload_key=%r,target=%r,dag=%s)" % ( 134 | self.__class__.__name__, 135 | self.workload_key, 136 | self.target, 137 | self.dag_repr, 138 | ) 139 | -------------------------------------------------------------------------------- /lorien/dialect/tvm_dial/autotvm_dial/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | AutoTVM dialects. 16 | """ 17 | import argparse 18 | 19 | from ....configs import create_config_parser, register_config_parser 20 | 21 | from . import extract_from_model, extract_from_record 22 | from .job import AutoTVMJob 23 | from .result import AutoTVMTuneResult 24 | from .workload import AutoTVMWorkload 25 | 26 | 27 | @register_config_parser("top.generate.autotvm") 28 | def define_config() -> argparse.ArgumentParser: 29 | """Define the command line interface for AutoTVM workload generation. 30 | 31 | Returns 32 | ------- 33 | parser: argparse.ArgumentParser 34 | The defined argument parser. 35 | """ 36 | parser = create_config_parser("AutoTVM Workload Generation") 37 | 38 | # Define generators 39 | subparsers = parser.add_subparsers( 40 | dest="mode", description="The mode to generate AutoTVM workloads" 41 | ) 42 | subparsers.required = True 43 | return parser 44 | -------------------------------------------------------------------------------- /lorien/dialect/tvm_dial/autotvm_dial/extract_from_model.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The module of AutoTVM workload extraction from DNN models. 16 | """ 17 | import argparse 18 | from typing import Set 19 | 20 | import tqdm 21 | from tvm import autotvm 22 | 23 | from ....configs import create_config_parser, register_config_parser 24 | from ....generate import gen 25 | from ....logger import get_logger 26 | from ..frontend_parser import EXTRACTOR_FUNC_TABLE 27 | from .workload import AutoTVMWorkload 28 | 29 | log = get_logger("Extract") 30 | 31 | 32 | def extract_from_models(configs: argparse.Namespace): 33 | """Extract op workloads from a given model. 34 | 35 | Parameters 36 | ---------- 37 | configs: argparse.Namespace 38 | The system configure of generate.extract-from-model. 39 | 40 | Returns 41 | ------- 42 | workloads: List[Workload] 43 | A list of collected workloads. 44 | """ 45 | 46 | # Extract operator worklaods from models. 47 | workloads: Set[AutoTVMWorkload] = set() 48 | for framework in ["gcv", "keras", "onnx", "torch", "tf", "tflite", "mxnet"]: 49 | if not getattr(configs, framework): 50 | continue 51 | mod_n_params = EXTRACTOR_FUNC_TABLE[framework](configs) 52 | 53 | # Extract workloads from models. 54 | progress = tqdm.tqdm( 55 | total=len(mod_n_params), desc="", bar_format="{desc} {percentage:3.0f}%|{bar:50}{r_bar}" 56 | ) 57 | for name, (mod, params) in mod_n_params: 58 | for target in configs.target: 59 | progress.set_description_str(name, refresh=True) 60 | tasks = autotvm.task.extract_from_program(mod, target=target, params=params) 61 | 62 | # Task to workload 63 | for task in tasks: 64 | try: 65 | workloads.add(AutoTVMWorkload.from_task(task)) 66 | except RuntimeError as err: 67 | log.warning( 68 | "Failed to create workload from task %s: %s", str(task), str(err) 69 | ) 70 | continue 71 | progress.update(1) 72 | 73 | log.info("%d operator workloads have been generated", len(workloads)) 74 | return list(workloads) 75 | 76 | 77 | @register_config_parser("top.generate.autotvm.extract-from-model") 78 | def define_config_extract() -> argparse.ArgumentParser: 79 | """Define the command line interface for workload generation by model extraction. 80 | 81 | Returns 82 | ------- 83 | parser: argparse.ArgumentParser 84 | The defined argument parser. 85 | """ 86 | parser = create_config_parser("Workload Generation by Model Extraction") 87 | 88 | common_desc = ( 89 | "A {0} with input shape in YAML format: " 90 | '": {{: []}}". When shape is ignored, ' 91 | 'the default input {{"{1}": ({2})}} will be applied' 92 | ) 93 | 94 | parser.add_argument( 95 | "--gcv", 96 | action="append", 97 | default=[], 98 | required=False, 99 | help=common_desc.format("Gluon CV model name", "data", "1, 3, 224, 224"), 100 | ) 101 | parser.add_argument( 102 | "--keras", 103 | action="append", 104 | default=[], 105 | required=False, 106 | help=common_desc.format("Keras model file path", "input_1", "1, 3, 224, 224"), 107 | ) 108 | parser.add_argument( 109 | "--onnx", 110 | action="append", 111 | default=[], 112 | required=False, 113 | help=common_desc.format("ONNX model file path", "input", "1, 3, 224, 224"), 114 | ) 115 | parser.add_argument( 116 | "--torch", 117 | action="append", 118 | default=[], 119 | required=False, 120 | help=common_desc.format("PyTorch model file path", "input", "1, 3, 224, 224"), 121 | ) 122 | parser.add_argument( 123 | "--tf", 124 | action="append", 125 | default=[], 126 | required=False, 127 | help=common_desc.format("TensorFlow model file path", "Placeholder", "1, 224, 224, 3"), 128 | ) 129 | parser.add_argument( 130 | "--tflite", 131 | action="append", 132 | default=[], 133 | required=False, 134 | help=common_desc.format("TFLite model file path", "Placeholder", "1, 224, 224, 3"), 135 | ) 136 | parser.add_argument( 137 | "--mxnet", 138 | action="append", 139 | default=[], 140 | required=False, 141 | help=common_desc.format("MXNet model file path", "data", "1, 3, 224, 224"), 142 | ) 143 | parser.add_argument( 144 | "--target", 145 | action="append", 146 | default=[], 147 | required=True, 148 | help="A TVM target (e.g., llvm, cuda, etc). " 149 | "Note that the device tag (e.g., -model=v100) is not required.", 150 | ) 151 | parser.add_argument( 152 | "-o", "--output", default="autotvm_workloads.yaml", help="The output file path" 153 | ) 154 | parser.set_defaults(entry=gen(extract_from_models), validate_task=False) 155 | return parser 156 | -------------------------------------------------------------------------------- /lorien/dialect/tvm_dial/autotvm_dial/extract_from_record.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The module of AutoTVM workload extraction from tuning records. The graph-level optimization 16 | workloads such as data layout transform fit in this scenario. 17 | """ 18 | import argparse 19 | from typing import Optional, Tuple 20 | 21 | from tvm import autotvm, te 22 | from tvm.autotvm import graph_tuner # pylint: disable=unused-import 23 | from tvm.autotvm.record import MeasureInput, MeasureResult 24 | from tvm.autotvm.task import Task 25 | 26 | from ....configs import create_config_parser, register_config_parser 27 | from ....generate import gen 28 | from ....logger import get_logger 29 | from ....util import load_from_yaml 30 | from .result import AutoTVMRecords 31 | from .util import infer_task_layout 32 | from .workload import AutoTVMWorkload 33 | 34 | log = get_logger("Extract") 35 | 36 | 37 | def create_layout_transform_task(record: Tuple[MeasureInput, MeasureResult]) -> Optional[Task]: 38 | """Create an AutoTVM task of layout transform. 39 | 40 | Parameters 41 | ---------- 42 | record: Tuple[MeasureInput, MeasureResult] 43 | The AutoTVM record pair to create the layout transform task. 44 | 45 | Returns 46 | ------- 47 | task: Optional[Task] 48 | The created layout_transform task, or None if the record has no layout transform support. 49 | """ 50 | layout = infer_task_layout(record) 51 | if layout is None: 52 | return None 53 | 54 | try: 55 | in_shape, in_layout = layout[0][0] 56 | _, out_layout = layout[1][0] 57 | if in_layout == out_layout: # No need to transform layout. 58 | return None 59 | data = te.placeholder(in_shape, name="data", dtype=record[0].task.args[0][2]) 60 | args = (data, in_layout, out_layout) 61 | task = autotvm.task.create("layout_transform", args=args, target=record[0].target) 62 | return task 63 | except Exception as err: # pylint:disable=broad-except 64 | log.warning("Failed to create layout transfrom task from %s: %s", args, str(err)) 65 | return None 66 | 67 | 68 | def extract_from_records(configs: argparse.Namespace): 69 | """Extract graph optimization workloads from a given DB table. 70 | 71 | Parameters 72 | ---------- 73 | configs: argparse.Namespace 74 | The system configure of generate.extract-from-record. 75 | 76 | Returns 77 | ------- 78 | workloads: List[AutoTVMWorkload] 79 | A list of collected workloads. 80 | """ 81 | db_options = load_from_yaml(configs.db) 82 | 83 | tasks = [] 84 | for target in configs.target: 85 | records = AutoTVMRecords(target) 86 | records.query(configs.table_name, configs.ignore_target_attrs, **db_options) 87 | for record in records.to_list(): 88 | task = create_layout_transform_task(record) 89 | if task is not None: 90 | tasks.append(task) 91 | 92 | workloads = list({AutoTVMWorkload.from_task(t) for t in tasks}) 93 | log.info("%d layout transform workloads for %s have been generated", len(workloads), target) 94 | 95 | return workloads 96 | 97 | 98 | @register_config_parser("top.generate.autotvm.extract-from-record") 99 | def define_config_extract() -> argparse.ArgumentParser: 100 | """Define the command line interface for workload generation by record extraction. 101 | 102 | Returns 103 | ------- 104 | parser: argparse.ArgumentParser 105 | The defined argument parser. 106 | """ 107 | parser = create_config_parser("AutoTVMWorkload Generation by Tuning Record Extraction") 108 | parser.add_argument("--table-name", required=True, help="The DynamoDB table name") 109 | parser.add_argument("--db", default="{ }", help="DynamoDB client options in YAML format") 110 | parser.add_argument( 111 | "--ignore-target-attrs", 112 | default=False, 113 | action="store_true", 114 | help="Only use target ID (e.g., llvm) and keys (e.g., cpu) " 115 | "instead of the full target string when querying records.", 116 | ) 117 | parser.add_argument( 118 | "--target", 119 | action="append", 120 | default=[], 121 | required=True, 122 | help="A TVM target (e.g., llvm, cuda, etc). " 123 | "Note that the device tag (e.g., -model=v100) is not required.", 124 | ) 125 | parser.add_argument( 126 | "-o", "--output", default="autotvm_workloads.yaml", help="The output file path" 127 | ) 128 | parser.set_defaults(entry=gen(extract_from_records), validate_task=False) 129 | return parser 130 | -------------------------------------------------------------------------------- /lorien/dialect/tvm_dial/autotvm_dial/util.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Utility functions for AutoTVM dialects 16 | """ 17 | from typing import Callable, Optional, Tuple 18 | 19 | from tvm.autotvm.graph_tuner.base_graph_tuner import get_infer_layout 20 | from tvm.autotvm.measure import MeasureInput, MeasureResult 21 | 22 | 23 | def infer_task_layout(record: Tuple[MeasureInput, MeasureResult]) -> Optional[Tuple]: 24 | """Infer the layout of the given task. Return None if the layout cannot be inferred. 25 | 26 | Parameters 27 | ---------- 28 | record: Tuple[MeasureInput, MeasureResult] 29 | The AutoTVM record pair. 30 | 31 | Return 32 | ------ 33 | layout: Optional[Tuple] 34 | A tuple of input and output layout, or None if not inferrable. 35 | """ 36 | infer_layout_func: Optional[Callable] = None 37 | try: 38 | infer_layout_func = get_infer_layout(record[0].task.name) 39 | assert infer_layout_func is not None 40 | with record[0].target: 41 | return infer_layout_func(record[0].task.workload, record[0].config) 42 | except ValueError: 43 | pass 44 | return None 45 | -------------------------------------------------------------------------------- /lorien/dialect/tvm_dial/job.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | TVM Job definition module. 16 | """ 17 | import argparse 18 | from typing import Any, Dict, List, Optional, Tuple 19 | 20 | from ...configs import append_config_parser 21 | from ...logger import get_logger 22 | from ...tune.job import Job, JobConfigs 23 | from .util import get_canonical_tvm_target_str, get_tvm_build_config, is_cover 24 | 25 | log = get_logger("TVMTuneJob") 26 | 27 | 28 | class TuneMetadata: 29 | """Metadata for a tuning process.""" 30 | 31 | def __init__(self): 32 | self.max_thrpt = 0 # Maximum throughput (GFLOP/s). 33 | self.trial_count = 0 34 | self.failed_count = 0 # The number of trails with error number != 0. 35 | 36 | 37 | class TVMJobConfigs(JobConfigs): 38 | """AutoTVM job configurations.""" 39 | 40 | def __init__(self, configs: argparse.Namespace): 41 | """Initialize a job configuration. 42 | 43 | Parameters 44 | ---------- 45 | configs: argparse.Namespace 46 | The system configuration of tuner. 47 | """ 48 | super(TVMJobConfigs, self).__init__(configs) 49 | self.tvm_build_config: Dict[str, str] = get_tvm_build_config() 50 | 51 | def check_tvm_build_config(self) -> bool: 52 | """Check if the TVM build config on this machine matches the expected one. 53 | 54 | Returns 55 | ------- 56 | match: bool 57 | Return True if matcheing. 58 | """ 59 | if not self.tvm_build_config: 60 | # Always match if job configs do not specify the expectation. 61 | return True 62 | 63 | this_config = get_tvm_build_config() 64 | for key, val in this_config.items(): 65 | if key not in self.tvm_build_config or self.tvm_build_config[key] != val: 66 | log.warning( 67 | "TVM build config mismatch: expected %s, but here is %s", 68 | str(self.tvm_build_config), 69 | str(this_config), 70 | ) 71 | return False 72 | return True 73 | 74 | def localize(self, target: str, **kwargs): 75 | """Localize options on worker. 76 | 77 | Parameters 78 | ---------- 79 | target: str 80 | The target string. 81 | 82 | **kwargs 83 | The kwargs of job configuration for updating. 84 | """ 85 | raise NotImplementedError 86 | 87 | 88 | class TVMJob(Job): 89 | """A tuning job including a workload as well as tuning related configurations.""" 90 | 91 | def is_target_compatible(self, target: str) -> bool: 92 | """Check if the taret is compatible to this job. 93 | 94 | Parameters 95 | ---------- 96 | target: str 97 | The target string 98 | 99 | Returns 100 | ------- 101 | compatible: bool 102 | Whether the target is compatible to this job. 103 | """ 104 | this_target = get_canonical_tvm_target_str(self.workload.target) 105 | that_target = get_canonical_tvm_target_str(target) 106 | return is_cover(this_target, that_target) 107 | 108 | @staticmethod 109 | def create_job_configs(configs: argparse.Namespace) -> JobConfigs: 110 | """Create a JobConfigs. See `JobConfigs`. 111 | 112 | Parameters 113 | ---------- 114 | configs: argparse.Namespace 115 | The system configuration of tuner. 116 | 117 | Returns 118 | ------- 119 | job_configs: JobConfigs 120 | The job configurations. 121 | """ 122 | raise NotImplementedError 123 | 124 | def tune( 125 | self, 126 | tune_options: Dict[str, Any], 127 | measure_options: Dict[str, Any], 128 | commit_options: Optional[Dict[str, Any]] = None, 129 | ): 130 | """Tune the job with the given configuration and update the result. 131 | If the commit options are provided, then this function also in charge of 132 | committing the tuning results, or the job manager will commit the result, otherwise. 133 | """ 134 | raise NotImplementedError 135 | 136 | 137 | @append_config_parser("top.tune", "TVM tuning options") 138 | def append_tune_config() -> List[Tuple[List[str], Dict[str, Any]]]: 139 | """Define the command line interface for TVM tuning. 140 | 141 | Returns 142 | ------- 143 | actions: List[Tuple[List[str], Dict[str, Any]]] 144 | The AutoTVM tuning configs. 145 | """ 146 | return [ 147 | ( 148 | ["-n", "--ntrial"], 149 | {"default": 3000, "type": int, "help": "Number of tuning trials for each workload"}, 150 | ), 151 | (["--test"], {"default": 5, "type": int, "help": "Number of tests in one measurement"}), 152 | ( 153 | ["--repeat"], 154 | {"default": 1, "type": int, "help": "Number of measurements for one config"}, 155 | ), 156 | (["--min"], {"default": 1000, "type": int, "help": "Minimum repeat time (ms)"}), 157 | ] 158 | 159 | 160 | @append_config_parser("top.rpc-client", "TVM options for RPC") 161 | def append_rpc_config() -> List[Tuple[List[str], Dict[str, Any]]]: 162 | """Define the command line interface for TVM RPC. 163 | 164 | Returns 165 | ------- 166 | actions: List[Tuple[List[str], Dict[str, Any]]] 167 | The AutoTVM tuning configs. 168 | """ 169 | return [ 170 | (["--device"], {"type": str, "help": "Device name owned by this host"}), 171 | (["--runner-port"], {"type": int, "help": "The port to the TVM runner"}), 172 | ] 173 | -------------------------------------------------------------------------------- /lorien/dialect/tvm_dial/result.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Tuning records of TVM dialect. 16 | """ 17 | from typing import Any, Dict, List, Optional, Tuple, TypeVar 18 | 19 | from ...tune.result import Records 20 | from .util import TVM_BUILD_CONFIG 21 | 22 | TVMRecordType = TypeVar("TVMRecordType", bound=Tuple) 23 | 24 | 25 | class TVMRecords(Records[TVMRecordType]): 26 | """The container to maintain the records of a tuning task.""" 27 | 28 | def get_framework_build_config(self) -> Optional[Dict[str, str]]: 29 | """Get the framework build configurations that generate these records. 30 | If None, then the committed records will not have this information. 31 | """ 32 | return TVM_BUILD_CONFIG 33 | 34 | @staticmethod 35 | def encode(record: TVMRecordType) -> str: 36 | """Encode a record to a string.""" 37 | raise NotImplementedError 38 | 39 | @staticmethod 40 | def decode(record_str: str) -> TVMRecordType: 41 | """Decode a string to a record.""" 42 | raise NotImplementedError 43 | 44 | def gen_task_item(self) -> Dict[str, Any]: 45 | """Generate an item that can be committed to the database. Note that since all records 46 | in this container should be for the same task, they should be in the same task item. 47 | """ 48 | raise NotImplementedError 49 | 50 | @staticmethod 51 | def gen_record_item(record: TVMRecordType): 52 | """Generate an item for a record that can be appended to the task item.""" 53 | raise NotImplementedError 54 | 55 | def push(self, record: TVMRecordType): 56 | """Push a new record. 57 | 58 | Parameters 59 | ---------- 60 | record: Any 61 | The record to be pushed. 62 | """ 63 | raise NotImplementedError 64 | 65 | def pop(self) -> TVMRecordType: 66 | """Pop the worst record in the container.""" 67 | raise NotImplementedError 68 | 69 | def peak(self) -> TVMRecordType: 70 | """Peak the first record.""" 71 | raise NotImplementedError 72 | 73 | def to_list(self, nbest: int = -1) -> List[TVMRecordType]: 74 | """Sort the record (of any layout) to be a list and return the best N. 75 | 76 | Parameters 77 | ---------- 78 | nbest: int 79 | The best N records to be returned. Default to return all. 80 | 81 | Returns 82 | ------- 83 | records: List[RecordType] 84 | The sorted list of records. 85 | """ 86 | raise NotImplementedError 87 | 88 | def __len__(self) -> int: 89 | raise NotImplementedError 90 | -------------------------------------------------------------------------------- /lorien/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Workload generator. 16 | """ 17 | import argparse 18 | from typing import Callable, List 19 | 20 | from .configs import create_config_parser, register_config_parser 21 | from .logger import get_logger 22 | from .util import dump_to_yaml 23 | from .workload import Workload 24 | 25 | log = get_logger("Generator") 26 | 27 | 28 | def gen(gen_func: Callable[[argparse.Namespace], List[Workload]]) -> Callable: 29 | """Generate workloads based on configs. 30 | 31 | Parameters 32 | ---------- 33 | gen_func: Callable[[argparse.Namespace], List[WorkloadBase]] 34 | The generator function that accepts generator specific configs and returns 35 | a list of generated workloads. 36 | 37 | Returns 38 | ------- 39 | ret: Callable 40 | The entry function that uses the given generation function to generate workloads. 41 | """ 42 | 43 | def _do_gen(configs: argparse.Namespace): 44 | """Invoke the generator function to get the workload list, and validate if the workload 45 | is a valid for AutoTVM. 46 | 47 | Parameters 48 | ---------- 49 | configs: argparse.Namespace 50 | The configuration of the generator. 51 | """ 52 | # Collect workloads 53 | log.info("Generating workloads...") 54 | workloads: List[Workload] = gen_func(configs) 55 | 56 | # Dump each workload to a string with the last "\n" removed, and 57 | # aggregate all dumped workloads to a single dict to match tuning config. 58 | with open(configs.output, "w") as workload_file: 59 | dumped = dump_to_yaml( 60 | {"workload": [dump_to_yaml(w)[:-1] for w in workloads]}, single_line=False 61 | ) 62 | assert dumped is not None 63 | workload_file.write(dumped) 64 | 65 | return _do_gen 66 | 67 | 68 | @register_config_parser("top.generate") 69 | def define_config() -> argparse.ArgumentParser: 70 | """Define the command line interface for workload generation. 71 | 72 | Returns 73 | ------- 74 | parser: argparse.ArgumentParser 75 | The defined argument parser. 76 | """ 77 | parser = create_config_parser("Workload Generation") 78 | 79 | # Define generators 80 | subparsers = parser.add_subparsers(dest="dialect", description="The generator dialect") 81 | subparsers.required = True 82 | return parser 83 | -------------------------------------------------------------------------------- /lorien/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The format and config of logging. 16 | """ 17 | 18 | import logging 19 | from typing import Callable, Dict 20 | 21 | from .util import get_time_str 22 | 23 | LOGGER_TABLE: Dict[str, logging.Logger] = {} 24 | 25 | FORMATTER = logging.Formatter( 26 | "[%(asctime)s] %(levelname)7s %(name)s: %(message)s", "%Y-%m-%d %H:%M:%S" 27 | ) 28 | STREAM_HANDLER = logging.StreamHandler() 29 | STREAM_HANDLER.setFormatter(FORMATTER) 30 | 31 | 32 | def get_logger(name: str) -> logging.Logger: 33 | """Attach to the default logger.""" 34 | 35 | if name in LOGGER_TABLE: 36 | return LOGGER_TABLE[name] 37 | 38 | logger = logging.getLogger(name) 39 | logger.setLevel(logging.INFO) 40 | logger.addHandler(STREAM_HANDLER) 41 | 42 | LOGGER_TABLE[name] = logger 43 | return logger 44 | 45 | 46 | def enable_log_file(): 47 | """Add file handler to all loggers.""" 48 | 49 | file_handler = logging.FileHandler("run-{}.log".format(get_time_str())) 50 | file_handler.setFormatter(FORMATTER) 51 | 52 | for logger in LOGGER_TABLE.values(): 53 | logger.addHandler(file_handler) 54 | 55 | 56 | def disable_stream_handler(func: Callable): 57 | """Disable stream (console) handler when running a function.""" 58 | 59 | def _wrapper(*args, **kwargs): 60 | for logger in LOGGER_TABLE.values(): 61 | logger.removeHandler(STREAM_HANDLER) 62 | ret = func(*args, **kwargs) 63 | for logger in LOGGER_TABLE.values(): 64 | logger.addHandler(STREAM_HANDLER) 65 | return ret 66 | 67 | return _wrapper 68 | -------------------------------------------------------------------------------- /lorien/main.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The main flow that integrates all modules 16 | """ 17 | import sys 18 | from .configs import create_config_parser, make_config_parser, register_config_parser 19 | from .logger import enable_log_file 20 | 21 | 22 | @register_config_parser("top") 23 | def define_config(): 24 | """Define the command line interface for the main entry. 25 | 26 | Returns 27 | ------- 28 | parser: argparse.ArgumentParser 29 | The defined argument parser. 30 | """ 31 | 32 | parser = create_config_parser("Lorien: TVM Optimized Schedule Database", prog="lorien") 33 | parser.add_argument( 34 | "--log-run", action="store_true", default=False, help="Log execution logs to a file" 35 | ) 36 | subparsers = parser.add_subparsers(dest="command", help="The command being executed") 37 | subparsers.required = True 38 | return parser 39 | 40 | 41 | class Main: 42 | """The main entry.""" 43 | 44 | def __init__(self): 45 | args = make_config_parser(sys.argv[1:]) 46 | if args.log_run: 47 | enable_log_file() 48 | args.entry(args) 49 | -------------------------------------------------------------------------------- /lorien/tune/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tune the given workload.""" 15 | 16 | from .master import run 17 | -------------------------------------------------------------------------------- /lorien/tune/rpc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """The RPC for targets other than EC2.""" 15 | 16 | from .client import RPCClient 17 | from .launch import launch_client, launch_server 18 | -------------------------------------------------------------------------------- /lorien/tune/rpc/client.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | RPC client. 16 | """ 17 | import argparse 18 | from typing import List, Optional, Tuple 19 | 20 | import rpyc 21 | 22 | from ...database.table import check_table 23 | from ...logger import get_logger 24 | from ...util import load_from_yaml 25 | from ..job import JobConfigs 26 | 27 | log = get_logger("RPCClient") 28 | 29 | 30 | class RPCClient: 31 | """The RPC client.""" 32 | 33 | def __init__(self, configs: argparse.Namespace, silent=False): 34 | """Parse configs to initialize a client. 35 | 36 | Parameters 37 | ---------- 38 | configs: argparse.Namespace 39 | The system configuration of RPC tuner. 40 | 41 | silent: bool 42 | If true, then all messages will be disabled. 43 | """ 44 | # Parse server string 45 | server_str = configs.server 46 | if server_str.find(":") == -1: 47 | raise RuntimeError("Missing port") 48 | 49 | try: 50 | port = int(server_str[server_str.find(":") + 1 :]) 51 | except ValueError: 52 | raise RuntimeError("Invalid port: %s" % server_str[server_str.find(":") + 1 :]) 53 | server_name = server_str[: server_str.find(":")] 54 | 55 | # Connect to the server 56 | try: 57 | conn = rpyc.connect( 58 | server_name, 59 | port, 60 | config={ 61 | "allow_public_attrs": True, 62 | "allow_pickle": True, 63 | "sync_request_timeout": None, 64 | }, 65 | ) 66 | if not silent: 67 | log.info("%s connected", server_str) 68 | except Exception as err: # pylint: disable=broad-except 69 | raise RuntimeError("Failed to connect to %s: %s" % (server_str, str(err))) 70 | 71 | self.target = configs.target 72 | self.job_configs: Optional[JobConfigs] = None 73 | self.socket_port = str(conn._channel.stream.sock.getsockname()[1]) 74 | self.conn = conn 75 | self.token = "" 76 | 77 | def init_worker(self, configs: argparse.Namespace): 78 | """Initialize the worker with tuning options. 79 | 80 | Parameters 81 | ---------- 82 | configs: argparse.Namespace 83 | The system configure for RPC server. 84 | """ 85 | 86 | job_configs_str = self.conn.root.get_job_configs_str(self.token) 87 | self.job_configs = load_from_yaml(job_configs_str) 88 | if self.job_configs is None: 89 | raise RuntimeError( 90 | "Job configuration has not been initialized on the server " 91 | "or this work is not registered yet" 92 | ) 93 | 94 | # Check if the AWS credential on this worker can access the DynamoDB table, and remove 95 | # commit options if failed. 96 | assert self.job_configs.commit_options is not None 97 | if not check_table( 98 | self.job_configs.commit_options["table-name"], 99 | self.job_configs.commit_options["table-arn"], 100 | **self.job_configs.commit_options["db"] 101 | ): 102 | log.warning("AWS credential is invalid. Will let the master commit results") 103 | self.job_configs.commit_options = None 104 | 105 | self.job_configs.localize(self.target, configs=configs) 106 | 107 | def init_server(self, job_configs: JobConfigs): 108 | """Initialize the server options. The client that initializes the server becomes 109 | the root client, which is authorized to submit jobs and fetch results. 110 | 111 | Parameters 112 | ---------- 113 | configs: argparse.Namespace 114 | The system configurations including tune and measure options. 115 | 116 | job_configs: JobConfigs 117 | The job configurations. 118 | """ 119 | self.token = self.conn.root.init(self.socket_port, job_configs) 120 | 121 | def submit(self, job_str: str) -> bool: 122 | """Submit a serialized job to the server. Only the client with root permission 123 | can submit jobs. 124 | 125 | Parameters 126 | ---------- 127 | job_str: str 128 | The serialized job to be submitted. 129 | 130 | Returns 131 | ------- 132 | success: bool 133 | True if the job is submitted successfully; False otherwise. 134 | """ 135 | return self.conn.root.submit(self.token, job_str) 136 | 137 | def fetch_results(self) -> List[Tuple[str, str]]: 138 | """Fetch tuned results from the server. Only the client with root permission is allowed 139 | to fetch results. 140 | 141 | Returns 142 | ------- 143 | results: List[Tuple[str, str]] 144 | A list of serialized (job, result) pairs. 145 | """ 146 | return self.conn.root.fetch_results(self.token) 147 | 148 | def is_server_init(self) -> bool: 149 | """Check if the server is initialized. 150 | 151 | Returns 152 | ------- 153 | init: bool 154 | True if all options are ready; False otherwise. 155 | """ 156 | return self.conn.root.is_init() 157 | 158 | def num_workers(self) -> int: 159 | """Get the number of live workers. 160 | 161 | Returns 162 | ------- 163 | n_workers: int 164 | The number of live workers. 165 | """ 166 | return self.conn.root.num_workers() 167 | 168 | def register_as_worker(self) -> Tuple[bool, str]: 169 | """Register client self as a tuning worker. 170 | 171 | Returns 172 | ------- 173 | token_or_msg: Tuple[bool, str] 174 | A tuple of (success, token or error message). 175 | """ 176 | ret = self.conn.root.register_worker(self.socket_port, self.target) 177 | self.token = ret[1] if ret[0] else None 178 | return ret 179 | 180 | def request_job(self) -> Optional[str]: 181 | """Request a job from the server. The result will be stored in cached_result 182 | so this function will not return anything. 183 | """ 184 | job_str = self.conn.root.request_job(self.token) 185 | if not job_str: 186 | return None 187 | 188 | return job_str 189 | 190 | def send_result(self, job_n_result: Tuple[str, str]) -> str: 191 | """Send the serailized job and result back to the server. 192 | 193 | Parameters 194 | ---------- 195 | job_n_result: Tuple[str, str] 196 | A string pair of job and result. 197 | 198 | Returns 199 | ------- 200 | msg: str 201 | The error message. 202 | """ 203 | return self.conn.root.send_result(self.token, job_n_result) 204 | -------------------------------------------------------------------------------- /lorien/tune/rpc/launch.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | The module to launch RPC client or server. 16 | """ 17 | import argparse 18 | import threading 19 | import signal 20 | import time 21 | from typing import Optional 22 | 23 | from rpyc.utils.server import ThreadedServer 24 | 25 | from ...configs import create_config_parser, register_config_parser 26 | from ...logger import get_logger 27 | from ...util import dump_to_yaml, load_from_yaml 28 | from ..job import Job 29 | from ..result import TuneErrorCode, TuneResult 30 | from .client import RPCClient 31 | from .server import RPCService 32 | 33 | log = get_logger("RPC") 34 | 35 | 36 | def launch_server(port: int, target: str) -> None: 37 | """Launch a RPC server. 38 | 39 | Parameters 40 | ---------- 41 | port: int 42 | The port for launching the server. 43 | 44 | target: str 45 | The target string for this server. 46 | """ 47 | s = ThreadedServer( 48 | RPCService(target), 49 | port=port, 50 | protocol_config={"allow_public_attrs": True, "allow_pickle": True}, 51 | ) 52 | log.info("Launching RPC server at port %d", port) 53 | 54 | try: 55 | s.start() 56 | except Exception as err: # pylint: disable=broad-except 57 | log.info("RPC server at port %d throws exceptions: %s", port, str(err)) 58 | 59 | log.info("RPC server at port %d is shutdown", port) 60 | 61 | 62 | class ClientThread(threading.Thread): 63 | """The custom thread to run the client loop.""" 64 | 65 | def __init__(self, configs): 66 | super(ClientThread, self).__init__(name="ClientThread", daemon=True) 67 | self.configs = configs 68 | self._stop_event = threading.Event() 69 | self._sleep_period = 1.0 70 | 71 | def run(self): 72 | log.info("Connecting to server %s", self.configs.server) 73 | while True: 74 | try: 75 | client = RPCClient(self.configs) 76 | break 77 | except Exception as err: # pylint: disable=broad-except 78 | log.warning("Failed to connect: %s. Reconnectiong", str(err)) 79 | time.sleep(1) 80 | 81 | success, msg = client.register_as_worker() 82 | if not success: 83 | raise RuntimeError("Failed to register as a worker: %s" % msg) 84 | log.info("Register token %s", client.token) 85 | 86 | client.init_worker(self.configs) 87 | assert client.job_configs is not None 88 | 89 | while not self._stop_event.isSet(): 90 | log.info("Requesting a job for tuning") 91 | try: 92 | job_str = client.request_job() 93 | except EOFError: 94 | log.info("Lost server connection") 95 | break 96 | 97 | if not job_str: 98 | log.info("Server job queue empty") 99 | time.sleep(1) 100 | continue 101 | 102 | log.info("Start tuning") 103 | job: Optional[Job] = None 104 | result = TuneResult() 105 | try: 106 | job = load_from_yaml(job_str) 107 | except RuntimeError as err: 108 | msg = "Failed to create a job {0} from string: {1}".format(job_str, str(err)) 109 | log.warning(msg) 110 | result.error_code = TuneErrorCode.FAIL_TO_LOAD_WORKLOAD 111 | result.error_msgs.append(msg) 112 | 113 | if job is not None: 114 | job.tune( 115 | client.job_configs.tune_options, 116 | client.job_configs.measure_options, 117 | client.job_configs.commit_options, 118 | ) 119 | result = job.result 120 | 121 | # Send the result back to the server. 122 | log.info("Result: %s", str(result)) 123 | try: 124 | msg = client.send_result((job_str, dump_to_yaml(result))) 125 | if msg: 126 | log.error(msg) 127 | except EOFError: 128 | log.info("Lost server connection") 129 | break 130 | 131 | self._stop_event.wait(self._sleep_period) 132 | 133 | def join(self, timeout=None): 134 | self._stop_event.set() 135 | threading.Thread.join(self, timeout) 136 | 137 | 138 | def launch_client(configs: argparse.Namespace) -> None: 139 | """Launch a RPC client. 140 | 141 | Parameters 142 | ---------- 143 | configs: argparse.Namespace 144 | The system configure for RPC server. 145 | """ 146 | 147 | client_thread = ClientThread(configs) 148 | client_thread.start() 149 | 150 | def signal_handler(sig, fname): # pylint: disable=unused-argument 151 | log.info("Ctrl+C pressed") 152 | client_thread.join() 153 | 154 | signal.signal(signal.SIGINT, signal_handler) 155 | running = threading.Event() 156 | running.wait() 157 | 158 | 159 | @register_config_parser("top.rpc-client") 160 | def define_config() -> argparse.ArgumentParser: 161 | """Define the command line interface for RPC client. 162 | 163 | Returns 164 | ------- 165 | parser: argparse.ArgumentParser 166 | The defined argument parser. 167 | """ 168 | parser = create_config_parser("Launch RPC client on this machine and connect to the server") 169 | parser.add_argument( 170 | "--server", type=str, required=True, help="RPC Server IP and port (e.g., 0.0.0.0:18871)" 171 | ) 172 | parser.add_argument( 173 | "--target", type=str, required=True, help="The target string for this client" 174 | ) 175 | parser.set_defaults(entry=launch_client) 176 | return parser 177 | -------------------------------------------------------------------------------- /lorien/workload.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"). 4 | # You may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Workload Definition Module. 16 | """ 17 | import hashlib 18 | import uuid 19 | from typing import Any, Sequence 20 | 21 | from ruamel.yaml import YAML, yaml_object 22 | 23 | 24 | @yaml_object(YAML()) 25 | class Workload: 26 | """The workload base class that can be used to create a tuning task.""" 27 | 28 | def __init__(self): 29 | self.target = "unknown" 30 | 31 | @classmethod 32 | def from_task(cls, task: Any) -> "Workload": 33 | """Create a workload from a tuning task. 34 | 35 | Parameters 36 | ---------- 37 | task: Any 38 | The tuning task for the workload. 39 | 40 | Returns 41 | ------- 42 | workload: Workload 43 | The initialized workload. 44 | """ 45 | raise NotImplementedError 46 | 47 | def to_task(self) -> Any: 48 | """Create a tuning task from this workload. 49 | 50 | Returns 51 | ------- 52 | task: Any 53 | Return the created task, or raise RuntimeError if failed. 54 | """ 55 | raise NotImplementedError 56 | 57 | def to_job(self): 58 | """Create a job to tune this workload. 59 | 60 | Returns 61 | ------- 62 | job: Job 63 | The created job. 64 | """ 65 | raise NotImplementedError 66 | 67 | def mutate(self, rules: Any) -> Sequence["Workload"]: 68 | """Mutate workload arguments with the given rules. 69 | 70 | Parameters 71 | ---------- 72 | workload: Workload 73 | The workload to be mutated. 74 | 75 | rules: Any 76 | The mutation rules that can be customized. 77 | 78 | Returns 79 | ------- 80 | workloads: Sequence[Workload] 81 | The mutated workloads. 82 | """ 83 | raise NotImplementedError 84 | 85 | def __lt__(self, other: object) -> bool: 86 | raise NotImplementedError 87 | 88 | def __repr__(self): 89 | raise NotImplementedError 90 | 91 | def hash_sha2(self) -> str: 92 | """Hash this workload with SHA256 algorithm to be a unique 64-byte string. 93 | 94 | Returns 95 | ------- 96 | code: str 97 | A 64-byte string. 98 | """ 99 | sha_obj = hashlib.sha256(str(self).encode("utf-8")) 100 | return sha_obj.hexdigest() 101 | 102 | def get_log_file_name(self) -> str: 103 | """Log file name is encoded as -.json 104 | 105 | Parameters 106 | ---------- 107 | workload: Workload 108 | The target workload. 109 | 110 | Returns 111 | ------- 112 | log_file_name: str 113 | The generated log file name. 114 | """ 115 | return "{0}-{1}.json".format(self.hash_sha2(), str(uuid.uuid4())[:5]) 116 | 117 | def get_workload_key(self) -> str: 118 | """Get the primary key of this workload in DB. 119 | 120 | Returns 121 | ------- 122 | key: str 123 | The primary key of this workload to index the records in DB. 124 | """ 125 | return self.hash_sha2() 126 | 127 | def __hash__(self) -> int: 128 | return hash(str(self)) 129 | 130 | def __eq__(self, other: object) -> bool: 131 | return type(self) is type(other) and str(self) == str(other) 132 | 133 | def __str__(self) -> str: 134 | return repr(self) 135 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | future 3 | tqdm > 4.40 4 | argparse 5 | rpyc 6 | boto3 7 | filelock 8 | ruamel.yaml >= 0.16.12 9 | awscli >= 1.18.140 10 | 11 | -------------------------------------------------------------------------------- /scripts/aws/aws_batch_env.json: -------------------------------------------------------------------------------- 1 | { 2 | "computeEnvironmentName": "lorien-c5-env", 3 | "type": "MANAGED", 4 | "state": "ENABLED", 5 | "computeResources": { 6 | "type": "EC2", 7 | "minvCpus": 0, 8 | "maxvCpus": 32, 9 | "desiredvCpus": 0, 10 | "instanceTypes": [ 11 | "c5.2xlarge" 12 | ], 13 | "subnets": [ 14 | "subnet-xxxxxxxx", 15 | "subnet-yyyyyyyy" 16 | ], 17 | "securityGroupIds": [ 18 | "sg-zzzzzzz" 19 | ], 20 | "ec2KeyPair": "", 21 | "instanceRole": "ecsInstanceRole", 22 | "launchTemplate": { 23 | "launchTemplateName": "batch-template-for-lorien", 24 | "version": "$Default" 25 | } 26 | }, 27 | "serviceRole": "arn:aws:iam:::role/service-role/AWSBatchServiceRole" 28 | } 29 | -------------------------------------------------------------------------------- /scripts/aws/create_launch_template.sh: -------------------------------------------------------------------------------- 1 | # The purpose of creating an AWS batch launch template is to configure 2 | # the instance storage and docker image size limit (default 10G). 3 | # See https://aws.amazon.com/premiumsupport/knowledge-center/batch-job-failure-disk-space/ 4 | 5 | REGION=us-west-2 6 | 7 | python3 gen_launch_template.py 8 | aws ec2 --region ${REGION} create-launch-template --cli-input-json file://aws_batch_launch_template.json 9 | rm aws_batch_launch_template.json 10 | -------------------------------------------------------------------------------- /scripts/aws/gen_launch_template.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | 4 | template = { 5 | "LaunchTemplateName": "batch-template-for-lorien", 6 | "LaunchTemplateData": { 7 | "BlockDeviceMappings": [ 8 | { 9 | "Ebs": {"DeleteOnTermination": True, "VolumeSize": 100, "VolumeType": "gp2"}, 10 | "DeviceName": "/dev/xvda", 11 | }, 12 | { 13 | "Ebs": {"DeleteOnTermination": True, "VolumeSize": 100, "VolumeType": "gp2"}, 14 | "DeviceName": "/dev/xvdcz", 15 | }, 16 | ], 17 | }, 18 | } 19 | 20 | # Be careful about the first empty line. 21 | user_data = """Content-Type: multipart/mixed; boundary="==BOUNDARY==" 22 | MIME-Version: 1.0 23 | 24 | --==BOUNDARY== 25 | Content-Type: text/cloud-boothook; charset="us-ascii" 26 | #cloud-boothook 27 | #!/bin/bash 28 | cloud-init-per once docker_options echo 'OPTIONS="${OPTIONS} --storage-opt dm.basesize=40G"' >> /etc/sysconfig/docker 29 | 30 | --==BOUNDARY== 31 | """ 32 | user_data_bytes = user_data.encode("ascii") 33 | 34 | template["LaunchTemplateData"]["UserData"] = base64.b64encode(user_data_bytes).decode("ascii") 35 | with open("aws_batch_launch_template.json", "w") as fp: 36 | json.dump(template, fp, indent=2) 37 | -------------------------------------------------------------------------------- /scripts/python/download_db.py: -------------------------------------------------------------------------------- 1 | """ 2 | A script to download the entire DynamoDB table 3 | """ 4 | # pylint: disable=invalid-name 5 | 6 | import argparse 7 | import boto3 8 | from lorien import database 9 | import sys, time 10 | 11 | 12 | def dump_items(fp, items): 13 | num_written = 0 14 | for item in items: 15 | config_list = database.query.convert_to_list(item["BestConfigs"]) 16 | best_result = max(config_list, key=lambda r: r["thrpt"]) 17 | if best_result["thrpt"] == -1: 18 | continue 19 | fp.write("{}\n".format(best_result["config"])) 20 | num_written += 1 21 | return num_written 22 | 23 | 24 | def list_table(args): 25 | print(database.list_tables("")) 26 | 27 | 28 | def download_table(args): 29 | db = boto3.client("dynamodb") 30 | table_name = args.table_name 31 | limit_of_items = args.limit_of_items 32 | try: 33 | table_info = db.describe_table(TableName=table_name) 34 | except Exception as err: 35 | raise RuntimeError("Fail to get information of table {}: {}".format(table_name, err)) 36 | 37 | if limit_of_items != 0 and table_info["Table"]["ItemCount"] > limit_of_items: 38 | prompt_txt = ( 39 | "The number of items ({}) in table {} " 40 | "exceeds the number of items limit {}\n" 41 | "Download it anyway? (Y/N) ".format( 42 | table_info["Table"]["ItemCount"], table_name, limit_of_items 43 | ) 44 | ) 45 | txt = input(prompt_txt) 46 | if txt.strip().lower() != "y" and txt.strip().lower() != "yes": 47 | return 48 | 49 | output_file_name = "{}.log".format(table_name) 50 | with open(output_file_name, "w") as fp: 51 | num_iter = 0 52 | num_written = 0 53 | start_key = None 54 | query_options = { 55 | "TableName": table_name, 56 | "ProjectionExpression": "BestConfigs", 57 | "Limit": 100, 58 | } 59 | 60 | while num_iter == 0 or start_key is not None: 61 | try: 62 | if start_key is None: 63 | ret = db.scan(**query_options) 64 | else: 65 | ret = db.scan(**query_options, ExclusiveStartKey=start_key) 66 | except Exception as err: 67 | raise RuntimeError("Fail to download table {}: {}".format(table_name, err)) 68 | 69 | items = ret["Items"] 70 | num_written += dump_items(fp, items) 71 | start_key = ret["LastEvaluatedKey"] if "LastEvaluatedKey" in ret else None 72 | num_iter += 1 73 | 74 | print( 75 | "Finished downloading table {}, {} records written to {}".format( 76 | table_name, num_written, output_file_name 77 | ) 78 | ) 79 | 80 | 81 | def main(): 82 | parser = argparse.ArgumentParser( 83 | description="Download the best configs from a Lorien-format DynamoDB table" 84 | ) 85 | subparser = parser.add_subparsers(dest="options") 86 | subparser.required = True 87 | 88 | list_parser = subparser.add_parser("list_table", help="List all DynamoDB tables") 89 | list_parser.set_defaults(entry_func=list_table) 90 | 91 | download_parser = subparser.add_parser( 92 | "download_table", 93 | help="Download the best configs from a Lorien-format DynamoDB table", 94 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 95 | ) 96 | download_parser.add_argument( 97 | "--table_name", "-t", help="The name of the Lorien-format DynamoDB table", required=True 98 | ) 99 | download_parser.add_argument( 100 | "--limit_of_items", 101 | "-l", 102 | type=int, 103 | help="The number of items limit of the DynamoDB table that can be downloaded (0 = unlimited)", 104 | default=10000, 105 | ) 106 | download_parser.set_defaults(entry_func=download_table) 107 | 108 | args = parser.parse_args() 109 | args.entry_func(args) 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /scripts/python/merge_workload.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script merges workloads of multiple files. 3 | """ 4 | 5 | import argparse 6 | 7 | from lorien.util import dump_to_yaml, load_from_yaml 8 | from lorien.workload import Workload 9 | 10 | 11 | def create_config(): 12 | """Create the config parser.""" 13 | parser = argparse.ArgumentParser(description="Merge Workloads") 14 | parser.add_argument("-f", "--file", nargs="+", required=True, help="The workload files") 15 | parser.add_argument( 16 | "-o", "--output", default="merged_workloads.yaml", help="Output workload file" 17 | ) 18 | 19 | return parser.parse_args() 20 | 21 | 22 | def main(): 23 | """Main function.""" 24 | configs = create_config() 25 | workloads = set() 26 | 27 | for file_path in configs.file: 28 | with open(file_path, "r") as filep: 29 | row_workloads = load_from_yaml(filep.read())["workload"] 30 | workloads.update([load_from_yaml(row_workload, Workload) for row_workload in row_workloads]) 31 | 32 | out = {"workload": [dump_to_yaml(workload) for workload in workloads]} 33 | out_str = dump_to_yaml(out, single_line=False) 34 | with open(configs.output, "w") as filep: 35 | filep.write(out_str) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /scripts/python/sort_workloads.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script sorts workloads by their shapes and attributes. 3 | """ 4 | 5 | import argparse 6 | 7 | from lorien.workload import Workload 8 | from lorien.util import dump_to_yaml, load_from_yaml 9 | 10 | 11 | def create_config(): 12 | """Create the config parser.""" 13 | parser = argparse.ArgumentParser(description="Sort Workloads") 14 | parser.add_argument("file", help="The workload file") 15 | parser.add_argument("-i", action="store_true", help="In-place sorting") 16 | 17 | return parser.parse_args() 18 | 19 | 20 | def main(): 21 | """Main function.""" 22 | configs = create_config() 23 | with open(configs.file, "r") as filep: 24 | row_workloads = load_from_yaml(filep.read())["workload"] 25 | 26 | workloads = [load_from_yaml(row_workload, Workload) for row_workload in row_workloads] 27 | sorted_workloads = sorted(workloads) 28 | out = {"workload": [dump_to_yaml(workload) for workload in sorted_workloads]} 29 | out_str = dump_to_yaml(out, single_line=False) 30 | 31 | if configs.i: 32 | with open(configs.file, "w") as filep: 33 | filep.write(out_str) 34 | filep.write("\n") 35 | else: 36 | print(out_str) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Package Setup""" 2 | import os 3 | import re 4 | from distutils.core import setup 5 | 6 | from setuptools import find_packages 7 | 8 | CURRENT_DIR = os.path.dirname(__file__) 9 | 10 | 11 | def read(path): 12 | with open(path, "r") as filep: 13 | return filep.read() 14 | 15 | 16 | def get_version(package_name): 17 | with open(os.path.join(os.path.dirname(__file__), package_name, "__init__.py")) as fp: 18 | for line in fp: 19 | tokens = re.search(r'^\s*__version__\s*=\s*"(.+)"\s*$', line) 20 | if tokens: 21 | return tokens.group(1) 22 | raise RuntimeError("Unable to find own __version__ string") 23 | 24 | 25 | setup( 26 | name="lorien", 27 | version=get_version("lorien"), 28 | license="Apache-2.0", 29 | description="A Unified Infrastructure for Efficient Deep Learning Workloads Delivery", 30 | long_description=read(os.path.join(CURRENT_DIR, "README.md")), 31 | long_description_content_type="text/markdown", 32 | author="Lorien Community", 33 | url="https://github.com/awslabs/lorien", 34 | keywords=[], 35 | packages=find_packages(), 36 | install_requires=[p for p in read("requirements.txt").split("\n") if p], 37 | classifiers=[ 38 | "Development Status :: 4 - Beta", 39 | "Intended Audience :: Developers", 40 | "Topic :: Software Development :: Build Tools", 41 | "License :: OSI Approved :: Apache Software License", 42 | "Programming Language :: Python :: 3", 43 | "Programming Language :: Python :: 3.6", 44 | ], 45 | ) 46 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/lorien/4ebaeb0d50fde66cb27b8dde538f962f60c8918e/tests/__init__.py -------------------------------------------------------------------------------- /tests/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | The common utilities for unit tests. 3 | """ 4 | # pylint:disable=missing-docstring, redefined-outer-name, invalid-name 5 | import socket 6 | 7 | import boto3 8 | import pytest 9 | from moto import mock_dynamodb2, mock_s3 10 | 11 | from lorien.database.table import create_table 12 | from lorien.tune.job import Job, JobConfigs 13 | from lorien.tune.result import TuneResult 14 | from lorien.workload import Workload 15 | 16 | 17 | class LorienTestJobConfig(JobConfigs): 18 | pass 19 | 20 | 21 | class LorienTestJob(Job): 22 | # pylint: disable=abstract-method 23 | 24 | @staticmethod 25 | def create_job_configs(configs): 26 | return LorienTestJobConfig(configs) 27 | 28 | def is_target_compatible(self, target): 29 | return self.workload.target == target 30 | 31 | def tune(self, tune_options, measure_options, commit_options=None): 32 | result = TuneResult() 33 | result.metadata["tune_logs"] = "tuning logs" 34 | result.commit = lambda options, workload, silent: None 35 | 36 | 37 | class LorienTestWorkload(Workload): 38 | # pylint: disable=abstract-method 39 | def __init__(self, target, idx): 40 | super(LorienTestWorkload, self).__init__() 41 | self.target = target 42 | self.idx = idx 43 | self.dummy_data = "" 44 | 45 | def to_job(self): 46 | return LorienTestJob(self) 47 | 48 | def __repr__(self): 49 | return "LorienTestWorkload(%s, %d)" % (self.target, self.idx) 50 | 51 | 52 | def gen_workloads(lower_idx, upper_idx, target="llvm"): 53 | """Generate a number of dummy workloads.""" 54 | return [LorienTestWorkload(target, idx) for idx in range(lower_idx, upper_idx)] 55 | 56 | 57 | def gen_jobs(lower_idx, upper_idx, target="llvm"): 58 | """Generate a number of dummy jobs.""" 59 | return [LorienTestWorkload(target, idx).to_job() for idx in range(lower_idx, upper_idx)] 60 | 61 | 62 | def find_first_available_port(): 63 | """Find an available port to perform RPC tests.""" 64 | skt = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 65 | skt.bind(("0.0.0.0", 0)) 66 | _, port = skt.getsockname() 67 | skt.close() 68 | return port 69 | 70 | 71 | @pytest.fixture 72 | def mock_s3_client(): 73 | with mock_s3(): 74 | client = boto3.client("s3") 75 | client.create_bucket( 76 | Bucket="unit-test-bucket", CreateBucketConfiguration={"LocationConstraint": "us-west-2"} 77 | ) 78 | yield client 79 | 80 | 81 | @pytest.fixture 82 | def mock_db_table_arn(): 83 | table_name = "unit-test-lorien" 84 | with mock_dynamodb2(): 85 | arn = create_table(table_name) 86 | yield (table_name, arn) 87 | -------------------------------------------------------------------------------- /tests/dialect/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/lorien/4ebaeb0d50fde66cb27b8dde538f962f60c8918e/tests/dialect/__init__.py -------------------------------------------------------------------------------- /tests/dialect/tvm_dial/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/lorien/4ebaeb0d50fde66cb27b8dde538f962f60c8918e/tests/dialect/tvm_dial/__init__.py -------------------------------------------------------------------------------- /tests/dialect/tvm_dial/auto_scheduler_dial/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/lorien/4ebaeb0d50fde66cb27b8dde538f962f60c8918e/tests/dialect/tvm_dial/auto_scheduler_dial/__init__.py -------------------------------------------------------------------------------- /tests/dialect/tvm_dial/autotvm_dial/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/lorien/4ebaeb0d50fde66cb27b8dde538f962f60c8918e/tests/dialect/tvm_dial/autotvm_dial/__init__.py -------------------------------------------------------------------------------- /tests/dialect/tvm_dial/test_parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | The unit test module for TVM dialect frontend parsers. 3 | """ 4 | # pylint:disable=missing-docstring, redefined-outer-name, invalid-name 5 | # pylint:disable=unused-argument, unused-import, wrong-import-position, ungrouped-imports 6 | import argparse 7 | import mock 8 | import pytest 9 | 10 | from lorien.util import is_dialect_enabled 11 | 12 | if not is_dialect_enabled("tvm"): 13 | pytest.skip("TVM dialect is not available", allow_module_level=True) 14 | 15 | from lorien.dialect.tvm_dial.frontend_parser import EXTRACTOR_FUNC_TABLE 16 | 17 | 18 | def test_parse_from_gcv(mocker): 19 | # Mock GCV frontend and assume it's always working. 20 | mocker.patch("gluoncv.model_zoo.get_model").return_value = "FakeNet" 21 | mocker.patch( 22 | "lorien.dialect.tvm_dial.frontend_parser.relay.frontend.from_mxnet" 23 | ).return_value = ("mod", "params") 24 | 25 | configs = argparse.Namespace(gcv=["alexnet"], target=["llvm"]) 26 | 27 | mod_n_params = EXTRACTOR_FUNC_TABLE["gcv"](configs) 28 | assert len(mod_n_params) == 1 29 | assert mod_n_params[0][0] == "alexnet" 30 | assert mod_n_params[0][1][0] == "mod" 31 | assert mod_n_params[0][1][1] == "params" 32 | 33 | mocker.patch("gluoncv.model_zoo.get_model").side_effect = Exception("Mocked Error") 34 | mod_n_params = EXTRACTOR_FUNC_TABLE["gcv"](configs) 35 | assert len(mod_n_params) == 0 36 | 37 | 38 | def test_parse_from_keras(mocker): 39 | # Mock Keras frontend and assume it's always working. 40 | mocker.patch("keras.models.load_model").return_value = "FakeNet" 41 | mocker.patch( 42 | "lorien.dialect.tvm_dial.frontend_parser.relay.frontend.from_keras" 43 | ).return_value = ("mod", "params") 44 | 45 | configs = argparse.Namespace(keras=["alexnet: { data: [1, 3, 224, 224]}"], target="llvm") 46 | 47 | mod_n_params = EXTRACTOR_FUNC_TABLE["keras"](configs) 48 | assert len(mod_n_params) == 1 49 | assert mod_n_params[0][0] == "alexnet" 50 | assert mod_n_params[0][1][0] == "mod" 51 | assert mod_n_params[0][1][1] == "params" 52 | 53 | mocker.patch("keras.models.load_model").side_effect = Exception("Mocked Error") 54 | mod_n_params = EXTRACTOR_FUNC_TABLE["keras"](configs) 55 | assert len(mod_n_params) == 0 56 | 57 | 58 | def test_parse_from_onnx(mocker): 59 | # Mock ONNX frontend and assume it's always working. 60 | mocker.patch("onnx.load").return_value = "FakeNet" 61 | mocker.patch( 62 | "lorien.dialect.tvm_dial.frontend_parser.relay.frontend.from_onnx" 63 | ).return_value = ("mod", "params") 64 | 65 | configs = argparse.Namespace(onnx=["alexnet: { data: [1, 3, 224, 224]}"], target=["llvm"]) 66 | 67 | mod_n_params = EXTRACTOR_FUNC_TABLE["onnx"](configs) 68 | assert len(mod_n_params) == 1 69 | assert mod_n_params[0][0] == "alexnet" 70 | assert mod_n_params[0][1][0] == "mod" 71 | assert mod_n_params[0][1][1] == "params" 72 | 73 | mocker.patch("onnx.load").side_effect = Exception("Mocked Error") 74 | mod_n_params = EXTRACTOR_FUNC_TABLE["onnx"](configs) 75 | assert len(mod_n_params) == 0 76 | 77 | 78 | def test_extract_from_torch(): 79 | configs = argparse.Namespace( 80 | torch=["alexnet: { data: [1, 3, 224, 224]}"], 81 | target=["llvm"], 82 | ) 83 | mod_n_params = EXTRACTOR_FUNC_TABLE["torch"](configs) 84 | assert len(mod_n_params) == 1 85 | 86 | configs = argparse.Namespace( 87 | torch=["alexnet_wrong_name"], 88 | target=["llvm"], 89 | ) 90 | mod_n_params = EXTRACTOR_FUNC_TABLE["torch"](configs) 91 | assert len(mod_n_params) == 0 92 | 93 | 94 | def test_parse_from_tf(mocker): 95 | # Mock TensorFlow frontend and assume it's always working. 96 | mocker.patch( 97 | "tvm.relay.frontend.tensorflow_parser.TFParser" 98 | ).return_value.parse.return_value = "FakeNet" 99 | mocker.patch( 100 | "lorien.dialect.tvm_dial.frontend_parser.relay.frontend.from_tensorflow" 101 | ).return_value = ("mod", "params") 102 | 103 | configs = argparse.Namespace(tf=["alexnet: { data: [1, 224, 224, 3]}"], target=["llvm"]) 104 | 105 | mod_n_params = EXTRACTOR_FUNC_TABLE["tf"](configs) 106 | assert len(mod_n_params) == 1 107 | assert mod_n_params[0][0] == "alexnet" 108 | assert mod_n_params[0][1][0] == "mod" 109 | assert mod_n_params[0][1][1] == "params" 110 | 111 | mocker.patch("tvm.relay.frontend.tensorflow_parser.TFParser").side_effect = Exception( 112 | "Mocked Error" 113 | ) 114 | mod_n_params = EXTRACTOR_FUNC_TABLE["tf"](configs) 115 | assert len(mod_n_params) == 0 116 | 117 | 118 | def test_parse_from_tflite(mocker): 119 | # Mock tflite frontend and assume it's always working. 120 | mocker.patch("lorien.dialect.tvm_dial.frontend_parser.open") 121 | mocker.patch("tflite.Model.Model.GetRootAsModel").return_value.parse.return_value = "FakeNet" 122 | mocker.patch( 123 | "lorien.dialect.tvm_dial.frontend_parser.relay.frontend.from_tflite" 124 | ).return_value = ("mod", "params") 125 | 126 | def dummy_func(a): 127 | return a 128 | 129 | mocker.patch( 130 | "lorien.dialect.tvm_dial.frontend_parser.relay.transform.RemoveUnusedFunctions" 131 | ).return_value = dummy_func 132 | mocker.patch( 133 | "lorien.dialect.tvm_dial.frontend_parser.relay.transform.InferType" 134 | ).return_value = dummy_func 135 | mocker.patch( 136 | "lorien.dialect.tvm_dial.frontend_parser.relay.transform.ConvertLayout" 137 | ).return_value = dummy_func 138 | 139 | configs = argparse.Namespace(tflite=["alexnet: { data: [1, 224, 224, 3]}"], target=["llvm"]) 140 | 141 | mod_n_params = EXTRACTOR_FUNC_TABLE["tflite"](configs) 142 | assert len(mod_n_params) == 1 143 | assert mod_n_params[0][0] == "alexnet" 144 | assert mod_n_params[0][1][0] == "mod" 145 | assert mod_n_params[0][1][1] == "params" 146 | 147 | mocker.patch("tflite.Model.Model.GetRootAsModel").side_effect = Exception("Mocked Error") 148 | mod_n_params = EXTRACTOR_FUNC_TABLE["tflite"](configs) 149 | assert len(mod_n_params) == 0 150 | 151 | 152 | def test_parse_from_mxnet(mocker): 153 | mocker.patch("mxnet.sym.load").return_value = None 154 | 155 | dummy_sym = mock.MagicMock() 156 | dummy_sym.hybridize.return_value = None 157 | dummy_sym.collect_params.return_value = mock.MagicMock() 158 | dummy_sym.collect_params.return_value.load.return_value = None 159 | 160 | mocker.patch("mxnet.gluon.SymbolBlock").return_value = dummy_sym 161 | mocker.patch( 162 | "lorien.dialect.tvm_dial.frontend_parser.relay.frontend.from_mxnet" 163 | ).return_value = ("mod", "params") 164 | 165 | configs = argparse.Namespace(target=["llvm"], mxnet=["alexnet: { data: [1, 3, 224, 224]}"]) 166 | 167 | mod_n_params = EXTRACTOR_FUNC_TABLE["mxnet"](configs) 168 | assert len(mod_n_params) == 1 169 | assert mod_n_params[0][0] == "alexnet" 170 | assert mod_n_params[0][1][0] == "mod" 171 | assert mod_n_params[0][1][1] == "params" 172 | 173 | mocker.patch("mxnet.gluon.SymbolBlock").side_effect = Exception("Mocked Error") 174 | mod_n_params = EXTRACTOR_FUNC_TABLE["mxnet"](configs) 175 | assert len(mod_n_params) == 0 176 | -------------------------------------------------------------------------------- /tests/lint/coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | */logger.py 4 | */__main__.py 5 | */__init__.py 6 | */launch.py 7 | */main.py 8 | 9 | [report] 10 | exclude_lines = 11 | pragma: no cover 12 | raise NotImplementedError 13 | if __name__ == .__main__.: 14 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | The unit test module for config. 3 | """ 4 | # pylint:disable=unused-import, missing-docstring, redefined-outer-name 5 | import argparse 6 | 7 | import pytest 8 | from testfixtures import TempDirectory 9 | 10 | from lorien import main 11 | from lorien.configs import CONFIG_GROUP, make_config_parser, read_args_from_files 12 | from lorien.logger import get_logger 13 | from lorien.util import dump_to_yaml 14 | 15 | log = get_logger("Unit-Test") 16 | 17 | 18 | def test_definition(): 19 | # Check main entry. 20 | assert "top" in CONFIG_GROUP 21 | 22 | # Check subparsers and sub-module entries. 23 | for parser in CONFIG_GROUP.values(): 24 | subparsers = [p for p in parser._actions if isinstance(p, argparse._SubParsersAction)] 25 | if len(subparsers) == 1: 26 | # If the parser has a subparser, it has to be required. 27 | assert subparsers[0].required 28 | elif not subparsers: 29 | # If no sub-commands, then this parser has to define "entry" as the default function. 30 | assert parser.get_default("entry") is not None 31 | else: 32 | # Not allowed to have more than one subparsers. 33 | assert False 34 | 35 | 36 | def test_config(): 37 | with pytest.raises(SystemExit): 38 | make_config_parser([]) 39 | 40 | # Test config cache 41 | with pytest.raises(SystemExit): 42 | make_config_parser([]) 43 | 44 | 45 | def test_read_args_from_files(): 46 | args = read_args_from_files(["a", "b", "c"]) 47 | assert len(args) == 3 48 | 49 | with TempDirectory() as temp_dir: 50 | config_file = "{}/cfg.yaml".format(temp_dir.path) 51 | with open(config_file, "w") as filep: 52 | filep.write(dump_to_yaml({"model": ["a", "b", "c"]})) 53 | args = read_args_from_files(["p", "@{}".format(config_file)]) 54 | assert args == ["p", "--model", "a", "--model", "b", "--model", "c"] 55 | -------------------------------------------------------------------------------- /tests/test_database.py: -------------------------------------------------------------------------------- 1 | """ 2 | The unit test for database. 3 | """ 4 | # pylint:disable=missing-docstring, redefined-outer-name, invalid-name, unused-argument 5 | 6 | import boto3 7 | import pytest 8 | from moto import mock_dynamodb2 9 | 10 | from lorien.database.table import check_table, create_table, delete_table, list_tables, scan_table 11 | from lorien.database.util import ( 12 | convert_to_db_dict, 13 | convert_to_db_list, 14 | convert_to_dict, 15 | convert_to_list, 16 | ) 17 | 18 | 19 | @mock_dynamodb2 20 | def test_manipulate_table(): 21 | 22 | table_name = "lorien-test" 23 | 24 | # Test create table 25 | with pytest.raises(RuntimeError): 26 | create_table(table_name, region_name="invalid-region") 27 | arn = create_table(table_name, region_name="us-west-2") 28 | 29 | # Do not create if the table exists 30 | arn = create_table(table_name, region_name="us-west-2") 31 | 32 | # Test check table 33 | assert check_table(table_name, arn, region_name="us-west-2") 34 | assert not check_table(table_name, arn, region_name="us-west-1") 35 | 36 | # Test list table 37 | with pytest.raises(RuntimeError): 38 | list_tables(region_name="invalid-region") 39 | assert len(list_tables(region_name="us-west-2")) == 1 40 | assert not list_tables(region_name="us-west-1") 41 | 42 | # Put something in the table 43 | item = { 44 | "Target": {"S": "llvm"}, 45 | "TargetIDKeys": {"S": "llvm_cpu"}, 46 | "PrimaryRangeKey": {"S": "key"}, 47 | } 48 | client = boto3.client("dynamodb", region_name="us-west-2") 49 | client.put_item(TableName=table_name, Item=item) 50 | 51 | # Test scan table 52 | scanner = scan_table(table_name, limit=1, region_name="us-west-2") 53 | count = 0 54 | while True: 55 | try: 56 | next(scanner) 57 | count += 1 58 | except StopIteration: 59 | break 60 | assert count == 1 61 | 62 | # Remove the unit test table. 63 | delete_table(table_name, region_name="us-west-2") 64 | 65 | 66 | def test_database_util(): 67 | orig_list = ["string", 123.34, 345, [1, 2, None], {"a": 2, "b": 8}] 68 | db_list = convert_to_db_list(orig_list) 69 | assert orig_list == convert_to_list(db_list) 70 | with pytest.raises(RuntimeError): 71 | convert_to_list(db_list["L"]) 72 | 73 | orig_dict = {"a": "string", "b": 123.45, "c": None, "d": [1, 2, 3], "e": {"p": 3, "q": 4}} 74 | db_dict = convert_to_db_dict(orig_dict) 75 | assert orig_dict == convert_to_dict(db_dict) 76 | with pytest.raises(RuntimeError): 77 | convert_to_dict(db_dict["M"]) 78 | -------------------------------------------------------------------------------- /tests/test_generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | The unit test for generator. 3 | """ 4 | # pylint:disable=missing-docstring, redefined-outer-name, invalid-name, unused-argument 5 | import argparse 6 | import os 7 | import tempfile 8 | 9 | from lorien.generate import gen 10 | from lorien.util import load_from_yaml 11 | 12 | from .common import gen_workloads 13 | 14 | 15 | def test_gen(): 16 | def test_generator(configs): 17 | return gen_workloads(0, 5, configs.target) 18 | 19 | with tempfile.TemporaryDirectory(prefix="lorien_test_gen_") as temp_dir: 20 | wkl_file = os.path.join(temp_dir, "test_workloads.yaml") 21 | configs = argparse.Namespace(target="llvm", output=wkl_file) 22 | 23 | gen(test_generator)(configs) 24 | assert os.path.exists(wkl_file) 25 | with open(wkl_file, "r") as filep: 26 | workloads = load_from_yaml(filep.read()) 27 | assert len(workloads["workload"]) == 5 28 | assert all([load_from_yaml(wkl).target == "llvm" for wkl in workloads["workload"]]) 29 | -------------------------------------------------------------------------------- /tests/test_result.py: -------------------------------------------------------------------------------- 1 | """ 2 | The unit test for result. 3 | """ 4 | # pylint:disable=missing-docstring, redefined-outer-name, invalid-name 5 | # pylint:disable=unused-argument, unused-import 6 | import os 7 | import tempfile 8 | 9 | import pytest 10 | 11 | from lorien.tune.result import Records, TuneErrorCode, TuneResult 12 | from lorien.util import dump_to_yaml, load_from_yaml 13 | 14 | from .common import mock_db_table_arn, mock_s3_client 15 | 16 | 17 | def test_result(mocker, mock_s3_client, mock_db_table_arn): 18 | table_name, _ = mock_db_table_arn 19 | 20 | class TestRecords(Records[int]): 21 | def __init__(self, build_config, target_key, alter_key=None, workload_key=None): 22 | super(TestRecords, self).__init__(target_key, alter_key, workload_key) 23 | self._data = [] 24 | self._build_config = build_config 25 | 26 | def get_framework_build_config(self): 27 | return self._build_config 28 | 29 | @staticmethod 30 | def encode(record): 31 | return str(record) 32 | 33 | @staticmethod 34 | def decode(record_str): 35 | return int(record_str) 36 | 37 | def gen_task_item(self): 38 | return {} 39 | 40 | @staticmethod 41 | def gen_record_item(record): 42 | return {"latency": record} 43 | 44 | def push(self, record): 45 | self._data.append(record) 46 | 47 | def pop(self): 48 | self._data, val = self._data[:-1], self._data[-1] 49 | return val 50 | 51 | def peak(self): 52 | return self._data[0] 53 | 54 | def to_list(self, nbest=-1): 55 | return self._data 56 | 57 | def __len__(self) -> int: 58 | return len(self._data) 59 | 60 | class TestResult1(TuneResult): 61 | @staticmethod 62 | def create_records_by_workloads(log_file_path, nbest, workload=None): 63 | return [] 64 | 65 | class TestResult2(TuneResult): 66 | @staticmethod 67 | def create_records_by_workloads(log_file_path, nbest, workload=None): 68 | records = TestRecords({"commit": "abc"}, "workload_key1", "target_ley", "alter_key") 69 | for idx in range(10): 70 | records.push(idx) 71 | return [ 72 | records, 73 | TestRecords({"commit": "abc"}, "workload_key2", "target_ley", "alter_key"), 74 | ] 75 | 76 | class TestResult3(TuneResult): 77 | @staticmethod 78 | def create_records_by_workloads(log_file_path, nbest, workload=None): 79 | records = TestRecords({"commit": "pqr"}, "workload_key1", "target_ley", "alter_key") 80 | for idx in range(10): 81 | records.push(idx) 82 | return [records] 83 | 84 | with tempfile.TemporaryDirectory(prefix="lorien_test_result_") as temp_dir: 85 | log_file = os.path.join(temp_dir, "tuning_log.json") 86 | with open(log_file, "w") as filep: 87 | filep.write("aaa\n") 88 | 89 | commit_options = { 90 | "commit-log": "s3://unit-test-bucket/tuning_log.json", 91 | "table-name": table_name, 92 | "nbest": 1, 93 | "db": {}, 94 | } 95 | 96 | result = TestResult1() 97 | assert str(result).find("error_code") != -1 98 | 99 | # Failed due to no log file. 100 | with pytest.raises(RuntimeError): 101 | result.commit(commit_options) 102 | 103 | # Failed due to no valid record. 104 | result.log_file = log_file 105 | result.commit(commit_options) 106 | assert result.error_code == TuneErrorCode.STORAGE_ERROR 107 | 108 | # Success committed twice. 109 | result = TestResult2() 110 | result.log_file = log_file 111 | result.commit(commit_options) 112 | assert result.error_code == TuneErrorCode.NORMAL 113 | assert not result.error_msgs 114 | result.commit(commit_options) 115 | assert result.error_code == TuneErrorCode.NORMAL 116 | assert not result.error_msgs 117 | 118 | # Success committed the same log with different build config. 119 | result = TestResult3() 120 | result.log_file = log_file 121 | result.commit(commit_options) 122 | assert result.error_code == TuneErrorCode.NORMAL 123 | assert not result.error_msgs 124 | 125 | # Failed to upload to S3 126 | commit_options["commit-log"] = "s3://invalid-bucket/tuning_log.json" 127 | result = TestResult3() 128 | result.log_file = log_file 129 | result.commit(commit_options) 130 | assert result.error_code == TuneErrorCode.STORAGE_ERROR 131 | assert result.error_msgs 132 | 133 | # Failed to commit to DB 134 | mocker.patch.object(TestResult3, "commit_tuning_log").side_effect = RuntimeError 135 | commit_options["commit-log"] = None 136 | result = TestResult3() 137 | result.log_file = log_file 138 | result.commit(commit_options) 139 | assert result.error_code == TuneErrorCode.STORAGE_ERROR 140 | assert any([msg.find("Failed to commit result") != -1 for msg in result.error_msgs]) 141 | -------------------------------------------------------------------------------- /tests/test_tune_master.py: -------------------------------------------------------------------------------- 1 | """ 2 | The unit test module for tuning master. 3 | """ 4 | # pylint:disable=missing-docstring, redefined-outer-name, invalid-name 5 | # pylint:disable=unused-argument, unused-import 6 | import argparse 7 | import tempfile 8 | 9 | import pytest 10 | from lorien import main # pylint: disable=unused-import 11 | from lorien.configs import make_config_parser 12 | from lorien.tune.master import run, run_job_manager 13 | from lorien.util import dump_to_yaml, upload_s3_file 14 | 15 | from .common import gen_jobs, gen_workloads, mock_s3_client 16 | 17 | 18 | def test_run(mocker, mock_s3_client): 19 | # Failed to load workloads 20 | with pytest.raises(RuntimeError): 21 | configs = argparse.Namespace(workload=["not/exist/path/a.json"]) 22 | run(configs) 23 | 24 | argv = ["tune"] 25 | for wkl in gen_workloads(0, 5): 26 | argv += ["--workload", dump_to_yaml(wkl)] 27 | 28 | for job in gen_jobs(5, 10): 29 | argv += ["--job", dump_to_yaml(job)] 30 | 31 | job = gen_jobs(10, 11)[0] 32 | with tempfile.NamedTemporaryFile(mode="w", prefix="lorien-test-tune-") as temp_file: 33 | temp_file.write(dump_to_yaml(job)) 34 | temp_file.flush() 35 | upload_s3_file(temp_file.name, "s3://unit-test-bucket/job.yaml") 36 | 37 | argv += ["--job", "s3://unit-test-bucket/job.yaml"] 38 | 39 | def mock_run_job_manager(_, packed_args): 40 | return packed_args 41 | 42 | mocker.patch("lorien.tune.master.run_job_manager").side_effect = mock_run_job_manager 43 | 44 | # No job manager config 45 | configs = make_config_parser(argv) 46 | assert not run(configs) 47 | 48 | batch_cfg_str = dump_to_yaml({"target": "llvm"}) 49 | for mgr_cfg in [["--batch", batch_cfg_str], ["--local", "llvm"]]: 50 | upload_s3_file(temp_file.name, "s3://unit-test-bucket/job.yaml") 51 | configs = make_config_parser(argv + mgr_cfg) 52 | packed_args = run(configs) 53 | 54 | assert packed_args["target"] == "llvm" 55 | assert len(packed_args["jobs"]) == 11 56 | 57 | # Conflict job managers 58 | with pytest.raises(RuntimeError): 59 | upload_s3_file(temp_file.name, "s3://unit-test-bucket/job.yaml") 60 | configs = make_config_parser(argv + ["--batch", batch_cfg_str, "--local", "llvm"]) 61 | run(configs) 62 | 63 | 64 | def test_run_job_manager(): 65 | class FakeManager1: 66 | def __init__(self, **args): 67 | raise RuntimeError 68 | 69 | class FakeManager2: 70 | def __init__(self, **args): 71 | pass 72 | 73 | def tune(self): 74 | return 1 75 | 76 | assert not run_job_manager(FakeManager1, {}) 77 | assert run_job_manager(FakeManager2, {}) == 1 78 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | The unit tests for utility functions. 3 | """ 4 | # pylint:disable=missing-docstring, redefined-outer-name, invalid-name 5 | # pylint:disable=unused-argument, unused-import 6 | import os 7 | import tempfile 8 | 9 | import boto3 10 | import pytest 11 | from lorien.util import ( 12 | deep_tuple_to_list, 13 | delete_s3_file, 14 | download_s3_file, 15 | dump_to_yaml, 16 | get_time_str, 17 | load_from_yaml, 18 | serialize_framework_build_config, 19 | split_s3_path, 20 | upload_s3_file, 21 | ) 22 | 23 | from .common import mock_s3_client 24 | 25 | 26 | def test_get_time_str(): 27 | assert get_time_str() 28 | 29 | 30 | def test_split_s3_path(): 31 | bucket, folder = split_s3_path("s3://bucket_name") 32 | assert bucket == "bucket_name" 33 | assert not folder 34 | 35 | bucket, folder = split_s3_path("bucket_name/folder1") 36 | assert bucket == "bucket_name" 37 | assert folder == "folder1" 38 | 39 | bucket, folder = split_s3_path("bucket_name/folder1/folder2/") 40 | assert bucket == "bucket_name" 41 | assert folder == "folder1/folder2" 42 | 43 | bucket, folder = split_s3_path("bucket_name/folder1/folder2/file.log") 44 | assert bucket == "bucket_name" 45 | assert folder == "folder1/folder2/file.log" 46 | 47 | 48 | def test_manipulate_s3(mock_s3_client): 49 | with tempfile.NamedTemporaryFile(mode="w", prefix="lorien-test-s3-") as temp_file: 50 | temp_file.write("aaa\n") 51 | temp_file.flush() 52 | ret = upload_s3_file(temp_file.name, "s3://invalid-bucket") 53 | assert ret.startswith("Failed to upload the file") 54 | assert upload_s3_file(temp_file.name, "s3://unit-test-bucket/a/b/c/temp") == "" 55 | 56 | with tempfile.NamedTemporaryFile(mode="w", prefix="lorien-test-s3-") as temp_file: 57 | temp_file_path = temp_file.name 58 | 59 | ret = download_s3_file("s3://invalid-bucket", temp_file_path) 60 | assert ret.startswith("Failed to download the file") 61 | assert download_s3_file("s3://unit-test-bucket/a/b/c/temp", temp_file_path, delete=True) == "" 62 | with open(temp_file_path, "r") as filep: 63 | context = filep.read() 64 | os.remove(temp_file_path) 65 | assert context.find("aaa") != -1 66 | 67 | ret = delete_s3_file("s3://invalid-bucket") 68 | assert ret.startswith("Failed to delete") 69 | 70 | 71 | def test_deep_tuple_to_list(): 72 | assert deep_tuple_to_list((1, (2, 3, (4, 5), (6, 7)))) == [1, [2, 3, [4, 5], [6, 7]]] 73 | 74 | 75 | def test_manipulate_yaml(): 76 | class FakeClass: 77 | def __init__(self, val): 78 | self.val = val 79 | 80 | data = [1, {2: 3, 4: 5}] 81 | assert data == load_from_yaml(dump_to_yaml(data)) 82 | 83 | # Local class cannot be loaded 84 | with pytest.raises(RuntimeError): 85 | load_from_yaml(dump_to_yaml({"a": 4, "b": FakeClass(5)})) 86 | 87 | 88 | def test_serialize_framework_build_config(): 89 | assert serialize_framework_build_config({"a": "b", "c": "d"}) == (("a", "b"), ("c", "d")) 90 | assert serialize_framework_build_config("aa") == ("aa",) 91 | --------------------------------------------------------------------------------