├── .editorconfig ├── .github ├── dependabot.yml └── workflows │ ├── codeql-analysis.yml │ └── traceml.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md └── traceml ├── MANIFEST.in ├── requirements ├── dev.txt ├── master.txt └── test.txt ├── ruff.toml ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── fixtures │ └── events │ │ ├── artifact │ │ └── artifact_events.plx │ │ ├── audio │ │ └── audio_events.plx │ │ ├── chart │ │ └── chart_events.plx │ │ ├── confusion │ │ └── confusion_events.plx │ │ ├── curve │ │ └── curve_events.plx │ │ ├── dataframe │ │ └── dataframe_events.plx │ │ ├── histogram │ │ └── histogram_events.plx │ │ ├── html │ │ └── html_events.plx │ │ ├── image │ │ └── image_events.plx │ │ ├── metric │ │ └── metric_events.plx │ │ ├── model │ │ ├── model_events.plx │ │ └── model_events_without_step.plx │ │ ├── span │ │ ├── span_events.plx │ │ └── span_events_without_step.plx │ │ └── video │ │ └── video_events.plx ├── test_events │ ├── __init__.py │ └── test_schemas.py ├── test_events_processing │ ├── __init__.py │ ├── test_df_processor.py │ ├── test_event_resources.py │ ├── test_event_values.py │ └── test_importance_processors.py ├── test_logging │ ├── __init__.py │ ├── test_logging_parser.py │ └── test_logs.py ├── test_serialization │ ├── __init__.py │ └── test_event_recorder.py ├── test_summary │ ├── __init__.py │ └── test_dfsummary.py └── test_tracking │ ├── __init__.py │ ├── test_run_tracking.py │ └── test_summaries.py └── traceml ├── __init__.py ├── artifacts ├── __init__.py ├── enums.py └── schemas.py ├── events ├── __init__.py ├── paths.py └── schemas.py ├── exceptions.py ├── integrations ├── __init__.py ├── fastai.py ├── fastai_v1.py ├── hugging_face.py ├── ignite.py ├── keras.py ├── langchain.py ├── lightgbm.py ├── pytorch_lightning.py ├── scikit.py ├── tensorboard.py ├── tensorflow.py └── xgboost.py ├── logger.py ├── logging ├── __init__.py ├── handler.py ├── parser.py ├── schemas.py └── streamer.py ├── pkg.py ├── processors ├── __init__.py ├── df_processors.py ├── errors.py ├── events_processors │ ├── __init__.py │ ├── events_artifacts_processors.py │ ├── events_audio_processors.py │ ├── events_charts_processors.py │ ├── events_image_processors.py │ ├── events_metrics_processors.py │ ├── events_models_processors.py │ ├── events_tables_processors.py │ └── events_video_processors.py ├── gpu_processor.py ├── importance_processors.py ├── logs_processor.py └── psutil_processor.py ├── py.typed ├── serialization ├── __init__.py ├── base.py └── writer.py ├── summary ├── __init__.py └── df.py ├── tracking ├── __init__.py └── run.py └── vendor ├── __init__.py ├── matplotlylib ├── LICENSE.txt ├── __init__.py ├── mplexporter │ ├── __init__.py │ ├── _py3k_compat.py │ ├── exporter.py │ ├── renderers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── fake_renderer.py │ │ ├── vega_renderer.py │ │ └── vincent_renderer.py │ ├── tools.py │ └── utils.py ├── mpltools.py ├── renderer.py └── tools.py └── pynvml.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig: http://EditorConfig.org 2 | 3 | # top-most EditorConfig file 4 | root = true 5 | 6 | # Unix-style newlines with a newline ending every file 7 | [*] 8 | end_of_line = lf 9 | insert_final_newline = true 10 | trim_trailing_whitespace = true 11 | max_line_length = 100 12 | 13 | # Set default charset 14 | [*.{ts,js,py}] 15 | charset = utf-8 16 | 17 | # 4 space indentation 18 | [*.py] 19 | indent_style = space 20 | indent_size = 4 21 | 22 | # isort configuration 23 | multi_line_output = 1 24 | skip = migrations 25 | 26 | # Indentation override for all JS under lib directory 27 | [*.{ts,tsx,js,css,yml,yaml}] 28 | indent_style = space 29 | indent_size = 2 30 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | - package-ecosystem: "pip" 8 | directory: "/" 9 | schedule: 10 | interval: "daily" 11 | - package-ecosystem: "gomod" 12 | directory: "/" 13 | schedule: 14 | interval: "daily" 15 | - package-ecosystem: "npm" 16 | directory: "/" 17 | schedule: 18 | interval: "daily" 19 | - package-ecosystem: "composer" 20 | directory: "/" 21 | schedule: 22 | interval: "daily" 23 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ master ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ master ] 20 | schedule: 21 | - cron: '26 6 * * 6' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://git.io/codeql-language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v4 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 52 | 53 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 54 | # If this step fails, then you should remove it and run the build manually (see below) 55 | - name: Autobuild 56 | uses: github/codeql-action/autobuild@v2 57 | 58 | # ℹ️ Command-line programs to run using the OS shell. 59 | # 📚 https://git.io/JvXDl 60 | 61 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 62 | # and modify them (or add more) to build your code if your project 63 | # uses a compiled language 64 | 65 | #- run: | 66 | # make bootstrap 67 | # make release 68 | 69 | - name: Perform CodeQL Analysis 70 | uses: github/codeql-action/analyze@v2 71 | -------------------------------------------------------------------------------- /.github/workflows/traceml.yml: -------------------------------------------------------------------------------- 1 | name: TraceML 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | - testing 8 | jobs: 9 | library: 10 | # if: github.event.comment.body == 'test core' 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.8, 3.9, '3.10', '3.11', '3.12', '3.13'] 15 | steps: 16 | - run: | 17 | pip install -U traceml 18 | tests: 19 | # if: github.event.comment.body == 'test core' 20 | runs-on: ubuntu-latest 21 | strategy: 22 | matrix: 23 | python-version: [3.8, 3.9, '3.10', '3.11', '3.12', '3.13'] 24 | 25 | steps: 26 | - uses: actions/checkout@v4 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v5 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | - name: Upgrade pip 32 | run: | 33 | which python 34 | python -m pip install --upgrade pip 35 | - name: Install test dependencies 36 | run: pip install -r traceml/requirements/test.txt 37 | - name: Install master dependencies 38 | run: pip install -r traceml/requirements/master.txt 39 | - name: Install dev libraries 40 | run: export USE_LOCAL_PACKAGES="true" && pip install --upgrade --editable "traceml[dev]" 41 | - name: Test with pytest 42 | run: | 43 | cd traceml 44 | pytest -vv 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | .mypy_cache/ 7 | .vscode/ 8 | 9 | # temp files 10 | *~ 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | docker/environment/ 18 | build/ 19 | develop-eggs/ 20 | tsconfig.tsbuildinfo 21 | dist/ 22 | pydist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | public/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *,cover 55 | .hypothesis/ 56 | 57 | # Translations 58 | *.mo 59 | 60 | # Mr Developer 61 | .mr.developer.cfg 62 | .pydevproject 63 | .project 64 | .settings/ 65 | .idea/ 66 | .DS_Store 67 | 68 | # fab files 69 | fabsettings.py 70 | fabfile.py 71 | fab_templates/ 72 | 73 | # graphviz files 74 | *_graphviz.png 75 | *_graphviz.dot 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # IPython Notebook 84 | .ipynb_checkpoints 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # locals 90 | local.py 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule* 94 | celeryev.pid 95 | celeryd*pid 96 | celeryd*log 97 | 98 | # dotenv 99 | .env 100 | 101 | # virtualenv 102 | venv/ 103 | ENV/ 104 | 105 | # data 106 | big_data/ 107 | examples/data/ 108 | data/ 109 | 110 | # environment vars 111 | local.json 112 | 113 | # project 114 | logs/ 115 | /wheels/* 116 | setup.log 117 | reports/ 118 | media/ 119 | /**/static/debug_toolbar 120 | /**/static/rest_framework 121 | /**/static/css 122 | 123 | # npm modules and transpiled typescript files 124 | client/dist/ 125 | client/node_modules/ 126 | client/packages/sdk/dist/ 127 | client/packages/sdk/src/ 128 | client/packages/ui/dist/ 129 | client/packages/ui/lib/ 130 | client/packages/ui/es/ 131 | client/packages/ui/test-env/ 132 | client/packages/ui/types/ 133 | client/packages/ui/ui.d.ts 134 | client/**/node_modules/ 135 | client/npm-debug.log 136 | deploy/dist/ 137 | deploy/node_modules/ 138 | deploy/npm-debug.log 139 | /examples/polyaxon-logs/ 140 | wip/ 141 | 142 | # don't ignore static dist 143 | !static/v1/dist 144 | lastfailed 145 | 146 | # Ignore local.env.ts 147 | dev.env.ts 148 | 149 | # Binaries for programs and plugins 150 | *.exe 151 | *.exe~ 152 | *.dll 153 | *.so 154 | *.dylib 155 | bin 156 | 157 | # Test binary, build with `go test -c` 158 | *.test 159 | 160 | # Output of the go coverage tool, specifically when used with LiteIDE 161 | *.out 162 | 163 | # Kubernetes Generated files - skip generated files, except for vendored files 164 | 165 | !vendor/**/zz_generated.* 166 | client/.awcache/ 167 | 168 | 169 | # ignore dev/test values 170 | values.*.dev.yaml 171 | values.*.test.yaml 172 | values.*.test.json 173 | values.*.test.sh 174 | 175 | # Chart dependencies 176 | **/charts/*.tgz 177 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | * Using welcoming and inclusive language 12 | * Being respectful of differing viewpoints and experiences 13 | * Gracefully accepting constructive criticism 14 | * Focusing on what is best for the community 15 | * Showing empathy towards other community members 16 | 17 | Examples of unacceptable behavior by participants include: 18 | 19 | * The use of sexualized language or imagery and unwelcome sexual attention or advances 20 | * Trolling, insulting/derogatory comments, and personal or political attacks 21 | * Public or private harassment 22 | * Publishing others' private information, such as a physical or electronic address, without explicit permission 23 | * Other conduct which could reasonably be considered inappropriate in a professional setting 24 | 25 | ## Our Responsibilities 26 | 27 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 28 | 29 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 30 | 31 | ## Scope 32 | 33 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 34 | 35 | ## Enforcement 36 | 37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 38 | 39 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 40 | 41 | ## Attribution 42 | 43 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] 44 | 45 | [homepage]: http://contributor-covenant.org 46 | [version]: http://contributor-covenant.org/version/1/4/ 47 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | Polyaxon project would love to welcome your contributions. There are several ways to help out: 4 | 5 | * Create an [issue](https://github.com/polyaxon/polyaxon/issues) on GitHub, if you have found a bug 6 | * Write test cases for open bug issues 7 | * Write patches for open bug/feature issues, preferably with test cases included 8 | * Contribute to the documentation 9 | * Blog about different ways you are using Polyaxon 10 | 11 | There are a few guidelines that we need contributors to follow so that we have a chance of keeping on top of things. 12 | 13 | ## Reporting issues 14 | 15 | Polyaxon has probably many issues and bugs, a great way to contribute to the project is to send a detailed report when you encounter an issue. We always appreciate a well-written, thorough bug report, and will thank you for it! 16 | 17 | Sometimes Polyaxon is missing a feature you need, and we encourage our users to create and contribute such features. 18 | 19 | Check the current [issues](https://github.com/polyaxon/polyaxon/issues) if doesn't already include that problem or suggestion before submitting an issue. 20 | If you find a match, add a quick "+1", Doing this helps prioritize the most common problems and requests. 21 | 22 | When reporting issues, please include your host OS (Ubuntu 14.04, CentOS 7, etc), and the version of the libraries you are using. 23 | 24 | Please also include the steps required to reproduce the problem if possible and applicable. This information will help us review and fix your issue faster. 25 | 26 | ## Contributing 27 | 28 | Before you contribute to Polyaxon, there are a few things that you'll need to do 29 | 30 | * Make sure you have a [GitHub account](https://github.com/signup/free). 31 | * Submit an [issue](https://github.com/polyaxon/polyaxon/issues), assuming one does not already exist. 32 | * Clearly describe the issue including steps to reproduce when it is a bug. 33 | * Make sure you fill in the earliest version that you know has the issue. 34 | * Fork the repository on GitHub. 35 | 36 | ### Making Changes 37 | 38 | * Create a topic branch from where you want to base your work. 39 | * This is usually the master branch. 40 | * Only target an existing branch if you are certain your fix must be on that branch. 41 | * To quickly create a topic branch based on master; `git checkout -b my_contribution origin/master`. 42 | It is best to avoid working directly on the `master` branch. Doing so will help avoid conflicts if you pull in updates from origin. 43 | * Make commits of logical units. Implementing a new function and calling it in 44 | another file constitute a single logical unit of work. 45 | * A majority of submissions should have a single commit, so if in doubt, squash your commits down to one commit. 46 | * Use descriptive commit messages and reference the #issue number. 47 | * Core test cases should continue to pass. (Test are in progress) 48 | * Pull requests must be cleanly rebased on top of master without multiple branches mixed into the PR. 49 | 50 | ### Which branch to base the work 51 | 52 | All changes should be based on the latest master commit. 53 | 54 | ## Questions 55 | 56 | If you need help with how to use this library, please check the list of examples, if it is still unclear you can also open an issue. 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2018-2023 Polyaxon, Inc. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /traceml/MANIFEST.in: -------------------------------------------------------------------------------- 1 | global-exclude *.py[cod] 2 | global-exclude __pycache__ 3 | prune tests 4 | prune __pycache__ 5 | include traceml/py.typed 6 | include ../README.md 7 | include ../LICENSE 8 | include ../CONTRIBUTING.md 9 | -------------------------------------------------------------------------------- /traceml/requirements/dev.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | pyarrow 3 | ipython 4 | scikit-learn 5 | altair 6 | matplotlib 7 | moviepy==1.0.3 8 | plotly==4.6.0 9 | Pillow 10 | bokeh 11 | aiofiles==24.1.0 12 | imageio<2.28.0 # 2.28.0 has a bug that breaks moviepy 13 | -------------------------------------------------------------------------------- /traceml/requirements/master.txt: -------------------------------------------------------------------------------- 1 | -e git+https://github.com/polyaxon/cli.git@master#egg=polyaxon&subdirectory=cli 2 | -------------------------------------------------------------------------------- /traceml/requirements/test.txt: -------------------------------------------------------------------------------- 1 | coverage<7.5 2 | faker<24.0.0 3 | flaky<3.8.0 4 | mock<5.2.0 5 | pytest<8.1.0 6 | pytest-asyncio<0.23.0 7 | ruff 8 | 9 | mypy<1.9 10 | types-aiofiles 11 | types-certifi 12 | types-protobuf 13 | types-python-dateutil 14 | types-pytz 15 | types-PyYAML 16 | types-requests 17 | types-setuptools 18 | types-six 19 | types-orjson 20 | -------------------------------------------------------------------------------- /traceml/ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 88 2 | exclude = ["traceml/vendor"] 3 | 4 | [lint.isort] 5 | known-first-party = ["polyaxon", "traceml", "hypertune", "vents"] 6 | known-third-party = ["rest_framework", "scipy", "sklearn", "datadog", "docker", "corsheaders", "celery", "picklefield", "sentry_sdk", "orjson", "pydantic", "clipped"] 7 | extra-standard-library = ["typing", "typing_extensions", "mock", "pytest", "factory", "faker", "flaky", "numpy", "pandas", "requests", "websocket", "jinja2", "yaml", "pytz"] 8 | force-single-line = false 9 | force-sort-within-sections = true 10 | combine-as-imports = true 11 | lines-after-imports = 2 12 | section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] 13 | 14 | [lint.per-file-ignores] 15 | # Do not enforce usage and import order rules in init files 16 | "__init__.py" = ["E402", "F401", "F403", "I"] 17 | -------------------------------------------------------------------------------- /traceml/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description = file: README.md 3 | 4 | [tool:pytest] 5 | addopts = --doctest-glob='*.rst' 6 | markers = 7 | events_mark 8 | processors_mark 9 | logging_mark 10 | serialization_mark 11 | tracking_mark 12 | 13 | [mypy] 14 | python_version = 3.9 15 | namespace_packages = true 16 | ignore_missing_imports = True 17 | show_error_codes = True 18 | allow_redefinition = True 19 | exclude = (setup.py$)|(tests/)|(vendor/) 20 | -------------------------------------------------------------------------------- /traceml/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from setuptools import find_packages, setup 5 | from setuptools.command.test import test as TestCommand 6 | 7 | 8 | def read_readme(): 9 | if not os.path.exists("./README.md"): 10 | return "" 11 | with open("./README.md") as f: 12 | return f.read() 13 | 14 | 15 | class PyTest(TestCommand): 16 | def finalize_options(self): 17 | TestCommand.finalize_options(self) 18 | self.test_args = [] 19 | self.test_suite = True 20 | 21 | def run_tests(self): 22 | import pytest 23 | 24 | errcode = pytest.main(self.test_args) 25 | sys.exit(errcode) 26 | 27 | 28 | with open(os.path.join("./traceml/pkg.py"), encoding="utf8") as f: 29 | pkg = {} 30 | exec(f.read(), pkg) 31 | 32 | 33 | with open("requirements/dev.txt") as requirements_file: 34 | dev_requirements = requirements_file.read().splitlines() 35 | 36 | extra = { 37 | "polyaxon": ["polyaxon"], 38 | "dev": dev_requirements, 39 | "all": [ 40 | "scikit-learn", 41 | "Pillow", 42 | "matplotlib", 43 | "moviepy", 44 | "plotly", 45 | "bokeh", 46 | "pandas", 47 | "altair", 48 | ], 49 | } 50 | 51 | setup( 52 | name=pkg["NAME"], 53 | version=pkg["VERSION"], 54 | description=pkg["DESC"], 55 | long_description=read_readme(), 56 | long_description_content_type="text/markdown", 57 | maintainer=pkg["AUTHOR"], 58 | maintainer_email=pkg["EMAIL"], 59 | author=pkg["AUTHOR"], 60 | author_email=pkg["EMAIL"], 61 | url=pkg["URL"], 62 | license=pkg["LICENSE"], 63 | platforms="any", 64 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 65 | keywords=[ 66 | "polyaxon", 67 | "aws", 68 | "s3", 69 | "microsoft", 70 | "azure", 71 | "google cloud storage", 72 | "gcs", 73 | "deep-learning", 74 | "machine-learning", 75 | "data-science", 76 | "neural-networks", 77 | "artificial-intelligence", 78 | "ai", 79 | "reinforcement-learning", 80 | "kubernetes", 81 | "aws", 82 | "microsoft", 83 | "azure", 84 | "google cloud", 85 | "tensorFlow", 86 | "pytorch", 87 | "matplotlib", 88 | "plotly", 89 | "visualization", 90 | "analytics", 91 | ], 92 | install_requires=[], 93 | extras_require=extra, 94 | python_requires=">=3.8", 95 | classifiers=[ 96 | "Intended Audience :: Information Technology", 97 | "Intended Audience :: System Administrators", 98 | "Intended Audience :: Developers", 99 | "Intended Audience :: Science/Research", 100 | "Operating System :: OS Independent", 101 | "Programming Language :: Python :: 3", 102 | "Programming Language :: Python", 103 | "Topic :: Software Development :: Libraries :: Application Frameworks", 104 | "Topic :: Software Development :: Libraries :: Python Modules", 105 | "Topic :: Software Development :: Libraries", 106 | "Topic :: Software Development", 107 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 108 | "Typing :: Typed", 109 | "Programming Language :: Python :: 3 :: Only", 110 | "Programming Language :: Python :: 3.6", 111 | "Programming Language :: Python :: 3.7", 112 | "Programming Language :: Python :: 3.8", 113 | "Programming Language :: Python :: 3.9", 114 | "Programming Language :: Python :: 3.10", 115 | "Programming Language :: Python :: 3.11", 116 | ], 117 | tests_require=["pytest"], 118 | cmdclass={"test": PyTest}, 119 | ) 120 | -------------------------------------------------------------------------------- /traceml/tests/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/artifact/artifact_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|artifact 2 | 2018-12-11T10:24:57|12|{"kind":"dataframe","path":"path1"} 3 | 2018-12-11T10:25:57|13|{"kind":"tsv","path":"path2"} 4 | 2018-12-11T10:26:57|14|{"kind":"csv","path":"path3"} 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/audio/audio_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|audio 2 | 2018-12-11T10:24:57|12|{"sample_rate": 1.1, "num_channels": 2, "length_frames": 2, "path": "test"} 3 | 2018-12-11T10:25:57|13|{"sample_rate": 1.11, "num_channels": 22, "length_frames": 22, "path": "test", "content_type": "wav"} 4 | 2018-12-11T10:26:57|14|{"path": "testwave", "content_type": "wav"} 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/chart/chart_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|chart 2 | 2018-12-11T10:24:57|12|{"kind":"plotly","figure":{"foo":"bar"}} 3 | 2018-12-11T10:25:57|13|{"kind":"vega","figure":{"foo2": "bar2"}} 4 | 2018-12-11T10:26:57|14|{"kind":"bokeh","figure":{"foo3": "bar3"}} 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/confusion/confusion_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|confusion 2 | 2018-12-11T10:24:57|12|{"x":["foo", "bar", "moo"], "y": [0.1, 0.3, 0.4], "z": [[0.1, 0.3, 0.5], [1.0, 0.8, 0.6], [0.1, 0.3, 0.6], [0.4, 0.2, 0.2]]} 3 | 2018-12-11T10:25:57|13|{"x":["foo", "bar", "moo"], "y": [0.1, 0.3, 0.4], "z": [[0.1, 0.3, 0.5], [1.0, 0.8, 0.6], [0.1, 0.3, 0.6], [0.4, 0.2, 0.2]]} 4 | 2018-12-11T10:26:57|14|{"x":["foo", "bar", "moo"], "y": [0.1, 0.3, 0.4], "z": [[0.1, 0.3, 0.5], [1.0, 0.8, 0.6], [0.1, 0.3, 0.6], [0.4, 0.2, 0.2]]} 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/curve/curve_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|curve 2 | 2018-12-11T10:24:57|12|{"kind":"roc","x":[1.1, 3.1, 5.1], "y": [0.1, 0.3, 0.4], "annotation": "0.1"} 3 | 2018-12-11T10:25:57|13|{"kind":"pr","x":[1.1, 3.1, 5.1], "y": [0.1, 0.3, 0.4], "annotation": "0.21"} 4 | 2018-12-11T10:26:57|14|{"kind":"custom","x":[1.1, 3.1, 5.1], "y": [0.1, 0.3, 0.4], "annotation": "0.1"} 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/dataframe/dataframe_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|dataframe 2 | 2018-12-11T10:24:57|12|{"path": "path1", "content_type": "parquet"} 3 | 2018-12-11T10:25:57|13|{"path": "path2", "content_type": "pickle"} 4 | 2018-12-11T10:26:57|14|{"path": "path3"} 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/histogram/histogram_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|histogram 2 | 2018-12-11T10:24:57|12|{"values": [10], "counts": [1]} 3 | 2018-12-11T10:25:57|13|{"values": [10, 1, 1], "counts": [1, 1, 1]} 4 | 2018-12-11T10:26:57|14|{"values": [10, 112, 12, 1], "counts": [12, 1, 1, 1]} 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/html/html_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|html 2 | 2018-12-11T10:24:57|12|"
1
" 3 | 2018-12-11T10:25:57|13|"
2
" 4 | 2018-12-11T10:26:57|14|"
3
" 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/image/image_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|image 2 | 2018-12-11T10:24:57|12|{"path": "test"} 3 | 2018-12-11T10:25:57|13|{"height": 1, "width": 1} 4 | 2018-12-11T10:26:57|14|{"height": 10, "width": 10, "colorspace": 2} 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/metric/metric_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|metric 2 | 2018-12-11T10:24:57|12|0.1 3 | 2018-12-11T10:25:57|13|0.2 4 | 2018-12-11T10:26:57|14|0.3 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/model/model_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|model 2 | 2018-12-11T10:24:57|12|{"framework": "tensorflow", "path": "path1"} 3 | 2018-12-11T10:25:57|13|{"framework": "pytorch", "path": "path2"} 4 | 2018-12-11T10:26:57|14|{"framework": "onnx", "path": "path3"} 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/model/model_events_without_step.plx: -------------------------------------------------------------------------------- 1 | step|timestamp|model 2 | |2018-12-11T10:24:57|{"framework": "tensorflow", "path": "path1"} 3 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/span/span_events.plx: -------------------------------------------------------------------------------- 1 | 2 | {"timestamp": "2018-12-11T10:24:57", "step": 12, "span": "{\"uuid\":\"ceb21ee781254719b18664ad1fde57a2\",\"name\":\"span1\",\"tags\":[\"tag1\",\"tag2\"]}"} 3 | {"timestamp": "2018-12-11T10:25:57", "step": 13, "span": "{\"uuid\":\"ceb21ee781254719b18664ad1fde57a2\",\"name\":\"span2\",\"tags\":[\"tag1\",\"tag2\"],\"inputs\":{\"key\":\"value\"}}"} 4 | {"timestamp": "2018-12-11T10:26:57", "step": 14, "span": "{\"uuid\":\"ceb21ee781254719b18664ad1fde57a2\",\"name\":\"span3\",\"tags\":[\"tag1\",\"tag2\"]}"} 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/span/span_events_without_step.plx: -------------------------------------------------------------------------------- 1 | 2 | {"timestamp": "2018-12-11T10:24:57", "span": "{\"uuid\":\"ceb21ee781254719b18664ad1fde57a2\",\"name\":\"span1\",\"tags\":[\"tag1\",\"tag2\"]}"} 3 | {"timestamp": "2018-12-11T10:25:57", "span": "{\"uuid\":\"ceb21ee781254719b18664ad1fde57a2\",\"name\":\"span2\",\"tags\":[\"tag1\",\"tag2\"],\"inputs\":{\"key\":\"value\"}}"} 4 | {"timestamp": "2018-12-11T10:26:57", "span": "{\"uuid\":\"ceb21ee781254719b18664ad1fde57a2\",\"name\":\"span3\",\"tags\":[\"tag1\",\"tag2\"]}"} 5 | -------------------------------------------------------------------------------- /traceml/tests/fixtures/events/video/video_events.plx: -------------------------------------------------------------------------------- 1 | timestamp|step|video 2 | 2018-12-11T10:24:57|12|{"path": "test", "content_type": "mp4"} 3 | 2018-12-11T10:25:57|13|{"height": 1, "width": 1} 4 | 2018-12-11T10:26:57|14|{"height": 10, "width": 10, "colorspace": 2} 5 | -------------------------------------------------------------------------------- /traceml/tests/test_events/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/tests/test_events_processing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/tests/test_events_processing/test_event_resources.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from polyaxon._utils.test_utils import BaseTestCase 4 | from traceml.processors.events_processors import metrics_dict_to_list 5 | 6 | 7 | @pytest.mark.processors_mark 8 | class TestEventWriter(BaseTestCase): 9 | def test_gpu_resources_to_metrics(self): 10 | resources = { 11 | "gpu_0_memory_free": 1000, 12 | "gpu_0_memory_used": 8388608000, 13 | "gpu_0_utilization": 76, 14 | } 15 | 16 | events = metrics_dict_to_list(resources) 17 | assert len(events) == 3 18 | assert [e.event.metric for e in events] == [1000, 8388608000, 76] 19 | 20 | def test_psutil_resources_to_metrics(self): 21 | resources = { 22 | "cpu_percent_avg": 1000, 23 | "cpu_percent_1": 0.3, 24 | "cpu_percent_2": 0.5, 25 | "getloadavg": 76, 26 | "memory_total": 12883853312, 27 | "memory_used": 8388608000, 28 | } 29 | 30 | events = metrics_dict_to_list(resources) 31 | assert len(events) == 6 32 | assert [e.event.metric for e in events] == [ 33 | 1000, 34 | 0.3, 35 | 0.5, 36 | 76, 37 | 12883853312, 38 | 8388608000, 39 | ] 40 | -------------------------------------------------------------------------------- /traceml/tests/test_events_processing/test_event_values.py: -------------------------------------------------------------------------------- 1 | import io 2 | import numpy as np 3 | import os 4 | import pytest 5 | import tempfile 6 | 7 | from bokeh.plotting import figure 8 | from PIL import Image 9 | from plotly import figure_factory 10 | 11 | from polyaxon._utils.test_utils import BaseTestCase, tensor_np 12 | from traceml.processors.events_processors import ( 13 | audio, 14 | bokeh_chart, 15 | convert_to_HWC, 16 | histogram, 17 | image, 18 | image_boxes, 19 | plotly_chart, 20 | prepare_video, 21 | video, 22 | ) 23 | 24 | 25 | @pytest.mark.processors_mark 26 | class TestEventValues(BaseTestCase): 27 | def setUp(self): 28 | self.run_path = tempfile.mkdtemp() 29 | self.asset_path = self.run_path + "/asset" 30 | 31 | def test_uint8_image(self): 32 | """Tests that uint8 image (pixel values in [0, 255]) is not changed""" 33 | assert os.path.exists(self.asset_path) is False 34 | event = image( 35 | asset_path=self.asset_path, 36 | data=tensor_np(shape=(3, 32, 32), dtype=np.uint8), 37 | ) 38 | assert event.path == self.asset_path 39 | assert os.path.exists(self.asset_path) is True 40 | 41 | def test_float32_image(self): 42 | """Tests that float32 image (pixel values in [0, 1]) are scaled correctly to [0, 255]""" 43 | assert os.path.exists(self.asset_path) is False 44 | event = image(asset_path=self.asset_path, data=tensor_np(shape=(3, 32, 32))) 45 | assert event.path == self.asset_path 46 | assert os.path.exists(self.asset_path) is True 47 | 48 | def test_float_1_converts_to_uint8_255(self): 49 | assert os.path.exists(self.asset_path) is False 50 | green_uint8 = np.array([[[0, 255, 0]]], dtype="uint8") 51 | green_float32 = np.array([[[0, 1, 0]]], dtype="float32") 52 | 53 | a = image(asset_path=self.run_path + "/asset1", data=green_uint8) 54 | b = image(asset_path=self.run_path + "/asset2", data=green_float32) 55 | self.assertEqual( 56 | Image.open(io.BytesIO(open(a.path, "br").read())), 57 | Image.open(io.BytesIO(open(b.path, "br").read())), 58 | ) 59 | 60 | def test_list_input(self): 61 | with pytest.raises(Exception): 62 | histogram("dummy", [1, 3, 4, 5, 6], "tensorflow") 63 | 64 | def test_empty_input(self): 65 | print("expect error here:") 66 | with pytest.raises(Exception): 67 | histogram("dummy", np.ndarray(0), "tensorflow") 68 | 69 | def test_image_with_boxes(self): 70 | event = image_boxes( 71 | asset_path=self.asset_path, 72 | tensor_image=tensor_np(shape=(3, 32, 32)), 73 | tensor_boxes=np.array([[10, 10, 40, 40]]), 74 | ) 75 | assert event.path == self.asset_path 76 | assert os.path.exists(self.asset_path) is True 77 | 78 | def test_image_with_one_channel(self): 79 | event = image( 80 | asset_path=self.asset_path, 81 | data=tensor_np(shape=(1, 8, 8)), 82 | dataformats="CHW", 83 | ) 84 | assert event.path == self.asset_path 85 | assert os.path.exists(self.asset_path) is True 86 | 87 | def test_image_with_four_channel(self): 88 | event = image( 89 | asset_path=self.asset_path, 90 | data=tensor_np(shape=(4, 8, 8)), 91 | dataformats="CHW", 92 | ) 93 | assert event.path == self.asset_path 94 | assert os.path.exists(self.asset_path) is True 95 | 96 | def test_image_with_one_channel_batched(self): 97 | event = image( 98 | asset_path=self.asset_path, 99 | data=tensor_np(shape=(2, 1, 8, 8)), 100 | dataformats="NCHW", 101 | ) 102 | assert event.path == self.asset_path 103 | assert os.path.exists(self.asset_path) is True 104 | 105 | def test_image_with_3_channel_batched(self): 106 | event = image( 107 | asset_path=self.asset_path, 108 | data=tensor_np(shape=(2, 3, 8, 8)), 109 | dataformats="NCHW", 110 | ) 111 | assert event.path == self.asset_path 112 | assert os.path.exists(self.asset_path) is True 113 | 114 | def test_image_with_four_channel_batched(self): 115 | event = image( 116 | asset_path=self.asset_path, 117 | data=tensor_np(shape=(2, 4, 8, 8)), 118 | dataformats="NCHW", 119 | ) 120 | assert event.path == self.asset_path 121 | assert os.path.exists(self.asset_path) is True 122 | 123 | def test_image_without_channel(self): 124 | event = image( 125 | asset_path=self.asset_path, data=tensor_np(shape=(8, 8)), dataformats="HW" 126 | ) 127 | assert event.path == self.asset_path 128 | assert os.path.exists(self.asset_path) is True 129 | 130 | def test_video(self): 131 | asset_path = self.asset_path + ".gif" 132 | event = video(asset_path=asset_path, tensor=tensor_np(shape=(4, 3, 1, 8, 8))) 133 | assert event.path == asset_path 134 | assert os.path.exists(asset_path) is True 135 | event = video( 136 | asset_path=asset_path, tensor=tensor_np(shape=(16, 48, 1, 28, 28)) 137 | ) 138 | assert event.path == asset_path 139 | assert os.path.exists(asset_path) is True 140 | event = video(asset_path=asset_path, tensor=tensor_np(shape=(20, 7, 1, 8, 8))) 141 | assert event.path == asset_path 142 | assert os.path.exists(asset_path) is True 143 | 144 | def test_audio(self): 145 | event = audio(asset_path=self.asset_path, tensor=tensor_np(shape=(42,))) 146 | assert event.path == self.asset_path 147 | assert os.path.exists(self.asset_path) is True 148 | 149 | def test_histogram_auto(self): 150 | with self.assertRaises(ValueError): 151 | histogram(values=tensor_np(shape=(1024,)), bins="auto", max_bins=5) 152 | event = histogram(values=tensor_np(shape=(1024,)), bins="auto") 153 | assert event.values is not None 154 | assert event.counts is not None 155 | 156 | def test_histogram(self): 157 | with self.assertRaises(ValueError): 158 | histogram(values=tensor_np(shape=(1024,)), bins="fd", max_bins=5) 159 | event = histogram(values=tensor_np(shape=(1024,)), bins="fd") 160 | assert event.values is not None 161 | assert event.counts is not None 162 | 163 | def test_histogram_doane(self): 164 | with self.assertRaises(ValueError): 165 | histogram(tensor_np(shape=(1024,)), bins="doane", max_bins=5) 166 | event = histogram(tensor_np(shape=(1024,)), bins="doane") 167 | assert event.values is not None 168 | assert event.counts is not None 169 | 170 | def test_to_HWC(self): # noqa 171 | np.random.seed(1) 172 | test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8) 173 | converted = convert_to_HWC(test_image, "chw") 174 | assert converted.shape == (32, 32, 3) 175 | test_image = np.random.randint(0, 256, size=(16, 3, 32, 32), dtype=np.uint8) 176 | converted = convert_to_HWC(test_image, "nchw") 177 | assert converted.shape == (64, 256, 3) 178 | test_image = np.random.randint(0, 256, size=(32, 32), dtype=np.uint8) 179 | converted = convert_to_HWC(test_image, "hw") 180 | assert converted.shape == (32, 32, 3) 181 | 182 | def test_prepare_video(self): 183 | # at each timestep the sum over all other dimensions of the video should stay the same 184 | np.random.seed(1) 185 | video_before = np.random.random((4, 10, 3, 20, 20)) 186 | video_after = prepare_video(np.copy(video_before)) 187 | video_before = np.swapaxes(video_before, 0, 1) 188 | video_before = np.reshape(video_before, newshape=(10, -1)) 189 | video_after = np.reshape(video_after, newshape=(10, -1)) 190 | np.testing.assert_array_almost_equal( 191 | np.sum(video_before, axis=1), np.sum(video_after, axis=1) 192 | ) 193 | 194 | def test_bokeh_chart(self): 195 | # prepare some data 196 | x = [1, 2, 3, 4, 5] 197 | y = [6, 7, 2, 4, 5] 198 | 199 | # create a new plot with a title and axis labels 200 | p = figure(title="simple line example", x_axis_label="x", y_axis_label="y") 201 | 202 | # add a line renderer with legend and line thickness 203 | p.line(x, y, line_width=2) 204 | 205 | # show the results 206 | event = bokeh_chart(p) 207 | assert isinstance(event.figure, dict) 208 | 209 | def test_plotly_chart(self): 210 | x1 = np.random.randn(200) - 2 211 | x2 = np.random.randn(200) 212 | x3 = np.random.randn(200) + 2 213 | hist_data = [x1, x2, x3] 214 | group_labels = ["Group 1", "Group 2", "Group 3"] 215 | p = figure_factory.create_distplot( 216 | hist_data, group_labels, bin_size=[0.1, 0.25, 0.5] 217 | ) 218 | 219 | # show the results 220 | event = plotly_chart(p) 221 | assert isinstance(event.figure, dict) 222 | -------------------------------------------------------------------------------- /traceml/tests/test_events_processing/test_importance_processors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from unittest import TestCase 5 | 6 | from traceml.processors.importance_processors import calculate_importance_correlation 7 | 8 | 9 | @pytest.mark.processors_mark 10 | class TestFeatureImportance(TestCase): 11 | def test_empty_value(self): 12 | assert calculate_importance_correlation(None, None) is None 13 | 14 | def test_single_value(self): 15 | res = calculate_importance_correlation([{"param1": 3}], [4]) 16 | exp = {"param1": {"correlation": None, "importance": 0.0}} 17 | assert res == exp 18 | 19 | def test_correct_values(self): 20 | res = calculate_importance_correlation( 21 | [{"param1": 3}, {"param1": 4}, {"param1": 5}], 22 | [3, 4, 5], 23 | ) 24 | exp = {"param1": {"correlation": 1.0, "importance": 1.0}} 25 | assert res == exp 26 | 27 | def test_multiple_params(self): 28 | res = calculate_importance_correlation( 29 | [ 30 | { 31 | "param1": 1, 32 | "param2": 3, 33 | }, 34 | { 35 | "param1": 2, 36 | "param2": 2, 37 | }, 38 | { 39 | "param1": 3, 40 | "param2": 1, 41 | }, 42 | ], 43 | [1, 2, 3], 44 | ) 45 | exp = { 46 | "param1": {"correlation": 1.0, "importance": 0.464}, 47 | "param2": {"correlation": -1.0, "importance": 0.536}, 48 | } 49 | assert res == exp 50 | 51 | def test_wrong_string_params(self): 52 | assert calculate_importance_correlation(["foo", "bar"], []) is None 53 | 54 | def test_complex_params(self): 55 | res = calculate_importance_correlation( 56 | [{"param1": "str1", "param2": 1}, {"param1": 2, "param2": 2}], [1, 2] 57 | ) 58 | exp = { 59 | "param1_2": {"correlation": 1.0, "importance": 0.348}, 60 | "param1_str1": {"correlation": -1.0, "importance": 0.308}, 61 | "param2": {"correlation": 1.0, "importance": 0.344}, 62 | } 63 | assert res == exp 64 | 65 | def test_nan_value(self): 66 | assert ( 67 | calculate_importance_correlation( 68 | [{"param1": 3, "param2": 1}, {"param1": 2, "param2": 2}], [np.nan, 2] 69 | ) 70 | is None 71 | ) 72 | 73 | def test_empty_metrics(self): 74 | assert calculate_importance_correlation([{"foo": 2, "bar": 4}], []) is None 75 | -------------------------------------------------------------------------------- /traceml/tests/test_logging/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/tests/test_logging/test_logging_parser.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from clipped.utils.dates import parse_datetime 4 | 5 | from polyaxon._utils.test_utils import BaseTestCase 6 | from traceml.logging.parser import ( 7 | DATETIME_REGEX, 8 | ISO_DATETIME_REGEX, 9 | timestamp_search_regex, 10 | ) 11 | 12 | 13 | @pytest.mark.logging_mark 14 | class TestLoggingUtils(BaseTestCase): 15 | def test_has_timestamp(self): 16 | log_line = "2018-12-11 10:24:57 UTC" 17 | log_value, ts = timestamp_search_regex(DATETIME_REGEX, log_line) 18 | assert ts == parse_datetime("2018-12-11 10:24:57 UTC") 19 | assert log_value == "" 20 | 21 | def test_log_line_has_datetime(self): 22 | log_line = "2018-12-11 10:24:57 UTC foo" 23 | log_value, ts = timestamp_search_regex(DATETIME_REGEX, log_line) 24 | 25 | assert ts == parse_datetime("2018-12-11 10:24:57 UTC") 26 | assert log_value == "foo" 27 | 28 | def test_log_line_has_iso_datetime(self): 29 | log_line = "2018-12-11T08:49:07.163495183Z foo" 30 | 31 | log_value, ts = timestamp_search_regex(ISO_DATETIME_REGEX, log_line) 32 | 33 | assert ts == parse_datetime("2018-12-11T08:49:07.163495183Z") 34 | assert log_value == "foo" 35 | -------------------------------------------------------------------------------- /traceml/tests/test_logging/test_logs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from clipped.utils.dates import parse_datetime 4 | from clipped.utils.tz import now 5 | 6 | from polyaxon._utils.test_utils import BaseTestCase 7 | from traceml.logging.schemas import V1Log, V1Logs 8 | 9 | 10 | @pytest.mark.logging_mark 11 | class TestLogV1(BaseTestCase): 12 | def test_has_timestamp(self): 13 | parsed = V1Log.process_log_line( 14 | value="foo", 15 | timestamp="2018-12-11 10:24:57 UTC", 16 | node="node1", 17 | pod="pod1", 18 | container="container1", 19 | ) 20 | expected = V1Log( 21 | value="foo", 22 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 23 | node="node1", 24 | pod="pod1", 25 | container="container1", 26 | ) 27 | assert parsed == expected 28 | 29 | def test_has_no_timestamp(self): 30 | log_result = V1Log.process_log_line( 31 | value="foo", node="node1", pod="pod1", container="container1" 32 | ) 33 | assert log_result.timestamp.date() == now().date() 34 | 35 | def test_has_datetime_timestamp(self): 36 | log_result = V1Log.process_log_line( 37 | timestamp=now(), 38 | value="foo", 39 | node="node1", 40 | pod="pod1", 41 | container="container1", 42 | ) 43 | assert log_result.timestamp.date() == now().date() 44 | 45 | def test_log_line_has_datetime(self): 46 | parsed = V1Log.process_log_line( 47 | value="2018-12-11 10:24:57 UTC foo", 48 | node="node1", 49 | pod="pod1", 50 | container="container1", 51 | ) 52 | expected = V1Log( 53 | value="foo", 54 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 55 | node="node1", 56 | pod="pod1", 57 | container="container1", 58 | ) 59 | assert parsed == expected 60 | 61 | def test_log_line_has_iso_datetime(self): 62 | parsed = V1Log.process_log_line( 63 | value="2018-12-11T08:49:07.163495183Z foo", 64 | node="node1", 65 | pod="pod1", 66 | container="container1", 67 | ) 68 | expected = V1Log( 69 | value="foo", 70 | timestamp=parse_datetime("2018-12-11T08:49:07.163495+00:00"), 71 | node="node1", 72 | pod="pod1", 73 | container="container1", 74 | ) 75 | assert parsed == expected 76 | 77 | def test_to_csv(self): 78 | log_line = V1Log( 79 | value="foo", 80 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 81 | node="node1", 82 | pod="pod1", 83 | container="container1", 84 | ) 85 | 86 | assert log_line.to_csv() == '{}|node1|pod1|container1|{{"_":"foo"}}'.format( 87 | log_line.timestamp 88 | ) 89 | 90 | 91 | class TestLogsV1(BaseTestCase): 92 | def test_logs(self): 93 | logs = V1Logs( 94 | last_file="1679665234.498643", 95 | logs=[ 96 | V1Log( 97 | value="foo", 98 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 99 | node="node1", 100 | pod="pod1", 101 | container="container1", 102 | ), 103 | V1Log( 104 | value="foo", 105 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 106 | node="node1", 107 | pod="pod1", 108 | container="container1", 109 | ), 110 | V1Log( 111 | value="foo", 112 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 113 | node="node1", 114 | pod="pod1", 115 | container="container1", 116 | ), 117 | ], 118 | ) 119 | logs_dict = logs.to_light_dict() 120 | assert logs_dict == logs.from_dict(logs_dict).to_light_dict() 121 | assert logs_dict == logs.read(logs.to_json()).to_light_dict() 122 | 123 | def test_logs_with_files(self): 124 | logs = V1Logs( 125 | last_file="1679665234.498643", 126 | last_time=now(), 127 | files=["file1", "file2"], 128 | logs=[ 129 | V1Log( 130 | value="foo", 131 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 132 | node="node1", 133 | pod="pod1", 134 | container="container1", 135 | ), 136 | V1Log( 137 | value="foo", 138 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 139 | node="node1", 140 | pod="pod1", 141 | container="container1", 142 | ), 143 | V1Log( 144 | value="foo", 145 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 146 | node="node1", 147 | pod="pod1", 148 | container="container1", 149 | ), 150 | ], 151 | ) 152 | logs_dict = logs.to_light_dict() 153 | assert logs_dict == logs.from_dict(logs_dict).to_light_dict() 154 | assert logs_dict == logs.read(logs.to_json()).to_light_dict() 155 | 156 | def test_chunk_logs(self): 157 | logs = [ 158 | V1Log( 159 | value="foo1", 160 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 161 | node="node1", 162 | pod="pod1", 163 | container="container1", 164 | ), 165 | V1Log( 166 | value="foo2", 167 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 168 | node="node1", 169 | pod="pod1", 170 | container="container1", 171 | ), 172 | V1Log( 173 | value="foo3", 174 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 175 | node="node1", 176 | pod="pod1", 177 | container="container1", 178 | ), 179 | ] 180 | 181 | V1Logs._CHUNK_SIZE = 2 182 | chunks = [c for c in V1Logs.chunk_logs(logs)] 183 | # 1 chunk 184 | assert [i.value for i in chunks[0].logs] == ["foo1", "foo2"] 185 | 186 | # 2 chunk 187 | assert [i.value for i in chunks[1].logs] == ["foo3"] 188 | 189 | def test_logs_csv_header(self): 190 | assert V1Logs.get_csv_header() == "timestamp|node|pod|container|value" 191 | 192 | def test_logs_to_csv(self): 193 | logs = V1Logs( 194 | last_file="1679665234.498643", 195 | logs=[ 196 | V1Log( 197 | value="foo", 198 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 199 | node="node1", 200 | pod="pod1", 201 | container="container1", 202 | ), 203 | V1Log( 204 | value="foo", 205 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 206 | node="node1", 207 | pod="pod1", 208 | container="container1", 209 | ), 210 | V1Log( 211 | value="foo", 212 | timestamp=parse_datetime("2018-12-11 10:24:57 UTC"), 213 | node="node1", 214 | pod="pod1", 215 | container="container1", 216 | ), 217 | ], 218 | ) 219 | assert logs.to_csv() == "".join( 220 | [ 221 | "\n{}".format(logs.logs[0].to_csv()), 222 | "\n{}".format(logs.logs[1].to_csv()), 223 | "\n{}".format(logs.logs[2].to_csv()), 224 | ] 225 | ) 226 | -------------------------------------------------------------------------------- /traceml/tests/test_serialization/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/tests/test_summary/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/tests/test_summary/test_dfsummary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from pandas.testing import assert_series_equal 5 | from random import shuffle 6 | from unittest import TestCase 7 | 8 | from clipped.utils.units import to_percentage 9 | 10 | from traceml.processors import df_processors 11 | from traceml.summary.df import DataFrameSummary 12 | 13 | 14 | class DataFrameSummaryTest(TestCase): 15 | def setUp(self): 16 | self.size = 1000 17 | missing = [np.nan] * (self.size // 10) + list(range(10)) * ( 18 | (self.size - self.size // 10) // 10 19 | ) 20 | shuffle(missing) 21 | 22 | self.types = pd.Index( 23 | [ 24 | DataFrameSummary.TYPE_NUMERIC, 25 | DataFrameSummary.TYPE_BOOL, 26 | DataFrameSummary.TYPE_CATEGORICAL, 27 | DataFrameSummary.TYPE_CONSTANT, 28 | DataFrameSummary.TYPE_UNIQUE, 29 | DataFrameSummary.TYPE_DATE, 30 | ], 31 | name="types", 32 | ) 33 | 34 | self.columns = [ 35 | "dbool1", 36 | "dbool2", 37 | "duniques", 38 | "dcategoricals", 39 | "dnumerics1", 40 | "dnumerics2", 41 | "dnumerics3", 42 | "dmissing", 43 | "dconstant", 44 | "ddates", 45 | ] 46 | 47 | self.df = pd.DataFrame( 48 | dict( 49 | dbool1=np.random.choice([0, 1], size=self.size), 50 | dbool2=np.random.choice(["a", "b"], size=self.size), 51 | duniques=["x{}".format(i) for i in range(self.size)], 52 | dcategoricals=[ 53 | "a" if i % 2 == 0 else "b" if i % 3 == 0 else "c" 54 | for i in range(self.size) 55 | ], 56 | dnumerics1=range(self.size), 57 | dnumerics2=range(self.size, 2 * self.size), 58 | dnumerics3=list(range(self.size - self.size // 10)) 59 | + list(range(-self.size // 10, 0)), 60 | dmissing=missing, 61 | dconstant=["a"] * self.size, 62 | ddates=pd.date_range("2010-01-01", periods=self.size, freq="1M"), 63 | ) 64 | ) 65 | 66 | self.dfs = DataFrameSummary(self.df) 67 | 68 | def test_get_columns_works_as_expected(self): 69 | assert len(self.dfs.get_columns(self.df, DataFrameSummary.ALL)) == 10 70 | 71 | assert ( 72 | len( 73 | self.dfs.get_columns( 74 | self.df, 75 | DataFrameSummary.INCLUDE, 76 | ["dnumerics1", "dnumerics2", "dnumerics3"], 77 | ) 78 | ) 79 | == 3 80 | ) 81 | 82 | assert ( 83 | len( 84 | self.dfs.get_columns( 85 | self.df, 86 | DataFrameSummary.EXCLUDE, 87 | ["dnumerics1", "dnumerics2", "dnumerics3"], 88 | ) 89 | ) 90 | == 7 91 | ) 92 | 93 | def test_column_types_works_as_expected(self): 94 | result = self.dfs.columns_types[self.types] 95 | expected = pd.Series( 96 | index=self.types, data=[4, 2, 1, 1, 1, 1], name=result.name 97 | )[self.types] 98 | assert_series_equal(result, expected) 99 | 100 | def test_column_stats_works_as_expected(self): 101 | column_stats = self.dfs.columns_stats 102 | self.assertTupleEqual(column_stats.shape, (5, 10)) 103 | 104 | # counts 105 | expected = pd.Series( 106 | index=self.columns, data=self.size, name="counts", dtype="object" 107 | ) 108 | expected["dmissing"] -= 100 109 | assert_series_equal( 110 | column_stats[self.columns].loc["counts"], expected[self.columns] 111 | ) 112 | 113 | # uniques 114 | expected = pd.Series( 115 | index=self.columns, data=self.size, name="uniques", dtype="object" 116 | ) 117 | expected[["dbool1", "dbool2"]] = 2 118 | expected[["dcategoricals"]] = 3 119 | expected[["dconstant"]] = 1 120 | expected[["dmissing"]] = 10 121 | assert_series_equal( 122 | column_stats[self.columns].loc["uniques"].sort_index(), 123 | expected[self.columns].sort_index(), 124 | check_dtype=False, 125 | ) 126 | 127 | # missing 128 | expected = pd.Series(index=self.columns, data=0, name="missing", dtype="object") 129 | expected[["dmissing"]] = 100 130 | assert_series_equal( 131 | column_stats[self.columns].loc["missing"], 132 | expected[self.columns], 133 | check_dtype=False, 134 | ) 135 | 136 | # missing_perc 137 | expected = pd.Series( 138 | index=self.columns, data=["0%"] * 10, name="missing_perc", dtype="object" 139 | ) 140 | 141 | expected[["dmissing"]] = "10%" 142 | assert_series_equal( 143 | column_stats[self.columns].loc["missing_perc"], expected[self.columns] 144 | ) 145 | 146 | # types 147 | expected = pd.Series( 148 | index=self.columns, data=[np.nan] * 10, name="types", dtype="object" 149 | ) 150 | 151 | expected[["dbool1", "dbool2"]] = DataFrameSummary.TYPE_BOOL 152 | expected[["dcategoricals"]] = DataFrameSummary.TYPE_CATEGORICAL 153 | expected[["dconstant"]] = DataFrameSummary.TYPE_CONSTANT 154 | expected[["ddates"]] = DataFrameSummary.TYPE_DATE 155 | expected[["duniques"]] = DataFrameSummary.TYPE_UNIQUE 156 | expected[ 157 | ["dnumerics1", "dnumerics2", "dnumerics3", "dmissing"] 158 | ] = DataFrameSummary.TYPE_NUMERIC # fmt: skip 159 | assert_series_equal( 160 | column_stats[self.columns].loc["types"], expected[self.columns] 161 | ) 162 | 163 | def test_uniques_summary(self): 164 | expected = pd.Series( 165 | index=["counts", "uniques", "missing", "missing_perc", "types"], 166 | data=[self.size, self.size, 0, "0%", DataFrameSummary.TYPE_UNIQUE], 167 | name="duniques", 168 | dtype=object, 169 | ) 170 | assert_series_equal(self.dfs["duniques"], expected) 171 | 172 | def test_constant_summary(self): 173 | self.assertEqual(self.dfs["dconstant"], "This is a constant value: a") 174 | 175 | def test_bool1_summary(self): 176 | count_values = self.df["dbool1"].value_counts() 177 | total_count = self.df["dbool1"].count() 178 | count0 = count_values[0] 179 | count1 = count_values[1] 180 | perc0 = to_percentage(count0 / total_count) 181 | perc1 = to_percentage(count1 / total_count) 182 | expected = pd.Series( 183 | index=[ 184 | '"0" count', 185 | '"0" perc', 186 | '"1" count', 187 | '"1" perc', 188 | "counts", 189 | "uniques", 190 | "missing", 191 | "missing_perc", 192 | "types", 193 | ], 194 | data=[ 195 | str(count0), 196 | perc0, 197 | str(count1), 198 | perc1, 199 | self.size, 200 | 2, 201 | 0, 202 | "0%", 203 | DataFrameSummary.TYPE_BOOL, 204 | ], 205 | name="dbool1", 206 | dtype=object, 207 | ) 208 | 209 | assert_series_equal(self.dfs["dbool1"], expected) 210 | 211 | def test_bool2_summary(self): 212 | count_values = self.df["dbool2"].value_counts() 213 | total_count = self.df["dbool2"].count() 214 | count0 = count_values["a"] 215 | count1 = count_values["b"] 216 | perc0 = to_percentage(count0 / total_count) 217 | perc1 = to_percentage(count1 / total_count) 218 | expected = pd.Series( 219 | index=[ 220 | '"a" count', 221 | '"a" perc', 222 | '"b" count', 223 | '"b" perc', 224 | "counts", 225 | "uniques", 226 | "missing", 227 | "missing_perc", 228 | "types", 229 | ], 230 | data=[ 231 | str(count0), 232 | perc0, 233 | str(count1), 234 | perc1, 235 | self.size, 236 | 2, 237 | 0, 238 | "0%", 239 | DataFrameSummary.TYPE_BOOL, 240 | ], 241 | name="dbool2", 242 | dtype=object, 243 | ) 244 | 245 | assert_series_equal(self.dfs["dbool2"], expected) 246 | 247 | def test_categorical_summary(self): 248 | expected = pd.Series( 249 | index=["top", "counts", "uniques", "missing", "missing_perc", "types"], 250 | data=["a: 500", self.size, 3, 0, "0%", DataFrameSummary.TYPE_CATEGORICAL], 251 | name="dcategoricals", 252 | dtype=object, 253 | ) 254 | 255 | assert_series_equal(self.dfs["dcategoricals"], expected) 256 | 257 | def test_dates_summary(self): 258 | dmin = self.df["ddates"].min() 259 | dmax = self.df["ddates"].max() 260 | expected = pd.Series( 261 | index=[ 262 | "max", 263 | "min", 264 | "range", 265 | "counts", 266 | "uniques", 267 | "missing", 268 | "missing_perc", 269 | "types", 270 | ], 271 | data=[ 272 | dmax, 273 | dmin, 274 | dmax - dmin, 275 | self.size, 276 | self.size, 277 | 0, 278 | "0%", 279 | DataFrameSummary.TYPE_DATE, 280 | ], 281 | name="ddates", 282 | dtype=object, 283 | ).sort_index() 284 | 285 | tmp = self.dfs["ddates"].sort_index() 286 | assert_series_equal(tmp, expected) 287 | 288 | def test_numerics_summary(self): 289 | num1 = self.df["dnumerics1"] 290 | dm, dmp = df_processors.get_deviation_of_mean(num1) 291 | dam, damp = df_processors.get_median_absolute_deviation(num1) 292 | expected = pd.Series( 293 | index=[ 294 | "mean", 295 | "std", 296 | "variance", 297 | "min", 298 | "max", 299 | "mode", 300 | "5%", 301 | "25%", 302 | "50%", 303 | "75%", 304 | "95%", 305 | "iqr", 306 | "kurtosis", 307 | "skewness", 308 | "sum", 309 | "mad", 310 | "cv", 311 | "zeros_num", 312 | "zeros_perc", 313 | "deviating_of_mean", 314 | "deviating_of_mean_perc", 315 | "deviating_of_median", 316 | "deviating_of_median_perc", 317 | "counts", 318 | "uniques", 319 | "missing", 320 | "missing_perc", 321 | "types", 322 | ], 323 | data=[ 324 | num1.mean(), 325 | num1.std(), 326 | num1.var(), 327 | num1.min(), 328 | num1.max(), 329 | num1.mode()[0], 330 | num1.quantile(0.05), 331 | num1.quantile(0.25), 332 | num1.quantile(0.5), 333 | num1.quantile(0.75), 334 | num1.quantile(0.95), 335 | num1.quantile(0.75) - num1.quantile(0.25), 336 | num1.kurt(), 337 | num1.skew(), 338 | num1.sum(), 339 | df_processors.mad(num1), 340 | num1.std() / num1.mean() if num1.mean() else np.nan, 341 | self.size - np.count_nonzero(num1), 342 | to_percentage((self.size - np.count_nonzero(num1)) / self.size), 343 | dm, 344 | dmp, 345 | dam, 346 | damp, 347 | self.size, 348 | self.size, 349 | 0, 350 | "0%", 351 | DataFrameSummary.TYPE_NUMERIC, 352 | ], 353 | name="dnumerics1", 354 | dtype=object, 355 | ) 356 | 357 | assert_series_equal(self.dfs["dnumerics1"], expected) 358 | -------------------------------------------------------------------------------- /traceml/tests/test_tracking/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/traceml/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/traceml/artifacts/__init__.py: -------------------------------------------------------------------------------- 1 | from traceml.artifacts.enums import V1ArtifactKind 2 | from traceml.artifacts.schemas import V1RunArtifact, V1RunArtifacts 3 | -------------------------------------------------------------------------------- /traceml/traceml/artifacts/enums.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from clipped.utils.enums import PEnum 4 | 5 | 6 | class V1ArtifactKind(str, PEnum): 7 | MODEL = "model" 8 | AUDIO = "audio" 9 | VIDEO = "video" 10 | HISTOGRAM = "histogram" 11 | IMAGE = "image" 12 | TENSOR = "tensor" 13 | DATAFRAME = "dataframe" 14 | CHART = "chart" 15 | CSV = "csv" 16 | TSV = "tsv" 17 | PSV = "psv" 18 | SSV = "ssv" 19 | METRIC = "metric" 20 | ENV = "env" 21 | HTML = "html" 22 | TEXT = "text" 23 | FILE = "file" 24 | DIR = "dir" 25 | DOCKERFILE = "dockerfile" 26 | DOCKER_IMAGE = "docker_image" 27 | DATA = "data" 28 | CODEREF = "coderef" 29 | TABLE = "table" 30 | TENSORBOARD = "tensorboard" 31 | CURVE = "curve" 32 | CONFUSION = "confusion" 33 | ANALYSIS = "analysis" 34 | ITERATION = "iteration" 35 | MARKDOWN = "markdown" 36 | SYSTEM = "system" 37 | ARTIFACT = "artifact" 38 | SPAN = "span" 39 | 40 | @classmethod 41 | def is_jsonl_file_event(cls, kind: Optional[Union["V1ArtifactKind", str]]) -> bool: 42 | return kind in { 43 | V1ArtifactKind.HTML, 44 | V1ArtifactKind.TEXT, 45 | V1ArtifactKind.CHART, 46 | V1ArtifactKind.SPAN, 47 | } 48 | 49 | @classmethod 50 | def is_single_file_event(cls, kind: Optional[Union["V1ArtifactKind", str]]) -> bool: 51 | return kind in { 52 | V1ArtifactKind.HTML, 53 | V1ArtifactKind.TEXT, 54 | V1ArtifactKind.HISTOGRAM, 55 | V1ArtifactKind.CHART, 56 | V1ArtifactKind.CONFUSION, 57 | V1ArtifactKind.CURVE, 58 | V1ArtifactKind.METRIC, 59 | V1ArtifactKind.SYSTEM, 60 | V1ArtifactKind.SPAN, 61 | } 62 | 63 | @classmethod 64 | def is_single_or_multi_file_event( 65 | cls, kind: Optional[Union["V1ArtifactKind", str]] 66 | ) -> bool: 67 | return kind in { 68 | V1ArtifactKind.MODEL, 69 | V1ArtifactKind.DATAFRAME, 70 | V1ArtifactKind.AUDIO, 71 | V1ArtifactKind.VIDEO, 72 | V1ArtifactKind.IMAGE, 73 | V1ArtifactKind.CSV, 74 | V1ArtifactKind.TSV, 75 | V1ArtifactKind.PSV, 76 | V1ArtifactKind.SSV, 77 | } 78 | 79 | @classmethod 80 | def is_dir(cls, kind: Optional[Union["V1ArtifactKind", str]]) -> bool: 81 | return kind in { 82 | V1ArtifactKind.TENSORBOARD, 83 | V1ArtifactKind.DIR, 84 | } 85 | 86 | @classmethod 87 | def is_file(cls, kind: Optional[Union["V1ArtifactKind", str]]) -> bool: 88 | return kind in { 89 | V1ArtifactKind.DOCKERFILE, 90 | V1ArtifactKind.FILE, 91 | V1ArtifactKind.ENV, 92 | } 93 | 94 | @classmethod 95 | def is_file_or_dir(cls, kind: Optional[Union["V1ArtifactKind", str]]) -> bool: 96 | return kind in { 97 | V1ArtifactKind.DATA, 98 | V1ArtifactKind.MODEL, 99 | } 100 | -------------------------------------------------------------------------------- /traceml/traceml/artifacts/schemas.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from typing import Dict, List, Optional 4 | 5 | from clipped.compact.pydantic import StrictStr 6 | from clipped.types.uuids import UUIDStr 7 | 8 | from polyaxon._schemas.base import BaseSchemaModel 9 | from traceml.artifacts.enums import V1ArtifactKind 10 | 11 | 12 | class V1RunArtifact(BaseSchemaModel): 13 | _IDENTIFIER = "artifact" 14 | 15 | name: Optional[StrictStr] = None 16 | kind: Optional[V1ArtifactKind] = None 17 | path: Optional[StrictStr] = None 18 | state: Optional[UUIDStr] = None 19 | summary: Optional[Dict] = None 20 | meta_info: Optional[Dict] = None 21 | run: Optional[UUIDStr] = None 22 | connection: Optional[StrictStr] = None 23 | is_input: Optional[bool] = None 24 | 25 | @classmethod 26 | def from_model(cls, model): 27 | return cls( 28 | name=model.name, 29 | kind=model.kind, 30 | path=model.path, 31 | state=model.state, 32 | summary=model.summary, 33 | # connection=model.connection, # TODO: enable 34 | ) 35 | 36 | def get_state(self, namespace: uuid.UUID): 37 | if self.state: 38 | return self.state 39 | summary = self.summary or {} 40 | content = str(summary) 41 | if not summary.get("hash") and self.path: 42 | content += self.path 43 | return uuid.uuid5(namespace, content) 44 | 45 | 46 | class V1RunArtifacts(BaseSchemaModel): 47 | _IDENTIFIER = "artifacts" 48 | 49 | artifacts: Optional[List[V1RunArtifact]] 50 | -------------------------------------------------------------------------------- /traceml/traceml/events/__init__.py: -------------------------------------------------------------------------------- 1 | from traceml.events.paths import ( 2 | get_asset_path, 3 | get_event_assets_path, 4 | get_event_path, 5 | get_resource_path, 6 | ) 7 | from traceml.events.schemas import ( 8 | LoggedEventListSpec, 9 | LoggedEventSpec, 10 | V1Event, 11 | V1EventArtifact, 12 | V1EventAudio, 13 | V1EventChart, 14 | V1EventChartKind, 15 | V1EventConfusionMatrix, 16 | V1EventCurve, 17 | V1EventCurveKind, 18 | V1EventDataframe, 19 | V1EventHistogram, 20 | V1EventImage, 21 | V1EventModel, 22 | V1Events, 23 | V1EventSpan, 24 | V1EventVideo, 25 | ) 26 | -------------------------------------------------------------------------------- /traceml/traceml/events/paths.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from clipped.utils.enums import get_enum_value 4 | 5 | 6 | def get_resource_path( 7 | run_path: str, kind: Optional[str] = None, name: Optional[str] = None 8 | ) -> str: 9 | _path = "{}/resources".format(run_path) 10 | if kind: 11 | _path = "{}/{}".format(_path, get_enum_value(kind)) 12 | if name: 13 | _path = "{}/{}.plx".format(_path, name) 14 | 15 | return _path 16 | 17 | 18 | def get_event_path( 19 | run_path: str, kind: Optional[str] = None, name: Optional[str] = None 20 | ) -> str: 21 | _path = "{}/events".format(run_path) 22 | if kind: 23 | _path = "{}/{}".format(_path, get_enum_value(kind)) 24 | if name: 25 | _path = "{}/{}.plx".format(_path, name) 26 | 27 | return _path 28 | 29 | 30 | def get_event_assets_path(run_path: str, kind: Optional[str] = None) -> str: 31 | _path = "{}/assets".format(run_path) 32 | if kind: 33 | _path = "{}/{}".format(_path, get_enum_value(kind)) 34 | return _path 35 | 36 | 37 | def get_asset_path( 38 | run_path: Optional[str], 39 | kind: Optional[str] = None, 40 | name: Optional[str] = None, 41 | step: Optional[int] = None, 42 | ext: Optional[str] = None, 43 | ) -> str: 44 | if not run_path: 45 | raise ValueError("run_path must be provided to get asset path.") 46 | _path = get_event_assets_path(run_path, kind) 47 | if name: 48 | _path = "{}/{}".format(_path, name) 49 | if step is not None: 50 | _path = "{}_{}".format(_path, step) 51 | if ext: 52 | _path = "{}.{}".format(_path, ext) 53 | 54 | return _path 55 | -------------------------------------------------------------------------------- /traceml/traceml/exceptions.py: -------------------------------------------------------------------------------- 1 | from polyaxon.exceptions import PolyaxonClientException 2 | 3 | TracemlException = PolyaxonClientException 4 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/fastai.py: -------------------------------------------------------------------------------- 1 | from polyaxon._client.decorators import client_handler 2 | from traceml import tracking 3 | from traceml.exceptions import TracemlException 4 | 5 | try: 6 | from fastai.basics import * 7 | from fastai.learner import Callback as baseCallback 8 | from fastai.vision.all import * 9 | except ImportError: 10 | raise TracemlException("Fastai is required to use the tracking Callback") 11 | 12 | 13 | class Callback(baseCallback): 14 | @client_handler(check_no_op=True) 15 | def __init__(self, log_model=False, run=None): 16 | self.log_model = log_model 17 | self.plx_run = tracking.get_or_create_run(run) 18 | self._plx_step = 0 19 | 20 | @client_handler(check_no_op=True) 21 | def before_fit(self): 22 | if not self.plx_run: 23 | return 24 | try: 25 | self.plx_run.log_inputs( 26 | n_epoch=str(self.learn.n_epoch), 27 | model_class=str(type(self.learn.model.__name__)), 28 | ) 29 | except Exception: # noqa 30 | print("Did not log all properties to Polyaxon.") 31 | 32 | try: 33 | model_summary_path = self.plx_run.get_outputs_path("model_summary.txt") 34 | with open(model_summary_path, "w") as g: 35 | g.write(repr(self.learn.model)) 36 | self.plx_run.log_file_ref( 37 | path=model_summary_path, name="model_summary", is_input=False 38 | ) 39 | except Exception: # noqa 40 | print( 41 | "Did not log model summary. " 42 | "Check if your model is PyTorch model and that Polyaxon has correctly initialized " 43 | "the artifacts/outputs path." 44 | ) 45 | 46 | if self.log_model and not hasattr(self.learn, "save_model"): 47 | print( 48 | "Unable to log model to Polyaxon.\n", 49 | 'Use "SaveModelCallback" to save model checkpoints ' 50 | "that will be logged to Polyaxon.", 51 | ) 52 | 53 | @client_handler(check_no_op=True) 54 | def after_batch(self): 55 | # log loss and opt.hypers 56 | if self.training: 57 | self._plx_step += 1 58 | metrics = {} 59 | if hasattr(self, "smooth_loss"): 60 | metrics["smooth_loss"] = to_detach(self.smooth_loss.clone()) 61 | if hasattr(self, "loss"): 62 | metrics["raw_loss"] = to_detach(self.loss.clone()) 63 | if hasattr(self, "train_iter"): 64 | metrics["train_iter"] = self.train_iter 65 | for i, h in enumerate(self.learn.opt.hypers): 66 | for k, v in h.items(): 67 | metrics[f"hypers_{k}"] = v 68 | self.plx_run.log_metrics(step=self._plx_step, **metrics) 69 | 70 | @client_handler(check_no_op=True) 71 | def after_epoch(self): 72 | # log metrics 73 | self.plx_run.log_metrics( 74 | step=self._plx_step, 75 | **{ 76 | n: v 77 | for n, v in zip(self.recorder.metric_names, self.recorder.log) 78 | if n not in ["train_loss", "epoch", "time"] 79 | }, 80 | ) 81 | 82 | # log model weights 83 | if self.log_model and hasattr(self.learn, "save_model"): 84 | if self.learn.save_model.every_epoch: 85 | _file = join_path_file( 86 | f"{self.learn.save_model.fname}_{self.learn.save_model.epoch}", 87 | self.learn.path / self.learn.model_dir, 88 | ext=".pth", 89 | ) 90 | self.plx_run.log_model( 91 | _file, framework="fastai", step=self.learn.save_model.epoch 92 | ) 93 | else: 94 | _file = join_path_file( 95 | self.learn.save_model.fname, 96 | self.learn.path / self.learn.model_dir, 97 | ext=".pth", 98 | ) 99 | self.plx_run.log_model(_file, framework="fastai", versioned=False) 100 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/fastai_v1.py: -------------------------------------------------------------------------------- 1 | from traceml import tracking 2 | from traceml.exceptions import TracemlException 3 | 4 | try: 5 | from fastai.callbacks import TrackerCallback 6 | except ImportError: 7 | raise TracemlException("Fastai is required to use the tracking Callback") 8 | 9 | 10 | class Callback(TrackerCallback): 11 | def __init__(self, learn, run=None, monitor="auto", mode="auto"): 12 | super().__init__(learn, monitor=monitor, mode=mode) 13 | if monitor is None: 14 | # use default TrackerCallback monitor value 15 | super().__init__(learn, mode=mode) 16 | self.run = tracking.get_or_create_run(run) 17 | 18 | def on_epoch_end(self, epoch, smooth_loss, last_metrics, **kwargs): 19 | if not self.run: 20 | return 21 | metrics = { 22 | name: stat 23 | for name, stat in list( 24 | zip(self.learn.recorder.names, [epoch, smooth_loss] + last_metrics) 25 | )[1:] 26 | } 27 | 28 | self.run.log_metrics(**metrics) 29 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/hugging_face.py: -------------------------------------------------------------------------------- 1 | from polyaxon._client.decorators import client_handler 2 | from traceml import tracking 3 | from traceml.exceptions import TracemlException 4 | from traceml.logger import logger 5 | from traceml.processors import events_processors 6 | 7 | try: 8 | from transformers.trainer_callback import TrainerCallback 9 | except ImportError: 10 | raise TracemlException("transformers is required to use the tracking Callback") 11 | 12 | 13 | class Callback(TrainerCallback): 14 | def __init__( 15 | self, 16 | run=None, 17 | ): 18 | super().__init__() 19 | self.run = run 20 | 21 | def _log_model_summary(self, model): 22 | summary, filetype = events_processors.model_to_str(model) 23 | if not summary: 24 | return 25 | rel_path = self.run.get_outputs_path("model_summary.{}".format(filetype)) 26 | with open(rel_path, "w") as f: 27 | f.write(summary) 28 | self.run.log_file_ref(path=rel_path, name="model_summary", is_input=False) 29 | 30 | @client_handler(check_no_op=True) 31 | def setup(self, args, state, model, **kwargs): 32 | self.run = tracking.get_or_create_run(kwargs.get("run")) 33 | if state.is_world_process_zero: 34 | self._log_model_summary(model) 35 | self.run.log_inputs(**args.to_sanitized_dict()) 36 | 37 | @client_handler(check_no_op=True) 38 | def on_train_begin(self, args, state, control, model=None, **kwargs): 39 | if not self.run: 40 | self.setup(args, state, model) 41 | 42 | @client_handler(check_no_op=True) 43 | def on_log(self, args, state, control, logs, model=None, **kwargs): 44 | if not self.run: 45 | self.setup(args, state, model) 46 | if state.is_world_process_zero: 47 | metrics = {} 48 | for k, v in logs.items(): 49 | if isinstance(v, (int, float)): 50 | metrics[k] = v 51 | else: 52 | logger.warning( 53 | f"Trainer is attempting to log a value of " 54 | f'"{v}" of type {type(v)} for key "{k}" as a metric. ' 55 | f"Polyaxon's log_metrics() only accepts float and " 56 | f"int types so we dropped this attribute." 57 | ) 58 | self.run.log_metrics(**metrics, step=state.global_step) 59 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/ignite.py: -------------------------------------------------------------------------------- 1 | from traceml.exceptions import TracemlException 2 | from traceml.run import Run 3 | 4 | try: 5 | from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger 6 | except ImportError: 7 | raise TracemlException("ignite is required to use the tracking Logger") 8 | 9 | 10 | class Logger(PolyaxonLogger): 11 | def __init__(self, *args, **kwargs): 12 | self.experiment = kwargs.get("run") 13 | if not self.experiment: 14 | self.experiment = Run(*args, **kwargs) 15 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/keras.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | from typing import List, Optional 4 | 5 | from clipped.utils.np import sanitize_np_types 6 | 7 | from polyaxon._client.decorators import client_handler 8 | from traceml import tracking 9 | from traceml.exceptions import TracemlException 10 | from traceml.logger import logger 11 | 12 | try: 13 | from tensorflow import keras 14 | except ImportError: 15 | try: 16 | import keras 17 | except ImportError: 18 | raise TracemlException("Keras is required to use the tracking Callback") 19 | 20 | 21 | class Callback(keras.callbacks.Callback): 22 | @client_handler(check_no_op=True) 23 | def __init__( 24 | self, 25 | run=None, 26 | metrics: Optional[List[str]] = None, 27 | log_model: bool = True, 28 | save_weights_only: bool = False, 29 | log_best_prefix="best", 30 | mode: str = "auto", 31 | monitor: str = "val_loss", 32 | use_store_path: bool = False, 33 | model_ext: str = "", 34 | ): 35 | self.run = tracking.get_or_create_run(run) 36 | self.metrics = metrics 37 | self.log_model = log_model 38 | self.filepath = self.run.get_outputs_path( 39 | "model{}".format(model_ext), use_store_path=use_store_path 40 | ) 41 | self.log_best_prefix = log_best_prefix 42 | self.best = None 43 | self.current = None 44 | self.monitor = monitor 45 | self.save_weights_only = save_weights_only 46 | 47 | # From Keras 48 | if mode not in ["auto", "min", "max"]: 49 | print( 50 | "PolyaxonCallback mode %s is unknown, " "fallback to auto mode." % mode 51 | ) 52 | mode = "auto" 53 | 54 | if mode == "min": 55 | self.monitor_op = operator.lt 56 | self.best = float("inf") 57 | elif mode == "max": 58 | self.monitor_op = operator.gt 59 | self.best = float("-inf") 60 | else: 61 | if "acc" in self.monitor or self.monitor.startswith("fmeasure"): 62 | self.monitor_op = operator.gt 63 | self.best = float("-inf") 64 | else: 65 | self.monitor_op = operator.lt 66 | self.best = float("inf") 67 | # Get the previous best metric for resumed runs 68 | previous_best = (self.run.get_inputs() or {}).get( 69 | "{}_{}".format(self.log_best_prefix, self.monitor) 70 | ) 71 | if previous_best is not None: 72 | self.best = previous_best 73 | 74 | @client_handler(check_no_op=True) 75 | def on_train_begin(self, logs=None): # pylint: disable=unused-argument 76 | if not self.run: 77 | return 78 | 79 | params = {} 80 | 81 | try: 82 | params["num_layers"] = len(self.model.layers) 83 | except Exception: # noqa 84 | pass 85 | 86 | try: 87 | params["optimizer_name"] = type(self.model.optimizer).__name__ 88 | except Exception: # noqa 89 | pass 90 | 91 | try: 92 | if hasattr(self.model.optimizer, "lr"): 93 | params["optimizer_lr"] = sanitize_np_types( 94 | self.model.optimizer.lr 95 | if type(self.model.optimizer.lr) is float 96 | else keras.backend.eval(self.model.optimizer.lr) 97 | ) 98 | except Exception: # noqa 99 | pass 100 | 101 | try: 102 | if hasattr(self.model.optimizer, "epsilon"): 103 | params["optimizer_epsilon"] = sanitize_np_types( 104 | self.model.optimizer.epsilon 105 | if type(self.model.optimizer.epsilon) is float 106 | else keras.backend.eval(self.model.optimizer.epsilon) 107 | ) 108 | except Exception: # noqa 109 | pass 110 | 111 | if params: 112 | self.run.log_inputs(**params) 113 | 114 | try: 115 | sum_list = [] 116 | self.model.summary(print_fn=sum_list.append) 117 | summary = "\n".join(sum_list) 118 | rel_path = self.run.get_outputs_path("model_summary.txt") 119 | with open(rel_path, "w") as f: 120 | f.write(summary) 121 | self.run.log_file_ref(path=rel_path, name="model_summary", is_input=False) 122 | except Exception: # noqa 123 | pass 124 | 125 | @client_handler(check_no_op=True) 126 | def on_epoch_end(self, epoch, logs=None): 127 | if not logs or not self.run: 128 | return 129 | 130 | if self.metrics: 131 | metrics = { 132 | metric: logs[metric] for metric in self.metrics if metric in logs 133 | } 134 | else: 135 | metrics = logs # Log all metrics 136 | 137 | self.current = logs.get(self.monitor) 138 | if self.current and self.monitor_op(self.current, self.best): 139 | if self.log_best_prefix: 140 | metrics[f"{self.log_best_prefix}_{self.monitor}"] = self.current 141 | metrics[f"{self.log_best_prefix}_epoch"] = epoch 142 | if self.log_model: 143 | self._log_model() 144 | self.best = self.current 145 | 146 | self.run.log_metrics(step=epoch, **metrics) 147 | 148 | @client_handler(check_no_op=True) 149 | def on_train_end(self, logs=None): # pylint: disable=unused-argument 150 | if not self.log_model: 151 | return 152 | 153 | if self.run._has_meta_key("has_model"): # noqa 154 | # Best model was already saved 155 | return 156 | 157 | self._log_model() 158 | 159 | @client_handler(check_no_op=True) 160 | def _log_model(self): 161 | try: 162 | if self.save_weights_only: 163 | self.model.save_weights(self.filepath, overwrite=True) 164 | else: 165 | self.model.save(self.filepath, overwrite=True) 166 | if not self.run._has_meta_key("has_model"): # noqa 167 | self.run.log_model_ref(self.filepath, name="model", framework="keras") 168 | # `RuntimeError: Unable to create link` in TF 1.13.1 169 | # also saw `TypeError: can't pickle _thread.RLock objects` 170 | except (ImportError, RuntimeError, TypeError) as e: 171 | logger.warning("Can't save model, h5py returned error: %s" % e) 172 | self.log_model = False 173 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/langchain.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from typing import Any, Optional 4 | 5 | from polyaxon._client.decorators import client_handler 6 | from polyaxon._schemas.lifecycle import V1StatusCondition, V1Statuses 7 | from traceml import tracking 8 | from traceml.events.schemas import V1EventSpan, V1EventSpanKind 9 | from traceml.exceptions import TracemlException 10 | 11 | try: 12 | from langchain.callbacks.tracers.base import BaseTracer 13 | from langchain.callbacks.tracers.schemas import Run, TracerSession 14 | from langchain.env import get_runtime_environment 15 | from langchain.load.dump import dumpd 16 | from langchain.schema.messages import BaseMessage 17 | except ImportError: 18 | raise TracemlException("Langchain is required to use the tracking Callback") 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def _serialize_io(run_inputs: dict) -> dict: 24 | serialized_inputs = {} 25 | for key, value in run_inputs.items(): 26 | if key == "input_documents": 27 | serialized_inputs.update( 28 | {f"input_document_{i}": doc.to_json() for i, doc in enumerate(value)} 29 | ) 30 | else: 31 | serialized_inputs[key] = value 32 | return serialized_inputs 33 | 34 | 35 | class RunProcessor: 36 | """Handles the conversion of a LangChain Runs into a trace.""" 37 | 38 | @classmethod 39 | def process_span(cls, run: Run) -> Optional["V1EventSpan"]: 40 | """Converts a LangChain Run into a V1EventSpan. 41 | Params: 42 | run: The LangChain Run to convert. 43 | Returns: 44 | The converted V1EventSpan. 45 | """ 46 | try: 47 | span = cls._convert_lc_run_to_span(run) 48 | return span 49 | except Exception as e: 50 | logger.warning( 51 | f"Skipping trace saving - unable to safely convert LangChain Run " 52 | f"into Trace due to: {e}" 53 | ) 54 | return None 55 | 56 | @classmethod 57 | def _convert_run_to_span(cls, run: Run) -> "V1EventSpan": 58 | """Base utility to create a span from a run. 59 | 60 | Params: 61 | run: The run to convert. 62 | Returns: 63 | The converted V1EventSpan. 64 | """ 65 | metadata = {**run.extra} if run.extra else {} 66 | metadata["execution_order"] = run.execution_order 67 | 68 | status_conditions = ( 69 | [ 70 | V1StatusCondition.construct( 71 | type=V1Statuses.FAILED, 72 | status=True, 73 | reason="SpanFailed", 74 | message=run.error, 75 | last_transition_time=run.end_time, 76 | last_update_time=run.end_time, 77 | ) 78 | ] 79 | if run.error 80 | else None 81 | ) 82 | return V1EventSpan( 83 | uuid=str(run.id) if run.id is not None else None, 84 | name=run.name, 85 | started_at=run.start_time, 86 | finished_at=run.end_time, 87 | status=V1Statuses.SUCCEEDED if run.error is None else V1Statuses.FAILED, 88 | status_conditions=status_conditions, 89 | metadata=metadata, 90 | ) 91 | 92 | @classmethod 93 | def _convert_llm_run_to_span(cls, run: Run) -> "V1EventSpan": 94 | """Converts a LangChain LLM Run into a V1EventSpan. 95 | Params 96 | run: The LangChain LLM Run to convert. 97 | Returns: 98 | The converted V1EventSpan. 99 | """ 100 | base_span = cls._convert_run_to_span(run) 101 | if base_span.metadata is None: 102 | base_span.metadata = {} 103 | base_span.inputs = run.inputs 104 | base_span.outputs = run.outputs 105 | base_span.metadata["llm_output"] = run.outputs.get("llm_output", {}) 106 | base_span.kind = V1EventSpanKind.LLM 107 | return base_span 108 | 109 | @classmethod 110 | def _convert_chain_run_to_span(cls, run: Run) -> "V1EventSpan": 111 | """Converts a LangChain Chain Run into a V1EventSpan. 112 | 113 | Params 114 | run: The LangChain Chain Run to convert. 115 | Returns: 116 | The converted V1EventSpan. 117 | """ 118 | base_span = cls._convert_run_to_span(run) 119 | 120 | base_span.inputs = _serialize_io(run.inputs) 121 | base_span.outputs = _serialize_io(run.outputs) 122 | base_span.children = [ 123 | cls._convert_lc_run_to_span(child_run) for child_run in run.child_runs 124 | ] 125 | base_span.kind = ( 126 | V1EventSpanKind.AGENT 127 | if "agent" in run.name.lower() 128 | else V1EventSpanKind.CHAIN 129 | ) 130 | 131 | return base_span 132 | 133 | @classmethod 134 | def _convert_tool_run_to_span(cls, run: Run) -> "V1EventSpan": 135 | """Converts a LangChain Tool Run into a V1EventSpan. 136 | 137 | Params 138 | run: The LangChain Tool Run to convert. 139 | Returns: 140 | The converted V1EventSpan. 141 | """ 142 | base_span = cls._convert_run_to_span(run) 143 | base_span.inputs = _serialize_io(run.inputs) 144 | base_span.outputs = _serialize_io(run.outputs) 145 | base_span.children = [ 146 | cls._convert_lc_run_to_span(child_run) for child_run in run.child_runs 147 | ] 148 | base_span.kind = V1EventSpanKind.TOOL 149 | 150 | return base_span 151 | 152 | @classmethod 153 | def _convert_lc_run_to_span(cls, run: Run) -> "V1EventSpan": 154 | """Utility to convert any generic LangChain Run into a V1EventSpan. 155 | 156 | Params 157 | run: The LangChain Run to convert. 158 | Returns: 159 | The converted V1EventSpan. 160 | """ 161 | if run.run_type == V1EventSpanKind.LLM: 162 | return cls._convert_llm_run_to_span(run) 163 | elif run.run_type == V1EventSpanKind.CHAIN: 164 | return cls._convert_chain_run_to_span(run) 165 | elif run.run_type == V1EventSpanKind.TOOL: 166 | return cls._convert_tool_run_to_span(run) 167 | else: 168 | return cls._convert_run_to_span(run) 169 | 170 | 171 | class Callback(BaseTracer): 172 | """Callback Handler that logs Langchain traces/spans.""" 173 | 174 | @client_handler(check_no_op=True) 175 | def __init__(self, run=None, **kwargs: Any) -> None: 176 | """Initializes the callback. 177 | 178 | To monitor all LangChain activity, add this tracer like any other 179 | LangChain callback: 180 | 181 | ```python 182 | from polyaxon.tracking.integrations.langchain import Callback 183 | callback = Callback() 184 | chain = LLMChain(llm, callbacks=[callback]) 185 | ``` 186 | 187 | When using manual tracking of multiple runs in a single script: 188 | 189 | ```python 190 | from polyaxon.tracking.integrations.langchain import Callback 191 | tracking.init(..., is_new=True, ...) 192 | ... 193 | callback = Callback() 194 | chain = LLMChain(llm, callbacks=[callback]) 195 | ... 196 | tracking.end() 197 | ``` 198 | """ 199 | super().__init__(**kwargs) 200 | self.run = tracking.get_or_create_run(run) 201 | 202 | @client_handler(check_no_op=True) 203 | def _persist_run(self, run: "Run"): 204 | """Converts a LangChain Run to a Trace.""" 205 | span = RunProcessor.process_span(run) 206 | if span is None: 207 | return 208 | if self.run is not None: 209 | self.run.log_trace(span=span) 210 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/lightgbm.py: -------------------------------------------------------------------------------- 1 | from traceml import tracking 2 | 3 | 4 | def callback(run=None): 5 | run = tracking.get_or_create_run(run) 6 | 7 | def _callback(env): 8 | res = {} 9 | for data_name, eval_name, value, _ in env.evaluation_result_list: 10 | key = data_name + "-" + eval_name 11 | res[key] = value 12 | run.log_metrics(step=env.iteration, **res) 13 | 14 | return _callback 15 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/pytorch_lightning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | 4 | from argparse import Namespace 5 | from typing import Any, Dict, List, Optional, Union 6 | 7 | import packaging 8 | 9 | from polyaxon._env_vars.keys import ENV_KEYS_RUN_INSTANCE 10 | from polyaxon.client import RunClient 11 | from traceml import tracking 12 | from traceml.exceptions import TracemlException 13 | 14 | try: 15 | import pytorch_lightning as pl 16 | 17 | NEW_PL_VERSION = packaging.version.parse(pl.__version__) 18 | 19 | if NEW_PL_VERSION < packaging.version.parse("1.7"): 20 | from pytorch_lightning.loggers.base import LightningLoggerBase as Logger 21 | from pytorch_lightning.loggers.base import rank_zero_experiment 22 | else: 23 | from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment 24 | 25 | if NEW_PL_VERSION < packaging.version.parse("1.9"): 26 | from pytorch_lightning.utilities.logger import ( 27 | _add_prefix, 28 | _convert_params, 29 | _flatten_dict, 30 | _sanitize_callable_params, 31 | ) 32 | else: 33 | from lightning_fabric.utilities.logger import ( 34 | _add_prefix, 35 | _convert_params, 36 | _flatten_dict, 37 | _sanitize_callable_params, 38 | ) 39 | from pytorch_lightning.utilities.model_summary import ModelSummary 40 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn 41 | except ImportError: 42 | raise TracemlException("PytorchLightning is required to use the tracking Callback") 43 | 44 | 45 | class Callback(Logger): 46 | LOGGER_JOIN_CHAR = "_" 47 | 48 | def __init__( 49 | self, 50 | owner: Optional[str] = None, 51 | project: Optional[str] = None, 52 | run_uuid: Optional[str] = None, 53 | client: RunClient = None, 54 | track_code: bool = True, 55 | track_env: bool = True, 56 | refresh_data: bool = False, 57 | artifacts_path: Optional[str] = None, 58 | collect_artifacts: Optional[str] = None, 59 | collect_resources: Optional[str] = None, 60 | is_offline: Optional[bool] = None, 61 | is_new: Optional[bool] = None, 62 | name: Optional[str] = None, 63 | description: Optional[str] = None, 64 | tags: Optional[List[str]] = None, 65 | end_on_finalize: bool = False, 66 | prefix: str = "", 67 | ): 68 | super().__init__() 69 | self._owner = owner 70 | self._project = project 71 | self._run_uuid = run_uuid 72 | self._client = client 73 | self._track_code = track_code 74 | self._track_env = track_env 75 | self._refresh_data = refresh_data 76 | self._artifacts_path = artifacts_path 77 | self._collect_artifacts = collect_artifacts 78 | self._collect_resources = collect_resources 79 | self._is_offline = is_offline 80 | self._is_new = is_new 81 | self._name = name 82 | self._description = description 83 | self._tags = tags 84 | self._end_on_finalize = end_on_finalize 85 | self._prefix = prefix 86 | self._experiment = None 87 | 88 | @property 89 | @rank_zero_experiment 90 | def experiment(self) -> tracking.Run: 91 | if self._experiment: 92 | return self._experiment 93 | tracking.init( 94 | owner=self._owner, 95 | project=self._project, 96 | run_uuid=self._run_uuid, 97 | client=self._client, 98 | track_code=self._track_code, 99 | track_env=self._track_env, 100 | refresh_data=self._refresh_data, 101 | artifacts_path=self._artifacts_path, 102 | collect_artifacts=self._collect_artifacts, 103 | collect_resources=self._collect_resources, 104 | is_offline=self._is_offline, 105 | is_new=self._is_new, 106 | name=self._name, 107 | description=self._description, 108 | tags=self._tags, 109 | ) 110 | self._experiment = tracking.TRACKING_RUN 111 | return self._experiment 112 | 113 | @rank_zero_only 114 | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]): 115 | params = _convert_params(params) 116 | params = _flatten_dict(params) 117 | params = _sanitize_callable_params(params) 118 | self.experiment.log_inputs(**params) 119 | 120 | @rank_zero_only 121 | def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): 122 | assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" 123 | metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) 124 | self.experiment.log_metrics(**metrics, step=step) 125 | 126 | @rank_zero_only 127 | def log_model_summary(self, model: "pl.LightningModule", max_depth: int = -1): 128 | summary = str(ModelSummary(model=model, max_depth=max_depth)) 129 | rel_path = self.experiment.get_outputs_path("model_summary.txt") 130 | with open(rel_path, "w") as f: 131 | f.write(summary) 132 | self.experiment.log_file_ref( 133 | path=rel_path, name="model_summary", is_input=False 134 | ) 135 | 136 | @property 137 | def save_dir(self) -> Optional[str]: 138 | return self.experiment.get_outputs_path() 139 | 140 | @rank_zero_only 141 | def finalize(self, status: str): 142 | if self._end_on_finalize: 143 | self.experiment.end() 144 | self._experiment = None 145 | 146 | def _set_run_instance_from_env_vars(self, force: bool = False): 147 | """Tries to extract run info from canonical env vars""" 148 | run_instance = os.getenv(ENV_KEYS_RUN_INSTANCE) 149 | if not run_instance: 150 | return 151 | 152 | parts = run_instance.split(".") 153 | if len(parts) != 4: 154 | return 155 | 156 | if not self._name or force: 157 | self._name = parts[2] 158 | if not self._run_uuid or force: 159 | self._run_uuid = parts[-1] 160 | 161 | @property 162 | def name(self) -> str: 163 | if self._experiment is not None and self._experiment.run_data.name is not None: 164 | return self.experiment.run_data.name 165 | 166 | if not self._name: 167 | self._set_run_instance_from_env_vars() 168 | 169 | if self._name: 170 | return self._name 171 | 172 | return "default" 173 | 174 | @property 175 | def version(self) -> str: 176 | if self._experiment is not None and self._experiment.run_data.uuid is not None: 177 | return self.experiment.run_data.uuid 178 | 179 | if not self._run_uuid: 180 | self._set_run_instance_from_env_vars() 181 | 182 | if self._run_uuid: 183 | return self._run_uuid 184 | 185 | return uuid.uuid4().hex 186 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/scikit.py: -------------------------------------------------------------------------------- 1 | from clipped.utils.np import sanitize_dict 2 | 3 | from traceml import tracking 4 | from traceml.exceptions import TracemlException 5 | 6 | try: 7 | from sklearn.base import is_classifier, is_regressor 8 | from sklearn.metrics import ( 9 | accuracy_score, 10 | explained_variance_score, 11 | f1_score, 12 | max_error, 13 | mean_absolute_error, 14 | precision_recall_fscore_support, 15 | r2_score, 16 | ) 17 | except ImportError: 18 | raise TracemlException("sklearn is required to use scikit polyaxon's loggers") 19 | 20 | 21 | def _log_test_predictions(run, y_test, y_pred=None, nrows=1000): 22 | try: 23 | import pandas as pd 24 | except ImportError: 25 | return 26 | 27 | # single output 28 | if len(y_pred.shape) == 1: 29 | df = pd.DataFrame(data={"y_true": y_test, "y_pred": y_pred}) 30 | run.log_dataframe(df=df.head(nrows), name="test_predictions") 31 | 32 | # multi output 33 | if len(y_pred.shape) == 2: 34 | df = pd.DataFrame() 35 | for j in range(y_pred.shape[1]): 36 | df["y_test_output_{}".format(j)] = y_test[:, j] 37 | df["y_pred_output_{}".format(j)] = y_pred[:, j] 38 | run.log_dataframe(df=df.head(nrows), name="test_predictions") 39 | 40 | 41 | def _log_test_preds_proba(run, classifier, X_test, nrows=1000): 42 | try: 43 | import pandas as pd 44 | except ImportError: 45 | return 46 | 47 | try: 48 | y_pred_proba = classifier.predict_proba(X_test) 49 | except Exception as e: 50 | print( 51 | "This classifier does not provide predictions probabilities. Error: {}".format( 52 | e 53 | ) 54 | ) 55 | return 56 | 57 | df = pd.DataFrame(data=y_pred_proba, columns=classifier.classes_) 58 | run.log_dataframe(df=df.head(nrows), name="test_proba_predictions") 59 | 60 | 61 | def log_regressor(regressor, X_test, y_test, nrows=1000, run=None): 62 | assert is_regressor(regressor), "regressor should be sklearn regressor." 63 | 64 | run = tracking.get_or_create_run(run) 65 | 66 | run.log_inputs(**regressor.get_params()) 67 | 68 | y_pred = regressor.predict(X_test) 69 | 70 | # single output 71 | results = {} 72 | if len(y_pred.shape) == 1: 73 | results["evs"] = explained_variance_score(y_test, y_pred) 74 | results["me"] = max_error(y_test, y_pred) 75 | results["mae"] = mean_absolute_error(y_test, y_pred) 76 | results["r2"] = r2_score(y_test, y_pred) 77 | # multi output 78 | if len(y_pred.shape) == 2: 79 | results["r2"] = regressor.score(X_test, y_test) 80 | run.log_metrics(**results) 81 | 82 | _log_test_predictions(run, y_test, y_pred=y_pred, nrows=nrows) 83 | 84 | 85 | def log_classifier(classifier, X_test, y_test, nrows=1000, run=None): 86 | assert is_classifier(classifier), "classifier should be sklearn classifier." 87 | 88 | run = tracking.get_or_create_run(run) 89 | 90 | run.log_inputs(**sanitize_dict(classifier.get_params())) 91 | 92 | _log_test_preds_proba(run, classifier, X_test, nrows=nrows) 93 | 94 | y_pred = classifier.predict(X_test) 95 | 96 | results = {} 97 | for metric_name, values in zip( 98 | ["precision", "recall", "fbeta_score", "support"], 99 | precision_recall_fscore_support(y_test, y_pred), 100 | ): 101 | for i, value in enumerate(values): 102 | results["{}_class_{}_test".format(metric_name, i)] = value 103 | results["accuracy"] = accuracy_score(y_test, y_pred) 104 | results["f1"] = f1_score(y_pred, y_pred, average="weighted") 105 | run.log_metrics(**results) 106 | _log_test_predictions(run, y_test, y_pred=y_pred, nrows=nrows) 107 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/tensorboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import TYPE_CHECKING, Any, Optional 4 | 5 | from traceml import tracking 6 | from traceml.exceptions import TracemlException 7 | from traceml.logger import logger 8 | 9 | summary_pb2 = None 10 | 11 | try: 12 | from tensorflow.core.framework import summary_pb2 # noqa 13 | except ImportError: 14 | pass 15 | try: 16 | from tensorboardX.proto import summary_pb2 # noqa 17 | except ImportError: 18 | pass 19 | 20 | try: 21 | from tensorboard.compat.proto import summary_pb2 # noqa 22 | except ImportError: 23 | pass 24 | 25 | if not summary_pb2: 26 | raise TracemlException( 27 | "tensorflow/tensorboard/tensorboardx is required to use the tracking Logger" 28 | ) 29 | 30 | 31 | if TYPE_CHECKING: 32 | from traceml.tracking import Run 33 | 34 | 35 | class Logger: 36 | @classmethod 37 | def process_summary( 38 | cls, 39 | summary: Any, 40 | global_step: Optional[int] = None, 41 | run: "Run" = None, 42 | log_image: bool = False, 43 | log_histo: bool = False, 44 | log_tensor: bool = False, 45 | ): 46 | run = tracking.get_or_create_run(run) 47 | if not run: 48 | return 49 | 50 | if isinstance(summary, bytes): 51 | summary_proto = summary_pb2.Summary() 52 | summary_proto.ParseFromString(summary) 53 | summary = summary_proto 54 | 55 | step = cls._process_step(global_step) 56 | for value in summary.value: 57 | try: 58 | cls.add_value( 59 | run=run, 60 | step=step, 61 | value=value, 62 | log_image=log_image, 63 | log_histo=log_histo, 64 | log_tensor=log_tensor, 65 | ) 66 | except TracemlException("Polyaxon failed processing tensorboard summary."): 67 | pass 68 | 69 | @classmethod 70 | def add_value( 71 | cls, 72 | run, 73 | step, 74 | value, 75 | log_image: bool = False, 76 | log_histo: bool = False, 77 | log_tensor: bool = False, 78 | ): 79 | field = value.WhichOneof("value") 80 | 81 | if field == "simple_value": 82 | run.log_metric(name=value.tag, step=step, value=value.simple_value) 83 | return 84 | 85 | if field == "image" and log_image: 86 | run.log_image(name=value.tag, step=step, data=value.image) 87 | return 88 | 89 | if ( 90 | field == "tensor" 91 | and log_tensor 92 | and value.tensor.string_val 93 | and len(value.tensor.string_val) 94 | ): 95 | string_values = [] 96 | for _ in range(0, len(value.tensor.string_val)): 97 | string_value = value.tensor.string_val.pop() 98 | string_values.append(string_value.decode("utf-8")) 99 | 100 | run.log_text(name=value.tag, step=step, text=", ".join(string_values)) 101 | return 102 | 103 | elif field == "histo" and log_histo: 104 | if len(value.histo.bucket_limit) >= 3: 105 | first = ( 106 | value.histo.bucket_limit[0] 107 | + value.histo.bucket_limit[0] 108 | - value.histo.bucket_limit[1] 109 | ) 110 | last = ( 111 | value.histo.bucket_limit[-2] 112 | + value.histo.bucket_limit[-2] 113 | - value.histo.bucket_limit[-3] 114 | ) 115 | values, counts = ( 116 | list(value.histo.bucket), 117 | [first] + value.histo.bucket_limit[:-1] + [last], 118 | ) 119 | try: 120 | run.log_np_histogram( 121 | name=value.tag, values=values, counts=counts, step=step 122 | ) 123 | return 124 | except ValueError: 125 | logger.warning( 126 | "Ignoring histogram for tag `{}`, " 127 | "Histograms must have few bins".format(value.tag) 128 | ) 129 | else: 130 | logger.warning( 131 | "Ignoring histogram for tag `{}`, " 132 | "Found a histogram with only 2 bins.".format(value.tag) 133 | ) 134 | 135 | @staticmethod 136 | def get_writer_name(log_dir): 137 | return os.path.basename(os.path.normpath(log_dir)) 138 | 139 | @staticmethod 140 | def _process_step(global_step): 141 | return int(global_step) if global_step is not None else None 142 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/tensorflow.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any 2 | 3 | from traceml import tracking 4 | from traceml.exceptions import TracemlException 5 | from traceml.integrations.tensorboard import Logger 6 | 7 | try: 8 | import tensorflow as tf 9 | except ImportError: 10 | raise TracemlException("tensorflow is required to use the tracking Callback") 11 | 12 | try: 13 | from tensorflow.train import SessionRunHook # noqa 14 | except ImportError: 15 | raise TracemlException("tensorflow is required to use the tracking Callback") 16 | 17 | 18 | if TYPE_CHECKING: 19 | from traceml.tracking import Run 20 | 21 | 22 | class Callback(SessionRunHook): 23 | def __init__( 24 | self, 25 | summary_op: Any = None, 26 | steps_per_log: int = 1000, 27 | run: "Run" = None, 28 | log_image: bool = False, 29 | log_histo: bool = False, 30 | log_tensor: bool = False, 31 | ): 32 | self._summary_op = summary_op 33 | self._steps_per_log = steps_per_log 34 | self.run = tracking.get_or_create_run(run) 35 | self._log_image = log_image 36 | self._log_histo = log_histo 37 | self._log_tensor = log_tensor 38 | 39 | def begin(self): 40 | if self._summary_op is None: 41 | self._summary_op = tf.summary.merge_all() 42 | self._step = -1 43 | 44 | def before_run(self, run_context): 45 | self._step += 1 46 | return tf.train.SessionRunArgs({"summary": self._summary_op}) 47 | 48 | def after_run(self, run_context, run_values): 49 | if self._step % self._steps_per_log == 0: 50 | Logger.process_summary( 51 | run_values.results["summary"], 52 | run=self.run, 53 | log_image=self._log_image, 54 | log_histo=self._log_histo, 55 | log_tensor=self._log_tensor, 56 | ) 57 | -------------------------------------------------------------------------------- /traceml/traceml/integrations/xgboost.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional 2 | 3 | from clipped.utils.json import orjson_loads 4 | 5 | from traceml import tracking 6 | from traceml.exceptions import TracemlException 7 | from traceml.logger import logger 8 | 9 | try: 10 | import xgboost as xgb 11 | 12 | from xgboost import Booster 13 | except ImportError: 14 | raise TracemlException("xgboost is required to use the tracking callback") 15 | 16 | 17 | if TYPE_CHECKING: 18 | from traceml.tracking import Run 19 | 20 | 21 | def _get_cv(model): 22 | return getattr(model, "cvfolds", False) 23 | 24 | 25 | def _log_importance(run, model, model_folds, max_num_features, **kwargs): 26 | try: 27 | import matplotlib.pyplot as plt 28 | except ImportError: 29 | raise ImportError("Please install matplotlib to log importance") 30 | 31 | if model_folds: 32 | for i, fold in enumerate(model_folds): 33 | importance = xgb.plot_importance( 34 | fold.bst, max_num_features=max_num_features, **kwargs 35 | ) 36 | run.log_mpl_plotly_chart( 37 | name="feature_importance", figure=importance.figure, step=i 38 | ) 39 | else: 40 | importance = xgb.plot_importance( 41 | model, max_num_features=max_num_features, **kwargs 42 | ) 43 | run.log_mpl_plotly_chart(name="feature_importance", figure=importance.figure) 44 | plt.close("all") 45 | 46 | 47 | def _log_model(run, model, model_folds): 48 | def _save(file_model, file_name): 49 | asset_path = run.get_outputs_path(file_name) 50 | file_model.save_model(asset_path) 51 | run.log_model_ref(asset_path, framework="xgboost") 52 | 53 | if model_folds: 54 | for i, cvpack in enumerate(model_folds): 55 | _save(cvpack.bst, "model-{}".format(i)) 56 | else: # train case 57 | _save(model, "model") 58 | 59 | 60 | def callback( 61 | log_model: bool = True, 62 | log_importance: bool = True, 63 | max_num_features: Optional[int] = None, 64 | run: Optional["Run"] = None, 65 | ): 66 | run = tracking.get_or_create_run(run) 67 | 68 | def callback(env): 69 | # Log metrics after iteration 70 | metrics = {} 71 | for item in env.evaluation_result_list: 72 | if len(item) == 2: # train case 73 | metrics[item[0]] = item[1] 74 | if len(item) == 3: # cv case 75 | metrics["{}-mean".format(item[0])] = item[1] 76 | metrics["{}-std".format(item[0])] = item[2] 77 | run.log_metrics( 78 | **metrics, 79 | step=env.iteration, 80 | ) 81 | 82 | model = getattr(env, "model") 83 | model_folds = _get_cv(env) 84 | 85 | # Log booster, end of training 86 | if log_model: 87 | _log_model(run=run, model=model, model_folds=model_folds) 88 | 89 | # Log feature importance, end of training 90 | if env.iteration + 1 == env.end_iteration and log_importance: 91 | try: 92 | _log_importance( 93 | run, 94 | model=model, 95 | model_folds=model_folds, 96 | max_num_features=max_num_features, 97 | ) 98 | except Exception as e: 99 | logger.info("Failed logging feature importance %s", e) 100 | 101 | return callback 102 | 103 | 104 | class Callback(xgb.callback.TrainingCallback): 105 | def __init__( 106 | self, 107 | run: "Run" = None, 108 | log_model: bool = True, 109 | log_importance: bool = True, 110 | importance_type: str = "gain", 111 | max_num_features: Optional[int] = None, 112 | ): 113 | self.log_model: bool = log_model 114 | self.log_importance: bool = log_importance 115 | self.importance_type: str = importance_type 116 | self.max_num_features: int = max_num_features 117 | self.run = tracking.get_or_create_run(run) 118 | 119 | def after_training(self, model: Booster) -> Booster: 120 | model_folds = _get_cv(model) 121 | 122 | if self.log_model: 123 | _log_model(run=self.run, model=model, model_folds=model_folds) 124 | 125 | if self.log_importance: 126 | _log_importance( 127 | self.run, 128 | model=model, 129 | model_folds=model_folds, 130 | max_num_features=self.max_num_features, 131 | ) 132 | 133 | if model_folds: 134 | config = {} 135 | for i, fold in enumerate(model_folds): 136 | config["fold_{}_config".format(i)] = orjson_loads( 137 | fold.bst.save_config() 138 | ) 139 | if config: 140 | self.run.log_inputs(**config) 141 | else: 142 | self.run.log_inputs(config=orjson_loads(model.save_config())) 143 | outputs = {} 144 | if "best_score" in model.attributes().keys(): 145 | outputs["best_score"] = model.attributes()["best_score"] 146 | if "best_iteration" in model.attributes().keys(): 147 | outputs["best_iteration"] = model.attributes()["best_iteration"] 148 | self.run.log_outputs(**outputs) 149 | 150 | return model 151 | 152 | def after_iteration(self, model: Booster, epoch: int, evals_log: dict) -> bool: 153 | metrics = {} 154 | for stage, metrics_dict in evals_log.items(): 155 | for metric_name, metric_values in evals_log[stage].items(): 156 | if _get_cv(model): 157 | mean, std = metric_values[-1] 158 | metrics["{}-{}-mean".format(stage, metric_name)] = mean 159 | metrics["{}-{}-std".format(stage, metric_name)] = std 160 | else: 161 | metrics["{}-{}".format(stage, metric_name)] = metric_values[-1] 162 | 163 | if metrics: 164 | self.run.log_metrics(step=epoch, **metrics) 165 | 166 | return False 167 | -------------------------------------------------------------------------------- /traceml/traceml/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger("traceml") 4 | -------------------------------------------------------------------------------- /traceml/traceml/logging/__init__.py: -------------------------------------------------------------------------------- 1 | from traceml.logging.schemas import V1Log, V1Logs 2 | -------------------------------------------------------------------------------- /traceml/traceml/logging/handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import socket 4 | 5 | from clipped.utils.dates import to_datetime 6 | from clipped.utils.env import get_user 7 | 8 | from polyaxon import settings 9 | from polyaxon._env_vars.keys import ENV_KEYS_K8S_NODE_NAME, ENV_KEYS_K8S_POD_ID 10 | from traceml.logging.schemas import V1Log 11 | 12 | 13 | class LogStreamHandler(logging.Handler): 14 | def __init__(self, add_logs, **kwargs): 15 | self._add_logs = add_logs 16 | self._container = socket.gethostname() 17 | self._node = os.environ.get(ENV_KEYS_K8S_NODE_NAME, "local") 18 | self._pod = os.environ.get(ENV_KEYS_K8S_POD_ID, get_user()) 19 | log_level = settings.CLIENT_CONFIG.log_level 20 | if log_level and isinstance(log_level, str): 21 | log_level = log_level.upper() 22 | super().__init__( 23 | level=kwargs.get("level", log_level or logging.NOTSET), 24 | ) 25 | 26 | def set_add_logs(self, add_logs): 27 | self._add_logs = add_logs 28 | 29 | def can_record(self, record): 30 | return not ( 31 | record.name == "polyaxon" 32 | or record.name == "traceml" 33 | or record.name == "polyaxon.cli" 34 | or record.name.startswith("polyaxon") 35 | or record.name.startswith("traceml") 36 | ) 37 | 38 | def format_record(self, record): 39 | message = "" 40 | if record.msg: 41 | message = record.msg 42 | if record.args: 43 | message %= record.args 44 | return V1Log.process_log_line( 45 | value=message, 46 | timestamp=to_datetime(record.created), 47 | node=self._node, 48 | pod=self._pod, 49 | container=self._container, 50 | ) 51 | 52 | def emit(self, record): # pylint:disable=inconsistent-return-statements 53 | if not self.can_record(record): 54 | return 55 | try: 56 | return self._add_logs(self.format_record(record)) 57 | except Exception: # noqa 58 | pass 59 | 60 | 61 | class LogStreamWriter: 62 | def __init__(self, logger, log_level, channel): 63 | self._logger = logger 64 | self._log_level = log_level 65 | self._channel = channel 66 | 67 | def write(self, buf): 68 | for line in buf.rstrip().splitlines(): 69 | if line != "\n": 70 | self._logger.log(self._log_level, line.rstrip()) 71 | 72 | def flush(self): 73 | self._channel.flush() 74 | -------------------------------------------------------------------------------- /traceml/traceml/logging/parser.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from clipped.utils.dates import parse_datetime 4 | 5 | # pylint:disable=anomalous-backslash-in-string 6 | 7 | DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S %Z" # noqa 8 | ISO_DATETIME_REGEX = re.compile( # noqa 9 | r"([0-9]+)-(0[1-9]|1[012])-(0[1-9]|[12][0-9]|3[01])[Tt]" 10 | r"([01][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9]|60)(\.[0-9]+)?" 11 | r"(([Zz])|([\+|\-]([01][0-9]|2[0-3]):[0-5][0-9]))\s?" 12 | ) 13 | DATETIME_REGEX = re.compile( # noqa 14 | r"\d{2}(?:\d{2})?-\d{1,2}-\d{1,2}\s\d{1,2}:\d{1,2}:\d{1,2}\s\w+\s?" 15 | ) 16 | 17 | 18 | def timestamp_search_regex(regex, log_line): 19 | log_search = regex.search(log_line) 20 | if not log_search: 21 | return log_line, None 22 | 23 | ts = log_search.group() 24 | ts = parse_datetime(ts) 25 | 26 | return re.sub(regex, "", log_line), ts 27 | -------------------------------------------------------------------------------- /traceml/traceml/logging/schemas.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from typing import ClassVar, List, Optional, Text 4 | 5 | from clipped.compact.pydantic import StrictStr 6 | from clipped.utils.dates import parse_datetime 7 | from clipped.utils.json import orjson_dumps, orjson_loads 8 | from clipped.utils.strings import validate_file_or_buffer 9 | from clipped.utils.tz import now 10 | 11 | from polyaxon._schemas.base import BaseSchemaModel 12 | from traceml.logging.parser import ( 13 | DATETIME_REGEX, 14 | ISO_DATETIME_REGEX, 15 | timestamp_search_regex, 16 | ) 17 | 18 | 19 | class V1Log(BaseSchemaModel): 20 | _SEPARATOR: ClassVar = "|" 21 | _IDENTIFIER: ClassVar = "log" 22 | 23 | timestamp: Optional[datetime.datetime] = None 24 | node: Optional[StrictStr] = None 25 | pod: Optional[StrictStr] = None 26 | container: Optional[StrictStr] = None 27 | value: Optional[StrictStr] = None 28 | 29 | @classmethod 30 | def process_log_line( 31 | cls, 32 | value: Text, 33 | node: Optional[str], 34 | pod: Optional[str], 35 | container: Optional[str], 36 | timestamp=None, 37 | ) -> Optional["V1Log"]: 38 | if not value: 39 | return None 40 | 41 | if not isinstance(value, str): 42 | value = value.decode("utf-8") 43 | 44 | value = value.strip() 45 | 46 | if not timestamp: 47 | value, timestamp = timestamp_search_regex(ISO_DATETIME_REGEX, value) 48 | if not timestamp: 49 | value, timestamp = timestamp_search_regex(DATETIME_REGEX, value) 50 | if isinstance(timestamp, str): 51 | try: 52 | timestamp = parse_datetime(timestamp) 53 | except Exception as e: 54 | raise ValueError("Received an invalid timestamp") from e 55 | 56 | return cls.construct( 57 | timestamp=timestamp if timestamp else now(tzinfo=True), 58 | node=node, 59 | pod=pod, 60 | container=container, 61 | value=value, 62 | ) 63 | 64 | def to_csv(self) -> str: 65 | values = [ 66 | str(self.timestamp) if self.timestamp is not None else "", 67 | str(self.node) if self.node is not None else "", 68 | str(self.pod) if self.pod is not None else "", 69 | str(self.container) if self.container is not None else "", 70 | orjson_dumps({"_": self.value}) if self.value is not None else "", 71 | ] 72 | 73 | return self._SEPARATOR.join(values) 74 | 75 | 76 | class V1Logs(BaseSchemaModel): 77 | _CHUNK_SIZE: ClassVar = 6000 78 | _IDENTIFIER = "logs" 79 | 80 | logs: Optional[List[V1Log]] = None 81 | last_time: Optional[datetime.datetime] = None 82 | last_file: Optional[StrictStr] = None 83 | files: Optional[List[StrictStr]] = None 84 | 85 | @classmethod 86 | def get_csv_header(cls) -> str: 87 | return V1Log._SEPARATOR.join(V1Log.model_fields.keys()) 88 | 89 | def to_csv(self): 90 | _logs = ["\n{}".format(e.to_csv()) for e in self.logs if e.value] 91 | return "".join(_logs) 92 | 93 | def get_jsonl_events(self) -> str: 94 | events = ["\n{}".format(e.to_json()) for e in self.logs if e.value] 95 | return "".join(events) 96 | 97 | @classmethod 98 | def should_chunk(cls, logs: List[V1Log]): 99 | return len(logs) >= cls._CHUNK_SIZE 100 | 101 | @classmethod 102 | def chunk_logs(cls, logs: List[V1Log]): 103 | total_size = len(logs) 104 | for i in range(0, total_size, cls._CHUNK_SIZE): 105 | yield cls(logs=logs[i : i + cls._CHUNK_SIZE]) 106 | 107 | @classmethod 108 | def read_csv(cls, data: str, parse_dates: bool = True) -> "V1Logs": 109 | import numpy as np 110 | import pandas as pd 111 | 112 | csv = validate_file_or_buffer(data) 113 | if parse_dates: 114 | df = pd.read_csv( 115 | csv, 116 | sep=V1Log._SEPARATOR, 117 | parse_dates=["timestamp"], 118 | error_bad_lines=False, 119 | engine="pyarrow", 120 | ) 121 | else: 122 | df = pd.read_csv( 123 | csv, 124 | sep=V1Log._SEPARATOR, 125 | engine="pyarrow", 126 | ) 127 | 128 | return cls.construct( 129 | logs=[ 130 | V1Log.construct( 131 | timestamp=i.get("timestamp"), 132 | node=i.get("node"), 133 | pod=i.get("pod"), 134 | container=i.get("container"), 135 | value=orjson_loads(i.get("value")).get("_"), 136 | ) 137 | for i in df.replace({np.nan: None}).to_dict(orient="records") 138 | ] 139 | ) 140 | 141 | @classmethod 142 | def read_jsonl(cls, data: str, to_structured: bool = False) -> "pandas.DataFrame": 143 | import numpy as np 144 | import pandas as pd 145 | 146 | data = validate_file_or_buffer(data) 147 | df = pd.read_json( 148 | data, 149 | lines=True, 150 | ) 151 | if "timestamp" in df.columns: 152 | df["timestamp"] = df["timestamp"].astype(str) 153 | 154 | df = df.replace({np.nan: None}).to_dict(orient="records") 155 | if not to_structured: 156 | return df 157 | return cls.construct( 158 | logs=[ 159 | V1Log.construct( 160 | timestamp=parse_datetime(i.get("timestamp")), 161 | node=i.get("node"), 162 | pod=i.get("pod"), 163 | container=i.get("container"), 164 | value=i.get("value"), 165 | ) 166 | for i in df 167 | ] 168 | ) 169 | -------------------------------------------------------------------------------- /traceml/traceml/logging/streamer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from collections import deque 4 | from typing import Callable 5 | 6 | from clipped.formatting import Printer 7 | from clipped.utils.json import orjson_loads 8 | from clipped.utils.tz import local_datetime 9 | 10 | from polyaxon import settings 11 | from polyaxon._containers.names import MAIN_CONTAINER_NAMES 12 | from traceml.logging.schemas import V1Log, V1Logs 13 | 14 | 15 | def get_logs_streamer( 16 | show_timestamp: bool = True, all_containers: bool = False, all_info: bool = False 17 | ) -> Callable: 18 | colors = deque(Printer.COLORS) 19 | job_to_color = {} 20 | if all_info: 21 | all_containers = True 22 | 23 | def handle_log_line(log: V1Log): 24 | log_dict = log.to_dict() 25 | log_line = "" 26 | if log.timestamp and show_timestamp: 27 | date_value = local_datetime( 28 | log_dict.get("timestamp"), tz=settings.CLIENT_CONFIG.timezone 29 | ) 30 | log_line = Printer.add_log_color(date_value, "white") + " | " 31 | 32 | def get_container_info(): 33 | if container_info in job_to_color: 34 | color = job_to_color[container_info] 35 | else: 36 | color = colors[0] 37 | colors.rotate(-1) 38 | job_to_color[container_info] = color 39 | return Printer.add_log_color(container_info, color) + " | " 40 | 41 | if not all_containers and log.container not in MAIN_CONTAINER_NAMES: 42 | return log_line 43 | 44 | if all_info: 45 | container_info = "" 46 | if log.node: 47 | log_line += Printer.add_log_color(log_dict.get("node"), "white") + " | " 48 | if log.pod: 49 | log_line += Printer.add_log_color(log_dict.get("pod"), "white") + " | " 50 | if log.container: 51 | container_info = log_dict.get("container") 52 | 53 | log_line += get_container_info() 54 | 55 | log_line += log_dict.get("value") 56 | Printer.log(log_line, nl=True) 57 | 58 | def handle_log_lines(logs: V1Logs): 59 | for log in logs.logs: 60 | if log: 61 | handle_log_line(log=log) 62 | 63 | return handle_log_lines 64 | 65 | 66 | def load_logs_from_path( 67 | logs_path: str, 68 | hide_time: bool = False, 69 | all_containers: bool = True, 70 | all_info: bool = True, 71 | ): 72 | for file_logs in sorted(os.listdir(logs_path)): 73 | with open(os.path.join(logs_path, file_logs)) as f: 74 | logs_data = orjson_loads(f.read()).get("logs", []) 75 | logs_stream = V1Logs(logs=logs_data) 76 | get_logs_streamer( 77 | show_timestamp=not hide_time, 78 | all_containers=all_containers, 79 | all_info=all_info, 80 | )(logs_stream) 81 | -------------------------------------------------------------------------------- /traceml/traceml/pkg.py: -------------------------------------------------------------------------------- 1 | NAME = "traceml" 2 | VERSION = "1.2.1" 3 | DESC = ( 4 | "Engine for ML/Data tracking, visualization, dashboards, and model UI for Polyaxon." 5 | ) 6 | URL = "https://github.com/polyaxon/traceml" 7 | AUTHOR = "Polyaxon, Inc." 8 | EMAIL = "contact@polyaxon.com" 9 | LICENSE = "Apache 2.0" 10 | -------------------------------------------------------------------------------- /traceml/traceml/processors/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/traceml/processors/errors.py: -------------------------------------------------------------------------------- 1 | NUMPY_ERROR_MESSAGE = "numpy is required for this tracking operation." 2 | PANDAS_ERROR_MESSAGE = "pandas is required for this tracking operation." 3 | PIL_ERROR_MESSAGE = "PIL/Pillow is required for this tracking operation." 4 | MOVIEPY_ERROR_MESSAGE = "moviepy is required for this tracking operation." 5 | MATPLOTLIB_ERROR_MESSAGE = "matplotlib is required for this tracking operation." 6 | PLOTLY_ERROR_MESSAGE = "plotly is required for this tracking operation." 7 | BOKEH_ERROR_MESSAGE = "bokeh is required for this tracking operation." 8 | SKLEARN_ERROR_MESSAGE = "sklearn is required for this tracking operation." 9 | -------------------------------------------------------------------------------- /traceml/traceml/processors/events_processors/__init__.py: -------------------------------------------------------------------------------- 1 | from traceml.processors.events_processors.events_artifacts_processors import ( 2 | artifact_path, 3 | ) 4 | from traceml.processors.events_processors.events_audio_processors import ( 5 | audio, 6 | audio_path, 7 | ) 8 | from traceml.processors.events_processors.events_charts_processors import ( 9 | altair_chart, 10 | bokeh_chart, 11 | mpl_plotly_chart, 12 | plotly_chart, 13 | ) 14 | from traceml.processors.events_processors.events_image_processors import ( 15 | convert_to_HWC, 16 | draw_boxes, 17 | encoded_image, 18 | ensure_matplotlib_figure, 19 | figure_to_image, 20 | figures_to_images, 21 | image, 22 | image_boxes, 23 | image_path, 24 | make_grid, 25 | make_image, 26 | save_image, 27 | ) 28 | from traceml.processors.events_processors.events_metrics_processors import ( 29 | confusion_matrix, 30 | curve, 31 | histogram, 32 | metric, 33 | metrics_dict_to_list, 34 | np_histogram, 35 | pr_curve, 36 | roc_auc_curve, 37 | sklearn_pr_curve, 38 | sklearn_roc_auc_curve, 39 | ) 40 | from traceml.processors.events_processors.events_models_processors import ( 41 | model_path, 42 | model_to_str, 43 | ) 44 | from traceml.processors.events_processors.events_tables_processors import dataframe_path 45 | from traceml.processors.events_processors.events_video_processors import ( 46 | make_video, 47 | prepare_video, 48 | video, 49 | video_path, 50 | ) 51 | -------------------------------------------------------------------------------- /traceml/traceml/processors/events_processors/events_artifacts_processors.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from clipped.utils.paths import copy_file_or_dir_path 4 | 5 | from traceml.events import V1EventArtifact 6 | 7 | try: 8 | import numpy as np 9 | except ImportError: 10 | np = None 11 | 12 | 13 | def artifact_path( 14 | from_path: str, asset_path: str, kind: str, asset_rel_path: Optional[str] = None 15 | ) -> V1EventArtifact: 16 | copy_file_or_dir_path(from_path, asset_path) 17 | return V1EventArtifact(kind=kind, path=asset_rel_path or asset_path) 18 | -------------------------------------------------------------------------------- /traceml/traceml/processors/events_processors/events_audio_processors.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from clipped.utils.np import to_np 4 | from clipped.utils.paths import check_or_create_path, copy_file_path 5 | 6 | from polyaxon._constants.globals import UNKNOWN 7 | from traceml.events import V1EventAudio 8 | from traceml.logger import logger 9 | from traceml.processors.errors import NUMPY_ERROR_MESSAGE 10 | 11 | try: 12 | import numpy as np 13 | except ImportError: 14 | np = None 15 | 16 | 17 | def audio_path( 18 | from_path: str, 19 | asset_path: str, 20 | content_type=None, 21 | asset_rel_path: Optional[str] = None, 22 | ) -> V1EventAudio: 23 | copy_file_path(from_path, asset_path) 24 | return V1EventAudio(path=asset_rel_path or asset_path, content_type=content_type) 25 | 26 | 27 | def audio( 28 | asset_path: str, tensor, sample_rate=44100, asset_rel_path: Optional[str] = None 29 | ): 30 | if not np: 31 | logger.warning(NUMPY_ERROR_MESSAGE) 32 | return UNKNOWN 33 | 34 | tensor = to_np(tensor) 35 | tensor = tensor.squeeze() 36 | if abs(tensor).max() > 1: 37 | print("warning: audio amplitude out of range, auto clipped.") 38 | tensor = tensor.clip(-1, 1) 39 | assert tensor.ndim == 1, "input tensor should be 1 dimensional." 40 | 41 | tensor_list = [int(32767.0 * x) for x in tensor] 42 | 43 | import struct 44 | import wave 45 | 46 | check_or_create_path(asset_path, is_dir=False) 47 | 48 | wave_write = wave.open(asset_path, "wb") 49 | wave_write.setnchannels(1) 50 | wave_write.setsampwidth(2) 51 | wave_write.setframerate(sample_rate) 52 | tensor_enc = b"" 53 | for v in tensor_list: 54 | tensor_enc += struct.pack(" V1EventChart: 19 | try: 20 | from bokeh.embed import json_item 21 | except ImportError: 22 | logger.warning(BOKEH_ERROR_MESSAGE) 23 | return UNKNOWN 24 | return V1EventChart(kind=V1EventChartKind.BOKEH, figure=json_item(figure)) 25 | 26 | 27 | def altair_chart(figure) -> V1EventChart: 28 | return V1EventChart(kind=V1EventChartKind.VEGA, figure=figure.to_dict()) 29 | 30 | 31 | def plotly_chart(figure) -> V1EventChart: 32 | if module_type(figure, "matplotlib.figure.Figure"): 33 | try: 34 | from traceml.vendor.matplotlylib import mpl_to_plotly 35 | except ImportError: 36 | logger.error(MATPLOTLIB_ERROR_MESSAGE) 37 | logger.error(PLOTLY_ERROR_MESSAGE) 38 | return UNKNOWN 39 | 40 | figure = mpl_to_plotly(figure) 41 | else: 42 | try: 43 | import plotly.tools 44 | except ImportError: 45 | logger.error(PLOTLY_ERROR_MESSAGE) 46 | return UNKNOWN 47 | 48 | figure = plotly.tools.return_figure_from_figure_or_data( 49 | figure, validate_figure=True 50 | ) 51 | return V1EventChart(kind=V1EventChartKind.PLOTLY, figure=figure) 52 | 53 | 54 | def mpl_plotly_chart(figure, close: bool = True) -> V1EventChart: 55 | try: 56 | import plotly.tools 57 | 58 | from plotly import optional_imports 59 | except ImportError: 60 | logger.warning(PLOTLY_ERROR_MESSAGE) 61 | return UNKNOWN 62 | 63 | try: 64 | import matplotlib 65 | import matplotlib.pyplot as plt 66 | 67 | from matplotlib.figure import Figure 68 | except ImportError: 69 | logger.warning(MATPLOTLIB_ERROR_MESSAGE) 70 | 71 | if module_type(figure, "matplotlib.figure.Figure"): 72 | pass 73 | else: 74 | if figure == matplotlib.pyplot: 75 | figure = figure.gcf() 76 | elif not isinstance(figure, Figure): 77 | if hasattr(figure, "figure"): 78 | figure = figure.figure 79 | # Some matplotlib objects have a figure function 80 | if not isinstance(figure, Figure): 81 | raise ValueError( 82 | "Only matplotlib.pyplot or matplotlib.pyplot.Figure objects are accepted." 83 | ) 84 | 85 | from traceml.vendor.matplotlylib import mpl_to_plotly 86 | 87 | plotly_figure = mpl_to_plotly(figure) 88 | result = plotly_chart(figure=plotly_figure) 89 | if close: 90 | try: 91 | plt.close(figure.number) 92 | except Exception: # noqa 93 | plt.close(figure) 94 | 95 | return result 96 | -------------------------------------------------------------------------------- /traceml/traceml/processors/events_processors/events_image_processors.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | from typing import Optional 4 | 5 | from clipped.utils.np import calculate_scale_factor, to_np 6 | from clipped.utils.paths import check_or_create_path, copy_file_path 7 | 8 | from polyaxon._constants.globals import UNKNOWN 9 | from traceml.events import V1EventImage 10 | from traceml.logger import logger 11 | from traceml.processors.errors import ( 12 | MATPLOTLIB_ERROR_MESSAGE, 13 | NUMPY_ERROR_MESSAGE, 14 | PIL_ERROR_MESSAGE, 15 | ) 16 | 17 | try: 18 | import numpy as np 19 | except ImportError: 20 | np = None 21 | 22 | 23 | def image_path( 24 | from_path: str, asset_path: str, asset_rel_path: Optional[str] = None 25 | ) -> V1EventImage: 26 | copy_file_path(from_path, asset_path) 27 | return V1EventImage(path=asset_rel_path or asset_path) 28 | 29 | 30 | def _draw_single_box( 31 | image, 32 | xmin, 33 | ymin, 34 | xmax, 35 | ymax, 36 | display_str, 37 | color="black", 38 | color_text="black", 39 | thickness=2, 40 | ): 41 | if not np: 42 | logger.warning(NUMPY_ERROR_MESSAGE) 43 | return UNKNOWN 44 | 45 | try: 46 | from PIL import ImageDraw, ImageFont 47 | except ImportError: 48 | logger.warning(PIL_ERROR_MESSAGE) 49 | return UNKNOWN 50 | 51 | font = ImageFont.load_default() 52 | draw = ImageDraw.Draw(image) 53 | (left, right, top, bottom) = (xmin, xmax, ymin, ymax) 54 | draw.line( 55 | [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], 56 | width=thickness, 57 | fill=color, 58 | ) 59 | if display_str: 60 | text_bottom = bottom 61 | # Reverse list and print from bottom to top. 62 | text_width, text_height = font.getsize(display_str) 63 | margin = np.ceil(0.05 * text_height) 64 | draw.rectangle( 65 | [ 66 | (left, text_bottom - text_height - 2 * margin), 67 | (left + text_width, text_bottom), 68 | ], 69 | fill=color, 70 | ) 71 | draw.text( 72 | (left + margin, text_bottom - text_height - margin), 73 | display_str, 74 | fill=color_text, 75 | font=font, 76 | ) 77 | return image 78 | 79 | 80 | def encoded_image(asset_path: str, data, asset_rel_path: Optional[str] = None): 81 | try: 82 | from PIL import Image 83 | except ImportError: 84 | logger.warning(PIL_ERROR_MESSAGE) 85 | return UNKNOWN 86 | 87 | image_data = Image.open(io.BytesIO(data.encoded_image_string)) 88 | return save_image( 89 | asset_path=asset_rel_path or asset_path, 90 | image_data=image_data, 91 | height=data.height, 92 | width=data.width, 93 | colorspace=data.colorspace, 94 | ) 95 | 96 | 97 | def image( 98 | asset_path: str, 99 | data, 100 | rescale=1, 101 | dataformats="CHW", 102 | asset_rel_path: Optional[str] = None, 103 | ): 104 | if not np: 105 | logger.warning(NUMPY_ERROR_MESSAGE) 106 | return UNKNOWN 107 | 108 | tensor = to_np(data) 109 | tensor = convert_to_HWC(tensor, dataformats) 110 | # Do not assume that user passes in values in [0, 255], use data type to detect 111 | scale_factor = calculate_scale_factor(tensor) 112 | tensor = tensor.astype(np.float32) 113 | tensor = (tensor * scale_factor).astype(np.uint8) 114 | return make_image( 115 | asset_path, tensor, rescale=rescale, asset_rel_path=asset_rel_path 116 | ) 117 | 118 | 119 | def image_boxes( 120 | asset_path: str, 121 | tensor_image, 122 | tensor_boxes, 123 | rescale=1, 124 | dataformats="CHW", 125 | asset_rel_path: Optional[str] = None, 126 | ): 127 | if not np: 128 | logger.warning(NUMPY_ERROR_MESSAGE) 129 | return UNKNOWN 130 | 131 | tensor_image = to_np(tensor_image) 132 | tensor_image = convert_to_HWC(tensor_image, dataformats) 133 | tensor_boxes = to_np(tensor_boxes) 134 | tensor_image = tensor_image.astype(np.float32) * calculate_scale_factor( 135 | tensor_image 136 | ) 137 | return make_image( 138 | asset_path, 139 | tensor_image.astype(np.uint8), 140 | rescale=rescale, 141 | rois=tensor_boxes, 142 | asset_rel_path=asset_rel_path, 143 | ) 144 | 145 | 146 | def draw_boxes(disp_image, boxes): 147 | # xyxy format 148 | num_boxes = boxes.shape[0] 149 | list_gt = range(num_boxes) 150 | for i in list_gt: 151 | disp_image = _draw_single_box( 152 | disp_image, 153 | boxes[i, 0], 154 | boxes[i, 1], 155 | boxes[i, 2], 156 | boxes[i, 3], 157 | display_str=None, 158 | color="Red", 159 | ) 160 | return disp_image 161 | 162 | 163 | def make_image( 164 | asset_path: str, tensor, rescale=1, rois=None, asset_rel_path: Optional[str] = None 165 | ): 166 | try: 167 | from PIL import Image 168 | except ImportError: 169 | logger.warning(PIL_ERROR_MESSAGE) 170 | return UNKNOWN 171 | 172 | height, width, colorspace = tensor.shape 173 | scaled_height = int(height * rescale) 174 | scaled_width = int(width * rescale) 175 | image_data = Image.fromarray(tensor) 176 | if rois is not None: 177 | image_data = draw_boxes(image_data, rois) 178 | image_data = image_data.resize((scaled_width, scaled_height), Image.LANCZOS) 179 | 180 | return save_image( 181 | asset_path=asset_path, 182 | image_data=image_data, 183 | height=height, 184 | width=width, 185 | colorspace=colorspace, 186 | asset_rel_path=asset_rel_path, 187 | ) 188 | 189 | 190 | def save_image( 191 | asset_path: str, 192 | image_data, 193 | height, 194 | width, 195 | colorspace, 196 | asset_rel_path: Optional[str] = None, 197 | ): 198 | check_or_create_path(asset_path, is_dir=False) 199 | image_data.save(asset_path, format="PNG") 200 | return V1EventImage( 201 | height=height, 202 | width=width, 203 | colorspace=colorspace, 204 | path=asset_rel_path or asset_path, 205 | ) 206 | 207 | 208 | def figure_to_image(figure, close: bool = True): 209 | """Render matplotlib figure to numpy format. 210 | 211 | Returns: 212 | numpy.array: image in [CHW] order 213 | """ 214 | if not np: 215 | logger.warning(NUMPY_ERROR_MESSAGE) 216 | 217 | try: 218 | import matplotlib.backends.backend_agg as plt_backend_agg 219 | import matplotlib.pyplot as plt 220 | except ImportError: 221 | logger.warning(MATPLOTLIB_ERROR_MESSAGE) 222 | 223 | canvas = plt_backend_agg.FigureCanvasAgg(figure) 224 | canvas.draw() 225 | data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) 226 | w, h = figure.canvas.get_width_height() 227 | image_hwc = data.reshape([h, w, 4])[:, :, 0:3] 228 | image_chw = np.moveaxis(image_hwc, source=2, destination=0) 229 | if close: 230 | try: 231 | plt.close(figure.number) 232 | except Exception: # noqa 233 | plt.close(figure) 234 | return image_chw 235 | 236 | 237 | def figures_to_images(figures, close=True): 238 | """Render matplotlib figure to numpy format. 239 | 240 | Returns: 241 | numpy.array: image in [CHW] order 242 | """ 243 | if not np: 244 | logger.warning(NUMPY_ERROR_MESSAGE) 245 | return UNKNOWN 246 | 247 | images = [figure_to_image(figure, close=close) for figure in figures] 248 | return np.stack(images) 249 | 250 | 251 | def ensure_matplotlib_figure(figure): 252 | """Extract the current figure from a matplotlib object or return the object if it's a figure. 253 | raises ValueError if the object can't be converted. 254 | """ 255 | try: 256 | import matplotlib 257 | 258 | from matplotlib.figure import Figure 259 | except ImportError: 260 | logger.warning(MATPLOTLIB_ERROR_MESSAGE) 261 | 262 | if figure == matplotlib.pyplot: 263 | figure = figure.gcf() 264 | elif not isinstance(figure, Figure): 265 | if hasattr(figure, "figure"): 266 | figure = figure.figure 267 | # Some matplotlib objects have a figure function 268 | if not isinstance(figure, Figure): 269 | raise ValueError( 270 | "Only matplotlib.pyplot or matplotlib.pyplot.Figure objects are accepted." 271 | ) 272 | if not figure.gca().has_data(): 273 | raise ValueError( 274 | "You attempted to log an empty plot, " 275 | "pass a figure directly or ensure the global plot isn't closed." 276 | ) 277 | return figure 278 | 279 | 280 | def make_grid(data, ncols=8): 281 | # I: N1HW or N3HW 282 | if not np: 283 | logger.warning(NUMPY_ERROR_MESSAGE) 284 | return UNKNOWN 285 | 286 | assert isinstance(data, np.ndarray), "plugin error, should pass numpy array here" 287 | if data.shape[1] == 1: 288 | data = np.concatenate([data, data, data], 1) 289 | assert data.ndim == 4 and data.shape[1] == 3 or data.shape[1] == 4 290 | nimg = data.shape[0] 291 | H = data.shape[2] # noqa 292 | W = data.shape[3] # noqa 293 | ncols = min(nimg, ncols) 294 | nrows = int(np.ceil(float(nimg) / ncols)) 295 | canvas = np.zeros((data.shape[1], H * nrows, W * ncols)) 296 | i = 0 297 | for y in range(nrows): 298 | for x in range(ncols): 299 | if i >= nimg: 300 | break 301 | canvas[:, y * H : (y + 1) * H, x * W : (x + 1) * W] = data[i] # noqa 302 | i = i + 1 303 | return canvas 304 | 305 | 306 | def convert_to_HWC(tensor, input_format): # noqa 307 | if not np: 308 | logger.warning(NUMPY_ERROR_MESSAGE) 309 | return UNKNOWN 310 | 311 | assert len(set(input_format)) == len( 312 | input_format 313 | ), "You can not use the same dimension shorthand twice. \ 314 | input_format: {}".format(input_format) 315 | assert len(tensor.shape) == len( 316 | input_format 317 | ), "size of input tensor and input format are different. \ 318 | tensor shape: {}, input_format: {}".format(tensor.shape, input_format) 319 | input_format = input_format.upper() 320 | 321 | if len(input_format) == 4: 322 | index = [input_format.find(c) for c in "NCHW"] 323 | tensor_NCHW = tensor.transpose(index) # noqa 324 | tensor_CHW = make_grid(tensor_NCHW) # noqa 325 | return tensor_CHW.transpose(1, 2, 0) 326 | 327 | if len(input_format) == 3: 328 | index = [input_format.find(c) for c in "HWC"] 329 | tensor_HWC = tensor.transpose(index) # noqa 330 | if tensor_HWC.shape[2] == 1: 331 | tensor_HWC = np.concatenate([tensor_HWC, tensor_HWC, tensor_HWC], 2) # noqa 332 | return tensor_HWC 333 | 334 | if len(input_format) == 2: 335 | index = [input_format.find(c) for c in "HW"] 336 | tensor = tensor.transpose(index) 337 | tensor = np.stack([tensor, tensor, tensor], 2) 338 | return tensor 339 | -------------------------------------------------------------------------------- /traceml/traceml/processors/events_processors/events_metrics_processors.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from clipped.utils.np import to_np 4 | 5 | from polyaxon._constants.globals import UNKNOWN 6 | from traceml.artifacts import V1ArtifactKind 7 | from traceml.events import ( 8 | LoggedEventSpec, 9 | V1Event, 10 | V1EventConfusionMatrix, 11 | V1EventCurve, 12 | V1EventCurveKind, 13 | V1EventHistogram, 14 | ) 15 | from traceml.logger import logger 16 | from traceml.processors.errors import NUMPY_ERROR_MESSAGE, SKLEARN_ERROR_MESSAGE 17 | 18 | try: 19 | import numpy as np 20 | except ImportError: 21 | np = None 22 | 23 | 24 | def metric(value): 25 | if isinstance(value, float): 26 | return value 27 | 28 | if not np: 29 | logger.warning(NUMPY_ERROR_MESSAGE) 30 | return UNKNOWN 31 | 32 | value = to_np(value) 33 | assert value.squeeze().ndim == 0, "scalar should be 0D" 34 | return float(value) 35 | 36 | 37 | def histogram(values, bins, max_bins=None): 38 | if not np: 39 | logger.warning(NUMPY_ERROR_MESSAGE) 40 | return UNKNOWN 41 | 42 | values = to_np(values).astype(float) 43 | 44 | if values.size == 0: 45 | raise ValueError("The input has no element.") 46 | values = values.reshape(-1) 47 | values, counts = np.histogram(values, bins=bins) 48 | 49 | if counts.size == 0: 50 | logger.warning("Tracking an empty histogram") 51 | return UNKNOWN 52 | 53 | return np_histogram(values=values, counts=counts, max_bins=max_bins) 54 | 55 | 56 | def np_histogram(values, counts, max_bins=None): 57 | try: 58 | values = values.tolist() 59 | counts = counts.tolist() 60 | except: # noqa 61 | pass 62 | max_bins = max_bins or 512 63 | values_len = len(values) 64 | counts_len = len(counts) 65 | if values_len > max_bins: 66 | raise ValueError( 67 | "The maximum bins for a histogram is {}, received {}".format( 68 | max_bins, values_len 69 | ) 70 | ) 71 | if values_len + 1 != counts_len: 72 | raise ValueError("len(hist.values) must be len(hist.counts) + 1") 73 | return V1EventHistogram(values=values, counts=counts) 74 | 75 | 76 | def roc_auc_curve(fpr, tpr, auc=None): 77 | return V1EventCurve( 78 | kind=V1EventCurveKind.ROC, 79 | x=fpr, 80 | y=tpr, 81 | annotation=str(auc) if auc else None, 82 | ) 83 | 84 | 85 | def sklearn_roc_auc_curve(y_preds, y_targets, pos_label=None): 86 | try: 87 | from sklearn.metrics import auc, roc_curve 88 | except ImportError: 89 | logger.warning(SKLEARN_ERROR_MESSAGE) 90 | 91 | try: 92 | y_true = y_targets.numpy() 93 | except AttributeError: 94 | y_true = y_targets 95 | try: 96 | y_pred = y_preds.numpy() 97 | except AttributeError: 98 | y_pred = y_preds 99 | fpr, tpr, _ = roc_curve(y_true, y_pred, pos_label=pos_label) 100 | auc_score = auc(fpr, tpr) 101 | return V1EventCurve( 102 | kind=V1EventCurveKind.ROC, 103 | x=fpr, 104 | y=tpr, 105 | annotation=str(auc_score), 106 | ) 107 | 108 | 109 | def pr_curve(precision, recall, average_precision=None): 110 | return V1EventCurve( 111 | kind=V1EventCurveKind.PR, 112 | x=precision, 113 | y=recall, 114 | annotation=str(average_precision) if average_precision else None, 115 | ) 116 | 117 | 118 | def sklearn_pr_curve(y_preds, y_targets, pos_label=None): 119 | try: 120 | from sklearn.metrics import average_precision_score, precision_recall_curve 121 | except ImportError: 122 | logger.warning(SKLEARN_ERROR_MESSAGE) 123 | 124 | try: 125 | y_true = y_targets.numpy() 126 | except AttributeError: 127 | y_true = y_targets 128 | try: 129 | y_pred = y_preds.numpy() 130 | except AttributeError: 131 | y_pred = y_preds 132 | 133 | precision, recall, _ = precision_recall_curve(y_true, y_pred, pos_label=pos_label) 134 | ap = average_precision_score(y_true, y_pred) 135 | return V1EventCurve( 136 | kind=V1EventCurveKind.PR, 137 | x=precision, 138 | y=recall, 139 | annotation=str(ap), 140 | ) 141 | 142 | 143 | def curve(x, y, annotation=None): 144 | return V1EventCurve( 145 | kind=V1EventCurveKind.CUSTOM, 146 | x=x, 147 | y=y, 148 | annotation=str(annotation) if annotation else None, 149 | ) 150 | 151 | 152 | def confusion_matrix(x, y, z): 153 | if hasattr(x, "tolist"): 154 | x = x.tolist() 155 | if hasattr(x, "tolist"): 156 | y = y.tolist() 157 | if hasattr(x, "tolist"): 158 | z = z.tolist() 159 | try: 160 | x_len = len(x) 161 | y_len = len(y) 162 | z_len = len(z) 163 | if x_len != y_len or x_len != z_len: 164 | raise ValueError( 165 | "Received invalid data for confusion matrix. " 166 | "All arrays must have the same structure: " 167 | "[len(x): {}, len(y): {}, len(z): {}]".format( 168 | x_len, 169 | y_len, 170 | z_len, 171 | ) 172 | ) 173 | zi_len = [len(zi) for zi in z] 174 | if len(set(zi_len)) != 1 or zi_len[0] != z_len: 175 | raise ValueError( 176 | "Received invalid data for confusion matrix. " 177 | "Current structure: [len(x): {}, len(y): {}, len(z): {}]. " 178 | "The z array has different nested structures: {}".format( 179 | x_len, y_len, z_len, zi_len 180 | ) 181 | ) 182 | except Exception as e: # noqa 183 | raise ValueError( 184 | "Received invalid data for confusion matrix. Error {}".format(e) 185 | ) 186 | return V1EventConfusionMatrix( 187 | x=x, 188 | y=y, 189 | z=z, 190 | ) 191 | 192 | 193 | def metrics_dict_to_list(metrics: Dict) -> List: 194 | results = [] 195 | for k, v in metrics.items(): 196 | results.append( 197 | LoggedEventSpec( 198 | name=k, 199 | kind=V1ArtifactKind.METRIC, 200 | event=V1Event.make(metric=v), 201 | ) 202 | ) 203 | return results 204 | -------------------------------------------------------------------------------- /traceml/traceml/processors/events_processors/events_models_processors.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from clipped.utils.paths import copy_file_or_dir_path 4 | 5 | from traceml.events import V1EventModel 6 | from traceml.logger import logger 7 | 8 | try: 9 | import numpy as np 10 | except ImportError: 11 | np = None 12 | 13 | 14 | def model_path( 15 | from_path: str, 16 | asset_path: str, 17 | framework: Optional[str] = None, 18 | spec: Optional[Dict] = None, 19 | asset_rel_path: Optional[str] = None, 20 | ) -> V1EventModel: 21 | copy_file_or_dir_path(from_path, asset_path) 22 | return V1EventModel( 23 | path=asset_rel_path or asset_path, framework=framework, spec=spec 24 | ) 25 | 26 | 27 | def _model_to_str(model): 28 | filetype = "txt" 29 | if hasattr(model, "to_json"): 30 | model = model.model.to_json() 31 | filetype = "json" 32 | elif hasattr(model, "to_yaml"): 33 | model = model.to_yaml() 34 | filetype = "yaml" 35 | 36 | try: 37 | return str(model), filetype 38 | except Exception as e: 39 | logger.warning("Could not convert model to a string. Error: %s" % e) 40 | 41 | 42 | def model_to_str(model): 43 | # Tensorflow Graph Definition 44 | if type(model).__name__ == "Graph": 45 | try: 46 | from google.protobuf import json_format 47 | 48 | graph_def = model.as_graph_def() 49 | model = json_format.MessageToJson(graph_def, sort_keys=True) 50 | except Exception as e: # noqa 51 | logger.warning( 52 | "Could not convert Tensorflow graph to JSON %s", e, exc_info=True 53 | ) 54 | 55 | return _model_to_str(model) 56 | -------------------------------------------------------------------------------- /traceml/traceml/processors/events_processors/events_tables_processors.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from clipped.utils.paths import copy_file_path 4 | 5 | from traceml.events import V1EventDataframe 6 | 7 | try: 8 | import numpy as np 9 | except ImportError: 10 | np = None 11 | 12 | 13 | def dataframe_path( 14 | from_path: str, 15 | asset_path: str, 16 | content_type: Optional[str] = None, 17 | asset_rel_path: Optional[str] = None, 18 | ) -> V1EventDataframe: 19 | copy_file_path(from_path, asset_path) 20 | return V1EventDataframe( 21 | path=asset_rel_path or asset_path, content_type=content_type 22 | ) 23 | -------------------------------------------------------------------------------- /traceml/traceml/processors/events_processors/events_video_processors.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from clipped.utils.np import calculate_scale_factor, to_np 4 | from clipped.utils.paths import check_or_create_path, copy_file_path 5 | 6 | from polyaxon._constants.globals import UNKNOWN 7 | from traceml.events import V1EventVideo 8 | from traceml.logger import logger 9 | from traceml.processors.errors import MOVIEPY_ERROR_MESSAGE, NUMPY_ERROR_MESSAGE 10 | 11 | try: 12 | import numpy as np 13 | except ImportError: 14 | np = None 15 | 16 | 17 | def video_path( 18 | from_path: str, 19 | asset_path: str, 20 | content_type=None, 21 | asset_rel_path: Optional[str] = None, 22 | ) -> V1EventVideo: 23 | copy_file_path(from_path, asset_path) 24 | return V1EventVideo(path=asset_rel_path or asset_path, content_type=content_type) 25 | 26 | 27 | def video( 28 | asset_path: str, 29 | tensor, 30 | fps=4, 31 | content_type="gif", 32 | asset_rel_path: Optional[str] = None, 33 | ): 34 | if not np: 35 | logger.warning(NUMPY_ERROR_MESSAGE) 36 | return UNKNOWN 37 | 38 | tensor = to_np(tensor) 39 | tensor = prepare_video(tensor) 40 | # If user passes in uint8, then we don't need to rescale by 255 41 | scale_factor = calculate_scale_factor(tensor) 42 | tensor = tensor.astype(np.float32) 43 | tensor = (tensor * scale_factor).astype(np.uint8) 44 | return make_video( 45 | asset_path, tensor, fps, content_type, asset_rel_path=asset_rel_path 46 | ) 47 | 48 | 49 | def make_video( 50 | asset_path: str, 51 | tensor, 52 | fps, 53 | content_type="gif", 54 | asset_rel_path: Optional[str] = None, 55 | ): 56 | try: 57 | import moviepy # noqa: F401 58 | except ImportError: 59 | logger.warning(MOVIEPY_ERROR_MESSAGE) 60 | return UNKNOWN 61 | try: 62 | from moviepy import editor as mpy 63 | except ImportError: 64 | logger.warning( 65 | "moviepy is installed, but can't import moviepy.editor.", 66 | "Some packages could be missing [imageio, requests]", 67 | ) 68 | return 69 | 70 | t, h, w, c = tensor.shape 71 | 72 | # encode sequence of images into gif string 73 | clip = mpy.ImageSequenceClip(list(tensor), fps=fps) 74 | 75 | check_or_create_path(asset_path, is_dir=False) 76 | 77 | try: # older version of moviepy 78 | if content_type == "gif": 79 | clip.write_gif(asset_path, verbose=False, logger=None) 80 | else: 81 | clip.write_videofile(asset_path, verbose=False, logger=None) 82 | except TypeError: 83 | if content_type == "gif": 84 | clip.write_gif(asset_path, verbose=False) 85 | else: 86 | clip.write_videofile(asset_path, verbose=False) 87 | 88 | return V1EventVideo( 89 | height=h, 90 | width=w, 91 | colorspace=c, 92 | path=asset_rel_path or asset_path, 93 | content_type=content_type, 94 | ) 95 | 96 | 97 | def prepare_video(data): 98 | """ 99 | Converts a 5D tensor [batchsize, time(frame), channel(color), height, width] 100 | into 4D tensor with dimension [time(frame), new_width, new_height, channel]. 101 | A batch of images are spreaded to a grid, which forms a frame. 102 | e.g. Video with batchsize 16 will have a 4x4 grid. 103 | """ 104 | if not np: 105 | logger.warning(NUMPY_ERROR_MESSAGE) 106 | return UNKNOWN 107 | 108 | b, t, c, h, w = data.shape 109 | 110 | if data.dtype == np.uint8: 111 | data = np.float32(data) / 255.0 112 | 113 | def is_power2(num): 114 | return num != 0 and ((num & (num - 1)) == 0) 115 | 116 | # pad to nearest power of 2, all at once 117 | if not is_power2(data.shape[0]): 118 | len_addition = int(2 ** data.shape[0].bit_length() - data.shape[0]) 119 | data = np.concatenate( 120 | (data, np.zeros(shape=(len_addition, t, c, h, w))), axis=0 121 | ) 122 | 123 | n_rows = 2 ** ((b.bit_length() - 1) // 2) 124 | n_cols = data.shape[0] // n_rows 125 | 126 | data = np.reshape(data, newshape=(n_rows, n_cols, t, c, h, w)) 127 | data = np.transpose(data, axes=(2, 0, 4, 1, 5, 3)) 128 | return np.reshape(data, newshape=(t, n_rows * h, n_cols * w, c)) 129 | -------------------------------------------------------------------------------- /traceml/traceml/processors/gpu_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from traceml.logger import logger 4 | from traceml.processors.events_processors import metrics_dict_to_list 5 | from traceml.vendor import pynvml 6 | 7 | try: 8 | import psutil 9 | except ImportError: 10 | psutil = None 11 | 12 | 13 | def can_log_gpu_resources(): 14 | if pynvml is None: 15 | return False 16 | 17 | try: 18 | pynvml.nvmlInit() 19 | return True 20 | except pynvml.NVMLError: 21 | return False 22 | 23 | 24 | def query_gpu(handle_idx: int, handle: any) -> Dict: 25 | memory = pynvml.nvmlDeviceGetMemoryInfo(handle) # in Bytes 26 | utilization = pynvml.nvmlDeviceGetUtilizationRates(handle) 27 | 28 | return { 29 | "gpu_{}_memory_free".format(handle_idx): int(memory.free), 30 | "gpu_{}_memory_used".format(handle_idx): int(memory.used), 31 | "gpu_{}_utilization".format(handle_idx): utilization.gpu, 32 | } 33 | 34 | 35 | def get_gpu_metrics() -> List: 36 | try: 37 | pynvml.nvmlInit() 38 | device_count = pynvml.nvmlDeviceGetCount() 39 | results = [] 40 | 41 | for i in range(device_count): 42 | handle = pynvml.nvmlDeviceGetHandleByIndex(i) 43 | results += metrics_dict_to_list(query_gpu(i, handle)) 44 | return results 45 | except pynvml.NVMLError: 46 | logger.debug("Failed to collect gpu resources", exc_info=True) 47 | return [] 48 | -------------------------------------------------------------------------------- /traceml/traceml/processors/importance_processors.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import pandas as pd 4 | 5 | from typing import Dict, List, Optional, Tuple, Union 6 | 7 | from clipped.utils.np import sanitize_np_types 8 | 9 | 10 | def clean_duplicates( 11 | params: pd.DataFrame, metrics: pd.DataFrame 12 | ) -> Optional[Tuple[pd.DataFrame, pd.DataFrame]]: 13 | duplicate_ids = metrics.duplicated() 14 | params_df = params[~duplicate_ids] 15 | metrics_df = metrics[~duplicate_ids] 16 | if params.empty or metrics.empty: 17 | return None 18 | 19 | params_df = pd.get_dummies(params_df) 20 | params_df = params_df.loc[:, ~params_df.columns.duplicated()] 21 | return params_df, metrics_df 22 | 23 | 24 | def clean_values( 25 | params: List[Dict], metrics: List[Union[int, float]] 26 | ) -> Optional[Tuple[pd.DataFrame, pd.DataFrame]]: 27 | if not metrics or not params: 28 | return None 29 | 30 | for m in metrics: 31 | if not isinstance(m, (int, float)): 32 | return None 33 | 34 | metrics_df = pd.DataFrame(metrics) 35 | if metrics_df.isnull().values.any(): 36 | return None 37 | 38 | params_df = pd.DataFrame.from_records(params).replace(r"^\s*$", np.nan, regex=True) 39 | for col in params_df.columns: 40 | if not params_df[col].isnull().sum() == len(params_df[col]): 41 | if params_df[col].dtype == "object": 42 | params_df[col].fillna("NAN", inplace=True) 43 | params_df[col].fillna("NAN", inplace=True) 44 | params_df[col] = params_df[col].astype("category") 45 | elif params_df[col].dtype == "float64" or params_df[col].dtype == "int64": 46 | params_df[col].fillna(params_df[col].mean(), inplace=True) 47 | else: 48 | print("Unexpected Column type: {}".format(params_df[col].dtype)) 49 | else: 50 | if params_df[col].dtype == "object": 51 | params_df[col] = "NAN" 52 | elif params_df[col].dtype == "float64" or params_df[col].dtype == "int64": 53 | params_df[col] = 0 54 | 55 | return clean_duplicates(params_df, metrics_df) 56 | 57 | 58 | def _get_value(x): 59 | if x is None or math.isnan(x): 60 | return None 61 | return round(sanitize_np_types(x), 3) 62 | 63 | 64 | def calculate_importance_correlation( 65 | params: List[Dict], metrics: List[Union[int, float]] 66 | ): 67 | values = clean_values(params, metrics) 68 | if not values: 69 | return None 70 | params_df, metrics_df = values 71 | 72 | corr_list = params_df.corrwith(metrics_df[0]) 73 | 74 | from sklearn.ensemble import ExtraTreesRegressor 75 | 76 | forest = ExtraTreesRegressor(n_estimators=250, random_state=0) 77 | forest.fit(params_df, metrics_df[0]) 78 | feature_importances = forest.feature_importances_ 79 | 80 | results = {} 81 | for i, name in enumerate(params_df.columns): 82 | results[name] = { 83 | "importance": _get_value(feature_importances[i]), 84 | "correlation": _get_value(corr_list[name]), 85 | } 86 | return results 87 | -------------------------------------------------------------------------------- /traceml/traceml/processors/logs_processor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | from traceml.logging.handler import LogStreamHandler, LogStreamWriter 5 | 6 | EXCLUDE_DEFAULT_LOGGERS = ("polyaxon.client", "polyaxon.cli", "traceml") 7 | 8 | 9 | def start_log_processor(add_logs, exclude=EXCLUDE_DEFAULT_LOGGERS): 10 | plx_logger = logging.getLogger("__plx__") 11 | plx_logger.setLevel(logging.INFO) 12 | if LogStreamHandler in map(type, plx_logger.handlers): 13 | for handler in plx_logger.handlers: 14 | if isinstance(handler, LogStreamHandler): 15 | handler.set_add_logs(add_logs=add_logs) 16 | else: 17 | handler = LogStreamHandler(add_logs=add_logs) 18 | plx_logger.addHandler(handler) 19 | 20 | exclude = ("__plx__",) + (exclude or ()) 21 | for logger_name in exclude: 22 | logger = logging.getLogger(logger_name) 23 | if logging.StreamHandler not in map(type, logger.handlers): 24 | logger.addHandler(logging.StreamHandler()) 25 | logger.propagate = False 26 | 27 | sys.stdout = LogStreamWriter(plx_logger, logging.INFO, sys.__stdout__) 28 | sys.stderr = LogStreamWriter(plx_logger, logging.ERROR, sys.__stderr__) 29 | 30 | 31 | def end_log_processor(): 32 | sys.stdout = sys.__stdout__ 33 | sys.stderr = sys.__stderr__ 34 | -------------------------------------------------------------------------------- /traceml/traceml/processors/psutil_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from traceml.processors.events_processors import metrics_dict_to_list 4 | 5 | try: 6 | import psutil 7 | except ImportError: 8 | psutil = None 9 | 10 | 11 | def can_log_psutil_resources(): 12 | return psutil is not None 13 | 14 | 15 | def query_psutil() -> Dict: 16 | results = {} 17 | try: 18 | # psutil <= 5.6.2 did not have getloadavg: 19 | if hasattr(psutil, "getloadavg"): 20 | results["load"] = psutil.getloadavg()[0] 21 | else: 22 | # Do not log an empty metric 23 | pass 24 | except OSError: 25 | pass 26 | vm = psutil.virtual_memory() 27 | results["cpu"] = psutil.cpu_percent(interval=None) 28 | results["memory"] = vm.percent 29 | return results 30 | 31 | 32 | def get_psutils_metrics() -> List: 33 | return metrics_dict_to_list(query_psutil()) 34 | -------------------------------------------------------------------------------- /traceml/traceml/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/polyaxon/traceml/3876c98841ada8d1b13ceb22e4a1f3d33060dfb1/traceml/traceml/py.typed -------------------------------------------------------------------------------- /traceml/traceml/serialization/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/traceml/serialization/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import List 4 | 5 | from clipped.utils.enums import get_enum_value 6 | from clipped.utils.paths import check_or_create_path, set_permissions 7 | 8 | from traceml.artifacts import V1ArtifactKind 9 | from traceml.events import ( 10 | LoggedEventListSpec, 11 | LoggedEventSpec, 12 | get_event_path, 13 | get_resource_path, 14 | ) 15 | 16 | 17 | class EventWriter: 18 | EVENTS_BACKEND = "events" 19 | RESOURCES_BACKEND = "resources" 20 | 21 | def __init__(self, run_path: str, backend: str): 22 | self._events_backend = backend 23 | self._run_path = run_path 24 | self._files = {} # type: dict[str, LoggedEventListSpec] 25 | self._closed = False 26 | 27 | def _get_event_path(self, kind: str, name: str) -> str: 28 | if self._events_backend == self.EVENTS_BACKEND: 29 | return get_event_path( 30 | run_path=self._run_path, 31 | kind=kind, 32 | name=name, 33 | ) 34 | if self._events_backend == self.RESOURCES_BACKEND: 35 | return get_resource_path( 36 | run_path=self._run_path, 37 | kind=kind, 38 | name=name, 39 | ) 40 | raise ValueError( 41 | "Unrecognized backend {}".format(get_enum_value(self._events_backend)) 42 | ) 43 | 44 | def _init_events(self, events_spec: LoggedEventListSpec): 45 | event_path = self._get_event_path(kind=events_spec.kind, name=events_spec.name) 46 | # Check if the file exists otherwise initialize 47 | if not os.path.exists(event_path): 48 | check_or_create_path(event_path, is_dir=False) 49 | with open(event_path, "w") as event_file: 50 | if V1ArtifactKind.is_jsonl_file_event(events_spec.kind): 51 | event_file.write("") 52 | else: 53 | event_file.write(events_spec.get_csv_header()) 54 | set_permissions(event_path) 55 | 56 | def _append_events(self, events_spec: LoggedEventListSpec): 57 | event_path = self._get_event_path(kind=events_spec.kind, name=events_spec.name) 58 | with open(event_path, "a") as event_file: 59 | if V1ArtifactKind.is_jsonl_file_event(events_spec.kind): 60 | event_file.write(events_spec.get_jsonl_events()) 61 | else: 62 | event_file.write(events_spec.get_csv_events()) 63 | 64 | def _events_to_files(self, events: List[LoggedEventSpec]): 65 | for event in events: 66 | file_name = "{}.{}".format(event.kind, event.name) 67 | if file_name in self._files: 68 | self._files[file_name].events.append(event.event) 69 | else: 70 | self._files[file_name] = LoggedEventListSpec( 71 | kind=event.kind, name=event.name, events=[event.event] 72 | ) 73 | self._init_events(self._files[file_name]) 74 | 75 | def write(self, events: List[LoggedEventSpec]): 76 | if not events: 77 | return 78 | if isinstance(events, LoggedEventSpec): 79 | events = [events] 80 | self._events_to_files(events) 81 | 82 | def flush(self): 83 | for file_name in self._files: 84 | events_spec = self._files[file_name] 85 | if events_spec.events: 86 | self._append_events(events_spec) 87 | self._files[file_name].empty_events() 88 | 89 | def close(self): 90 | self.flush() 91 | self._closed = True 92 | 93 | @property 94 | def closed(self): 95 | return self._closed 96 | 97 | 98 | class BaseFileWriter: 99 | """Writes `LoggedEventSpec` to event files. 100 | 101 | The `EventFileWriter` class creates a event files in the run path, 102 | and asynchronously writes Events to the files. 103 | """ 104 | 105 | def __init__(self, run_path: str): 106 | self._run_path = run_path 107 | check_or_create_path(run_path, is_dir=True) 108 | 109 | @property 110 | def run_path(self): 111 | return self._run_path 112 | 113 | def add_event(self, event: LoggedEventSpec): 114 | if not isinstance(event, LoggedEventSpec): 115 | raise TypeError("Expected an LoggedEventSpec, " " but got %s" % type(event)) 116 | self._async_writer.write(event) 117 | 118 | def add_events(self, events: List[LoggedEventSpec]): 119 | for e in events: 120 | if not isinstance(e, LoggedEventSpec): 121 | raise TypeError("Expected an LoggedEventSpec, " " but got %s" % type(e)) 122 | self._async_writer.write(events) 123 | 124 | def flush(self): 125 | """Flushes the event files to disk. 126 | 127 | Call this method to make sure that all pending events have been 128 | written to disk. 129 | """ 130 | self._async_writer.flush() 131 | 132 | def close(self): 133 | """Performs a final flush of the event files to disk, stops the 134 | write/flush worker and closes the files. 135 | 136 | Call this method when you do not need the writer anymore. 137 | """ 138 | self._async_writer.close() 139 | -------------------------------------------------------------------------------- /traceml/traceml/serialization/writer.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import threading 3 | import time 4 | 5 | from typing import List, Union 6 | 7 | from clipped.utils.paths import check_or_create_path 8 | 9 | from traceml.events import LoggedEventSpec, get_asset_path, get_event_path 10 | from traceml.events.paths import get_resource_path 11 | from traceml.processors.gpu_processor import can_log_gpu_resources, get_gpu_metrics 12 | from traceml.processors.psutil_processor import ( 13 | can_log_psutil_resources, 14 | get_psutils_metrics, 15 | ) 16 | from traceml.serialization.base import BaseFileWriter, EventWriter 17 | 18 | 19 | class EventFileWriter(BaseFileWriter): 20 | def __init__(self, run_path: str, max_queue_size: int = 20, flush_secs: int = 10): 21 | """Creates a `EventFileWriter`. 22 | 23 | Args: 24 | run_path: A string. Directory where events files will be written. 25 | max_queue_size: Integer. Size of the queue for pending events and summaries. 26 | flush_secs: Number. How often, in seconds, to flush the 27 | pending events and summaries to disk. 28 | """ 29 | super().__init__(run_path=run_path) 30 | 31 | check_or_create_path(get_event_path(run_path), is_dir=True) 32 | check_or_create_path(get_asset_path(run_path), is_dir=True) 33 | 34 | self._async_writer = EventAsyncManager( 35 | EventWriter(self._run_path, backend=EventWriter.EVENTS_BACKEND), 36 | max_queue_size, 37 | flush_secs, 38 | ) 39 | 40 | 41 | class ResourceFileWriter(BaseFileWriter): 42 | def __init__(self, run_path: str, max_queue_size: int = 20, flush_secs: int = 10): 43 | """Creates a `ResourceFileWriter`. 44 | 45 | Args: 46 | run_path: A string. Directory where events files will be written. 47 | max_queue_size: Integer. Size of the queue for pending events and summaries. 48 | flush_secs: Number. How often, in seconds, to flush the 49 | pending events and summaries to disk. 50 | """ 51 | super().__init__(run_path=run_path) 52 | 53 | check_or_create_path(get_resource_path(run_path), is_dir=True) 54 | 55 | self._async_writer = ResourceAsyncManager( 56 | EventWriter(self._run_path, backend=EventWriter.RESOURCES_BACKEND), 57 | max_queue_size, 58 | flush_secs, 59 | ) 60 | 61 | 62 | class BaseAsyncManager: 63 | """Base manager for writing events to files by name by event kind.""" 64 | 65 | def __init__(self, event_writer: EventWriter, max_queue_size: int = 20): 66 | """Writes events json spec to files asynchronously. An instance of this class 67 | holds a queue to keep the incoming data temporarily. Data passed to the 68 | `write` function will be put to the queue and the function returns 69 | immediately. This class also maintains a thread to write data in the 70 | queue to disk. 71 | 72 | Args: 73 | event_writer: A EventWriter instance 74 | max_queue_size: Integer. Size of the queue for pending bytestrings. 75 | flush_secs: Number. How often, in seconds, to flush the 76 | pending bytestrings to disk. 77 | """ 78 | self._event_writer = event_writer 79 | self._closed = False 80 | self._event_queue = queue.Queue(max_queue_size) 81 | self._lock = threading.Lock() 82 | self._worker = None 83 | 84 | def write(self, event: Union[LoggedEventSpec, List[LoggedEventSpec]]): 85 | """Enqueue the given event to be written asynchronously.""" 86 | with self._lock: 87 | if self._closed: 88 | raise IOError("Writer is closed") 89 | self._event_queue.put(event) 90 | 91 | def flush(self): 92 | """Write all the enqueued events before this flush call to disk. 93 | 94 | Block until all the above events are written. 95 | """ 96 | with self._lock: 97 | if self._closed: 98 | raise IOError("Writer is closed") 99 | self._event_queue.join() 100 | self._event_writer.flush() 101 | 102 | def close(self): 103 | """Closes the underlying writer, flushing any pending writes first.""" 104 | if not self._closed: 105 | with self._lock: 106 | if not self._closed: 107 | self._closed = True 108 | self._worker.stop() 109 | self._event_writer.flush() 110 | self._event_writer.close() 111 | 112 | 113 | class EventAsyncManager(BaseAsyncManager): 114 | """Writes events to files by name by event kind.""" 115 | 116 | def __init__( 117 | self, event_writer: EventWriter, max_queue_size: int = 20, flush_secs: int = 10 118 | ): 119 | super().__init__(event_writer=event_writer, max_queue_size=max_queue_size) 120 | self._worker = EventWriterThread( 121 | self._event_queue, self._event_writer, flush_secs 122 | ) 123 | self._worker.start() 124 | 125 | 126 | class EventWriterThread(threading.Thread): 127 | """Thread that processes asynchronous writes for EventWriter.""" 128 | 129 | def __init__(self, event_queue, event_writer: EventWriter, flush_secs: int): 130 | """Creates an EventWriterThread. 131 | 132 | Args: 133 | event_queue: A Queue from which to dequeue data. 134 | event_writer: An instance of EventWriter. 135 | flush_secs: How often, in seconds, to flush the 136 | pending file to disk. 137 | """ 138 | threading.Thread.__init__(self) 139 | self.daemon = True 140 | self._event_queue = event_queue 141 | self._event_writer = event_writer 142 | self._flush_secs = flush_secs 143 | # The first data will be flushed immediately. 144 | self._next_flush_time = 0 145 | self._has_pending_data = False 146 | self._shutdown_signal = object() 147 | 148 | def stop(self): 149 | self._event_queue.put(self._shutdown_signal) 150 | self.join() 151 | 152 | def run(self): 153 | # Wait for the queue until data appears, or until the next 154 | # time to flush the writer. 155 | # Invoke write If we have data. 156 | # If not, an empty queue exception will be raised and invoke writer flush. 157 | while True: 158 | now = time.time() 159 | queue_wait_duration = self._next_flush_time - now 160 | data = None 161 | try: 162 | if queue_wait_duration > 0: 163 | data = self._event_queue.get(True, queue_wait_duration) 164 | else: 165 | data = self._event_queue.get(False) 166 | 167 | if data is self._shutdown_signal: 168 | return 169 | self._event_writer.write(data) 170 | self._has_pending_data = True 171 | except queue.Empty: 172 | pass 173 | finally: 174 | if data: 175 | self._event_queue.task_done() 176 | 177 | now = time.time() 178 | if now > self._next_flush_time: 179 | if self._has_pending_data: 180 | # Small optimization - if there are no pending data, 181 | # there's no need to flush. 182 | self._event_writer.flush() 183 | self._has_pending_data = False 184 | # Do it again in flush_secs. 185 | self._next_flush_time = now + self._flush_secs 186 | 187 | 188 | class ResourceAsyncManager(BaseAsyncManager): 189 | """Writes resource events to files by name by event kind.""" 190 | 191 | def __init__( 192 | self, event_writer: EventWriter, max_queue_size: int = 20, flush_secs: int = 10 193 | ): 194 | super().__init__(event_writer=event_writer, max_queue_size=max_queue_size) 195 | self._worker = ResourceWriterThread( 196 | self._event_queue, 197 | self._event_writer, 198 | flush_secs, 199 | ) 200 | self._worker.start() 201 | 202 | 203 | class ResourceWriterThread(EventWriterThread): 204 | """Thread that processes periodic resources (cpu, gpu, memory) writes for EventWriter.""" 205 | 206 | def __init__(self, event_queue, event_writer: EventWriter, flush_secs: int): 207 | super().__init__( 208 | event_queue=event_queue, event_writer=event_writer, flush_secs=flush_secs 209 | ) 210 | self._log_psutil_resources = can_log_psutil_resources() 211 | self._log_gpu_resources = can_log_gpu_resources() 212 | 213 | def run(self): 214 | # Wait for flush time to invoke the writer. 215 | while True: 216 | now = time.time() 217 | queue_wait_duration = self._next_flush_time - now 218 | data = None 219 | try: 220 | if queue_wait_duration > 0: 221 | data = self._event_queue.get(True, queue_wait_duration) 222 | else: 223 | data = self._event_queue.get(False) 224 | 225 | if data is self._shutdown_signal: 226 | return 227 | self._event_writer.write(data) 228 | self._has_pending_data = True 229 | except queue.Empty: 230 | pass 231 | finally: 232 | if data: 233 | self._event_queue.task_done() 234 | 235 | now = time.time() 236 | if now > self._next_flush_time: 237 | data = [] 238 | if self._log_psutil_resources: 239 | try: 240 | data += get_psutils_metrics() 241 | except Exception: 242 | pass 243 | try: 244 | data += get_gpu_metrics() 245 | except Exception: 246 | pass 247 | if data: 248 | self._event_writer.write(data) 249 | self._event_writer.flush() 250 | self._next_flush_time = now + self._flush_secs 251 | -------------------------------------------------------------------------------- /traceml/traceml/summary/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/traceml/summary/df.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from traceml.processors import df_processors 5 | 6 | 7 | class DataFrameSummary: 8 | ALL = "ALL" 9 | INCLUDE = "INCLUDE" 10 | EXCLUDE = "EXCLUDE" 11 | 12 | TYPE_BOOL = "bool" 13 | TYPE_NUMERIC = "numeric" 14 | TYPE_DATE = "date" 15 | TYPE_CATEGORICAL = "categorical" 16 | TYPE_CONSTANT = "constant" 17 | TYPE_UNIQUE = "unique" 18 | 19 | def __init__(self, df, plot=False): 20 | self.df = df 21 | self.length = len(df) 22 | self._columns_stats = None 23 | self.plot = plot 24 | 25 | def __getitem__(self, column): 26 | if isinstance(column, str) and df_processors.df_has_column( 27 | df=self.df, column=column 28 | ): 29 | return df_processors.get_df_column_summary( 30 | df=self.df, 31 | column=column, 32 | columns_stats=self.columns_stats, 33 | df_length=self.length, 34 | plot=self.plot, 35 | ) 36 | 37 | if isinstance(column, int) and column < self.df.shape[1]: 38 | return df_processors.get_df_column_summary( 39 | df=self.df, 40 | column=self.df.columns[column], 41 | columns_stats=self.columns_stats, 42 | df_length=self.length, 43 | plot=self.plot, 44 | ) 45 | 46 | if isinstance(column, (tuple, list)): 47 | error_keys = [ 48 | k 49 | for k in column 50 | if not df_processors.df_has_column(df=self.df, column=k) 51 | ] 52 | if len(error_keys) > 0: 53 | raise KeyError(", ".join(error_keys)) 54 | return self.df[list(column)].values 55 | 56 | if isinstance(column, pd.Index): 57 | error_keys = [ 58 | k 59 | for k in column.values 60 | if not df_processors.df_has_column(df=self.df, column=k) 61 | ] 62 | if len(error_keys) > 0: 63 | raise KeyError(", ".join(error_keys)) 64 | return self.df[column].values 65 | 66 | if isinstance(column, np.ndarray): 67 | error_keys = [ 68 | k 69 | for k in column 70 | if not df_processors.df_has_column(df=self.df, column=k) 71 | ] 72 | if len(error_keys) > 0: 73 | raise KeyError(", ".join(error_keys)) 74 | return self.df[column].values 75 | 76 | raise KeyError(column) 77 | 78 | @property 79 | def columns_stats(self): 80 | if self._columns_stats: 81 | return self._columns_stats 82 | self._columns_stats = df_processors.get_df_column_stats(self.df) 83 | return self._columns_stats 84 | 85 | @property 86 | def columns_types(self): 87 | return df_processors.get_df_columns_types(self.columns_stats) 88 | 89 | def summary(self): 90 | return pd.concat([self.df.describe(), self.columns_stats], sort=True)[ 91 | self.df.columns 92 | ] 93 | 94 | """ Column summaries """ 95 | 96 | @property 97 | def constants(self): 98 | return self.df.columns[self.columns_stats.loc["types"] == "constant"] 99 | 100 | @property 101 | def categoricals(self): 102 | return self.df.columns[self.columns_stats.loc["types"] == "categorical"] 103 | 104 | @property 105 | def numerics(self): 106 | return self.df.columns[self.columns_stats.loc["types"] == "numeric"] 107 | 108 | @property 109 | def uniques(self): 110 | return self.df.columns[self.columns_stats.loc["types"] == "unique"] 111 | 112 | @property 113 | def bools(self): 114 | return self.df.columns[self.columns_stats.loc["types"] == "bool"] 115 | 116 | @property 117 | def missing_frac(self): 118 | return self.columns_stats.loc["missing"].apply(lambda x: float(x) / self.length) 119 | 120 | def get_columns(self, df, usage, columns=None): 121 | """ 122 | Returns a `data_frame.columns`. 123 | :param df: dataframe to select columns from 124 | :param usage: should be a value from [ALL, INCLUDE, EXCLUDE]. 125 | this value only makes sense if attr `columns` is also set. 126 | otherwise, should be used with default value ALL. 127 | :param columns: * if `usage` is all, this value is not used. 128 | * if `usage` is INCLUDE, the `df` is restricted to the intersection 129 | between `columns` and the `df.columns` 130 | * if usage is EXCLUDE, returns the `df.columns` excluding these `columns` 131 | :return: `data_frame` columns, excluding `target_column` and `id_column` if given. 132 | `data_frame` columns, including/excluding the `columns` depending on `usage`. 133 | """ 134 | columns_excluded = pd.Index([]) 135 | columns_included = df.columns 136 | 137 | if usage == self.INCLUDE: 138 | try: 139 | columns_included = columns_included.intersection(pd.Index(columns)) 140 | except TypeError: 141 | pass 142 | elif usage == self.EXCLUDE: 143 | try: 144 | columns_excluded = columns_excluded.union(pd.Index(columns)) 145 | except TypeError: 146 | pass 147 | 148 | columns_included = columns_included.difference(columns_excluded) 149 | return columns_included.intersection(df.columns) 150 | -------------------------------------------------------------------------------- /traceml/traceml/vendor/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /traceml/traceml/vendor/matplotlylib/LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016-2018 Plotly, Inc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /traceml/traceml/vendor/matplotlylib/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | matplotlylib 3 | ============ 4 | 5 | This module converts matplotlib figure objects into JSON structures which can 6 | be understood and visualized by Plotly. 7 | 8 | Most of the functionality should be accessed through the parent directory's 9 | 'tools' module or 'plotly' package. 10 | 11 | """ 12 | from __future__ import absolute_import 13 | 14 | from .renderer import PlotlyRenderer 15 | from .mplexporter import Exporter 16 | from .tools import mpl_to_plotly 17 | -------------------------------------------------------------------------------- /traceml/traceml/vendor/matplotlylib/mplexporter/__init__.py: -------------------------------------------------------------------------------- 1 | from .renderers import Renderer 2 | from .exporter import Exporter 3 | -------------------------------------------------------------------------------- /traceml/traceml/vendor/matplotlylib/mplexporter/_py3k_compat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple fixes for Python 2/3 compatibility 3 | """ 4 | import sys 5 | 6 | PY3K = sys.version_info[0] >= 3 7 | 8 | 9 | if PY3K: 10 | import builtins 11 | import functools 12 | 13 | reduce = functools.reduce 14 | zip = builtins.zip 15 | xrange = builtins.range 16 | map = builtins.map 17 | else: 18 | import __builtin__ 19 | import itertools 20 | 21 | builtins = __builtin__ 22 | reduce = __builtin__.reduce 23 | zip = itertools.izip 24 | xrange = __builtin__.xrange 25 | map = itertools.imap 26 | -------------------------------------------------------------------------------- /traceml/traceml/vendor/matplotlylib/mplexporter/renderers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Matplotlib Renderers 3 | ==================== 4 | This submodule contains renderer objects which define renderer behavior used 5 | within the Exporter class. The base renderer class is :class:`Renderer`, an 6 | abstract base class 7 | """ 8 | 9 | from .base import Renderer 10 | from .vega_renderer import VegaRenderer, fig_to_vega 11 | from .vincent_renderer import VincentRenderer, fig_to_vincent 12 | from .fake_renderer import FakeRenderer, FullFakeRenderer 13 | -------------------------------------------------------------------------------- /traceml/traceml/vendor/matplotlylib/mplexporter/renderers/fake_renderer.py: -------------------------------------------------------------------------------- 1 | from .base import Renderer 2 | 3 | 4 | class FakeRenderer(Renderer): 5 | """ 6 | Fake Renderer 7 | 8 | This is a fake renderer which simply outputs a text tree representing the 9 | elements found in the plot(s). This is used in the unit tests for the 10 | package. 11 | 12 | Below are the methods your renderer must implement. You are free to do 13 | anything you wish within the renderer (i.e. build an XML or JSON 14 | representation, call an external API, etc.) Here the renderer just 15 | builds a simple string representation for testing purposes. 16 | """ 17 | 18 | def __init__(self): 19 | self.output = "" 20 | 21 | def open_figure(self, fig, props): 22 | self.output += "opening figure\n" 23 | 24 | def close_figure(self, fig): 25 | self.output += "closing figure\n" 26 | 27 | def open_axes(self, ax, props): 28 | self.output += " opening axes\n" 29 | 30 | def close_axes(self, ax): 31 | self.output += " closing axes\n" 32 | 33 | def open_legend(self, legend, props): 34 | self.output += " opening legend\n" 35 | 36 | def close_legend(self, legend): 37 | self.output += " closing legend\n" 38 | 39 | def draw_text( 40 | self, text, position, coordinates, style, text_type=None, mplobj=None 41 | ): 42 | self.output += " draw text '{0}' {1}\n".format(text, text_type) 43 | 44 | def draw_path( 45 | self, 46 | data, 47 | coordinates, 48 | pathcodes, 49 | style, 50 | offset=None, 51 | offset_coordinates="data", 52 | mplobj=None, 53 | ): 54 | self.output += " draw path with {0} vertices\n".format(data.shape[0]) 55 | 56 | def draw_image(self, imdata, extent, coordinates, style, mplobj=None): 57 | self.output += " draw image of size {0}\n".format(len(imdata)) 58 | 59 | 60 | class FullFakeRenderer(FakeRenderer): 61 | """ 62 | Renderer with the full complement of methods. 63 | 64 | When the following are left undefined, they will be implemented via 65 | other methods in the class. They can be defined explicitly for 66 | more efficient or specialized use within the renderer implementation. 67 | """ 68 | 69 | def draw_line(self, data, coordinates, style, label, mplobj=None): 70 | self.output += " draw line with {0} points\n".format(data.shape[0]) 71 | 72 | def draw_markers(self, data, coordinates, style, label, mplobj=None): 73 | self.output += " draw {0} markers\n".format(data.shape[0]) 74 | 75 | def draw_path_collection( 76 | self, 77 | paths, 78 | path_coordinates, 79 | path_transforms, 80 | offsets, 81 | offset_coordinates, 82 | offset_order, 83 | styles, 84 | mplobj=None, 85 | ): 86 | self.output += " draw path collection " "with {0} offsets\n".format( 87 | offsets.shape[0] 88 | ) 89 | -------------------------------------------------------------------------------- /traceml/traceml/vendor/matplotlylib/mplexporter/renderers/vega_renderer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import random 3 | 4 | from clipped.utils.json import orjson_dumps 5 | from .base import Renderer 6 | from ..exporter import Exporter 7 | 8 | 9 | class VegaRenderer(Renderer): 10 | def open_figure(self, fig, props): 11 | self.props = props 12 | self.figwidth = int(props["figwidth"] * props["dpi"]) 13 | self.figheight = int(props["figheight"] * props["dpi"]) 14 | self.data = [] 15 | self.scales = [] 16 | self.axes = [] 17 | self.marks = [] 18 | 19 | def open_axes(self, ax, props): 20 | if len(self.axes) > 0: 21 | warnings.warn("multiple axes not yet supported") 22 | self.axes = [ 23 | dict(type="x", scale="x", ticks=10), 24 | dict(type="y", scale="y", ticks=10), 25 | ] 26 | self.scales = [ 27 | dict(name="x", domain=props["xlim"], type="linear", range="width",), 28 | dict(name="y", domain=props["ylim"], type="linear", range="height",), 29 | ] 30 | 31 | def draw_line(self, data, coordinates, style, label, mplobj=None): 32 | if coordinates != "data": 33 | warnings.warn("Only data coordinates supported. Skipping this") 34 | dataname = "table{0:03d}".format(len(self.data) + 1) 35 | 36 | # TODO: respect the other style settings 37 | self.data.append( 38 | {"name": dataname, "values": [dict(x=d[0], y=d[1]) for d in data]} 39 | ) 40 | self.marks.append( 41 | { 42 | "type": "line", 43 | "from": {"data": dataname}, 44 | "properties": { 45 | "enter": { 46 | "interpolate": {"value": "monotone"}, 47 | "x": {"scale": "x", "field": "data.x"}, 48 | "y": {"scale": "y", "field": "data.y"}, 49 | "stroke": {"value": style["color"]}, 50 | "strokeOpacity": {"value": style["alpha"]}, 51 | "strokeWidth": {"value": style["linewidth"]}, 52 | } 53 | }, 54 | } 55 | ) 56 | 57 | def draw_markers(self, data, coordinates, style, label, mplobj=None): 58 | if coordinates != "data": 59 | warnings.warn("Only data coordinates supported. Skipping this") 60 | dataname = "table{0:03d}".format(len(self.data) + 1) 61 | 62 | # TODO: respect the other style settings 63 | self.data.append( 64 | {"name": dataname, "values": [dict(x=d[0], y=d[1]) for d in data]} 65 | ) 66 | self.marks.append( 67 | { 68 | "type": "symbol", 69 | "from": {"data": dataname}, 70 | "properties": { 71 | "enter": { 72 | "interpolate": {"value": "monotone"}, 73 | "x": {"scale": "x", "field": "data.x"}, 74 | "y": {"scale": "y", "field": "data.y"}, 75 | "fill": {"value": style["facecolor"]}, 76 | "fillOpacity": {"value": style["alpha"]}, 77 | "stroke": {"value": style["edgecolor"]}, 78 | "strokeOpacity": {"value": style["alpha"]}, 79 | "strokeWidth": {"value": style["edgewidth"]}, 80 | } 81 | }, 82 | } 83 | ) 84 | 85 | def draw_text( 86 | self, text, position, coordinates, style, text_type=None, mplobj=None 87 | ): 88 | if text_type == "xlabel": 89 | self.axes[0]["title"] = text 90 | elif text_type == "ylabel": 91 | self.axes[1]["title"] = text 92 | 93 | 94 | class VegaHTML(object): 95 | def __init__(self, renderer): 96 | self.specification = dict( 97 | width=renderer.figwidth, 98 | height=renderer.figheight, 99 | data=renderer.data, 100 | scales=renderer.scales, 101 | axes=renderer.axes, 102 | marks=renderer.marks, 103 | ) 104 | 105 | def html(self): 106 | """Build the HTML representation for IPython.""" 107 | id = random.randint(0, 2 ** 16) 108 | html = '
' % id 109 | html += "\n" 112 | return html 113 | 114 | def _repr_html_(self): 115 | return self.html() 116 | 117 | 118 | def fig_to_vega(fig, notebook=False): 119 | """Convert a matplotlib figure to vega dictionary 120 | 121 | if notebook=True, then return an object which will display in a notebook 122 | otherwise, return an HTML string. 123 | """ 124 | renderer = VegaRenderer() 125 | Exporter(renderer).run(fig) 126 | vega_html = VegaHTML(renderer) 127 | if notebook: 128 | return vega_html 129 | else: 130 | return vega_html.html() 131 | 132 | 133 | VEGA_TEMPLATE = """ 134 | ( function() { 135 | var _do_plot = function() { 136 | if ( (typeof vg == 'undefined') && (typeof IPython != 'undefined')) { 137 | $([IPython.events]).on("vega_loaded.vincent", _do_plot); 138 | return; 139 | } 140 | vg.parse.spec(%s, function(chart) { 141 | chart({el: "#vis%d"}).update(); 142 | }); 143 | }; 144 | _do_plot(); 145 | })(); 146 | """ 147 | -------------------------------------------------------------------------------- /traceml/traceml/vendor/matplotlylib/mplexporter/renderers/vincent_renderer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from .base import Renderer 3 | from ..exporter import Exporter 4 | 5 | 6 | class VincentRenderer(Renderer): 7 | def open_figure(self, fig, props): 8 | self.chart = None 9 | self.figwidth = int(props["figwidth"] * props["dpi"]) 10 | self.figheight = int(props["figheight"] * props["dpi"]) 11 | 12 | def draw_line(self, data, coordinates, style, label, mplobj=None): 13 | import vincent # only import if VincentRenderer is used 14 | 15 | if coordinates != "data": 16 | warnings.warn("Only data coordinates supported. Skipping this") 17 | linedata = {"x": data[:, 0], "y": data[:, 1]} 18 | line = vincent.Line( 19 | linedata, iter_idx="x", width=self.figwidth, height=self.figheight 20 | ) 21 | 22 | # TODO: respect the other style settings 23 | line.scales["color"].range = [style["color"]] 24 | 25 | if self.chart is None: 26 | self.chart = line 27 | else: 28 | warnings.warn("Multiple plot elements not yet supported") 29 | 30 | def draw_markers(self, data, coordinates, style, label, mplobj=None): 31 | import vincent # only import if VincentRenderer is used 32 | 33 | if coordinates != "data": 34 | warnings.warn("Only data coordinates supported. Skipping this") 35 | markerdata = {"x": data[:, 0], "y": data[:, 1]} 36 | markers = vincent.Scatter( 37 | markerdata, iter_idx="x", width=self.figwidth, height=self.figheight 38 | ) 39 | 40 | # TODO: respect the other style settings 41 | markers.scales["color"].range = [style["facecolor"]] 42 | 43 | if self.chart is None: 44 | self.chart = markers 45 | else: 46 | warnings.warn("Multiple plot elements not yet supported") 47 | 48 | 49 | def fig_to_vincent(fig): 50 | """Convert a matplotlib figure to a vincent object""" 51 | renderer = VincentRenderer() 52 | exporter = Exporter(renderer) 53 | exporter.run(fig) 54 | return renderer.chart 55 | -------------------------------------------------------------------------------- /traceml/traceml/vendor/matplotlylib/mplexporter/tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for matplotlib plot exporting 3 | """ 4 | 5 | 6 | def ipynb_vega_init(): 7 | """Initialize the IPython notebook display elements 8 | 9 | This function borrows heavily from the excellent vincent package: 10 | http://github.com/wrobstory/vincent 11 | """ 12 | try: 13 | from IPython.core.display import display, HTML 14 | except ImportError: 15 | print("IPython Notebook could not be loaded.") 16 | 17 | require_js = """ 18 | if (window['d3'] === undefined) {{ 19 | require.config({{ paths: {{d3: "http://d3js.org/d3.v3.min"}} }}); 20 | require(["d3"], function(d3) {{ 21 | window.d3 = d3; 22 | {0} 23 | }}); 24 | }}; 25 | if (window['topojson'] === undefined) {{ 26 | require.config( 27 | {{ paths: {{topojson: "http://d3js.org/topojson.v1.min"}} }} 28 | ); 29 | require(["topojson"], function(topojson) {{ 30 | window.topojson = topojson; 31 | }}); 32 | }}; 33 | """ 34 | d3_geo_projection_js_url = "http://d3js.org/d3.geo.projection.v0.min.js" 35 | d3_layout_cloud_js_url = "http://wrobstory.github.io/d3-cloud/" "d3.layout.cloud.js" 36 | topojson_js_url = "http://d3js.org/topojson.v1.min.js" 37 | vega_js_url = "http://trifacta.github.com/vega/vega.js" 38 | 39 | dep_libs = """$.getScript("%s", function() { 40 | $.getScript("%s", function() { 41 | $.getScript("%s", function() { 42 | $.getScript("%s", function() { 43 | $([IPython.events]).trigger("vega_loaded.vincent"); 44 | }) 45 | }) 46 | }) 47 | });""" % ( 48 | d3_geo_projection_js_url, 49 | d3_layout_cloud_js_url, 50 | topojson_js_url, 51 | vega_js_url, 52 | ) 53 | load_js = require_js.format(dep_libs) 54 | html = "" 55 | display(HTML(html)) 56 | -------------------------------------------------------------------------------- /traceml/traceml/vendor/matplotlylib/tools.py: -------------------------------------------------------------------------------- 1 | from .renderer import PlotlyRenderer 2 | from .mplexporter import Exporter 3 | 4 | 5 | def mpl_to_plotly(fig, resize=False, strip_style=False, verbose=False): 6 | """Convert a matplotlib figure to plotly dictionary and send. 7 | All available information about matplotlib visualizations are stored 8 | within a matplotlib.figure.Figure object. You can create a plot in python 9 | using matplotlib, store the figure object, and then pass this object to 10 | the fig_to_plotly function. In the background, mplexporter is used to 11 | crawl through the mpl figure object for appropriate information. This 12 | information is then systematically sent to the PlotlyRenderer which 13 | creates the JSON structure used to make plotly visualizations. Finally, 14 | these dictionaries are sent to plotly and your browser should open up a 15 | new tab for viewing! Optionally, if you're working in IPython, you can 16 | set notebook=True and the PlotlyRenderer will call plotly.iplot instead 17 | of plotly.plot to have the graph appear directly in the IPython notebook. 18 | Note, this function gives the user access to a simple, one-line way to 19 | render an mpl figure in plotly. If you need to trouble shoot, you can do 20 | this step manually by NOT running this fuction and entereing the following: 21 | =========================================================================== 22 | from plotly.matplotlylib import mplexporter, PlotlyRenderer 23 | # create an mpl figure and store it under a varialble 'fig' 24 | renderer = PlotlyRenderer() 25 | exporter = mplexporter.Exporter(renderer) 26 | exporter.run(fig) 27 | =========================================================================== 28 | You can then inspect the JSON structures by accessing these: 29 | renderer.layout -- a plotly layout dictionary 30 | renderer.data -- a list of plotly data dictionaries 31 | """ 32 | 33 | # Update vendor 34 | # This code was taken from: 35 | # https://github.com/matplotlib/matplotlib/pull/16772/files#diff-506cc6d736a0593e8bb820981b2c12ae # noqa 36 | # Removed in https://github.com/matplotlib/matplotlib/pull/16772 37 | from matplotlib.spines import Spine 38 | 39 | def is_frame_like(self): 40 | """return True if directly on axes frame 41 | This is useful for determining if a spine is the edge of an 42 | old style MPL plot. If so, this function will return True. 43 | """ 44 | self._ensure_position_is_set() 45 | position = self._position 46 | if isinstance(position, str): 47 | if position == "center": 48 | position = ("axes", 0.5) 49 | elif position == "zero": 50 | position = ("data", 0) 51 | if len(position) != 2: 52 | raise ValueError("position should be 2-tuple") 53 | position_type, amount = position 54 | if position_type == "outward" and amount == 0: 55 | return True 56 | else: 57 | return False 58 | 59 | Spine.is_frame_like = is_frame_like 60 | 61 | renderer = PlotlyRenderer() 62 | Exporter(renderer).run(fig) 63 | if resize: 64 | renderer.resize() 65 | if strip_style: 66 | renderer.strip_style() 67 | if verbose: 68 | print(renderer.msg) 69 | return renderer.plotly_fig 70 | --------------------------------------------------------------------------------