├── .gitignore ├── LICENSE ├── README.md ├── finetuning ├── finetuning.py ├── other │ └── hf_readme_template.md ├── requirements.txt └── training_engine │ ├── dataset.py │ ├── engine.py │ └── utils.py ├── inference ├── README.md └── lorax_api_server.py ├── logo.png ├── main ├── .env.example ├── .eslintrc.cjs ├── components.json ├── docker-compose.yml ├── next.config.mjs ├── package-lock.json ├── package.json ├── postcss.config.cjs ├── prisma │ └── schema.prisma ├── public │ ├── favicon.ico │ ├── huggingface.png │ ├── logo-simple-dark.svg │ └── wandblogo.svg ├── src │ ├── app │ │ ├── api │ │ │ ├── auth │ │ │ │ └── [...nextauth] │ │ │ │ │ └── route.ts │ │ │ ├── hooks │ │ │ │ └── update │ │ │ │ │ └── route.ts │ │ │ ├── inference │ │ │ │ ├── route.ts │ │ │ │ └── schema.ts │ │ │ ├── search │ │ │ │ └── route.ts │ │ │ └── trpc │ │ │ │ └── [trpc] │ │ │ │ └── route_inactive.ts │ │ ├── billing │ │ │ ├── actions.tsx │ │ │ ├── add-balance.tsx │ │ │ ├── add-credit-card.tsx │ │ │ ├── billing.tsx │ │ │ └── page.tsx │ │ ├── chat │ │ │ └── [slug] │ │ │ │ ├── chat.tsx │ │ │ │ ├── page.tsx │ │ │ │ ├── process.tsx │ │ │ │ └── settings.tsx │ │ ├── datasets │ │ │ ├── new │ │ │ │ ├── actions.tsx │ │ │ │ ├── form.tsx │ │ │ │ └── modal.tsx │ │ │ ├── page.tsx │ │ │ └── table.tsx │ │ ├── layout.tsx │ │ ├── models │ │ │ ├── actions.tsx │ │ │ ├── buttons.tsx │ │ │ ├── export.tsx │ │ │ ├── list.tsx │ │ │ ├── new │ │ │ │ ├── actions.tsx │ │ │ │ ├── confirm-price.tsx │ │ │ │ ├── form.tsx │ │ │ │ ├── page.tsx │ │ │ │ └── search.tsx │ │ │ └── page.tsx │ │ ├── page.tsx │ │ ├── providers.tsx │ │ └── settings │ │ │ ├── actions.tsx │ │ │ ├── form.tsx │ │ │ └── page.tsx │ ├── components │ │ ├── form │ │ │ ├── button.tsx │ │ │ ├── dropdown.tsx │ │ │ ├── file-upload.tsx │ │ │ ├── label.tsx │ │ │ ├── submit-button.tsx │ │ │ └── textfield.tsx │ │ ├── modal.tsx │ │ ├── padding.tsx │ │ ├── page-heading.tsx │ │ ├── sidebar.tsx │ │ ├── tiles.tsx │ │ ├── tooltip.tsx │ │ ├── ui │ │ │ ├── hover-card.tsx │ │ │ ├── slider.tsx │ │ │ └── switch.tsx │ │ ├── user-avatar.tsx │ │ ├── utils │ │ │ └── utils.ts │ │ └── warning.tsx │ ├── constants │ │ ├── modal.ts │ │ └── models.ts │ ├── env.mjs │ ├── pages │ │ ├── _app.tsx │ │ ├── _document.tsx │ │ └── auth │ │ │ ├── login │ │ │ └── index.tsx │ │ │ ├── new-user │ │ │ └── index.tsx │ │ │ └── verify-request │ │ │ └── index.tsx │ ├── server │ │ ├── auth.ts │ │ ├── controller │ │ │ ├── new-dataset.ts │ │ │ ├── new-model.ts │ │ │ ├── process-dataset.ts │ │ │ └── stripe.ts │ │ ├── database │ │ │ ├── chat-request.ts │ │ │ ├── dataset.ts │ │ │ ├── index.ts │ │ │ ├── model.ts │ │ │ └── user.ts │ │ └── utils │ │ │ ├── mail.ts │ │ │ ├── modal.ts │ │ │ ├── observability │ │ │ ├── logtail.ts │ │ │ └── posthog.ts │ │ │ └── session.ts │ └── styles │ │ └── globals.css ├── tailwind.config.ts └── tsconfig.json ├── modal-proxy ├── gunicorn_config.py ├── requirements.txt └── run.py └── prettier.config.js /.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | */node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # database 12 | */prisma/db.sqlite 13 | */prisma/db.sqlite-journal 14 | 15 | # next.js 16 | */.next/ 17 | /out/ 18 | next-env.d.ts 19 | 20 | # production 21 | /build 22 | 23 | # misc 24 | .DS_Store 25 | *.pem 26 | 27 | # debug 28 | npm-debug.log* 29 | yarn-debug.log* 30 | yarn-error.log* 31 | .pnpm-debug.log* 32 | 33 | # local env files 34 | # do not commit any .env files to git, except for the .env.example file. https://create.t3.gg/en/usage/env-variables#using-environment-variables 35 | .env 36 | .env*.local 37 | 38 | # vercel 39 | .vercel 40 | 41 | # typescript 42 | *.tsbuildinfo 43 | 44 | # modal 45 | .modal.toml 46 | 47 | # temporary files 48 | */tmp/**/* 49 | 50 | # python 51 | *.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

4 | 5 |

6 | 7 |

8 | Build AI models for specialized tasks. 9 | 10 |

