├── .ci └── scripts │ ├── check_gibberish │ ├── convert_checkpoint.sh │ ├── download_llama.sh │ ├── extract-sequence.py │ ├── gather_test_models.py │ ├── run-docs │ ├── validate.sh │ └── wget_checkpoint.sh ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── config.yml │ └── feature-request.yml ├── pytorch-probot.yml └── workflows │ ├── more-tests.yml │ ├── periodic.yml │ ├── pull.yml │ ├── run-readme-periodic.yml │ ├── run-readme-pr-linuxaarch64.yml │ ├── run-readme-pr-macos.yml │ ├── run-readme-pr-mps.yml │ ├── run-readme-pr.yml │ └── runner-cuda-dtype.yml ├── .gitignore ├── .gitmodules ├── .lintrunner.toml ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── dog.jpg └── view.jpg ├── dist_run.py ├── docs ├── ADVANCED-USERS.md ├── CONTRIBUTING.md ├── GGUF.md ├── README.md ├── distributed.md ├── local-model.md ├── model_customization.md ├── multimodal.md ├── native-execution.md └── quantization.md ├── install ├── .clang-format ├── .flake8 ├── .pins │ ├── et-pin.txt │ └── torchao-pin.txt ├── install_requirements.sh ├── install_torch.sh ├── install_torchao.sh ├── requirements-lintrunner.txt └── requirements.txt ├── pyproject.toml ├── pytest.ini ├── runner ├── LICENSE ├── Utils.cmake ├── aoti.cmake ├── build_android.sh ├── et.cmake └── run.cpp ├── tests ├── conftest.py └── test_chat_formatters.py ├── tokenizer ├── __init__.py ├── base.py ├── hf_tokenizer.py └── tiktoken.py ├── torchchat.py └── torchchat ├── README.md ├── __init__.py ├── cli ├── __init__.py ├── builder.py ├── cli.py ├── convert_hf_checkpoint.py └── download.py ├── distributed ├── README.md ├── __init__.py ├── checkpoint.py ├── checkpoint_utils.py ├── config_manager.py ├── dtensor_utils.py ├── force_download.py ├── inference_configs │ └── llama3_8B.toml ├── logging_utils.py ├── parallel_config.py ├── parallelize_llama.py ├── run_dist_inference.sh ├── utils.py ├── verification_utils.py ├── version.txt └── world_maker.py ├── edge ├── README.md ├── android │ ├── README.md │ └── torchchat │ │ ├── .gitignore │ │ ├── .idea │ │ └── .gitignore │ │ ├── app │ │ ├── .gitignore │ │ ├── build.gradle.kts │ │ ├── proguard-rules.pro │ │ └── src │ │ │ ├── androidTest │ │ │ └── java │ │ │ │ └── org │ │ │ │ └── pytorch │ │ │ │ └── torchchat │ │ │ │ └── PerfTest.java │ │ │ └── main │ │ │ ├── AndroidManifest.xml │ │ │ ├── java │ │ │ └── org │ │ │ │ └── pytorch │ │ │ │ └── torchchat │ │ │ │ ├── AppLog.java │ │ │ │ ├── DemoSharedPreferences.java │ │ │ │ ├── ETImage.java │ │ │ │ ├── ETLogging.java │ │ │ │ ├── LlmBenchmarkRunner.java │ │ │ │ ├── LogsActivity.java │ │ │ │ ├── LogsAdapter.java │ │ │ │ ├── MainActivity.java │ │ │ │ ├── Message.java │ │ │ │ ├── MessageAdapter.java │ │ │ │ ├── MessageType.java │ │ │ │ ├── ModelRunner.java │ │ │ │ ├── ModelRunnerCallback.java │ │ │ │ ├── ModelType.java │ │ │ │ ├── ModelUtils.java │ │ │ │ ├── PromptFormat.java │ │ │ │ ├── SettingsActivity.java │ │ │ │ └── SettingsFields.java │ │ │ └── res │ │ │ ├── drawable │ │ │ ├── banner_shape.xml │ │ │ ├── baseline_add_24.xml │ │ │ ├── baseline_add_photo_alternate_24.xml │ │ │ ├── baseline_article_24.xml │ │ │ ├── baseline_close_24.xml │ │ │ ├── baseline_delete_forever_24.xml │ │ │ ├── baseline_restart_alt_24.xml │ │ │ ├── baseline_send_24.xml │ │ │ ├── baseline_settings_24.xml │ │ │ ├── baseline_stop_24.xml │ │ │ ├── btn.xml │ │ │ ├── chat_background.xml │ │ │ ├── custom_button_round.xml │ │ │ ├── expand_circle_down.xml │ │ │ ├── ic_launcher_background.xml │ │ │ ├── ic_launcher_foreground.xml │ │ │ ├── input_text_shape.xml │ │ │ ├── logo.png │ │ │ ├── outline_add_box_48.xml │ │ │ ├── outline_camera_alt_48.xml │ │ │ ├── outline_image_48.xml │ │ │ ├── prompt_shape.xml │ │ │ ├── received_message.xml │ │ │ ├── sent_message.xml │ │ │ └── three_dots.xml │ │ │ ├── layout │ │ │ ├── activity_benchmarking.xml │ │ │ ├── activity_logs.xml │ │ │ ├── activity_main.xml │ │ │ ├── activity_settings.xml │ │ │ ├── logs_message.xml │ │ │ ├── received_message.xml │ │ │ ├── sent_message.xml │ │ │ └── system_message.xml │ │ │ ├── mipmap-anydpi-v26 │ │ │ ├── ic_launcher.xml │ │ │ └── ic_launcher_round.xml │ │ │ ├── mipmap-hdpi │ │ │ ├── ic_launcher.webp │ │ │ └── ic_launcher_round.webp │ │ │ ├── mipmap-mdpi │ │ │ ├── ic_launcher.webp │ │ │ └── ic_launcher_round.webp │ │ │ ├── mipmap-xhdpi │ │ │ ├── ic_launcher.webp │ │ │ └── ic_launcher_round.webp │ │ │ ├── mipmap-xxhdpi │ │ │ ├── ic_launcher.webp │ │ │ └── ic_launcher_round.webp │ │ │ ├── mipmap-xxxhdpi │ │ │ ├── ic_launcher.webp │ │ │ └── ic_launcher_round.webp │ │ │ ├── values │ │ │ ├── colors.xml │ │ │ ├── strings.xml │ │ │ ├── styles.xml │ │ │ └── themes.xml │ │ │ └── xml │ │ │ ├── backup_rules.xml │ │ │ └── data_extraction_rules.xml │ │ ├── build.gradle.kts │ │ ├── gradle.properties │ │ ├── gradle │ │ └── wrapper │ │ │ ├── gradle-wrapper.jar │ │ │ └── gradle-wrapper.properties │ │ ├── gradlew │ │ ├── gradlew.bat │ │ └── settings.gradle.kts └── docs │ ├── Android.md │ ├── executorch_setup.md │ └── iOS.md ├── export.py ├── generate.py ├── model.py ├── model_config ├── model_config.py ├── models.json └── tests │ └── test_model_config.py ├── model_params ├── 13B.json ├── 30B.json ├── 34B.json ├── 70B.json ├── 7B.json ├── CodeLlama-7b-Python-hf.json ├── DeepSeek-R1-Distill-Llama-8B.json ├── Granite-3.0-2B-Instruct.json ├── Granite-3.0-8B-Instruct.json ├── Granite-3.1-2B-Instruct.json ├── Granite-3.1-8B-Instruct.json ├── Granite-3B-Code.json ├── Granite-8B-Code.json ├── Llama-3.2-11B-Vision.json ├── Llama-Guard-3-1B-INT4.json ├── Llama-Guard-3-1B.json ├── Meta-Llama-3-70B.json ├── Meta-Llama-3-8B.json ├── Meta-Llama-3.1-70B-Tune.json ├── Meta-Llama-3.1-70B.json ├── Meta-Llama-3.1-8B-Tune.json ├── Meta-Llama-3.1-8B.json ├── Meta-Llama-3.2-1B.json ├── Meta-Llama-3.2-3B.json ├── Mistral-7B.json ├── llava-1.5.json ├── stories110M.json ├── stories15M.json └── stories42M.json ├── quant_config ├── README.md ├── cuda-32.json ├── cuda.json ├── desktop.json ├── mobile-32.json ├── mobile.json └── pi5.json ├── usages ├── README.md ├── __init__.py ├── browser.py ├── eval.py ├── openai_api.py └── server.py └── utils ├── build_utils.py ├── device_info.py ├── docs └── evaluation.md ├── gguf_loader.py ├── measure_time.py ├── quantize.py └── scripts ├── android_example.sh ├── build_native.sh ├── build_torchao_ops.sh ├── clone_torchao.sh ├── install_et.sh ├── install_utils.sh ├── patch_triton.py ├── prepare.sh ├── test_flow.sh ├── updown.py └── workflow.sh /.ci/scripts/check_gibberish: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # Check the spelling of the specified file as an indicator of 4 | # whether the model produced gibberish - this doesn't catch low quality 5 | # output, just utter garbage as a basic backstop 6 | 7 | cat "$1" 8 | 9 | ######################################################################## 10 | # 11 | # extract sequence from other output 12 | 13 | TMPFILE=/tmp/`basename "$1"`-sequence 14 | 15 | if [ "X$2" == "X--no-extract" ]; then 16 | cp "$1" $TMPFILE 17 | else 18 | # We extract only the sequence output and don't spell check status and performance stats 19 | python3 .ci/scripts/extract-sequence.py "$1" >$TMPFILE 20 | 21 | if [ $? -ne 0 ]; then 22 | echo "Sequence extraction failed. Exiting." 23 | exit 1 24 | fi 25 | fi 26 | 27 | ####################################################################### 28 | # 29 | # check whether aspell spell check evailable 30 | 31 | if command -v aspell &> /dev/null; then 32 | echo "Checking $TMPFILE for gibberish" 33 | else 34 | echo "Aspell is not installed or not in PATH." 35 | echo "Gibberish unchecked in $TMPFILE" 36 | exit 0 37 | fi 38 | 39 | ####################################################################### 40 | # 41 | # run spell check on the extracted sequence 42 | 43 | cat ${TMPFILE} | aspell -a -c | grep '^[\&#]' >/tmp/out.$$ 44 | # Exit with a non-zero status code if there were any spelling errors because: 45 | # * Finding one or more lines with & or # means we found a spelling error, might be gibberish 46 | if [ $? -ne 0 ]; then 47 | echo "No spelling errors found; likely correct operation. Success." 48 | exit 0 49 | fi 50 | cat /tmp/out.$$ 51 | echo "Spelling errors found; might indicate garbage output. Failing." 52 | exit 1 53 | -------------------------------------------------------------------------------- /.ci/scripts/convert_checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | set -eu 10 | 11 | function convert_checkpoint() { 12 | local MODEL_REPO="$1" 13 | local CHECKPOINT_NAME="${MODEL_REPO##*/}" 14 | 15 | if [[ $CHECKPOINT_NAME == *"stories15M"* || $CHECKPOINT_NAME == *"stories42M"* || $CHECKPOINT_NAME == *"stories110M"* ]]; then 16 | # We need this to make the workflow unique for all models because convert_hf_checkpoint will always convert the checkpoint to model.pth 17 | pushd "checkpoints/${MODEL_REPO}" 18 | if [ ! -f "model.pth" ]; then 19 | mv "$CHECKPOINT_NAME.pt" "model.pth" 20 | fi 21 | popd 22 | return 0 23 | fi 24 | 25 | [ -f "torchchat/cli/convert_hf_checkpoint.py" ] || exit 1 26 | 27 | if [ -f "checkpoints/$MODEL_REPO/model.pth" ]; then 28 | echo "Converted checkpoint already exists. Skipping conversion for $MODEL_REPO." 29 | return 0 30 | fi 31 | echo "Convert Huggingface checkpoint for $MODEL_REPO" 32 | python3 torchchat/cli/convert_hf_checkpoint.py --checkpoint-dir "checkpoints/$MODEL_REPO" 33 | } 34 | 35 | 36 | convert_checkpoint $1 37 | -------------------------------------------------------------------------------- /.ci/scripts/download_llama.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -xeou pipefail 4 | 5 | shopt -s globstar 6 | 7 | install_huggingface_cli() { 8 | pip install -U "huggingface_hub[cli]" 9 | } 10 | 11 | download_checkpoint() { 12 | # This funciton is "technically re-usable but ymmv" 13 | # includes org name, like / 14 | local repo_name=$1 15 | local include=$2 16 | # basically just removes the org in / 17 | local local_dir="checkpoints/${repo_name}" 18 | 19 | mkdir -p "${local_dir}" 20 | huggingface-cli download \ 21 | "${repo_name}" \ 22 | --quiet \ 23 | --include "${include}" \ 24 | --local-dir "${local_dir}" 25 | } 26 | 27 | # install huggingface-cli if not already installed 28 | if ! command -v huggingface-cli; then 29 | install_huggingface_cli 30 | fi 31 | 32 | # TODO: Eventually you could extend this to download different models 33 | # taking in some arguments similar to .ci/scripts/wget_checkpoint.sh 34 | download_checkpoint "meta-llama/Meta-Llama-3-8B" "original/*" 35 | -------------------------------------------------------------------------------- /.ci/scripts/extract-sequence.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def print_until_equals(filename): 5 | output = False 6 | past_output = False 7 | with open(filename, "r") as f: 8 | for line in f: 9 | if line.startswith("-" * 8): 10 | output = True 11 | if output and line.startswith("=" * 8): 12 | if past_output: 13 | print("Double end-of-sequence line") 14 | exit(1) 15 | past_output = True 16 | output = False 17 | if output: 18 | print(line) 19 | 20 | if not past_output: 21 | print("Did find sequence to output") 22 | exit(1) 23 | 24 | 25 | if __name__ == "__main__": 26 | if len(sys.argv) < 2: 27 | print(f"Usage:\n {sys.executable} {sys.argv[0]} filename") 28 | sys.exit(1) 29 | filename = sys.argv[1] 30 | print_until_equals(filename) 31 | -------------------------------------------------------------------------------- /.ci/scripts/run-docs: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | 3 | # Check if an argument was provided 4 | if [ -z "$1" ]; then 5 | echo "Must specify document to run" 6 | exit 1 7 | fi 8 | 9 | # Pre-initialize variables 10 | filepath="" 11 | # cuda supports padding, so no need to replace quantization for now. 12 | # otherwise add: 'cuda.json:cuda-32.json' to replace rules 13 | parameters="--replace llama3:stories15M,-l3:-l2,mobile.json:mobile-32.json --suppress huggingface-cli,HF_TOKEN" 14 | script_name="./run-${1}.sh" # Dynamically initialize script name 15 | 16 | # Use a case statement to handle the $1 argument 17 | case "$1" in 18 | "readme") 19 | filepath="README.md" 20 | parameters="--replace llama3.1:stories15M,-l3:-l2,mobile.json:mobile-32.json --suppress huggingface-cli,HF_TOKEN" 21 | ;; 22 | "quantization") 23 | filepath="docs/quantization.md" 24 | ;; 25 | "gguf") 26 | filepath="docs/GGUF.md" 27 | ;; 28 | "advanced") 29 | filepath="docs/ADVANCED-USERS.md" 30 | ;; 31 | "evaluation") 32 | filepath="torchchat/utils/docs/evaluation.md" 33 | ;; 34 | "multimodal") 35 | filepath="docs/multimodal.md" 36 | parameters="" # Clear parameters 37 | ;; 38 | "native") 39 | filepath="docs/native-execution.md" 40 | parameters="" # Clear parameters 41 | ;; 42 | "distributed") 43 | filepath="docs/distributed.md" 44 | parameters="--replace llama3.1:stories110M,-l3:-l2 --suppress huggingface-cli,HF_TOKEN" # Use stories110M to avoid need for authentication 45 | ;; 46 | "local") 47 | filepath="docs/local-model.md" 48 | parameters="" # Clear parameters 49 | ;; 50 | 51 | *) 52 | echo "Unknown option: $1" 53 | exit 1 54 | ;; 55 | esac 56 | 57 | # Generate the script 58 | echo "::group::Create script to run $1" 59 | python3 torchchat/utils/scripts/updown.py --file "$filepath" $parameters > "$script_name" 60 | # if something happened to updown processor, and it did not error out, fail with an exit 1 61 | echo "exit 1" >> "$script_name" 62 | echo "::endgroup::" 63 | 64 | # Run the script 65 | echo "::group::Run $1" 66 | echo "*******************************************" 67 | cat "$script_name" 68 | echo "*******************************************" 69 | set -x 70 | . "$script_name" 71 | echo "::endgroup::" 72 | -------------------------------------------------------------------------------- /.ci/scripts/wget_checkpoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -exu 9 | 10 | MODEL_REPO="$1" 11 | RESOURCES_STRING="$2" 12 | CHECKPOINT_NAME="${MODEL_REPO##*/}" 13 | 14 | # Create the directory for the checkpoint 15 | mkdir -p "checkpoints/${MODEL_REPO}" 16 | pushd "checkpoints/${MODEL_REPO}" || exit 17 | 18 | # Download all resources 19 | IFS=',' # Set the field separator to comma 20 | for resource in $RESOURCES_STRING; do 21 | echo "Downloading: $resource" 22 | if ! wget "$resource" 2>&1; then 23 | echo "Error: Failed to download $resource" >&2 24 | exit 1 25 | fi 26 | done 27 | 28 | popd || exit 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug Report 2 | description: Create a report to help us reproduce and fix the bug 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/pytorch/torchchat/issues?q=is%3Aissue+sort%3Acreated-desc+). 9 | - type: textarea 10 | attributes: 11 | label: 🐛 Describe the bug 12 | description: | 13 | Please provide a clear and concise description of what the bug is. 14 | 15 | If relevant, add a minimal example so that we can reproduce the error by running the code. 16 | 17 | If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. 18 | 19 | Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. 20 | placeholder: | 21 | A clear and concise description of what the bug is. 22 | 23 | ```python 24 | # Sample code to reproduce the problem 25 | ``` 26 | 27 | ``` 28 | The error message you got, with the full traceback. 29 | ``` 30 | validations: 31 | required: true 32 | - type: textarea 33 | attributes: 34 | label: Versions 35 | description: | 36 | Please run the following and paste the output below. 37 | ```sh 38 | wget https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py 39 | # For security purposes, please check the contents of collect_env.py before running it. 40 | python collect_env.py 41 | ``` 42 | validations: 43 | required: true 44 | - type: markdown 45 | attributes: 46 | value: > 47 | Thanks for contributing 🎉! 48 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | .github/ISSUE_TEMPLATE/config.yml 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature request 2 | description: Submit a proposal/request for a new torchchat feature 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 🚀 The feature, motivation and pitch 8 | description: > 9 | A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link to it here too. 10 | validations: 11 | required: true 12 | - type: textarea 13 | attributes: 14 | label: Alternatives 15 | description: > 16 | A description of alternative solutions or features you've considered, if any. 17 | - type: textarea 18 | attributes: 19 | label: Additional context 20 | description: > 21 | Add any other context or screenshots about the feature request. 22 | - type: textarea 23 | attributes: 24 | label: RFC (Optional) 25 | description: > 26 | Explain the design in enough detail for others to understand the problem, scope, and proposed solution. 27 | - type: markdown 28 | attributes: 29 | value: > 30 | Thanks for contributing 🎉! 31 | -------------------------------------------------------------------------------- /.github/pytorch-probot.yml: -------------------------------------------------------------------------------- 1 | ciflow_push_tags: 2 | - ciflow/periodic 3 | -------------------------------------------------------------------------------- /.github/workflows/run-readme-periodic.yml: -------------------------------------------------------------------------------- 1 | name: Run the README instructions periodically to ensure they work 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' # Runs daily at midnight UTC 6 | push: 7 | tags: 8 | - ciflow/periodic/* 9 | workflow_dispatch: 10 | 11 | jobs: 12 | test-readme: 13 | permissions: 14 | id-token: write 15 | contents: read 16 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 17 | secrets: inherit 18 | with: 19 | runner: linux.g5.4xlarge.nvidia.gpu 20 | secrets-env: "HF_TOKEN_PERIODIC" 21 | gpu-arch-type: cuda 22 | gpu-arch-version: "12.4" 23 | timeout: 60 24 | script: | 25 | echo "::group::Print machine info" 26 | uname -a 27 | echo "::endgroup::" 28 | 29 | echo "::group::Create script to run README" 30 | python3 torchchat/utils/scripts/updown.py --create-sections --file README.md > ./run-readme.sh 31 | # for good measure, if something happened to updown processor, 32 | # and it did not error out, fail with an exit 1 33 | echo "exit 1" >> ./run-readme.sh 34 | echo "::endgroup::" 35 | 36 | echo "::group::Run README" 37 | echo "*******************************************" 38 | cat ./run-readme.sh 39 | echo "*******************************************" 40 | bash -x ./run-readme.sh 41 | echo "::endgroup::" 42 | 43 | 44 | test-quantization-any: 45 | permissions: 46 | id-token: write 47 | contents: read 48 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 49 | with: 50 | runner: linux.g5.4xlarge.nvidia.gpu 51 | secrets: inherit 52 | gpu-arch-type: cuda 53 | gpu-arch-version: "12.4" 54 | timeout: 60 55 | script: | 56 | echo "::group::Print machine info" 57 | uname -a 58 | echo "::endgroup::" 59 | 60 | echo "::group::Create script to run quantization" 61 | python3 torchchat/utils/scripts/updown.py --create-sections --file docs/quantization.md > ./run-quantization.sh 62 | # for good measure, if something happened to updown processor, 63 | # and it did not error out, fail with an exit 1 64 | echo "exit 1" >> ./run-quantization.sh 65 | echo "::endgroup::" 66 | 67 | echo "::group::Run quantization" 68 | echo "*******************************************" 69 | cat ./run-quantization.sh 70 | echo "*******************************************" 71 | bash -x ./run-quantization.sh 72 | echo "::endgroup::" 73 | 74 | test-gguf-any: 75 | permissions: 76 | id-token: write 77 | contents: read 78 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 79 | secrets: inherit 80 | with: 81 | runner: linux.g5.4xlarge.nvidia.gpu 82 | secrets-env: "HF_TOKEN_PERIODIC" 83 | gpu-arch-type: cuda 84 | gpu-arch-version: "12.4" 85 | timeout: 60 86 | script: | 87 | echo "::group::Print machine info" 88 | uname -a 89 | echo "::endgroup::" 90 | 91 | echo "::group::Create script to run gguf" 92 | python3 torchchat/utils/scripts/updown.py --file docs/GGUF.md > ./run-gguf.sh 93 | # for good measure, if something happened to updown processor, 94 | # and it did not error out, fail with an exit 1 95 | echo "exit 1" >> ./run-gguf.sh 96 | echo "::endgroup::" 97 | 98 | echo "::group::Run gguf" 99 | echo "*******************************************" 100 | cat ./run-gguf.sh 101 | echo "*******************************************" 102 | bash -x ./run-gguf.sh 103 | echo "::endgroup::" 104 | 105 | echo "::group::Completion" 106 | echo "tests complete" 107 | echo "*******************************************" 108 | echo "::endgroup::" 109 | -------------------------------------------------------------------------------- /.github/workflows/run-readme-pr-linuxaarch64.yml: -------------------------------------------------------------------------------- 1 | name: Run the README instructions - with stories - on Linux aarch64 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | workflow_dispatch: 9 | 10 | jobs: 11 | test-readme-cpu: 12 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 13 | permissions: 14 | id-token: write 15 | contents: read 16 | with: 17 | runner: linux.arm64.2xlarge 18 | docker-image: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64-main" 19 | gpu-arch-type: cpu-aarch64 20 | timeout: 60 21 | script: | 22 | echo "::group::Print machine info" 23 | uname -a 24 | echo "::endgroup::" 25 | 26 | which pip || true 27 | which pip3 || true 28 | which conda || true 29 | # TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs readme 30 | 31 | echo "::group::Completion" 32 | echo "tests complete" 33 | echo "*******************************************" 34 | echo "::endgroup::" 35 | 36 | test-quantization-cpu: 37 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 38 | permissions: 39 | id-token: write 40 | contents: read 41 | with: 42 | runner: linux.arm64.2xlarge 43 | docker-image: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64-main" 44 | gpu-arch-type: cpu-aarch64 45 | timeout: 60 46 | script: | 47 | echo "::group::Print machine info" 48 | uname -a 49 | echo "::endgroup::" 50 | 51 | which pip || true 52 | which pip3 || true 53 | which conda || true 54 | 55 | # TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs quantization 56 | 57 | test-gguf-cpu: 58 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 59 | permissions: 60 | id-token: write 61 | contents: read 62 | with: 63 | runner: linux.arm64.2xlarge 64 | docker-image: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64-main" 65 | gpu-arch-type: cpu-aarch64 66 | timeout: 60 67 | script: | 68 | echo "::group::Print machine info" 69 | uname -a 70 | echo "::endgroup::" 71 | 72 | which pip || true 73 | which pip3 || true 74 | which conda || true 75 | 76 | # TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs gguf 77 | 78 | echo "::group::Completion" 79 | echo "tests complete" 80 | echo "*******************************************" 81 | echo "::endgroup::" 82 | 83 | test-advanced-cpu: 84 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 85 | permissions: 86 | id-token: write 87 | contents: read 88 | with: 89 | runner: linux.arm64.2xlarge 90 | docker-image: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64-main" 91 | gpu-arch-type: cpu-aarch64 92 | timeout: 60 93 | script: | 94 | echo "::group::Print machine info" 95 | uname -a 96 | echo "::endgroup::" 97 | 98 | which pip || true 99 | which pip3 || true 100 | which conda || true 101 | 102 | # TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs advanced 103 | 104 | echo "::group::Completion" 105 | echo "tests complete" 106 | echo "*******************************************" 107 | echo "::endgroup::" 108 | 109 | test-evaluation-cpu: 110 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 111 | permissions: 112 | id-token: write 113 | contents: read 114 | with: 115 | runner: linux.arm64.2xlarge 116 | docker-image: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64-main" 117 | gpu-arch-type: cpu-aarch64 118 | timeout: 60 119 | script: | 120 | echo "::group::Print machine info" 121 | uname -a 122 | echo "::endgroup::" 123 | 124 | which pip || true 125 | which pip3 || true 126 | which conda || true 127 | 128 | # TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs evaluation 129 | 130 | echo "::group::Completion" 131 | echo "tests complete" 132 | echo "*******************************************" 133 | echo "::endgroup::" 134 | -------------------------------------------------------------------------------- /.github/workflows/runner-cuda-dtype.yml: -------------------------------------------------------------------------------- 1 | name: Run the aoti runner with CUDA using stories 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | workflow_dispatch: 9 | 10 | jobs: 11 | test-runner-aot-cuda: 12 | permissions: 13 | id-token: write 14 | contents: read 15 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 16 | with: 17 | runner: linux.g5.4xlarge.nvidia.gpu 18 | secrets-env: "HF_TOKEN_PERIODIC" 19 | gpu-arch-type: cuda 20 | gpu-arch-version: "12.4" 21 | timeout: 60 22 | script: | 23 | echo "::group::Print machine info" 24 | uname -a 25 | echo "::endgroup::" 26 | 27 | echo "::group::Download checkpoints" 28 | # Install requirements 29 | 30 | ./install/install_requirements.sh cuda 31 | bash torchchat/utils/scripts/build_native.sh aoti 32 | pip3 list 33 | python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' 34 | echo "::endgroup::" 35 | 36 | echo "::group::Download checkpoints" 37 | mkdir -p checkpoints/stories15M 38 | pushd checkpoints/stories15M 39 | wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt 40 | wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model 41 | popd 42 | echo "::endgroup::" 43 | 44 | echo "::group::Run inference" 45 | export MODEL_PATH=checkpoints/stories15M/stories15M.pt 46 | export MODEL_NAME=stories15M 47 | export MODEL_DIR=/tmp 48 | 49 | set -eou pipefail 50 | export MODEL_DIR=${PWD}/checkpoints/stories15M 51 | export PROMPT="Once upon a time in a land far away" 52 | 53 | for DTYPE in bfloat16; do 54 | python torchchat.py generate --dtype ${DTYPE} --checkpoint-path ${MODEL_DIR}/stories15M.pt --temperature 0 --prompt "${PROMPT}" --device cuda 55 | 56 | python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --output-aoti-package-path /tmp/model.pt2 57 | 58 | ./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}" 59 | 60 | done 61 | 62 | echo "tests complete" 63 | echo "******************************************" 64 | echo "::endgroup::" 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | .vscode 10 | .model-artifacts/ 11 | .venv 12 | .torchchat 13 | 14 | # Build directories 15 | build/android/* 16 | et-build/* 17 | torchao-build/* 18 | runner-et/cmake-out/* 19 | runner-aoti/cmake-out/* 20 | cmake-out/ 21 | 22 | # Example project Android Studio ignore 23 | torchchat/edge/android/torchchat/.idea/* 24 | 25 | 26 | # pte files 27 | *.pte 28 | 29 | # debug / logging files 30 | system_info.txt 31 | 32 | # intermediate system file 33 | .DS_Store 34 | 35 | # build artifacts 36 | checkpoints/ 37 | exportedModels/ 38 | 39 | # test script 40 | _torchchat_test_script.py 41 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "runner/third-party/tokenizers"] 2 | path = runner/third-party/tokenizers 3 | url = https://github.com/pytorch-labs/tokenizers 4 | -------------------------------------------------------------------------------- /.lintrunner.toml: -------------------------------------------------------------------------------- 1 | merge_base_with = "origin/main" 2 | 3 | [[linter]] 4 | code = 'FLAKE8' 5 | include_patterns = ['**/*.py'] 6 | command = [ 7 | 'python3', 8 | '-m', 9 | 'lintrunner_adapters', 10 | 'run', 11 | 'flake8_linter', 12 | '--', 13 | '@{{PATHSFILE}}' 14 | ] 15 | init_command = [ 16 | 'python3', 17 | '-m', 18 | 'lintrunner_adapters', 19 | 'run', 20 | 'pip_init', 21 | '--dry-run={{DRYRUN}}', 22 | '--requirement=install/requirements-lintrunner.txt', 23 | ] 24 | 25 | # Black + usort 26 | [[linter]] 27 | code = 'UFMT' 28 | include_patterns = [ 29 | '**/*.py', 30 | '**/*.pyi', 31 | ] 32 | command = [ 33 | 'python3', 34 | '-m', 35 | 'lintrunner_adapters', 36 | 'run', 37 | 'ufmt_linter', 38 | '--', 39 | '@{{PATHSFILE}}' 40 | ] 41 | init_command = [ 42 | 'python3', 43 | '-m', 44 | 'lintrunner_adapters', 45 | 'run', 46 | 'pip_init', 47 | '--dry-run={{DRYRUN}}', 48 | '--no-black-binary', 49 | '--requirement=install/requirements-lintrunner.txt', 50 | ] 51 | is_formatter = true 52 | 53 | #CLANGFORMAT 54 | [[linter]] 55 | code = 'CLANGFORMAT' 56 | include_patterns = [ 57 | '**/*.h', 58 | '**/*.cpp', 59 | ] 60 | command = [ 61 | 'python3', 62 | '-m', 63 | 'lintrunner_adapters', 64 | 'run', 65 | 'clangformat_linter', 66 | '--binary=clang-format', 67 | '--fallback', 68 | '--', 69 | '@{{PATHSFILE}}' 70 | ] 71 | init_command = [ 72 | 'python3', 73 | '-m', 74 | 'lintrunner_adapters', 75 | 'run', 76 | 'pip_init', 77 | '--dry-run={{DRYRUN}}', 78 | '--requirement=install/requirements-lintrunner.txt', 79 | ] 80 | is_formatter = true 81 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.24) 2 | set(CMAKE_CXX_STANDARD 17) 3 | IF(DEFINED ENV{TORCHCHAT_ROOT}) 4 | set(TORCHCHAT_ROOT $ENV{TORCHCHAT_ROOT}) 5 | ELSE() 6 | set(TORCHCHAT_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) 7 | ENDIF() 8 | 9 | project(Torchchat) 10 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") 11 | 12 | # include tokenizer 13 | add_subdirectory(runner/third-party/tokenizers) 14 | 15 | # include et_run executable 16 | include(runner/et.cmake) 17 | if(TARGET et_run) 18 | target_link_libraries(et_run PUBLIC tokenizers microkernels-prod) 19 | target_include_directories(et_run PUBLIC runner/third-party/tokenizers/include) 20 | endif() 21 | 22 | # include aoti_run executable 23 | include(runner/aoti.cmake) 24 | if(TARGET aoti_run) 25 | target_link_libraries(aoti_run tokenizers) 26 | target_include_directories(aoti_run PUBLIC runner/third-party/tokenizers/include) 27 | endif() 28 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to torchchat 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | 1. Fork the repo and create your branch from `main`. 10 | 2. If you've added code that should be tested, add tests. 11 | 3. If you've changed APIs, update the documentation. 12 | 4. Ensure the test suite passes. 13 | 5. Make sure your code is well-formatted using the repo linter. See "Linting" for details. 14 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 15 | 16 | 17 | ### Linting 18 | Install the lintrunner dependencies from the requirements file. 19 | ``` 20 | pip3 install -r install/requirements-lintrunner.txt 21 | ``` 22 | 23 | After making your changes locally, run the lintrunner and apply all suggestions to your changes. 24 | You can do this from the top-level torchchat directory - it will apply suggestions only to files that 25 | you have touched. 26 | ``` 27 | lintrunner -a 28 | ``` 29 | 30 | ## Contributor License Agreement ("CLA") 31 | In order to accept your pull request, we need you to submit a CLA. You only need 32 | to do this once to work on any of Meta's open source projects. 33 | 34 | Complete your CLA here: 35 | 36 | ## Issues 37 | We use GitHub issues to track public bugs. Please ensure your description is 38 | clear and has sufficient instructions to be able to reproduce the issue. 39 | 40 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 41 | disclosure of security bugs. In those cases, please go through the process 42 | outlined on that page and do not file a public issue. 43 | 44 | ## License 45 | By contributing to `torchchat`, you agree that your contributions will be licensed 46 | under the LICENSE file in the root directory of this source tree. 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024 Meta 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /assets/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/assets/dog.jpg -------------------------------------------------------------------------------- /assets/view.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/assets/view.jpg -------------------------------------------------------------------------------- /docs/GGUF.md: -------------------------------------------------------------------------------- 1 | > [!WARNING] 2 | > Files in this directory may be outdated, incomplete, scratch notes, or a WIP. torchchat provides no guarantees on these files as references. Please refer to the root README for stable features and documentation. 3 | 4 | # Using GGUF Models 5 | 6 | 13 | 14 | We support parsing [GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) files with 15 | the following tensor types: 16 | - F16 17 | - F32 18 | - Q4_0 19 | - Q6_K 20 | 21 | If an unsupported type is encountered while parsing a GGUF file, an 22 | exception is raised. 23 | 24 | We now go over an example of using GGUF files in the torchchat flow. 25 | 26 | ### Download resources 27 | 28 | First download a GGUF model and tokenizer. In this example, we use a 29 | Q4_0 GGUF file. (Note that Q4_0 is only the dominant tensor type in 30 | the file, but the file also contains GGUF tensors of types Q6_K, F16, 31 | and F32.) 32 | 33 | ``` 34 | # Download resources 35 | mkdir -p ggufs/open_orca 36 | pushd ggufs/open_orca 37 | 38 | curl -o open_orca.Q4_0.gguf "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true" 39 | curl -o ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model 40 | 41 | popd 42 | 43 | export GGUF_MODEL_PATH=ggufs/open_orca/open_orca.Q4_0.gguf 44 | export GGUF_TOKENIZER_PATH=ggufs/open_orca/tokenizer.model 45 | 46 | # Define export paths for examples below 47 | export GGUF_SO_PATH=/tmp/gguf_model.so 48 | export GGUF_PTE_PATH=/tmp/gguf_model.pte 49 | ``` 50 | 51 | ### Eager generate 52 | We can generate text in eager mode as we did before, but we now pass a GGUF file path. 53 | ``` 54 | python3 torchchat.py generate --gguf-path ${GGUF_MODEL_PATH} --tokenizer-path ${GGUF_TOKENIZER_PATH} --temperature 0 --prompt "Once upon a time" --max-new-tokens 15 55 | ``` 56 | 57 | ### AOTI export + generate 58 | ``` 59 | # Convert the model for use 60 | python3 torchchat.py export --gguf-path ${GGUF_MODEL_PATH} --output-dso-path ${GGUF_SO_PATH} 61 | 62 | # Generate using the PTE model that was created by the export command 63 | python3 torchchat.py generate --gguf-path ${GGUF_MODEL_PATH} --dso-path ${GGUF_SO_PATH} --tokenizer-path ${GGUF_TOKENIZER_PATH} --temperature 0 --prompt "Once upon a time" --max-new-tokens 15 64 | 65 | ``` 66 | 67 | ### ExecuTorch export + generate 68 | Before running this example, you must first [Set-up ExecuTorch](torchchat/edge/docs/executorch_setup.md). 69 | ``` 70 | # Convert the model for use 71 | python3 torchchat.py export --gguf-path ${GGUF_MODEL_PATH} --output-pte-path ${GGUF_PTE_PATH} 72 | 73 | # Generate using the PTE model that was created by the export command 74 | python3 torchchat.py generate --gguf-path ${GGUF_MODEL_PATH} --pte-path ${GGUF_PTE_PATH} --tokenizer-path ${GGUF_TOKENIZER_PATH} --temperature 0 --prompt "Once upon a time" --max-new-tokens 15 75 | ``` 76 | 77 | ### Advanced: loading unsupported GGUF formats in torchchat 78 | GGUF formats not presently supported natively in torchchat can be 79 | converted to one of the supported formats with GGUF's 80 | [quantize](https://github.com/ggerganov/llama.cpp/tree/master/examples/quantize) utility. 81 | If you convert to the FP16 or FP32 formats with GGUF's quantize utility, you can 82 | then requantize these models with torchchat's native quantization workflow. 83 | 84 | **Please note that quantizing and dequantizing is a lossy process, and 85 | you will get the best results by starting with the original 86 | unquantized model, not a previously quantized and then 87 | dequantized model.** 88 | 89 | As an example, suppose you have [llama.cpp cloned and installed](https://github.com/ggerganov/llama.cpp) at ${GGUF}. 90 | You can then convert a model to FP16 with the following command: 91 | 92 | 96 | 97 | [skip default]: begin 98 | ``` 99 | ${GGUF}/quantize --allow-requantize path_of_model_you_are_converting_from.gguf path_for_model_you_are_converting_to.gguf fp16 100 | ``` 101 | [skip default]: end 102 | 103 | For example, to convert the quantized model you downloaded above to an FP16 model, you would execute: 104 | ``` 105 | ${GGUF}/quantize --allow-requantize ${GGUF_MODEL_PATH} ./open_orca_fp16.gguf fp16 106 | ``` 107 | 108 | After the model is converted to a supported format like FP16, you can proceed using the instructions above. 109 | 110 | [end default]: end 111 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Most Docs in this directory are unstable 2 | 3 | Explicitly calling out that the docs in this directory may be outdated, incomplete, scratch notes, or a WIP. 4 | torchchat provides no guarantees on these files as references. 5 | 6 | Please refer to the root README for stable features and documentation. 7 | 8 | --- 9 | 10 | Docs that are updated and used as **Source of Truth**: 11 | - [Model Customization](model_customization.md) 12 | - [Quantization](quantization.md) 13 | -------------------------------------------------------------------------------- /docs/model_customization.md: -------------------------------------------------------------------------------- 1 | # Model Customization 2 | 3 | By default, torchchat (and PyTorch) defaults to unquantized [eager execution](https://pytorch.org/blog/optimizing-production-pytorch-performance-with-graph-transformations/). 4 | 5 | This page goes over the different options torchchat provides for customizing the model execution for inference. 6 | - Device 7 | - Compilation 8 | - Model Precision 9 | - Quantization 10 | 11 | 12 | ## Device 13 | 14 | ``` 15 | python3 (chat | generate | browser | server | export | eval) --device [cpu | cuda | mps] ... 16 | ``` 17 | 18 | To leverage a specific accelerator, the target device can be set. 19 | 20 | By default, torchchat defaults to the fastest executor available in the system, chosen in this 21 | order: cuda, mps, and cpu. 22 | 23 | 24 | ## Compilation: JIT-compiled execution 25 | ``` 26 | python3 (chat | generate | browser | server | eval) [--compile][--compile_prefill] ... 27 | ``` 28 | 29 | To improve performance, you can compile the model with `--compile`; 30 | trading off the time to first token processed with time per token. 31 | 32 | To improve performance further, at the cost of increased compile time, you may also compile the 33 | prefill with `--compile_prefill`. 34 | 35 | To learn more about compilation, check out: https://pytorch.org/get-started/pytorch-2.0/ 36 | 37 | For CPU, you can use `--max-autotune` to further improve the performance with `--compile` and `compile-prefill`. 38 | 39 | See [`max-autotune on CPU tutorial`](https://pytorch.org/tutorials/prototype/max_autotune_on_CPU_tutorial.html). 40 | 41 | ## Model Precision 42 | 43 | ``` 44 | python3 (chat | generate | browser | server | export | eval) --dtype [fast | fast16 | bf16 | fp16 | fp32] ... 45 | ``` 46 | 47 | To reduce the memory bandwidth requirement and to take advantage of higher density compute available, 48 | the model can use lower precision floating point representations. 49 | For example, many GPUs and some of the CPUs have good support for bfloat16 and float16. 50 | 51 | Unlike gpt-fast which uses bfloat16 as default, torchchat uses the dtype 52 | "fast16". This picks the best performing 16-bit floating point type 53 | available (for execution with Executorch, macOS/ARM and Linux/x86 platforms). 54 | For example on macOS, support depends on the OS version, with versions starting 55 | with 14.0 supporting bfloat16 as support, and float16 for earlier OS version 56 | based on system support for these data types. 57 | 58 | The "fast" data type is also provided as a virtual data type that defaults 59 | to the best floating point data type available on the selected device. 60 | Currently, this behaves the same as "fast16", but with "fp32" when exporting 61 | to ExecuTorch. 62 | 63 | 64 | ## Quantization 65 | 66 | ``` 67 | python3 (chat | generate | browser | server | export | eval) [--quantize] ... 68 | ``` 69 | 70 | To further minimize memory requirements, accelerate inference speeds, and 71 | decrease power consumption the model can also be quantized. 72 | Torchchat leverages [torchao](https://github.com/pytorch/ao) for quantization. 73 | 74 | See the [quantization guide](quantization.md) for examples and more details. 75 | -------------------------------------------------------------------------------- /install/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = B,C,E,F,P,W,B9,TOR0,TOR1,TOR2 3 | max-line-length = 120 4 | ignore = 5 | # Black conflicts and overlaps. 6 | B950, 7 | E111, 8 | E115, 9 | E117, 10 | E121, 11 | E122, 12 | E123, 13 | E124, 14 | E125, 15 | E126, 16 | E127, 17 | E128, 18 | E129, 19 | E131, 20 | E201, 21 | E202, 22 | E203, 23 | E221, 24 | E222, 25 | E225, 26 | E226, 27 | E227, 28 | E231, 29 | E241, 30 | E251, 31 | E252, 32 | E261, 33 | E262, 34 | E265, 35 | E271, 36 | E272, 37 | E301, 38 | E302, 39 | E303, 40 | E305, 41 | E306, 42 | E501, 43 | E502, 44 | E701, 45 | E702, 46 | E703, 47 | E704, 48 | W291, 49 | W292, 50 | W293, 51 | W391, 52 | W504, 53 | 54 | # Too opinionated. 55 | E265, 56 | E266, 57 | E402, 58 | E722, 59 | B001, 60 | P207, 61 | B003, 62 | P208, 63 | C403, 64 | W503, 65 | 66 | # Bugbear has opinions: https://github.com/PyCQA/flake8-bugbear#opinionated-warnings 67 | B904, 68 | B905, 69 | B906, 70 | B907, 71 | exclude = 72 | ./.git, 73 | *.pyi 74 | 75 | max-complexity = 12 76 | 77 | -------------------------------------------------------------------------------- /install/.pins/et-pin.txt: -------------------------------------------------------------------------------- 1 | b173722085b3f555d6ba4533d6bbaddfd7c71144 2 | -------------------------------------------------------------------------------- /install/.pins/torchao-pin.txt: -------------------------------------------------------------------------------- 1 | b95cf189e4aca1a44886258c40e2c834ca0d1045 2 | -------------------------------------------------------------------------------- /install/install_requirements.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -eou pipefail 9 | 10 | # Install required python dependencies for developing 11 | # Dependencies are defined in .pyproject.toml 12 | if [ -z "${PYTHON_EXECUTABLE:-}" ]; 13 | then 14 | if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]]; 15 | then 16 | PYTHON_EXECUTABLE=python3 17 | else 18 | PYTHON_EXECUTABLE=python 19 | fi 20 | fi 21 | echo "Using python executable: $PYTHON_EXECUTABLE" 22 | 23 | PYTHON_SYS_VERSION="$($PYTHON_EXECUTABLE -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")" 24 | # Check python version. Expect at least 3.10.x 25 | if ! $PYTHON_EXECUTABLE -c " 26 | import sys 27 | if sys.version_info < (3, 10): 28 | sys.exit(1) 29 | "; 30 | then 31 | echo "Python version must be at least 3.10.x. Detected version: $PYTHON_SYS_VERSION" 32 | exit 1 33 | fi 34 | 35 | if [[ "$PYTHON_EXECUTABLE" == "python" ]]; 36 | then 37 | PIP_EXECUTABLE=pip 38 | elif [[ "$PYTHON_EXECUTABLE" == "python3" ]]; 39 | then 40 | PIP_EXECUTABLE=pip3 41 | else 42 | PIP_EXECUTABLE=pip${PYTHON_SYS_VERSION} 43 | fi 44 | 45 | echo "Using pip executable: $PIP_EXECUTABLE" 46 | 47 | ( 48 | set -x 49 | $PIP_EXECUTABLE install -r install/requirements.txt 50 | ) 51 | 52 | bash install/install_torch.sh 53 | -------------------------------------------------------------------------------- /install/install_torch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | if [ -z "${PYTHON_EXECUTABLE:-}" ]; 9 | then 10 | if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]]; 11 | then 12 | PYTHON_EXECUTABLE=python3 13 | else 14 | PYTHON_EXECUTABLE=python 15 | fi 16 | fi 17 | echo "Using python executable: $PYTHON_EXECUTABLE" 18 | 19 | if [[ "$PYTHON_EXECUTABLE" == "python" ]]; 20 | then 21 | PIP_EXECUTABLE=pip 22 | elif [[ "$PYTHON_EXECUTABLE" == "python3" ]]; 23 | then 24 | PIP_EXECUTABLE=pip3 25 | else 26 | PIP_EXECUTABLE=pip${PYTHON_SYS_VERSION} 27 | fi 28 | echo "Using pip executable: $PIP_EXECUTABLE" 29 | 30 | # Since torchchat often uses main-branch features of pytorch, only the nightly 31 | # pip versions will have the required features. The PYTORCH_NIGHTLY_VERSION value should 32 | # agree with the third-party/pytorch pinned submodule commit. 33 | # 34 | # NOTE: If a newly-fetched version of the executorch repo changes the value of 35 | # PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary 36 | # package versions. 37 | PYTORCH_NIGHTLY_VERSION=dev20250418 38 | 39 | # Nightly version for torchvision 40 | VISION_NIGHTLY_VERSION=dev20250418 41 | 42 | # Nightly version for torchtune 43 | TUNE_NIGHTLY_VERSION=dev20250418 44 | 45 | # The pip repository that hosts nightly torch packages. cpu by default. 46 | # If cuda is available, based on presence of nvidia-smi, install the pytorch nightly 47 | # with cuda for faster execution on cuda GPUs. 48 | if [[ -x "$(command -v nvidia-smi)" ]]; 49 | then 50 | TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cu126" 51 | elif [[ -x "$(command -v rocminfo)" ]]; 52 | then 53 | TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/rocm6.2" 54 | elif [[ -x "$(command -v xpu-smi)" ]]; 55 | then 56 | TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/xpu" 57 | else 58 | TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu" 59 | fi 60 | 61 | # pip packages needed by exir. 62 | if [[ -x "$(command -v xpu-smi)" ]]; 63 | then 64 | REQUIREMENTS_TO_INSTALL=( 65 | torch=="2.8.0.${PYTORCH_NIGHTLY_VERSION}" 66 | torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}" 67 | #torchtune=="0.7.0" # no 0.6.0 on xpu nightly 68 | ) 69 | elif [[ -x "$(command -v npu-smi)" ]]; 70 | then 71 | REQUIREMENTS_TO_INSTALL=( 72 | torch=="2.7.0.dev20250310+cpu" 73 | torchvision=="0.22.0.dev20250310" 74 | torchtune=="0.6.0" 75 | ) 76 | else 77 | REQUIREMENTS_TO_INSTALL=( 78 | torch=="2.8.0.${PYTORCH_NIGHTLY_VERSION}" 79 | torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}" 80 | torchtune=="0.7.0.${TUNE_NIGHTLY_VERSION}" 81 | ) 82 | fi 83 | 84 | # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same 85 | ( 86 | set -x 87 | $PIP_EXECUTABLE uninstall -y triton 88 | ) 89 | 90 | # Install the requirements. --extra-index-url tells pip to look for package 91 | # versions on the provided URL if they aren't available on the default URL. 92 | ( 93 | set -x 94 | $PIP_EXECUTABLE install --extra-index-url "${TORCH_NIGHTLY_URL}" \ 95 | "${REQUIREMENTS_TO_INSTALL[@]}" 96 | ) 97 | 98 | # Temporatory instal torchtune nightly from cpu nightly link since no torchtune nightly for xpu now 99 | # TODO: Change to install torchtune from xpu nightly link, once torchtune xpu nightly is ready 100 | if [[ -x "$(command -v xpu-smi)" ]]; 101 | then 102 | ( 103 | set -x 104 | $PIP_EXECUTABLE install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" \ 105 | torchtune=="0.6.0.${TUNE_NIGHTLY_VERSION}" 106 | ) 107 | fi 108 | 109 | bash install/install_torchao.sh 110 | 111 | # Delete since already patched in PT main 112 | if [[ -x "$(command -v nvidia-smi)" ]]; then 113 | ( 114 | set -x 115 | $PYTHON_EXECUTABLE torchchat/utils/scripts/patch_triton.py 116 | ) 117 | fi 118 | -------------------------------------------------------------------------------- /install/install_torchao.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | # USE_CPP=1 indicates that the torchao experimental aten kernels will be built and loaded 10 | # if on Mac with Apple Silicon 11 | 12 | if [ -z "${PYTHON_EXECUTABLE:-}" ]; 13 | then 14 | if [[ -z ${CONDA_DEFAULT_ENV:-} ]] || [[ ${CONDA_DEFAULT_ENV:-} == "base" ]] || [[ ! -x "$(command -v python)" ]]; 15 | then 16 | PYTHON_EXECUTABLE=python3 17 | else 18 | PYTHON_EXECUTABLE=python 19 | fi 20 | fi 21 | echo "Using python executable: $PYTHON_EXECUTABLE" 22 | 23 | if [[ "$PYTHON_EXECUTABLE" == "python" ]]; 24 | then 25 | PIP_EXECUTABLE=pip 26 | elif [[ "$PYTHON_EXECUTABLE" == "python3" ]]; 27 | then 28 | PIP_EXECUTABLE=pip3 29 | else 30 | PIP_EXECUTABLE=pip${PYTHON_SYS_VERSION} 31 | fi 32 | echo "Using pip executable: $PIP_EXECUTABLE" 33 | 34 | 35 | export TORCHAO_PIN=$(cat install/.pins/torchao-pin.txt) 36 | ( 37 | set -x 38 | USE_CPP=1 $PIP_EXECUTABLE install git+https://github.com/pytorch/ao.git@${TORCHAO_PIN} 39 | ) 40 | -------------------------------------------------------------------------------- /install/requirements-lintrunner.txt: -------------------------------------------------------------------------------- 1 | # Lintrunner itself 2 | lintrunner==0.11.0 3 | lintrunner-adapters==0.11.0 4 | 5 | # Flake 8 and its dependencies 6 | flake8==6.0.0 7 | flake8-breakpoint==1.1.0 8 | flake8-bugbear==23.6.5 9 | flake8-comprehensions==3.12.0 10 | flake8-pyi==23.5.0 11 | mccabe==0.7.0 12 | pycodestyle==2.10.0 13 | torchfix==0.1.1 14 | 15 | # UFMT 16 | black==24.3.0 17 | ufmt==2.5.1 18 | usort==1.0.5 19 | 20 | # Clang 21 | clang-format==18.1.3 22 | -------------------------------------------------------------------------------- /install/requirements.txt: -------------------------------------------------------------------------------- 1 | # Requires python >=3.10 2 | 3 | # Hugging Face download 4 | huggingface_hub 5 | 6 | # GGUF import 7 | gguf 8 | 9 | # Tiktoken tokenizer for Llama 3 and other advanced models 10 | tiktoken 11 | 12 | # Tokenizers and jinja2 for other non-llama models that use HF tokenizers 13 | tokenizers 14 | jinja2 15 | 16 | # Miscellaneous 17 | snakeviz 18 | sentencepiece 19 | numpy >= 1.17 20 | blobfile 21 | tomli >= 1.1.0 ; python_version < "3.11" 22 | openai 23 | 24 | # Build tools 25 | wheel 26 | cmake>=3.24, < 4.0.0 # 4.0 is BC breaking 27 | ninja 28 | zstd 29 | 30 | # Test tools 31 | pytest 32 | 33 | # Browser mode 34 | streamlit 35 | 36 | # Server mode 37 | flask 38 | 39 | # eval 40 | lm_eval==0.4.7 41 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "torchchat" 7 | version = "0.1.0" 8 | description = "PyTorch showcase for running LLMs on local devices" 9 | authors = [ 10 | {name="PyTorch Team", email="packages@pytorch.org"}, 11 | ] 12 | license = { file = "LICENSE" } 13 | keywords = ["pytorch", "machine learning", "llm"] 14 | readme = "README.md" 15 | 16 | requires-python = ">=3.10" 17 | dependencies=[ 18 | # Hugging Face downloads 19 | "huggingface_hub", 20 | 21 | # GGUF import 22 | "gguf", 23 | 24 | # Tiktoken tokenizer for Llama 3 and other advanced models 25 | "tiktoken", 26 | 27 | # Tokenizers and jinja2 for other non-llama models that use HF tokenizers 28 | "tokenizers", 29 | "jinja2", 30 | 31 | # Miscellaneous 32 | "snakeviz", 33 | "sentencepiece", 34 | "numpy>=1.17", 35 | "blobfile", 36 | "tomli>=1.1.0; python_version<'3.11'", 37 | "openai", 38 | 39 | # Build tools 40 | "wheel", 41 | "cmake>=3.24,<4.0.0", # 4.0 is BC breaking 42 | "ninja", 43 | "zstd", 44 | 45 | # Test tools 46 | "pytest", 47 | 48 | # Browser mode 49 | "streamlit", 50 | 51 | # Server mode 52 | "flask", 53 | 54 | # eval 55 | "lm-eval==0.4.7", 56 | ] 57 | 58 | [tool.setuptools] 59 | packages = ["torchchat"] 60 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | model_config: Tests related to model config 4 | -------------------------------------------------------------------------------- /runner/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Andrej Karpathy 4 | Copyright (c) 2024 Meta Platforms 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /runner/Utils.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | # This is the funtion to use -Wl, --whole-archive to link static library NB: 9 | # target_link_options is broken for this case, it only append the interface link 10 | # options of the first library. 11 | function(kernel_link_options target_name) 12 | # target_link_options(${target_name} INTERFACE 13 | # "$") 14 | target_link_options( 15 | ${target_name} INTERFACE "SHELL:LINKER:--whole-archive \ 16 | $ \ 17 | LINKER:--no-whole-archive") 18 | endfunction() 19 | 20 | # Same as kernel_link_options but it's for MacOS linker 21 | function(macos_kernel_link_options target_name) 22 | target_link_options(${target_name} INTERFACE 23 | "SHELL:LINKER:-force_load,$") 24 | endfunction() 25 | 26 | # Ensure that the load-time constructor functions run. By default, the linker 27 | # would remove them since there are no other references to them. 28 | function(target_link_options_shared_lib target_name) 29 | if(APPLE) 30 | macos_kernel_link_options(${target_name}) 31 | else() 32 | kernel_link_options(${target_name}) 33 | endif() 34 | endfunction() 35 | -------------------------------------------------------------------------------- /runner/aoti.cmake: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | cmake_minimum_required(VERSION 3.24) 9 | set(CMAKE_CXX_STANDARD 17) 10 | IF(DEFINED ENV{TORCHCHAT_ROOT}) 11 | set(TORCHCHAT_ROOT $ENV{TORCHCHAT_ROOT}) 12 | ELSE() 13 | set(TORCHCHAT_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) 14 | ENDIF() 15 | 16 | find_package(CUDA) 17 | 18 | find_package(Torch 2.4.0) 19 | if(Torch_FOUND) 20 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g ${TORCH_CXX_FLAGS} -fpermissive") 21 | 22 | add_executable(aoti_run runner/run.cpp) 23 | target_compile_options(aoti_run PUBLIC -D__AOTI_MODEL__) 24 | if(DEFINED TORCH_CUDA_LIBRARIES) 25 | target_compile_options(aoti_run PUBLIC -DUSE_CUDA) 26 | endif() 27 | target_include_directories(aoti_run PRIVATE ${TORCHCHAT_ROOT}/runner) 28 | target_link_libraries(aoti_run "${TORCH_LIBRARIES}" m) 29 | set_property(TARGET aoti_run PROPERTY CXX_STANDARD 17) 30 | endif() 31 | 32 | if (LINK_TORCHAO_OPS) 33 | target_link_libraries(aoti_run "${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/libtorchao_ops_aten${CMAKE_SHARED_LIBRARY_SUFFIX}") 34 | endif() 35 | -------------------------------------------------------------------------------- /runner/build_android.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | source "$(dirname "${BASH_SOURCE[0]}")/../torchchat/utils/scripts/install_utils.sh" 11 | 12 | if ["${ANDROID_NDK}" == ""]; then 13 | echo "Please set ANDROID_NDK enviornment variable." 14 | echo "For example it can be /Users/guest/Desktop/android-ndk-r26." 15 | echo "You can use setup_android_ndk function in torchchat/utils/scripts/android_example.sh" 16 | echo "to set up; or you can download from Android NDK website" 17 | exit 1 18 | else 19 | echo "ANDROID_NDK set to ${ANDROID_NDK}" 20 | fi 21 | 22 | export ET_BUILD_DIR="et-build-android" 23 | export CMAKE_OUT_DIR="cmake-out-android" 24 | export EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT="OFF" 25 | export EXECUTORCH_BUILD_KERNELS_CUSTOM="ON" 26 | export CMAKE_OUT_DIR="cmake-out-android" 27 | 28 | build_runner_et() { 29 | rm -rf cmake-out-android 30 | echo "ET BUILD DIR IS ${ET_BUILD_DIR}" 31 | cmake -DET_USE_ADAPTIVE_THREADS=ON -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -S . -B cmake-out-android -G Ninja 32 | cmake --build cmake-out-android/ -j16 --config Release --target et_run 33 | } 34 | 35 | find_cmake_prefix_path 36 | install_pip_dependencies 37 | clone_executorch 38 | export ENABLE_ET_PYBIND=false 39 | install_executorch_python_libs $ENABLE_ET_PYBIND 40 | 41 | export CMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake 42 | export ANDROID_ABI=arm64-v8a 43 | export ANDROID_PLATFORM=android-23 44 | install_executorch_cpp_libs 45 | build_runner_et 46 | -------------------------------------------------------------------------------- /runner/et.cmake: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | cmake_minimum_required(VERSION 3.24) 9 | set(CMAKE_CXX_STANDARD 17) 10 | 11 | IF(DEFINED ENV{ET_BUILD_DIR}) 12 | set(ET_BUILD_DIR $ENV{ET_BUILD_DIR}) 13 | ELSE() 14 | set(ET_BUILD_DIR "et-build") 15 | ENDIF() 16 | 17 | MESSAGE(STATUS "Using ET BUILD DIR: --[${ET_BUILD_DIR}]--") 18 | 19 | IF(DEFINED ENV{CMAKE_OUT_DIR}) 20 | set(CMAKE_OUT_DIR $ENV{CMAKE_OUT_DIR}) 21 | ELSE() 22 | set(CMAKE_OUT_DIR "cmake-out") 23 | ENDIF() 24 | 25 | IF(DEFINED ENV{TORCHCHAT_ROOT}) 26 | set(TORCHCHAT_ROOT $ENV{TORCHCHAT_ROOT}) 27 | ELSE() 28 | set(TORCHCHAT_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) 29 | ENDIF() 30 | 31 | project(Torchchat) 32 | 33 | IF(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) 34 | SET(CMAKE_INSTALL_PREFIX ${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install CACHE PATH "Setting it to a default value" FORCE) 35 | ENDIF(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) 36 | 37 | # Building for Android. Since Android overwrites CMAKE_FIND_ROOT_PATH normal 38 | # CMAKE_INSTALL_PREFIX won't work. Redirect CMAKE_FIND_ROOT_PATH to it. 39 | # This should check any cross compilation but let's do Android for now 40 | if(ANDROID) 41 | set(CMAKE_FIND_ROOT_PATH "${CMAKE_INSTALL_PREFIX}") 42 | endif() 43 | 44 | include(CMakePrintHelpers) 45 | include(runner/Utils.cmake) 46 | 47 | cmake_print_variables(TORCHCHAT_ROOT) 48 | 49 | MESSAGE(STATUS "Looking for excutorch in ${CMAKE_INSTALL_PREFIX}") 50 | 51 | find_package(executorch CONFIG HINTS ${CMAKE_INSTALL_PREFIX}) 52 | 53 | if(executorch_FOUND) 54 | set(_common_include_directories ${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/src) 55 | 56 | cmake_print_variables(_common_include_directories) 57 | 58 | set(_srcs runner/run.cpp) 59 | set(_common_compile_options -D__ET__MODEL -D_GLIBCXX_USE_CXX11_ABI=1) 60 | if(ET_USE_ADAPTIVE_THREADS) 61 | list(APPEND _common_compile_options -DET_USE_ADAPTIVE_THREADS) 62 | 63 | set(EXECUTORCH_SRC_ROOT ${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/src/executorch) 64 | set(XNNPACK_ROOT ${EXECUTORCH_SRC_ROOT}/backends/xnnpack) 65 | list(APPEND _srcs ${EXECUTORCH_SRC_ROOT}/extension/threadpool/cpuinfo_utils.cpp) 66 | list(APPEND _common_include_directories 67 | ${XNNPACK_ROOT}/third-party/cpuinfo/include) 68 | 69 | list(APPEND _common_include_directories 70 | ${XNNPACK_ROOT}/third-party/pthreadpool/include) 71 | endif() 72 | add_library(custom_ops STATIC IMPORTED) 73 | set_property(TARGET custom_ops PROPERTY IMPORTED_LOCATION ${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libcustom_ops.a) 74 | 75 | target_include_directories(executorch INTERFACE ${_common_include_directories}) # Ideally ExecuTorch installation process would do this 76 | add_executable(et_run ${_srcs}) 77 | 78 | target_compile_options(et_run PUBLIC ${_common_compile_options}) 79 | 80 | # Link ET runtime + extensions 81 | target_link_libraries( 82 | et_run PRIVATE 83 | executorch 84 | extension_module 85 | extension_tensor 86 | extension_data_loader 87 | extension_threadpool 88 | optimized_kernels 89 | quantized_kernels 90 | portable_kernels 91 | cpublas 92 | eigen_blas 93 | # The libraries below need to be whole-archived linked 94 | optimized_native_cpu_ops_lib 95 | quantized_ops_lib 96 | xnnpack_backend 97 | microkernels-prod 98 | XNNPACK 99 | pthreadpool 100 | cpuinfo 101 | custom_ops 102 | ) 103 | target_link_options_shared_lib(optimized_native_cpu_ops_lib) 104 | target_link_options_shared_lib(quantized_ops_lib) 105 | target_link_options_shared_lib(xnnpack_backend) 106 | target_link_options_shared_lib(custom_ops) 107 | 108 | # Not clear why linking executorch as whole-archive outside android/apple is leading 109 | # to double registration. Most likely because of linkage issues. 110 | # Will figure this out later. Until then use this. 111 | if(ANDROID OR APPLE) 112 | target_link_options_shared_lib(executorch) 113 | endif() 114 | 115 | # This one is needed for cpuinfo where it uses android specific log lib 116 | if(ANDROID) 117 | target_link_libraries(et_run PRIVATE log) 118 | endif() 119 | 120 | if(LINK_TORCHAO_OPS) 121 | target_link_libraries(et_run PRIVATE "$") 122 | target_link_libraries(et_run PRIVATE 123 | "${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/libtorchao_kernels_aarch64.a" 124 | ) 125 | endif() 126 | 127 | else() 128 | MESSAGE(WARNING "ExecuTorch package not found") 129 | endif() 130 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global pytest config, fixtures, and helpers go here! 3 | """ 4 | 5 | # Standard 6 | import os 7 | import sys 8 | 9 | # Make sure tests can import torchchat 10 | sys.path.append( 11 | os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) 12 | ) 13 | -------------------------------------------------------------------------------- /tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/tokenizer/__init__.py -------------------------------------------------------------------------------- /tokenizer/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Abstract base class for all tokenizer classes in python matching c++ interface. 8 | """ 9 | 10 | # Standard 11 | from abc import ABC, abstractmethod 12 | from typing import List 13 | 14 | 15 | class TokenizerBase(ABC): 16 | __doc__ = __doc__ 17 | 18 | @abstractmethod 19 | def encode(self, s: str, *, bos: bool = False, eos: bool = False) -> List[int]: 20 | """Encode the given string and optionally include bos/eos tokens""" 21 | 22 | @abstractmethod 23 | def decode(self, ids: List[int]) -> str: 24 | """Decode the given token ids into a string""" 25 | 26 | @abstractmethod 27 | def bos_id(self) -> int: 28 | """The id of the begin-of-string token""" 29 | 30 | @abstractmethod 31 | def eos_id(self) -> int: 32 | """The id of the end-of-string token""" 33 | -------------------------------------------------------------------------------- /torchchat.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import logging 9 | import signal 10 | import sys 11 | 12 | # MPS ops missing with Multimodal torchtune 13 | # https://github.com/pytorch/torchtune/issues/1723 14 | import os 15 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 16 | 17 | from torchchat.cli.cli import ( 18 | add_arguments_for_verb, 19 | arg_init, 20 | check_args, 21 | INVENTORY_VERBS, 22 | KNOWN_VERBS, 23 | ) 24 | 25 | default_device = "cpu" 26 | 27 | 28 | def signal_handler(sig, frame): 29 | print("\nInterrupted by user. Bye!\n") 30 | sys.exit(0) 31 | 32 | 33 | if __name__ == "__main__": 34 | # Set the signal handler for SIGINT 35 | signal.signal(signal.SIGINT, signal_handler) 36 | 37 | # Initialize the top-level parser 38 | parser = argparse.ArgumentParser( 39 | prog="torchchat", 40 | add_help=True, 41 | ) 42 | 43 | subparsers = parser.add_subparsers( 44 | dest="command", 45 | help="The specific command to run", 46 | ) 47 | subparsers.required = True 48 | 49 | VERB_HELP = { 50 | "chat": "Chat interactively with a model via the CLI", 51 | "generate": "Generate responses from a model given a prompt", 52 | "browser": "Chat interactively with a model in a locally hosted browser", 53 | "export": "Export a model artifact to AOT Inductor or ExecuTorch", 54 | "download": "Download model artifacts", 55 | "list": "List all supported models", 56 | "remove": "Remove downloaded model artifacts", 57 | "where": "Return directory containing downloaded model artifacts", 58 | "server": "[WIP] Starts a locally hosted REST server for model interaction", 59 | "eval": "Evaluate a model via lm-eval", 60 | } 61 | for verb, description in VERB_HELP.items(): 62 | subparser = subparsers.add_parser(verb, help=description) 63 | add_arguments_for_verb(subparser, verb) 64 | 65 | # Now parse the arguments 66 | args = parser.parse_args() 67 | 68 | # Don't initialize for Inventory management subcommands 69 | # TODO: Remove when arg_init is refactored 70 | if args.command not in INVENTORY_VERBS: 71 | args = arg_init(args) 72 | logging.basicConfig( 73 | format="%(message)s", level=logging.DEBUG if args.verbose else logging.INFO 74 | ) 75 | 76 | if args.command == "chat": 77 | # enable "chat" 78 | args.chat = True 79 | check_args(args, "chat") 80 | from generate import main as generate_main 81 | 82 | generate_main(args) 83 | elif args.command == "browser": 84 | print( 85 | "\nTo test out the browser please use: streamlit run torchchat/usages/browser.py \n" 86 | ) 87 | elif args.command == "server": 88 | check_args(args, "server") 89 | from torchchat.usages.server import main as server_main 90 | 91 | server_main(args) 92 | elif args.command == "generate": 93 | check_args(args, "generate") 94 | from torchchat.generate import main as generate_main 95 | 96 | generate_main(args) 97 | elif args.command == "eval": 98 | from torchchat.usages.eval import main as eval_main 99 | 100 | eval_main(args) 101 | elif args.command == "export": 102 | check_args(args, "export") 103 | from torchchat.export import main as export_main 104 | 105 | export_main(args) 106 | elif args.command == "download": 107 | check_args(args, "download") 108 | from torchchat.cli.download import download_main 109 | 110 | download_main(args) 111 | elif args.command == "list": 112 | check_args(args, "list") 113 | from torchchat.cli.download import list_main 114 | 115 | list_main(args) 116 | elif args.command == "where": 117 | check_args(args, "where") 118 | from torchchat.cli.download import where_main 119 | 120 | where_main(args) 121 | elif args.command == "remove": 122 | check_args(args, "remove") 123 | from torchchat.cli.download import remove_main 124 | 125 | remove_main(args) 126 | else: 127 | parser.print_help() 128 | -------------------------------------------------------------------------------- /torchchat/README.md: -------------------------------------------------------------------------------- 1 | # Chat with LLMs Everywhere 2 | 3 | This directory is a WIP path that will host most of the files currently living in root 4 | -------------------------------------------------------------------------------- /torchchat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/__init__.py -------------------------------------------------------------------------------- /torchchat/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/cli/__init__.py -------------------------------------------------------------------------------- /torchchat/distributed/README.md: -------------------------------------------------------------------------------- 1 | # The Work in this directory is Experimental 2 | 3 | Please refer to the root README for stable features and documentation. 4 | 5 | ## What this directory for? 6 | 7 | 8 | This directory is home to the upcoming support for multi-node, very large scale inference - think Llama 405B and Mistral 123B. 9 | 10 | More updates shortly! 11 | -------------------------------------------------------------------------------- /torchchat/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchchat.distributed.checkpoint import load_checkpoints_to_model 8 | from torchchat.distributed.logging_utils import SingletonLogger 9 | from torchchat.distributed.parallel_config import ParallelDims 10 | from torchchat.distributed.parallelize_llama import parallelize_llama 11 | from torchchat.distributed.utils import init_distributed 12 | from torchchat.distributed.world_maker import launch_distributed 13 | -------------------------------------------------------------------------------- /torchchat/distributed/dtensor_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributed import DeviceMesh 3 | from torch.distributed._tensor import DTensor, Shard, Replicate, Placement 4 | from torch.distributed.tensor._utils import compute_local_shape_and_global_offset 5 | 6 | from collections import defaultdict 7 | from typing import Optional, Sequence 8 | 9 | from torchchat.distributed.logging_utils import SingletonLogger 10 | logger = SingletonLogger.get_logger() 11 | 12 | 13 | def convert_to_dtensor( 14 | full_tensor: torch.Tensor, 15 | dtensor_template: DTensor, 16 | ) -> DTensor: 17 | """ 18 | Converts a full tensor to a DTensor with the same placements as the given 19 | DTensor template. 20 | """ 21 | if full_tensor.shape != dtensor_template.shape: 22 | raise ValueError( 23 | f"Shape mismatch: weight tensor shape {full_tensor.shape} " 24 | f"doesn't match DTensor shape {dtensor_template.shape}" 25 | ) 26 | 27 | new_dtensor = shard( 28 | full_tensor, 29 | dtensor_template.placements, 30 | dtensor_template.device_mesh 31 | ) 32 | return new_dtensor 33 | 34 | 35 | def shard( 36 | full_tensor: torch.Tensor, 37 | placements: Sequence[Placement], 38 | device_mesh: Optional[DeviceMesh] = None, 39 | ) -> DTensor: 40 | """ 41 | Shards a full tensor based on indicated placements, and returns a 42 | DTensor containing the shard. 43 | Args: 44 | full_tensor (torch.Tensor): the full tensor to be sharded. 45 | placements (Sequence[:class:`Placement`]): the placements that 46 | describes how to place the local tensor on DeviceMesh. 47 | device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the 48 | DTensor. Must have same dimension as the number of placements. 49 | If not specified, would be retrieve from current context. 50 | Returns: 51 | A :class:`DTensor` object with the shard as its local tensor. 52 | Examples: 53 | >>> # xdoctest: +SKIP("need world_size and rank") 54 | >>> device_mesh = dist.init_device_mesh("cuda", (world_size,)) 55 | >>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}") 56 | >>> placements = [Shard(1)] 57 | >>> dtensor = shard(full_tensor, placements, device_mesh) 58 | """ 59 | device_mesh = device_mesh or _mesh_resources.get_current_mesh() 60 | 61 | shape, offset = compute_local_shape_and_global_offset( 62 | full_tensor.shape, device_mesh, placements 63 | ) 64 | slices = [ 65 | slice(cur_offset, cur_offset + cur_shape) 66 | for cur_shape, cur_offset in zip(shape, offset) 67 | ] 68 | local_tensor = full_tensor[slices] 69 | return DTensor.from_local(local_tensor, device_mesh, placements) 70 | -------------------------------------------------------------------------------- /torchchat/distributed/force_download.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | 3 | # tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") 4 | model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-70B-Instruct") 5 | print("Model weights and tokenizer downloaded") 6 | -------------------------------------------------------------------------------- /torchchat/distributed/inference_configs/llama3_8B.toml: -------------------------------------------------------------------------------- 1 | # torchchat Distributed Config.toml 2 | 3 | [job] 4 | dump_folder = "./outputs" 5 | description = "Llama 3 distributed inference" 6 | use_for_integration_test = true 7 | 8 | [profiling] 9 | enable_profiling = false 10 | save_traces_folder = "profile_trace" 11 | profile_freq = 10 12 | enable_memory_snapshot = false 13 | save_memory_snapshot_folder = "memory_snapshot" 14 | 15 | [metrics] 16 | enable_color_printing = true 17 | enable_tensorboard = true 18 | save_tb_folder = "tb" 19 | 20 | [model] 21 | name = "llama3" 22 | flavor = "8B" 23 | tokenizer_path = "./test/assets/test_tiktoken.model" 24 | dtype = "bfloat16" 25 | 26 | [parallel] 27 | pipeline_parallel_degree = 1 28 | tensor_parallel_degree = 2 29 | enable_async_tensor_parallel=false 30 | 31 | [inference] 32 | batch_size = 8 33 | seq_len = 2048 34 | reps=1 # for profiling inference runs, can run repeatedly 35 | fp8_linear = "" 36 | compile = false 37 | 38 | [pipelining] 39 | pipeline_parallel_split_points= "layers.4" # string list of placements 40 | pipeline_parallel_schedule="gpipe" # TODO - what is best inference schedule for continous batching 41 | pipeline_parallel_split_mode = "manual" 42 | pipeline_parallel_microbatches=1 # TODO - continuous batching 43 | -------------------------------------------------------------------------------- /torchchat/distributed/logging_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | from datetime import datetime 10 | from typing import Optional 11 | 12 | 13 | def millisecond_timestamp(include_year: bool = False) -> str: 14 | format_string = "%Y-%m-%d %H:%M:%S.%f" if include_year else "%m-%d %H:%M:%S.%f" 15 | return datetime.now().strftime(format_string)[:-3] 16 | 17 | 18 | class CompactFormatter(logging.Formatter): 19 | def __init__( 20 | self, 21 | fmt: Optional[str] = None, 22 | datefmt: Optional[str] = None, 23 | style: str = "%", 24 | validate: bool = True, 25 | *, 26 | defaults: Optional[dict] = None, 27 | show_lower_levels: bool = True, 28 | ): 29 | super().__init__(fmt, datefmt, style, validate, defaults=defaults) 30 | self.show_lower_levels = show_lower_levels 31 | self.original_fmt = fmt 32 | 33 | def format(self, record: logging.LogRecord) -> str: 34 | # Remove .py extension from filename 35 | record.filename = os.path.splitext(record.filename)[0] 36 | 37 | if self.show_lower_levels or record.levelno > logging.INFO: 38 | return super().format(record) 39 | else: 40 | # Create a copy of the record and modify it 41 | new_record = logging.makeLogRecord(record.__dict__) 42 | new_record.levelname = "" 43 | # Temporarily change the format string 44 | temp_fmt = self.original_fmt.replace(" - %(levelname)s", "") 45 | self._style._fmt = temp_fmt 46 | formatted_message = super().format(new_record) 47 | # Restore the original format string 48 | self._style._fmt = self.original_fmt 49 | return formatted_message 50 | 51 | 52 | class SingletonLogger: 53 | """Singleton (global) logger to avoid logging duplication""" 54 | 55 | _instance = None 56 | 57 | @classmethod 58 | def get_logger( 59 | cls, 60 | name: str = "global_logger", 61 | level: int = logging.INFO, 62 | include_year: bool = False, 63 | show_lower_levels: bool = False, 64 | ) -> logging.Logger: 65 | """ 66 | Get or create a singleton logger instance. 67 | 68 | :param name: Name of the logger 69 | :param level: Logging level 70 | :param include_year: Whether to include the year in timestamps 71 | :param show_lower_levels: Whether to show level names for INFO and DEBUG messages 72 | :return: Logger instance 73 | """ 74 | if cls._instance is None: 75 | cls._instance = cls._setup_logger( 76 | name, level, include_year, show_lower_levels 77 | ) 78 | return cls._instance 79 | 80 | @staticmethod 81 | def _setup_logger( 82 | name: str, 83 | level: int, 84 | include_year: bool = False, 85 | show_lower_levels: bool = False, 86 | ) -> logging.Logger: 87 | logger = logging.getLogger(name) 88 | 89 | if not logger.handlers: 90 | logger.setLevel(level) 91 | 92 | console_handler = logging.StreamHandler() 93 | console_handler.setLevel(level) 94 | 95 | formatter = CompactFormatter( 96 | "%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", 97 | show_lower_levels=show_lower_levels, 98 | ) 99 | formatter.formatTime = lambda record, datefmt=None: millisecond_timestamp( 100 | include_year 101 | ) 102 | console_handler.setFormatter(formatter) 103 | logger.addHandler(console_handler) 104 | 105 | # Suppress verbose torch.profiler logging 106 | os.environ["KINETO_LOG_LEVEL"] = "5" 107 | 108 | logger.propagate = False 109 | return logger 110 | -------------------------------------------------------------------------------- /torchchat/distributed/parallel_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | 9 | from torch.distributed.device_mesh import init_device_mesh 10 | 11 | from torchchat.distributed.logging_utils import SingletonLogger 12 | logger = SingletonLogger.get_logger() 13 | 14 | @dataclass 15 | class ParallelDims: 16 | tp: int 17 | pp: int 18 | world_size: int 19 | 20 | def __post_init__(self): 21 | self._validate() 22 | 23 | def _validate(self): 24 | tp, pp = self.tp, self.pp 25 | assert tp >= 1, tp 26 | assert pp >= 1, pp 27 | assert ( 28 | tp * pp == self.world_size 29 | ), f"Invalid parallel dims: tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" 30 | 31 | def build_mesh(self, device_type): 32 | dims = [] 33 | names = [] 34 | for d, name in zip( 35 | [self.pp, self.tp], ["pp", "tp"], strict=True 36 | ): 37 | if d > 1: 38 | dims.append(d) 39 | names.append(name) 40 | logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") 41 | names = tuple(names) 42 | return init_device_mesh(device_type, dims, mesh_dim_names=names) 43 | 44 | @property 45 | def tp_enabled(self): 46 | return self.tp > 1 47 | 48 | @property 49 | def pp_enabled(self): 50 | return self.pp > 1 51 | -------------------------------------------------------------------------------- /torchchat/distributed/run_dist_inference.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | # libUV is a scalable backend for TCPStore which is used in processGroup 11 | # rendezvous. This is the recommended backend for distributed training. 12 | export USE_LIBUV=1 13 | 14 | # use envs as local overrides for convenience 15 | # e.g. 16 | # LOG_RANK=0,1 NGPU=4 ./run_dist_inference.sh 17 | 18 | NGPU=${NGPU:-"2"} 19 | 20 | # TODO: We need to decide how to log for inference. 21 | # by default log just rank 0 output, 22 | LOG_RANK=${LOG_RANK:-0} 23 | 24 | overrides="" 25 | if [ $# -ne 0 ]; then 26 | overrides="$*" 27 | fi 28 | 29 | torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ 30 | --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ 31 | ../torchchat.py chat llama3 --distributed $overrides --dcp-dir ~/.torchchat/model-cache/meta-llama/Meta-Llama-3-70B-Instruct/original 32 | -------------------------------------------------------------------------------- /torchchat/distributed/version.txt: -------------------------------------------------------------------------------- 1 | 0.0.1 2 | -------------------------------------------------------------------------------- /torchchat/distributed/world_maker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import Optional, Tuple 9 | 10 | from torch.distributed.device_mesh import DeviceMesh 11 | 12 | from torchchat.distributed.parallel_config import ParallelDims 13 | from torchchat.distributed.utils import init_distributed 14 | from torchchat.distributed.logging_utils import SingletonLogger 15 | 16 | from .config_manager import InferenceConfig 17 | 18 | 19 | logger = SingletonLogger.get_logger() 20 | 21 | 22 | def launch_distributed( 23 | toml_config: str, 24 | ) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: 25 | """ 26 | Initialize distributed related setups if the user specified 27 | using distributed inference. If not, this is a no-op. 28 | 29 | Args: 30 | toml_config: str: 31 | toml file for the inference config. 32 | Returns: 33 | Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: 34 | - The first element is an optional DeviceMesh object, 35 | which which describes the mesh topology of devices for the DTensor. 36 | - The second element is an optional ParallelDims object, 37 | which represents the parallel dimensions configuration. 38 | """ 39 | #init_logger() TODO - do we want formatted logging? 40 | world_size = int(os.environ["WORLD_SIZE"]) 41 | config = InferenceConfig() 42 | config.parse_args(toml_config) 43 | 44 | 45 | logger.info(f"toml parsing completed. Launching with {world_size} GPUs") 46 | # review parallel config 47 | tp = config.parallel.tensor_parallel_degree 48 | pp = config.parallel.pipeline_parallel_degree 49 | 50 | parallel_dims = ParallelDims( 51 | tp=tp, 52 | pp=pp, 53 | world_size=world_size, 54 | ) 55 | init_distributed() 56 | world_mesh = parallel_dims.build_mesh(device_type="cuda") 57 | logger.info(f"world_mesh created: {world_mesh}") 58 | return world_mesh, parallel_dims 59 | -------------------------------------------------------------------------------- /torchchat/edge/README.md: -------------------------------------------------------------------------------- 1 | # Chat with LLMs Everywhere: Edge 2 | 3 | This directory is a WIP path that will host files related to inference on edge devices 4 | -------------------------------------------------------------------------------- /torchchat/edge/android/README.md: -------------------------------------------------------------------------------- 1 | # torchchat on Android 2 | 3 | This is the app for deploying torchchat mobile inference on Android. Please see [this page](https://github.com/pytorch/torchchat/blob/main/README.md#deploy-and-run-on-android) for the tutorial. 4 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | .gradle 3 | /local.properties 4 | /.idea/caches 5 | /.idea/libraries 6 | /.idea/modules.xml 7 | /.idea/workspace.xml 8 | /.idea/navEditor.xml 9 | /.idea/assetWizardSettings.xml 10 | .DS_Store 11 | /build 12 | /captures 13 | .externalNativeBuild 14 | .cxx 15 | local.properties 16 | *.aar 17 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/build.gradle.kts: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | plugins { 10 | id("com.android.application") 11 | id("org.jetbrains.kotlin.android") 12 | } 13 | 14 | android { 15 | namespace = "org.pytorch.torchchat" 16 | compileSdk = 34 17 | 18 | defaultConfig { 19 | applicationId = "org.pytorch.torchchat" 20 | minSdk = 28 21 | targetSdk = 33 22 | versionCode = 1 23 | versionName = "1.0" 24 | 25 | testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" 26 | vectorDrawables { useSupportLibrary = true } 27 | externalNativeBuild { cmake { cppFlags += "" } } 28 | } 29 | 30 | buildTypes { 31 | release { 32 | isMinifyEnabled = false 33 | proguardFiles(getDefaultProguardFile("proguard-android-optimize.txt"), "proguard-rules.pro") 34 | } 35 | } 36 | compileOptions { 37 | sourceCompatibility = JavaVersion.VERSION_1_8 38 | targetCompatibility = JavaVersion.VERSION_1_8 39 | } 40 | kotlinOptions { jvmTarget = "1.8" } 41 | buildFeatures { compose = true } 42 | composeOptions { kotlinCompilerExtensionVersion = "1.4.3" } 43 | packaging { resources { excludes += "/META-INF/{AL2.0,LGPL2.1}" } } 44 | } 45 | 46 | dependencies { 47 | implementation("androidx.core:core-ktx:1.9.0") 48 | implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.1") 49 | implementation("androidx.activity:activity-compose:1.7.0") 50 | implementation(platform("androidx.compose:compose-bom:2023.03.00")) 51 | implementation("androidx.compose.ui:ui") 52 | implementation("androidx.compose.ui:ui-graphics") 53 | implementation("androidx.compose.ui:ui-tooling-preview") 54 | implementation("androidx.compose.material3:material3") 55 | implementation("androidx.appcompat:appcompat:1.6.1") 56 | implementation("androidx.camera:camera-core:1.3.0-rc02") 57 | implementation("androidx.constraintlayout:constraintlayout:2.2.0-alpha12") 58 | implementation("com.facebook.fbjni:fbjni:0.5.1") 59 | implementation("com.google.code.gson:gson:2.8.6") 60 | implementation(files("libs/executorch.aar")) 61 | implementation("com.google.android.material:material:1.12.0") 62 | implementation("androidx.activity:activity:1.9.0") 63 | testImplementation("junit:junit:4.13.2") 64 | androidTestImplementation("androidx.test.ext:junit:1.1.5") 65 | androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") 66 | androidTestImplementation(platform("androidx.compose:compose-bom:2023.03.00")) 67 | androidTestImplementation("androidx.compose.ui:ui-test-junit4") 68 | debugImplementation("androidx.compose.ui:ui-tooling") 69 | debugImplementation("androidx.compose.ui:ui-test-manifest") 70 | } 71 | 72 | tasks.register("setup") { 73 | doFirst { 74 | exec { 75 | commandLine("sh", "examples/demo-apps/android/LlamaDemo/setup.sh") 76 | workingDir("../../../../../") 77 | } 78 | } 79 | } 80 | 81 | tasks.register("setupQnn") { 82 | doFirst { 83 | exec { 84 | commandLine("sh", "examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh") 85 | workingDir("../../../../../") 86 | } 87 | } 88 | } 89 | 90 | tasks.register("download_prebuilt_lib") { 91 | doFirst { 92 | exec { 93 | commandLine("sh", "examples/demo-apps/android/LlamaDemo/download_prebuilt_lib.sh") 94 | workingDir("../../../../../") 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/proguard-rules.pro: -------------------------------------------------------------------------------- 1 | # Add project specific ProGuard rules here. 2 | # You can control the set of applied configuration files using the 3 | # proguardFiles setting in build.gradle. 4 | # 5 | # For more details, see 6 | # http://developer.android.com/guide/developing/tools/proguard.html 7 | 8 | # If your project uses WebView with JS, uncomment the following 9 | # and specify the fully qualified class name to the JavaScript interface 10 | # class: 11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview { 12 | # public *; 13 | #} 14 | 15 | # Uncomment this to preserve the line number information for 16 | # debugging stack traces. 17 | #-keepattributes SourceFile,LineNumberTable 18 | 19 | # If you keep the line number information, uncomment this to 20 | # hide the original source file name. 21 | #-renamesourcefileattribute SourceFile -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/androidTest/java/org/pytorch/torchchat/PerfTest.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | import static org.junit.Assert.assertEquals; 12 | import static org.junit.Assert.assertFalse; 13 | 14 | import android.os.Bundle; 15 | import androidx.test.ext.junit.runners.AndroidJUnit4; 16 | import androidx.test.platform.app.InstrumentationRegistry; 17 | import java.io.File; 18 | import java.util.ArrayList; 19 | import java.util.Arrays; 20 | import java.util.List; 21 | import org.junit.Test; 22 | import org.junit.runner.RunWith; 23 | import org.pytorch.executorch.LlamaCallback; 24 | import org.pytorch.executorch.LlamaModule; 25 | 26 | @RunWith(AndroidJUnit4.class) 27 | public class PerfTest implements LlamaCallback { 28 | 29 | private static final String RESOURCE_PATH = "/data/local/tmp/llama/"; 30 | private static final String TOKENIZER_BIN = "tokenizer.bin"; 31 | 32 | private final List results = new ArrayList<>(); 33 | private final List tokensPerSecond = new ArrayList<>(); 34 | 35 | @Test 36 | public void testTokensPerSecond() { 37 | String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN; 38 | // Find out the model name 39 | File directory = new File(RESOURCE_PATH); 40 | Arrays.stream(directory.listFiles()) 41 | .filter(file -> file.getName().endsWith(".pte")) 42 | .forEach( 43 | model -> { 44 | LlamaModule mModule = new LlamaModule(model.getPath(), tokenizerPath, 0.8f); 45 | // Print the model name because there might be more than one of them 46 | report("ModelName", model.getName()); 47 | 48 | int loadResult = mModule.load(); 49 | // Check that the model can be load successfully 50 | assertEquals(0, loadResult); 51 | 52 | // Run a testing prompt 53 | mModule.generate("How do you do! I'm testing llama2 on mobile device", PerfTest.this); 54 | assertFalse(tokensPerSecond.isEmpty()); 55 | 56 | final Float tps = tokensPerSecond.get(tokensPerSecond.size() - 1); 57 | report("TPS", tps); 58 | }); 59 | } 60 | 61 | @Override 62 | public void onResult(String result) { 63 | results.add(result); 64 | } 65 | 66 | @Override 67 | public void onStats(float tps) { 68 | tokensPerSecond.add(tps); 69 | } 70 | 71 | private void report(final String metric, final Float value) { 72 | Bundle bundle = new Bundle(); 73 | bundle.putFloat(metric, value); 74 | InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); 75 | } 76 | 77 | private void report(final String key, final String value) { 78 | Bundle bundle = new Bundle(); 79 | bundle.putString(key, value); 80 | InstrumentationRegistry.getInstrumentation().sendStatus(0, bundle); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/AndroidManifest.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 28 | 31 | 34 | 35 | 38 | 39 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/AppLog.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | import java.text.SimpleDateFormat; 12 | import java.util.Date; 13 | import java.util.Locale; 14 | 15 | public class AppLog { 16 | private final Long timestamp; 17 | private final String message; 18 | 19 | public AppLog(String message) { 20 | this.timestamp = getCurrentTimeStamp(); 21 | this.message = message; 22 | } 23 | 24 | public Long getTimestamp() { 25 | return timestamp; 26 | } 27 | 28 | public String getMessage() { 29 | return message; 30 | } 31 | 32 | public String getFormattedLog() { 33 | return "[" + getFormattedTimeStamp() + "] " + message; 34 | } 35 | 36 | private Long getCurrentTimeStamp() { 37 | return System.currentTimeMillis(); 38 | } 39 | 40 | private String getFormattedTimeStamp() { 41 | return formatDate(timestamp); 42 | } 43 | 44 | private String formatDate(long milliseconds) { 45 | SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault()); 46 | Date date = new Date(milliseconds); 47 | return formatter.format(date); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/DemoSharedPreferences.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | import android.content.Context; 12 | import android.content.SharedPreferences; 13 | import com.google.gson.Gson; 14 | import com.google.gson.reflect.TypeToken; 15 | import java.lang.reflect.Type; 16 | import java.util.ArrayList; 17 | 18 | public class DemoSharedPreferences { 19 | Context context; 20 | SharedPreferences sharedPreferences; 21 | 22 | public DemoSharedPreferences(Context context) { 23 | this.context = context; 24 | this.sharedPreferences = getSharedPrefs(); 25 | } 26 | 27 | private SharedPreferences getSharedPrefs() { 28 | return context.getSharedPreferences( 29 | context.getString(R.string.demo_pref_file_key), Context.MODE_PRIVATE); 30 | } 31 | 32 | public String getSavedMessages() { 33 | return sharedPreferences.getString(context.getString(R.string.saved_messages_json_key), ""); 34 | } 35 | 36 | public void addMessages(MessageAdapter messageAdapter) { 37 | SharedPreferences.Editor editor = sharedPreferences.edit(); 38 | Gson gson = new Gson(); 39 | String msgJSON = gson.toJson(messageAdapter.getSavedMessages()); 40 | editor.putString(context.getString(R.string.saved_messages_json_key), msgJSON); 41 | editor.apply(); 42 | } 43 | 44 | public void removeExistingMessages() { 45 | SharedPreferences.Editor editor = sharedPreferences.edit(); 46 | editor.remove(context.getString(R.string.saved_messages_json_key)); 47 | editor.apply(); 48 | } 49 | 50 | public void addSettings(SettingsFields settingsFields) { 51 | SharedPreferences.Editor editor = sharedPreferences.edit(); 52 | Gson gson = new Gson(); 53 | String settingsJSON = gson.toJson(settingsFields); 54 | editor.putString(context.getString(R.string.settings_json_key), settingsJSON); 55 | editor.apply(); 56 | } 57 | 58 | public String getSettings() { 59 | return sharedPreferences.getString(context.getString(R.string.settings_json_key), ""); 60 | } 61 | 62 | public void saveLogs() { 63 | SharedPreferences.Editor editor = sharedPreferences.edit(); 64 | Gson gson = new Gson(); 65 | String msgJSON = gson.toJson(ETLogging.getInstance().getLogs()); 66 | editor.putString(context.getString(R.string.logs_json_key), msgJSON); 67 | editor.apply(); 68 | } 69 | 70 | public void removeExistingLogs() { 71 | SharedPreferences.Editor editor = sharedPreferences.edit(); 72 | editor.remove(context.getString(R.string.logs_json_key)); 73 | editor.apply(); 74 | } 75 | 76 | public ArrayList getSavedLogs() { 77 | String logsJSONString = 78 | sharedPreferences.getString(context.getString(R.string.logs_json_key), null); 79 | if (logsJSONString == null || logsJSONString.isEmpty()) { 80 | return new ArrayList<>(); 81 | } 82 | Gson gson = new Gson(); 83 | Type type = new TypeToken>() {}.getType(); 84 | ArrayList appLogs = gson.fromJson(logsJSONString, type); 85 | if (appLogs == null) { 86 | return new ArrayList<>(); 87 | } 88 | return appLogs; 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETImage.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | import android.content.ContentResolver; 12 | import android.graphics.Bitmap; 13 | import android.graphics.BitmapFactory; 14 | import android.graphics.Color; 15 | import android.net.Uri; 16 | import androidx.annotation.Nullable; 17 | import java.io.FileNotFoundException; 18 | import java.io.InputStream; 19 | 20 | public class ETImage { 21 | private int width; 22 | private int height; 23 | private final byte[] bytes; 24 | private final Uri uri; 25 | private final ContentResolver contentResolver; 26 | 27 | ETImage(ContentResolver contentResolver, Uri uri) { 28 | this.contentResolver = contentResolver; 29 | this.uri = uri; 30 | bytes = getBytesFromImageURI(uri); 31 | } 32 | 33 | public int getWidth() { 34 | return width; 35 | } 36 | 37 | public int getHeight() { 38 | return height; 39 | } 40 | 41 | public Uri getUri() { 42 | return uri; 43 | } 44 | 45 | public byte[] getBytes() { 46 | return bytes; 47 | } 48 | 49 | public int[] getInts() { 50 | // We need to convert the byte array to an int array because 51 | // the runner expects an int array as input. 52 | int[] intArray = new int[bytes.length]; 53 | for (int i = 0; i < bytes.length; i++) { 54 | intArray[i] = (bytes[i++] & 0xFF); 55 | } 56 | return intArray; 57 | } 58 | 59 | private byte[] getBytesFromImageURI(Uri uri) { 60 | try { 61 | int RESIZED_IMAGE_WIDTH = 336; 62 | Bitmap bitmap = resizeImage(uri, RESIZED_IMAGE_WIDTH); 63 | 64 | if (bitmap == null) { 65 | ETLogging.getInstance().log("Unable to get bytes from Image URI. Bitmap is null"); 66 | return new byte[0]; 67 | } 68 | 69 | width = bitmap.getWidth(); 70 | height = bitmap.getHeight(); 71 | 72 | byte[] rgbValues = new byte[width * height * 3]; 73 | 74 | for (int y = 0; y < height; y++) { 75 | for (int x = 0; x < width; x++) { 76 | // Get the color of the current pixel 77 | int color = bitmap.getPixel(x, y); 78 | 79 | // Extract the RGB values from the color 80 | int red = Color.red(color); 81 | int green = Color.green(color); 82 | int blue = Color.blue(color); 83 | 84 | // Store the RGB values in the byte array 85 | rgbValues[y * width + x] = (byte) red; 86 | rgbValues[(y * width + x) + height * width] = (byte) green; 87 | rgbValues[(y * width + x) + 2 * height * width] = (byte) blue; 88 | } 89 | } 90 | return rgbValues; 91 | } catch (FileNotFoundException e) { 92 | throw new RuntimeException(e); 93 | } 94 | } 95 | 96 | @Nullable 97 | private Bitmap resizeImage(Uri uri, int maxLength) throws FileNotFoundException { 98 | InputStream inputStream = contentResolver.openInputStream(uri); 99 | if (inputStream == null) { 100 | ETLogging.getInstance().log("Unable to resize image, input streams is null"); 101 | return null; 102 | } 103 | Bitmap bitmap = BitmapFactory.decodeStream(inputStream); 104 | if (bitmap == null) { 105 | ETLogging.getInstance().log("Unable to resize image, bitmap during decode stream is null"); 106 | return null; 107 | } 108 | 109 | float aspectRatio; 110 | int finalWidth, finalHeight; 111 | 112 | if (bitmap.getWidth() > bitmap.getHeight()) { 113 | // width > height --> width = maxLength, height scale with aspect ratio 114 | aspectRatio = bitmap.getWidth() / (float) bitmap.getHeight(); 115 | finalWidth = maxLength; 116 | finalHeight = Math.round(maxLength / aspectRatio); 117 | } else { 118 | // height >= width --> height = maxLength, width scale with aspect ratio 119 | aspectRatio = bitmap.getHeight() / (float) bitmap.getWidth(); 120 | finalHeight = maxLength; 121 | finalWidth = Math.round(maxLength / aspectRatio); 122 | } 123 | 124 | return Bitmap.createScaledBitmap(bitmap, finalWidth, finalHeight, false); 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ETLogging.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | import android.app.Application; 12 | import android.util.Log; 13 | import java.util.ArrayList; 14 | 15 | public class ETLogging extends Application { 16 | private static ETLogging singleton; 17 | 18 | private ArrayList logs; 19 | private DemoSharedPreferences mDemoSharedPreferences; 20 | 21 | @Override 22 | public void onCreate() { 23 | super.onCreate(); 24 | singleton = this; 25 | mDemoSharedPreferences = new DemoSharedPreferences(this.getApplicationContext()); 26 | logs = mDemoSharedPreferences.getSavedLogs(); 27 | if (logs == null) { // We don't have existing sharedPreference stored 28 | logs = new ArrayList<>(); 29 | } 30 | } 31 | 32 | public static ETLogging getInstance() { 33 | return singleton; 34 | } 35 | 36 | public void log(String message) { 37 | AppLog appLog = new AppLog(message); 38 | logs.add(appLog); 39 | Log.d("ETLogging", appLog.getMessage()); 40 | } 41 | 42 | public ArrayList getLogs() { 43 | return logs; 44 | } 45 | 46 | public void clearLogs() { 47 | logs.clear(); 48 | mDemoSharedPreferences.removeExistingLogs(); 49 | } 50 | 51 | public void saveLogs() { 52 | mDemoSharedPreferences.saveLogs(); 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsActivity.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | import android.app.AlertDialog; 12 | import android.content.DialogInterface; 13 | import android.os.Build; 14 | import android.os.Bundle; 15 | import android.widget.ImageButton; 16 | import android.widget.ListView; 17 | import androidx.appcompat.app.AppCompatActivity; 18 | import androidx.core.content.ContextCompat; 19 | import androidx.core.graphics.Insets; 20 | import androidx.core.view.ViewCompat; 21 | import androidx.core.view.WindowInsetsCompat; 22 | 23 | public class LogsActivity extends AppCompatActivity { 24 | 25 | private LogsAdapter mLogsAdapter; 26 | 27 | @Override 28 | protected void onCreate(Bundle savedInstanceState) { 29 | super.onCreate(savedInstanceState); 30 | setContentView(R.layout.activity_logs); 31 | if (Build.VERSION.SDK_INT >= 21) { 32 | getWindow().setStatusBarColor(ContextCompat.getColor(this, R.color.status_bar)); 33 | getWindow().setNavigationBarColor(ContextCompat.getColor(this, R.color.nav_bar)); 34 | } 35 | ViewCompat.setOnApplyWindowInsetsListener( 36 | requireViewById(R.id.main), 37 | (v, insets) -> { 38 | Insets systemBars = insets.getInsets(WindowInsetsCompat.Type.systemBars()); 39 | v.setPadding(systemBars.left, systemBars.top, systemBars.right, systemBars.bottom); 40 | return insets; 41 | }); 42 | 43 | setupLogs(); 44 | setupClearLogsButton(); 45 | } 46 | 47 | @Override 48 | public void onResume() { 49 | super.onResume(); 50 | mLogsAdapter.clear(); 51 | mLogsAdapter.addAll(ETLogging.getInstance().getLogs()); 52 | mLogsAdapter.notifyDataSetChanged(); 53 | } 54 | 55 | private void setupLogs() { 56 | ListView mLogsListView = requireViewById(R.id.logsListView); 57 | mLogsAdapter = new LogsAdapter(this, R.layout.logs_message); 58 | 59 | mLogsListView.setAdapter(mLogsAdapter); 60 | mLogsAdapter.addAll(ETLogging.getInstance().getLogs()); 61 | mLogsAdapter.notifyDataSetChanged(); 62 | } 63 | 64 | private void setupClearLogsButton() { 65 | ImageButton clearLogsButton = requireViewById(R.id.clearLogsButton); 66 | clearLogsButton.setOnClickListener( 67 | view -> { 68 | new AlertDialog.Builder(this) 69 | .setTitle("Delete Logs History") 70 | .setMessage("Do you really want to delete logs history?") 71 | .setIcon(android.R.drawable.ic_dialog_alert) 72 | .setPositiveButton( 73 | android.R.string.yes, 74 | new DialogInterface.OnClickListener() { 75 | public void onClick(DialogInterface dialog, int whichButton) { 76 | // Clear the messageAdapter and sharedPreference 77 | ETLogging.getInstance().clearLogs(); 78 | mLogsAdapter.clear(); 79 | mLogsAdapter.notifyDataSetChanged(); 80 | } 81 | }) 82 | .setNegativeButton(android.R.string.no, null) 83 | .show(); 84 | }); 85 | } 86 | 87 | @Override 88 | protected void onDestroy() { 89 | super.onDestroy(); 90 | ETLogging.getInstance().saveLogs(); 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/LogsAdapter.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | import android.view.LayoutInflater; 12 | import android.view.View; 13 | import android.view.ViewGroup; 14 | import android.widget.ArrayAdapter; 15 | import android.widget.TextView; 16 | import androidx.annotation.NonNull; 17 | import java.util.Objects; 18 | 19 | public class LogsAdapter extends ArrayAdapter { 20 | public LogsAdapter(android.content.Context context, int resource) { 21 | super(context, resource); 22 | } 23 | 24 | static class ViewHolder { 25 | private TextView logTextView; 26 | } 27 | 28 | @NonNull 29 | @Override 30 | public View getView(int position, View convertView, @NonNull ViewGroup parent) { 31 | ViewHolder mViewHolder = null; 32 | 33 | String logMessage = Objects.requireNonNull(getItem(position)).getFormattedLog(); 34 | 35 | if (convertView == null || convertView.getTag() == null) { 36 | mViewHolder = new ViewHolder(); 37 | convertView = LayoutInflater.from(getContext()).inflate(R.layout.logs_message, parent, false); 38 | mViewHolder.logTextView = convertView.requireViewById(R.id.logsTextView); 39 | } else { 40 | mViewHolder = (ViewHolder) convertView.getTag(); 41 | } 42 | mViewHolder.logTextView.setText(logMessage); 43 | return convertView; 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/Message.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | import java.text.SimpleDateFormat; 12 | import java.util.Date; 13 | import java.util.Locale; 14 | 15 | public class Message { 16 | private String text; 17 | private final boolean isSent; 18 | private float tokensPerSecond; 19 | private long totalGenerationTime; 20 | private final long timestamp; 21 | private final MessageType messageType; 22 | private String imagePath; 23 | private final int promptID; 24 | 25 | private static final String TIMESTAMP_FORMAT = "hh:mm a"; // example: 2:23 PM 26 | 27 | public Message(String text, boolean isSent, MessageType messageType, int promptID) { 28 | this.isSent = isSent; 29 | this.messageType = messageType; 30 | this.promptID = promptID; 31 | 32 | if (messageType == MessageType.IMAGE) { 33 | this.imagePath = text; 34 | } else { 35 | this.text = text; 36 | } 37 | 38 | if (messageType != MessageType.SYSTEM) { 39 | this.timestamp = System.currentTimeMillis(); 40 | } else { 41 | this.timestamp = (long) 0; 42 | } 43 | } 44 | 45 | public int getPromptID() { 46 | return promptID; 47 | } 48 | 49 | public MessageType getMessageType() { 50 | return messageType; 51 | } 52 | 53 | public String getImagePath() { 54 | return imagePath; 55 | } 56 | 57 | public String getText() { 58 | return text; 59 | } 60 | 61 | public void appendText(String text) { 62 | this.text += text; 63 | } 64 | 65 | public boolean getIsSent() { 66 | return isSent; 67 | } 68 | 69 | public void setTokensPerSecond(float tokensPerSecond) { 70 | this.tokensPerSecond = tokensPerSecond; 71 | } 72 | 73 | public void setTotalGenerationTime(long totalGenerationTime) { 74 | this.totalGenerationTime = totalGenerationTime; 75 | } 76 | 77 | public float getTokensPerSecond() { 78 | return tokensPerSecond; 79 | } 80 | 81 | public long getTotalGenerationTime() { 82 | return totalGenerationTime; 83 | } 84 | 85 | public long getTimestamp() { 86 | return timestamp; 87 | } 88 | 89 | public String getFormattedTimestamp() { 90 | SimpleDateFormat formatter = new SimpleDateFormat(TIMESTAMP_FORMAT, Locale.getDefault()); 91 | Date date = new Date(timestamp); 92 | return formatter.format(date); 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/MessageType.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | public enum MessageType { 12 | TEXT, 13 | IMAGE, 14 | SYSTEM 15 | } 16 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunner.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | import android.os.Handler; 12 | import android.os.HandlerThread; 13 | import android.os.Looper; 14 | import android.os.Message; 15 | import androidx.annotation.NonNull; 16 | import org.pytorch.executorch.LlamaCallback; 17 | import org.pytorch.executorch.LlamaModule; 18 | 19 | /** A helper class to handle all model running logic within this class. */ 20 | public class ModelRunner implements LlamaCallback { 21 | LlamaModule mModule = null; 22 | 23 | String mModelFilePath = ""; 24 | String mTokenizerFilePath = ""; 25 | 26 | ModelRunnerCallback mCallback = null; 27 | 28 | HandlerThread mHandlerThread = null; 29 | Handler mHandler = null; 30 | 31 | /** 32 | * ] Helper class to separate between UI logic and model runner logic. Automatically handle 33 | * generate() request on worker thread. 34 | * 35 | * @param modelFilePath 36 | * @param tokenizerFilePath 37 | * @param callback 38 | */ 39 | ModelRunner( 40 | String modelFilePath, 41 | String tokenizerFilePath, 42 | float temperature, 43 | ModelRunnerCallback callback) { 44 | mModelFilePath = modelFilePath; 45 | mTokenizerFilePath = tokenizerFilePath; 46 | mCallback = callback; 47 | 48 | mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f); 49 | mHandlerThread = new HandlerThread("ModelRunner"); 50 | mHandlerThread.start(); 51 | mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); 52 | 53 | mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); 54 | } 55 | 56 | int generate(String prompt) { 57 | Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); 58 | msg.sendToTarget(); 59 | return 0; 60 | } 61 | 62 | void stop() { 63 | mModule.stop(); 64 | } 65 | 66 | @Override 67 | public void onResult(String result) { 68 | mCallback.onTokenGenerated(result); 69 | } 70 | 71 | @Override 72 | public void onStats(float tps) { 73 | mCallback.onStats("tokens/second: " + tps); 74 | } 75 | } 76 | 77 | class ModelRunnerHandler extends Handler { 78 | public static int MESSAGE_LOAD_MODEL = 1; 79 | public static int MESSAGE_GENERATE = 2; 80 | 81 | private final ModelRunner mModelRunner; 82 | 83 | public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { 84 | super(looper); 85 | mModelRunner = modelRunner; 86 | } 87 | 88 | @Override 89 | public void handleMessage(@NonNull android.os.Message msg) { 90 | if (msg.what == MESSAGE_LOAD_MODEL) { 91 | int status = mModelRunner.mModule.load(); 92 | mModelRunner.mCallback.onModelLoaded(status); 93 | } else if (msg.what == MESSAGE_GENERATE) { 94 | mModelRunner.mModule.generate((String) msg.obj, mModelRunner); 95 | mModelRunner.mCallback.onGenerationStopped(); 96 | } 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelRunnerCallback.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | /** 12 | * A helper interface within the app for MainActivity and Benchmarking to handle callback from 13 | * ModelRunner. 14 | */ 15 | public interface ModelRunnerCallback { 16 | 17 | void onModelLoaded(int status); 18 | 19 | void onTokenGenerated(String token); 20 | 21 | void onStats(String token); 22 | 23 | void onGenerationStopped(); 24 | } 25 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelType.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | public enum ModelType { 12 | LLAMA_3, 13 | LLAMA_3_1, 14 | LLAMA_3_2, 15 | LLAVA_1_5, 16 | LLAMA_GUARD_3, 17 | } 18 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/ModelUtils.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | public class ModelUtils { 12 | static final int TEXT_MODEL = 1; 13 | static final int VISION_MODEL = 2; 14 | static final int VISION_MODEL_IMAGE_CHANNELS = 3; 15 | static final int VISION_MODEL_SEQ_LEN = 768; 16 | static final int TEXT_MODEL_SEQ_LEN = 256; 17 | 18 | public static int getModelCategory(ModelType modelType) { 19 | switch (modelType) { 20 | case LLAVA_1_5: 21 | return VISION_MODEL; 22 | case LLAMA_3: 23 | case LLAMA_3_1: 24 | case LLAMA_3_2: 25 | default: 26 | return TEXT_MODEL; 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/PromptFormat.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | public class PromptFormat { 12 | 13 | public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}"; 14 | public static final String USER_PLACEHOLDER = "{{ user_prompt }}"; 15 | public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}"; 16 | public static final String DEFAULT_SYSTEM_PROMPT = "Answer the questions in a few sentences"; 17 | 18 | public static String getSystemPromptTemplate(ModelType modelType) { 19 | switch (modelType) { 20 | case LLAMA_3: 21 | case LLAMA_3_1: 22 | case LLAMA_3_2: 23 | return "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n" 24 | + SYSTEM_PLACEHOLDER 25 | + "<|eot_id|>"; 26 | case LLAVA_1_5: 27 | return "USER: "; 28 | default: 29 | return SYSTEM_PLACEHOLDER; 30 | } 31 | } 32 | 33 | public static String getUserPromptTemplate(ModelType modelType) { 34 | switch (modelType) { 35 | case LLAMA_3: 36 | case LLAMA_3_1: 37 | case LLAMA_3_2: 38 | case LLAMA_GUARD_3: 39 | return "<|start_header_id|>user<|end_header_id|>\n" 40 | + USER_PLACEHOLDER 41 | + "<|eot_id|>" 42 | + "<|start_header_id|>assistant<|end_header_id|>"; 43 | 44 | case LLAVA_1_5: 45 | default: 46 | return USER_PLACEHOLDER; 47 | } 48 | } 49 | 50 | public static String getConversationFormat(ModelType modelType) { 51 | switch (modelType) { 52 | case LLAMA_3: 53 | case LLAMA_3_1: 54 | case LLAMA_3_2: 55 | return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>"; 56 | case LLAVA_1_5: 57 | return USER_PLACEHOLDER + " ASSISTANT:"; 58 | default: 59 | return USER_PLACEHOLDER; 60 | } 61 | } 62 | 63 | public static String getStopToken(ModelType modelType) { 64 | switch (modelType) { 65 | case LLAMA_3: 66 | case LLAMA_3_1: 67 | case LLAMA_3_2: 68 | case LLAMA_GUARD_3: 69 | return "<|eot_id|>"; 70 | case LLAVA_1_5: 71 | return ""; 72 | default: 73 | return ""; 74 | } 75 | } 76 | 77 | public static String getLlavaPresetPrompt() { 78 | return "A chat between a curious human and an artificial intelligence assistant. The assistant" 79 | + " gives helpful, detailed, and polite answers to the human's questions. USER: "; 80 | } 81 | 82 | public static String getFormattedLlamaGuardPrompt(String userPrompt) { 83 | return getUserPromptTemplate(ModelType.LLAMA_GUARD_3) 84 | .replace( 85 | USER_PLACEHOLDER, getLlamaGuardPresetPrompt().replace(USER_PLACEHOLDER, userPrompt)); 86 | } 87 | 88 | public static String getLlamaGuardPresetPrompt() { 89 | return "Task: Check if there is unsafe content in 'User' messages in conversations" 90 | + " according our safety policy with the below categories.\n" 91 | + "\n" 92 | + "\n" 93 | + "S1: Violent Crimes. \n" 94 | + "S2: Non-Violent Crimes. \n" 95 | + "S3: Sex Crimes. \n" 96 | + "S4: Child Exploitation. \n" 97 | + "S5: Defamation. \n" 98 | + "S6: Specialized Advice. \n" 99 | + "S7: Privacy. \n" 100 | + "S8: Intellectual Property. \n" 101 | + "S9: Indiscriminate Weapons. \n" 102 | + "S10: Hate. \n" 103 | + "S11: Self-Harm. \n" 104 | + "S12: Sexual Content. \n" 105 | + "S13: Elections. \n" 106 | + "S14: Code Interpreter Abuse.\n" 107 | + "\n" 108 | + "\n" 109 | + "\n" 110 | + "\n" 111 | + "User: " 112 | + USER_PLACEHOLDER 113 | + "\n" 114 | + "\n" 115 | + "Provide your safety assessment for ONLY THE LAST User message in the above" 116 | + " conversation:\n" 117 | + " - First line must read 'safe' or 'unsafe'.\n" 118 | + " - If unsafe, a second line must include a comma-separated list of violated" 119 | + " categories."; 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/java/org/pytorch/torchchat/SettingsFields.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | package org.pytorch.torchchat; 10 | 11 | public class SettingsFields { 12 | 13 | public String getModelFilePath() { 14 | return modelFilePath; 15 | } 16 | 17 | public String getTokenizerFilePath() { 18 | return tokenizerFilePath; 19 | } 20 | 21 | public double getTemperature() { 22 | return temperature; 23 | } 24 | 25 | public String getSystemPrompt() { 26 | return systemPrompt; 27 | } 28 | 29 | public ModelType getModelType() { 30 | return modelType; 31 | } 32 | 33 | public String getUserPrompt() { 34 | return userPrompt; 35 | } 36 | 37 | public String getFormattedSystemAndUserPrompt(String prompt) { 38 | return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt); 39 | } 40 | 41 | public String getFormattedSystemPrompt() { 42 | return PromptFormat.getSystemPromptTemplate(modelType) 43 | .replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt); 44 | } 45 | 46 | public String getFormattedUserPrompt(String prompt) { 47 | return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt); 48 | } 49 | 50 | public boolean getIsClearChatHistory() { 51 | return isClearChatHistory; 52 | } 53 | 54 | public boolean getIsLoadModel() { 55 | return isLoadModel; 56 | } 57 | 58 | private String modelFilePath; 59 | private String tokenizerFilePath; 60 | private double temperature; 61 | private String systemPrompt; 62 | private String userPrompt; 63 | private boolean isClearChatHistory; 64 | private boolean isLoadModel; 65 | private ModelType modelType; 66 | 67 | public SettingsFields() { 68 | ModelType DEFAULT_MODEL = ModelType.LLAMA_3; 69 | 70 | modelFilePath = ""; 71 | tokenizerFilePath = ""; 72 | temperature = SettingsActivity.TEMPERATURE_MIN_VALUE; 73 | systemPrompt = ""; 74 | userPrompt = PromptFormat.getUserPromptTemplate(DEFAULT_MODEL); 75 | isClearChatHistory = false; 76 | isLoadModel = false; 77 | modelType = DEFAULT_MODEL; 78 | } 79 | 80 | public SettingsFields(SettingsFields settingsFields) { 81 | this.modelFilePath = settingsFields.modelFilePath; 82 | this.tokenizerFilePath = settingsFields.tokenizerFilePath; 83 | this.temperature = settingsFields.temperature; 84 | this.systemPrompt = settingsFields.getSystemPrompt(); 85 | this.userPrompt = settingsFields.getUserPrompt(); 86 | this.isClearChatHistory = settingsFields.getIsClearChatHistory(); 87 | this.isLoadModel = settingsFields.getIsLoadModel(); 88 | this.modelType = settingsFields.modelType; 89 | } 90 | 91 | public void saveModelPath(String modelFilePath) { 92 | this.modelFilePath = modelFilePath; 93 | } 94 | 95 | public void saveTokenizerPath(String tokenizerFilePath) { 96 | this.tokenizerFilePath = tokenizerFilePath; 97 | } 98 | 99 | public void saveModelType(ModelType modelType) { 100 | this.modelType = modelType; 101 | } 102 | 103 | public void saveParameters(Double temperature) { 104 | this.temperature = temperature; 105 | } 106 | 107 | public void savePrompts(String systemPrompt, String userPrompt) { 108 | this.systemPrompt = systemPrompt; 109 | this.userPrompt = userPrompt; 110 | } 111 | 112 | public void saveIsClearChatHistory(boolean needToClear) { 113 | this.isClearChatHistory = needToClear; 114 | } 115 | 116 | public void saveLoadModelAction(boolean shouldLoadModel) { 117 | this.isLoadModel = shouldLoadModel; 118 | } 119 | 120 | public boolean equals(SettingsFields anotherSettingsFields) { 121 | if (this == anotherSettingsFields) return true; 122 | return modelFilePath.equals(anotherSettingsFields.modelFilePath) 123 | && tokenizerFilePath.equals(anotherSettingsFields.tokenizerFilePath) 124 | && temperature == anotherSettingsFields.temperature 125 | && systemPrompt.equals(anotherSettingsFields.systemPrompt) 126 | && userPrompt.equals(anotherSettingsFields.userPrompt) 127 | && isClearChatHistory == anotherSettingsFields.isClearChatHistory 128 | && isLoadModel == anotherSettingsFields.isLoadModel 129 | && modelType == anotherSettingsFields.modelType; 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/banner_shape.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_24.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_add_photo_alternate_24.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_article_24.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_close_24.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_delete_forever_24.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_restart_alt_24.xml: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_send_24.xml: -------------------------------------------------------------------------------- 1 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_settings_24.xml: -------------------------------------------------------------------------------- 1 | 8 | 11 | 12 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/baseline_stop_24.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/btn.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/chat_background.xml: -------------------------------------------------------------------------------- 1 | 7 | 9 | 10 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/custom_button_round.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/expand_circle_down.xml: -------------------------------------------------------------------------------- 1 | 6 | 9 | 10 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/ic_launcher_foreground.xml: -------------------------------------------------------------------------------- 1 | 7 | 8 | 9 | 15 | 18 | 21 | 22 | 23 | 24 | 30 | 31 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/input_text_shape.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/edge/android/torchchat/app/src/main/res/drawable/logo.png -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_add_box_48.xml: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_camera_alt_48.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/outline_image_48.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/prompt_shape.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/received_message.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/sent_message.xml: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/drawable/three_dots.xml: -------------------------------------------------------------------------------- 1 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/layout/activity_benchmarking.xml: -------------------------------------------------------------------------------- 1 | 2 | 10 | 11 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/layout/activity_logs.xml: -------------------------------------------------------------------------------- 1 | 2 | 9 | 10 | 14 | 15 | 21 | 22 | 32 | 37 | 45 | 46 | 47 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/layout/logs_message.xml: -------------------------------------------------------------------------------- 1 | 2 | 8 | 9 | 15 | 16 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/layout/received_message.xml: -------------------------------------------------------------------------------- 1 | 2 | 9 | 10 | 18 | 19 | 32 | 33 | 38 | 39 | 48 | 49 | 59 | 60 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/layout/sent_message.xml: -------------------------------------------------------------------------------- 1 | 2 | 11 | 12 | 18 | 19 | 29 | 30 | 41 | 42 | 49 | 50 | 51 | 52 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/layout/system_message.xml: -------------------------------------------------------------------------------- 1 | 2 | 9 | 10 | 22 | 23 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-hdpi/ic_launcher.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/edge/android/torchchat/app/src/main/res/mipmap-hdpi/ic_launcher.webp -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/edge/android/torchchat/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-mdpi/ic_launcher.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/edge/android/torchchat/app/src/main/res/mipmap-mdpi/ic_launcher.webp -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/edge/android/torchchat/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-xhdpi/ic_launcher.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/edge/android/torchchat/app/src/main/res/mipmap-xhdpi/ic_launcher.webp -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/edge/android/torchchat/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/edge/android/torchchat/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/edge/android/torchchat/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/edge/android/torchchat/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchchat/71552fdea3361aaecc37ca3f2d8d29b7ce3e8901/torchchat/edge/android/torchchat/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/values/colors.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | #4294F0 4 | #3700B3 5 | #03DAC5 6 | #007CBA 7 | #A2A4B6 8 | #16293D 9 | #16293D 10 | 11 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/values/strings.xml: -------------------------------------------------------------------------------- 1 | 2 | torchchat 3 | DemoPrefFileKey 4 | SavedMessagesJsonKey 5 | SettingsJsonKey 6 | LogsJsonKey 7 | 8 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/values/styles.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 9 | 10 | 14 | 15 | -------------------------------------------------------------------------------- /torchchat/edge/android/torchchat/app/src/main/res/values/themes.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 |