11 | 12 |
13 | 14 | [💻 Website](https://docs.haven.run/) 15 |   •   16 | [📄 Docs](https://docs.haven.run/) 17 |   •   18 | [☁️ Hosted App](https://app.haven.run/models) 19 |   •   20 | [💬 Discord](https://discord.gg/JDjbfp6q2G) 21 |
22 | 23 |

24 | Haven gives you tools needed to build specialized large language models.
Our platform lets you to fine-tune LLMs through a simple UI and evaluate them based on a wide range of criteria. 25 | 26 |

27 |
28 | 29 |
30 | 31 |
32 | 33 | ## Welcome 🥳 34 | 35 | Welcome to the Haven repository! We initially started out building functionality for LLM fine-tuning, but noticed that most people actually struggle with evaluation and data collection rather than model training itself. To address this, we're planning to add the following features soon: 36 | 37 | - Evaluation: Use prebuilt or define custom evaluation metrics and run experiments comparing models across them 38 | - Visualization: Visualize and search through your experiments to get insights on model performance 39 | - Data Collection: Collecting and formatting training datasets is annyoing. We don't really know how to solve this yet, but we're trying to figure it out! 40 | 41 | If you have feedback or suggestions on these problems, **please reach out!** You can join our [Discord](https://discord.com/invite/JDjbfp6q2G), write us an [email](mailto:hello@haven.run), or [schedule a call](https://cal.com/justus-mattern-xfnomx/30-min-chat). 42 | 43 | 44 |
45 | 46 | ## Getting Started :rocket: 47 | 48 | Instructions to self-host as well as support for AWS, GCP and Azure will follow soon. In the meantime, you can try our hosted app [here](https://app.haven.run/). 49 | 50 | -------------------------------------------------------------------------------- /finetuning/finetuning.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from datasets import load_dataset 4 | from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments, TrainerCallback 5 | from peft import LoraConfig 6 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 7 | from huggingface_hub import HfApi, create_repo 8 | import wandb 9 | import os 10 | from training_engine.engine import ModelEngine 11 | import random 12 | import string 13 | import argparse 14 | import modal 15 | from typing import Dict 16 | from training_engine.utils import send_update 17 | 18 | from fastapi import HTTPException 19 | from modal import gpu, Mount, Stub, Image, Volume, Secret, web_endpoint 20 | from modal.cli.volume import put 21 | 22 | 23 | image = ( 24 | Image.debian_slim(python_version="3.10").apt_install("git") 25 | # Pinned to 10/16/23 26 | .pip_install("torch==2.0.1", "transformers==4.35.0", "peft==0.6.0", "accelerate==0.24.1", "bitsandbytes==0.41.1", "einops==0.7.0", "evaluate==0.4.1", "scikit-learn==1.2.2", "sentencepiece==0.1.99", "wandb==0.15.3", "trl==0.7.2", "huggingface-hub" 27 | ).pip_install("hf-transfer").pip_install("requests").pip_install("modal-client") 28 | .env(dict(HUGGINGFACE_HUB_CACHE="/pretrained_models", HF_HUB_ENABLE_HF_TRANSFER="0", MODAL_CONFIG_PATH="/utils/modal.toml")) 29 | ) 30 | 31 | stub = Stub("llama-finetuning", image=image) 32 | stub.model_volume = modal.NetworkFileSystem.persisted("adapters") 33 | 34 | 35 | @stub.function( 36 | volumes={ 37 | "/datasets": modal.Volume.from_name("datasets"), 38 | "/pretrained_models": modal.Volume.from_name("pretrained_models"), 39 | "/utils": modal.Volume.from_name("utils"), 40 | }, 41 | network_file_systems={ 42 | "/trained_adapters": stub.model_volume 43 | }, 44 | mounts=[modal.Mount.from_local_dir("./other", remote_path="/other")], 45 | gpu=gpu.A100(count=1, memory=80), 46 | timeout=3600 * 12, 47 | allow_cross_region_volumes=True, 48 | secret=Secret.from_name("finetuning-auth-token") 49 | ) 50 | @web_endpoint(method="POST", wait_for_response=False) 51 | def train(inputs: Dict): 52 | 53 | if inputs["auth_token"] != os.environ["AUTH_TOKEN"]: 54 | raise HTTPException( 55 | status_code=status.HTTP_401_UNAUTHORIZED, 56 | detail="Incorrect auth token", 57 | ) 58 | 59 | 60 | if torch.cuda.is_available(): 61 | # Get the current GPU device 62 | device = torch.cuda.current_device() 63 | 64 | # Get the available GPU memory 65 | gpu_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 2) # in megabytes 66 | 67 | print("GPU memory", gpu_memory) 68 | 69 | model_folder_name = f"/trained_adapters/{inputs['model_id']}" 70 | os.environ["WANDB_PROJECT"]="haven" 71 | 72 | try: 73 | wandb.login(anonymous="must", key=inputs["wandb_token"]) 74 | log_wandb = True 75 | except Exception as e: 76 | print("exception: not logged in to wandb") 77 | print(e) 78 | log_wandb = False 79 | 80 | try: 81 | model_engine = ModelEngine(model_name=inputs["model_name"], 82 | dataset_name=inputs["dataset_name"], 83 | training_args=get_training_args(inputs, output_dir=model_folder_name, log_wandb=log_wandb), 84 | engine_args=get_engine_args(inputs) 85 | ) 86 | 87 | model_engine.train() 88 | if log_wandb: 89 | wandb.finish() 90 | 91 | model_engine.tokenizer.save_pretrained(model_folder_name) 92 | save_readme_file(hf_repo_name=inputs["hf_repo"], base_model_name=inputs["model_name"], model_directory=model_folder_name) 93 | 94 | upload_model(repo=inputs["hf_repo"], folder_path=model_folder_name, hf_token=inputs["hf_token"]) 95 | send_update(model_id=inputs["model_id"], key="status", value="finished") 96 | 97 | except Exception as e: 98 | print("exception:", e) 99 | send_update(model_id=inputs["model_id"], key="status", value="error") 100 | 101 | 102 | 103 | def get_training_args(inputs, output_dir, log_wandb): 104 | learning_rate_map = dict(Low=5e-5, Medium=1e-4, High=3e-4) 105 | if log_wandb: 106 | report_to = "wandb" 107 | else: 108 | report_to = "none" 109 | 110 | return dict( 111 | learning_rate=learning_rate_map[inputs["learning_rate"]], 112 | num_train_epochs=inputs["num_epochs"], 113 | output_dir=output_dir, 114 | gradient_accumulation_steps=inputs["gradient_accumulation_steps"], 115 | per_device_train_batch_size=inputs["per_device_train_batch_size"], 116 | report_to=report_to 117 | ) 118 | 119 | def get_engine_args(inputs): 120 | 121 | class TrainingStartedCallback(TrainerCallback): 122 | def on_train_begin(self, args, state, control, **kwargs): 123 | try: 124 | print("sending url", wandb.run.url) 125 | except Exception as e: 126 | print("wandb url not found exception", e) 127 | return 128 | 129 | send_update(model_id=inputs["model_id"], key="wandb", value=wandb.run.url) 130 | 131 | 132 | return dict( 133 | wandb_token=inputs["wandb_token"], 134 | hf_token="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", 135 | max_tokens=inputs["max_tokens"], 136 | callbacks=[TrainingStartedCallback] 137 | ) 138 | 139 | def generate_random_string(): 140 | characters = string.ascii_letters + string.digits 141 | return ''.join(random.choice(characters) for _ in range(8)) 142 | 143 | 144 | def upload_model(repo, folder_path, hf_token): 145 | api = HfApi() 146 | 147 | create_repo(repo, use_auth_token=hf_token, private=True) 148 | api.upload_folder( 149 | folder_path=folder_path, 150 | repo_id=repo, 151 | repo_type="model", 152 | use_auth_token=hf_token, 153 | ) 154 | 155 | 156 | def save_readme_file(hf_repo_name: str, base_model_name: str, model_directory: str): 157 | with open("/other/hf_readme_template.md", "r") as f: 158 | data = f.read() 159 | 160 | full_readme = data.replace("{{base_model_name}}", base_model_name).replace("{{model_name}}", hf_repo_name) 161 | 162 | with open(f"{model_directory}/README.md", "w") as f: 163 | f.write(full_readme) 164 | 165 | 166 | if __name__=="__main__": 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("--learning_rate", type=str, choices=["Low", "Medium", "High"], default="Medium") 169 | parser.add_argument("--num_epochs", type=int, default=3) 170 | parser.add_argument("--model_name", default="meta-llama/Llama-2-7b-chat-hf") 171 | parser.add_argument("--max-tokens", type=int, default=2600) 172 | parser.add_argument("--gradient_accumulation_steps", type=int, default=16) 173 | parser.add_argument("--per_device_train_batch_size", type=int, default=1) 174 | parser.add_argument("--dataset_name", type=str) 175 | parser.add_argument("--hf_repo", type=str) 176 | args = parser.parse_args() 177 | 178 | 179 | 180 | train.local((dict( 181 | wandb_token="d519a364bc5fcc5970e3eb1e9cc54225c444e511", 182 | hf_token="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", 183 | learning_rate=args.learning_rate, 184 | num_epochs=args.num_epochs, 185 | model_name=args.model_name, 186 | dataset_name=args.dataset_name, 187 | hf_repo=args.hf_repo, 188 | max_tokens=args.max_tokens, 189 | gradient_accumulation_steps=args.gradient_accumulation_steps, 190 | per_device_train_batch_size=args.per_device_train_batch_size, 191 | auth_token="", 192 | model_id=generate_random_string() 193 | ) 194 | )) 195 | 196 | -------------------------------------------------------------------------------- /finetuning/other/hf_readme_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | library_name: peft 3 | base_model: meta-llama/Llama-2-7b-chat-hf 4 | --- 5 | 6 | # Model Card for {{model_name}} 7 | 8 | This is a lora adapter for {{base_model_name}} that was trained and automatically uploaded with [Haven](https://haven.run/). 9 | 10 | 11 | ## Testing the Model 12 | 13 | To quickly test the model, you can run it on a GPU with the transformers / peft library: 14 | 15 | ```python 16 | from peft import AutoPeftModelForCausalLM 17 | from transformers import AutoTokenizer 18 | 19 | tokenizer = AutoTokenizer.from_pretrained("{{model_name}}") 20 | model = AutoPeftModelForCausalLM.from_pretrained("{{model_name}}").to("cuda") # if you get a CUDA out of memory error, try load_in_8bit=True 21 | 22 | messages = [ 23 | {"role": "system", "content": "You are a helpful assistant"}, 24 | {"role": "user", "content": "Hi, can you please explain machine learning to me?"} 25 | ] 26 | 27 | encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda") 28 | generated_ids = model.generate(input_ids=model_inputs, min_new_tokens=10, max_new_tokens=300, do_sample=True, temperature=0.9, top_p=0.8) 29 | decoded = tokenizer.batch_decode(generated_ids) 30 | 31 | print(decoded[0]) 32 | ``` -------------------------------------------------------------------------------- /finetuning/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | peft 3 | accelerate 4 | bitsandbytes 5 | einops 6 | evaluate 7 | scikit-learn==1.2.2 8 | sentencepiece==0.1.99 9 | wandb==0.15.3 10 | trl 11 | huggingface-hub -------------------------------------------------------------------------------- /finetuning/training_engine/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | import io 4 | import json 5 | import logging 6 | 7 | from dataclasses import dataclass 8 | from typing import Dict, Sequence 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class ChatDataset(Dataset): 13 | def __init__(self, data_path: str, tokenizer: transformers.AutoTokenizer, conversation_template: str, max_tokens: int): 14 | super(ChatDataset, self).__init__() 15 | data = [] 16 | with open(data_path, "r") as file: 17 | for line in file: 18 | try: 19 | data.append(json.loads(line)) 20 | except Exception as e: 21 | print("json processing exception", e) 22 | continue 23 | 24 | 25 | data_dict = preprocess(data, tokenizer, conversation_template, max_tokens) 26 | 27 | self.input_ids = data_dict["input_ids"] 28 | self.labels = data_dict["labels"] 29 | 30 | def __len__(self): 31 | return len(self.input_ids) 32 | 33 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 34 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 35 | 36 | 37 | @dataclass 38 | class DataCollatorForChatDataset(object): 39 | """ 40 | Collate examples for supervised fine-tuning. 41 | """ 42 | 43 | tokenizer: transformers.PreTrainedTokenizer 44 | 45 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 46 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 47 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 48 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) 49 | 50 | return dict( 51 | input_ids=input_ids, 52 | labels=labels, 53 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 54 | ) 55 | 56 | 57 | class ChatDataModule(): 58 | def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_path: str, conversation_template, max_tokens: int): 59 | 60 | self.dataset = ChatDataset(tokenizer=tokenizer, data_path=data_path, conversation_template=conversation_template, max_tokens=max_tokens) 61 | self.data_collator = DataCollatorForChatDataset(tokenizer=tokenizer) 62 | 63 | 64 | def preprocess(conversations: Sequence[Sequence[dict]], tokenizer: transformers.PreTrainedTokenizer, conversation_template: str, max_tokens: int) -> Dict: 65 | """ 66 | Preprocess the data by tokenizing. 67 | """ 68 | all_input_ids = [] 69 | all_label_ids = [] 70 | tokenizer.use_default_system_prompt = False 71 | 72 | for conv in conversations: 73 | current_conv = conv["messages"] 74 | tokenized_responses = [] 75 | for msg in current_conv: 76 | if msg["role"] == "assistant": 77 | tokenized_responses.append(tokenizer.encode(msg["content"], add_special_tokens=False)) 78 | 79 | tokenized_conv = tokenizer.apply_chat_template(current_conv, chat_template=conversation_template, max_length=max_tokens, truncation=True) 80 | nontokenized_conv = tokenizer.apply_chat_template(current_conv, chat_template=conversation_template, max_length=max_tokens, truncation=True, tokenize=False) 81 | 82 | print("nontokenized", nontokenized_conv) 83 | 84 | loss_aware_tokenized_conv = [-100] * len(tokenized_conv) 85 | 86 | for sublist in tokenized_responses: 87 | for index, value in enumerate(tokenized_conv): 88 | if value in sublist: 89 | loss_aware_tokenized_conv[index] = value 90 | if index < len(tokenized_conv) - 1 and tokenized_conv[index + 1] not in sublist: 91 | loss_aware_tokenized_conv[index + 1] = tokenized_conv[index + 1] 92 | 93 | all_input_ids.append(torch.LongTensor(tokenized_conv)) 94 | all_label_ids.append(torch.LongTensor(loss_aware_tokenized_conv)) 95 | 96 | return dict(input_ids=all_input_ids, labels=all_label_ids) 97 | 98 | -------------------------------------------------------------------------------- /finetuning/training_engine/engine.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from datasets import load_dataset 4 | from .dataset import ChatDataModule 5 | from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback 6 | from peft import LoraConfig, get_peft_model 7 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM 8 | import wandb 9 | import os 10 | from dataclasses import dataclass, field 11 | from typing import List 12 | 13 | 14 | 15 | @dataclass 16 | class EngineTrainArguments(TrainingArguments): 17 | output_dir: str = field( 18 | default='./output', 19 | metadata={"help": 'The output dir for logs and checkpoints'} 20 | ) 21 | per_device_train_batch_size: int = field( 22 | default=1, 23 | metadata={"help": 'batch size'} 24 | ) 25 | gradient_accumulation_steps: int = field( 26 | default=16, 27 | metadata={"help": 'gradient accumulation steps'} 28 | ) 29 | logging_steps: int = field( 30 | default=3, 31 | metadata={"help": 'gradient accumulation steps'} 32 | ) 33 | learning_rate: float = field( 34 | default=1e-4, 35 | metadata={"help": 'learning rate'} 36 | ) 37 | num_train_epochs: int = field( 38 | default=3, 39 | metadata={"help": 'number of training epochs'} 40 | ) 41 | report_to: str = field( 42 | default="wandb", 43 | metadata={"help": 'where to report training logs to'} 44 | ) 45 | save_strategy: str = field( 46 | default="no", 47 | metadata={"help": 'model saving strategy'} 48 | ) 49 | 50 | @dataclass 51 | class EngineConfig: 52 | hf_token: str = field( 53 | default="none", 54 | metadata={"help": 'huggingface token'} 55 | ) 56 | wandb_token: str = field( 57 | default="none", 58 | metadata={"help": 'wandb token'} 59 | ) 60 | default_conversation_template_id: str = field( 61 | default="meta-llama/Llama-2-7b-chat-hf", 62 | metadata={"help": 'which template to use if model does not have one'} 63 | ) 64 | max_tokens: int = field( 65 | default=2700, 66 | metadata={"help": 'max sample length'} 67 | ) 68 | callbacks: List = field( 69 | default=None 70 | ) 71 | 72 | 73 | 74 | class ModelEngine: 75 | def __init__(self, model_name, dataset_name, training_args, engine_args): 76 | self.training_args = EngineTrainArguments(**training_args) 77 | self.engine_config = EngineConfig(**engine_args) 78 | self.model, self.tokenizer = self.load_model(model_name) 79 | self.tokenizer.chat_template = self.get_conversation_template() 80 | self.data_module = ChatDataModule(self.tokenizer, dataset_name, self.tokenizer.chat_template, self.engine_config.max_tokens) 81 | 82 | def load_model(self, model_name): 83 | bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16) 84 | 85 | peft_config = LoraConfig( 86 | lora_alpha=32, 87 | lora_dropout=0.05, 88 | r=8, 89 | bias="none", 90 | task_type="CAUSAL_LM", 91 | target_modules=["q_proj", "v_proj", "o_proj", "k_proj"] 92 | ) 93 | 94 | model = AutoModelForCausalLM.from_pretrained( 95 | model_name, 96 | quantization_config=bnb_config, 97 | device_map="auto", 98 | trust_remote_code=True, 99 | use_auth_token=self.engine_config.hf_token 100 | ) 101 | 102 | 103 | model = get_peft_model(model, peft_config) 104 | 105 | print("device map", model.hf_device_map) 106 | print("trainable", model.print_trainable_parameters()) 107 | 108 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=self.engine_config.hf_token) 109 | tokenizer.pad_token = tokenizer.eos_token 110 | 111 | return model, tokenizer 112 | 113 | def get_conversation_template(self): 114 | print("chat template jajaaaa", self.tokenizer.chat_template) 115 | if self.tokenizer.chat_template is not None: 116 | if not "Only user and assistant roles are supported" in self.tokenizer.chat_template: 117 | return self.tokenizer.chat_template 118 | 119 | if self.tokenizer.default_chat_template is not None: 120 | if not "Only user and assistant roles are supported" in self.tokenizer.default_chat_template: 121 | return self.tokenizer.default_chat_template 122 | 123 | alt_tokenizer = AutoTokenizer.from_pretrained(self.engine_config.default_conversation_template_id, trust_remote_code=True, use_auth_token=self.engine_config.hf_token) 124 | return alt_tokenizer.chat_template 125 | 126 | 127 | def train(self): 128 | 129 | trainer = Trainer( 130 | model=self.model, 131 | train_dataset=self.data_module.dataset, 132 | tokenizer=self.tokenizer, 133 | args=self.training_args, 134 | data_collator=self.data_module.data_collator, 135 | callbacks=self.engine_config.callbacks 136 | ) 137 | 138 | trainer.train() 139 | trainer.save_model() 140 | -------------------------------------------------------------------------------- /finetuning/training_engine/utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | def send_update(model_id, key, value): 4 | url = 'https://app-haven-run.onrender.com/api/hooks/update' 5 | headers = { 6 | 'x-secret': 'SECRET', 7 | 'Content-Type': 'application/json' 8 | } 9 | data = { 10 | 'id': model_id, 11 | 'key': key, 12 | 'value': value 13 | } 14 | 15 | print("sending request", dict( 16 | url=url, 17 | headers=headers, 18 | data=data 19 | )) 20 | 21 | response = requests.post(url, headers=headers, json=data) 22 | 23 | print("response code", response.status_code) 24 | 25 | return response 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /inference/README.md: -------------------------------------------------------------------------------- 1 | # LoraX API Server 2 | 3 | To deploy a multi lora server for a certain model, you need to first specify the 'MODEL_ID' in 'lorax_api_server.py' (the model id should be the huggingface name of the desired base model). Afterwards, just run: 4 | 5 | ``` 6 | modal deploy lorax_api_server.py 7 | ``` -------------------------------------------------------------------------------- /inference/lorax_api_server.py: -------------------------------------------------------------------------------- 1 | from modal import Image, Mount, Secret, Stub, asgi_app, gpu, method, Volume, NetworkFileSystem 2 | import lorax 3 | from typing import List, Dict 4 | from transformers import AutoTokenizer 5 | import requests 6 | import json 7 | 8 | from fastapi import BackgroundTasks, FastAPI, Request 9 | from fastapi.responses import Response, StreamingResponse, JSONResponse 10 | 11 | 12 | GPU_CONFIG = gpu.A10G(count=1) 13 | MODEL_ID = "meta-llama/Llama-2-7b-chat-hf" 14 | 15 | LAUNCH_FLAGS = [ 16 | "--model-id", 17 | MODEL_ID, 18 | "--port", 19 | "8000", 20 | "--cuda-memory-fraction", 21 | "0.95", 22 | "--max-input-length", 23 | "2048", 24 | "--max-total-tokens", 25 | "4096" 26 | ] 27 | 28 | 29 | image = ( 30 | Image.from_registry("ghcr.io/predibase/lorax:latest") 31 | .dockerfile_commands("ENTRYPOINT []") 32 | .pip_install("lorax-client").pip_install("transformers==4.35.0").env(dict(HUGGINGFACE_HUB_CACHE="/pretrained_models")) 33 | ) 34 | 35 | stub = Stub("multi-lora-server-" + MODEL_ID.split("/")[0]+"-"+MODEL_ID.split("/")[1], image=image) 36 | 37 | 38 | @stub.cls( 39 | secret=Secret.from_name("huggingface-token"), 40 | gpu=GPU_CONFIG, 41 | allow_concurrent_inputs=100, 42 | container_idle_timeout=60 * 10, 43 | timeout=60 * 60, 44 | keep_warm=1, 45 | allow_cross_region_volumes=True, 46 | volumes={ 47 | "/pretrained_models": Volume.from_name("pretrained_models") 48 | }, 49 | network_file_systems={ 50 | "/trained_adapters": NetworkFileSystem.persisted("adapters") 51 | } 52 | ) 53 | class Model: 54 | def __enter__(self): 55 | import socket 56 | import subprocess 57 | import time 58 | 59 | from lorax import AsyncClient 60 | 61 | self.launcher = subprocess.Popen( 62 | ["lorax-launcher"] + LAUNCH_FLAGS 63 | ) 64 | self.client = AsyncClient("http://127.0.0.1:8000", timeout=60) 65 | self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) 66 | 67 | 68 | def webserver_ready(): 69 | try: 70 | socket.create_connection(("127.0.0.1", 8000), timeout=1).close() 71 | return True 72 | except (socket.timeout, ConnectionRefusedError): 73 | 74 | retcode = self.launcher.poll() 75 | if retcode is not None: 76 | raise RuntimeError( 77 | f"launcher exited unexpectedly with code {retcode}" 78 | ) 79 | return False 80 | 81 | while not webserver_ready(): 82 | time.sleep(1.0) 83 | 84 | print("Webserver ready!") 85 | 86 | def __exit__(self, _exc_type, _exc_value, _traceback): 87 | self.launcher.terminate() 88 | 89 | 90 | @asgi_app() 91 | def fastapi_app(self): 92 | app = FastAPI() 93 | 94 | 95 | @app.post("/generate") 96 | async def generate(request: Request): 97 | request = await request.json() 98 | chat = request["chat"] 99 | parameters = request["parameters"] 100 | 101 | prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 102 | 103 | payload = { 104 | "inputs": prompt, 105 | "parameters": parameters 106 | } 107 | url = "http://127.0.0.1:8000/generate" 108 | headers = {"Content-Type": "application/json"} 109 | response = requests.post(url, headers=headers, json=payload) 110 | 111 | return Response(content=json.dumps(response.json(), ensure_ascii=False).encode("utf-8")) 112 | 113 | @app.post("/generate_stream") 114 | async def generate_stream(request: Request): 115 | request = await request.json() 116 | chat = request["chat"] 117 | parameters = request["parameters"] 118 | 119 | prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 120 | 121 | print("prompt", prompt) 122 | 123 | payload = { 124 | "inputs": prompt, 125 | "parameters": parameters 126 | } 127 | url = "http://127.0.0.1:8000/generate_stream" 128 | headers = {"Content-Type": "application/json"} 129 | response = requests.post(url, headers=headers, json=payload, stream=True) 130 | 131 | print("res", response) 132 | 133 | def stream(): 134 | for line in response: 135 | print(line) 136 | yield line 137 | 138 | 139 | return StreamingResponse(stream(), media_type="text/event-stream") 140 | 141 | return app 142 | 143 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/redotvideo/haven/a12a84de83d1763f9b543d027ae4d49b59defe0a/logo.png -------------------------------------------------------------------------------- /main/.env.example: -------------------------------------------------------------------------------- 1 | # Since the ".env" file is gitignored, you can use the ".env.example" file to 2 | # build a new ".env" file when you clone the repo. Keep this file up-to-date 3 | # when you add new variables to `.env`. 4 | 5 | # This file will be committed to version control, so make sure not to have any 6 | # secrets in it. If you are cloning this repo, create a copy of this file named 7 | # ".env" and populate it with your secrets. 8 | 9 | # When adding additional environment variables, the schema in "/src/env.mjs" 10 | # should be updated accordingly. 11 | 12 | # Prisma 13 | # https://www.prisma.io/docs/reference/database-reference/connection-urls#env 14 | DATABASE_URL="file:./db.sqlite" 15 | 16 | # Next Auth 17 | # You can generate a new secret on the command line with: 18 | # openssl rand -base64 32 19 | # https://next-auth.js.org/configuration/options#secret 20 | # NEXTAUTH_SECRET="" 21 | NEXTAUTH_URL="http://localhost:3000" 22 | 23 | # Next Auth Discord Provider 24 | DISCORD_CLIENT_ID="" 25 | DISCORD_CLIENT_SECRET="" 26 | -------------------------------------------------------------------------------- /main/.eslintrc.cjs: -------------------------------------------------------------------------------- 1 | /** @type {import("eslint").Linter.Config} */ 2 | const config = { 3 | parser: "@typescript-eslint/parser", 4 | parserOptions: { 5 | project: true, 6 | }, 7 | plugins: ["@typescript-eslint"], 8 | extends: [ 9 | "next/core-web-vitals", 10 | "plugin:@typescript-eslint/recommended-type-checked", 11 | "plugin:@typescript-eslint/stylistic-type-checked", 12 | ], 13 | rules: { 14 | // These opinionated rules are enabled in stylistic-type-checked above. 15 | // Feel free to reconfigure them to your own preference. 16 | "@typescript-eslint/array-type": "off", 17 | "@typescript-eslint/consistent-type-definitions": "off", 18 | 19 | "@typescript-eslint/consistent-type-imports": [ 20 | "warn", 21 | { 22 | prefer: "type-imports", 23 | fixStyle: "inline-type-imports", 24 | }, 25 | ], 26 | "@typescript-eslint/no-unused-vars": ["warn", {argsIgnorePattern: "^_"}], 27 | "@typescript-eslint/no-misused-promises": [ 28 | 2, 29 | { 30 | checksVoidReturn: {attributes: false}, 31 | }, 32 | ], 33 | 34 | // TODO: maybe we can turn these back on at some point 35 | "@typescript-eslint/no-unsafe-assignment": "off", 36 | "@typescript-eslint/no-explicit-any": "off", 37 | "@typescript-eslint/prefer-nullish-coalescing": "off", 38 | "@typescript-eslint/no-empty-function": "off", 39 | }, 40 | }; 41 | 42 | module.exports = config; 43 | -------------------------------------------------------------------------------- /main/components.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://ui.shadcn.com/schema.json", 3 | "style": "default", 4 | "rsc": true, 5 | "tsx": true, 6 | "tailwind": { 7 | "config": "tailwind.config.ts", 8 | "css": "src/styles/globals.css", 9 | "baseColor": "gray", 10 | "cssVariables": false 11 | }, 12 | "aliases": { 13 | "components": "src/components", 14 | "utils": "~/components/utils/utils" 15 | } 16 | } -------------------------------------------------------------------------------- /main/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.3" 2 | services: 3 | mysql: 4 | image: mysql:8.1.0 5 | restart: always 6 | ports: 7 | - 3306:3306 8 | environment: 9 | MYSQL_ROOT_PASSWORD: example 10 | MYSQL_DATABASE: example 11 | MYSQL_USER: example 12 | MYSQL_PASSWORD: example 13 | -------------------------------------------------------------------------------- /main/next.config.mjs: -------------------------------------------------------------------------------- 1 | /** 2 | * Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation. This is especially useful 3 | * for Docker builds. 4 | */ 5 | await import("./src/env.mjs"); 6 | 7 | /** @type {import("next").NextConfig} */ 8 | const config = { 9 | reactStrictMode: true, 10 | 11 | /** 12 | * If you are using `appDir` then you must comment the below `i18n` config out. 13 | * 14 | * @see https://github.com/vercel/next.js/issues/41980 15 | */ 16 | i18n: { 17 | locales: ["en"], 18 | defaultLocale: "en", 19 | }, 20 | 21 | // Indicate that these packages should not be bundled by webpack 22 | experimental: { 23 | serverComponentsExternalPackages: ["sharp", "onnxruntime-node"], 24 | }, 25 | }; 26 | 27 | export default config; 28 | -------------------------------------------------------------------------------- /main/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "t3", 3 | "version": "0.1.0", 4 | "private": true, 5 | "scripts": { 6 | "build": "npm install && next build", 7 | "db:push": "prisma db push", 8 | "db:studio": "prisma studio", 9 | "dev": "next dev", 10 | "postinstall": "prisma generate", 11 | "lint": "next lint", 12 | "start": "next start" 13 | }, 14 | "dependencies": { 15 | "@headlessui/react": "^1.7.17", 16 | "@headlessui/tailwindcss": "^0.2.0", 17 | "@heroicons/react": "^2.0.18", 18 | "@huggingface/hub": "^0.12.1", 19 | "@logtail/node": "^0.4.17", 20 | "@next-auth/prisma-adapter": "^1.0.7", 21 | "@prisma/client": "^5.1.1", 22 | "@radix-ui/react-hover-card": "^1.0.7", 23 | "@radix-ui/react-slider": "^1.1.2", 24 | "@radix-ui/react-switch": "^1.0.3", 25 | "@stripe/react-stripe-js": "^2.3.1", 26 | "@t3-oss/env-nextjs": "^0.7.0", 27 | "@tailwindcss/forms": "^0.5.6", 28 | "@tanstack/react-query": "^4.32.6", 29 | "@xenova/transformers": "^2.8.0", 30 | "class-variance-authority": "^0.7.0", 31 | "clsx": "^2.0.0", 32 | "lucide-react": "^0.290.0", 33 | "next": "^14.0.1", 34 | "next-auth": "^4.23.1", 35 | "nodemailer": "^6.9.7", 36 | "posthog-node": "^3.1.3", 37 | "react": "18.2.0", 38 | "react-dom": "18.2.0", 39 | "resend": "^2.0.0", 40 | "stripe": "^14.3.0", 41 | "superjson": "^1.13.1", 42 | "tailwind-merge": "^2.0.0", 43 | "tailwindcss-animate": "^1.0.7", 44 | "uuid": "^8.0.0", 45 | "yup": "^1.3.2" 46 | }, 47 | "devDependencies": { 48 | "@types/eslint": "^8.44.2", 49 | "@types/node": "^18.16.0", 50 | "@types/react": "^18.2.20", 51 | "@types/react-dom": "^18.2.7", 52 | "@types/uuid": "^9.0.6", 53 | "@typescript-eslint/eslint-plugin": "^6.3.0", 54 | "@typescript-eslint/parser": "^6.3.0", 55 | "autoprefixer": "^10.4.14", 56 | "eslint": "^8.47.0", 57 | "eslint-config-next": "^13.5.4", 58 | "postcss": "^8.4.27", 59 | "prettier": "^3.0.0", 60 | "prettier-plugin-tailwindcss": "^0.5.1", 61 | "prisma": "^5.1.1", 62 | "tailwindcss": "^3.3.5", 63 | "typescript": "^5.1.6" 64 | }, 65 | "ct3aMetadata": { 66 | "initVersion": "7.22.0" 67 | }, 68 | "packageManager": "npm@9.6.7", 69 | "engines": { 70 | "node": "18.17.1" 71 | } 72 | } -------------------------------------------------------------------------------- /main/postcss.config.cjs: -------------------------------------------------------------------------------- 1 | const config = { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | }, 6 | }; 7 | 8 | module.exports = config; 9 | -------------------------------------------------------------------------------- /main/prisma/schema.prisma: -------------------------------------------------------------------------------- 1 | // This is your Prisma schema file, 2 | // learn more about it in the docs: https://pris.ly/d/prisma-schema 3 | 4 | generator client { 5 | provider = "prisma-client-js" 6 | } 7 | 8 | datasource db { 9 | provider = "mysql" 10 | // NOTE: When using mysql or sqlserver, uncomment the @db.Text annotations in model Account below 11 | // Further reading: 12 | // https://next-auth.js.org/adapters/prisma#create-the-prisma-schema 13 | // https://www.prisma.io/docs/reference/api-reference/prisma-schema-reference#string 14 | url = env("DATABASE_URL") 15 | relationMode = "prisma" 16 | } 17 | 18 | model Post { 19 | id Int @id @default(autoincrement()) 20 | name String 21 | createdAt DateTime @default(now()) 22 | updatedAt DateTime @updatedAt 23 | 24 | createdBy User @relation(fields: [createdById], references: [id]) 25 | createdById String 26 | 27 | @@index([name]) 28 | @@index([createdById]) 29 | } 30 | 31 | // Necessary for Next auth 32 | model Account { 33 | id String @id @default(cuid()) 34 | userId String 35 | type String 36 | provider String 37 | providerAccountId String 38 | refresh_token String? @db.Text 39 | access_token String? @db.Text 40 | expires_at Int? 41 | token_type String? 42 | scope String? 43 | id_token String? @db.Text 44 | session_state String? 45 | user User @relation(fields: [userId], references: [id]) 46 | 47 | createdAt DateTime @default(now()) 48 | updatedAt DateTime @default(now()) @updatedAt 49 | 50 | @@unique([provider, providerAccountId]) 51 | @@index([userId]) 52 | } 53 | 54 | model Session { 55 | id String @id @default(cuid()) 56 | sessionToken String @unique 57 | userId String 58 | expires DateTime 59 | user User @relation(fields: [userId], references: [id]) 60 | 61 | createdAt DateTime @default(now()) 62 | updatedAt DateTime @default(now()) @updatedAt 63 | 64 | @@index([userId]) 65 | } 66 | 67 | model User { 68 | id String @id @default(cuid()) 69 | name String? 70 | email String? @unique 71 | emailVerified DateTime? 72 | image String? 73 | 74 | centsBalance Float @default(500) 75 | 76 | stripeCustomerId String? 77 | stripePaymentMethodId String? 78 | 79 | apiKey String? 80 | 81 | hfToken String? 82 | 83 | createdAt DateTime @default(now()) 84 | updatedAt DateTime @default(now()) @updatedAt 85 | 86 | accounts Account[] 87 | sessions Session[] 88 | posts Post[] 89 | Model Model[] 90 | Transaction Transaction[] 91 | ChatRequest ChatRequest[] 92 | Dataset Dataset[] 93 | } 94 | 95 | model Model { 96 | id String @id @default(cuid()) 97 | userId String 98 | name String 99 | costInCents Int 100 | 101 | state String @default("training") // "training" | "finished" | "error" 102 | 103 | datasetPath String? // TODO: Delete 104 | learningRate String @default("Medium") // "Low" | "Medium" | "High" 105 | epochs Int 106 | baseModel String 107 | 108 | wandbUrl String? 109 | 110 | createdAt DateTime @default(now()) 111 | updatedAt DateTime @default(now()) @updatedAt 112 | 113 | datasetId String? // TODO: Make required 114 | dataset Dataset? @relation(fields: [datasetId], references: [id]) // TODO: make required 115 | 116 | user User @relation(fields: [userId], references: [id]) 117 | 118 | @@index([name]) 119 | @@index([userId]) 120 | @@index([datasetId]) 121 | } 122 | 123 | model Dataset { 124 | id String @id @default(cuid()) 125 | userId String 126 | name String 127 | fileName String 128 | 129 | rows Int 130 | createdAt DateTime @default(now()) 131 | updatedAt DateTime @default(now()) @updatedAt 132 | 133 | user User @relation(fields: [userId], references: [id]) 134 | Model Model[] 135 | 136 | @@index([name]) 137 | @@index([userId]) 138 | } 139 | 140 | model VerificationToken { 141 | identifier String 142 | token String @unique 143 | expires DateTime 144 | 145 | @@unique([identifier, token]) 146 | } 147 | 148 | model Transaction { 149 | id String @id @default(cuid()) 150 | userId String 151 | amount Float 152 | createdAt DateTime @default(now()) 153 | 154 | reason String 155 | 156 | user User @relation(fields: [userId], references: [id]) 157 | 158 | @@index([userId]) 159 | } 160 | 161 | model ChatRequest { 162 | id String @id @default(cuid()) 163 | 164 | userId String 165 | modelId String 166 | 167 | createdAt DateTime @default(now()) 168 | 169 | user User @relation(fields: [userId], references: [id]) 170 | 171 | @@index([userId]) 172 | @@index([createdAt]) 173 | } 174 | -------------------------------------------------------------------------------- /main/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/redotvideo/haven/a12a84de83d1763f9b543d027ae4d49b59defe0a/main/public/favicon.ico -------------------------------------------------------------------------------- /main/public/huggingface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/redotvideo/haven/a12a84de83d1763f9b543d027ae4d49b59defe0a/main/public/huggingface.png -------------------------------------------------------------------------------- /main/public/logo-simple-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /main/public/wandblogo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /main/src/app/api/auth/[...nextauth]/route.ts: -------------------------------------------------------------------------------- 1 | import NextAuth from "next-auth"; 2 | 3 | import {authOptions} from "~/server/auth"; 4 | 5 | const handler: any = NextAuth(authOptions); 6 | export {handler as GET, handler as POST}; 7 | -------------------------------------------------------------------------------- /main/src/app/api/hooks/update/route.ts: -------------------------------------------------------------------------------- 1 | import * as y from "yup"; 2 | import {addWandBUrl, getJobCostInCents, getModelFromId, updateState} from "~/server/database/model"; 3 | import {getUserFromId, increaseBalance} from "~/server/database/user"; 4 | import {logger} from "~/server/utils/observability/logtail"; 5 | import {EventName, sendEvent} from "~/server/utils/observability/posthog"; 6 | 7 | const WEBHOOK_SECRET = process.env.WEBHOOK_SECRET || ""; 8 | 9 | const bodySchema = y.object({ 10 | id: y.string().required(), 11 | key: y.string().oneOf(["wandb", "status"]).required(), 12 | value: y.string().required(), 13 | }); 14 | 15 | export async function POST(request: Request) { 16 | if (request.headers.get("x-secret") !== WEBHOOK_SECRET) { 17 | return new Response("Unauthorized", { 18 | status: 401, 19 | }); 20 | } 21 | 22 | const body = await request.json(); 23 | const validatedBody = await bodySchema.validate(body); 24 | 25 | const model = await getModelFromId(validatedBody.id); 26 | if (model == null) { 27 | throw new Error("Could not find model"); 28 | } 29 | 30 | if (validatedBody.key === "wandb" && validatedBody.id !== "") { 31 | sendEvent("system", EventName.FINE_TUNE_WANDB_ADDED, {wandbUrl: validatedBody.value, modelId: validatedBody.id}); 32 | await addWandBUrl(validatedBody.id, validatedBody.value); 33 | } 34 | 35 | if (validatedBody.key === "status") { 36 | const status = await y.string().oneOf(["finished", "error"]).required().validate(validatedBody.value); 37 | 38 | // Refund 39 | if (status === "error") { 40 | const jobPriceInCents = await getJobCostInCents(validatedBody.id); 41 | const user = await getUserFromId(model.userId); 42 | 43 | if (!user) { 44 | logger.error("Could not find user", {id: validatedBody.id, userId: model.userId}); 45 | throw new Error("Could not find user"); 46 | } 47 | 48 | if (jobPriceInCents == null) { 49 | logger.error("Could not find job price", {id: validatedBody.id, userId: model.userId}); 50 | throw new Error("Could not find job price"); 51 | } 52 | 53 | logger.error("Training job failed, refunding cost", {id: validatedBody.id, jobPriceInCents, userId: user.id}); 54 | await increaseBalance(model.userId, jobPriceInCents, `Refund for failed model: ${validatedBody.id}`).catch( 55 | (e) => { 56 | logger.error("Could not refund", {id: validatedBody.id, jobPriceInCents, userId: user.id, error: e}); 57 | }, 58 | ); 59 | } else { 60 | sendEvent(model.userId, status === "finished" ? EventName.FINE_TUNE_FINISHED : EventName.FINE_TUNE_FAILED); 61 | } 62 | 63 | await updateState(validatedBody.id, status); 64 | } 65 | 66 | return new Response("OK", { 67 | status: 200, 68 | }); 69 | } 70 | -------------------------------------------------------------------------------- /main/src/app/api/inference/route.ts: -------------------------------------------------------------------------------- 1 | import * as y from "yup"; 2 | import {defaultModelLoopup, modelsToFinetune} from "~/constants/models"; 3 | import {inferenceEndpoints} from "~/constants/modal"; 4 | import {createChatRequest, getNumberOfChatRequestsInLast24Hours} from "~/server/database/chat-request"; 5 | import {getModelFromId} from "~/server/database/model"; 6 | import {logger} from "~/server/utils/observability/logtail"; 7 | import {checkSessionAction} from "~/server/utils/session"; 8 | import {bodySchema} from "./schema"; 9 | import {EventName, sendEvent} from "~/server/utils/observability/posthog"; 10 | 11 | async function getResponse(host: string, modelId: string | undefined, validatedBody: y.InferType) { 12 | const adapter_id = modelId ? `/trained_adapters/${modelId}` : undefined; 13 | 14 | const body = JSON.stringify({ 15 | parameters: { 16 | temperature: validatedBody.parameters.temperature, 17 | top_p: validatedBody.parameters.topP, 18 | do_sample: validatedBody.parameters.doSample, 19 | max_new_tokens: validatedBody.parameters.maxTokens, 20 | repetition_penalty: validatedBody.parameters.repetitionPenalty, 21 | adapter_id, 22 | }, 23 | chat: validatedBody.history, 24 | auth_token: process.env.MODAL_AUTH_TOKEN, 25 | }); 26 | 27 | console.log("sending request", body, host); 28 | 29 | return fetch(host, { 30 | method: "POST", 31 | headers: { 32 | "Content-Type": "application/json", 33 | }, 34 | body, 35 | }); 36 | } 37 | 38 | async function retryInference(makeCall: () => Promise, logInfo: object) { 39 | for (let i = 0; i < 3; i++) { 40 | try { 41 | return makeCall(); 42 | } catch (e) { 43 | logger.error("Inference error, try no. " + i, {...logInfo, error: e}); 44 | } 45 | } 46 | 47 | logger.error("Inference failed after 3 attempts.", logInfo); 48 | throw new Error("Inference failed"); 49 | } 50 | 51 | export async function POST(request: Request) { 52 | const body = await request.json(); 53 | const validatedBody = await bodySchema.validate(body); 54 | 55 | const [session, model] = await Promise.all([checkSessionAction(), getModelFromId(validatedBody.modelId)]); 56 | 57 | sendEvent(session.user.id, EventName.INFERENCE_REQUEST, { 58 | modelId: validatedBody.modelId, 59 | }); 60 | 61 | logger.info("Inference request", { 62 | modelId: validatedBody.modelId, 63 | userId: session.user.id, 64 | body: JSON.stringify(body), 65 | }); 66 | 67 | // Check if user has reached their daily limit 68 | const last24h = await getNumberOfChatRequestsInLast24Hours(session.user.id); 69 | if (last24h > 100) { 70 | logger.info("Daily inference rate limit reached", {modelId: validatedBody.modelId}); 71 | return new Response("Daily limit reached for free tier", { 72 | status: 402, 73 | }); 74 | } 75 | 76 | // Both of these are available to all users 77 | const isDefaultModel = Object.keys(defaultModelLoopup).includes(validatedBody.modelId); 78 | 79 | // Adapter not found and not internal model 80 | if (!isDefaultModel && !model) { 81 | logger.error("Model not found", {modelId: validatedBody.modelId, userId: session.user.id}); 82 | throw new Error("Model not found"); 83 | } 84 | 85 | const baseModel = model?.baseModel || defaultModelLoopup[validatedBody.modelId as keyof typeof defaultModelLoopup]; 86 | const baseModelValidated = y.string().oneOf(modelsToFinetune).required().validateSync(baseModel); 87 | const host = inferenceEndpoints[baseModelValidated]; 88 | 89 | return retryInference( 90 | async () => { 91 | const response = await getResponse(host, model?.id, validatedBody); 92 | if (response.ok) { 93 | await createChatRequest(session.user.id, validatedBody.modelId); 94 | } 95 | 96 | return new Response(response.body); 97 | }, 98 | {modelId: validatedBody.modelId}, 99 | ); 100 | } 101 | -------------------------------------------------------------------------------- /main/src/app/api/inference/schema.ts: -------------------------------------------------------------------------------- 1 | import * as y from "yup"; 2 | 3 | export const bodySchema = y.object({ 4 | modelId: y.string().required(), 5 | history: y 6 | .array( 7 | y.object({ 8 | role: y.string().oneOf(["user", "system", "assistant"]).required(), 9 | content: y.string().required(), 10 | }), 11 | ) 12 | .required(), 13 | parameters: y.object({ 14 | temperature: y.number().min(0).max(1.5).required(), 15 | topP: y.number().min(0).max(1).required(), 16 | maxTokens: y.number().min(10).max(2048).required(), 17 | repetitionPenalty: y.number().min(1).max(1.8).required(), 18 | doSample: y.boolean().required(), 19 | }), 20 | }); 21 | -------------------------------------------------------------------------------- /main/src/app/api/search/route.ts: -------------------------------------------------------------------------------- 1 | import * as y from "yup"; 2 | import {getDatasetsByNameForUser} from "~/server/database/dataset"; 3 | import {checkSessionAction} from "~/server/utils/session"; 4 | 5 | import type {Dataset} from "@prisma/client"; 6 | 7 | const bodySchema = y 8 | .object({ 9 | query: y.string(), 10 | }) 11 | .required(); 12 | 13 | function filterDatasetProperties(datasets: Dataset[]) { 14 | return datasets.map((dataset) => ({ 15 | id: dataset.id, 16 | name: dataset.name, 17 | rows: dataset.rows, 18 | })); 19 | } 20 | 21 | export type SearchResponse = ReturnType; 22 | 23 | export async function POST(request: Request) { 24 | const session = await checkSessionAction(); 25 | 26 | const body = await request.json(); 27 | const validatedBody = await bodySchema.validate(body); 28 | 29 | const results = await getDatasetsByNameForUser(session.user.id, validatedBody.query || ""); 30 | 31 | return new Response(JSON.stringify(filterDatasetProperties(results)), { 32 | status: 200, 33 | headers: { 34 | "Content-Type": "application/json", 35 | }, 36 | }); 37 | } 38 | -------------------------------------------------------------------------------- /main/src/app/api/trpc/[trpc]/route_inactive.ts: -------------------------------------------------------------------------------- 1 | // TODO 2 | /* 3 | import {createNextApiHandler} from "@trpc/server/adapters/next"; 4 | import {NextApiRequest} from "next"; 5 | 6 | import {env} from "~/env.mjs"; 7 | import {appRouter} from "~/server/api/root"; 8 | import {createTRPCContext} from "~/server/api/trpc"; 9 | 10 | // export API handler 11 | const handler = createNextApiHandler({ 12 | router: appRouter, 13 | createContext: createTRPCContext, 14 | onError: 15 | env.NODE_ENV === "development" 16 | ? ({path, error}) => { 17 | console.error(`❌ tRPC failed on ${path ?? ""}: ${error.message}`); 18 | } 19 | : undefined, 20 | }) as any as NextApiRequest; 21 | */ 22 | -------------------------------------------------------------------------------- /main/src/app/billing/actions.tsx: -------------------------------------------------------------------------------- 1 | "use server"; 2 | 3 | import * as y from "yup"; 4 | 5 | import {revalidatePath} from "next/cache"; 6 | import { 7 | createPaymentIntent, 8 | createSetupIntent, 9 | createStripeCustomerAndAttachPaymentMethod, 10 | getStripeCreditCardInfo, 11 | validateSetupIntent, 12 | } from "~/server/controller/stripe"; 13 | import {addStripeCustomerIdAndPaymentMethodToUser, increaseBalance, updateName} from "~/server/database/user"; 14 | import {EventName, sendEvent} from "~/server/utils/observability/posthog"; 15 | import {checkSessionAction} from "~/server/utils/session"; 16 | import {logger} from "~/server/utils/observability/logtail"; 17 | 18 | export async function revalidate() { 19 | return Promise.resolve().then(() => revalidatePath("/billing")); 20 | } 21 | 22 | export async function updateAccountName(name: string) { 23 | y.string().required().validateSync(name); 24 | 25 | const session = await checkSessionAction(); 26 | await updateName(session.user.id, name).catch((e) => { 27 | logger.error("Error updating name", {error: e}); 28 | throw new Error("Internal server error"); 29 | }); 30 | } 31 | 32 | export async function getStripePublishableKey() { 33 | if (process.env.STRIPE_PUBLISHABLE_KEY === undefined) { 34 | logger.error("Stripe publishable key is not set."); 35 | throw new Error("Stripe publishable key is not set."); 36 | } 37 | return Promise.resolve(process.env.STRIPE_PUBLISHABLE_KEY); 38 | } 39 | 40 | export async function getCreditCardInformation() { 41 | const session = await checkSessionAction(); 42 | 43 | if (session.user.stripeCustomerId) { 44 | return getStripeCreditCardInfo(session.user.stripeCustomerId).catch((e) => { 45 | logger.error("Error getting stripe credit card info", {error: e}); 46 | throw new Error("Internal server error"); 47 | }); 48 | } 49 | } 50 | 51 | export async function createSetupIntentAction(stripeCreditCardId: string) { 52 | // TODO: validation 53 | const intent = await createSetupIntent(stripeCreditCardId); 54 | return intent.client_secret; 55 | } 56 | 57 | export async function finalizeCreditCard(stripeCreditCardId: string, stripeSetupIntentClientSecret: string) { 58 | // TODO: valdation 59 | const session = await checkSessionAction(); 60 | if (!session.user.email) { 61 | sendEvent(session.user.id, EventName.EMERGENCY, { 62 | message: `User email is not set. User id: ${session.user.id}`, 63 | }); 64 | logger.error("User email is not set."); 65 | throw new Error("User email is not set."); 66 | } 67 | 68 | // Validating setup intent 69 | const setupIntentValid = await validateSetupIntent(stripeSetupIntentClientSecret); 70 | if (!setupIntentValid) { 71 | logger.error("Setup intent is invalid", { 72 | stripeSetupIntentClientSecret, 73 | stripeCreditCardId, 74 | }); 75 | sendEvent(session.user.id, EventName.EMERGENCY, { 76 | message: `Setup intent is invalid. User id: ${session.user.id}, stripeSetupIntentClientSecret: ${stripeSetupIntentClientSecret}, stripeCreditCardId: ${stripeCreditCardId}`, 77 | }); 78 | throw new Error("Error validating setup intend."); 79 | } 80 | 81 | // Creating a stripe customer and attaching the payment method 82 | const stripeCustomer = await createStripeCustomerAndAttachPaymentMethod( 83 | session.user.name ?? "", 84 | session.user.email, 85 | stripeCreditCardId, 86 | session.user.stripeCustomerId ?? undefined, 87 | ).catch((e) => { 88 | logger.error("Error creating stripe customer", { 89 | stripeSetupIntentClientSecret, 90 | stripeCreditCardId, 91 | error: e, 92 | }); 93 | sendEvent(session.user.id, EventName.EMERGENCY, { 94 | message: `Error creating stripe customer. User id: ${session.user.id}, stripeSetupIntentClientSecret: ${stripeSetupIntentClientSecret}, stripeCreditCardId: ${stripeCreditCardId}`, 95 | error: (e as Error).message, 96 | }); 97 | throw new Error("Error creating stripe customer."); 98 | }); 99 | 100 | // Attaching stripe customer id to the databsae user 101 | await addStripeCustomerIdAndPaymentMethodToUser(session.user.id, stripeCustomer.id, stripeCreditCardId).catch((e) => { 102 | logger.error("Error adding stripe customer id to user", { 103 | stripeSetupIntentClientSecret, 104 | stripeCreditCardId, 105 | error: e, 106 | }); 107 | sendEvent(session.user.id, EventName.EMERGENCY, { 108 | message: `Error adding stripe customer id to user. User id: ${session.user.id}, stripeSetupIntentClientSecret: ${stripeSetupIntentClientSecret}, stripeCreditCardId: ${stripeCreditCardId}`, 109 | error: (e as Error).message, 110 | }); 111 | throw new Error("Internal server error."); 112 | }); 113 | 114 | sendEvent(session.user.id, EventName.CREDIT_CARD_ADDED); 115 | } 116 | 117 | export async function addBalanceAction(amountInDollars: number) { 118 | const session = await checkSessionAction(); 119 | if (!session.user.stripeCustomerId || !session.user.stripePaymentMethodId) { 120 | logger.error("User has no stripe customer id"); 121 | throw new Error("Internal error."); 122 | } 123 | 124 | if (amountInDollars > 50) { 125 | logger.error("User is trying to add more than 50 dollars", {amountInDollars}); 126 | throw new Error("Can't add more than $50 at once."); 127 | } 128 | 129 | if (session.user.centsBalance > 10000) { 130 | logger.error("User already has more than 100 dollars"); 131 | throw new Error("You can't add more funds when your balance currently exceeds $100."); 132 | } 133 | 134 | const amountInCents = amountInDollars * 100; 135 | 136 | await createPaymentIntent( 137 | session.user.stripeCustomerId, 138 | session.user.stripePaymentMethodId, 139 | amountInCents, 140 | "Haven Account Top Up", 141 | ).catch((e) => { 142 | logger.error("Error creating off session payment intent", {error: e}); 143 | sendEvent(session.user.id, EventName.EMERGENCY, {message: "User couldn't add balance", amount: amountInCents}); 144 | throw new Error("Error charging credit card. If this persists, send us an email at hello@haven.run."); 145 | }); 146 | 147 | await increaseBalance(session.user.id, amountInCents).catch((e) => { 148 | logger.error("Error increasing balance", {error: e}); 149 | sendEvent(session.user.id, EventName.EMERGENCY, { 150 | message: "User charged but balance not increased.", 151 | amount: amountInCents, 152 | }); 153 | throw new Error("Internal server error."); 154 | }); 155 | 156 | sendEvent(session.user.id, EventName.MONEY_ADDED, {amount: amountInCents}); 157 | } 158 | -------------------------------------------------------------------------------- /main/src/app/billing/add-balance.tsx: -------------------------------------------------------------------------------- 1 | import {Dialog} from "@headlessui/react"; 2 | 3 | import Modal from "../../components/modal"; 4 | import {XMarkIcon} from "@heroicons/react/24/outline"; 5 | import TextField from "../../components/form/textfield"; 6 | import Button from "../../components/form/button"; 7 | import {useState} from "react"; 8 | import {addBalanceAction} from "./actions"; 9 | 10 | interface Props { 11 | open: boolean; 12 | setOpen: (open: boolean) => void; 13 | } 14 | 15 | export default function AddBalance({open, setOpen}: Props) { 16 | const [value, setValue] = useState("$ 5"); 17 | const [loading, setLoading] = useState(false); 18 | const [error, setError] = useState(""); 19 | 20 | /** 21 | * Make sure only only integers are allowed and they are prefixed with a dollar sign 22 | */ 23 | function onChange(str: string) { 24 | const re = /^\$ [0-9\b]*$/; 25 | if (re.test(str)) { 26 | setValue(str); 27 | } 28 | } 29 | 30 | async function addBalance() { 31 | const amountInDollars = parseInt(value.replace("$ ", "")); 32 | 33 | setLoading(true); 34 | try { 35 | await addBalanceAction(amountInDollars); 36 | setOpen(false); 37 | } catch (e) { 38 | setError((e as Error).message); 39 | } 40 | setLoading(false); 41 | } 42 | 43 | return ( 44 | 45 |
46 |
47 | 48 | Add balance 49 | 50 | 53 |
54 |
55 | Add balance to your account. The selected amount will be charged to your credit card. 56 |
57 | 58 |
{error}
59 | 62 |
63 |
64 | ); 65 | } 66 | -------------------------------------------------------------------------------- /main/src/app/billing/add-credit-card.tsx: -------------------------------------------------------------------------------- 1 | import {useState} from "react"; 2 | import {Dialog} from "@headlessui/react"; 3 | import Button from "../../components/form/button"; 4 | import {CardElement, useElements, useStripe} from "@stripe/react-stripe-js"; 5 | import {XMarkIcon} from "@heroicons/react/20/solid"; 6 | import Modal from "../../components/modal"; 7 | import {createSetupIntentAction, finalizeCreditCard, updateAccountName} from "./actions"; 8 | import TextField from "~/components/form/textfield"; 9 | import Label from "~/components/form/label"; 10 | 11 | interface Props { 12 | open: boolean; 13 | setOpen: (open: boolean) => void; 14 | } 15 | 16 | export default function AddCreditCard({open, setOpen}: Props) { 17 | const [loading, setLoading] = useState(false); 18 | const [error, setError] = useState(""); 19 | 20 | const [name, setName] = useState(""); 21 | 22 | const elements = useElements()!; 23 | const stripe = useStripe()!; 24 | 25 | async function addCreditCard() { 26 | setLoading(true); 27 | 28 | const cardElement = elements.getElement(CardElement); 29 | 30 | await updateAccountName(name); 31 | 32 | const {error, paymentMethod} = await stripe.createPaymentMethod({ 33 | type: "card", 34 | card: cardElement!, 35 | }); 36 | 37 | if (error) { 38 | setError("Failed to add credit card. Please try a different card or try again later."); 39 | setLoading(false); 40 | return; 41 | } 42 | 43 | const clientSecret = await createSetupIntentAction(paymentMethod.id); 44 | if (!clientSecret) { 45 | setError("Failed to add credit card. Please try a different card or try again later."); 46 | setLoading(false); 47 | return; 48 | } 49 | 50 | const confirmationResult = await stripe.confirmCardSetup(clientSecret); 51 | if (confirmationResult.error) { 52 | setError("Failed to add credit card. Please try a different card or try again later."); 53 | setLoading(false); 54 | return; 55 | } 56 | 57 | await finalizeCreditCard(paymentMethod.id, clientSecret); 58 | 59 | setError(""); 60 | setLoading(false); 61 | setOpen(false); 62 | } 63 | 64 | return ( 65 | 66 |
67 |
68 | 69 | Add a credit card 70 | 71 | 74 |
75 |
76 | {"We won't charge your card until you add balance to your account manually."} 77 |
78 |
79 | 80 |
81 |
82 | 83 |
84 | 85 |
86 |
87 |
{error}
88 |
89 |
90 | 93 |
94 |
95 | Powered by{" "} 96 | 97 | Stripe 98 | 99 |
100 |
101 | ); 102 | } 103 | -------------------------------------------------------------------------------- /main/src/app/billing/billing.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import Padding from "../../components/padding"; 3 | import PageHeading from "../../components/page-heading"; 4 | import Sidebar from "../../components/sidebar"; 5 | import Warning from "../../components/warning"; 6 | import {useState} from "react"; 7 | import Button from "../../components/form/button"; 8 | import AddCreditCard from "./add-credit-card"; 9 | import AddBalance from "./add-balance"; 10 | import {Elements} from "@stripe/react-stripe-js"; 11 | import {revalidate} from "./actions"; 12 | import {loadStripe} from "@stripe/stripe-js"; 13 | 14 | interface Props { 15 | balance: string; 16 | creditCardInformation?: { 17 | last4: string; 18 | expiry: string; 19 | }; 20 | stripePublishableKey: string; 21 | } 22 | 23 | export default function Billing({balance, creditCardInformation, stripePublishableKey}: Props) { 24 | const [openCreditCardModal, setOpenCreditCardModal] = useState(false); 25 | const [openAddBalanceModal, setOpenAddBalanceModal] = useState(false); 26 | 27 | // Make sure we refresh the page when the modals close 28 | function refreshWrapper(setOpenCloseFunction: (state: boolean) => void) { 29 | return (state: boolean) => { 30 | setOpenCloseFunction(state); 31 | void revalidate(); 32 | }; 33 | } 34 | 35 | return ( 36 | <> 37 | 38 | 39 | 40 | 41 | 42 | 43 | Billing 44 | 45 |
46 | {!creditCardInformation && ( 47 | 48 | )} 49 | 50 | 51 |
52 |
53 |
Balance
54 |
{balance}
55 | 66 |
67 | {creditCardInformation && ( 68 |
69 |
70 |
71 |
72 |
**** {creditCardInformation.last4}
73 |
Expires {creditCardInformation.expiry}
74 |
75 |
76 |
77 |
Using
78 | 79 | {/* 80 | //TODO: Enable delete button 81 | 84 | */} 85 |
86 |
87 |
88 | )} 89 |
90 |
91 | 92 | 93 | ); 94 | } 95 | -------------------------------------------------------------------------------- /main/src/app/billing/page.tsx: -------------------------------------------------------------------------------- 1 | import {checkSession} from "~/server/utils/session"; 2 | import Billing from "./billing"; 3 | import {getCreditCardInformation, getStripePublishableKey} from "./actions"; 4 | 5 | export default async function Page() { 6 | const [session, creditCardInformation, stripePublishableKey] = [ 7 | await checkSession(), 8 | await getCreditCardInformation(), 9 | await getStripePublishableKey(), 10 | ]; 11 | const balance = "$" + (session.user.centsBalance / 100).toFixed(2); 12 | 13 | return ( 14 | <> 15 | 20 | 21 | ); 22 | } 23 | -------------------------------------------------------------------------------- /main/src/app/chat/[slug]/chat.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import {SendIcon} from "lucide-react"; 3 | import {useEffect, useRef, useState} from "react"; 4 | import Button from "~/components/form/button"; 5 | import Label from "~/components/form/label"; 6 | import Settings from "./settings"; 7 | import {AdjustmentsHorizontalIcon} from "@heroicons/react/24/outline"; 8 | import {parseStream} from "./process"; 9 | 10 | import type {ChatMessage} from "./process"; 11 | import type {ChatSettings} from "./settings"; 12 | import UserAvatar from "~/components/user-avatar"; 13 | import {CpuChipIcon} from "@heroicons/react/20/solid"; 14 | 15 | export default function ChatComponent({modelId, email}: {modelId: string; email?: string}) { 16 | const [loading, setLoading] = useState(false); 17 | 18 | const [systemPrompt, setSystemPrompt] = useState("You're a useful assistant."); 19 | const [message, setMessage] = useState(""); 20 | 21 | const lastMessageRef = useRef(null); 22 | const [chat, setChat] = useState([]); 23 | 24 | const [settingsOpen, setSettingsOpen] = useState(false); 25 | const [chatSettings, setChatSettings] = useState({ 26 | temperature: 0.9, 27 | topP: 0.8, 28 | maxTokens: 256, 29 | repetitionPenalty: 1.1, 30 | doSample: true, 31 | }); 32 | 33 | function AssistantAvatar() { 34 | return ( 35 |
36 |
37 | 38 |
39 |
40 | ); 41 | } 42 | 43 | const handleFormSubmit = async (event: React.FormEvent) => { 44 | event.preventDefault(); 45 | 46 | setLoading(true); 47 | setMessage(""); 48 | 49 | const newChat: ChatMessage[] = [...chat, {role: "user", content: message}]; 50 | setChat(newChat); 51 | 52 | const res = await fetch("/api/inference", { 53 | method: "POST", 54 | headers: { 55 | "content-type": "application/json", 56 | }, 57 | body: JSON.stringify({ 58 | modelId, 59 | history: [{role: "system", content: systemPrompt}, ...newChat], 60 | parameters: chatSettings, 61 | }), 62 | }); 63 | 64 | if (res.status === 402) { 65 | alert(`You have reached the rate limit of 100 messages in 24h :(`); 66 | return; 67 | } 68 | 69 | // Res is a stream, read it line by line 70 | const reader = res.body!.getReader(); 71 | 72 | newChat.push({role: "assistant", content: ""}); 73 | await parseStream(reader, newChat, setChat); 74 | 75 | setLoading(false); 76 | }; 77 | 78 | useEffect(() => { 79 | if (lastMessageRef.current) { 80 | lastMessageRef.current.scrollIntoView({behavior: "smooth"}); 81 | } 82 | }, [chat]); 83 | 84 | return ( 85 | <> 86 | 92 | 96 |
97 | 98 | setSystemPrompt(e.target.value)} 103 | placeholder="You're a useful assistant." 104 | /> 105 |
106 | {chat.map((message, i) => ( 107 |
112 | {message.role === "user" ? : } 113 |
{message.content}
114 |
115 | ))} 116 |
117 |
118 |
119 |
120 |
121 | setMessage(e.target.value)} 125 | className="w-full px-3 py-1.5 border rounded-md bg-gray-900 border-gray-700 disabled:opacity-50" 126 | placeholder="Type your message here..." 127 | disabled={loading} 128 | /> 129 | 133 |
134 |
135 |
136 |
137 | 138 | ); 139 | } 140 | -------------------------------------------------------------------------------- /main/src/app/chat/[slug]/page.tsx: -------------------------------------------------------------------------------- 1 | import {ChevronLeftIcon} from "@heroicons/react/20/solid"; 2 | import {redirect} from "next/navigation"; 3 | import Button from "~/components/form/button"; 4 | import Sidebar from "~/components/sidebar"; 5 | import {getModelFromId} from "~/server/database/model"; 6 | import {checkSession} from "~/server/utils/session"; 7 | import ChatComponent from "./chat"; 8 | import Link from "next/link"; 9 | import {defaultModelLoopup} from "~/constants/models"; 10 | 11 | export default async function Chat({params}: {params: {slug: string}}) { 12 | const {slug} = params; 13 | 14 | const [session, model] = await Promise.all([checkSession(), getModelFromId(slug)]); 15 | 16 | const isInternal = Object.keys(defaultModelLoopup).includes(slug); 17 | 18 | if (!isInternal && model?.userId !== session.user.id) { 19 | redirect("/models"); 20 | } 21 | 22 | return ( 23 | 24 | 25 | 29 | 30 | 31 |
32 | 33 | 34 | 35 | ); 36 | } 37 | -------------------------------------------------------------------------------- /main/src/app/chat/[slug]/process.tsx: -------------------------------------------------------------------------------- 1 | import * as y from "yup"; 2 | 3 | export type ChatMessage = { 4 | role: "user" | "assistant" | "system"; 5 | content: string; 6 | }; 7 | 8 | export async function processChunk(current: string, newChat: ChatMessage[], setChat: (chat: ChatMessage[]) => void) { 9 | // Split the stream into lines 10 | const lines = current.split("\n\n"); 11 | 12 | if (lines.length === 0) { 13 | return ""; 14 | } 15 | 16 | // Check if the last line is incomplete 17 | const lastLineIsIncomplete = await Promise.resolve(lines[lines.length - 1]!.slice(5)) 18 | .then((line) => { 19 | JSON.parse(line); 20 | return false; 21 | }) 22 | .catch(() => true); 23 | 24 | const linesToProcess = lastLineIsIncomplete ? lines.slice(0, -1) : lines; 25 | 26 | // JSON parse all but the last line 27 | for (const line of linesToProcess) { 28 | if (!line) { 29 | continue; 30 | } 31 | 32 | const schema = y.object({ 33 | // Our multi lora server provides this 34 | token: y 35 | .object({ 36 | text: y.string(), 37 | special: y.boolean(), 38 | }) 39 | .required(), 40 | // Fireworks AI provides this 41 | choices: y.array( 42 | y.object({ 43 | index: y.number(), 44 | delta: y.object({ 45 | content: y.string(), 46 | }), 47 | }), 48 | ), 49 | }); 50 | 51 | // Line starts with "data:", skip it 52 | const validated = await Promise.resolve() 53 | .then(() => schema.validate(JSON.parse(line.slice(5)))) 54 | .catch(() => undefined); 55 | 56 | // Multi lora 57 | if (validated && Object.keys(validated.token).length > 0) { 58 | // Ignore special tokens 59 | if (validated.token.special) { 60 | continue; 61 | } 62 | 63 | // Add the message to the chat 64 | const current = newChat[newChat.length - 1]!; 65 | current.content += validated.token.text; 66 | setChat([...newChat.slice(0, -1), current]); 67 | } 68 | 69 | // Fireworks AI 70 | if (validated?.choices && Object.keys(validated.choices).length > 0) { 71 | const current = newChat[newChat.length - 1]!; 72 | 73 | if (!validated.choices[0]?.delta.content) { 74 | continue; 75 | } 76 | 77 | current.content += validated.choices[0].delta.content; 78 | setChat([...newChat.slice(0, -1), current]); 79 | } 80 | } 81 | 82 | if (lastLineIsIncomplete) { 83 | // Carry over 84 | return lines[lines.length - 1]!; 85 | } 86 | 87 | return ""; 88 | } 89 | 90 | export async function parseStream( 91 | reader: ReadableStreamDefaultReader, 92 | newChat: ChatMessage[], 93 | setChat: (chat: ChatMessage[]) => void, 94 | ) { 95 | let current = ""; 96 | 97 | // Iterate over the stream 98 | while (true) { 99 | const {done, value} = await reader.read(); 100 | 101 | const decoded = new TextDecoder("utf-8").decode(value); 102 | 103 | current += decoded; 104 | current = await processChunk(current, newChat, setChat); 105 | 106 | if (done) { 107 | break; 108 | } 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /main/src/app/chat/[slug]/settings.tsx: -------------------------------------------------------------------------------- 1 | import Modal from "~/components/modal"; 2 | import {Slider} from "~/components/ui/slider"; 3 | import Label from "~/components/form/label"; 4 | import {useState} from "react"; 5 | import Button from "~/components/form/button"; 6 | import {Dialog} from "@headlessui/react"; 7 | import {XMarkIcon} from "@heroicons/react/20/solid"; 8 | import {Switch} from "~/components/ui/switch"; 9 | 10 | export interface ChatSettings { 11 | temperature: number; 12 | topP: number; 13 | maxTokens: number; 14 | repetitionPenalty: number; 15 | doSample: boolean; 16 | } 17 | 18 | function SettingsSlider({ 19 | label, 20 | value, 21 | setValue, 22 | min = 0.1, 23 | max = 1, 24 | step = 0.01, 25 | }: { 26 | label: string; 27 | value: number; 28 | setValue: (value: number) => void; 29 | min?: number; 30 | max?: number; 31 | step?: number; 32 | }) { 33 | return ( 34 |
35 |
36 | 37 |
{value}
38 |
39 | setValue(value[0]!)} 42 | defaultValue={[value]} 43 | value={[value]} 44 | min={min} 45 | max={max} 46 | step={step} 47 | /> 48 |
49 | ); 50 | } 51 | 52 | export default function Settings({ 53 | open, 54 | setOpen, 55 | chatSettings, 56 | setChatSettings, 57 | }: { 58 | open: boolean; 59 | setOpen: (open: boolean) => void; 60 | chatSettings: ChatSettings; 61 | setChatSettings: (settings: ChatSettings) => void; 62 | }) { 63 | // NOTE: we don't need the extra state here 64 | const [temperature, setTemperature] = useState(chatSettings.temperature); 65 | const [topP, setTopP] = useState(chatSettings.topP); 66 | const [maxTokens, setMaxTokens] = useState(chatSettings.maxTokens); 67 | const [repetitionPenalty, setRepetitionPenalty] = useState(chatSettings.repetitionPenalty); 68 | const [sample, setSample] = useState(chatSettings.doSample); 69 | 70 | function save() { 71 | setChatSettings({ 72 | temperature, 73 | topP, 74 | maxTokens, 75 | repetitionPenalty, 76 | doSample: sample, 77 | }); 78 | setOpen(false); 79 | } 80 | 81 | return ( 82 | <> 83 | 84 | 87 | 88 | Chat Parameters 89 | 90 | 91 | 92 | 93 | 100 |
101 | 102 | 103 |
104 | 107 |
108 | 109 | ); 110 | } 111 | -------------------------------------------------------------------------------- /main/src/app/datasets/new/actions.tsx: -------------------------------------------------------------------------------- 1 | "use server"; 2 | 3 | import * as y from "yup"; 4 | import {checkSessionAction} from "~/server/utils/session"; 5 | import {revalidatePath} from "next/cache"; 6 | import {uploadDataset} from "~/server/controller/new-dataset"; 7 | 8 | import type {State} from "./form"; 9 | import {logger} from "~/server/utils/observability/logtail"; 10 | 11 | const FormDataSchema = y.object({ 12 | name: y.string().required(), 13 | dropzoneFile: y.mixed().required(), 14 | }); 15 | 16 | export type FormDataType = y.InferType; 17 | 18 | function buildState(datasetId?: string, datasetName?: string, message?: string) { 19 | return { 20 | datasetId, 21 | datasetName, 22 | message, 23 | }; 24 | } 25 | 26 | export async function uploadDatasetAction(_: State, formData: FormData) { 27 | const session = await checkSessionAction(); 28 | 29 | let id: string | undefined; 30 | let name: string | undefined; 31 | try { 32 | const validatedForm = await FormDataSchema.validate(Object.fromEntries(formData)); 33 | 34 | // TODO: Make sure validatedForm.dropzoneFile is a File 35 | 36 | id = await uploadDataset(session.user.id, validatedForm); 37 | name = validatedForm.name; 38 | 39 | revalidatePath("/datasets"); 40 | } catch (e) { 41 | logger.warn("Failed to upload dataset", {error: e}); 42 | return buildState(undefined, undefined, (e as Error).message); 43 | } 44 | 45 | return buildState(id, name, undefined); 46 | } 47 | -------------------------------------------------------------------------------- /main/src/app/datasets/new/form.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import {useFormState} from "react-dom"; 3 | import Label from "~/components/form/label"; 4 | import TextField from "~/components/form/textfield"; 5 | import SubmitButton from "~/components/form/submit-button"; 6 | import Tooltip from "~/components/tooltip"; 7 | import {UploadIcon} from "lucide-react"; 8 | import {uploadDatasetAction} from "./actions"; 9 | 10 | const initialState: { 11 | datasetId?: string; 12 | datasetName?: string; 13 | message?: string; 14 | } = { 15 | datasetId: undefined, 16 | datasetName: undefined, 17 | message: undefined, 18 | }; 19 | 20 | export type State = typeof initialState; 21 | 22 | const tooltips = { 23 | datasetName: ( 24 | <>What do you want to call your dataset? The name should be unique and contain only letters, numbers, and dashes. 25 | ), 26 | dataset: ( 27 | <> 28 | The dataset you want to use for fine-tuning. The dataset should be in{" "} 29 | 30 | this format 31 | 32 | . We recommend using a dataset with at least 100 conversations. 33 | 34 | ), 35 | }; 36 | 37 | export default function NewDatasetForm({ 38 | setModalOpen, 39 | }: { 40 | setModalOpen: (open: boolean, datasetId?: string, datasetName?: string) => void; 41 | }) { 42 | const [state, formAction] = useFormState(uploadDatasetAction, initialState); 43 | 44 | if (state.datasetId) { 45 | setModalOpen(false, state.datasetId, state.datasetName || undefined); 46 | } 47 | 48 | return ( 49 | <> 50 |
51 |
52 |

53 | Learn more about the dataset format{" "} 54 | 55 | here. 56 | 57 |

58 | 59 |
60 |
61 | 62 | {tooltips.datasetName} 63 |
64 | 65 |
66 | 67 |
68 |
69 | 70 | {tooltips.dataset} 71 |
72 | {/* TODO: re-enable dropzone */} 73 | 78 |
79 |
80 | 81 |
82 | {
{state.message}
} 83 | 84 | 85 | Upload 86 | 87 |
88 |
89 | 90 | ); 91 | } 92 | -------------------------------------------------------------------------------- /main/src/app/datasets/new/modal.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import {XMarkIcon} from "@heroicons/react/20/solid"; 3 | import NewDatasetForm from "./form"; 4 | import Modal from "~/components/modal"; 5 | import {Dialog} from "@headlessui/react"; 6 | 7 | export default function NewDatasetModal({ 8 | open, 9 | setOpen, 10 | }: { 11 | open: boolean; 12 | setOpen: (open: boolean, datasetId?: string, datasetName?: string) => void; 13 | }) { 14 | return ( 15 | 16 | 19 | 20 | Upload dataset 21 | 22 | 23 | 24 | ); 25 | } 26 | -------------------------------------------------------------------------------- /main/src/app/datasets/page.tsx: -------------------------------------------------------------------------------- 1 | import Padding from "~/components/padding"; 2 | import PageHeading from "~/components/page-heading"; 3 | import Sidebar from "~/components/sidebar"; 4 | import {checkSession} from "~/server/utils/session"; 5 | import DatasetTable from "./table"; 6 | import {getDatasets} from "~/server/database/dataset"; 7 | 8 | import type {Dataset} from "@prisma/client"; 9 | 10 | function updatedAtToPrettyString(updatedAt: Date) { 11 | const now = new Date(); 12 | const diff = now.getTime() - updatedAt.getTime(); 13 | const diffInDays = diff / (1000 * 3600 * 24); 14 | const diffInWeeks = diffInDays / 7; 15 | 16 | if (diffInWeeks >= 3) { 17 | return `${Math.floor(diffInWeeks)} weeks ago`; 18 | } 19 | 20 | if (Math.floor(diffInDays) == 1) { 21 | return "1 day ago"; 22 | } 23 | 24 | if (diffInDays > 1) { 25 | return `${Math.floor(diffInDays)} days ago`; 26 | } 27 | 28 | const diffInHours = diff / (1000 * 3600); 29 | if (Math.floor(diffInHours) == 1) { 30 | return "1 hour ago"; 31 | } 32 | 33 | if (diffInHours > 1) { 34 | return `${Math.floor(diffInHours)} hours ago`; 35 | } 36 | 37 | const diffInMinutes = diff / (1000 * 60); 38 | if (Math.floor(diffInMinutes) == 1) { 39 | return "1 minute ago"; 40 | } 41 | 42 | if (diffInMinutes >= 1) { 43 | return `${Math.floor(diffInMinutes)} minutes ago`; 44 | } 45 | 46 | return "Just now"; 47 | } 48 | 49 | function filterPropsForTable(datasets: Dataset[]) { 50 | return datasets.map((dataset) => ({ 51 | id: dataset.id, 52 | name: dataset.name, 53 | rows: dataset.rows, 54 | created: updatedAtToPrettyString(dataset.createdAt), 55 | })); 56 | } 57 | 58 | export type DatasetTableProps = ReturnType; 59 | 60 | export default async function Page() { 61 | const session = await checkSession(); 62 | const datasets = await getDatasets(session.user.id); 63 | const filtered = filterPropsForTable(datasets); 64 | 65 | return ( 66 | 67 | 68 | Datasets 69 | 70 |
71 | 72 | 73 | ); 74 | } 75 | -------------------------------------------------------------------------------- /main/src/app/datasets/table.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import {PlusIcon} from "@heroicons/react/20/solid"; 3 | import Button from "~/components/form/button"; 4 | import {useState} from "react"; 5 | import NewDatasetModal from "./new/modal"; 6 | 7 | import type {DatasetTableProps} from "./page"; 8 | 9 | export default function DatasetTable({datasets}: {datasets: DatasetTableProps}) { 10 | const [newDatasetModalOpen, setNewDatasetModalOpen] = useState(false); 11 | 12 | function NewModelButton() { 13 | return ( 14 | 18 | ); 19 | } 20 | 21 | return ( 22 | <> 23 | 24 |
25 |
26 |
27 |

28 | You can create and manage your datasets here. Learn more about the dataset format{" "} 29 | 30 | in our docs. 31 | 32 |

33 |
34 |
35 | 36 |
37 |
38 |
39 |
40 |
41 | 42 | 43 | 44 | 47 | 50 | 53 | 54 | 55 | 56 | {datasets.map((dataset) => ( 57 | 58 | 61 | {/**/} 62 | 63 | 64 | {/* 65 | 70 | */} 71 | 72 | ))} 73 | 74 |
45 | Name 46 | 48 | Rows 49 | 51 | Created 52 |
59 | {dataset.name} 60 | {dataset.description}{dataset.rows}{dataset.created} 66 | 67 | View, {dataset.name} 68 | 69 |
75 |
76 |
77 |
78 |
79 | 80 | ); 81 | } 82 | -------------------------------------------------------------------------------- /main/src/app/layout.tsx: -------------------------------------------------------------------------------- 1 | import type {Metadata} from "next"; 2 | import "../styles/globals.css"; 3 | 4 | import {NextAuthProvider} from "./providers"; 5 | import Script from "next/script"; 6 | 7 | export const metadata: Metadata = { 8 | title: "Haven", 9 | description: "Fine-tune open source language models.", 10 | }; 11 | 12 | export default function RootLayout({children}: {children: React.ReactNode}) { 13 | return ( 14 | 15 | 19 | 20 | {children} 21 | 22 | 23 | ); 24 | } 25 | -------------------------------------------------------------------------------- /main/src/app/models/actions.tsx: -------------------------------------------------------------------------------- 1 | "use server"; 2 | import * as y from "yup"; 3 | 4 | import type {State} from "./export"; 5 | import {checkSessionAction} from "~/server/utils/session"; 6 | import {getModelFromId} from "~/server/database/model"; 7 | import {logger} from "~/server/utils/observability/logtail"; 8 | import {exportEndpoint} from "~/constants/modal"; 9 | 10 | const FormDataSchema = y.object({ 11 | modelId: y.string().required(), 12 | hfToken: y.string().required(), 13 | namespace: y.string().required(), 14 | name: y.string().required(), 15 | }); 16 | 17 | export async function exportModel(state: State, payload: FormData) { 18 | const session = await checkSessionAction(); 19 | 20 | // validate form data 21 | let validatedForm: y.InferType; 22 | try { 23 | validatedForm = await FormDataSchema.validate(Object.fromEntries(payload)); 24 | } catch (e: unknown) { 25 | const error = e as Error; 26 | logger.error("[exportModel] Invalid form data", {error: error.message}); 27 | return { 28 | success: false, 29 | error: error.message, 30 | }; 31 | } 32 | 33 | // check if model belongs to user 34 | const model = await getModelFromId(validatedForm.modelId); 35 | if (!model) { 36 | logger.error("[exportModel] Model not found", {modelId: validatedForm.modelId}); 37 | return { 38 | success: false, 39 | error: "Unexpected error", 40 | }; 41 | } 42 | 43 | if (model.userId !== session.user.id) { 44 | logger.error("[exportModel] User does not own model", {modelId: validatedForm.modelId}); 45 | return { 46 | success: false, 47 | error: "Unexpected error", 48 | }; 49 | } 50 | 51 | const body = { 52 | hf_name: validatedForm.namespace, 53 | model_name: validatedForm.name, 54 | model_id: validatedForm.modelId, 55 | base_model_name: model.baseModel, 56 | hf_token: validatedForm.hfToken, 57 | }; 58 | 59 | logger.info("[exportModel] Starting export", {body}); 60 | 61 | const res = await fetch(exportEndpoint, { 62 | method: "POST", 63 | body: JSON.stringify(body), 64 | headers: { 65 | "Content-Type": "application/json", 66 | }, 67 | }); 68 | 69 | if (res.status === 200) { 70 | logger.info("[exportModel] Exported model to huggingface", {body}); 71 | return { 72 | success: true, 73 | error: "Success!", 74 | }; 75 | } 76 | 77 | logger.error("[exportModel] Something went wrong when exporting", { 78 | status: res.status, 79 | request: body, 80 | report: await res.text(), 81 | }); 82 | return { 83 | success: false, 84 | error: "Something went wrong and the team has been notified.", 85 | }; 86 | } 87 | -------------------------------------------------------------------------------- /main/src/app/models/buttons.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import {ChatBubbleBottomCenterIcon, EllipsisVerticalIcon} from "@heroicons/react/20/solid"; 3 | import {ExclamationTriangleIcon} from "@heroicons/react/24/outline"; 4 | import {Loader2Icon} from "lucide-react"; 5 | import Image from "next/image"; 6 | import Link from "next/link"; 7 | import {useState} from "react"; 8 | import Button from "~/components/form/button"; 9 | import Dropdown from "~/components/form/dropdown"; 10 | import ExportModal from "./export"; 11 | 12 | export default function Buttons({ 13 | modelId, 14 | state, 15 | wandbUrl, 16 | hfToken, 17 | }: { 18 | modelId: string; 19 | state: "training" | "finished" | "error" | "online"; 20 | wandbUrl?: string; 21 | hfToken?: string; 22 | }) { 23 | const [isExportModalOpen, setIsExportModalOpen] = useState(false); 24 | 25 | /** 26 | * Process dropdown item selection 27 | */ 28 | function selectItem(item: string, modelId: string, wandbUrl?: string) { 29 | if (item === "Logs" && wandbUrl) { 30 | window.open(wandbUrl); 31 | } 32 | 33 | if (item === "Export") { 34 | setIsExportModalOpen(true); 35 | } 36 | } 37 | 38 | const wandbButton = ( 39 | 40 | 44 | 45 | ); 46 | 47 | const chatButton = ( 48 | 49 | 53 | 54 | ); 55 | 56 | if (state === "training" && wandbUrl) { 57 | return wandbButton; 58 | } 59 | 60 | if (state === "training") { 61 | return ; 62 | } 63 | 64 | if (state === "finished") { 65 | const options = wandbUrl ? ["Logs", "Export"] : ["Export"]; 66 | 67 | return ( 68 | <> 69 | 70 |
71 | {chatButton} 72 | selectItem(item, modelId, wandbUrl)} 75 | dropdownDirection="downLeft" 76 | > 77 |
78 | 79 |
80 |
81 |
82 | 83 | ); 84 | } 85 | 86 | // If model is hardcoded we just show the chat button. 87 | if (state === "online") { 88 | return chatButton; 89 | } 90 | 91 | // Error state 92 | return ; 93 | } 94 | -------------------------------------------------------------------------------- /main/src/app/models/export.tsx: -------------------------------------------------------------------------------- 1 | import {Dialog} from "@headlessui/react"; 2 | import {useFormState} from "react-dom"; 3 | import Label from "~/components/form/label"; 4 | import TextField from "~/components/form/textfield"; 5 | import Modal from "~/components/modal"; 6 | import Tooltip from "~/components/tooltip"; 7 | import {exportModel} from "./actions"; 8 | import {XMarkIcon} from "@heroicons/react/20/solid"; 9 | import SubmitButton from "~/components/form/submit-button"; 10 | 11 | export interface State { 12 | success: boolean; 13 | error: string; 14 | } 15 | 16 | const initialState: State = { 17 | success: false, 18 | error: "", 19 | }; 20 | 21 | const tooltipContent = ( 22 | <> 23 | Your Huggingface token is used to upload the model to your account. You can find your token{" "} 24 | 25 | here 26 | 27 | . 28 | 29 | ); 30 | 31 | export default function ExportModal({ 32 | open, 33 | setOpen, 34 | modelId, 35 | hfToken, 36 | }: { 37 | open: boolean; 38 | setOpen: (open: boolean) => void; 39 | modelId: string; 40 | hfToken?: string; 41 | }) { 42 | const [state, formAction] = useFormState(exportModel, initialState); 43 | 44 | return ( 45 | 46 |
47 | 50 | 51 | Export to Huggingface 52 | 53 |
54 | Export the selected model to Huggingface. The repository will include instructions on how to use the model. 55 |
56 |
57 | 58 | {tooltipContent} 59 |
60 | 61 | 67 | 68 | 69 | {/* Hidden field for model id */} 70 | 71 | 72 | {state.error && ( 73 |
74 | {state.error} 75 |
76 | )} 77 | {!state.success && Export} 78 | 79 |
80 | ); 81 | } 82 | -------------------------------------------------------------------------------- /main/src/app/models/list.tsx: -------------------------------------------------------------------------------- 1 | import {ArrowUpLeftIcon, ArrowUpRightIcon} from "@heroicons/react/24/outline"; 2 | import Buttons from "./buttons"; 3 | import type {ModelProps} from "./page"; 4 | 5 | const environments = { 6 | training: "text-yellow-400 bg-yellow-400/10 ring-yellow-400/20", 7 | finished: "text-green-500 bg-green-500/10 ring-green-500/30", 8 | error: "text-red-500 bg-red-500/10 ring-red-500/30", 9 | }; 10 | 11 | function classNames(...classes: string[]) { 12 | return classes.filter(Boolean).join(" "); 13 | } 14 | 15 | interface ModelListProps { 16 | header?: string; 17 | // typeof return type of getModels 18 | models: ModelProps[]; 19 | hfToken?: string; 20 | } 21 | 22 | function Empty() { 23 | return ( 24 |
25 |
26 | 27 |
28 |
29 | 30 |
31 |

32 | {"You haven't trained any models yet. "} 33 |
34 | 35 | Refer to the docs. 36 | 37 |

38 |
39 | ); 40 | } 41 | 42 | export default function List({header, models, hfToken}: ModelListProps) { 43 | return ( 44 | <> 45 | {header && ( 46 | <> 47 |
48 |
{header}
49 |
50 |
51 | 52 | )} 53 | 54 | {models.length ? ( 55 |
    56 | {models.map((model) => ( 57 |
  • 58 |
    59 |
    60 |

    61 |
    62 | {model.modelName} 63 |
    64 |

    65 |
    71 | {model.status.charAt(0).toUpperCase() + model.status.slice(1)} 72 |
    73 |
    74 |
    75 |

    {model.baseModel}

    76 | {model.datasetName && ( 77 | <> 78 |
    79 |

    {model.datasetName}

    80 | 81 | )} 82 |
    83 |
    84 | 90 |
  • 91 | ))} 92 |
93 | ) : ( 94 | 95 | )} 96 | 97 | ); 98 | } 99 | -------------------------------------------------------------------------------- /main/src/app/models/new/actions.tsx: -------------------------------------------------------------------------------- 1 | "use server"; 2 | import {calculatePrice, createNewModelTraining} from "~/server/controller/new-model"; 3 | import * as y from "yup"; 4 | import {checkSessionAction} from "~/server/utils/session"; 5 | import {revalidatePath} from "next/cache"; 6 | import {modelsToFinetune} from "~/constants/models"; 7 | 8 | import type {State} from "./form"; 9 | 10 | const FormDataSchema = y.object({ 11 | name: y.string().required(), 12 | datasetId: y.string().required(), 13 | learningRate: y.string().oneOf(["Low", "Medium", "High"]).required(), 14 | numberOfEpochs: y.number().min(1).required(), 15 | baseModel: y.string().oneOf(modelsToFinetune).required(), 16 | confirmedPrice: y.number(), 17 | }); 18 | 19 | export type FormDataType = y.InferType; 20 | 21 | export async function revalidate() { 22 | return Promise.resolve().then(() => revalidatePath("/models")); 23 | } 24 | 25 | function buildState( 26 | success: boolean, 27 | message: string | null, 28 | priceInCents: number | null, 29 | userHasEnoughCredits: boolean | null, 30 | formData: FormData | null, 31 | ): State { 32 | return { 33 | success, 34 | message, 35 | priceInCents, 36 | userHasEnoughCredits, 37 | formData, 38 | }; 39 | } 40 | 41 | export async function validateAndCalculatePrice(_: State, formData: FormData) { 42 | const session = await checkSessionAction(); 43 | 44 | try { 45 | const validatedForm = await FormDataSchema.validate(Object.fromEntries(formData)); 46 | const priceInCents = await calculatePrice(session, validatedForm); 47 | const userHasEnoughCredits = session.user.centsBalance >= priceInCents; 48 | return buildState(false, null, priceInCents, userHasEnoughCredits, formData); 49 | } catch (e) { 50 | return buildState(false, (e as Error).message, null, null, null); 51 | } 52 | } 53 | 54 | export async function startNewTraining(_: State, formData: FormData) { 55 | const session = await checkSessionAction(); 56 | 57 | try { 58 | const validatedForm = await FormDataSchema.validate(Object.fromEntries(formData)); 59 | await createNewModelTraining(session, validatedForm, validatedForm.confirmedPrice); 60 | 61 | return buildState(true, null, null, null, null); 62 | } catch (e: unknown) { 63 | return buildState(false, (e as Error).message, null, null, null); 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /main/src/app/models/new/confirm-price.tsx: -------------------------------------------------------------------------------- 1 | import Modal from "~/components/modal"; 2 | import {useFormState} from "react-dom"; 3 | import {revalidate, startNewTraining} from "./actions"; 4 | import SubmitButton from "~/components/form/submit-button"; 5 | import {useRouter} from "next/navigation"; 6 | 7 | import type {State} from "./form"; 8 | 9 | export default function ConfirmModal({ 10 | state, 11 | open, 12 | setOpen, 13 | }: { 14 | state: State; 15 | open: boolean; 16 | setOpen: (state: boolean) => void; 17 | }) { 18 | const [newState, formAction] = useFormState(startNewTraining, state); 19 | 20 | const router = useRouter(); 21 | 22 | if (newState.success) { 23 | void revalidate().then(() => router.push("/models")); 24 | } 25 | 26 | function formActionWrapper(_: FormData) { 27 | if (!state.priceInCents || !state.formData) { 28 | // TODO: handle 29 | return; 30 | } 31 | 32 | (state.formData as any as string[][]).push(["confirmedPrice", state.priceInCents.toString()]); 33 | formAction(state.formData); 34 | } 35 | 36 | return ( 37 | 38 |
Cost for this run
39 |
40 | Refer{" "} 41 | 42 | to our docs 43 | {" "} 44 | for our pricing calculation. 45 |
46 |
{`$${((state.priceInCents || 0) / 100).toFixed(2)}`}
49 | {state.userHasEnoughCredits ? ( 50 | <> 51 |
52 | The fine-tuning process will start and this amount will be deducted from your balance once you click 53 | "Confirm and start". 54 |
55 |
56 | If the training process stops due to a an internal error, your credits will be reimbursed. 57 |
58 | 59 | ) : ( 60 |
61 | {"You don't have enough balance left. Please first add more balance to your account."} 62 |
63 | )} 64 | 65 | {newState.message &&
{newState.message}
} 66 | {state.userHasEnoughCredits && ( 67 |
68 | Confirm and start 69 |
70 | )} 71 |
72 | ); 73 | } 74 | -------------------------------------------------------------------------------- /main/src/app/models/new/form.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import {ChevronDownIcon} from "@heroicons/react/20/solid"; 3 | import {useState} from "react"; 4 | import {useFormState} from "react-dom"; 5 | import Dropdown from "~/components/form/dropdown"; 6 | import Label from "~/components/form/label"; 7 | import TextField from "~/components/form/textfield"; 8 | import {validateAndCalculatePrice} from "./actions"; 9 | import ConfirmModal from "./confirm-price"; 10 | import SubmitButton from "~/components/form/submit-button"; 11 | import Tooltip from "~/components/tooltip"; 12 | import Search from "./search"; 13 | import {modelsToFinetune} from "~/constants/models"; 14 | 15 | const initialState: { 16 | success: boolean; 17 | priceInCents: number | null; 18 | userHasEnoughCredits: boolean | null; 19 | formData: FormData | null; 20 | message: string | null; 21 | } = { 22 | success: false, 23 | priceInCents: null, 24 | userHasEnoughCredits: null, 25 | formData: null, 26 | message: null, 27 | }; 28 | 29 | export type State = typeof initialState; 30 | 31 | const tooltips = { 32 | modelName: ( 33 | <>What do you want to call your model? The name should be unique and contain only letters, numbers, and dashes. 34 | ), 35 | dataset: ( 36 | <> 37 | The dataset you want to use for fine-tuning. The dataset should be in{" "} 38 | 39 | this format 40 | 41 | . We recommend using a dataset with at least 100 conversations. 42 | 43 | ), 44 | learningRate: ( 45 | <> 46 | Learning rate controls how much the model gets updated after each batch. In general, the larger the dataset, the 47 | lower the learning rate should be. 48 | 49 | ), 50 | numberOfEpochs: ( 51 | <> 52 | An epoch is one full pass through the dataset. The number of epochs controls how many times the model will see the 53 | dataset. Usually, a number between 1 and 5 is sufficient. 54 | 55 | ), 56 | baseModel: ( 57 | <> 58 | The base model is the model you want to fine-tune. Zephyr is a fine-tuned version of Mistral which we found to be 59 | very effective for further fine-tuning. 60 | 61 | ), 62 | }; 63 | 64 | export default function NewModelForm() { 65 | const [open, setOpen] = useState(false); 66 | const [state, formAction] = useFormState(validateAndCalculatePrice, initialState); 67 | 68 | function formActionSetOpen(formData: FormData) { 69 | setOpen(true); 70 | formAction(formData); 71 | } 72 | 73 | const [learningRate, setLearningRate] = useState(""); 74 | const [baseModel, setBaseModel] = useState(""); 75 | 76 | return ( 77 | <> 78 | 79 |
80 |
81 |
82 |

Create a new fine-tuned LLM

83 |

84 | If you need help, check out our fine-tuning guide which explains the different parameters. Fine-tuning is 85 | a very volatile process and usually requires a lot of experimentation. 86 |

87 | 88 |
89 |
90 | 91 | {tooltips.modelName} 92 |
93 | 94 |
95 | 96 |
97 |
98 | 99 | {tooltips.dataset} 100 |
101 | 102 |
103 | 104 |
105 |
106 | 107 | {tooltips.learningRate} 108 |
109 | setLearningRate(choice)}> 110 |
111 | {learningRate ? learningRate : "Choose"} 112 | 113 |
114 |
115 |
116 | 117 | {/* Dropown result as hidden field */} 118 | 119 | 120 |
121 |
122 | 123 | {tooltips.numberOfEpochs} 124 |
125 | 131 |
132 | 133 |
134 |
135 | 136 | {tooltips.baseModel} 137 |
138 | setBaseModel(choice)}> 139 |
140 | {baseModel ? baseModel : "Choose"} 141 | 142 |
143 |
144 |
145 | 146 | {/* Dropown result as hidden field */} 147 | 148 |
149 |
150 | 151 |
152 | Calculate price 153 | {
{state.message}
} 154 |
155 |
156 | 157 | ); 158 | } 159 | -------------------------------------------------------------------------------- /main/src/app/models/new/page.tsx: -------------------------------------------------------------------------------- 1 | import Padding from "~/components/padding"; 2 | import PageHeading from "~/components/page-heading"; 3 | import Sidebar from "~/components/sidebar"; 4 | import NewModelForm from "./form"; 5 | import {checkSession} from "~/server/utils/session"; 6 | 7 | export default async function Page() { 8 | await checkSession(); 9 | 10 | return ( 11 | 12 | 13 | Start model training 14 | 15 |
16 | 17 | 18 | 19 | 20 | ); 21 | } 22 | -------------------------------------------------------------------------------- /main/src/app/models/new/search.tsx: -------------------------------------------------------------------------------- 1 | import * as y from "yup"; 2 | import {Transition} from "@headlessui/react"; 3 | import {MagnifyingGlassIcon, PlusIcon} from "@heroicons/react/20/solid"; 4 | import {ArrowPathIcon, CheckIcon} from "@heroicons/react/24/outline"; 5 | import {Fragment, useState} from "react"; 6 | 7 | import type {SearchResponse} from "~/app/api/search/route"; 8 | import NewDatasetModal from "~/app/datasets/new/modal"; 9 | 10 | export default function Search() { 11 | const [loading, setLoading] = useState(false); 12 | const [open, setOpen] = useState(false); 13 | const [value, setValue] = useState(""); 14 | const [searchResponse, setSerachResponse] = useState([]); 15 | const [selected, setSelected] = useState(null); 16 | 17 | const [uploadDatasetModelOpen, setUploadDatasetModelOpen] = useState(false); 18 | function setUploadDatasetModelOpenWrapper(open: boolean, datasetId?: string, datasetName?: string) { 19 | setUploadDatasetModelOpen(open); 20 | if (datasetId && datasetName) { 21 | setSelected(datasetId); 22 | setValue(datasetName); 23 | setOpen(false); 24 | } 25 | } 26 | 27 | async function fetchSearchResults() { 28 | setLoading(true); 29 | const results = await fetch("/api/search", { 30 | method: "POST", 31 | body: JSON.stringify({query: value}), 32 | }); 33 | 34 | const schema = y 35 | .array( 36 | y 37 | .object({ 38 | id: y.string().required(), 39 | name: y.string().required(), 40 | rows: y.number().required(), 41 | }) 42 | .required(), 43 | ) 44 | .required(); 45 | 46 | const json = await results.json(); 47 | const validated = await schema.validate(json); 48 | 49 | setSerachResponse(validated); 50 | setLoading(false); 51 | } 52 | 53 | async function onOpenChange(open: boolean) { 54 | setOpen(open); 55 | 56 | if (open && searchResponse.length === 0) { 57 | await fetchSearchResults(); 58 | } 59 | } 60 | 61 | async function onValueChange(value: string) { 62 | setValue(value); 63 | await fetchSearchResults(); 64 | } 65 | 66 | function onSelectedChange(id: string, name: string) { 67 | setSelected(id); 68 | setValue(name); 69 | setOpen(false); 70 | } 71 | 72 | return ( 73 | <> 74 | 75 |
76 |
77 |
78 | 79 | onValueChange(e.target.value)} 85 | onFocus={() => onOpenChange(true)} 86 | onBlur={() => onOpenChange(false)} 87 | /> 88 |
89 | {loading && ( 90 |
91 | 92 |
93 | )} 94 | {selected && } 95 |
96 | 106 |
107 |
108 |
setUploadDatasetModelOpen(true)} 110 | className={`cursor-pointer flex items-center w-full px-3 py-2 hover:bg-gray-100 hover:dark:bg-gray-900 ${ 111 | searchResponse.length > 0 && "border-b-2" 112 | }`} 113 | > 114 | 115 |
upload new dataset
116 |
117 | {searchResponse.map((option, index) => ( 118 |
onSelectedChange(option.id, option.name)} 120 | key={index} 121 | className="cursor-pointer flex items-center hover:bg-gray-100 hover:dark:bg-gray-900" 122 | > 123 |
{option.name}
124 |
{`${option.rows} rows`}
125 |
126 | ))} 127 |
128 |
129 |
130 |
131 | {/* Hidden input containing the datasetId */} 132 | 133 | 134 | ); 135 | } 136 | -------------------------------------------------------------------------------- /main/src/app/models/page.tsx: -------------------------------------------------------------------------------- 1 | import Padding from "~/components/padding"; 2 | import PageHeading from "~/components/page-heading"; 3 | import Sidebar from "~/components/sidebar"; 4 | import List from "./list"; 5 | import Button from "~/components/form/button"; 6 | import {PlusIcon} from "@heroicons/react/20/solid"; 7 | import Link from "next/link"; 8 | import {getModels} from "~/server/database/model"; 9 | import {checkSession} from "~/server/utils/session"; 10 | 11 | import type {Dataset, Model} from "@prisma/client"; 12 | import {getDatasetById} from "~/server/database/dataset"; 13 | 14 | function newModelButton() { 15 | return ( 16 | 17 | 21 | 22 | ); 23 | } 24 | 25 | // db is a prisma client. we use the types from the model table for the models array 26 | function filterPropsForTable(models: Model[], datasetNames: Map) { 27 | return models.map((model) => ({ 28 | id: model.id, 29 | modelName: model.name, 30 | datasetName: datasetNames.get(model.datasetId || ""), 31 | status: model.state, 32 | baseModel: model.baseModel, 33 | wandbUrl: model.wandbUrl || undefined, 34 | })); 35 | } 36 | 37 | export type ModelProps = ReturnType[number]; 38 | 39 | const defaultModels: ModelProps[] = [ 40 | { 41 | id: "mixtral-8x7b", 42 | modelName: "Mixtral 8x7b Chat", 43 | status: "online", 44 | baseModel: "Chat with Mistral's new model - Limited availability", 45 | datasetName: undefined, 46 | wandbUrl: undefined, 47 | }, 48 | { 49 | id: "zephyr", 50 | modelName: "Zephyr 7b", 51 | status: "online", 52 | baseModel: "A chat fine-tune of Mistral's 7b base-model", 53 | datasetName: undefined, 54 | wandbUrl: undefined, 55 | }, 56 | { 57 | id: "llama2-7b", 58 | modelName: "Llama 2 7b Chat", 59 | status: "online", 60 | baseModel: "The chat version of Meta's popular Llama 2 model", 61 | datasetName: undefined, 62 | wandbUrl: undefined, 63 | }, 64 | ]; 65 | 66 | export default async function Page() { 67 | const session = await checkSession(); 68 | const models = await getModels(session.user.id); 69 | 70 | // get unique dataset Ids 71 | const datasetIds = Array.from(new Set(models.map((model) => model.datasetId))); 72 | 73 | // filter out null values 74 | const filteredDatasetIds = datasetIds.filter((id) => id !== null) as string[]; 75 | 76 | // for each dataset, get the dataset name 77 | const datasets = await Promise.all(filteredDatasetIds.map((id) => getDatasetById(id, session.user.id))); 78 | 79 | // create a map of dataset id to dataset name 80 | const datasetMap = new Map(); 81 | datasets.forEach((dataset) => { 82 | if (!dataset) { 83 | return; 84 | } 85 | 86 | datasetMap.set(dataset.id, dataset.name); 87 | }); 88 | 89 | return ( 90 | 91 | 92 | Models 93 | 94 |
95 | 96 |
97 | 98 | 99 | ); 100 | } 101 | -------------------------------------------------------------------------------- /main/src/app/page.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import {signIn} from "next-auth/react"; 3 | import {useEffect} from "react"; 4 | 5 | export default function Page() { 6 | useEffect(() => { 7 | void signIn(); 8 | }, []); 9 | 10 | return null; 11 | } 12 | -------------------------------------------------------------------------------- /main/src/app/providers.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import { SessionProvider } from "next-auth/react"; 4 | 5 | type Props = { 6 | children?: React.ReactNode; 7 | }; 8 | 9 | export const NextAuthProvider = ({ children }: Props) => { 10 | return {children}; 11 | }; 12 | -------------------------------------------------------------------------------- /main/src/app/settings/actions.tsx: -------------------------------------------------------------------------------- 1 | "use server"; 2 | import {getServerSession} from "next-auth"; 3 | import {revalidatePath} from "next/cache"; 4 | import {authOptions} from "~/server/auth"; 5 | import {updateHfToken, updateName} from "~/server/database/user"; 6 | import type {initialState} from "./form"; 7 | import {logger} from "~/server/utils/observability/logtail"; 8 | 9 | export async function submitForm(_: typeof initialState, formData: FormData) { 10 | const session = await getServerSession(authOptions); 11 | if (!session) { 12 | throw new Error("Not authenticated"); 13 | } 14 | 15 | // TODO: validate 16 | const name = formData.get("name"); 17 | const hfToken = formData.get("hf_token"); 18 | 19 | if (hfToken && typeof hfToken === "string") { 20 | logger.info("Updating hfToken"); 21 | await updateHfToken(session.user.id, hfToken.toString()); 22 | } 23 | 24 | if (name && typeof name === "string") { 25 | logger.info("Updating name", {oldName: session.user.name, name: name.toString()}); 26 | await updateName(session.user.id, name.toString()); 27 | } 28 | 29 | revalidatePath("/settings"); 30 | return { 31 | message: "Success!", 32 | color: "text-green-600", 33 | }; 34 | } 35 | -------------------------------------------------------------------------------- /main/src/app/settings/form.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import {useFormState, useFormStatus} from "react-dom"; 3 | import Button from "~/components/form/button"; 4 | import TextField from "~/components/form/textfield"; 5 | import {submitForm} from "./actions"; 6 | 7 | export const initialState: { 8 | message: string | null; 9 | color: string | null; 10 | } = { 11 | message: null, 12 | color: null, 13 | }; 14 | 15 | function SubmitButton() { 16 | const {pending} = useFormStatus(); 17 | 18 | return ( 19 | 22 | ); 23 | } 24 | 25 | export default function SettingsForm({name, hfToken}: {name?: string; hfToken?: string}) { 26 | const [state, formAction] = useFormState(submitForm, initialState); 27 | 28 | return ( 29 |
30 |
31 |
32 |

Account Information

33 |

34 | Update your account information. This information stays between you and us. 35 |

36 | 37 |
38 | 39 |
40 | 41 |
42 | 43 |
44 |
45 |
46 | 47 |
48 |
{state.message}
49 | 50 |
51 |
52 | ); 53 | } 54 | -------------------------------------------------------------------------------- /main/src/app/settings/page.tsx: -------------------------------------------------------------------------------- 1 | import Padding from "~/components/padding"; 2 | import PageHeading from "~/components/page-heading"; 3 | import Sidebar from "~/components/sidebar"; 4 | import {checkSession} from "~/server/utils/session"; 5 | import SettingsForm from "./form"; 6 | 7 | export default async function Settings() { 8 | const session = await checkSession(); 9 | 10 | const name = session.user.name ?? undefined; 11 | const hfToken = session.user.hfToken; 12 | 13 | return ( 14 | 15 | 16 | Settings 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | ); 25 | } 26 | -------------------------------------------------------------------------------- /main/src/components/form/button.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import {Loader2} from "lucide-react"; 3 | 4 | interface Props { 5 | id?: string; 6 | onClick?: () => void; 7 | loading?: boolean; 8 | className?: string; 9 | type?: "button" | "submit"; // set to submit for forms 10 | children?: React.ReactNode; 11 | } 12 | 13 | export default function Button({ 14 | id, 15 | onClick = undefined, 16 | loading = false, 17 | className = "", 18 | type = "button", 19 | children, 20 | }: Props) { 21 | return ( 22 | 32 | ); 33 | } 34 | -------------------------------------------------------------------------------- /main/src/components/form/dropdown.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import {Fragment} from "react"; 3 | import {Menu, Transition} from "@headlessui/react"; 4 | import Label from "./label"; 5 | 6 | interface DropdownProps { 7 | children: React.ReactNode; 8 | options: string[]; 9 | onSelect?: (selectedOption: string) => void; 10 | dropdownDirection?: "up" | "downRight" | "downLeft"; 11 | label?: string; 12 | } 13 | 14 | function classNames(...classes: string[]) { 15 | return classes.filter(Boolean).join(" "); 16 | } 17 | 18 | // TODO: we can simplify this, most of these styles are the same 19 | const styles = Object.freeze({ 20 | up: "absolute left-0 bottom-0 z-10 ml-3 mb-10 w-56 rounded-md bg-white dark:bg-black shadow-lg ring-1 ring-black ring-opacity-5 focus:outline-none", 21 | downLeft: 22 | "absolute right-0 z-10 mt-2 w-56 origin-top-right rounded-md bg-white dark:bg-black shadow-lg ring-1 ring-black ring-opacity-5 focus:outline-none", 23 | downRight: 24 | "absolute left-0 z-10 mt-2 w-56 origin-top-right rounded-md bg-white dark:bg-black shadow-lg ring-1 ring-black ring-opacity-5 focus:outline-none", 25 | }); 26 | 27 | export default function Dropdown({children, options, onSelect, label, dropdownDirection = "downRight"}: DropdownProps) { 28 | return ( 29 | <> 30 | {label !== "" && } 31 | 32 | {children} 33 | 34 | 43 | 44 | 58 | 59 | 60 | 61 | 62 | ); 63 | } 64 | -------------------------------------------------------------------------------- /main/src/components/form/file-upload.tsx: -------------------------------------------------------------------------------- 1 | import {ArrowUpTrayIcon} from "@heroicons/react/20/solid"; 2 | import React, {useState} from "react"; 3 | 4 | interface Props { 5 | onFile: (file: File) => void; 6 | } 7 | 8 | export default function FileUpload({onFile}: Props) { 9 | const [file, setFile] = useState(); 10 | const [dragging, setDragging] = useState(false); 11 | 12 | function setDraggingWithoutDefault(e: React.DragEvent, value: boolean) { 13 | e.preventDefault(); 14 | e.stopPropagation(); 15 | setDragging(value); 16 | } 17 | 18 | function fileAdded(newFile: File) { 19 | onFile(newFile); 20 | setFile(newFile); 21 | } 22 | 23 | function onDrop(e: React.DragEvent) { 24 | e.preventDefault(); 25 | e.stopPropagation(); 26 | 27 | const files = e.dataTransfer.files; 28 | 29 | if (files.length > 0) { 30 | fileAdded(files[0]!); 31 | } 32 | } 33 | 34 | function onFileChangeInternal(e: React.ChangeEvent) { 35 | const files = e.target.files; 36 | 37 | if (files?.[0]) { 38 | fileAdded(files[0]); 39 | } 40 | } 41 | 42 | const draggingType = "bg-gray-950"; 43 | const droppedType = "bg-gray-950"; 44 | const styles = dragging ? (file ? droppedType : draggingType) : ""; 45 | 46 | return ( 47 |
48 | 75 |
76 | ); 77 | } 78 | -------------------------------------------------------------------------------- /main/src/components/form/label.tsx: -------------------------------------------------------------------------------- 1 | export default function Label({className = "", children}: {className?: string; children: React.ReactNode}) { 2 | return ; 3 | } 4 | -------------------------------------------------------------------------------- /main/src/components/form/submit-button.tsx: -------------------------------------------------------------------------------- 1 | import {useFormStatus} from "react-dom"; 2 | import Button from "./button"; 3 | 4 | export default function SubmitButton({className, children}: {className?: string; children: React.ReactNode}) { 5 | const {pending} = useFormStatus(); 6 | 7 | return ( 8 | 11 | ); 12 | } 13 | -------------------------------------------------------------------------------- /main/src/components/form/textfield.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | interface TextFieldProps { 3 | value?: string; 4 | defaultValue?: string; 5 | onChange?: (value: any) => void; 6 | 7 | type?: "text" | "number" | "password" | "email"; 8 | name?: string; 9 | id?: string; 10 | label?: string; 11 | placeholder?: string; 12 | disabled?: boolean; 13 | required?: boolean; 14 | 15 | className?: string; 16 | } 17 | 18 | export default function TextField({ 19 | value = undefined, 20 | defaultValue = undefined, 21 | onChange = undefined, 22 | type = "text", 23 | name = "", 24 | id = "", 25 | label = "", 26 | placeholder = "", 27 | disabled = false, 28 | required = false, 29 | className = "", 30 | }: TextFieldProps) { 31 | return ( 32 |
33 | {label !== "" && ( 34 | 37 | )} 38 |
39 | onChange && onChange(e.target.value)} 50 | /> 51 |
52 |
53 | ); 54 | } 55 | -------------------------------------------------------------------------------- /main/src/components/modal.tsx: -------------------------------------------------------------------------------- 1 | import {Fragment} from "react"; 2 | import {Dialog, Transition} from "@headlessui/react"; 3 | 4 | interface Props { 5 | open: boolean; 6 | setOpen: (open: boolean) => void; 7 | size?: "sm" | "xl"; 8 | children?: React.ReactNode; 9 | } 10 | 11 | export default function Modal({open, children, setOpen, size = "sm"}: Props) { 12 | const sizeClasses = { 13 | sm: "sm:max-w-sm", 14 | xl: "sm:max-w-xl", 15 | }; 16 | 17 | return ( 18 | 19 | 20 | 29 |
30 | 31 | 32 |
33 |
34 | 43 | 46 | {children} 47 | 48 | 49 |
50 |
51 |
52 |
53 | ); 54 | } 55 | -------------------------------------------------------------------------------- /main/src/components/padding.tsx: -------------------------------------------------------------------------------- 1 | export default function Padding(props: {children: React.ReactNode}) { 2 | return
{props.children}
; 3 | } 4 | -------------------------------------------------------------------------------- /main/src/components/page-heading.tsx: -------------------------------------------------------------------------------- 1 | interface Props { 2 | children: React.ReactNode; 3 | primary?: React.ReactNode; 4 | } 5 | 6 | export default function PageHeading({children, primary}: Props) { 7 | return ( 8 |
9 |
10 |

{children}

11 |
12 |
{primary}
13 |
14 | ); 15 | } 16 | -------------------------------------------------------------------------------- /main/src/components/sidebar.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | import { 3 | Bars3Icon, 4 | Cog8ToothIcon, 5 | ChevronUpDownIcon, 6 | XMarkIcon, 7 | CreditCardIcon, 8 | ChatBubbleLeftRightIcon, 9 | TableCellsIcon, 10 | } from "@heroicons/react/24/outline"; 11 | 12 | import Dropdown from "./form/dropdown"; 13 | import {Dialog, Transition} from "@headlessui/react"; 14 | import {Fragment, useState} from "react"; 15 | import Link from "next/link"; 16 | import {signOut, useSession} from "next-auth/react"; 17 | import Image from "next/image"; 18 | import UserAvatar from "./user-avatar"; 19 | 20 | type SidebarItem = "Models" | "Datasets" | "Billing" | "Settings" | "None"; 21 | 22 | // TODO: these are copied form tailwindui so there are multiple in the codebase. 23 | // we should only use one 24 | function classNames(...classes: string[]) { 25 | return classes.filter(Boolean).join(" "); 26 | } 27 | 28 | interface Props { 29 | children: React.ReactNode; 30 | current: SidebarItem; 31 | title?: string; 32 | } 33 | 34 | export default function Sidebar(props: Props) { 35 | // TODO: pass in through props 36 | const email = useSession().data?.user.email || ""; 37 | 38 | const [sidebarOpen, setSidebarOpen] = useState(false); 39 | 40 | let navigation = [ 41 | { 42 | name: "Models", 43 | page: "/models", 44 | icon: ChatBubbleLeftRightIcon, 45 | current: false, 46 | }, 47 | { 48 | name: "Datasets", 49 | page: "/datasets", 50 | icon: TableCellsIcon, 51 | current: false, 52 | }, 53 | {name: "Billing", page: "/billing", icon: CreditCardIcon, current: false, headline: "Account"}, 54 | { 55 | name: "Settings", 56 | page: "/settings", 57 | icon: Cog8ToothIcon, 58 | current: false, 59 | }, 60 | ]; 61 | 62 | // Set current navigation item 63 | navigation.map((item) => { 64 | if (item.name === props.current) { 65 | item.current = true; 66 | } 67 | }); 68 | 69 | if (props.current === "None") { 70 | navigation = []; 71 | } 72 | 73 | return ( 74 | <> 75 |
76 | 77 | 78 | 87 |
88 | 89 | 90 |
91 | 100 | 101 | 110 |
111 | 115 |
116 |
117 |
118 |
119 | 120 | Haven 127 | 128 |
129 | 159 |
160 |
161 |
162 |
163 |
164 |
165 | 166 | {/* Static sidebar for desktop */} 167 |
168 | {/* Sidebar component, swap this element with another sidebar if you like */} 169 |
170 |
171 | 172 | Haven 173 | 174 |
175 | 222 |
223 |
224 | 225 |
226 |
setSidebarOpen(true)}> 227 | Open sidebar 228 |
230 |
{props.title}
231 | signOut({callbackUrl: "/"})}> 232 | Your profile 233 | 234 | 235 |
236 | 237 |
{props.children}
238 |
239 | 240 | ); 241 | } 242 | -------------------------------------------------------------------------------- /main/src/components/tiles.tsx: -------------------------------------------------------------------------------- 1 | import Image from "next/image"; 2 | 3 | export interface Tile { 4 | id: number; 5 | name: string; 6 | imageUrl: string; 7 | selected: boolean; 8 | onClick: () => void; 9 | } 10 | 11 | interface TileProps { 12 | tiles: Tile[]; 13 | } 14 | 15 | export default function Tiles({tiles}: TileProps) { 16 | return ( 17 |
    18 | {tiles.map((tile) => ( 19 | 36 | ))} 37 |
38 | ); 39 | } 40 | -------------------------------------------------------------------------------- /main/src/components/tooltip.tsx: -------------------------------------------------------------------------------- 1 | import {HoverCard, HoverCardContent, HoverCardTrigger} from "~/components/ui/hover-card"; 2 | import {QuestionMarkCircleIcon} from "@heroicons/react/24/outline"; 3 | 4 | export default function Tooltip({children, className}: {children: React.ReactElement; className?: string}) { 5 | return ( 6 |
7 | 8 | 9 | 10 | 11 | 15 | {children} 16 | 17 | 18 |
19 | ); 20 | } 21 | -------------------------------------------------------------------------------- /main/src/components/ui/hover-card.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import * as React from "react"; 4 | import * as HoverCardPrimitive from "@radix-ui/react-hover-card"; 5 | 6 | import {cn} from "~/components/utils/utils"; 7 | 8 | const HoverCard = HoverCardPrimitive.Root; 9 | 10 | const HoverCardTrigger = HoverCardPrimitive.Trigger; 11 | 12 | const HoverCardContent = React.forwardRef< 13 | React.ElementRef, 14 | React.ComponentPropsWithoutRef 15 | >(({className, align = "center", sideOffset = 4, ...props}, ref) => ( 16 | 26 | )); 27 | HoverCardContent.displayName = HoverCardPrimitive.Content.displayName; 28 | 29 | export {HoverCard, HoverCardTrigger, HoverCardContent}; 30 | -------------------------------------------------------------------------------- /main/src/components/ui/slider.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import * as React from "react"; 4 | import * as SliderPrimitive from "@radix-ui/react-slider"; 5 | 6 | import {cn} from "~/components/utils/utils"; 7 | 8 | const Slider = React.forwardRef< 9 | React.ElementRef, 10 | React.ComponentPropsWithoutRef 11 | >(({className, ...props}, ref) => ( 12 | 17 | 18 | 19 | 20 | 21 | 22 | )); 23 | Slider.displayName = SliderPrimitive.Root.displayName; 24 | 25 | export {Slider}; 26 | -------------------------------------------------------------------------------- /main/src/components/ui/switch.tsx: -------------------------------------------------------------------------------- 1 | "use client"; 2 | 3 | import * as React from "react"; 4 | import * as SwitchPrimitives from "@radix-ui/react-switch"; 5 | 6 | import {cn} from "~/components/utils/utils"; 7 | 8 | const Switch = React.forwardRef< 9 | React.ElementRef, 10 | React.ComponentPropsWithoutRef 11 | >(({className, ...props}, ref) => ( 12 | 20 | 25 | 26 | )); 27 | Switch.displayName = SwitchPrimitives.Root.displayName; 28 | 29 | export {Switch}; 30 | -------------------------------------------------------------------------------- /main/src/components/user-avatar.tsx: -------------------------------------------------------------------------------- 1 | export default function UserAvatar({name}: {name?: string}) { 2 | const nameWithDefault = name || ""; 3 | 4 | return ( 5 |
6 |
7 | {nameWithDefault.charAt(0).toUpperCase()} 8 |
9 |
10 | ); 11 | } 12 | -------------------------------------------------------------------------------- /main/src/components/utils/utils.ts: -------------------------------------------------------------------------------- 1 | import clsx from "clsx"; 2 | import {twMerge} from "tailwind-merge"; 3 | 4 | import type {ClassValue} from "clsx"; 5 | 6 | export function cn(...inputs: ClassValue[]) { 7 | return twMerge(clsx(inputs)); 8 | } 9 | -------------------------------------------------------------------------------- /main/src/components/warning.tsx: -------------------------------------------------------------------------------- 1 | import {ExclamationTriangleIcon} from "@heroicons/react/24/outline"; 2 | 3 | interface Props { 4 | message: string; 5 | } 6 | 7 | export default function Warning(props: Props) { 8 | return ( 9 |
10 |
11 |
13 |
{props.message}
14 |
15 | ); 16 | } 17 | -------------------------------------------------------------------------------- /main/src/constants/modal.ts: -------------------------------------------------------------------------------- 1 | import type {modelsToFinetune} from "./models"; 2 | 3 | export const inferenceEndpoints: Readonly> = Object.freeze({ 4 | "HuggingFaceH4/zephyr-7b-beta": 5 | "https://havenhq--lora-server-huggingfaceh4-zephyr-7b-beta-model--b80943.modal.run/generate_stream", 6 | "meta-llama/Llama-2-7b-chat-hf": 7 | "https://havenhq--lora-server-meta-llama-llama-2-7b-chat-hf-model-4eee5b.modal.run/generate_stream", 8 | "mistralai/Mixtral-8x7b-Instruct-v0.1": 9 | "https://havenhq--lora-server-mistralai-mixtral-8x7b-instruct-v0--9fe7b9.modal.run/generate_stream", 10 | }); 11 | 12 | export const exportEndpoint = "https://havenhq--model-export-export.modal.run"; 13 | 14 | export const trainEndpoint: Readonly> = { 15 | "HuggingFaceH4/zephyr-7b-beta": "https://havenhq--finetuning-service-train.modal.run", 16 | "meta-llama/Llama-2-7b-chat-hf": "https://havenhq--finetuning-service-train.modal.run", 17 | "mistralai/Mixtral-8x7b-Instruct-v0.1": "https://havenhq--mixtral-finetuning-train.modal.run", 18 | }; 19 | -------------------------------------------------------------------------------- /main/src/constants/models.ts: -------------------------------------------------------------------------------- 1 | // Exporting the selection of models both to the client side and the server side code. 2 | export const defaultModelLoopup = Object.freeze({ 3 | "mixtral-8x7b": "mistralai/Mixtral-8x7b-Instruct-v0.1", 4 | zephyr: "HuggingFaceH4/zephyr-7b-beta", 5 | "llama2-7b": "meta-llama/Llama-2-7b-chat-hf", 6 | }); 7 | 8 | export const modelsToFinetune = Object.values(defaultModelLoopup); 9 | export const modelIds = Object.keys(defaultModelLoopup); 10 | 11 | export type Models = (typeof modelsToFinetune)[number]; 12 | -------------------------------------------------------------------------------- /main/src/env.mjs: -------------------------------------------------------------------------------- 1 | import {createEnv} from "@t3-oss/env-nextjs"; 2 | import {z} from "zod"; 3 | 4 | export const env = createEnv({ 5 | /** 6 | * Specify your server-side environment variables schema here. This way you can ensure the app 7 | * isn't built with invalid env vars. 8 | */ 9 | server: { 10 | DATABASE_URL: z 11 | .string() 12 | .url() 13 | .refine((str) => !str.includes("YOUR_MYSQL_URL_HERE"), "You forgot to change the default URL"), 14 | NODE_ENV: z.enum(["development", "test", "production"]).default("development"), 15 | NEXTAUTH_SECRET: process.env.NODE_ENV === "production" ? z.string() : z.string().optional(), 16 | NEXTAUTH_URL: z.preprocess( 17 | // This makes Vercel deployments not fail if you don't set NEXTAUTH_URL 18 | // Since NextAuth.js automatically uses the VERCEL_URL if present. 19 | (str) => process.env.VERCEL_URL ?? str, 20 | // VERCEL_URL doesn't include `https` so it cant be validated as a URL 21 | process.env.VERCEL ? z.string() : z.string().url(), 22 | ), 23 | // Add ` on ID and SECRET if you want to make sure they're not empty 24 | GOOGLE_CLIENT_ID: z.string(), 25 | GOOGLE_CLIENT_SECRET: z.string(), 26 | 27 | STRIPE_SECRET_KEY: z.string(), 28 | STRIPE_PUBLISHABLE_KEY: z.string(), 29 | 30 | WEBHOOK_SECRET: z.string(), 31 | MODAL_HOSTNAME: z.string(), 32 | MODAL_AUTH_TOKEN: z.string(), 33 | RESEND_KEY: z.string(), 34 | LOGTAIL_KEY: z.string(), 35 | LOGTAIL_ENV: z.string(), 36 | }, 37 | 38 | /** 39 | * Specify your client-side environment variables schema here. This way you can ensure the app 40 | * isn't built with invalid env vars. To expose them to the client, prefix them with 41 | * `NEXT_PUBLIC_`. 42 | */ 43 | client: { 44 | // NEXT_PUBLIC_CLIENTVAR: z.string(), 45 | }, 46 | 47 | /** 48 | * You can't destruct `process.env` as a regular object in the Next.js edge runtimes (e.g. 49 | * middlewares) or client-side so we need to destruct manually. 50 | */ 51 | runtimeEnv: { 52 | DATABASE_URL: process.env.DATABASE_URL, 53 | NODE_ENV: process.env.NODE_ENV, 54 | NEXTAUTH_SECRET: process.env.NEXTAUTH_SECRET, 55 | NEXTAUTH_URL: process.env.NEXTAUTH_URL, 56 | GOOGLE_CLIENT_ID: process.env.GOOGLE_CLIENT_ID, 57 | GOOGLE_CLIENT_SECRET: process.env.GOOGLE_CLIENT_SECRET, 58 | STRIPE_SECRET_KEY: process.env.STRIPE_SECRET_KEY, 59 | STRIPE_PUBLISHABLE_KEY: process.env.STRIPE_PUBLISHABLE_KEY, 60 | WEBHOOK_SECRET: process.env.WEBHOOK_SECRET, 61 | MODAL_HOSTNAME: process.env.MODAL_HOSTNAME, 62 | MODAL_AUTH_TOKEN: process.env.MODAL_AUTH_TOKEN, 63 | RESEND_KEY: process.env.RESEND_KEY, 64 | LOGTAIL_KEY: process.env.LOGTAIL_KEY, 65 | LOGTAIL_ENV: process.env.LOGTAIL_ENV, 66 | }, 67 | /** 68 | * Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation. This is especially 69 | * useful for Docker builds. 70 | */ 71 | skipValidation: !!process.env.SKIP_ENV_VALIDATION, 72 | /** 73 | * Makes it so that empty strings are treated as undefined. 74 | * `SOME_VAR: z.string()` and `SOME_VAR=''` will throw an error. 75 | */ 76 | emptyStringAsUndefined: true, 77 | }); 78 | -------------------------------------------------------------------------------- /main/src/pages/_app.tsx: -------------------------------------------------------------------------------- 1 | // pages/_app.tsx 2 | import {SessionProvider} from "next-auth/react"; 3 | import "../styles/globals.css"; 4 | 5 | function MyApp({Component, pageProps}: any) { 6 | return ( 7 |
8 |
9 | 10 | 11 | 12 |
13 |
14 | ); 15 | } 16 | 17 | export default MyApp; 18 | -------------------------------------------------------------------------------- /main/src/pages/_document.tsx: -------------------------------------------------------------------------------- 1 | import Document, {Html, Head, Main, NextScript} from "next/document"; 2 | 3 | class MyDocument extends Document { 4 | render() { 5 | return ( 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | ); 14 | } 15 | } 16 | 17 | export default MyDocument; 18 | -------------------------------------------------------------------------------- /main/src/pages/auth/login/index.tsx: -------------------------------------------------------------------------------- 1 | import Button from "~/components/form/button"; 2 | import TextField from "~/components/form/textfield"; 3 | import {getCsrfToken, signIn} from "next-auth/react"; 4 | import type {GetServerSidePropsContext, InferGetServerSidePropsType} from "next"; 5 | import {useRouter} from "next/router"; 6 | import Image from "next/image"; 7 | import {getServerAuthSession} from "~/server/auth"; 8 | import {useEffect} from "react"; 9 | 10 | export default function SignIn({csrfToken, session}: InferGetServerSidePropsType) { 11 | const router = useRouter(); 12 | 13 | useEffect(() => { 14 | if (session.expires) { 15 | void router.push("/models"); 16 | } 17 | }); 18 | 19 | return ( 20 | <> 21 |
22 |
23 |
24 | Your Company 25 |

26 | Log in or create an account 27 |

28 |
29 | 30 |
31 |
32 |
33 | 34 | 35 | 36 | 37 |
38 | 41 |
42 | 43 | 44 |
45 |
46 | 53 |
54 | 72 |
73 |
74 | {"By signing in, you agree to Haven's "} 75 | 76 | terms of service 77 | 78 | , and{" "} 79 | 80 | privacy policy 81 | 82 | . 83 |
84 |
85 |
86 |
87 |
88 |
89 | 90 | ); 91 | } 92 | 93 | export async function getServerSideProps(context: GetServerSidePropsContext) { 94 | const csrfToken = await getCsrfToken(context); 95 | const session = await getServerAuthSession(context); 96 | 97 | return { 98 | props: { 99 | csrfToken, 100 | session: { 101 | expires: session?.expires || "", 102 | }, 103 | }, 104 | }; 105 | } 106 | -------------------------------------------------------------------------------- /main/src/pages/auth/new-user/index.tsx: -------------------------------------------------------------------------------- 1 | import {useRouter} from "next/router"; 2 | import {useEffect} from "react"; 3 | 4 | export default function Redirect() { 5 | const router = useRouter(); 6 | 7 | useEffect(() => { 8 | // Redirect to /models 9 | void router.push("/models"); 10 | }); 11 | 12 | return null; 13 | } 14 | -------------------------------------------------------------------------------- /main/src/pages/auth/verify-request/index.tsx: -------------------------------------------------------------------------------- 1 | import Image from "next/image"; 2 | 3 | export default function Login() { 4 | return ( 5 | <> 6 |
7 |
8 |
9 | Your Company 16 |
17 | 18 |
19 |
20 | Please open the link we just sent to your email address to be logged in. 21 |
22 |
23 |
24 |
25 | 26 | ); 27 | } 28 | -------------------------------------------------------------------------------- /main/src/server/auth.ts: -------------------------------------------------------------------------------- 1 | import {PrismaAdapter} from "@next-auth/prisma-adapter"; 2 | import {type GetServerSidePropsContext} from "next"; 3 | import {getServerSession, type DefaultSession, type NextAuthOptions} from "next-auth"; 4 | import EmailProvider from "next-auth/providers/email"; 5 | import GoogleProvider from "next-auth/providers/google"; 6 | 7 | import {db} from "~/server/database"; 8 | import {addApiKeyToUser} from "./database/user"; 9 | import type {User} from "@prisma/client"; 10 | import {sendVerifyEmail} from "./utils/mail"; 11 | import {EventName, sendEvent} from "./utils/observability/posthog"; 12 | 13 | /** 14 | * Module augmentation for `next-auth` types. Allows us to add custom properties to the `session` 15 | * object and keep type safety. 16 | * 17 | * @see https://next-auth.js.org/getting-started/typescript#module-augmentation 18 | */ 19 | declare module "next-auth" { 20 | interface Session extends DefaultSession { 21 | user: DefaultSession["user"] & { 22 | id: string; 23 | centsBalance: number; 24 | stripeCustomerId?: string; 25 | stripePaymentMethodId?: string; 26 | apiKey: string; 27 | hfToken?: string; 28 | image?: string; 29 | }; 30 | } 31 | 32 | // interface User { 33 | // // ...other properties 34 | // // role: UserRole; 35 | // } 36 | } 37 | 38 | /** 39 | * Options for NextAuth.js used to configure adapters, providers, callbacks, etc. 40 | * 41 | * @see https://next-auth.js.org/configuration/options 42 | */ 43 | export const authOptions: NextAuthOptions = { 44 | secret: process.env.NEXTAUTH_SECRET, 45 | callbacks: { 46 | session: ({session, user}) => ({ 47 | ...session, 48 | user: { 49 | ...session.user, 50 | id: user.id, 51 | centsBalance: (user as User).centsBalance, 52 | stripeCustomerId: (user as User).stripeCustomerId, 53 | stripePaymentMethodId: (user as User).stripePaymentMethodId, 54 | apiKey: (user as User).apiKey, 55 | hfToken: (user as User).hfToken, 56 | image: (user as User).image, 57 | }, 58 | }), 59 | }, 60 | adapter: PrismaAdapter(db), 61 | providers: [ 62 | EmailProvider({ 63 | sendVerificationRequest: async ({identifier, url}) => { 64 | await sendVerifyEmail(identifier, url); 65 | }, 66 | }), 67 | GoogleProvider({ 68 | clientId: process.env.GOOGLE_CLIENT_ID ?? "", 69 | clientSecret: process.env.GOOGLE_CLIENT_SECRET ?? "", 70 | }), 71 | /** 72 | * ...add more providers here. 73 | * 74 | * Most other providers require a bit more work than the Discord provider. For example, the 75 | * GitHub provider requires you to add the `refresh_token_expires_in` field to the Account 76 | * model. Refer to the NextAuth.js docs for the provider you want to use. Example: 77 | * 78 | * @see https://next-auth.js.org/providers/github 79 | */ 80 | ], 81 | session: {strategy: "database"}, 82 | pages: { 83 | signIn: "/auth/login", 84 | verifyRequest: "/auth/verify-request", 85 | newUser: "/auth/new-user", 86 | }, 87 | events: { 88 | createUser: async (message) => { 89 | sendEvent(message.user.id, EventName.NEW_USER, {email: message.user.email}); 90 | 91 | // Add an API key to the user when they sign up. 92 | await addApiKeyToUser(message.user.id); 93 | }, 94 | }, 95 | }; 96 | 97 | /** 98 | * Wrapper for `getServerSession` so that you don't need to import the `authOptions` in every file. 99 | * 100 | * @see https://next-auth.js.org/configuration/nextjs 101 | */ 102 | export const getServerAuthSession = (ctx: { 103 | req: GetServerSidePropsContext["req"]; 104 | res: GetServerSidePropsContext["res"]; 105 | }) => { 106 | return getServerSession(ctx.req, ctx.res, authOptions); 107 | }; 108 | -------------------------------------------------------------------------------- /main/src/server/controller/new-dataset.ts: -------------------------------------------------------------------------------- 1 | import {promises as fs} from "fs"; 2 | import {v4 as uuid} from "uuid"; 3 | import {checkFileValidity} from "./process-dataset"; 4 | import {createDataset} from "../database/dataset"; 5 | import {uploadFile} from "../utils/modal"; 6 | 7 | import type {FormDataType} from "~/app/datasets/new/actions"; 8 | import {logger} from "../utils/observability/logtail"; 9 | 10 | export async function uploadDataset(userId: string, validatedForm: FormDataType) { 11 | const datasetText = await (validatedForm.dropzoneFile as File).text(); 12 | const parsed = checkFileValidity(datasetText); 13 | 14 | if (parsed.length < 20) { 15 | logger.error("[uploadDataset] Dataset has less than 20 rows."); 16 | throw new Error("A dataset should have at least 20 rows."); 17 | } 18 | 19 | // Write dataset to temporary path 20 | const fileName = `${uuid()}.json`; 21 | const localPath = `./tmp/${fileName}`; 22 | 23 | await fs.mkdir("./tmp/").catch(() => {}); 24 | await fs.writeFile(localPath, datasetText); 25 | 26 | // Upload dataset 27 | await uploadFile(datasetText, fileName).catch((e) => { 28 | logger.error("[uploadDataset] Could not upload file.", {error: e}); 29 | throw new Error("Could not upload file."); 30 | }); 31 | 32 | // Create new database entry 33 | const dataset = await createDataset(userId, validatedForm.name, fileName, parsed.length); 34 | return dataset.id; 35 | } 36 | -------------------------------------------------------------------------------- /main/src/server/controller/new-model.ts: -------------------------------------------------------------------------------- 1 | import {createModel} from "~/server/database/model"; 2 | 3 | import {downloadFile} from "~/server/utils/modal"; 4 | import {decreaseBalance} from "~/server/database/user"; 5 | import {EventName, sendEvent} from "~/server/utils/observability/posthog"; 6 | import {checkFileValidity, processDataset} from "./process-dataset"; 7 | import {getDatasetById} from "../database/dataset"; 8 | import {logger} from "../utils/observability/logtail"; 9 | import {trainEndpoint} from "~/constants/modal"; 10 | 11 | import type {Session} from "next-auth"; 12 | import type {FormDataType} from "~/app/models/new/actions"; 13 | import type {Models} from "~/constants/models"; 14 | 15 | /** 16 | * Validate the file and calculate the price. 17 | */ 18 | async function validate(userId: string, datasetId: string, baseModel: Models, numberOfEpochs: number) { 19 | const datasetDb = await getDatasetById(datasetId, userId); 20 | if (!datasetDb) { 21 | logger.error("Dataset not found."); 22 | throw new Error("Dataset not found."); 23 | } 24 | 25 | console.log(`dataset exists, now downloading file ${datasetDb.fileName}`); 26 | 27 | const dataset = await downloadFile(datasetDb.fileName); 28 | 29 | console.log(`file downloaded, now checking validity`); 30 | 31 | const processedDataset = checkFileValidity(dataset); 32 | 33 | console.log(`file valid, now processing`); 34 | 35 | // Extract relevant metadata 36 | const metaData = await processDataset(processedDataset, baseModel, numberOfEpochs); 37 | 38 | console.log(`file processed, now returning`); 39 | 40 | return { 41 | ...metaData, 42 | datasetFileName: datasetDb.fileName, 43 | }; 44 | } 45 | 46 | /** 47 | * Calculate the price of the training job. 48 | * @param session 49 | * @param validatedForm 50 | */ 51 | export async function calculatePrice(session: Session, validatedForm: FormDataType) { 52 | console.log(`calculating price for ${validatedForm.baseModel}`); 53 | const {priceInCents} = await validate( 54 | session.user.id, 55 | validatedForm.datasetId, 56 | validatedForm.baseModel, 57 | validatedForm.numberOfEpochs, 58 | ); 59 | sendEvent(session.user.id, EventName.FINE_TUNE_PRICE_CALCULATED, {priceInCents, baseModel: validatedForm.baseModel}); 60 | return priceInCents; 61 | } 62 | 63 | /** 64 | * Create a new model training job. 65 | * @param session 66 | * @param validatedForm 67 | * @param confirmedPrice 68 | */ 69 | export async function createNewModelTraining(session: Session, validatedForm: FormDataType, confirmedPrice?: number) { 70 | const {maxTokens, gradientAccumulationSteps, perDeviceTrainBatchSize, priceInCents, datasetFileName} = await validate( 71 | session.user.id, 72 | validatedForm.datasetId, 73 | validatedForm.baseModel, 74 | validatedForm.numberOfEpochs, 75 | ); 76 | 77 | // Check that the price is correct 78 | if (priceInCents !== confirmedPrice) { 79 | throw new Error("Confirmed price does not match calculated price. Please try again."); 80 | } 81 | 82 | // Check that the user has enough money 83 | if (session.user.centsBalance < priceInCents) { 84 | throw new Error("You do not have enough money to train this model."); 85 | } 86 | 87 | // Update user balance 88 | await decreaseBalance(session.user.id, priceInCents, `Model ${validatedForm.name}`); 89 | 90 | // Update stuff in the database 91 | const model = await createModel( 92 | session.user.id, 93 | validatedForm.name, 94 | priceInCents, 95 | validatedForm.datasetId, 96 | validatedForm.learningRate, 97 | validatedForm.numberOfEpochs, 98 | validatedForm.baseModel, 99 | ); 100 | 101 | const pathToSend = `/datasets/${datasetFileName}`; 102 | 103 | // TODO: fix on Modal's side 104 | const model_name = 105 | validatedForm.baseModel === "mistralai/Mixtral-8x7b-Instruct-v0.1" 106 | ? "/pretrained_models/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/f1ca00645f0b1565c7f9a1c863d2be6ebf896b04" 107 | : validatedForm.baseModel; 108 | 109 | // Send request to create new job 110 | await fetch(trainEndpoint[validatedForm.baseModel], { 111 | method: "POST", 112 | headers: { 113 | "Content-Type": "application/json", 114 | }, 115 | body: JSON.stringify({ 116 | wandb_token: "", // TODO: let user provide this 117 | learning_rate: validatedForm.learningRate, 118 | num_epochs: validatedForm.numberOfEpochs, 119 | model_name, 120 | model_id: model.id, 121 | dataset_name: pathToSend, 122 | hf_repo: validatedForm.name, 123 | max_tokens: maxTokens, 124 | gradient_accumulation_steps: gradientAccumulationSteps, 125 | per_device_train_batch_size: perDeviceTrainBatchSize, 126 | auth_token: process.env.MODAL_AUTH_TOKEN, 127 | }), 128 | }); 129 | 130 | sendEvent(session.user.id, EventName.FINE_TUNE_STARTED, {priceInCents, baseModel: validatedForm.baseModel}); 131 | } 132 | -------------------------------------------------------------------------------- /main/src/server/controller/process-dataset.ts: -------------------------------------------------------------------------------- 1 | import * as y from "yup"; 2 | import {AutoTokenizer} from "@xenova/transformers"; 3 | import {createRepo, deleteRepo} from "@huggingface/hub"; 4 | import {logger} from "~/server/utils/observability/logtail"; 5 | 6 | import type {PreTrainedTokenizer} from "@xenova/transformers"; 7 | import type {Models} from "~/constants/models"; 8 | 9 | const fileSchema = y.object({ 10 | messages: y 11 | .array( 12 | y 13 | .object({ 14 | role: y.string().oneOf(["system", "user", "assistant"]).required(), 15 | content: y.string().required(), 16 | }) 17 | .required(), 18 | ) 19 | .required(), 20 | }); 21 | 22 | function validateMessages(messages: y.InferType["messages"]): void { 23 | if (!messages[0]) { 24 | logger.info("[validateMessages] No messages found"); 25 | throw new Error("There must be at least two messages."); 26 | } 27 | 28 | if (messages[0].role !== "system") { 29 | logger.info("[validateMessages] First message is not from system"); 30 | throw new Error("The first message has to have the role 'system'."); 31 | } 32 | 33 | let lastRole = "system"; 34 | 35 | for (const message of messages.slice(1)) { 36 | if (message.role !== "user" && message.role !== "assistant") { 37 | logger.info("[validateMessages] Invalid role", {role: message.role}); 38 | throw new Error('Invalid role. Role must be either "user" or "assistant".'); 39 | } 40 | 41 | if (message.role === lastRole) { 42 | logger.info("[validateMessages] Consecutive messages must have different roles"); 43 | throw new Error("Consecutive messages must have different roles."); 44 | } 45 | 46 | lastRole = message.role; 47 | } 48 | 49 | if (lastRole !== "assistant") { 50 | logger.info("[validateMessages] Last message is not from assistant"); 51 | throw new Error('The last message must be from "assistant".'); 52 | } 53 | } 54 | 55 | async function getTextLengths(dataset: y.InferType[], model: string) { 56 | const tokenizer = await AutoTokenizer.from_pretrained(model); 57 | const sampleLengths = await Promise.all(dataset.map((data) => getSampleLength(data, tokenizer))); 58 | 59 | return sampleLengths; 60 | } 61 | 62 | async function getSampleLength( 63 | messageObject: y.InferType, 64 | tokenizer: PreTrainedTokenizer, 65 | ): Promise { 66 | const combinedContent = messageObject.messages.map((message) => message.content).join(" "); 67 | const tokenizedContent: {input_ids: {size: number}} = await tokenizer(combinedContent); 68 | 69 | return tokenizedContent.input_ids.size; 70 | } 71 | 72 | function calculatePrice(model: Models, totalTokens: number, epochs: number) { 73 | const baseFeeCents = 0; 74 | 75 | let centsPerThousandTokens = 0.4; 76 | if (model === "mistralai/Mixtral-8x7b-Instruct-v0.1") { 77 | centsPerThousandTokens = 0.8; 78 | } 79 | 80 | const thousands = Math.ceil(totalTokens / 1000); 81 | return baseFeeCents + thousands * centsPerThousandTokens * epochs; 82 | } 83 | 84 | function calculateBatchSize(model: Models, longestTextLength: number) { 85 | function settings(gradientAccumulationSteps: number, perDeviceTrainBatchSize: number, maxTokens: number) { 86 | return { 87 | gradientAccumulationSteps, 88 | perDeviceTrainBatchSize, 89 | maxTokens: maxTokens, 90 | }; 91 | } 92 | 93 | const isMixtral = model === "mistralai/Mixtral-8x7b-Instruct-v0.1"; 94 | 95 | if (isMixtral) { 96 | return settings(16, 1, 2000); 97 | } 98 | 99 | if (longestTextLength < 500) { 100 | return settings(2, 8, 2600); 101 | } 102 | 103 | if (longestTextLength < 900) { 104 | return settings(4, 4, 2600); 105 | } 106 | 107 | if (longestTextLength < 1400) { 108 | return settings(8, 2, 2600); 109 | } 110 | 111 | return settings(16, 1, 2600); 112 | } 113 | 114 | /** 115 | * Makes sure a given dataset is valid. 116 | * @param fileContent 117 | * @returns - The validated and parsed dataset. 118 | */ 119 | export function checkFileValidity(fileContent: string) { 120 | const validatedDataset: y.InferType[] = []; 121 | for (const line of fileContent.split("\n")) { 122 | try { 123 | if (!line) { 124 | continue; 125 | } 126 | 127 | const jsonData = JSON.parse(line); 128 | 129 | const validated = fileSchema.validateSync(jsonData); 130 | validateMessages(validated.messages); 131 | 132 | validatedDataset.push(validated); 133 | } catch (e) { 134 | logger.error("Invalid JSON structure", {line, error: e}); 135 | throw e; 136 | } 137 | } 138 | 139 | return validatedDataset; 140 | } 141 | 142 | /** 143 | * Calculates gradient accumulation steps, batch size and price for a given dataset. 144 | * @param fileContent 145 | * @param modelName 146 | * @param hfToken 147 | * @param epochs 148 | * @returns 149 | */ 150 | export async function processDataset(fileContent: y.InferType[], modelName: Models, epochs: number) { 151 | const sampleLengths = await getTextLengths(fileContent, modelName); 152 | const totalTokens = sampleLengths.reduce((a, b) => a + b, 0); 153 | const priceInCents = calculatePrice(modelName, totalTokens, epochs); 154 | 155 | const maxLength = Math.max(...sampleLengths); 156 | const {maxTokens, gradientAccumulationSteps, perDeviceTrainBatchSize} = calculateBatchSize(modelName, maxLength); 157 | 158 | return { 159 | gradientAccumulationSteps, 160 | perDeviceTrainBatchSize, 161 | maxTokens, 162 | priceInCents, 163 | }; 164 | } 165 | 166 | /** 167 | * Checks if a huggingface repo name is available by trying to create and delete it. 168 | * @param hfToken 169 | * @param name - shape: "username/repo-name" 170 | */ 171 | export async function checkRepoNameAvailability(hfToken: string, name: string) { 172 | try { 173 | await createRepo({ 174 | repo: { 175 | name, 176 | type: "model", 177 | }, 178 | credentials: { 179 | accessToken: hfToken, 180 | }, 181 | private: true, 182 | }); 183 | await deleteRepo({ 184 | repo: { 185 | name, 186 | type: "model", 187 | }, 188 | credentials: { 189 | accessToken: hfToken, 190 | }, 191 | }); 192 | } catch (e) { 193 | logger.error("Repo name already taken", {name, error: e}); 194 | throw new Error("This repo name is already taken."); 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /main/src/server/controller/stripe.ts: -------------------------------------------------------------------------------- 1 | import Stripe from "stripe"; 2 | 3 | const stripeSecretKey = process.env.STRIPE_SECRET_KEY ?? ""; 4 | 5 | /** 6 | * How setting up a payment method with Stripe works: 7 | * 8 | * 1. Client sends credit card info to Stripe 9 | * 2. Stripe returns a payment method ID 10 | * 3. Client sends payment method ID to server 11 | * 4. Server creates a Setup Intent with the payment method ID 12 | * 5. Server returns the client secret of the setup intent to the client 13 | * 6. Client uses the client secret to authenticate the payment method, sometimes with 3D secure 14 | * 7. Client sends the payment method ID and intent id back to the server, confirming that the payment method is authenticated 15 | * 8. Server validates that the payment method is authenticated 16 | * 9. Server creates a Stripe customer with the payment method ID 17 | */ 18 | 19 | const stripe = new Stripe(stripeSecretKey, { 20 | apiVersion: "2023-10-16", 21 | appInfo: { 22 | name: "Haven NextJS", 23 | version: "0.0.1", 24 | }, 25 | }); 26 | 27 | /** 28 | * Step 4 of setting up a payment method with Stripe 29 | */ 30 | export async function createSetupIntent(paymentMethodId: string) { 31 | return stripe.setupIntents.create({ 32 | payment_method: paymentMethodId, 33 | }); 34 | } 35 | 36 | /** 37 | * Step 8. 38 | */ 39 | export async function validateSetupIntent(setupIntentClientSecret: string) { 40 | const setupIntentId = setupIntentClientSecret.split("_secret")[0]; 41 | if (!setupIntentId) { 42 | throw new Error("Invalid setup intent secret"); 43 | } 44 | 45 | const setupIntent = await stripe.setupIntents.retrieve(setupIntentId); 46 | return setupIntent && setupIntent.status === "succeeded"; 47 | } 48 | 49 | /** 50 | * Step 9. 51 | */ 52 | export async function createStripeCustomerAndAttachPaymentMethod( 53 | name: string, 54 | email: string, 55 | paymentMethodId: string, 56 | customerId?: string, 57 | ) { 58 | let customer: Stripe.Response; 59 | 60 | if (customerId) { 61 | const existingCustomer = await stripe.customers.retrieve(customerId); 62 | 63 | // TODO: We should automatically recover from this at some point. 64 | if (!existingCustomer || existingCustomer.deleted) { 65 | throw new Error("Customer used to exist but was deleted."); 66 | } 67 | 68 | customer = existingCustomer; 69 | } else { 70 | customer = await stripe.customers.create({ 71 | name, 72 | email, 73 | }); 74 | } 75 | 76 | await stripe.paymentMethods.attach(paymentMethodId, {customer: customer.id}); 77 | 78 | return customer; 79 | } 80 | 81 | export async function getStripeCreditCardInfo(customerId: string) { 82 | const customer = await stripe.customers.retrieve(customerId); 83 | if (!customer || customer.deleted) { 84 | throw new Error("Stripe customer not found"); 85 | } 86 | 87 | // Get payment methods of customer 88 | const paymentMethods = await stripe.paymentMethods.list({ 89 | customer: customerId, 90 | type: "card", 91 | }); 92 | 93 | // TODO: verify that the first payment method is a card 94 | if (!paymentMethods.data[0]) { 95 | throw new Error("No payment methods found"); 96 | } 97 | 98 | const card = paymentMethods.data[0].card!; 99 | 100 | return { 101 | last4: card.last4, 102 | expiry: `${card.exp_month}/${card.exp_year}`, 103 | }; 104 | } 105 | 106 | export function createPaymentIntent(customerId: string, paymentMethodId: string, amount: number, description: string) { 107 | return stripe.paymentIntents.create({ 108 | customer: customerId, 109 | amount, 110 | currency: "usd", 111 | description, 112 | off_session: true, 113 | confirm: true, 114 | payment_method: paymentMethodId, 115 | }); 116 | } 117 | -------------------------------------------------------------------------------- /main/src/server/database/chat-request.ts: -------------------------------------------------------------------------------- 1 | import {db} from "."; 2 | 3 | /** 4 | * Returns the number of chat requests the user has made in the last 24 hours. 5 | */ 6 | export async function getNumberOfChatRequestsInLast24Hours(userId: string) { 7 | const now = new Date(); 8 | const yesterday = new Date(now.getTime() - 24 * 60 * 60 * 1000); 9 | 10 | return db.chatRequest.count({ 11 | where: { 12 | userId: userId, 13 | createdAt: { 14 | gte: yesterday, 15 | }, 16 | }, 17 | }); 18 | } 19 | 20 | /** 21 | * Creates a new chat request. 22 | */ 23 | export async function createChatRequest(userId: string, modelId: string) { 24 | return db.chatRequest.create({ 25 | data: { 26 | userId, 27 | modelId, 28 | }, 29 | }); 30 | } 31 | -------------------------------------------------------------------------------- /main/src/server/database/dataset.ts: -------------------------------------------------------------------------------- 1 | import {db} from "."; 2 | 3 | export async function createDataset(userId: string, name: string, fileName: string, rows: number) { 4 | return db.dataset.create({ 5 | data: { 6 | userId, 7 | name, 8 | fileName, 9 | rows, 10 | }, 11 | }); 12 | } 13 | 14 | export async function getDatasets(userId: string) { 15 | return db.dataset.findMany({ 16 | where: { 17 | userId, 18 | }, 19 | orderBy: { 20 | createdAt: "desc", 21 | }, 22 | }); 23 | } 24 | 25 | // Note: this might become slow at some point 26 | export async function getDatasetsByNameForUser(userId: string, name: string) { 27 | return await db.dataset.findMany({ 28 | where: { 29 | userId, 30 | name: { 31 | contains: name, 32 | }, 33 | }, 34 | orderBy: { 35 | createdAt: "desc", 36 | }, 37 | take: 5, 38 | }); 39 | } 40 | 41 | export async function getDatasetById(id: string, userId: string) { 42 | return db.dataset.findUnique({ 43 | where: { 44 | id, 45 | userId, 46 | }, 47 | }); 48 | } 49 | -------------------------------------------------------------------------------- /main/src/server/database/index.ts: -------------------------------------------------------------------------------- 1 | import { PrismaClient } from "@prisma/client"; 2 | 3 | import { env } from "~/env.mjs"; 4 | 5 | const globalForPrisma = globalThis as unknown as { 6 | prisma: PrismaClient | undefined; 7 | }; 8 | 9 | export const db = 10 | globalForPrisma.prisma ?? 11 | new PrismaClient({ 12 | log: 13 | env.NODE_ENV === "development" ? ["query", "error", "warn"] : ["error"], 14 | }); 15 | 16 | if (env.NODE_ENV !== "production") globalForPrisma.prisma = db; 17 | -------------------------------------------------------------------------------- /main/src/server/database/model.ts: -------------------------------------------------------------------------------- 1 | import {db} from "."; 2 | 3 | export async function getModels(userId: string) { 4 | return db.model.findMany({ 5 | where: { 6 | userId, 7 | }, 8 | }); 9 | } 10 | 11 | export async function getModelFromId(id: string) { 12 | return db.model.findUnique({ 13 | where: { 14 | id, 15 | }, 16 | }); 17 | } 18 | 19 | export async function createModel( 20 | userId: string, 21 | name: string, 22 | costInCents: number, 23 | datasetId: string, 24 | learningRate: string, 25 | epochs: number, 26 | baseModel: string, 27 | ) { 28 | return db.model.create({ 29 | data: { 30 | userId, 31 | name, 32 | costInCents, 33 | datasetId, 34 | learningRate, 35 | epochs, 36 | baseModel, 37 | }, 38 | }); 39 | } 40 | 41 | export async function addWandBUrl(modelId: string, wandbUrl: string) { 42 | return db.model.update({ 43 | where: { 44 | id: modelId, 45 | }, 46 | data: { 47 | wandbUrl, 48 | }, 49 | }); 50 | } 51 | 52 | export async function updateState(modelId: string, state: "training" | "finished" | "error") { 53 | return db.model.update({ 54 | where: { 55 | id: modelId, 56 | }, 57 | data: { 58 | state, 59 | }, 60 | }); 61 | } 62 | 63 | export async function getJobCostInCents(modelId: string) { 64 | return ( 65 | await db.model.findUnique({ 66 | where: { 67 | id: modelId, 68 | }, 69 | select: { 70 | costInCents: true, 71 | }, 72 | }) 73 | )?.costInCents; 74 | } 75 | -------------------------------------------------------------------------------- /main/src/server/database/user.ts: -------------------------------------------------------------------------------- 1 | import {createHash} from "crypto"; 2 | import {db} from "."; 3 | 4 | import {v4 as uuid} from "uuid"; 5 | 6 | export function hash(value: string) { 7 | return createHash("sha256").update(value).digest("hex"); 8 | } 9 | 10 | export async function getUserFromId(id: string) { 11 | return db.user.findUnique({ 12 | where: { 13 | id, 14 | }, 15 | }); 16 | } 17 | 18 | /** 19 | * Create an api key for a user 20 | */ 21 | export async function addApiKeyToUser(userId: string) { 22 | const id = uuid(); 23 | 24 | await db.user.update({ 25 | where: { 26 | id: userId, 27 | }, 28 | data: { 29 | apiKey: id, 30 | }, 31 | }); 32 | 33 | return id; 34 | } 35 | 36 | export async function updateHfToken(userId: string, hfToken: string) { 37 | await db.user.update({ 38 | where: { 39 | id: userId, 40 | }, 41 | data: { 42 | hfToken, 43 | }, 44 | }); 45 | } 46 | 47 | export function updateEmail(_: string, __: string) { 48 | // Warning for future Konsti: we might need to update the email on the Stripe customer too 49 | throw new Error("Not implemented"); 50 | } 51 | 52 | export async function updateName(userId: string, name: string) { 53 | await db.user.update({ 54 | where: { 55 | id: userId, 56 | }, 57 | data: { 58 | name, 59 | }, 60 | }); 61 | } 62 | 63 | /** 64 | * Add stripe customer id to user. 65 | */ 66 | export async function addStripeCustomerIdAndPaymentMethodToUser( 67 | id: string, 68 | stripeCustomerId: string, 69 | paymentMethodId: string, 70 | ) { 71 | return db.user.update({ 72 | where: { 73 | id: id, 74 | }, 75 | data: { 76 | stripeCustomerId: stripeCustomerId, 77 | stripePaymentMethodId: paymentMethodId, 78 | }, 79 | }); 80 | } 81 | 82 | /** 83 | * Increments users cents balance by given amount. 84 | */ 85 | export async function increaseBalance(id: string, cents: number, message?: string) { 86 | // TODO: move to controller 87 | await db.transaction.create({ 88 | data: { 89 | userId: id, 90 | amount: cents, 91 | reason: message || "Account top-up", 92 | }, 93 | }); 94 | 95 | return db.user.update({ 96 | where: { 97 | id: id, 98 | }, 99 | data: { 100 | centsBalance: { 101 | increment: cents, 102 | }, 103 | }, 104 | }); 105 | } 106 | 107 | /** 108 | * Decrements users cents balance by given amount. 109 | */ 110 | export async function decreaseBalance(id: string, cents: number, reason: string) { 111 | // TODO: move to controller 112 | await db.transaction.create({ 113 | data: { 114 | userId: id, 115 | amount: -cents, 116 | reason, 117 | }, 118 | }); 119 | 120 | return db.user.update({ 121 | where: { 122 | id: id, 123 | }, 124 | data: { 125 | centsBalance: { 126 | decrement: cents, 127 | }, 128 | }, 129 | }); 130 | } 131 | -------------------------------------------------------------------------------- /main/src/server/utils/mail.ts: -------------------------------------------------------------------------------- 1 | import {Resend} from "resend"; 2 | 3 | const resend = new Resend(process.env.RESEND_KEY); 4 | 5 | async function sendMail(to: string, subject: string, html: string) { 6 | return resend.emails.send({ 7 | from: "do-reply@haven.run", 8 | to, 9 | subject, 10 | html, 11 | }); 12 | } 13 | 14 | export async function sendVerifyEmail(email: string, url: string) { 15 | const message = `Someone is trying to log into Haven using your email address. If it's you, click here. If not, you can safely ignore this email or respond to let us know.`; 16 | await sendMail(email, "Log in to Haven", message); 17 | } 18 | -------------------------------------------------------------------------------- /main/src/server/utils/modal.ts: -------------------------------------------------------------------------------- 1 | const HOSTNAME = process.env.MODAL_HOSTNAME || ""; 2 | 3 | export async function uploadFile(file: string, filename: string) { 4 | const body = JSON.stringify({ 5 | file, 6 | filename, 7 | }); 8 | 9 | const response = await fetch(`${HOSTNAME}/upload`, { 10 | method: "POST", 11 | body, 12 | headers: { 13 | "Content-Type": "application/json", 14 | }, 15 | }); 16 | 17 | if (!response.ok) { 18 | throw new Error("Could not upload file."); 19 | } 20 | 21 | return response.text(); 22 | } 23 | 24 | export async function downloadFile(fileName: string) { 25 | const body = JSON.stringify({ 26 | fileName, 27 | }); 28 | 29 | const response = await fetch(`${HOSTNAME}/download`, { 30 | method: "POST", 31 | body, 32 | headers: { 33 | "Content-Type": "application/json", 34 | }, 35 | }); 36 | 37 | if (!response.ok) { 38 | throw new Error("Could not download file."); 39 | } 40 | 41 | return response.text(); 42 | } 43 | -------------------------------------------------------------------------------- /main/src/server/utils/observability/logtail.ts: -------------------------------------------------------------------------------- 1 | import {Logtail} from "@logtail/node"; 2 | import {getServerSession} from "next-auth"; 3 | import {authOptions} from "../../auth"; 4 | 5 | export const logtail = new Logtail(process.env.LOGTAIL_KEY || ""); 6 | 7 | // Don't send logs to Logtail in development 8 | function ifProd(fn: () => void) { 9 | if (process.env.LOGTAIL_ENV === "production") { 10 | fn(); 11 | } 12 | } 13 | 14 | interface Data { 15 | userId?: string; 16 | [key: string]: any; 17 | } 18 | 19 | function createLogger(level: (typeof levels)[number]) { 20 | return (message: string, data: Data = {}) => { 21 | void getServerSession(authOptions).then((session) => { 22 | const payload = { 23 | ...data, 24 | userId: data.userId ? data.userId : session?.user?.id, 25 | }; 26 | 27 | console[level](message, payload); 28 | ifProd(() => void logtail[level](message, payload)); 29 | }); 30 | }; 31 | } 32 | 33 | const levels: ["error", "warn", "info", "debug"] = ["error", "warn", "info", "debug"]; 34 | 35 | export const logger = { 36 | error: createLogger("error"), 37 | warn: createLogger("warn"), 38 | info: createLogger("info"), 39 | debug: createLogger("debug"), 40 | }; 41 | -------------------------------------------------------------------------------- /main/src/server/utils/observability/posthog.ts: -------------------------------------------------------------------------------- 1 | import {PostHog} from "posthog-node"; 2 | 3 | const client = new PostHog("phc_YpKoFD7smPe4SXRtVyMW766uP9AjUwnuRJ8hh2EJcVv", {host: "https://eu.posthog.com"}); 4 | 5 | export enum EventName { 6 | NEW_USER = "new-user", 7 | 8 | CREDIT_CARD_ADDED = "credit-card-added", 9 | MONEY_ADDED = "money-added", 10 | 11 | FINE_TUNE_PRICE_CALCULATED = "fine-tune-price-calculated", 12 | FINE_TUNE_STARTED = "fine-tune-started", 13 | FINE_TUNE_FINISHED = "fine-tune-finished", 14 | FINE_TUNE_FAILED = "fine-tune-failed", 15 | FINE_TUNE_WANDB_ADDED = "fine-tune-wandb-added", 16 | 17 | INFERENCE_REQUEST = "inference-request", 18 | 19 | EMERGENCY = "emergency", 20 | } 21 | 22 | export function sendEvent(userId: string, eventName: EventName, eventProperties: object = {}) { 23 | if (process.env.LOGTAIL_ENV !== "production") { 24 | return; 25 | } 26 | 27 | try { 28 | client.capture({ 29 | distinctId: userId, 30 | event: eventName, 31 | properties: eventProperties, 32 | }); 33 | } catch (e) { 34 | console.error(e); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /main/src/server/utils/session.ts: -------------------------------------------------------------------------------- 1 | import {authOptions} from "../auth"; 2 | import {getServerSession} from "next-auth"; 3 | import {redirect} from "next/navigation"; 4 | 5 | /** 6 | * Retrieves session and redirects to login if not found. 7 | */ 8 | export async function checkSession() { 9 | const session = await getServerSession(authOptions); 10 | if (!session) { 11 | redirect("/"); 12 | } 13 | 14 | return session; 15 | } 16 | 17 | /** 18 | * Retrieves session and throws an error if not found. 19 | */ 20 | export async function checkSessionAction() { 21 | const session = await getServerSession(authOptions); 22 | if (!session) { 23 | throw new Error("Unauthorized"); 24 | } 25 | return session; 26 | } 27 | -------------------------------------------------------------------------------- /main/src/styles/globals.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; 4 | 5 | @layer base { 6 | :root { 7 | --background: 0 0% 100%; 8 | --foreground: 0 0% 3.9%; 9 | --card: 0 0% 100%; 10 | --card-foreground: 0 0% 3.9%; 11 | --popover: 0 0% 100%; 12 | --popover-foreground: 0 0% 3.9%; 13 | --primary: 0 0% 9%; 14 | --primary-foreground: 0 0% 98%; 15 | --secondary: 0 0% 96.1%; 16 | --secondary-foreground: 0 0% 9%; 17 | --muted: 0 0% 96.1%; 18 | --muted-foreground: 0 0% 45.1%; 19 | --accent: 0 0% 96.1%; 20 | --accent-foreground: 0 0% 9%; 21 | --destructive: 0 84.2% 60.2%; 22 | --destructive-foreground: 0 0% 98%; 23 | --border: 0 0% 89.8%; 24 | --input: 0 0% 89.8%; 25 | --ring: 0 0% 3.9%; 26 | --radius: 0.5rem; 27 | } 28 | 29 | .dark { 30 | --background: 0 0% 3.9%; 31 | --foreground: 0 0% 98%; 32 | --card: 0 0% 3.9%; 33 | --card-foreground: 0 0% 98%; 34 | --popover: 0 0% 3.9%; 35 | --popover-foreground: 0 0% 98%; 36 | --primary: 0 0% 98%; 37 | --primary-foreground: 0 0% 9%; 38 | --secondary: 0 0% 14.9%; 39 | --secondary-foreground: 0 0% 98%; 40 | --muted: 0 0% 14.9%; 41 | --muted-foreground: 0 0% 63.9%; 42 | --accent: 0 0% 14.9%; 43 | --accent-foreground: 0 0% 98%; 44 | --destructive: 0 62.8% 30.6%; 45 | --destructive-foreground: 0 0% 98%; 46 | --border: 0 0% 14.9%; 47 | --input: 0 0% 14.9%; 48 | --ring: 0 0% 83.1%; 49 | } 50 | } 51 | 52 | @layer base { 53 | * { 54 | @apply border-border; 55 | } 56 | body { 57 | @apply bg-background text-foreground; 58 | } 59 | } -------------------------------------------------------------------------------- /main/tailwind.config.ts: -------------------------------------------------------------------------------- 1 | export default { 2 | darkMode: "class", 3 | content: ["./index.html", "./src/**/*.{js,ts,jsx,tsx}"], 4 | theme: { 5 | transparent: "transparent", 6 | current: "currentColor", 7 | extend: { 8 | colors: { 9 | gray: { 10 | 50: "#FAFAFA", 11 | 100: "#F3F3F3", 12 | 200: "#E5E5E5", 13 | 300: "#C7C7C7", 14 | 400: "#ACACAC", 15 | 500: "#808080", 16 | 600: "#5F5F5F", 17 | 700: "#4E4E4E", 18 | 800: "#373737", 19 | 900: "#2A2A2A", 20 | 950: "#151515", 21 | }, 22 | // SHADCN related config starts here 23 | border: "hsl(var(--border))", 24 | input: "hsl(var(--input))", 25 | ring: "hsl(var(--ring))", 26 | background: "hsl(var(--background))", 27 | foreground: "hsl(var(--foreground))", 28 | primary: { 29 | DEFAULT: "hsl(var(--primary))", 30 | foreground: "hsl(var(--primary-foreground))", 31 | }, 32 | secondary: { 33 | DEFAULT: "hsl(var(--secondary))", 34 | foreground: "hsl(var(--secondary-foreground))", 35 | }, 36 | destructive: { 37 | DEFAULT: "hsl(var(--destructive))", 38 | foreground: "hsl(var(--destructive-foreground))", 39 | }, 40 | muted: { 41 | DEFAULT: "hsl(var(--muted))", 42 | foreground: "hsl(var(--muted-foreground))", 43 | }, 44 | accent: { 45 | DEFAULT: "hsl(var(--accent))", 46 | foreground: "hsl(var(--accent-foreground))", 47 | }, 48 | popover: { 49 | DEFAULT: "hsl(var(--popover))", 50 | foreground: "hsl(var(--popover-foreground))", 51 | }, 52 | card: { 53 | DEFAULT: "hsl(var(--card))", 54 | foreground: "hsl(var(--card-foreground))", 55 | }, 56 | borderRadius: { 57 | lg: "var(--radius)", 58 | md: "calc(var(--radius) - 2px)", 59 | sm: "calc(var(--radius) - 4px)", 60 | }, 61 | keyframes: { 62 | "accordion-down": { 63 | from: {height: 0}, 64 | to: {height: "var(--radix-accordion-content-height)"}, 65 | }, 66 | "accordion-up": { 67 | from: {height: "var(--radix-accordion-content-height)"}, 68 | to: {height: 0}, 69 | }, 70 | }, 71 | animation: { 72 | "accordion-down": "accordion-down 0.2s ease-out", 73 | "accordion-up": "accordion-up 0.2s ease-out", 74 | }, 75 | }, 76 | }, 77 | container: { 78 | center: true, 79 | padding: "2rem", 80 | screens: { 81 | "2xl": "1400px", 82 | }, 83 | }, 84 | // SHADCN related config ends here 85 | }, 86 | safelist: [ 87 | { 88 | pattern: 89 | /^(bg-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/, 90 | variants: ["hover", "ui-selected"], 91 | }, 92 | { 93 | pattern: 94 | /^(text-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/, 95 | variants: ["hover", "ui-selected"], 96 | }, 97 | { 98 | pattern: 99 | /^(border-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/, 100 | variants: ["hover", "ui-selected"], 101 | }, 102 | { 103 | pattern: 104 | /^(ring-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/, 105 | }, 106 | { 107 | pattern: 108 | /^(stroke-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/, 109 | }, 110 | { 111 | pattern: 112 | /^(fill-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/, 113 | }, 114 | ], 115 | plugins: [require("@headlessui/tailwindcss"), require("@tailwindcss/forms"), require("tailwindcss-animate")], // tailwindcss-animate is from SHADCN 116 | }; 117 | -------------------------------------------------------------------------------- /main/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es2017", 4 | "lib": ["dom", "dom.iterable", "esnext"], 5 | "allowJs": true, 6 | "checkJs": true, 7 | "skipLibCheck": true, 8 | "strict": true, 9 | "forceConsistentCasingInFileNames": true, 10 | "noEmit": true, 11 | "esModuleInterop": true, 12 | "module": "esnext", 13 | "moduleResolution": "node", 14 | "resolveJsonModule": true, 15 | "isolatedModules": true, 16 | "jsx": "preserve", 17 | "incremental": true, 18 | "noUncheckedIndexedAccess": true, 19 | "baseUrl": ".", 20 | "paths": { 21 | "~/*": ["./src/*"] 22 | }, 23 | "plugins": [{ "name": "next" }] 24 | }, 25 | "include": [ 26 | ".eslintrc.cjs", 27 | "next-env.d.ts", 28 | "**/*.ts", 29 | "**/*.tsx", 30 | "**/*.cjs", 31 | "**/*.mjs", 32 | ".next/types/**/*.ts" 33 | ], 34 | "exclude": ["node_modules"] 35 | } 36 | -------------------------------------------------------------------------------- /modal-proxy/gunicorn_config.py: -------------------------------------------------------------------------------- 1 | bind = '0.0.0.0:8000' 2 | workers = 4 -------------------------------------------------------------------------------- /modal-proxy/requirements.txt: -------------------------------------------------------------------------------- 1 | modal==0.55.4091 2 | flask==3.0.0 3 | gunicorn==21.2.0 -------------------------------------------------------------------------------- /modal-proxy/run.py: -------------------------------------------------------------------------------- 1 | from modal.cli.volume import put, get 2 | 3 | import os 4 | from flask import Flask, request 5 | 6 | app = Flask(__name__) 7 | 8 | @app.route('/upload', methods=['POST']) 9 | def upload(): 10 | json = request.get_json() 11 | 12 | file: str = json['file'] 13 | filename = json['filename'] 14 | 15 | with open(filename, 'w') as f: 16 | f.write(file) 17 | 18 | put("datasets", filename, "/", env=None) 19 | 20 | os.remove(filename) 21 | 22 | return filename 23 | 24 | @app.route('/download', methods=['POST']) 25 | def download(): 26 | json = request.get_json() 27 | 28 | filename = json['fileName'] 29 | print("filename: "+filename) 30 | 31 | get("datasets", filename, ".", env=None) 32 | 33 | with open(filename, 'r') as f: 34 | file = f.read() 35 | 36 | os.remove(filename) 37 | 38 | return file 39 | 40 | if __name__ == '__main__': 41 | # MODAL_CONFIG_PATH=./.modal.toml python run.py 42 | app.run(debug=False) -------------------------------------------------------------------------------- /prettier.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | useTabs: true, 3 | semi: true, 4 | singleQuote: false, 5 | quoteProps: "as-needed", 6 | trailingComma: "all", 7 | bracketSpacing: false, 8 | arrowParantheses: "always", 9 | printWidth: 120, 10 | }; 11 | --------------------------------------------------------------------------------