├── .gitignore
├── LICENSE
├── README.md
├── finetuning
├── finetuning.py
├── other
│ └── hf_readme_template.md
├── requirements.txt
└── training_engine
│ ├── dataset.py
│ ├── engine.py
│ └── utils.py
├── inference
├── README.md
└── lorax_api_server.py
├── logo.png
├── main
├── .env.example
├── .eslintrc.cjs
├── components.json
├── docker-compose.yml
├── next.config.mjs
├── package-lock.json
├── package.json
├── postcss.config.cjs
├── prisma
│ └── schema.prisma
├── public
│ ├── favicon.ico
│ ├── huggingface.png
│ ├── logo-simple-dark.svg
│ └── wandblogo.svg
├── src
│ ├── app
│ │ ├── api
│ │ │ ├── auth
│ │ │ │ └── [...nextauth]
│ │ │ │ │ └── route.ts
│ │ │ ├── hooks
│ │ │ │ └── update
│ │ │ │ │ └── route.ts
│ │ │ ├── inference
│ │ │ │ ├── route.ts
│ │ │ │ └── schema.ts
│ │ │ ├── search
│ │ │ │ └── route.ts
│ │ │ └── trpc
│ │ │ │ └── [trpc]
│ │ │ │ └── route_inactive.ts
│ │ ├── billing
│ │ │ ├── actions.tsx
│ │ │ ├── add-balance.tsx
│ │ │ ├── add-credit-card.tsx
│ │ │ ├── billing.tsx
│ │ │ └── page.tsx
│ │ ├── chat
│ │ │ └── [slug]
│ │ │ │ ├── chat.tsx
│ │ │ │ ├── page.tsx
│ │ │ │ ├── process.tsx
│ │ │ │ └── settings.tsx
│ │ ├── datasets
│ │ │ ├── new
│ │ │ │ ├── actions.tsx
│ │ │ │ ├── form.tsx
│ │ │ │ └── modal.tsx
│ │ │ ├── page.tsx
│ │ │ └── table.tsx
│ │ ├── layout.tsx
│ │ ├── models
│ │ │ ├── actions.tsx
│ │ │ ├── buttons.tsx
│ │ │ ├── export.tsx
│ │ │ ├── list.tsx
│ │ │ ├── new
│ │ │ │ ├── actions.tsx
│ │ │ │ ├── confirm-price.tsx
│ │ │ │ ├── form.tsx
│ │ │ │ ├── page.tsx
│ │ │ │ └── search.tsx
│ │ │ └── page.tsx
│ │ ├── page.tsx
│ │ ├── providers.tsx
│ │ └── settings
│ │ │ ├── actions.tsx
│ │ │ ├── form.tsx
│ │ │ └── page.tsx
│ ├── components
│ │ ├── form
│ │ │ ├── button.tsx
│ │ │ ├── dropdown.tsx
│ │ │ ├── file-upload.tsx
│ │ │ ├── label.tsx
│ │ │ ├── submit-button.tsx
│ │ │ └── textfield.tsx
│ │ ├── modal.tsx
│ │ ├── padding.tsx
│ │ ├── page-heading.tsx
│ │ ├── sidebar.tsx
│ │ ├── tiles.tsx
│ │ ├── tooltip.tsx
│ │ ├── ui
│ │ │ ├── hover-card.tsx
│ │ │ ├── slider.tsx
│ │ │ └── switch.tsx
│ │ ├── user-avatar.tsx
│ │ ├── utils
│ │ │ └── utils.ts
│ │ └── warning.tsx
│ ├── constants
│ │ ├── modal.ts
│ │ └── models.ts
│ ├── env.mjs
│ ├── pages
│ │ ├── _app.tsx
│ │ ├── _document.tsx
│ │ └── auth
│ │ │ ├── login
│ │ │ └── index.tsx
│ │ │ ├── new-user
│ │ │ └── index.tsx
│ │ │ └── verify-request
│ │ │ └── index.tsx
│ ├── server
│ │ ├── auth.ts
│ │ ├── controller
│ │ │ ├── new-dataset.ts
│ │ │ ├── new-model.ts
│ │ │ ├── process-dataset.ts
│ │ │ └── stripe.ts
│ │ ├── database
│ │ │ ├── chat-request.ts
│ │ │ ├── dataset.ts
│ │ │ ├── index.ts
│ │ │ ├── model.ts
│ │ │ └── user.ts
│ │ └── utils
│ │ │ ├── mail.ts
│ │ │ ├── modal.ts
│ │ │ ├── observability
│ │ │ ├── logtail.ts
│ │ │ └── posthog.ts
│ │ │ └── session.ts
│ └── styles
│ │ └── globals.css
├── tailwind.config.ts
└── tsconfig.json
├── modal-proxy
├── gunicorn_config.py
├── requirements.txt
└── run.py
└── prettier.config.js
/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2 |
3 | # dependencies
4 | */node_modules
5 | /.pnp
6 | .pnp.js
7 |
8 | # testing
9 | /coverage
10 |
11 | # database
12 | */prisma/db.sqlite
13 | */prisma/db.sqlite-journal
14 |
15 | # next.js
16 | */.next/
17 | /out/
18 | next-env.d.ts
19 |
20 | # production
21 | /build
22 |
23 | # misc
24 | .DS_Store
25 | *.pem
26 |
27 | # debug
28 | npm-debug.log*
29 | yarn-debug.log*
30 | yarn-error.log*
31 | .pnpm-debug.log*
32 |
33 | # local env files
34 | # do not commit any .env files to git, except for the .env.example file. https://create.t3.gg/en/usage/env-variables#using-environment-variables
35 | .env
36 | .env*.local
37 |
38 | # vercel
39 | .vercel
40 |
41 | # typescript
42 | *.tsbuildinfo
43 |
44 | # modal
45 | .modal.toml
46 |
47 | # temporary files
48 | */tmp/**/*
49 |
50 | # python
51 | *.pyc
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 | Build AI models for specialized tasks.
9 |
10 |
11 |
12 |
13 |
14 | [💻 Website](https://docs.haven.run/)
15 |
•
16 | [📄 Docs](https://docs.haven.run/)
17 |
•
18 | [☁️ Hosted App](https://app.haven.run/models)
19 |
•
20 | [💬 Discord](https://discord.gg/JDjbfp6q2G)
21 |
22 |
23 |
24 | Haven gives you tools needed to build specialized large language models. Our platform lets you to fine-tune LLMs through a simple UI and evaluate them based on a wide range of criteria.
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 | ## Welcome 🥳
34 |
35 | Welcome to the Haven repository! We initially started out building functionality for LLM fine-tuning, but noticed that most people actually struggle with evaluation and data collection rather than model training itself. To address this, we're planning to add the following features soon:
36 |
37 | - Evaluation: Use prebuilt or define custom evaluation metrics and run experiments comparing models across them
38 | - Visualization: Visualize and search through your experiments to get insights on model performance
39 | - Data Collection: Collecting and formatting training datasets is annyoing. We don't really know how to solve this yet, but we're trying to figure it out!
40 |
41 | If you have feedback or suggestions on these problems, **please reach out!** You can join our [Discord](https://discord.com/invite/JDjbfp6q2G), write us an [email](mailto:hello@haven.run), or [schedule a call](https://cal.com/justus-mattern-xfnomx/30-min-chat).
42 |
43 |
44 |
45 |
46 | ## Getting Started :rocket:
47 |
48 | Instructions to self-host as well as support for AWS, GCP and Azure will follow soon. In the meantime, you can try our hosted app [here](https://app.haven.run/).
49 |
50 |
--------------------------------------------------------------------------------
/finetuning/finetuning.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from datasets import load_dataset
4 | from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments, TrainerCallback
5 | from peft import LoraConfig
6 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
7 | from huggingface_hub import HfApi, create_repo
8 | import wandb
9 | import os
10 | from training_engine.engine import ModelEngine
11 | import random
12 | import string
13 | import argparse
14 | import modal
15 | from typing import Dict
16 | from training_engine.utils import send_update
17 |
18 | from fastapi import HTTPException
19 | from modal import gpu, Mount, Stub, Image, Volume, Secret, web_endpoint
20 | from modal.cli.volume import put
21 |
22 |
23 | image = (
24 | Image.debian_slim(python_version="3.10").apt_install("git")
25 | # Pinned to 10/16/23
26 | .pip_install("torch==2.0.1", "transformers==4.35.0", "peft==0.6.0", "accelerate==0.24.1", "bitsandbytes==0.41.1", "einops==0.7.0", "evaluate==0.4.1", "scikit-learn==1.2.2", "sentencepiece==0.1.99", "wandb==0.15.3", "trl==0.7.2", "huggingface-hub"
27 | ).pip_install("hf-transfer").pip_install("requests").pip_install("modal-client")
28 | .env(dict(HUGGINGFACE_HUB_CACHE="/pretrained_models", HF_HUB_ENABLE_HF_TRANSFER="0", MODAL_CONFIG_PATH="/utils/modal.toml"))
29 | )
30 |
31 | stub = Stub("llama-finetuning", image=image)
32 | stub.model_volume = modal.NetworkFileSystem.persisted("adapters")
33 |
34 |
35 | @stub.function(
36 | volumes={
37 | "/datasets": modal.Volume.from_name("datasets"),
38 | "/pretrained_models": modal.Volume.from_name("pretrained_models"),
39 | "/utils": modal.Volume.from_name("utils"),
40 | },
41 | network_file_systems={
42 | "/trained_adapters": stub.model_volume
43 | },
44 | mounts=[modal.Mount.from_local_dir("./other", remote_path="/other")],
45 | gpu=gpu.A100(count=1, memory=80),
46 | timeout=3600 * 12,
47 | allow_cross_region_volumes=True,
48 | secret=Secret.from_name("finetuning-auth-token")
49 | )
50 | @web_endpoint(method="POST", wait_for_response=False)
51 | def train(inputs: Dict):
52 |
53 | if inputs["auth_token"] != os.environ["AUTH_TOKEN"]:
54 | raise HTTPException(
55 | status_code=status.HTTP_401_UNAUTHORIZED,
56 | detail="Incorrect auth token",
57 | )
58 |
59 |
60 | if torch.cuda.is_available():
61 | # Get the current GPU device
62 | device = torch.cuda.current_device()
63 |
64 | # Get the available GPU memory
65 | gpu_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 2) # in megabytes
66 |
67 | print("GPU memory", gpu_memory)
68 |
69 | model_folder_name = f"/trained_adapters/{inputs['model_id']}"
70 | os.environ["WANDB_PROJECT"]="haven"
71 |
72 | try:
73 | wandb.login(anonymous="must", key=inputs["wandb_token"])
74 | log_wandb = True
75 | except Exception as e:
76 | print("exception: not logged in to wandb")
77 | print(e)
78 | log_wandb = False
79 |
80 | try:
81 | model_engine = ModelEngine(model_name=inputs["model_name"],
82 | dataset_name=inputs["dataset_name"],
83 | training_args=get_training_args(inputs, output_dir=model_folder_name, log_wandb=log_wandb),
84 | engine_args=get_engine_args(inputs)
85 | )
86 |
87 | model_engine.train()
88 | if log_wandb:
89 | wandb.finish()
90 |
91 | model_engine.tokenizer.save_pretrained(model_folder_name)
92 | save_readme_file(hf_repo_name=inputs["hf_repo"], base_model_name=inputs["model_name"], model_directory=model_folder_name)
93 |
94 | upload_model(repo=inputs["hf_repo"], folder_path=model_folder_name, hf_token=inputs["hf_token"])
95 | send_update(model_id=inputs["model_id"], key="status", value="finished")
96 |
97 | except Exception as e:
98 | print("exception:", e)
99 | send_update(model_id=inputs["model_id"], key="status", value="error")
100 |
101 |
102 |
103 | def get_training_args(inputs, output_dir, log_wandb):
104 | learning_rate_map = dict(Low=5e-5, Medium=1e-4, High=3e-4)
105 | if log_wandb:
106 | report_to = "wandb"
107 | else:
108 | report_to = "none"
109 |
110 | return dict(
111 | learning_rate=learning_rate_map[inputs["learning_rate"]],
112 | num_train_epochs=inputs["num_epochs"],
113 | output_dir=output_dir,
114 | gradient_accumulation_steps=inputs["gradient_accumulation_steps"],
115 | per_device_train_batch_size=inputs["per_device_train_batch_size"],
116 | report_to=report_to
117 | )
118 |
119 | def get_engine_args(inputs):
120 |
121 | class TrainingStartedCallback(TrainerCallback):
122 | def on_train_begin(self, args, state, control, **kwargs):
123 | try:
124 | print("sending url", wandb.run.url)
125 | except Exception as e:
126 | print("wandb url not found exception", e)
127 | return
128 |
129 | send_update(model_id=inputs["model_id"], key="wandb", value=wandb.run.url)
130 |
131 |
132 | return dict(
133 | wandb_token=inputs["wandb_token"],
134 | hf_token="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",
135 | max_tokens=inputs["max_tokens"],
136 | callbacks=[TrainingStartedCallback]
137 | )
138 |
139 | def generate_random_string():
140 | characters = string.ascii_letters + string.digits
141 | return ''.join(random.choice(characters) for _ in range(8))
142 |
143 |
144 | def upload_model(repo, folder_path, hf_token):
145 | api = HfApi()
146 |
147 | create_repo(repo, use_auth_token=hf_token, private=True)
148 | api.upload_folder(
149 | folder_path=folder_path,
150 | repo_id=repo,
151 | repo_type="model",
152 | use_auth_token=hf_token,
153 | )
154 |
155 |
156 | def save_readme_file(hf_repo_name: str, base_model_name: str, model_directory: str):
157 | with open("/other/hf_readme_template.md", "r") as f:
158 | data = f.read()
159 |
160 | full_readme = data.replace("{{base_model_name}}", base_model_name).replace("{{model_name}}", hf_repo_name)
161 |
162 | with open(f"{model_directory}/README.md", "w") as f:
163 | f.write(full_readme)
164 |
165 |
166 | if __name__=="__main__":
167 | parser = argparse.ArgumentParser()
168 | parser.add_argument("--learning_rate", type=str, choices=["Low", "Medium", "High"], default="Medium")
169 | parser.add_argument("--num_epochs", type=int, default=3)
170 | parser.add_argument("--model_name", default="meta-llama/Llama-2-7b-chat-hf")
171 | parser.add_argument("--max-tokens", type=int, default=2600)
172 | parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
173 | parser.add_argument("--per_device_train_batch_size", type=int, default=1)
174 | parser.add_argument("--dataset_name", type=str)
175 | parser.add_argument("--hf_repo", type=str)
176 | args = parser.parse_args()
177 |
178 |
179 |
180 | train.local((dict(
181 | wandb_token="d519a364bc5fcc5970e3eb1e9cc54225c444e511",
182 | hf_token="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",
183 | learning_rate=args.learning_rate,
184 | num_epochs=args.num_epochs,
185 | model_name=args.model_name,
186 | dataset_name=args.dataset_name,
187 | hf_repo=args.hf_repo,
188 | max_tokens=args.max_tokens,
189 | gradient_accumulation_steps=args.gradient_accumulation_steps,
190 | per_device_train_batch_size=args.per_device_train_batch_size,
191 | auth_token="",
192 | model_id=generate_random_string()
193 | )
194 | ))
195 |
196 |
--------------------------------------------------------------------------------
/finetuning/other/hf_readme_template.md:
--------------------------------------------------------------------------------
1 | ---
2 | library_name: peft
3 | base_model: meta-llama/Llama-2-7b-chat-hf
4 | ---
5 |
6 | # Model Card for {{model_name}}
7 |
8 | This is a lora adapter for {{base_model_name}} that was trained and automatically uploaded with [Haven](https://haven.run/).
9 |
10 |
11 | ## Testing the Model
12 |
13 | To quickly test the model, you can run it on a GPU with the transformers / peft library:
14 |
15 | ```python
16 | from peft import AutoPeftModelForCausalLM
17 | from transformers import AutoTokenizer
18 |
19 | tokenizer = AutoTokenizer.from_pretrained("{{model_name}}")
20 | model = AutoPeftModelForCausalLM.from_pretrained("{{model_name}}").to("cuda") # if you get a CUDA out of memory error, try load_in_8bit=True
21 |
22 | messages = [
23 | {"role": "system", "content": "You are a helpful assistant"},
24 | {"role": "user", "content": "Hi, can you please explain machine learning to me?"}
25 | ]
26 |
27 | encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")
28 | generated_ids = model.generate(input_ids=model_inputs, min_new_tokens=10, max_new_tokens=300, do_sample=True, temperature=0.9, top_p=0.8)
29 | decoded = tokenizer.batch_decode(generated_ids)
30 |
31 | print(decoded[0])
32 | ```
--------------------------------------------------------------------------------
/finetuning/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers
2 | peft
3 | accelerate
4 | bitsandbytes
5 | einops
6 | evaluate
7 | scikit-learn==1.2.2
8 | sentencepiece==0.1.99
9 | wandb==0.15.3
10 | trl
11 | huggingface-hub
--------------------------------------------------------------------------------
/finetuning/training_engine/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 | import io
4 | import json
5 | import logging
6 |
7 | from dataclasses import dataclass
8 | from typing import Dict, Sequence
9 | from torch.utils.data import Dataset
10 |
11 |
12 | class ChatDataset(Dataset):
13 | def __init__(self, data_path: str, tokenizer: transformers.AutoTokenizer, conversation_template: str, max_tokens: int):
14 | super(ChatDataset, self).__init__()
15 | data = []
16 | with open(data_path, "r") as file:
17 | for line in file:
18 | try:
19 | data.append(json.loads(line))
20 | except Exception as e:
21 | print("json processing exception", e)
22 | continue
23 |
24 |
25 | data_dict = preprocess(data, tokenizer, conversation_template, max_tokens)
26 |
27 | self.input_ids = data_dict["input_ids"]
28 | self.labels = data_dict["labels"]
29 |
30 | def __len__(self):
31 | return len(self.input_ids)
32 |
33 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
34 | return dict(input_ids=self.input_ids[i], labels=self.labels[i])
35 |
36 |
37 | @dataclass
38 | class DataCollatorForChatDataset(object):
39 | """
40 | Collate examples for supervised fine-tuning.
41 | """
42 |
43 | tokenizer: transformers.PreTrainedTokenizer
44 |
45 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
46 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
47 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
48 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
49 |
50 | return dict(
51 | input_ids=input_ids,
52 | labels=labels,
53 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
54 | )
55 |
56 |
57 | class ChatDataModule():
58 | def __init__(self, tokenizer: transformers.PreTrainedTokenizer, data_path: str, conversation_template, max_tokens: int):
59 |
60 | self.dataset = ChatDataset(tokenizer=tokenizer, data_path=data_path, conversation_template=conversation_template, max_tokens=max_tokens)
61 | self.data_collator = DataCollatorForChatDataset(tokenizer=tokenizer)
62 |
63 |
64 | def preprocess(conversations: Sequence[Sequence[dict]], tokenizer: transformers.PreTrainedTokenizer, conversation_template: str, max_tokens: int) -> Dict:
65 | """
66 | Preprocess the data by tokenizing.
67 | """
68 | all_input_ids = []
69 | all_label_ids = []
70 | tokenizer.use_default_system_prompt = False
71 |
72 | for conv in conversations:
73 | current_conv = conv["messages"]
74 | tokenized_responses = []
75 | for msg in current_conv:
76 | if msg["role"] == "assistant":
77 | tokenized_responses.append(tokenizer.encode(msg["content"], add_special_tokens=False))
78 |
79 | tokenized_conv = tokenizer.apply_chat_template(current_conv, chat_template=conversation_template, max_length=max_tokens, truncation=True)
80 | nontokenized_conv = tokenizer.apply_chat_template(current_conv, chat_template=conversation_template, max_length=max_tokens, truncation=True, tokenize=False)
81 |
82 | print("nontokenized", nontokenized_conv)
83 |
84 | loss_aware_tokenized_conv = [-100] * len(tokenized_conv)
85 |
86 | for sublist in tokenized_responses:
87 | for index, value in enumerate(tokenized_conv):
88 | if value in sublist:
89 | loss_aware_tokenized_conv[index] = value
90 | if index < len(tokenized_conv) - 1 and tokenized_conv[index + 1] not in sublist:
91 | loss_aware_tokenized_conv[index + 1] = tokenized_conv[index + 1]
92 |
93 | all_input_ids.append(torch.LongTensor(tokenized_conv))
94 | all_label_ids.append(torch.LongTensor(loss_aware_tokenized_conv))
95 |
96 | return dict(input_ids=all_input_ids, labels=all_label_ids)
97 |
98 |
--------------------------------------------------------------------------------
/finetuning/training_engine/engine.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from datasets import load_dataset
4 | from .dataset import ChatDataModule
5 | from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback
6 | from peft import LoraConfig, get_peft_model
7 | from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
8 | import wandb
9 | import os
10 | from dataclasses import dataclass, field
11 | from typing import List
12 |
13 |
14 |
15 | @dataclass
16 | class EngineTrainArguments(TrainingArguments):
17 | output_dir: str = field(
18 | default='./output',
19 | metadata={"help": 'The output dir for logs and checkpoints'}
20 | )
21 | per_device_train_batch_size: int = field(
22 | default=1,
23 | metadata={"help": 'batch size'}
24 | )
25 | gradient_accumulation_steps: int = field(
26 | default=16,
27 | metadata={"help": 'gradient accumulation steps'}
28 | )
29 | logging_steps: int = field(
30 | default=3,
31 | metadata={"help": 'gradient accumulation steps'}
32 | )
33 | learning_rate: float = field(
34 | default=1e-4,
35 | metadata={"help": 'learning rate'}
36 | )
37 | num_train_epochs: int = field(
38 | default=3,
39 | metadata={"help": 'number of training epochs'}
40 | )
41 | report_to: str = field(
42 | default="wandb",
43 | metadata={"help": 'where to report training logs to'}
44 | )
45 | save_strategy: str = field(
46 | default="no",
47 | metadata={"help": 'model saving strategy'}
48 | )
49 |
50 | @dataclass
51 | class EngineConfig:
52 | hf_token: str = field(
53 | default="none",
54 | metadata={"help": 'huggingface token'}
55 | )
56 | wandb_token: str = field(
57 | default="none",
58 | metadata={"help": 'wandb token'}
59 | )
60 | default_conversation_template_id: str = field(
61 | default="meta-llama/Llama-2-7b-chat-hf",
62 | metadata={"help": 'which template to use if model does not have one'}
63 | )
64 | max_tokens: int = field(
65 | default=2700,
66 | metadata={"help": 'max sample length'}
67 | )
68 | callbacks: List = field(
69 | default=None
70 | )
71 |
72 |
73 |
74 | class ModelEngine:
75 | def __init__(self, model_name, dataset_name, training_args, engine_args):
76 | self.training_args = EngineTrainArguments(**training_args)
77 | self.engine_config = EngineConfig(**engine_args)
78 | self.model, self.tokenizer = self.load_model(model_name)
79 | self.tokenizer.chat_template = self.get_conversation_template()
80 | self.data_module = ChatDataModule(self.tokenizer, dataset_name, self.tokenizer.chat_template, self.engine_config.max_tokens)
81 |
82 | def load_model(self, model_name):
83 | bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16)
84 |
85 | peft_config = LoraConfig(
86 | lora_alpha=32,
87 | lora_dropout=0.05,
88 | r=8,
89 | bias="none",
90 | task_type="CAUSAL_LM",
91 | target_modules=["q_proj", "v_proj", "o_proj", "k_proj"]
92 | )
93 |
94 | model = AutoModelForCausalLM.from_pretrained(
95 | model_name,
96 | quantization_config=bnb_config,
97 | device_map="auto",
98 | trust_remote_code=True,
99 | use_auth_token=self.engine_config.hf_token
100 | )
101 |
102 |
103 | model = get_peft_model(model, peft_config)
104 |
105 | print("device map", model.hf_device_map)
106 | print("trainable", model.print_trainable_parameters())
107 |
108 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_auth_token=self.engine_config.hf_token)
109 | tokenizer.pad_token = tokenizer.eos_token
110 |
111 | return model, tokenizer
112 |
113 | def get_conversation_template(self):
114 | print("chat template jajaaaa", self.tokenizer.chat_template)
115 | if self.tokenizer.chat_template is not None:
116 | if not "Only user and assistant roles are supported" in self.tokenizer.chat_template:
117 | return self.tokenizer.chat_template
118 |
119 | if self.tokenizer.default_chat_template is not None:
120 | if not "Only user and assistant roles are supported" in self.tokenizer.default_chat_template:
121 | return self.tokenizer.default_chat_template
122 |
123 | alt_tokenizer = AutoTokenizer.from_pretrained(self.engine_config.default_conversation_template_id, trust_remote_code=True, use_auth_token=self.engine_config.hf_token)
124 | return alt_tokenizer.chat_template
125 |
126 |
127 | def train(self):
128 |
129 | trainer = Trainer(
130 | model=self.model,
131 | train_dataset=self.data_module.dataset,
132 | tokenizer=self.tokenizer,
133 | args=self.training_args,
134 | data_collator=self.data_module.data_collator,
135 | callbacks=self.engine_config.callbacks
136 | )
137 |
138 | trainer.train()
139 | trainer.save_model()
140 |
--------------------------------------------------------------------------------
/finetuning/training_engine/utils.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | def send_update(model_id, key, value):
4 | url = 'https://app-haven-run.onrender.com/api/hooks/update'
5 | headers = {
6 | 'x-secret': 'SECRET',
7 | 'Content-Type': 'application/json'
8 | }
9 | data = {
10 | 'id': model_id,
11 | 'key': key,
12 | 'value': value
13 | }
14 |
15 | print("sending request", dict(
16 | url=url,
17 | headers=headers,
18 | data=data
19 | ))
20 |
21 | response = requests.post(url, headers=headers, json=data)
22 |
23 | print("response code", response.status_code)
24 |
25 | return response
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/inference/README.md:
--------------------------------------------------------------------------------
1 | # LoraX API Server
2 |
3 | To deploy a multi lora server for a certain model, you need to first specify the 'MODEL_ID' in 'lorax_api_server.py' (the model id should be the huggingface name of the desired base model). Afterwards, just run:
4 |
5 | ```
6 | modal deploy lorax_api_server.py
7 | ```
--------------------------------------------------------------------------------
/inference/lorax_api_server.py:
--------------------------------------------------------------------------------
1 | from modal import Image, Mount, Secret, Stub, asgi_app, gpu, method, Volume, NetworkFileSystem
2 | import lorax
3 | from typing import List, Dict
4 | from transformers import AutoTokenizer
5 | import requests
6 | import json
7 |
8 | from fastapi import BackgroundTasks, FastAPI, Request
9 | from fastapi.responses import Response, StreamingResponse, JSONResponse
10 |
11 |
12 | GPU_CONFIG = gpu.A10G(count=1)
13 | MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"
14 |
15 | LAUNCH_FLAGS = [
16 | "--model-id",
17 | MODEL_ID,
18 | "--port",
19 | "8000",
20 | "--cuda-memory-fraction",
21 | "0.95",
22 | "--max-input-length",
23 | "2048",
24 | "--max-total-tokens",
25 | "4096"
26 | ]
27 |
28 |
29 | image = (
30 | Image.from_registry("ghcr.io/predibase/lorax:latest")
31 | .dockerfile_commands("ENTRYPOINT []")
32 | .pip_install("lorax-client").pip_install("transformers==4.35.0").env(dict(HUGGINGFACE_HUB_CACHE="/pretrained_models"))
33 | )
34 |
35 | stub = Stub("multi-lora-server-" + MODEL_ID.split("/")[0]+"-"+MODEL_ID.split("/")[1], image=image)
36 |
37 |
38 | @stub.cls(
39 | secret=Secret.from_name("huggingface-token"),
40 | gpu=GPU_CONFIG,
41 | allow_concurrent_inputs=100,
42 | container_idle_timeout=60 * 10,
43 | timeout=60 * 60,
44 | keep_warm=1,
45 | allow_cross_region_volumes=True,
46 | volumes={
47 | "/pretrained_models": Volume.from_name("pretrained_models")
48 | },
49 | network_file_systems={
50 | "/trained_adapters": NetworkFileSystem.persisted("adapters")
51 | }
52 | )
53 | class Model:
54 | def __enter__(self):
55 | import socket
56 | import subprocess
57 | import time
58 |
59 | from lorax import AsyncClient
60 |
61 | self.launcher = subprocess.Popen(
62 | ["lorax-launcher"] + LAUNCH_FLAGS
63 | )
64 | self.client = AsyncClient("http://127.0.0.1:8000", timeout=60)
65 | self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
66 |
67 |
68 | def webserver_ready():
69 | try:
70 | socket.create_connection(("127.0.0.1", 8000), timeout=1).close()
71 | return True
72 | except (socket.timeout, ConnectionRefusedError):
73 |
74 | retcode = self.launcher.poll()
75 | if retcode is not None:
76 | raise RuntimeError(
77 | f"launcher exited unexpectedly with code {retcode}"
78 | )
79 | return False
80 |
81 | while not webserver_ready():
82 | time.sleep(1.0)
83 |
84 | print("Webserver ready!")
85 |
86 | def __exit__(self, _exc_type, _exc_value, _traceback):
87 | self.launcher.terminate()
88 |
89 |
90 | @asgi_app()
91 | def fastapi_app(self):
92 | app = FastAPI()
93 |
94 |
95 | @app.post("/generate")
96 | async def generate(request: Request):
97 | request = await request.json()
98 | chat = request["chat"]
99 | parameters = request["parameters"]
100 |
101 | prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
102 |
103 | payload = {
104 | "inputs": prompt,
105 | "parameters": parameters
106 | }
107 | url = "http://127.0.0.1:8000/generate"
108 | headers = {"Content-Type": "application/json"}
109 | response = requests.post(url, headers=headers, json=payload)
110 |
111 | return Response(content=json.dumps(response.json(), ensure_ascii=False).encode("utf-8"))
112 |
113 | @app.post("/generate_stream")
114 | async def generate_stream(request: Request):
115 | request = await request.json()
116 | chat = request["chat"]
117 | parameters = request["parameters"]
118 |
119 | prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
120 |
121 | print("prompt", prompt)
122 |
123 | payload = {
124 | "inputs": prompt,
125 | "parameters": parameters
126 | }
127 | url = "http://127.0.0.1:8000/generate_stream"
128 | headers = {"Content-Type": "application/json"}
129 | response = requests.post(url, headers=headers, json=payload, stream=True)
130 |
131 | print("res", response)
132 |
133 | def stream():
134 | for line in response:
135 | print(line)
136 | yield line
137 |
138 |
139 | return StreamingResponse(stream(), media_type="text/event-stream")
140 |
141 | return app
142 |
143 |
--------------------------------------------------------------------------------
/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/redotvideo/haven/a12a84de83d1763f9b543d027ae4d49b59defe0a/logo.png
--------------------------------------------------------------------------------
/main/.env.example:
--------------------------------------------------------------------------------
1 | # Since the ".env" file is gitignored, you can use the ".env.example" file to
2 | # build a new ".env" file when you clone the repo. Keep this file up-to-date
3 | # when you add new variables to `.env`.
4 |
5 | # This file will be committed to version control, so make sure not to have any
6 | # secrets in it. If you are cloning this repo, create a copy of this file named
7 | # ".env" and populate it with your secrets.
8 |
9 | # When adding additional environment variables, the schema in "/src/env.mjs"
10 | # should be updated accordingly.
11 |
12 | # Prisma
13 | # https://www.prisma.io/docs/reference/database-reference/connection-urls#env
14 | DATABASE_URL="file:./db.sqlite"
15 |
16 | # Next Auth
17 | # You can generate a new secret on the command line with:
18 | # openssl rand -base64 32
19 | # https://next-auth.js.org/configuration/options#secret
20 | # NEXTAUTH_SECRET=""
21 | NEXTAUTH_URL="http://localhost:3000"
22 |
23 | # Next Auth Discord Provider
24 | DISCORD_CLIENT_ID=""
25 | DISCORD_CLIENT_SECRET=""
26 |
--------------------------------------------------------------------------------
/main/.eslintrc.cjs:
--------------------------------------------------------------------------------
1 | /** @type {import("eslint").Linter.Config} */
2 | const config = {
3 | parser: "@typescript-eslint/parser",
4 | parserOptions: {
5 | project: true,
6 | },
7 | plugins: ["@typescript-eslint"],
8 | extends: [
9 | "next/core-web-vitals",
10 | "plugin:@typescript-eslint/recommended-type-checked",
11 | "plugin:@typescript-eslint/stylistic-type-checked",
12 | ],
13 | rules: {
14 | // These opinionated rules are enabled in stylistic-type-checked above.
15 | // Feel free to reconfigure them to your own preference.
16 | "@typescript-eslint/array-type": "off",
17 | "@typescript-eslint/consistent-type-definitions": "off",
18 |
19 | "@typescript-eslint/consistent-type-imports": [
20 | "warn",
21 | {
22 | prefer: "type-imports",
23 | fixStyle: "inline-type-imports",
24 | },
25 | ],
26 | "@typescript-eslint/no-unused-vars": ["warn", {argsIgnorePattern: "^_"}],
27 | "@typescript-eslint/no-misused-promises": [
28 | 2,
29 | {
30 | checksVoidReturn: {attributes: false},
31 | },
32 | ],
33 |
34 | // TODO: maybe we can turn these back on at some point
35 | "@typescript-eslint/no-unsafe-assignment": "off",
36 | "@typescript-eslint/no-explicit-any": "off",
37 | "@typescript-eslint/prefer-nullish-coalescing": "off",
38 | "@typescript-eslint/no-empty-function": "off",
39 | },
40 | };
41 |
42 | module.exports = config;
43 |
--------------------------------------------------------------------------------
/main/components.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://ui.shadcn.com/schema.json",
3 | "style": "default",
4 | "rsc": true,
5 | "tsx": true,
6 | "tailwind": {
7 | "config": "tailwind.config.ts",
8 | "css": "src/styles/globals.css",
9 | "baseColor": "gray",
10 | "cssVariables": false
11 | },
12 | "aliases": {
13 | "components": "src/components",
14 | "utils": "~/components/utils/utils"
15 | }
16 | }
--------------------------------------------------------------------------------
/main/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3.3"
2 | services:
3 | mysql:
4 | image: mysql:8.1.0
5 | restart: always
6 | ports:
7 | - 3306:3306
8 | environment:
9 | MYSQL_ROOT_PASSWORD: example
10 | MYSQL_DATABASE: example
11 | MYSQL_USER: example
12 | MYSQL_PASSWORD: example
13 |
--------------------------------------------------------------------------------
/main/next.config.mjs:
--------------------------------------------------------------------------------
1 | /**
2 | * Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation. This is especially useful
3 | * for Docker builds.
4 | */
5 | await import("./src/env.mjs");
6 |
7 | /** @type {import("next").NextConfig} */
8 | const config = {
9 | reactStrictMode: true,
10 |
11 | /**
12 | * If you are using `appDir` then you must comment the below `i18n` config out.
13 | *
14 | * @see https://github.com/vercel/next.js/issues/41980
15 | */
16 | i18n: {
17 | locales: ["en"],
18 | defaultLocale: "en",
19 | },
20 |
21 | // Indicate that these packages should not be bundled by webpack
22 | experimental: {
23 | serverComponentsExternalPackages: ["sharp", "onnxruntime-node"],
24 | },
25 | };
26 |
27 | export default config;
28 |
--------------------------------------------------------------------------------
/main/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "t3",
3 | "version": "0.1.0",
4 | "private": true,
5 | "scripts": {
6 | "build": "npm install && next build",
7 | "db:push": "prisma db push",
8 | "db:studio": "prisma studio",
9 | "dev": "next dev",
10 | "postinstall": "prisma generate",
11 | "lint": "next lint",
12 | "start": "next start"
13 | },
14 | "dependencies": {
15 | "@headlessui/react": "^1.7.17",
16 | "@headlessui/tailwindcss": "^0.2.0",
17 | "@heroicons/react": "^2.0.18",
18 | "@huggingface/hub": "^0.12.1",
19 | "@logtail/node": "^0.4.17",
20 | "@next-auth/prisma-adapter": "^1.0.7",
21 | "@prisma/client": "^5.1.1",
22 | "@radix-ui/react-hover-card": "^1.0.7",
23 | "@radix-ui/react-slider": "^1.1.2",
24 | "@radix-ui/react-switch": "^1.0.3",
25 | "@stripe/react-stripe-js": "^2.3.1",
26 | "@t3-oss/env-nextjs": "^0.7.0",
27 | "@tailwindcss/forms": "^0.5.6",
28 | "@tanstack/react-query": "^4.32.6",
29 | "@xenova/transformers": "^2.8.0",
30 | "class-variance-authority": "^0.7.0",
31 | "clsx": "^2.0.0",
32 | "lucide-react": "^0.290.0",
33 | "next": "^14.0.1",
34 | "next-auth": "^4.23.1",
35 | "nodemailer": "^6.9.7",
36 | "posthog-node": "^3.1.3",
37 | "react": "18.2.0",
38 | "react-dom": "18.2.0",
39 | "resend": "^2.0.0",
40 | "stripe": "^14.3.0",
41 | "superjson": "^1.13.1",
42 | "tailwind-merge": "^2.0.0",
43 | "tailwindcss-animate": "^1.0.7",
44 | "uuid": "^8.0.0",
45 | "yup": "^1.3.2"
46 | },
47 | "devDependencies": {
48 | "@types/eslint": "^8.44.2",
49 | "@types/node": "^18.16.0",
50 | "@types/react": "^18.2.20",
51 | "@types/react-dom": "^18.2.7",
52 | "@types/uuid": "^9.0.6",
53 | "@typescript-eslint/eslint-plugin": "^6.3.0",
54 | "@typescript-eslint/parser": "^6.3.0",
55 | "autoprefixer": "^10.4.14",
56 | "eslint": "^8.47.0",
57 | "eslint-config-next": "^13.5.4",
58 | "postcss": "^8.4.27",
59 | "prettier": "^3.0.0",
60 | "prettier-plugin-tailwindcss": "^0.5.1",
61 | "prisma": "^5.1.1",
62 | "tailwindcss": "^3.3.5",
63 | "typescript": "^5.1.6"
64 | },
65 | "ct3aMetadata": {
66 | "initVersion": "7.22.0"
67 | },
68 | "packageManager": "npm@9.6.7",
69 | "engines": {
70 | "node": "18.17.1"
71 | }
72 | }
--------------------------------------------------------------------------------
/main/postcss.config.cjs:
--------------------------------------------------------------------------------
1 | const config = {
2 | plugins: {
3 | tailwindcss: {},
4 | autoprefixer: {},
5 | },
6 | };
7 |
8 | module.exports = config;
9 |
--------------------------------------------------------------------------------
/main/prisma/schema.prisma:
--------------------------------------------------------------------------------
1 | // This is your Prisma schema file,
2 | // learn more about it in the docs: https://pris.ly/d/prisma-schema
3 |
4 | generator client {
5 | provider = "prisma-client-js"
6 | }
7 |
8 | datasource db {
9 | provider = "mysql"
10 | // NOTE: When using mysql or sqlserver, uncomment the @db.Text annotations in model Account below
11 | // Further reading:
12 | // https://next-auth.js.org/adapters/prisma#create-the-prisma-schema
13 | // https://www.prisma.io/docs/reference/api-reference/prisma-schema-reference#string
14 | url = env("DATABASE_URL")
15 | relationMode = "prisma"
16 | }
17 |
18 | model Post {
19 | id Int @id @default(autoincrement())
20 | name String
21 | createdAt DateTime @default(now())
22 | updatedAt DateTime @updatedAt
23 |
24 | createdBy User @relation(fields: [createdById], references: [id])
25 | createdById String
26 |
27 | @@index([name])
28 | @@index([createdById])
29 | }
30 |
31 | // Necessary for Next auth
32 | model Account {
33 | id String @id @default(cuid())
34 | userId String
35 | type String
36 | provider String
37 | providerAccountId String
38 | refresh_token String? @db.Text
39 | access_token String? @db.Text
40 | expires_at Int?
41 | token_type String?
42 | scope String?
43 | id_token String? @db.Text
44 | session_state String?
45 | user User @relation(fields: [userId], references: [id])
46 |
47 | createdAt DateTime @default(now())
48 | updatedAt DateTime @default(now()) @updatedAt
49 |
50 | @@unique([provider, providerAccountId])
51 | @@index([userId])
52 | }
53 |
54 | model Session {
55 | id String @id @default(cuid())
56 | sessionToken String @unique
57 | userId String
58 | expires DateTime
59 | user User @relation(fields: [userId], references: [id])
60 |
61 | createdAt DateTime @default(now())
62 | updatedAt DateTime @default(now()) @updatedAt
63 |
64 | @@index([userId])
65 | }
66 |
67 | model User {
68 | id String @id @default(cuid())
69 | name String?
70 | email String? @unique
71 | emailVerified DateTime?
72 | image String?
73 |
74 | centsBalance Float @default(500)
75 |
76 | stripeCustomerId String?
77 | stripePaymentMethodId String?
78 |
79 | apiKey String?
80 |
81 | hfToken String?
82 |
83 | createdAt DateTime @default(now())
84 | updatedAt DateTime @default(now()) @updatedAt
85 |
86 | accounts Account[]
87 | sessions Session[]
88 | posts Post[]
89 | Model Model[]
90 | Transaction Transaction[]
91 | ChatRequest ChatRequest[]
92 | Dataset Dataset[]
93 | }
94 |
95 | model Model {
96 | id String @id @default(cuid())
97 | userId String
98 | name String
99 | costInCents Int
100 |
101 | state String @default("training") // "training" | "finished" | "error"
102 |
103 | datasetPath String? // TODO: Delete
104 | learningRate String @default("Medium") // "Low" | "Medium" | "High"
105 | epochs Int
106 | baseModel String
107 |
108 | wandbUrl String?
109 |
110 | createdAt DateTime @default(now())
111 | updatedAt DateTime @default(now()) @updatedAt
112 |
113 | datasetId String? // TODO: Make required
114 | dataset Dataset? @relation(fields: [datasetId], references: [id]) // TODO: make required
115 |
116 | user User @relation(fields: [userId], references: [id])
117 |
118 | @@index([name])
119 | @@index([userId])
120 | @@index([datasetId])
121 | }
122 |
123 | model Dataset {
124 | id String @id @default(cuid())
125 | userId String
126 | name String
127 | fileName String
128 |
129 | rows Int
130 | createdAt DateTime @default(now())
131 | updatedAt DateTime @default(now()) @updatedAt
132 |
133 | user User @relation(fields: [userId], references: [id])
134 | Model Model[]
135 |
136 | @@index([name])
137 | @@index([userId])
138 | }
139 |
140 | model VerificationToken {
141 | identifier String
142 | token String @unique
143 | expires DateTime
144 |
145 | @@unique([identifier, token])
146 | }
147 |
148 | model Transaction {
149 | id String @id @default(cuid())
150 | userId String
151 | amount Float
152 | createdAt DateTime @default(now())
153 |
154 | reason String
155 |
156 | user User @relation(fields: [userId], references: [id])
157 |
158 | @@index([userId])
159 | }
160 |
161 | model ChatRequest {
162 | id String @id @default(cuid())
163 |
164 | userId String
165 | modelId String
166 |
167 | createdAt DateTime @default(now())
168 |
169 | user User @relation(fields: [userId], references: [id])
170 |
171 | @@index([userId])
172 | @@index([createdAt])
173 | }
174 |
--------------------------------------------------------------------------------
/main/public/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/redotvideo/haven/a12a84de83d1763f9b543d027ae4d49b59defe0a/main/public/favicon.ico
--------------------------------------------------------------------------------
/main/public/huggingface.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/redotvideo/haven/a12a84de83d1763f9b543d027ae4d49b59defe0a/main/public/huggingface.png
--------------------------------------------------------------------------------
/main/public/logo-simple-dark.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/main/public/wandblogo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/main/src/app/api/auth/[...nextauth]/route.ts:
--------------------------------------------------------------------------------
1 | import NextAuth from "next-auth";
2 |
3 | import {authOptions} from "~/server/auth";
4 |
5 | const handler: any = NextAuth(authOptions);
6 | export {handler as GET, handler as POST};
7 |
--------------------------------------------------------------------------------
/main/src/app/api/hooks/update/route.ts:
--------------------------------------------------------------------------------
1 | import * as y from "yup";
2 | import {addWandBUrl, getJobCostInCents, getModelFromId, updateState} from "~/server/database/model";
3 | import {getUserFromId, increaseBalance} from "~/server/database/user";
4 | import {logger} from "~/server/utils/observability/logtail";
5 | import {EventName, sendEvent} from "~/server/utils/observability/posthog";
6 |
7 | const WEBHOOK_SECRET = process.env.WEBHOOK_SECRET || "";
8 |
9 | const bodySchema = y.object({
10 | id: y.string().required(),
11 | key: y.string().oneOf(["wandb", "status"]).required(),
12 | value: y.string().required(),
13 | });
14 |
15 | export async function POST(request: Request) {
16 | if (request.headers.get("x-secret") !== WEBHOOK_SECRET) {
17 | return new Response("Unauthorized", {
18 | status: 401,
19 | });
20 | }
21 |
22 | const body = await request.json();
23 | const validatedBody = await bodySchema.validate(body);
24 |
25 | const model = await getModelFromId(validatedBody.id);
26 | if (model == null) {
27 | throw new Error("Could not find model");
28 | }
29 |
30 | if (validatedBody.key === "wandb" && validatedBody.id !== "") {
31 | sendEvent("system", EventName.FINE_TUNE_WANDB_ADDED, {wandbUrl: validatedBody.value, modelId: validatedBody.id});
32 | await addWandBUrl(validatedBody.id, validatedBody.value);
33 | }
34 |
35 | if (validatedBody.key === "status") {
36 | const status = await y.string().oneOf(["finished", "error"]).required().validate(validatedBody.value);
37 |
38 | // Refund
39 | if (status === "error") {
40 | const jobPriceInCents = await getJobCostInCents(validatedBody.id);
41 | const user = await getUserFromId(model.userId);
42 |
43 | if (!user) {
44 | logger.error("Could not find user", {id: validatedBody.id, userId: model.userId});
45 | throw new Error("Could not find user");
46 | }
47 |
48 | if (jobPriceInCents == null) {
49 | logger.error("Could not find job price", {id: validatedBody.id, userId: model.userId});
50 | throw new Error("Could not find job price");
51 | }
52 |
53 | logger.error("Training job failed, refunding cost", {id: validatedBody.id, jobPriceInCents, userId: user.id});
54 | await increaseBalance(model.userId, jobPriceInCents, `Refund for failed model: ${validatedBody.id}`).catch(
55 | (e) => {
56 | logger.error("Could not refund", {id: validatedBody.id, jobPriceInCents, userId: user.id, error: e});
57 | },
58 | );
59 | } else {
60 | sendEvent(model.userId, status === "finished" ? EventName.FINE_TUNE_FINISHED : EventName.FINE_TUNE_FAILED);
61 | }
62 |
63 | await updateState(validatedBody.id, status);
64 | }
65 |
66 | return new Response("OK", {
67 | status: 200,
68 | });
69 | }
70 |
--------------------------------------------------------------------------------
/main/src/app/api/inference/route.ts:
--------------------------------------------------------------------------------
1 | import * as y from "yup";
2 | import {defaultModelLoopup, modelsToFinetune} from "~/constants/models";
3 | import {inferenceEndpoints} from "~/constants/modal";
4 | import {createChatRequest, getNumberOfChatRequestsInLast24Hours} from "~/server/database/chat-request";
5 | import {getModelFromId} from "~/server/database/model";
6 | import {logger} from "~/server/utils/observability/logtail";
7 | import {checkSessionAction} from "~/server/utils/session";
8 | import {bodySchema} from "./schema";
9 | import {EventName, sendEvent} from "~/server/utils/observability/posthog";
10 |
11 | async function getResponse(host: string, modelId: string | undefined, validatedBody: y.InferType) {
12 | const adapter_id = modelId ? `/trained_adapters/${modelId}` : undefined;
13 |
14 | const body = JSON.stringify({
15 | parameters: {
16 | temperature: validatedBody.parameters.temperature,
17 | top_p: validatedBody.parameters.topP,
18 | do_sample: validatedBody.parameters.doSample,
19 | max_new_tokens: validatedBody.parameters.maxTokens,
20 | repetition_penalty: validatedBody.parameters.repetitionPenalty,
21 | adapter_id,
22 | },
23 | chat: validatedBody.history,
24 | auth_token: process.env.MODAL_AUTH_TOKEN,
25 | });
26 |
27 | console.log("sending request", body, host);
28 |
29 | return fetch(host, {
30 | method: "POST",
31 | headers: {
32 | "Content-Type": "application/json",
33 | },
34 | body,
35 | });
36 | }
37 |
38 | async function retryInference(makeCall: () => Promise, logInfo: object) {
39 | for (let i = 0; i < 3; i++) {
40 | try {
41 | return makeCall();
42 | } catch (e) {
43 | logger.error("Inference error, try no. " + i, {...logInfo, error: e});
44 | }
45 | }
46 |
47 | logger.error("Inference failed after 3 attempts.", logInfo);
48 | throw new Error("Inference failed");
49 | }
50 |
51 | export async function POST(request: Request) {
52 | const body = await request.json();
53 | const validatedBody = await bodySchema.validate(body);
54 |
55 | const [session, model] = await Promise.all([checkSessionAction(), getModelFromId(validatedBody.modelId)]);
56 |
57 | sendEvent(session.user.id, EventName.INFERENCE_REQUEST, {
58 | modelId: validatedBody.modelId,
59 | });
60 |
61 | logger.info("Inference request", {
62 | modelId: validatedBody.modelId,
63 | userId: session.user.id,
64 | body: JSON.stringify(body),
65 | });
66 |
67 | // Check if user has reached their daily limit
68 | const last24h = await getNumberOfChatRequestsInLast24Hours(session.user.id);
69 | if (last24h > 100) {
70 | logger.info("Daily inference rate limit reached", {modelId: validatedBody.modelId});
71 | return new Response("Daily limit reached for free tier", {
72 | status: 402,
73 | });
74 | }
75 |
76 | // Both of these are available to all users
77 | const isDefaultModel = Object.keys(defaultModelLoopup).includes(validatedBody.modelId);
78 |
79 | // Adapter not found and not internal model
80 | if (!isDefaultModel && !model) {
81 | logger.error("Model not found", {modelId: validatedBody.modelId, userId: session.user.id});
82 | throw new Error("Model not found");
83 | }
84 |
85 | const baseModel = model?.baseModel || defaultModelLoopup[validatedBody.modelId as keyof typeof defaultModelLoopup];
86 | const baseModelValidated = y.string().oneOf(modelsToFinetune).required().validateSync(baseModel);
87 | const host = inferenceEndpoints[baseModelValidated];
88 |
89 | return retryInference(
90 | async () => {
91 | const response = await getResponse(host, model?.id, validatedBody);
92 | if (response.ok) {
93 | await createChatRequest(session.user.id, validatedBody.modelId);
94 | }
95 |
96 | return new Response(response.body);
97 | },
98 | {modelId: validatedBody.modelId},
99 | );
100 | }
101 |
--------------------------------------------------------------------------------
/main/src/app/api/inference/schema.ts:
--------------------------------------------------------------------------------
1 | import * as y from "yup";
2 |
3 | export const bodySchema = y.object({
4 | modelId: y.string().required(),
5 | history: y
6 | .array(
7 | y.object({
8 | role: y.string().oneOf(["user", "system", "assistant"]).required(),
9 | content: y.string().required(),
10 | }),
11 | )
12 | .required(),
13 | parameters: y.object({
14 | temperature: y.number().min(0).max(1.5).required(),
15 | topP: y.number().min(0).max(1).required(),
16 | maxTokens: y.number().min(10).max(2048).required(),
17 | repetitionPenalty: y.number().min(1).max(1.8).required(),
18 | doSample: y.boolean().required(),
19 | }),
20 | });
21 |
--------------------------------------------------------------------------------
/main/src/app/api/search/route.ts:
--------------------------------------------------------------------------------
1 | import * as y from "yup";
2 | import {getDatasetsByNameForUser} from "~/server/database/dataset";
3 | import {checkSessionAction} from "~/server/utils/session";
4 |
5 | import type {Dataset} from "@prisma/client";
6 |
7 | const bodySchema = y
8 | .object({
9 | query: y.string(),
10 | })
11 | .required();
12 |
13 | function filterDatasetProperties(datasets: Dataset[]) {
14 | return datasets.map((dataset) => ({
15 | id: dataset.id,
16 | name: dataset.name,
17 | rows: dataset.rows,
18 | }));
19 | }
20 |
21 | export type SearchResponse = ReturnType;
22 |
23 | export async function POST(request: Request) {
24 | const session = await checkSessionAction();
25 |
26 | const body = await request.json();
27 | const validatedBody = await bodySchema.validate(body);
28 |
29 | const results = await getDatasetsByNameForUser(session.user.id, validatedBody.query || "");
30 |
31 | return new Response(JSON.stringify(filterDatasetProperties(results)), {
32 | status: 200,
33 | headers: {
34 | "Content-Type": "application/json",
35 | },
36 | });
37 | }
38 |
--------------------------------------------------------------------------------
/main/src/app/api/trpc/[trpc]/route_inactive.ts:
--------------------------------------------------------------------------------
1 | // TODO
2 | /*
3 | import {createNextApiHandler} from "@trpc/server/adapters/next";
4 | import {NextApiRequest} from "next";
5 |
6 | import {env} from "~/env.mjs";
7 | import {appRouter} from "~/server/api/root";
8 | import {createTRPCContext} from "~/server/api/trpc";
9 |
10 | // export API handler
11 | const handler = createNextApiHandler({
12 | router: appRouter,
13 | createContext: createTRPCContext,
14 | onError:
15 | env.NODE_ENV === "development"
16 | ? ({path, error}) => {
17 | console.error(`❌ tRPC failed on ${path ?? ""}: ${error.message}`);
18 | }
19 | : undefined,
20 | }) as any as NextApiRequest;
21 | */
22 |
--------------------------------------------------------------------------------
/main/src/app/billing/actions.tsx:
--------------------------------------------------------------------------------
1 | "use server";
2 |
3 | import * as y from "yup";
4 |
5 | import {revalidatePath} from "next/cache";
6 | import {
7 | createPaymentIntent,
8 | createSetupIntent,
9 | createStripeCustomerAndAttachPaymentMethod,
10 | getStripeCreditCardInfo,
11 | validateSetupIntent,
12 | } from "~/server/controller/stripe";
13 | import {addStripeCustomerIdAndPaymentMethodToUser, increaseBalance, updateName} from "~/server/database/user";
14 | import {EventName, sendEvent} from "~/server/utils/observability/posthog";
15 | import {checkSessionAction} from "~/server/utils/session";
16 | import {logger} from "~/server/utils/observability/logtail";
17 |
18 | export async function revalidate() {
19 | return Promise.resolve().then(() => revalidatePath("/billing"));
20 | }
21 |
22 | export async function updateAccountName(name: string) {
23 | y.string().required().validateSync(name);
24 |
25 | const session = await checkSessionAction();
26 | await updateName(session.user.id, name).catch((e) => {
27 | logger.error("Error updating name", {error: e});
28 | throw new Error("Internal server error");
29 | });
30 | }
31 |
32 | export async function getStripePublishableKey() {
33 | if (process.env.STRIPE_PUBLISHABLE_KEY === undefined) {
34 | logger.error("Stripe publishable key is not set.");
35 | throw new Error("Stripe publishable key is not set.");
36 | }
37 | return Promise.resolve(process.env.STRIPE_PUBLISHABLE_KEY);
38 | }
39 |
40 | export async function getCreditCardInformation() {
41 | const session = await checkSessionAction();
42 |
43 | if (session.user.stripeCustomerId) {
44 | return getStripeCreditCardInfo(session.user.stripeCustomerId).catch((e) => {
45 | logger.error("Error getting stripe credit card info", {error: e});
46 | throw new Error("Internal server error");
47 | });
48 | }
49 | }
50 |
51 | export async function createSetupIntentAction(stripeCreditCardId: string) {
52 | // TODO: validation
53 | const intent = await createSetupIntent(stripeCreditCardId);
54 | return intent.client_secret;
55 | }
56 |
57 | export async function finalizeCreditCard(stripeCreditCardId: string, stripeSetupIntentClientSecret: string) {
58 | // TODO: valdation
59 | const session = await checkSessionAction();
60 | if (!session.user.email) {
61 | sendEvent(session.user.id, EventName.EMERGENCY, {
62 | message: `User email is not set. User id: ${session.user.id}`,
63 | });
64 | logger.error("User email is not set.");
65 | throw new Error("User email is not set.");
66 | }
67 |
68 | // Validating setup intent
69 | const setupIntentValid = await validateSetupIntent(stripeSetupIntentClientSecret);
70 | if (!setupIntentValid) {
71 | logger.error("Setup intent is invalid", {
72 | stripeSetupIntentClientSecret,
73 | stripeCreditCardId,
74 | });
75 | sendEvent(session.user.id, EventName.EMERGENCY, {
76 | message: `Setup intent is invalid. User id: ${session.user.id}, stripeSetupIntentClientSecret: ${stripeSetupIntentClientSecret}, stripeCreditCardId: ${stripeCreditCardId}`,
77 | });
78 | throw new Error("Error validating setup intend.");
79 | }
80 |
81 | // Creating a stripe customer and attaching the payment method
82 | const stripeCustomer = await createStripeCustomerAndAttachPaymentMethod(
83 | session.user.name ?? "",
84 | session.user.email,
85 | stripeCreditCardId,
86 | session.user.stripeCustomerId ?? undefined,
87 | ).catch((e) => {
88 | logger.error("Error creating stripe customer", {
89 | stripeSetupIntentClientSecret,
90 | stripeCreditCardId,
91 | error: e,
92 | });
93 | sendEvent(session.user.id, EventName.EMERGENCY, {
94 | message: `Error creating stripe customer. User id: ${session.user.id}, stripeSetupIntentClientSecret: ${stripeSetupIntentClientSecret}, stripeCreditCardId: ${stripeCreditCardId}`,
95 | error: (e as Error).message,
96 | });
97 | throw new Error("Error creating stripe customer.");
98 | });
99 |
100 | // Attaching stripe customer id to the databsae user
101 | await addStripeCustomerIdAndPaymentMethodToUser(session.user.id, stripeCustomer.id, stripeCreditCardId).catch((e) => {
102 | logger.error("Error adding stripe customer id to user", {
103 | stripeSetupIntentClientSecret,
104 | stripeCreditCardId,
105 | error: e,
106 | });
107 | sendEvent(session.user.id, EventName.EMERGENCY, {
108 | message: `Error adding stripe customer id to user. User id: ${session.user.id}, stripeSetupIntentClientSecret: ${stripeSetupIntentClientSecret}, stripeCreditCardId: ${stripeCreditCardId}`,
109 | error: (e as Error).message,
110 | });
111 | throw new Error("Internal server error.");
112 | });
113 |
114 | sendEvent(session.user.id, EventName.CREDIT_CARD_ADDED);
115 | }
116 |
117 | export async function addBalanceAction(amountInDollars: number) {
118 | const session = await checkSessionAction();
119 | if (!session.user.stripeCustomerId || !session.user.stripePaymentMethodId) {
120 | logger.error("User has no stripe customer id");
121 | throw new Error("Internal error.");
122 | }
123 |
124 | if (amountInDollars > 50) {
125 | logger.error("User is trying to add more than 50 dollars", {amountInDollars});
126 | throw new Error("Can't add more than $50 at once.");
127 | }
128 |
129 | if (session.user.centsBalance > 10000) {
130 | logger.error("User already has more than 100 dollars");
131 | throw new Error("You can't add more funds when your balance currently exceeds $100.");
132 | }
133 |
134 | const amountInCents = amountInDollars * 100;
135 |
136 | await createPaymentIntent(
137 | session.user.stripeCustomerId,
138 | session.user.stripePaymentMethodId,
139 | amountInCents,
140 | "Haven Account Top Up",
141 | ).catch((e) => {
142 | logger.error("Error creating off session payment intent", {error: e});
143 | sendEvent(session.user.id, EventName.EMERGENCY, {message: "User couldn't add balance", amount: amountInCents});
144 | throw new Error("Error charging credit card. If this persists, send us an email at hello@haven.run.");
145 | });
146 |
147 | await increaseBalance(session.user.id, amountInCents).catch((e) => {
148 | logger.error("Error increasing balance", {error: e});
149 | sendEvent(session.user.id, EventName.EMERGENCY, {
150 | message: "User charged but balance not increased.",
151 | amount: amountInCents,
152 | });
153 | throw new Error("Internal server error.");
154 | });
155 |
156 | sendEvent(session.user.id, EventName.MONEY_ADDED, {amount: amountInCents});
157 | }
158 |
--------------------------------------------------------------------------------
/main/src/app/billing/add-balance.tsx:
--------------------------------------------------------------------------------
1 | import {Dialog} from "@headlessui/react";
2 |
3 | import Modal from "../../components/modal";
4 | import {XMarkIcon} from "@heroicons/react/24/outline";
5 | import TextField from "../../components/form/textfield";
6 | import Button from "../../components/form/button";
7 | import {useState} from "react";
8 | import {addBalanceAction} from "./actions";
9 |
10 | interface Props {
11 | open: boolean;
12 | setOpen: (open: boolean) => void;
13 | }
14 |
15 | export default function AddBalance({open, setOpen}: Props) {
16 | const [value, setValue] = useState("$ 5");
17 | const [loading, setLoading] = useState(false);
18 | const [error, setError] = useState("");
19 |
20 | /**
21 | * Make sure only only integers are allowed and they are prefixed with a dollar sign
22 | */
23 | function onChange(str: string) {
24 | const re = /^\$ [0-9\b]*$/;
25 | if (re.test(str)) {
26 | setValue(str);
27 | }
28 | }
29 |
30 | async function addBalance() {
31 | const amountInDollars = parseInt(value.replace("$ ", ""));
32 |
33 | setLoading(true);
34 | try {
35 | await addBalanceAction(amountInDollars);
36 | setOpen(false);
37 | } catch (e) {
38 | setError((e as Error).message);
39 | }
40 | setLoading(false);
41 | }
42 |
43 | return (
44 |
45 |
46 |
47 |
48 | Add balance
49 |
50 | setOpen(false)}>
51 |
52 |
53 |
54 |
55 | Add balance to your account. The selected amount will be charged to your credit card.
56 |
57 |
58 |
{error}
59 |
60 | Confirm
61 |
62 |
63 |
64 | );
65 | }
66 |
--------------------------------------------------------------------------------
/main/src/app/billing/add-credit-card.tsx:
--------------------------------------------------------------------------------
1 | import {useState} from "react";
2 | import {Dialog} from "@headlessui/react";
3 | import Button from "../../components/form/button";
4 | import {CardElement, useElements, useStripe} from "@stripe/react-stripe-js";
5 | import {XMarkIcon} from "@heroicons/react/20/solid";
6 | import Modal from "../../components/modal";
7 | import {createSetupIntentAction, finalizeCreditCard, updateAccountName} from "./actions";
8 | import TextField from "~/components/form/textfield";
9 | import Label from "~/components/form/label";
10 |
11 | interface Props {
12 | open: boolean;
13 | setOpen: (open: boolean) => void;
14 | }
15 |
16 | export default function AddCreditCard({open, setOpen}: Props) {
17 | const [loading, setLoading] = useState(false);
18 | const [error, setError] = useState("");
19 |
20 | const [name, setName] = useState("");
21 |
22 | const elements = useElements()!;
23 | const stripe = useStripe()!;
24 |
25 | async function addCreditCard() {
26 | setLoading(true);
27 |
28 | const cardElement = elements.getElement(CardElement);
29 |
30 | await updateAccountName(name);
31 |
32 | const {error, paymentMethod} = await stripe.createPaymentMethod({
33 | type: "card",
34 | card: cardElement!,
35 | });
36 |
37 | if (error) {
38 | setError("Failed to add credit card. Please try a different card or try again later.");
39 | setLoading(false);
40 | return;
41 | }
42 |
43 | const clientSecret = await createSetupIntentAction(paymentMethod.id);
44 | if (!clientSecret) {
45 | setError("Failed to add credit card. Please try a different card or try again later.");
46 | setLoading(false);
47 | return;
48 | }
49 |
50 | const confirmationResult = await stripe.confirmCardSetup(clientSecret);
51 | if (confirmationResult.error) {
52 | setError("Failed to add credit card. Please try a different card or try again later.");
53 | setLoading(false);
54 | return;
55 | }
56 |
57 | await finalizeCreditCard(paymentMethod.id, clientSecret);
58 |
59 | setError("");
60 | setLoading(false);
61 | setOpen(false);
62 | }
63 |
64 | return (
65 |
66 |
67 |
68 |
69 | Add a credit card
70 |
71 | setOpen(false)}>
72 |
73 |
74 |
75 |
76 | {"We won't charge your card until you add balance to your account manually."}
77 |
78 |
79 |
80 |
81 |
82 |
Credit card
83 |
84 |
85 |
86 |
87 |
{error}
88 |
89 |
90 | addCreditCard()} className="w-full justify-center" loading={loading}>
91 | Verify and add
92 |
93 |
94 |
100 |
101 | );
102 | }
103 |
--------------------------------------------------------------------------------
/main/src/app/billing/billing.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import Padding from "../../components/padding";
3 | import PageHeading from "../../components/page-heading";
4 | import Sidebar from "../../components/sidebar";
5 | import Warning from "../../components/warning";
6 | import {useState} from "react";
7 | import Button from "../../components/form/button";
8 | import AddCreditCard from "./add-credit-card";
9 | import AddBalance from "./add-balance";
10 | import {Elements} from "@stripe/react-stripe-js";
11 | import {revalidate} from "./actions";
12 | import {loadStripe} from "@stripe/stripe-js";
13 |
14 | interface Props {
15 | balance: string;
16 | creditCardInformation?: {
17 | last4: string;
18 | expiry: string;
19 | };
20 | stripePublishableKey: string;
21 | }
22 |
23 | export default function Billing({balance, creditCardInformation, stripePublishableKey}: Props) {
24 | const [openCreditCardModal, setOpenCreditCardModal] = useState(false);
25 | const [openAddBalanceModal, setOpenAddBalanceModal] = useState(false);
26 |
27 | // Make sure we refresh the page when the modals close
28 | function refreshWrapper(setOpenCloseFunction: (state: boolean) => void) {
29 | return (state: boolean) => {
30 | setOpenCloseFunction(state);
31 | void revalidate();
32 | };
33 | }
34 |
35 | return (
36 | <>
37 |
38 |
39 |
40 |
41 |
42 |
43 | Billing
44 |
45 |
46 | {!creditCardInformation && (
47 |
48 | )}
49 |
50 |
51 |
52 |
53 |
Balance
54 |
{balance}
55 |
{
57 | if (creditCardInformation) {
58 | setOpenAddBalanceModal(true);
59 | } else {
60 | setOpenCreditCardModal(true);
61 | }
62 | }}
63 | >
64 | {creditCardInformation ? "Add balance" : "Add credit card"}
65 |
66 |
67 | {creditCardInformation && (
68 |
69 |
70 |
71 |
72 |
**** {creditCardInformation.last4}
73 |
Expires {creditCardInformation.expiry}
74 |
75 |
76 |
77 |
Using
78 |
79 | {/*
80 | //TODO: Enable delete button
81 |
{}}>
82 | Delete
83 |
84 | */}
85 |
86 |
87 |
88 | )}
89 |
90 |
91 |
92 | >
93 | );
94 | }
95 |
--------------------------------------------------------------------------------
/main/src/app/billing/page.tsx:
--------------------------------------------------------------------------------
1 | import {checkSession} from "~/server/utils/session";
2 | import Billing from "./billing";
3 | import {getCreditCardInformation, getStripePublishableKey} from "./actions";
4 |
5 | export default async function Page() {
6 | const [session, creditCardInformation, stripePublishableKey] = [
7 | await checkSession(),
8 | await getCreditCardInformation(),
9 | await getStripePublishableKey(),
10 | ];
11 | const balance = "$" + (session.user.centsBalance / 100).toFixed(2);
12 |
13 | return (
14 | <>
15 |
20 | >
21 | );
22 | }
23 |
--------------------------------------------------------------------------------
/main/src/app/chat/[slug]/chat.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import {SendIcon} from "lucide-react";
3 | import {useEffect, useRef, useState} from "react";
4 | import Button from "~/components/form/button";
5 | import Label from "~/components/form/label";
6 | import Settings from "./settings";
7 | import {AdjustmentsHorizontalIcon} from "@heroicons/react/24/outline";
8 | import {parseStream} from "./process";
9 |
10 | import type {ChatMessage} from "./process";
11 | import type {ChatSettings} from "./settings";
12 | import UserAvatar from "~/components/user-avatar";
13 | import {CpuChipIcon} from "@heroicons/react/20/solid";
14 |
15 | export default function ChatComponent({modelId, email}: {modelId: string; email?: string}) {
16 | const [loading, setLoading] = useState(false);
17 |
18 | const [systemPrompt, setSystemPrompt] = useState("You're a useful assistant.");
19 | const [message, setMessage] = useState("");
20 |
21 | const lastMessageRef = useRef(null);
22 | const [chat, setChat] = useState([]);
23 |
24 | const [settingsOpen, setSettingsOpen] = useState(false);
25 | const [chatSettings, setChatSettings] = useState({
26 | temperature: 0.9,
27 | topP: 0.8,
28 | maxTokens: 256,
29 | repetitionPenalty: 1.1,
30 | doSample: true,
31 | });
32 |
33 | function AssistantAvatar() {
34 | return (
35 |
40 | );
41 | }
42 |
43 | const handleFormSubmit = async (event: React.FormEvent) => {
44 | event.preventDefault();
45 |
46 | setLoading(true);
47 | setMessage("");
48 |
49 | const newChat: ChatMessage[] = [...chat, {role: "user", content: message}];
50 | setChat(newChat);
51 |
52 | const res = await fetch("/api/inference", {
53 | method: "POST",
54 | headers: {
55 | "content-type": "application/json",
56 | },
57 | body: JSON.stringify({
58 | modelId,
59 | history: [{role: "system", content: systemPrompt}, ...newChat],
60 | parameters: chatSettings,
61 | }),
62 | });
63 |
64 | if (res.status === 402) {
65 | alert(`You have reached the rate limit of 100 messages in 24h :(`);
66 | return;
67 | }
68 |
69 | // Res is a stream, read it line by line
70 | const reader = res.body!.getReader();
71 |
72 | newChat.push({role: "assistant", content: ""});
73 | await parseStream(reader, newChat, setChat);
74 |
75 | setLoading(false);
76 | };
77 |
78 | useEffect(() => {
79 | if (lastMessageRef.current) {
80 | lastMessageRef.current.scrollIntoView({behavior: "smooth"});
81 | }
82 | }, [chat]);
83 |
84 | return (
85 | <>
86 |
92 | setSettingsOpen(true)}>
93 |
94 | Chat Parameters
95 |
96 |
97 | System prompt
98 | setSystemPrompt(e.target.value)}
103 | placeholder="You're a useful assistant."
104 | />
105 |
106 | {chat.map((message, i) => (
107 |
112 | {message.role === "user" ?
:
}
113 |
{message.content}
114 |
115 | ))}
116 |
117 |
137 | >
138 | );
139 | }
140 |
--------------------------------------------------------------------------------
/main/src/app/chat/[slug]/page.tsx:
--------------------------------------------------------------------------------
1 | import {ChevronLeftIcon} from "@heroicons/react/20/solid";
2 | import {redirect} from "next/navigation";
3 | import Button from "~/components/form/button";
4 | import Sidebar from "~/components/sidebar";
5 | import {getModelFromId} from "~/server/database/model";
6 | import {checkSession} from "~/server/utils/session";
7 | import ChatComponent from "./chat";
8 | import Link from "next/link";
9 | import {defaultModelLoopup} from "~/constants/models";
10 |
11 | export default async function Chat({params}: {params: {slug: string}}) {
12 | const {slug} = params;
13 |
14 | const [session, model] = await Promise.all([checkSession(), getModelFromId(slug)]);
15 |
16 | const isInternal = Object.keys(defaultModelLoopup).includes(slug);
17 |
18 | if (!isInternal && model?.userId !== session.user.id) {
19 | redirect("/models");
20 | }
21 |
22 | return (
23 |
24 |
25 |
26 |
27 | Close chat
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | );
36 | }
37 |
--------------------------------------------------------------------------------
/main/src/app/chat/[slug]/process.tsx:
--------------------------------------------------------------------------------
1 | import * as y from "yup";
2 |
3 | export type ChatMessage = {
4 | role: "user" | "assistant" | "system";
5 | content: string;
6 | };
7 |
8 | export async function processChunk(current: string, newChat: ChatMessage[], setChat: (chat: ChatMessage[]) => void) {
9 | // Split the stream into lines
10 | const lines = current.split("\n\n");
11 |
12 | if (lines.length === 0) {
13 | return "";
14 | }
15 |
16 | // Check if the last line is incomplete
17 | const lastLineIsIncomplete = await Promise.resolve(lines[lines.length - 1]!.slice(5))
18 | .then((line) => {
19 | JSON.parse(line);
20 | return false;
21 | })
22 | .catch(() => true);
23 |
24 | const linesToProcess = lastLineIsIncomplete ? lines.slice(0, -1) : lines;
25 |
26 | // JSON parse all but the last line
27 | for (const line of linesToProcess) {
28 | if (!line) {
29 | continue;
30 | }
31 |
32 | const schema = y.object({
33 | // Our multi lora server provides this
34 | token: y
35 | .object({
36 | text: y.string(),
37 | special: y.boolean(),
38 | })
39 | .required(),
40 | // Fireworks AI provides this
41 | choices: y.array(
42 | y.object({
43 | index: y.number(),
44 | delta: y.object({
45 | content: y.string(),
46 | }),
47 | }),
48 | ),
49 | });
50 |
51 | // Line starts with "data:", skip it
52 | const validated = await Promise.resolve()
53 | .then(() => schema.validate(JSON.parse(line.slice(5))))
54 | .catch(() => undefined);
55 |
56 | // Multi lora
57 | if (validated && Object.keys(validated.token).length > 0) {
58 | // Ignore special tokens
59 | if (validated.token.special) {
60 | continue;
61 | }
62 |
63 | // Add the message to the chat
64 | const current = newChat[newChat.length - 1]!;
65 | current.content += validated.token.text;
66 | setChat([...newChat.slice(0, -1), current]);
67 | }
68 |
69 | // Fireworks AI
70 | if (validated?.choices && Object.keys(validated.choices).length > 0) {
71 | const current = newChat[newChat.length - 1]!;
72 |
73 | if (!validated.choices[0]?.delta.content) {
74 | continue;
75 | }
76 |
77 | current.content += validated.choices[0].delta.content;
78 | setChat([...newChat.slice(0, -1), current]);
79 | }
80 | }
81 |
82 | if (lastLineIsIncomplete) {
83 | // Carry over
84 | return lines[lines.length - 1]!;
85 | }
86 |
87 | return "";
88 | }
89 |
90 | export async function parseStream(
91 | reader: ReadableStreamDefaultReader,
92 | newChat: ChatMessage[],
93 | setChat: (chat: ChatMessage[]) => void,
94 | ) {
95 | let current = "";
96 |
97 | // Iterate over the stream
98 | while (true) {
99 | const {done, value} = await reader.read();
100 |
101 | const decoded = new TextDecoder("utf-8").decode(value);
102 |
103 | current += decoded;
104 | current = await processChunk(current, newChat, setChat);
105 |
106 | if (done) {
107 | break;
108 | }
109 | }
110 | }
111 |
--------------------------------------------------------------------------------
/main/src/app/chat/[slug]/settings.tsx:
--------------------------------------------------------------------------------
1 | import Modal from "~/components/modal";
2 | import {Slider} from "~/components/ui/slider";
3 | import Label from "~/components/form/label";
4 | import {useState} from "react";
5 | import Button from "~/components/form/button";
6 | import {Dialog} from "@headlessui/react";
7 | import {XMarkIcon} from "@heroicons/react/20/solid";
8 | import {Switch} from "~/components/ui/switch";
9 |
10 | export interface ChatSettings {
11 | temperature: number;
12 | topP: number;
13 | maxTokens: number;
14 | repetitionPenalty: number;
15 | doSample: boolean;
16 | }
17 |
18 | function SettingsSlider({
19 | label,
20 | value,
21 | setValue,
22 | min = 0.1,
23 | max = 1,
24 | step = 0.01,
25 | }: {
26 | label: string;
27 | value: number;
28 | setValue: (value: number) => void;
29 | min?: number;
30 | max?: number;
31 | step?: number;
32 | }) {
33 | return (
34 |
35 |
36 |
{label}
37 |
{value}
38 |
39 |
setValue(value[0]!)}
42 | defaultValue={[value]}
43 | value={[value]}
44 | min={min}
45 | max={max}
46 | step={step}
47 | />
48 |
49 | );
50 | }
51 |
52 | export default function Settings({
53 | open,
54 | setOpen,
55 | chatSettings,
56 | setChatSettings,
57 | }: {
58 | open: boolean;
59 | setOpen: (open: boolean) => void;
60 | chatSettings: ChatSettings;
61 | setChatSettings: (settings: ChatSettings) => void;
62 | }) {
63 | // NOTE: we don't need the extra state here
64 | const [temperature, setTemperature] = useState(chatSettings.temperature);
65 | const [topP, setTopP] = useState(chatSettings.topP);
66 | const [maxTokens, setMaxTokens] = useState(chatSettings.maxTokens);
67 | const [repetitionPenalty, setRepetitionPenalty] = useState(chatSettings.repetitionPenalty);
68 | const [sample, setSample] = useState(chatSettings.doSample);
69 |
70 | function save() {
71 | setChatSettings({
72 | temperature,
73 | topP,
74 | maxTokens,
75 | repetitionPenalty,
76 | doSample: sample,
77 | });
78 | setOpen(false);
79 | }
80 |
81 | return (
82 | <>
83 |
84 | setOpen(false)}>
85 |
86 |
87 |
88 | Chat Parameters
89 |
90 |
91 |
92 |
93 |
100 |
101 | Sample
102 |
103 |
104 |
105 | Save
106 |
107 |
108 | >
109 | );
110 | }
111 |
--------------------------------------------------------------------------------
/main/src/app/datasets/new/actions.tsx:
--------------------------------------------------------------------------------
1 | "use server";
2 |
3 | import * as y from "yup";
4 | import {checkSessionAction} from "~/server/utils/session";
5 | import {revalidatePath} from "next/cache";
6 | import {uploadDataset} from "~/server/controller/new-dataset";
7 |
8 | import type {State} from "./form";
9 | import {logger} from "~/server/utils/observability/logtail";
10 |
11 | const FormDataSchema = y.object({
12 | name: y.string().required(),
13 | dropzoneFile: y.mixed().required(),
14 | });
15 |
16 | export type FormDataType = y.InferType;
17 |
18 | function buildState(datasetId?: string, datasetName?: string, message?: string) {
19 | return {
20 | datasetId,
21 | datasetName,
22 | message,
23 | };
24 | }
25 |
26 | export async function uploadDatasetAction(_: State, formData: FormData) {
27 | const session = await checkSessionAction();
28 |
29 | let id: string | undefined;
30 | let name: string | undefined;
31 | try {
32 | const validatedForm = await FormDataSchema.validate(Object.fromEntries(formData));
33 |
34 | // TODO: Make sure validatedForm.dropzoneFile is a File
35 |
36 | id = await uploadDataset(session.user.id, validatedForm);
37 | name = validatedForm.name;
38 |
39 | revalidatePath("/datasets");
40 | } catch (e) {
41 | logger.warn("Failed to upload dataset", {error: e});
42 | return buildState(undefined, undefined, (e as Error).message);
43 | }
44 |
45 | return buildState(id, name, undefined);
46 | }
47 |
--------------------------------------------------------------------------------
/main/src/app/datasets/new/form.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import {useFormState} from "react-dom";
3 | import Label from "~/components/form/label";
4 | import TextField from "~/components/form/textfield";
5 | import SubmitButton from "~/components/form/submit-button";
6 | import Tooltip from "~/components/tooltip";
7 | import {UploadIcon} from "lucide-react";
8 | import {uploadDatasetAction} from "./actions";
9 |
10 | const initialState: {
11 | datasetId?: string;
12 | datasetName?: string;
13 | message?: string;
14 | } = {
15 | datasetId: undefined,
16 | datasetName: undefined,
17 | message: undefined,
18 | };
19 |
20 | export type State = typeof initialState;
21 |
22 | const tooltips = {
23 | datasetName: (
24 | <>What do you want to call your dataset? The name should be unique and contain only letters, numbers, and dashes.>
25 | ),
26 | dataset: (
27 | <>
28 | The dataset you want to use for fine-tuning. The dataset should be in{" "}
29 |
30 | this format
31 |
32 | . We recommend using a dataset with at least 100 conversations.
33 | >
34 | ),
35 | };
36 |
37 | export default function NewDatasetForm({
38 | setModalOpen,
39 | }: {
40 | setModalOpen: (open: boolean, datasetId?: string, datasetName?: string) => void;
41 | }) {
42 | const [state, formAction] = useFormState(uploadDatasetAction, initialState);
43 |
44 | if (state.datasetId) {
45 | setModalOpen(false, state.datasetId, state.datasetName || undefined);
46 | }
47 |
48 | return (
49 | <>
50 |
89 | >
90 | );
91 | }
92 |
--------------------------------------------------------------------------------
/main/src/app/datasets/new/modal.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import {XMarkIcon} from "@heroicons/react/20/solid";
3 | import NewDatasetForm from "./form";
4 | import Modal from "~/components/modal";
5 | import {Dialog} from "@headlessui/react";
6 |
7 | export default function NewDatasetModal({
8 | open,
9 | setOpen,
10 | }: {
11 | open: boolean;
12 | setOpen: (open: boolean, datasetId?: string, datasetName?: string) => void;
13 | }) {
14 | return (
15 |
16 | setOpen(false)}>
17 |
18 |
19 |
20 | Upload dataset
21 |
22 |
23 |
24 | );
25 | }
26 |
--------------------------------------------------------------------------------
/main/src/app/datasets/page.tsx:
--------------------------------------------------------------------------------
1 | import Padding from "~/components/padding";
2 | import PageHeading from "~/components/page-heading";
3 | import Sidebar from "~/components/sidebar";
4 | import {checkSession} from "~/server/utils/session";
5 | import DatasetTable from "./table";
6 | import {getDatasets} from "~/server/database/dataset";
7 |
8 | import type {Dataset} from "@prisma/client";
9 |
10 | function updatedAtToPrettyString(updatedAt: Date) {
11 | const now = new Date();
12 | const diff = now.getTime() - updatedAt.getTime();
13 | const diffInDays = diff / (1000 * 3600 * 24);
14 | const diffInWeeks = diffInDays / 7;
15 |
16 | if (diffInWeeks >= 3) {
17 | return `${Math.floor(diffInWeeks)} weeks ago`;
18 | }
19 |
20 | if (Math.floor(diffInDays) == 1) {
21 | return "1 day ago";
22 | }
23 |
24 | if (diffInDays > 1) {
25 | return `${Math.floor(diffInDays)} days ago`;
26 | }
27 |
28 | const diffInHours = diff / (1000 * 3600);
29 | if (Math.floor(diffInHours) == 1) {
30 | return "1 hour ago";
31 | }
32 |
33 | if (diffInHours > 1) {
34 | return `${Math.floor(diffInHours)} hours ago`;
35 | }
36 |
37 | const diffInMinutes = diff / (1000 * 60);
38 | if (Math.floor(diffInMinutes) == 1) {
39 | return "1 minute ago";
40 | }
41 |
42 | if (diffInMinutes >= 1) {
43 | return `${Math.floor(diffInMinutes)} minutes ago`;
44 | }
45 |
46 | return "Just now";
47 | }
48 |
49 | function filterPropsForTable(datasets: Dataset[]) {
50 | return datasets.map((dataset) => ({
51 | id: dataset.id,
52 | name: dataset.name,
53 | rows: dataset.rows,
54 | created: updatedAtToPrettyString(dataset.createdAt),
55 | }));
56 | }
57 |
58 | export type DatasetTableProps = ReturnType;
59 |
60 | export default async function Page() {
61 | const session = await checkSession();
62 | const datasets = await getDatasets(session.user.id);
63 | const filtered = filterPropsForTable(datasets);
64 |
65 | return (
66 |
67 |
68 | Datasets
69 |
70 |
71 |
72 |
73 | );
74 | }
75 |
--------------------------------------------------------------------------------
/main/src/app/datasets/table.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import {PlusIcon} from "@heroicons/react/20/solid";
3 | import Button from "~/components/form/button";
4 | import {useState} from "react";
5 | import NewDatasetModal from "./new/modal";
6 |
7 | import type {DatasetTableProps} from "./page";
8 |
9 | export default function DatasetTable({datasets}: {datasets: DatasetTableProps}) {
10 | const [newDatasetModalOpen, setNewDatasetModalOpen] = useState(false);
11 |
12 | function NewModelButton() {
13 | return (
14 | setNewDatasetModalOpen(true)}>
15 |
16 | Add new dataset
17 |
18 | );
19 | }
20 |
21 | return (
22 | <>
23 |
24 |
25 |
26 |
27 |
28 | You can create and manage your datasets here. Learn more about the dataset format{" "}
29 |
30 | in our docs.
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 | Name
46 |
47 |
48 | Rows
49 |
50 |
51 | Created
52 |
53 |
54 |
55 |
56 | {datasets.map((dataset) => (
57 |
58 |
59 | {dataset.name}
60 |
61 | {/*{dataset.description} */}
62 | {dataset.rows}
63 | {dataset.created}
64 | {/*
65 |
66 |
67 | View, {dataset.name}
68 |
69 |
70 | */}
71 |
72 | ))}
73 |
74 |
75 |
76 |
77 |
78 |
79 | >
80 | );
81 | }
82 |
--------------------------------------------------------------------------------
/main/src/app/layout.tsx:
--------------------------------------------------------------------------------
1 | import type {Metadata} from "next";
2 | import "../styles/globals.css";
3 |
4 | import {NextAuthProvider} from "./providers";
5 | import Script from "next/script";
6 |
7 | export const metadata: Metadata = {
8 | title: "Haven",
9 | description: "Fine-tune open source language models.",
10 | };
11 |
12 | export default function RootLayout({children}: {children: React.ReactNode}) {
13 | return (
14 |
15 |
19 |
20 | {children}
21 |
22 |
23 | );
24 | }
25 |
--------------------------------------------------------------------------------
/main/src/app/models/actions.tsx:
--------------------------------------------------------------------------------
1 | "use server";
2 | import * as y from "yup";
3 |
4 | import type {State} from "./export";
5 | import {checkSessionAction} from "~/server/utils/session";
6 | import {getModelFromId} from "~/server/database/model";
7 | import {logger} from "~/server/utils/observability/logtail";
8 | import {exportEndpoint} from "~/constants/modal";
9 |
10 | const FormDataSchema = y.object({
11 | modelId: y.string().required(),
12 | hfToken: y.string().required(),
13 | namespace: y.string().required(),
14 | name: y.string().required(),
15 | });
16 |
17 | export async function exportModel(state: State, payload: FormData) {
18 | const session = await checkSessionAction();
19 |
20 | // validate form data
21 | let validatedForm: y.InferType;
22 | try {
23 | validatedForm = await FormDataSchema.validate(Object.fromEntries(payload));
24 | } catch (e: unknown) {
25 | const error = e as Error;
26 | logger.error("[exportModel] Invalid form data", {error: error.message});
27 | return {
28 | success: false,
29 | error: error.message,
30 | };
31 | }
32 |
33 | // check if model belongs to user
34 | const model = await getModelFromId(validatedForm.modelId);
35 | if (!model) {
36 | logger.error("[exportModel] Model not found", {modelId: validatedForm.modelId});
37 | return {
38 | success: false,
39 | error: "Unexpected error",
40 | };
41 | }
42 |
43 | if (model.userId !== session.user.id) {
44 | logger.error("[exportModel] User does not own model", {modelId: validatedForm.modelId});
45 | return {
46 | success: false,
47 | error: "Unexpected error",
48 | };
49 | }
50 |
51 | const body = {
52 | hf_name: validatedForm.namespace,
53 | model_name: validatedForm.name,
54 | model_id: validatedForm.modelId,
55 | base_model_name: model.baseModel,
56 | hf_token: validatedForm.hfToken,
57 | };
58 |
59 | logger.info("[exportModel] Starting export", {body});
60 |
61 | const res = await fetch(exportEndpoint, {
62 | method: "POST",
63 | body: JSON.stringify(body),
64 | headers: {
65 | "Content-Type": "application/json",
66 | },
67 | });
68 |
69 | if (res.status === 200) {
70 | logger.info("[exportModel] Exported model to huggingface", {body});
71 | return {
72 | success: true,
73 | error: "Success!",
74 | };
75 | }
76 |
77 | logger.error("[exportModel] Something went wrong when exporting", {
78 | status: res.status,
79 | request: body,
80 | report: await res.text(),
81 | });
82 | return {
83 | success: false,
84 | error: "Something went wrong and the team has been notified.",
85 | };
86 | }
87 |
--------------------------------------------------------------------------------
/main/src/app/models/buttons.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import {ChatBubbleBottomCenterIcon, EllipsisVerticalIcon} from "@heroicons/react/20/solid";
3 | import {ExclamationTriangleIcon} from "@heroicons/react/24/outline";
4 | import {Loader2Icon} from "lucide-react";
5 | import Image from "next/image";
6 | import Link from "next/link";
7 | import {useState} from "react";
8 | import Button from "~/components/form/button";
9 | import Dropdown from "~/components/form/dropdown";
10 | import ExportModal from "./export";
11 |
12 | export default function Buttons({
13 | modelId,
14 | state,
15 | wandbUrl,
16 | hfToken,
17 | }: {
18 | modelId: string;
19 | state: "training" | "finished" | "error" | "online";
20 | wandbUrl?: string;
21 | hfToken?: string;
22 | }) {
23 | const [isExportModalOpen, setIsExportModalOpen] = useState(false);
24 |
25 | /**
26 | * Process dropdown item selection
27 | */
28 | function selectItem(item: string, modelId: string, wandbUrl?: string) {
29 | if (item === "Logs" && wandbUrl) {
30 | window.open(wandbUrl);
31 | }
32 |
33 | if (item === "Export") {
34 | setIsExportModalOpen(true);
35 | }
36 | }
37 |
38 | const wandbButton = (
39 |
40 |
41 |
42 | Logs
43 |
44 |
45 | );
46 |
47 | const chatButton = (
48 |
49 |
50 |
51 | Chat
52 |
53 |
54 | );
55 |
56 | if (state === "training" && wandbUrl) {
57 | return wandbButton;
58 | }
59 |
60 | if (state === "training") {
61 | return ;
62 | }
63 |
64 | if (state === "finished") {
65 | const options = wandbUrl ? ["Logs", "Export"] : ["Export"];
66 |
67 | return (
68 | <>
69 |
70 |
71 | {chatButton}
72 |
selectItem(item, modelId, wandbUrl)}
75 | dropdownDirection="downLeft"
76 | >
77 |
78 |
79 |
80 |
81 |
82 | >
83 | );
84 | }
85 |
86 | // If model is hardcoded we just show the chat button.
87 | if (state === "online") {
88 | return chatButton;
89 | }
90 |
91 | // Error state
92 | return ;
93 | }
94 |
--------------------------------------------------------------------------------
/main/src/app/models/export.tsx:
--------------------------------------------------------------------------------
1 | import {Dialog} from "@headlessui/react";
2 | import {useFormState} from "react-dom";
3 | import Label from "~/components/form/label";
4 | import TextField from "~/components/form/textfield";
5 | import Modal from "~/components/modal";
6 | import Tooltip from "~/components/tooltip";
7 | import {exportModel} from "./actions";
8 | import {XMarkIcon} from "@heroicons/react/20/solid";
9 | import SubmitButton from "~/components/form/submit-button";
10 |
11 | export interface State {
12 | success: boolean;
13 | error: string;
14 | }
15 |
16 | const initialState: State = {
17 | success: false,
18 | error: "",
19 | };
20 |
21 | const tooltipContent = (
22 | <>
23 | Your Huggingface token is used to upload the model to your account. You can find your token{" "}
24 |
25 | here
26 |
27 | .
28 | >
29 | );
30 |
31 | export default function ExportModal({
32 | open,
33 | setOpen,
34 | modelId,
35 | hfToken,
36 | }: {
37 | open: boolean;
38 | setOpen: (open: boolean) => void;
39 | modelId: string;
40 | hfToken?: string;
41 | }) {
42 | const [state, formAction] = useFormState(exportModel, initialState);
43 |
44 | return (
45 |
46 |
79 |
80 | );
81 | }
82 |
--------------------------------------------------------------------------------
/main/src/app/models/list.tsx:
--------------------------------------------------------------------------------
1 | import {ArrowUpLeftIcon, ArrowUpRightIcon} from "@heroicons/react/24/outline";
2 | import Buttons from "./buttons";
3 | import type {ModelProps} from "./page";
4 |
5 | const environments = {
6 | training: "text-yellow-400 bg-yellow-400/10 ring-yellow-400/20",
7 | finished: "text-green-500 bg-green-500/10 ring-green-500/30",
8 | error: "text-red-500 bg-red-500/10 ring-red-500/30",
9 | };
10 |
11 | function classNames(...classes: string[]) {
12 | return classes.filter(Boolean).join(" ");
13 | }
14 |
15 | interface ModelListProps {
16 | header?: string;
17 | // typeof return type of getModels
18 | models: ModelProps[];
19 | hfToken?: string;
20 | }
21 |
22 | function Empty() {
23 | return (
24 |
39 | );
40 | }
41 |
42 | export default function List({header, models, hfToken}: ModelListProps) {
43 | return (
44 | <>
45 | {header && (
46 | <>
47 |
50 |
51 | >
52 | )}
53 |
54 | {models.length ? (
55 |
56 | {models.map((model) => (
57 |
58 |
59 |
60 |
61 |
62 | {model.modelName}
63 |
64 |
65 |
71 | {model.status.charAt(0).toUpperCase() + model.status.slice(1)}
72 |
73 |
74 |
75 |
{model.baseModel}
76 | {model.datasetName && (
77 | <>
78 |
79 |
{model.datasetName}
80 | >
81 | )}
82 |
83 |
84 |
90 |
91 | ))}
92 |
93 | ) : (
94 |
95 | )}
96 | >
97 | );
98 | }
99 |
--------------------------------------------------------------------------------
/main/src/app/models/new/actions.tsx:
--------------------------------------------------------------------------------
1 | "use server";
2 | import {calculatePrice, createNewModelTraining} from "~/server/controller/new-model";
3 | import * as y from "yup";
4 | import {checkSessionAction} from "~/server/utils/session";
5 | import {revalidatePath} from "next/cache";
6 | import {modelsToFinetune} from "~/constants/models";
7 |
8 | import type {State} from "./form";
9 |
10 | const FormDataSchema = y.object({
11 | name: y.string().required(),
12 | datasetId: y.string().required(),
13 | learningRate: y.string().oneOf(["Low", "Medium", "High"]).required(),
14 | numberOfEpochs: y.number().min(1).required(),
15 | baseModel: y.string().oneOf(modelsToFinetune).required(),
16 | confirmedPrice: y.number(),
17 | });
18 |
19 | export type FormDataType = y.InferType;
20 |
21 | export async function revalidate() {
22 | return Promise.resolve().then(() => revalidatePath("/models"));
23 | }
24 |
25 | function buildState(
26 | success: boolean,
27 | message: string | null,
28 | priceInCents: number | null,
29 | userHasEnoughCredits: boolean | null,
30 | formData: FormData | null,
31 | ): State {
32 | return {
33 | success,
34 | message,
35 | priceInCents,
36 | userHasEnoughCredits,
37 | formData,
38 | };
39 | }
40 |
41 | export async function validateAndCalculatePrice(_: State, formData: FormData) {
42 | const session = await checkSessionAction();
43 |
44 | try {
45 | const validatedForm = await FormDataSchema.validate(Object.fromEntries(formData));
46 | const priceInCents = await calculatePrice(session, validatedForm);
47 | const userHasEnoughCredits = session.user.centsBalance >= priceInCents;
48 | return buildState(false, null, priceInCents, userHasEnoughCredits, formData);
49 | } catch (e) {
50 | return buildState(false, (e as Error).message, null, null, null);
51 | }
52 | }
53 |
54 | export async function startNewTraining(_: State, formData: FormData) {
55 | const session = await checkSessionAction();
56 |
57 | try {
58 | const validatedForm = await FormDataSchema.validate(Object.fromEntries(formData));
59 | await createNewModelTraining(session, validatedForm, validatedForm.confirmedPrice);
60 |
61 | return buildState(true, null, null, null, null);
62 | } catch (e: unknown) {
63 | return buildState(false, (e as Error).message, null, null, null);
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/main/src/app/models/new/confirm-price.tsx:
--------------------------------------------------------------------------------
1 | import Modal from "~/components/modal";
2 | import {useFormState} from "react-dom";
3 | import {revalidate, startNewTraining} from "./actions";
4 | import SubmitButton from "~/components/form/submit-button";
5 | import {useRouter} from "next/navigation";
6 |
7 | import type {State} from "./form";
8 |
9 | export default function ConfirmModal({
10 | state,
11 | open,
12 | setOpen,
13 | }: {
14 | state: State;
15 | open: boolean;
16 | setOpen: (state: boolean) => void;
17 | }) {
18 | const [newState, formAction] = useFormState(startNewTraining, state);
19 |
20 | const router = useRouter();
21 |
22 | if (newState.success) {
23 | void revalidate().then(() => router.push("/models"));
24 | }
25 |
26 | function formActionWrapper(_: FormData) {
27 | if (!state.priceInCents || !state.formData) {
28 | // TODO: handle
29 | return;
30 | }
31 |
32 | (state.formData as any as string[][]).push(["confirmedPrice", state.priceInCents.toString()]);
33 | formAction(state.formData);
34 | }
35 |
36 | return (
37 |
38 | Cost for this run
39 |
46 | {`$${((state.priceInCents || 0) / 100).toFixed(2)}`}
49 | {state.userHasEnoughCredits ? (
50 | <>
51 |
52 | The fine-tuning process will start and this amount will be deducted from your balance once you click
53 | "Confirm and start".
54 |
55 |
56 | If the training process stops due to a an internal error, your credits will be reimbursed.
57 |
58 | >
59 | ) : (
60 |
61 | {"You don't have enough balance left. Please first add more balance to your account."}
62 |
63 | )}
64 |
65 | {newState.message && {newState.message}
}
66 | {state.userHasEnoughCredits && (
67 |
70 | )}
71 |
72 | );
73 | }
74 |
--------------------------------------------------------------------------------
/main/src/app/models/new/form.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import {ChevronDownIcon} from "@heroicons/react/20/solid";
3 | import {useState} from "react";
4 | import {useFormState} from "react-dom";
5 | import Dropdown from "~/components/form/dropdown";
6 | import Label from "~/components/form/label";
7 | import TextField from "~/components/form/textfield";
8 | import {validateAndCalculatePrice} from "./actions";
9 | import ConfirmModal from "./confirm-price";
10 | import SubmitButton from "~/components/form/submit-button";
11 | import Tooltip from "~/components/tooltip";
12 | import Search from "./search";
13 | import {modelsToFinetune} from "~/constants/models";
14 |
15 | const initialState: {
16 | success: boolean;
17 | priceInCents: number | null;
18 | userHasEnoughCredits: boolean | null;
19 | formData: FormData | null;
20 | message: string | null;
21 | } = {
22 | success: false,
23 | priceInCents: null,
24 | userHasEnoughCredits: null,
25 | formData: null,
26 | message: null,
27 | };
28 |
29 | export type State = typeof initialState;
30 |
31 | const tooltips = {
32 | modelName: (
33 | <>What do you want to call your model? The name should be unique and contain only letters, numbers, and dashes.>
34 | ),
35 | dataset: (
36 | <>
37 | The dataset you want to use for fine-tuning. The dataset should be in{" "}
38 |
39 | this format
40 |
41 | . We recommend using a dataset with at least 100 conversations.
42 | >
43 | ),
44 | learningRate: (
45 | <>
46 | Learning rate controls how much the model gets updated after each batch. In general, the larger the dataset, the
47 | lower the learning rate should be.
48 | >
49 | ),
50 | numberOfEpochs: (
51 | <>
52 | An epoch is one full pass through the dataset. The number of epochs controls how many times the model will see the
53 | dataset. Usually, a number between 1 and 5 is sufficient.
54 | >
55 | ),
56 | baseModel: (
57 | <>
58 | The base model is the model you want to fine-tune. Zephyr is a fine-tuned version of Mistral which we found to be
59 | very effective for further fine-tuning.
60 | >
61 | ),
62 | };
63 |
64 | export default function NewModelForm() {
65 | const [open, setOpen] = useState(false);
66 | const [state, formAction] = useFormState(validateAndCalculatePrice, initialState);
67 |
68 | function formActionSetOpen(formData: FormData) {
69 | setOpen(true);
70 | formAction(formData);
71 | }
72 |
73 | const [learningRate, setLearningRate] = useState("");
74 | const [baseModel, setBaseModel] = useState("");
75 |
76 | return (
77 | <>
78 |
79 |
156 | >
157 | );
158 | }
159 |
--------------------------------------------------------------------------------
/main/src/app/models/new/page.tsx:
--------------------------------------------------------------------------------
1 | import Padding from "~/components/padding";
2 | import PageHeading from "~/components/page-heading";
3 | import Sidebar from "~/components/sidebar";
4 | import NewModelForm from "./form";
5 | import {checkSession} from "~/server/utils/session";
6 |
7 | export default async function Page() {
8 | await checkSession();
9 |
10 | return (
11 |
12 |
13 | Start model training
14 |
15 |
16 |
17 |
18 |
19 |
20 | );
21 | }
22 |
--------------------------------------------------------------------------------
/main/src/app/models/new/search.tsx:
--------------------------------------------------------------------------------
1 | import * as y from "yup";
2 | import {Transition} from "@headlessui/react";
3 | import {MagnifyingGlassIcon, PlusIcon} from "@heroicons/react/20/solid";
4 | import {ArrowPathIcon, CheckIcon} from "@heroicons/react/24/outline";
5 | import {Fragment, useState} from "react";
6 |
7 | import type {SearchResponse} from "~/app/api/search/route";
8 | import NewDatasetModal from "~/app/datasets/new/modal";
9 |
10 | export default function Search() {
11 | const [loading, setLoading] = useState(false);
12 | const [open, setOpen] = useState(false);
13 | const [value, setValue] = useState("");
14 | const [searchResponse, setSerachResponse] = useState([]);
15 | const [selected, setSelected] = useState(null);
16 |
17 | const [uploadDatasetModelOpen, setUploadDatasetModelOpen] = useState(false);
18 | function setUploadDatasetModelOpenWrapper(open: boolean, datasetId?: string, datasetName?: string) {
19 | setUploadDatasetModelOpen(open);
20 | if (datasetId && datasetName) {
21 | setSelected(datasetId);
22 | setValue(datasetName);
23 | setOpen(false);
24 | }
25 | }
26 |
27 | async function fetchSearchResults() {
28 | setLoading(true);
29 | const results = await fetch("/api/search", {
30 | method: "POST",
31 | body: JSON.stringify({query: value}),
32 | });
33 |
34 | const schema = y
35 | .array(
36 | y
37 | .object({
38 | id: y.string().required(),
39 | name: y.string().required(),
40 | rows: y.number().required(),
41 | })
42 | .required(),
43 | )
44 | .required();
45 |
46 | const json = await results.json();
47 | const validated = await schema.validate(json);
48 |
49 | setSerachResponse(validated);
50 | setLoading(false);
51 | }
52 |
53 | async function onOpenChange(open: boolean) {
54 | setOpen(open);
55 |
56 | if (open && searchResponse.length === 0) {
57 | await fetchSearchResults();
58 | }
59 | }
60 |
61 | async function onValueChange(value: string) {
62 | setValue(value);
63 | await fetchSearchResults();
64 | }
65 |
66 | function onSelectedChange(id: string, name: string) {
67 | setSelected(id);
68 | setValue(name);
69 | setOpen(false);
70 | }
71 |
72 | return (
73 | <>
74 |
75 |
76 |
77 |
78 |
79 | onValueChange(e.target.value)}
85 | onFocus={() => onOpenChange(true)}
86 | onBlur={() => onOpenChange(false)}
87 | />
88 |
89 | {loading && (
90 |
93 | )}
94 | {selected &&
}
95 |
96 |
106 |
107 |
108 |
setUploadDatasetModelOpen(true)}
110 | className={`cursor-pointer flex items-center w-full px-3 py-2 hover:bg-gray-100 hover:dark:bg-gray-900 ${
111 | searchResponse.length > 0 && "border-b-2"
112 | }`}
113 | >
114 |
115 |
upload new dataset
116 |
117 | {searchResponse.map((option, index) => (
118 |
onSelectedChange(option.id, option.name)}
120 | key={index}
121 | className="cursor-pointer flex items-center hover:bg-gray-100 hover:dark:bg-gray-900"
122 | >
123 |
{option.name}
124 |
{`${option.rows} rows`}
125 |
126 | ))}
127 |
128 |
129 |
130 |
131 | {/* Hidden input containing the datasetId */}
132 |
133 | >
134 | );
135 | }
136 |
--------------------------------------------------------------------------------
/main/src/app/models/page.tsx:
--------------------------------------------------------------------------------
1 | import Padding from "~/components/padding";
2 | import PageHeading from "~/components/page-heading";
3 | import Sidebar from "~/components/sidebar";
4 | import List from "./list";
5 | import Button from "~/components/form/button";
6 | import {PlusIcon} from "@heroicons/react/20/solid";
7 | import Link from "next/link";
8 | import {getModels} from "~/server/database/model";
9 | import {checkSession} from "~/server/utils/session";
10 |
11 | import type {Dataset, Model} from "@prisma/client";
12 | import {getDatasetById} from "~/server/database/dataset";
13 |
14 | function newModelButton() {
15 | return (
16 |
17 |
18 |
19 | Train model
20 |
21 |
22 | );
23 | }
24 |
25 | // db is a prisma client. we use the types from the model table for the models array
26 | function filterPropsForTable(models: Model[], datasetNames: Map) {
27 | return models.map((model) => ({
28 | id: model.id,
29 | modelName: model.name,
30 | datasetName: datasetNames.get(model.datasetId || ""),
31 | status: model.state,
32 | baseModel: model.baseModel,
33 | wandbUrl: model.wandbUrl || undefined,
34 | }));
35 | }
36 |
37 | export type ModelProps = ReturnType[number];
38 |
39 | const defaultModels: ModelProps[] = [
40 | {
41 | id: "mixtral-8x7b",
42 | modelName: "Mixtral 8x7b Chat",
43 | status: "online",
44 | baseModel: "Chat with Mistral's new model - Limited availability",
45 | datasetName: undefined,
46 | wandbUrl: undefined,
47 | },
48 | {
49 | id: "zephyr",
50 | modelName: "Zephyr 7b",
51 | status: "online",
52 | baseModel: "A chat fine-tune of Mistral's 7b base-model",
53 | datasetName: undefined,
54 | wandbUrl: undefined,
55 | },
56 | {
57 | id: "llama2-7b",
58 | modelName: "Llama 2 7b Chat",
59 | status: "online",
60 | baseModel: "The chat version of Meta's popular Llama 2 model",
61 | datasetName: undefined,
62 | wandbUrl: undefined,
63 | },
64 | ];
65 |
66 | export default async function Page() {
67 | const session = await checkSession();
68 | const models = await getModels(session.user.id);
69 |
70 | // get unique dataset Ids
71 | const datasetIds = Array.from(new Set(models.map((model) => model.datasetId)));
72 |
73 | // filter out null values
74 | const filteredDatasetIds = datasetIds.filter((id) => id !== null) as string[];
75 |
76 | // for each dataset, get the dataset name
77 | const datasets = await Promise.all(filteredDatasetIds.map((id) => getDatasetById(id, session.user.id)));
78 |
79 | // create a map of dataset id to dataset name
80 | const datasetMap = new Map();
81 | datasets.forEach((dataset) => {
82 | if (!dataset) {
83 | return;
84 | }
85 |
86 | datasetMap.set(dataset.id, dataset.name);
87 | });
88 |
89 | return (
90 |
91 |
92 | Models
93 |
94 |
95 |
96 |
97 |
98 |
99 | );
100 | }
101 |
--------------------------------------------------------------------------------
/main/src/app/page.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import {signIn} from "next-auth/react";
3 | import {useEffect} from "react";
4 |
5 | export default function Page() {
6 | useEffect(() => {
7 | void signIn();
8 | }, []);
9 |
10 | return null;
11 | }
12 |
--------------------------------------------------------------------------------
/main/src/app/providers.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 |
3 | import { SessionProvider } from "next-auth/react";
4 |
5 | type Props = {
6 | children?: React.ReactNode;
7 | };
8 |
9 | export const NextAuthProvider = ({ children }: Props) => {
10 | return {children} ;
11 | };
12 |
--------------------------------------------------------------------------------
/main/src/app/settings/actions.tsx:
--------------------------------------------------------------------------------
1 | "use server";
2 | import {getServerSession} from "next-auth";
3 | import {revalidatePath} from "next/cache";
4 | import {authOptions} from "~/server/auth";
5 | import {updateHfToken, updateName} from "~/server/database/user";
6 | import type {initialState} from "./form";
7 | import {logger} from "~/server/utils/observability/logtail";
8 |
9 | export async function submitForm(_: typeof initialState, formData: FormData) {
10 | const session = await getServerSession(authOptions);
11 | if (!session) {
12 | throw new Error("Not authenticated");
13 | }
14 |
15 | // TODO: validate
16 | const name = formData.get("name");
17 | const hfToken = formData.get("hf_token");
18 |
19 | if (hfToken && typeof hfToken === "string") {
20 | logger.info("Updating hfToken");
21 | await updateHfToken(session.user.id, hfToken.toString());
22 | }
23 |
24 | if (name && typeof name === "string") {
25 | logger.info("Updating name", {oldName: session.user.name, name: name.toString()});
26 | await updateName(session.user.id, name.toString());
27 | }
28 |
29 | revalidatePath("/settings");
30 | return {
31 | message: "Success!",
32 | color: "text-green-600",
33 | };
34 | }
35 |
--------------------------------------------------------------------------------
/main/src/app/settings/form.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import {useFormState, useFormStatus} from "react-dom";
3 | import Button from "~/components/form/button";
4 | import TextField from "~/components/form/textfield";
5 | import {submitForm} from "./actions";
6 |
7 | export const initialState: {
8 | message: string | null;
9 | color: string | null;
10 | } = {
11 | message: null,
12 | color: null,
13 | };
14 |
15 | function SubmitButton() {
16 | const {pending} = useFormStatus();
17 |
18 | return (
19 |
20 | Save
21 |
22 | );
23 | }
24 |
25 | export default function SettingsForm({name, hfToken}: {name?: string; hfToken?: string}) {
26 | const [state, formAction] = useFormState(submitForm, initialState);
27 |
28 | return (
29 |
52 | );
53 | }
54 |
--------------------------------------------------------------------------------
/main/src/app/settings/page.tsx:
--------------------------------------------------------------------------------
1 | import Padding from "~/components/padding";
2 | import PageHeading from "~/components/page-heading";
3 | import Sidebar from "~/components/sidebar";
4 | import {checkSession} from "~/server/utils/session";
5 | import SettingsForm from "./form";
6 |
7 | export default async function Settings() {
8 | const session = await checkSession();
9 |
10 | const name = session.user.name ?? undefined;
11 | const hfToken = session.user.hfToken;
12 |
13 | return (
14 |
15 |
16 | Settings
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | );
25 | }
26 |
--------------------------------------------------------------------------------
/main/src/components/form/button.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import {Loader2} from "lucide-react";
3 |
4 | interface Props {
5 | id?: string;
6 | onClick?: () => void;
7 | loading?: boolean;
8 | className?: string;
9 | type?: "button" | "submit"; // set to submit for forms
10 | children?: React.ReactNode;
11 | }
12 |
13 | export default function Button({
14 | id,
15 | onClick = undefined,
16 | loading = false,
17 | className = "",
18 | type = "button",
19 | children,
20 | }: Props) {
21 | return (
22 |
29 | {loading && }
30 | {children}
31 |
32 | );
33 | }
34 |
--------------------------------------------------------------------------------
/main/src/components/form/dropdown.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import {Fragment} from "react";
3 | import {Menu, Transition} from "@headlessui/react";
4 | import Label from "./label";
5 |
6 | interface DropdownProps {
7 | children: React.ReactNode;
8 | options: string[];
9 | onSelect?: (selectedOption: string) => void;
10 | dropdownDirection?: "up" | "downRight" | "downLeft";
11 | label?: string;
12 | }
13 |
14 | function classNames(...classes: string[]) {
15 | return classes.filter(Boolean).join(" ");
16 | }
17 |
18 | // TODO: we can simplify this, most of these styles are the same
19 | const styles = Object.freeze({
20 | up: "absolute left-0 bottom-0 z-10 ml-3 mb-10 w-56 rounded-md bg-white dark:bg-black shadow-lg ring-1 ring-black ring-opacity-5 focus:outline-none",
21 | downLeft:
22 | "absolute right-0 z-10 mt-2 w-56 origin-top-right rounded-md bg-white dark:bg-black shadow-lg ring-1 ring-black ring-opacity-5 focus:outline-none",
23 | downRight:
24 | "absolute left-0 z-10 mt-2 w-56 origin-top-right rounded-md bg-white dark:bg-black shadow-lg ring-1 ring-black ring-opacity-5 focus:outline-none",
25 | });
26 |
27 | export default function Dropdown({children, options, onSelect, label, dropdownDirection = "downRight"}: DropdownProps) {
28 | return (
29 | <>
30 | {label !== "" && {label} }
31 |
32 | {children}
33 |
34 |
43 |
44 |
58 |
59 |
60 |
61 | >
62 | );
63 | }
64 |
--------------------------------------------------------------------------------
/main/src/components/form/file-upload.tsx:
--------------------------------------------------------------------------------
1 | import {ArrowUpTrayIcon} from "@heroicons/react/20/solid";
2 | import React, {useState} from "react";
3 |
4 | interface Props {
5 | onFile: (file: File) => void;
6 | }
7 |
8 | export default function FileUpload({onFile}: Props) {
9 | const [file, setFile] = useState();
10 | const [dragging, setDragging] = useState(false);
11 |
12 | function setDraggingWithoutDefault(e: React.DragEvent, value: boolean) {
13 | e.preventDefault();
14 | e.stopPropagation();
15 | setDragging(value);
16 | }
17 |
18 | function fileAdded(newFile: File) {
19 | onFile(newFile);
20 | setFile(newFile);
21 | }
22 |
23 | function onDrop(e: React.DragEvent) {
24 | e.preventDefault();
25 | e.stopPropagation();
26 |
27 | const files = e.dataTransfer.files;
28 |
29 | if (files.length > 0) {
30 | fileAdded(files[0]!);
31 | }
32 | }
33 |
34 | function onFileChangeInternal(e: React.ChangeEvent) {
35 | const files = e.target.files;
36 |
37 | if (files?.[0]) {
38 | fileAdded(files[0]);
39 | }
40 | }
41 |
42 | const draggingType = "bg-gray-950";
43 | const droppedType = "bg-gray-950";
44 | const styles = dragging ? (file ? droppedType : draggingType) : "";
45 |
46 | return (
47 |
48 |
setDraggingWithoutDefault(e, true)}
52 | onDragLeave={(e) => setDraggingWithoutDefault(e, false)}
53 | onDrop={onDrop}
54 | >
55 |
56 |
57 | {file === undefined ? (
58 |
59 |
60 | Click to upload or drag and drop
61 |
62 | Dataset File in OpenAI Format. (.jsonl)
63 |
64 | ) : (
65 |
66 |
67 | Selected File: {file.name}
68 |
69 | Click or drag and drop to change file.
70 |
71 | )}
72 |
73 |
74 |
75 |
76 | );
77 | }
78 |
--------------------------------------------------------------------------------
/main/src/components/form/label.tsx:
--------------------------------------------------------------------------------
1 | export default function Label({className = "", children}: {className?: string; children: React.ReactNode}) {
2 | return {children} ;
3 | }
4 |
--------------------------------------------------------------------------------
/main/src/components/form/submit-button.tsx:
--------------------------------------------------------------------------------
1 | import {useFormStatus} from "react-dom";
2 | import Button from "./button";
3 |
4 | export default function SubmitButton({className, children}: {className?: string; children: React.ReactNode}) {
5 | const {pending} = useFormStatus();
6 |
7 | return (
8 |
9 | {children}
10 |
11 | );
12 | }
13 |
--------------------------------------------------------------------------------
/main/src/components/form/textfield.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | interface TextFieldProps {
3 | value?: string;
4 | defaultValue?: string;
5 | onChange?: (value: any) => void;
6 |
7 | type?: "text" | "number" | "password" | "email";
8 | name?: string;
9 | id?: string;
10 | label?: string;
11 | placeholder?: string;
12 | disabled?: boolean;
13 | required?: boolean;
14 |
15 | className?: string;
16 | }
17 |
18 | export default function TextField({
19 | value = undefined,
20 | defaultValue = undefined,
21 | onChange = undefined,
22 | type = "text",
23 | name = "",
24 | id = "",
25 | label = "",
26 | placeholder = "",
27 | disabled = false,
28 | required = false,
29 | className = "",
30 | }: TextFieldProps) {
31 | return (
32 |
33 | {label !== "" && (
34 |
35 | {label}
36 |
37 | )}
38 |
39 | onChange && onChange(e.target.value)}
50 | />
51 |
52 |
53 | );
54 | }
55 |
--------------------------------------------------------------------------------
/main/src/components/modal.tsx:
--------------------------------------------------------------------------------
1 | import {Fragment} from "react";
2 | import {Dialog, Transition} from "@headlessui/react";
3 |
4 | interface Props {
5 | open: boolean;
6 | setOpen: (open: boolean) => void;
7 | size?: "sm" | "xl";
8 | children?: React.ReactNode;
9 | }
10 |
11 | export default function Modal({open, children, setOpen, size = "sm"}: Props) {
12 | const sizeClasses = {
13 | sm: "sm:max-w-sm",
14 | xl: "sm:max-w-xl",
15 | };
16 |
17 | return (
18 |
19 |
20 |
29 |
30 |
31 |
32 |
33 |
34 |
43 |
46 | {children}
47 |
48 |
49 |
50 |
51 |
52 |
53 | );
54 | }
55 |
--------------------------------------------------------------------------------
/main/src/components/padding.tsx:
--------------------------------------------------------------------------------
1 | export default function Padding(props: {children: React.ReactNode}) {
2 | return {props.children}
;
3 | }
4 |
--------------------------------------------------------------------------------
/main/src/components/page-heading.tsx:
--------------------------------------------------------------------------------
1 | interface Props {
2 | children: React.ReactNode;
3 | primary?: React.ReactNode;
4 | }
5 |
6 | export default function PageHeading({children, primary}: Props) {
7 | return (
8 |
9 |
10 |
{children}
11 |
12 |
{primary}
13 |
14 | );
15 | }
16 |
--------------------------------------------------------------------------------
/main/src/components/sidebar.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 | import {
3 | Bars3Icon,
4 | Cog8ToothIcon,
5 | ChevronUpDownIcon,
6 | XMarkIcon,
7 | CreditCardIcon,
8 | ChatBubbleLeftRightIcon,
9 | TableCellsIcon,
10 | } from "@heroicons/react/24/outline";
11 |
12 | import Dropdown from "./form/dropdown";
13 | import {Dialog, Transition} from "@headlessui/react";
14 | import {Fragment, useState} from "react";
15 | import Link from "next/link";
16 | import {signOut, useSession} from "next-auth/react";
17 | import Image from "next/image";
18 | import UserAvatar from "./user-avatar";
19 |
20 | type SidebarItem = "Models" | "Datasets" | "Billing" | "Settings" | "None";
21 |
22 | // TODO: these are copied form tailwindui so there are multiple in the codebase.
23 | // we should only use one
24 | function classNames(...classes: string[]) {
25 | return classes.filter(Boolean).join(" ");
26 | }
27 |
28 | interface Props {
29 | children: React.ReactNode;
30 | current: SidebarItem;
31 | title?: string;
32 | }
33 |
34 | export default function Sidebar(props: Props) {
35 | // TODO: pass in through props
36 | const email = useSession().data?.user.email || "";
37 |
38 | const [sidebarOpen, setSidebarOpen] = useState(false);
39 |
40 | let navigation = [
41 | {
42 | name: "Models",
43 | page: "/models",
44 | icon: ChatBubbleLeftRightIcon,
45 | current: false,
46 | },
47 | {
48 | name: "Datasets",
49 | page: "/datasets",
50 | icon: TableCellsIcon,
51 | current: false,
52 | },
53 | {name: "Billing", page: "/billing", icon: CreditCardIcon, current: false, headline: "Account"},
54 | {
55 | name: "Settings",
56 | page: "/settings",
57 | icon: Cog8ToothIcon,
58 | current: false,
59 | },
60 | ];
61 |
62 | // Set current navigation item
63 | navigation.map((item) => {
64 | if (item.name === props.current) {
65 | item.current = true;
66 | }
67 | });
68 |
69 | if (props.current === "None") {
70 | navigation = [];
71 | }
72 |
73 | return (
74 | <>
75 |
76 |
77 |
78 |
87 |
88 |
89 |
90 |
91 |
100 |
101 |
110 |
111 | setSidebarOpen(false)}>
112 | Close sidebar
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
127 |
128 |
129 |
130 |
131 |
132 |
133 | {navigation.map((item) => (
134 |
135 |
144 |
151 | {item.name}
152 |
153 |
154 | ))}
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 | {/* Static sidebar for desktop */}
167 |
168 | {/* Sidebar component, swap this element with another sidebar if you like */}
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 | {navigation.map((item) => (
180 |
181 | {item.headline && (
182 | {item.headline}
183 | )}
184 |
193 |
200 | {item.name}
201 |
202 |
203 | ))}
204 |
205 |
206 |
207 |
208 |
signOut({callbackUrl: "/"})} dropdownDirection="up">
209 |
210 |
211 |
212 | Your profile
213 | {email}
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
setSidebarOpen(true)}>
227 | Open sidebar
228 |
229 |
230 |
{props.title}
231 |
signOut({callbackUrl: "/"})}>
232 | Your profile
233 |
234 |
235 |
236 |
237 |
{props.children}
238 |
239 | >
240 | );
241 | }
242 |
--------------------------------------------------------------------------------
/main/src/components/tiles.tsx:
--------------------------------------------------------------------------------
1 | import Image from "next/image";
2 |
3 | export interface Tile {
4 | id: number;
5 | name: string;
6 | imageUrl: string;
7 | selected: boolean;
8 | onClick: () => void;
9 | }
10 |
11 | interface TileProps {
12 | tiles: Tile[];
13 | }
14 |
15 | export default function Tiles({tiles}: TileProps) {
16 | return (
17 |
18 | {tiles.map((tile) => (
19 | {
22 | tile.onClick();
23 | }}
24 | >
25 |
26 |
27 |
28 |
29 |
30 |
31 | {tile.name}
32 |
33 |
34 |
35 |
36 | ))}
37 |
38 | );
39 | }
40 |
--------------------------------------------------------------------------------
/main/src/components/tooltip.tsx:
--------------------------------------------------------------------------------
1 | import {HoverCard, HoverCardContent, HoverCardTrigger} from "~/components/ui/hover-card";
2 | import {QuestionMarkCircleIcon} from "@heroicons/react/24/outline";
3 |
4 | export default function Tooltip({children, className}: {children: React.ReactElement; className?: string}) {
5 | return (
6 |
7 |
8 |
9 |
10 |
11 |
15 | {children}
16 |
17 |
18 |
19 | );
20 | }
21 |
--------------------------------------------------------------------------------
/main/src/components/ui/hover-card.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 |
3 | import * as React from "react";
4 | import * as HoverCardPrimitive from "@radix-ui/react-hover-card";
5 |
6 | import {cn} from "~/components/utils/utils";
7 |
8 | const HoverCard = HoverCardPrimitive.Root;
9 |
10 | const HoverCardTrigger = HoverCardPrimitive.Trigger;
11 |
12 | const HoverCardContent = React.forwardRef<
13 | React.ElementRef,
14 | React.ComponentPropsWithoutRef
15 | >(({className, align = "center", sideOffset = 4, ...props}, ref) => (
16 |
26 | ));
27 | HoverCardContent.displayName = HoverCardPrimitive.Content.displayName;
28 |
29 | export {HoverCard, HoverCardTrigger, HoverCardContent};
30 |
--------------------------------------------------------------------------------
/main/src/components/ui/slider.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 |
3 | import * as React from "react";
4 | import * as SliderPrimitive from "@radix-ui/react-slider";
5 |
6 | import {cn} from "~/components/utils/utils";
7 |
8 | const Slider = React.forwardRef<
9 | React.ElementRef,
10 | React.ComponentPropsWithoutRef
11 | >(({className, ...props}, ref) => (
12 |
17 |
18 |
19 |
20 |
21 |
22 | ));
23 | Slider.displayName = SliderPrimitive.Root.displayName;
24 |
25 | export {Slider};
26 |
--------------------------------------------------------------------------------
/main/src/components/ui/switch.tsx:
--------------------------------------------------------------------------------
1 | "use client";
2 |
3 | import * as React from "react";
4 | import * as SwitchPrimitives from "@radix-ui/react-switch";
5 |
6 | import {cn} from "~/components/utils/utils";
7 |
8 | const Switch = React.forwardRef<
9 | React.ElementRef,
10 | React.ComponentPropsWithoutRef
11 | >(({className, ...props}, ref) => (
12 |
20 |
25 |
26 | ));
27 | Switch.displayName = SwitchPrimitives.Root.displayName;
28 |
29 | export {Switch};
30 |
--------------------------------------------------------------------------------
/main/src/components/user-avatar.tsx:
--------------------------------------------------------------------------------
1 | export default function UserAvatar({name}: {name?: string}) {
2 | const nameWithDefault = name || "";
3 |
4 | return (
5 |
6 |
7 | {nameWithDefault.charAt(0).toUpperCase()}
8 |
9 |
10 | );
11 | }
12 |
--------------------------------------------------------------------------------
/main/src/components/utils/utils.ts:
--------------------------------------------------------------------------------
1 | import clsx from "clsx";
2 | import {twMerge} from "tailwind-merge";
3 |
4 | import type {ClassValue} from "clsx";
5 |
6 | export function cn(...inputs: ClassValue[]) {
7 | return twMerge(clsx(inputs));
8 | }
9 |
--------------------------------------------------------------------------------
/main/src/components/warning.tsx:
--------------------------------------------------------------------------------
1 | import {ExclamationTriangleIcon} from "@heroicons/react/24/outline";
2 |
3 | interface Props {
4 | message: string;
5 | }
6 |
7 | export default function Warning(props: Props) {
8 | return (
9 |
10 |
11 |
12 |
13 |
{props.message}
14 |
15 | );
16 | }
17 |
--------------------------------------------------------------------------------
/main/src/constants/modal.ts:
--------------------------------------------------------------------------------
1 | import type {modelsToFinetune} from "./models";
2 |
3 | export const inferenceEndpoints: Readonly> = Object.freeze({
4 | "HuggingFaceH4/zephyr-7b-beta":
5 | "https://havenhq--lora-server-huggingfaceh4-zephyr-7b-beta-model--b80943.modal.run/generate_stream",
6 | "meta-llama/Llama-2-7b-chat-hf":
7 | "https://havenhq--lora-server-meta-llama-llama-2-7b-chat-hf-model-4eee5b.modal.run/generate_stream",
8 | "mistralai/Mixtral-8x7b-Instruct-v0.1":
9 | "https://havenhq--lora-server-mistralai-mixtral-8x7b-instruct-v0--9fe7b9.modal.run/generate_stream",
10 | });
11 |
12 | export const exportEndpoint = "https://havenhq--model-export-export.modal.run";
13 |
14 | export const trainEndpoint: Readonly> = {
15 | "HuggingFaceH4/zephyr-7b-beta": "https://havenhq--finetuning-service-train.modal.run",
16 | "meta-llama/Llama-2-7b-chat-hf": "https://havenhq--finetuning-service-train.modal.run",
17 | "mistralai/Mixtral-8x7b-Instruct-v0.1": "https://havenhq--mixtral-finetuning-train.modal.run",
18 | };
19 |
--------------------------------------------------------------------------------
/main/src/constants/models.ts:
--------------------------------------------------------------------------------
1 | // Exporting the selection of models both to the client side and the server side code.
2 | export const defaultModelLoopup = Object.freeze({
3 | "mixtral-8x7b": "mistralai/Mixtral-8x7b-Instruct-v0.1",
4 | zephyr: "HuggingFaceH4/zephyr-7b-beta",
5 | "llama2-7b": "meta-llama/Llama-2-7b-chat-hf",
6 | });
7 |
8 | export const modelsToFinetune = Object.values(defaultModelLoopup);
9 | export const modelIds = Object.keys(defaultModelLoopup);
10 |
11 | export type Models = (typeof modelsToFinetune)[number];
12 |
--------------------------------------------------------------------------------
/main/src/env.mjs:
--------------------------------------------------------------------------------
1 | import {createEnv} from "@t3-oss/env-nextjs";
2 | import {z} from "zod";
3 |
4 | export const env = createEnv({
5 | /**
6 | * Specify your server-side environment variables schema here. This way you can ensure the app
7 | * isn't built with invalid env vars.
8 | */
9 | server: {
10 | DATABASE_URL: z
11 | .string()
12 | .url()
13 | .refine((str) => !str.includes("YOUR_MYSQL_URL_HERE"), "You forgot to change the default URL"),
14 | NODE_ENV: z.enum(["development", "test", "production"]).default("development"),
15 | NEXTAUTH_SECRET: process.env.NODE_ENV === "production" ? z.string() : z.string().optional(),
16 | NEXTAUTH_URL: z.preprocess(
17 | // This makes Vercel deployments not fail if you don't set NEXTAUTH_URL
18 | // Since NextAuth.js automatically uses the VERCEL_URL if present.
19 | (str) => process.env.VERCEL_URL ?? str,
20 | // VERCEL_URL doesn't include `https` so it cant be validated as a URL
21 | process.env.VERCEL ? z.string() : z.string().url(),
22 | ),
23 | // Add ` on ID and SECRET if you want to make sure they're not empty
24 | GOOGLE_CLIENT_ID: z.string(),
25 | GOOGLE_CLIENT_SECRET: z.string(),
26 |
27 | STRIPE_SECRET_KEY: z.string(),
28 | STRIPE_PUBLISHABLE_KEY: z.string(),
29 |
30 | WEBHOOK_SECRET: z.string(),
31 | MODAL_HOSTNAME: z.string(),
32 | MODAL_AUTH_TOKEN: z.string(),
33 | RESEND_KEY: z.string(),
34 | LOGTAIL_KEY: z.string(),
35 | LOGTAIL_ENV: z.string(),
36 | },
37 |
38 | /**
39 | * Specify your client-side environment variables schema here. This way you can ensure the app
40 | * isn't built with invalid env vars. To expose them to the client, prefix them with
41 | * `NEXT_PUBLIC_`.
42 | */
43 | client: {
44 | // NEXT_PUBLIC_CLIENTVAR: z.string(),
45 | },
46 |
47 | /**
48 | * You can't destruct `process.env` as a regular object in the Next.js edge runtimes (e.g.
49 | * middlewares) or client-side so we need to destruct manually.
50 | */
51 | runtimeEnv: {
52 | DATABASE_URL: process.env.DATABASE_URL,
53 | NODE_ENV: process.env.NODE_ENV,
54 | NEXTAUTH_SECRET: process.env.NEXTAUTH_SECRET,
55 | NEXTAUTH_URL: process.env.NEXTAUTH_URL,
56 | GOOGLE_CLIENT_ID: process.env.GOOGLE_CLIENT_ID,
57 | GOOGLE_CLIENT_SECRET: process.env.GOOGLE_CLIENT_SECRET,
58 | STRIPE_SECRET_KEY: process.env.STRIPE_SECRET_KEY,
59 | STRIPE_PUBLISHABLE_KEY: process.env.STRIPE_PUBLISHABLE_KEY,
60 | WEBHOOK_SECRET: process.env.WEBHOOK_SECRET,
61 | MODAL_HOSTNAME: process.env.MODAL_HOSTNAME,
62 | MODAL_AUTH_TOKEN: process.env.MODAL_AUTH_TOKEN,
63 | RESEND_KEY: process.env.RESEND_KEY,
64 | LOGTAIL_KEY: process.env.LOGTAIL_KEY,
65 | LOGTAIL_ENV: process.env.LOGTAIL_ENV,
66 | },
67 | /**
68 | * Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation. This is especially
69 | * useful for Docker builds.
70 | */
71 | skipValidation: !!process.env.SKIP_ENV_VALIDATION,
72 | /**
73 | * Makes it so that empty strings are treated as undefined.
74 | * `SOME_VAR: z.string()` and `SOME_VAR=''` will throw an error.
75 | */
76 | emptyStringAsUndefined: true,
77 | });
78 |
--------------------------------------------------------------------------------
/main/src/pages/_app.tsx:
--------------------------------------------------------------------------------
1 | // pages/_app.tsx
2 | import {SessionProvider} from "next-auth/react";
3 | import "../styles/globals.css";
4 |
5 | function MyApp({Component, pageProps}: any) {
6 | return (
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | );
15 | }
16 |
17 | export default MyApp;
18 |
--------------------------------------------------------------------------------
/main/src/pages/_document.tsx:
--------------------------------------------------------------------------------
1 | import Document, {Html, Head, Main, NextScript} from "next/document";
2 |
3 | class MyDocument extends Document {
4 | render() {
5 | return (
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | );
14 | }
15 | }
16 |
17 | export default MyDocument;
18 |
--------------------------------------------------------------------------------
/main/src/pages/auth/login/index.tsx:
--------------------------------------------------------------------------------
1 | import Button from "~/components/form/button";
2 | import TextField from "~/components/form/textfield";
3 | import {getCsrfToken, signIn} from "next-auth/react";
4 | import type {GetServerSidePropsContext, InferGetServerSidePropsType} from "next";
5 | import {useRouter} from "next/router";
6 | import Image from "next/image";
7 | import {getServerAuthSession} from "~/server/auth";
8 | import {useEffect} from "react";
9 |
10 | export default function SignIn({csrfToken, session}: InferGetServerSidePropsType) {
11 | const router = useRouter();
12 |
13 | useEffect(() => {
14 | if (session.expires) {
15 | void router.push("/models");
16 | }
17 | });
18 |
19 | return (
20 | <>
21 |
22 |
23 |
24 |
25 |
26 | Log in or create an account
27 |
28 |
29 |
30 |
31 |
32 |
43 |
44 |
45 |
46 |
49 |
50 | or
51 |
52 |
53 |
54 |
signIn("google")}
56 | className="flex w-full items-center justify-center gap-3 rounded-md bg-white px-3 py-1.5 text-black focus-visible:outline focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-[#24292F]"
57 | >
58 |
68 |
69 |
70 | Sign in with Google
71 |
72 |
73 |
84 |
85 |
86 |
87 |
88 |
89 | >
90 | );
91 | }
92 |
93 | export async function getServerSideProps(context: GetServerSidePropsContext) {
94 | const csrfToken = await getCsrfToken(context);
95 | const session = await getServerAuthSession(context);
96 |
97 | return {
98 | props: {
99 | csrfToken,
100 | session: {
101 | expires: session?.expires || "",
102 | },
103 | },
104 | };
105 | }
106 |
--------------------------------------------------------------------------------
/main/src/pages/auth/new-user/index.tsx:
--------------------------------------------------------------------------------
1 | import {useRouter} from "next/router";
2 | import {useEffect} from "react";
3 |
4 | export default function Redirect() {
5 | const router = useRouter();
6 |
7 | useEffect(() => {
8 | // Redirect to /models
9 | void router.push("/models");
10 | });
11 |
12 | return null;
13 | }
14 |
--------------------------------------------------------------------------------
/main/src/pages/auth/verify-request/index.tsx:
--------------------------------------------------------------------------------
1 | import Image from "next/image";
2 |
3 | export default function Login() {
4 | return (
5 | <>
6 |
7 |
8 |
9 |
16 |
17 |
18 |
19 |
20 | Please open the link we just sent to your email address to be logged in.
21 |
22 |
23 |
24 |
25 | >
26 | );
27 | }
28 |
--------------------------------------------------------------------------------
/main/src/server/auth.ts:
--------------------------------------------------------------------------------
1 | import {PrismaAdapter} from "@next-auth/prisma-adapter";
2 | import {type GetServerSidePropsContext} from "next";
3 | import {getServerSession, type DefaultSession, type NextAuthOptions} from "next-auth";
4 | import EmailProvider from "next-auth/providers/email";
5 | import GoogleProvider from "next-auth/providers/google";
6 |
7 | import {db} from "~/server/database";
8 | import {addApiKeyToUser} from "./database/user";
9 | import type {User} from "@prisma/client";
10 | import {sendVerifyEmail} from "./utils/mail";
11 | import {EventName, sendEvent} from "./utils/observability/posthog";
12 |
13 | /**
14 | * Module augmentation for `next-auth` types. Allows us to add custom properties to the `session`
15 | * object and keep type safety.
16 | *
17 | * @see https://next-auth.js.org/getting-started/typescript#module-augmentation
18 | */
19 | declare module "next-auth" {
20 | interface Session extends DefaultSession {
21 | user: DefaultSession["user"] & {
22 | id: string;
23 | centsBalance: number;
24 | stripeCustomerId?: string;
25 | stripePaymentMethodId?: string;
26 | apiKey: string;
27 | hfToken?: string;
28 | image?: string;
29 | };
30 | }
31 |
32 | // interface User {
33 | // // ...other properties
34 | // // role: UserRole;
35 | // }
36 | }
37 |
38 | /**
39 | * Options for NextAuth.js used to configure adapters, providers, callbacks, etc.
40 | *
41 | * @see https://next-auth.js.org/configuration/options
42 | */
43 | export const authOptions: NextAuthOptions = {
44 | secret: process.env.NEXTAUTH_SECRET,
45 | callbacks: {
46 | session: ({session, user}) => ({
47 | ...session,
48 | user: {
49 | ...session.user,
50 | id: user.id,
51 | centsBalance: (user as User).centsBalance,
52 | stripeCustomerId: (user as User).stripeCustomerId,
53 | stripePaymentMethodId: (user as User).stripePaymentMethodId,
54 | apiKey: (user as User).apiKey,
55 | hfToken: (user as User).hfToken,
56 | image: (user as User).image,
57 | },
58 | }),
59 | },
60 | adapter: PrismaAdapter(db),
61 | providers: [
62 | EmailProvider({
63 | sendVerificationRequest: async ({identifier, url}) => {
64 | await sendVerifyEmail(identifier, url);
65 | },
66 | }),
67 | GoogleProvider({
68 | clientId: process.env.GOOGLE_CLIENT_ID ?? "",
69 | clientSecret: process.env.GOOGLE_CLIENT_SECRET ?? "",
70 | }),
71 | /**
72 | * ...add more providers here.
73 | *
74 | * Most other providers require a bit more work than the Discord provider. For example, the
75 | * GitHub provider requires you to add the `refresh_token_expires_in` field to the Account
76 | * model. Refer to the NextAuth.js docs for the provider you want to use. Example:
77 | *
78 | * @see https://next-auth.js.org/providers/github
79 | */
80 | ],
81 | session: {strategy: "database"},
82 | pages: {
83 | signIn: "/auth/login",
84 | verifyRequest: "/auth/verify-request",
85 | newUser: "/auth/new-user",
86 | },
87 | events: {
88 | createUser: async (message) => {
89 | sendEvent(message.user.id, EventName.NEW_USER, {email: message.user.email});
90 |
91 | // Add an API key to the user when they sign up.
92 | await addApiKeyToUser(message.user.id);
93 | },
94 | },
95 | };
96 |
97 | /**
98 | * Wrapper for `getServerSession` so that you don't need to import the `authOptions` in every file.
99 | *
100 | * @see https://next-auth.js.org/configuration/nextjs
101 | */
102 | export const getServerAuthSession = (ctx: {
103 | req: GetServerSidePropsContext["req"];
104 | res: GetServerSidePropsContext["res"];
105 | }) => {
106 | return getServerSession(ctx.req, ctx.res, authOptions);
107 | };
108 |
--------------------------------------------------------------------------------
/main/src/server/controller/new-dataset.ts:
--------------------------------------------------------------------------------
1 | import {promises as fs} from "fs";
2 | import {v4 as uuid} from "uuid";
3 | import {checkFileValidity} from "./process-dataset";
4 | import {createDataset} from "../database/dataset";
5 | import {uploadFile} from "../utils/modal";
6 |
7 | import type {FormDataType} from "~/app/datasets/new/actions";
8 | import {logger} from "../utils/observability/logtail";
9 |
10 | export async function uploadDataset(userId: string, validatedForm: FormDataType) {
11 | const datasetText = await (validatedForm.dropzoneFile as File).text();
12 | const parsed = checkFileValidity(datasetText);
13 |
14 | if (parsed.length < 20) {
15 | logger.error("[uploadDataset] Dataset has less than 20 rows.");
16 | throw new Error("A dataset should have at least 20 rows.");
17 | }
18 |
19 | // Write dataset to temporary path
20 | const fileName = `${uuid()}.json`;
21 | const localPath = `./tmp/${fileName}`;
22 |
23 | await fs.mkdir("./tmp/").catch(() => {});
24 | await fs.writeFile(localPath, datasetText);
25 |
26 | // Upload dataset
27 | await uploadFile(datasetText, fileName).catch((e) => {
28 | logger.error("[uploadDataset] Could not upload file.", {error: e});
29 | throw new Error("Could not upload file.");
30 | });
31 |
32 | // Create new database entry
33 | const dataset = await createDataset(userId, validatedForm.name, fileName, parsed.length);
34 | return dataset.id;
35 | }
36 |
--------------------------------------------------------------------------------
/main/src/server/controller/new-model.ts:
--------------------------------------------------------------------------------
1 | import {createModel} from "~/server/database/model";
2 |
3 | import {downloadFile} from "~/server/utils/modal";
4 | import {decreaseBalance} from "~/server/database/user";
5 | import {EventName, sendEvent} from "~/server/utils/observability/posthog";
6 | import {checkFileValidity, processDataset} from "./process-dataset";
7 | import {getDatasetById} from "../database/dataset";
8 | import {logger} from "../utils/observability/logtail";
9 | import {trainEndpoint} from "~/constants/modal";
10 |
11 | import type {Session} from "next-auth";
12 | import type {FormDataType} from "~/app/models/new/actions";
13 | import type {Models} from "~/constants/models";
14 |
15 | /**
16 | * Validate the file and calculate the price.
17 | */
18 | async function validate(userId: string, datasetId: string, baseModel: Models, numberOfEpochs: number) {
19 | const datasetDb = await getDatasetById(datasetId, userId);
20 | if (!datasetDb) {
21 | logger.error("Dataset not found.");
22 | throw new Error("Dataset not found.");
23 | }
24 |
25 | console.log(`dataset exists, now downloading file ${datasetDb.fileName}`);
26 |
27 | const dataset = await downloadFile(datasetDb.fileName);
28 |
29 | console.log(`file downloaded, now checking validity`);
30 |
31 | const processedDataset = checkFileValidity(dataset);
32 |
33 | console.log(`file valid, now processing`);
34 |
35 | // Extract relevant metadata
36 | const metaData = await processDataset(processedDataset, baseModel, numberOfEpochs);
37 |
38 | console.log(`file processed, now returning`);
39 |
40 | return {
41 | ...metaData,
42 | datasetFileName: datasetDb.fileName,
43 | };
44 | }
45 |
46 | /**
47 | * Calculate the price of the training job.
48 | * @param session
49 | * @param validatedForm
50 | */
51 | export async function calculatePrice(session: Session, validatedForm: FormDataType) {
52 | console.log(`calculating price for ${validatedForm.baseModel}`);
53 | const {priceInCents} = await validate(
54 | session.user.id,
55 | validatedForm.datasetId,
56 | validatedForm.baseModel,
57 | validatedForm.numberOfEpochs,
58 | );
59 | sendEvent(session.user.id, EventName.FINE_TUNE_PRICE_CALCULATED, {priceInCents, baseModel: validatedForm.baseModel});
60 | return priceInCents;
61 | }
62 |
63 | /**
64 | * Create a new model training job.
65 | * @param session
66 | * @param validatedForm
67 | * @param confirmedPrice
68 | */
69 | export async function createNewModelTraining(session: Session, validatedForm: FormDataType, confirmedPrice?: number) {
70 | const {maxTokens, gradientAccumulationSteps, perDeviceTrainBatchSize, priceInCents, datasetFileName} = await validate(
71 | session.user.id,
72 | validatedForm.datasetId,
73 | validatedForm.baseModel,
74 | validatedForm.numberOfEpochs,
75 | );
76 |
77 | // Check that the price is correct
78 | if (priceInCents !== confirmedPrice) {
79 | throw new Error("Confirmed price does not match calculated price. Please try again.");
80 | }
81 |
82 | // Check that the user has enough money
83 | if (session.user.centsBalance < priceInCents) {
84 | throw new Error("You do not have enough money to train this model.");
85 | }
86 |
87 | // Update user balance
88 | await decreaseBalance(session.user.id, priceInCents, `Model ${validatedForm.name}`);
89 |
90 | // Update stuff in the database
91 | const model = await createModel(
92 | session.user.id,
93 | validatedForm.name,
94 | priceInCents,
95 | validatedForm.datasetId,
96 | validatedForm.learningRate,
97 | validatedForm.numberOfEpochs,
98 | validatedForm.baseModel,
99 | );
100 |
101 | const pathToSend = `/datasets/${datasetFileName}`;
102 |
103 | // TODO: fix on Modal's side
104 | const model_name =
105 | validatedForm.baseModel === "mistralai/Mixtral-8x7b-Instruct-v0.1"
106 | ? "/pretrained_models/models--mistralai--Mixtral-8x7B-Instruct-v0.1/snapshots/f1ca00645f0b1565c7f9a1c863d2be6ebf896b04"
107 | : validatedForm.baseModel;
108 |
109 | // Send request to create new job
110 | await fetch(trainEndpoint[validatedForm.baseModel], {
111 | method: "POST",
112 | headers: {
113 | "Content-Type": "application/json",
114 | },
115 | body: JSON.stringify({
116 | wandb_token: "", // TODO: let user provide this
117 | learning_rate: validatedForm.learningRate,
118 | num_epochs: validatedForm.numberOfEpochs,
119 | model_name,
120 | model_id: model.id,
121 | dataset_name: pathToSend,
122 | hf_repo: validatedForm.name,
123 | max_tokens: maxTokens,
124 | gradient_accumulation_steps: gradientAccumulationSteps,
125 | per_device_train_batch_size: perDeviceTrainBatchSize,
126 | auth_token: process.env.MODAL_AUTH_TOKEN,
127 | }),
128 | });
129 |
130 | sendEvent(session.user.id, EventName.FINE_TUNE_STARTED, {priceInCents, baseModel: validatedForm.baseModel});
131 | }
132 |
--------------------------------------------------------------------------------
/main/src/server/controller/process-dataset.ts:
--------------------------------------------------------------------------------
1 | import * as y from "yup";
2 | import {AutoTokenizer} from "@xenova/transformers";
3 | import {createRepo, deleteRepo} from "@huggingface/hub";
4 | import {logger} from "~/server/utils/observability/logtail";
5 |
6 | import type {PreTrainedTokenizer} from "@xenova/transformers";
7 | import type {Models} from "~/constants/models";
8 |
9 | const fileSchema = y.object({
10 | messages: y
11 | .array(
12 | y
13 | .object({
14 | role: y.string().oneOf(["system", "user", "assistant"]).required(),
15 | content: y.string().required(),
16 | })
17 | .required(),
18 | )
19 | .required(),
20 | });
21 |
22 | function validateMessages(messages: y.InferType["messages"]): void {
23 | if (!messages[0]) {
24 | logger.info("[validateMessages] No messages found");
25 | throw new Error("There must be at least two messages.");
26 | }
27 |
28 | if (messages[0].role !== "system") {
29 | logger.info("[validateMessages] First message is not from system");
30 | throw new Error("The first message has to have the role 'system'.");
31 | }
32 |
33 | let lastRole = "system";
34 |
35 | for (const message of messages.slice(1)) {
36 | if (message.role !== "user" && message.role !== "assistant") {
37 | logger.info("[validateMessages] Invalid role", {role: message.role});
38 | throw new Error('Invalid role. Role must be either "user" or "assistant".');
39 | }
40 |
41 | if (message.role === lastRole) {
42 | logger.info("[validateMessages] Consecutive messages must have different roles");
43 | throw new Error("Consecutive messages must have different roles.");
44 | }
45 |
46 | lastRole = message.role;
47 | }
48 |
49 | if (lastRole !== "assistant") {
50 | logger.info("[validateMessages] Last message is not from assistant");
51 | throw new Error('The last message must be from "assistant".');
52 | }
53 | }
54 |
55 | async function getTextLengths(dataset: y.InferType[], model: string) {
56 | const tokenizer = await AutoTokenizer.from_pretrained(model);
57 | const sampleLengths = await Promise.all(dataset.map((data) => getSampleLength(data, tokenizer)));
58 |
59 | return sampleLengths;
60 | }
61 |
62 | async function getSampleLength(
63 | messageObject: y.InferType,
64 | tokenizer: PreTrainedTokenizer,
65 | ): Promise {
66 | const combinedContent = messageObject.messages.map((message) => message.content).join(" ");
67 | const tokenizedContent: {input_ids: {size: number}} = await tokenizer(combinedContent);
68 |
69 | return tokenizedContent.input_ids.size;
70 | }
71 |
72 | function calculatePrice(model: Models, totalTokens: number, epochs: number) {
73 | const baseFeeCents = 0;
74 |
75 | let centsPerThousandTokens = 0.4;
76 | if (model === "mistralai/Mixtral-8x7b-Instruct-v0.1") {
77 | centsPerThousandTokens = 0.8;
78 | }
79 |
80 | const thousands = Math.ceil(totalTokens / 1000);
81 | return baseFeeCents + thousands * centsPerThousandTokens * epochs;
82 | }
83 |
84 | function calculateBatchSize(model: Models, longestTextLength: number) {
85 | function settings(gradientAccumulationSteps: number, perDeviceTrainBatchSize: number, maxTokens: number) {
86 | return {
87 | gradientAccumulationSteps,
88 | perDeviceTrainBatchSize,
89 | maxTokens: maxTokens,
90 | };
91 | }
92 |
93 | const isMixtral = model === "mistralai/Mixtral-8x7b-Instruct-v0.1";
94 |
95 | if (isMixtral) {
96 | return settings(16, 1, 2000);
97 | }
98 |
99 | if (longestTextLength < 500) {
100 | return settings(2, 8, 2600);
101 | }
102 |
103 | if (longestTextLength < 900) {
104 | return settings(4, 4, 2600);
105 | }
106 |
107 | if (longestTextLength < 1400) {
108 | return settings(8, 2, 2600);
109 | }
110 |
111 | return settings(16, 1, 2600);
112 | }
113 |
114 | /**
115 | * Makes sure a given dataset is valid.
116 | * @param fileContent
117 | * @returns - The validated and parsed dataset.
118 | */
119 | export function checkFileValidity(fileContent: string) {
120 | const validatedDataset: y.InferType[] = [];
121 | for (const line of fileContent.split("\n")) {
122 | try {
123 | if (!line) {
124 | continue;
125 | }
126 |
127 | const jsonData = JSON.parse(line);
128 |
129 | const validated = fileSchema.validateSync(jsonData);
130 | validateMessages(validated.messages);
131 |
132 | validatedDataset.push(validated);
133 | } catch (e) {
134 | logger.error("Invalid JSON structure", {line, error: e});
135 | throw e;
136 | }
137 | }
138 |
139 | return validatedDataset;
140 | }
141 |
142 | /**
143 | * Calculates gradient accumulation steps, batch size and price for a given dataset.
144 | * @param fileContent
145 | * @param modelName
146 | * @param hfToken
147 | * @param epochs
148 | * @returns
149 | */
150 | export async function processDataset(fileContent: y.InferType[], modelName: Models, epochs: number) {
151 | const sampleLengths = await getTextLengths(fileContent, modelName);
152 | const totalTokens = sampleLengths.reduce((a, b) => a + b, 0);
153 | const priceInCents = calculatePrice(modelName, totalTokens, epochs);
154 |
155 | const maxLength = Math.max(...sampleLengths);
156 | const {maxTokens, gradientAccumulationSteps, perDeviceTrainBatchSize} = calculateBatchSize(modelName, maxLength);
157 |
158 | return {
159 | gradientAccumulationSteps,
160 | perDeviceTrainBatchSize,
161 | maxTokens,
162 | priceInCents,
163 | };
164 | }
165 |
166 | /**
167 | * Checks if a huggingface repo name is available by trying to create and delete it.
168 | * @param hfToken
169 | * @param name - shape: "username/repo-name"
170 | */
171 | export async function checkRepoNameAvailability(hfToken: string, name: string) {
172 | try {
173 | await createRepo({
174 | repo: {
175 | name,
176 | type: "model",
177 | },
178 | credentials: {
179 | accessToken: hfToken,
180 | },
181 | private: true,
182 | });
183 | await deleteRepo({
184 | repo: {
185 | name,
186 | type: "model",
187 | },
188 | credentials: {
189 | accessToken: hfToken,
190 | },
191 | });
192 | } catch (e) {
193 | logger.error("Repo name already taken", {name, error: e});
194 | throw new Error("This repo name is already taken.");
195 | }
196 | }
197 |
--------------------------------------------------------------------------------
/main/src/server/controller/stripe.ts:
--------------------------------------------------------------------------------
1 | import Stripe from "stripe";
2 |
3 | const stripeSecretKey = process.env.STRIPE_SECRET_KEY ?? "";
4 |
5 | /**
6 | * How setting up a payment method with Stripe works:
7 | *
8 | * 1. Client sends credit card info to Stripe
9 | * 2. Stripe returns a payment method ID
10 | * 3. Client sends payment method ID to server
11 | * 4. Server creates a Setup Intent with the payment method ID
12 | * 5. Server returns the client secret of the setup intent to the client
13 | * 6. Client uses the client secret to authenticate the payment method, sometimes with 3D secure
14 | * 7. Client sends the payment method ID and intent id back to the server, confirming that the payment method is authenticated
15 | * 8. Server validates that the payment method is authenticated
16 | * 9. Server creates a Stripe customer with the payment method ID
17 | */
18 |
19 | const stripe = new Stripe(stripeSecretKey, {
20 | apiVersion: "2023-10-16",
21 | appInfo: {
22 | name: "Haven NextJS",
23 | version: "0.0.1",
24 | },
25 | });
26 |
27 | /**
28 | * Step 4 of setting up a payment method with Stripe
29 | */
30 | export async function createSetupIntent(paymentMethodId: string) {
31 | return stripe.setupIntents.create({
32 | payment_method: paymentMethodId,
33 | });
34 | }
35 |
36 | /**
37 | * Step 8.
38 | */
39 | export async function validateSetupIntent(setupIntentClientSecret: string) {
40 | const setupIntentId = setupIntentClientSecret.split("_secret")[0];
41 | if (!setupIntentId) {
42 | throw new Error("Invalid setup intent secret");
43 | }
44 |
45 | const setupIntent = await stripe.setupIntents.retrieve(setupIntentId);
46 | return setupIntent && setupIntent.status === "succeeded";
47 | }
48 |
49 | /**
50 | * Step 9.
51 | */
52 | export async function createStripeCustomerAndAttachPaymentMethod(
53 | name: string,
54 | email: string,
55 | paymentMethodId: string,
56 | customerId?: string,
57 | ) {
58 | let customer: Stripe.Response;
59 |
60 | if (customerId) {
61 | const existingCustomer = await stripe.customers.retrieve(customerId);
62 |
63 | // TODO: We should automatically recover from this at some point.
64 | if (!existingCustomer || existingCustomer.deleted) {
65 | throw new Error("Customer used to exist but was deleted.");
66 | }
67 |
68 | customer = existingCustomer;
69 | } else {
70 | customer = await stripe.customers.create({
71 | name,
72 | email,
73 | });
74 | }
75 |
76 | await stripe.paymentMethods.attach(paymentMethodId, {customer: customer.id});
77 |
78 | return customer;
79 | }
80 |
81 | export async function getStripeCreditCardInfo(customerId: string) {
82 | const customer = await stripe.customers.retrieve(customerId);
83 | if (!customer || customer.deleted) {
84 | throw new Error("Stripe customer not found");
85 | }
86 |
87 | // Get payment methods of customer
88 | const paymentMethods = await stripe.paymentMethods.list({
89 | customer: customerId,
90 | type: "card",
91 | });
92 |
93 | // TODO: verify that the first payment method is a card
94 | if (!paymentMethods.data[0]) {
95 | throw new Error("No payment methods found");
96 | }
97 |
98 | const card = paymentMethods.data[0].card!;
99 |
100 | return {
101 | last4: card.last4,
102 | expiry: `${card.exp_month}/${card.exp_year}`,
103 | };
104 | }
105 |
106 | export function createPaymentIntent(customerId: string, paymentMethodId: string, amount: number, description: string) {
107 | return stripe.paymentIntents.create({
108 | customer: customerId,
109 | amount,
110 | currency: "usd",
111 | description,
112 | off_session: true,
113 | confirm: true,
114 | payment_method: paymentMethodId,
115 | });
116 | }
117 |
--------------------------------------------------------------------------------
/main/src/server/database/chat-request.ts:
--------------------------------------------------------------------------------
1 | import {db} from ".";
2 |
3 | /**
4 | * Returns the number of chat requests the user has made in the last 24 hours.
5 | */
6 | export async function getNumberOfChatRequestsInLast24Hours(userId: string) {
7 | const now = new Date();
8 | const yesterday = new Date(now.getTime() - 24 * 60 * 60 * 1000);
9 |
10 | return db.chatRequest.count({
11 | where: {
12 | userId: userId,
13 | createdAt: {
14 | gte: yesterday,
15 | },
16 | },
17 | });
18 | }
19 |
20 | /**
21 | * Creates a new chat request.
22 | */
23 | export async function createChatRequest(userId: string, modelId: string) {
24 | return db.chatRequest.create({
25 | data: {
26 | userId,
27 | modelId,
28 | },
29 | });
30 | }
31 |
--------------------------------------------------------------------------------
/main/src/server/database/dataset.ts:
--------------------------------------------------------------------------------
1 | import {db} from ".";
2 |
3 | export async function createDataset(userId: string, name: string, fileName: string, rows: number) {
4 | return db.dataset.create({
5 | data: {
6 | userId,
7 | name,
8 | fileName,
9 | rows,
10 | },
11 | });
12 | }
13 |
14 | export async function getDatasets(userId: string) {
15 | return db.dataset.findMany({
16 | where: {
17 | userId,
18 | },
19 | orderBy: {
20 | createdAt: "desc",
21 | },
22 | });
23 | }
24 |
25 | // Note: this might become slow at some point
26 | export async function getDatasetsByNameForUser(userId: string, name: string) {
27 | return await db.dataset.findMany({
28 | where: {
29 | userId,
30 | name: {
31 | contains: name,
32 | },
33 | },
34 | orderBy: {
35 | createdAt: "desc",
36 | },
37 | take: 5,
38 | });
39 | }
40 |
41 | export async function getDatasetById(id: string, userId: string) {
42 | return db.dataset.findUnique({
43 | where: {
44 | id,
45 | userId,
46 | },
47 | });
48 | }
49 |
--------------------------------------------------------------------------------
/main/src/server/database/index.ts:
--------------------------------------------------------------------------------
1 | import { PrismaClient } from "@prisma/client";
2 |
3 | import { env } from "~/env.mjs";
4 |
5 | const globalForPrisma = globalThis as unknown as {
6 | prisma: PrismaClient | undefined;
7 | };
8 |
9 | export const db =
10 | globalForPrisma.prisma ??
11 | new PrismaClient({
12 | log:
13 | env.NODE_ENV === "development" ? ["query", "error", "warn"] : ["error"],
14 | });
15 |
16 | if (env.NODE_ENV !== "production") globalForPrisma.prisma = db;
17 |
--------------------------------------------------------------------------------
/main/src/server/database/model.ts:
--------------------------------------------------------------------------------
1 | import {db} from ".";
2 |
3 | export async function getModels(userId: string) {
4 | return db.model.findMany({
5 | where: {
6 | userId,
7 | },
8 | });
9 | }
10 |
11 | export async function getModelFromId(id: string) {
12 | return db.model.findUnique({
13 | where: {
14 | id,
15 | },
16 | });
17 | }
18 |
19 | export async function createModel(
20 | userId: string,
21 | name: string,
22 | costInCents: number,
23 | datasetId: string,
24 | learningRate: string,
25 | epochs: number,
26 | baseModel: string,
27 | ) {
28 | return db.model.create({
29 | data: {
30 | userId,
31 | name,
32 | costInCents,
33 | datasetId,
34 | learningRate,
35 | epochs,
36 | baseModel,
37 | },
38 | });
39 | }
40 |
41 | export async function addWandBUrl(modelId: string, wandbUrl: string) {
42 | return db.model.update({
43 | where: {
44 | id: modelId,
45 | },
46 | data: {
47 | wandbUrl,
48 | },
49 | });
50 | }
51 |
52 | export async function updateState(modelId: string, state: "training" | "finished" | "error") {
53 | return db.model.update({
54 | where: {
55 | id: modelId,
56 | },
57 | data: {
58 | state,
59 | },
60 | });
61 | }
62 |
63 | export async function getJobCostInCents(modelId: string) {
64 | return (
65 | await db.model.findUnique({
66 | where: {
67 | id: modelId,
68 | },
69 | select: {
70 | costInCents: true,
71 | },
72 | })
73 | )?.costInCents;
74 | }
75 |
--------------------------------------------------------------------------------
/main/src/server/database/user.ts:
--------------------------------------------------------------------------------
1 | import {createHash} from "crypto";
2 | import {db} from ".";
3 |
4 | import {v4 as uuid} from "uuid";
5 |
6 | export function hash(value: string) {
7 | return createHash("sha256").update(value).digest("hex");
8 | }
9 |
10 | export async function getUserFromId(id: string) {
11 | return db.user.findUnique({
12 | where: {
13 | id,
14 | },
15 | });
16 | }
17 |
18 | /**
19 | * Create an api key for a user
20 | */
21 | export async function addApiKeyToUser(userId: string) {
22 | const id = uuid();
23 |
24 | await db.user.update({
25 | where: {
26 | id: userId,
27 | },
28 | data: {
29 | apiKey: id,
30 | },
31 | });
32 |
33 | return id;
34 | }
35 |
36 | export async function updateHfToken(userId: string, hfToken: string) {
37 | await db.user.update({
38 | where: {
39 | id: userId,
40 | },
41 | data: {
42 | hfToken,
43 | },
44 | });
45 | }
46 |
47 | export function updateEmail(_: string, __: string) {
48 | // Warning for future Konsti: we might need to update the email on the Stripe customer too
49 | throw new Error("Not implemented");
50 | }
51 |
52 | export async function updateName(userId: string, name: string) {
53 | await db.user.update({
54 | where: {
55 | id: userId,
56 | },
57 | data: {
58 | name,
59 | },
60 | });
61 | }
62 |
63 | /**
64 | * Add stripe customer id to user.
65 | */
66 | export async function addStripeCustomerIdAndPaymentMethodToUser(
67 | id: string,
68 | stripeCustomerId: string,
69 | paymentMethodId: string,
70 | ) {
71 | return db.user.update({
72 | where: {
73 | id: id,
74 | },
75 | data: {
76 | stripeCustomerId: stripeCustomerId,
77 | stripePaymentMethodId: paymentMethodId,
78 | },
79 | });
80 | }
81 |
82 | /**
83 | * Increments users cents balance by given amount.
84 | */
85 | export async function increaseBalance(id: string, cents: number, message?: string) {
86 | // TODO: move to controller
87 | await db.transaction.create({
88 | data: {
89 | userId: id,
90 | amount: cents,
91 | reason: message || "Account top-up",
92 | },
93 | });
94 |
95 | return db.user.update({
96 | where: {
97 | id: id,
98 | },
99 | data: {
100 | centsBalance: {
101 | increment: cents,
102 | },
103 | },
104 | });
105 | }
106 |
107 | /**
108 | * Decrements users cents balance by given amount.
109 | */
110 | export async function decreaseBalance(id: string, cents: number, reason: string) {
111 | // TODO: move to controller
112 | await db.transaction.create({
113 | data: {
114 | userId: id,
115 | amount: -cents,
116 | reason,
117 | },
118 | });
119 |
120 | return db.user.update({
121 | where: {
122 | id: id,
123 | },
124 | data: {
125 | centsBalance: {
126 | decrement: cents,
127 | },
128 | },
129 | });
130 | }
131 |
--------------------------------------------------------------------------------
/main/src/server/utils/mail.ts:
--------------------------------------------------------------------------------
1 | import {Resend} from "resend";
2 |
3 | const resend = new Resend(process.env.RESEND_KEY);
4 |
5 | async function sendMail(to: string, subject: string, html: string) {
6 | return resend.emails.send({
7 | from: "do-reply@haven.run",
8 | to,
9 | subject,
10 | html,
11 | });
12 | }
13 |
14 | export async function sendVerifyEmail(email: string, url: string) {
15 | const message = `Someone is trying to log into Haven using your email address. If it's you, click here . If not, you can safely ignore this email or respond to let us know.`;
16 | await sendMail(email, "Log in to Haven", message);
17 | }
18 |
--------------------------------------------------------------------------------
/main/src/server/utils/modal.ts:
--------------------------------------------------------------------------------
1 | const HOSTNAME = process.env.MODAL_HOSTNAME || "";
2 |
3 | export async function uploadFile(file: string, filename: string) {
4 | const body = JSON.stringify({
5 | file,
6 | filename,
7 | });
8 |
9 | const response = await fetch(`${HOSTNAME}/upload`, {
10 | method: "POST",
11 | body,
12 | headers: {
13 | "Content-Type": "application/json",
14 | },
15 | });
16 |
17 | if (!response.ok) {
18 | throw new Error("Could not upload file.");
19 | }
20 |
21 | return response.text();
22 | }
23 |
24 | export async function downloadFile(fileName: string) {
25 | const body = JSON.stringify({
26 | fileName,
27 | });
28 |
29 | const response = await fetch(`${HOSTNAME}/download`, {
30 | method: "POST",
31 | body,
32 | headers: {
33 | "Content-Type": "application/json",
34 | },
35 | });
36 |
37 | if (!response.ok) {
38 | throw new Error("Could not download file.");
39 | }
40 |
41 | return response.text();
42 | }
43 |
--------------------------------------------------------------------------------
/main/src/server/utils/observability/logtail.ts:
--------------------------------------------------------------------------------
1 | import {Logtail} from "@logtail/node";
2 | import {getServerSession} from "next-auth";
3 | import {authOptions} from "../../auth";
4 |
5 | export const logtail = new Logtail(process.env.LOGTAIL_KEY || "");
6 |
7 | // Don't send logs to Logtail in development
8 | function ifProd(fn: () => void) {
9 | if (process.env.LOGTAIL_ENV === "production") {
10 | fn();
11 | }
12 | }
13 |
14 | interface Data {
15 | userId?: string;
16 | [key: string]: any;
17 | }
18 |
19 | function createLogger(level: (typeof levels)[number]) {
20 | return (message: string, data: Data = {}) => {
21 | void getServerSession(authOptions).then((session) => {
22 | const payload = {
23 | ...data,
24 | userId: data.userId ? data.userId : session?.user?.id,
25 | };
26 |
27 | console[level](message, payload);
28 | ifProd(() => void logtail[level](message, payload));
29 | });
30 | };
31 | }
32 |
33 | const levels: ["error", "warn", "info", "debug"] = ["error", "warn", "info", "debug"];
34 |
35 | export const logger = {
36 | error: createLogger("error"),
37 | warn: createLogger("warn"),
38 | info: createLogger("info"),
39 | debug: createLogger("debug"),
40 | };
41 |
--------------------------------------------------------------------------------
/main/src/server/utils/observability/posthog.ts:
--------------------------------------------------------------------------------
1 | import {PostHog} from "posthog-node";
2 |
3 | const client = new PostHog("phc_YpKoFD7smPe4SXRtVyMW766uP9AjUwnuRJ8hh2EJcVv", {host: "https://eu.posthog.com"});
4 |
5 | export enum EventName {
6 | NEW_USER = "new-user",
7 |
8 | CREDIT_CARD_ADDED = "credit-card-added",
9 | MONEY_ADDED = "money-added",
10 |
11 | FINE_TUNE_PRICE_CALCULATED = "fine-tune-price-calculated",
12 | FINE_TUNE_STARTED = "fine-tune-started",
13 | FINE_TUNE_FINISHED = "fine-tune-finished",
14 | FINE_TUNE_FAILED = "fine-tune-failed",
15 | FINE_TUNE_WANDB_ADDED = "fine-tune-wandb-added",
16 |
17 | INFERENCE_REQUEST = "inference-request",
18 |
19 | EMERGENCY = "emergency",
20 | }
21 |
22 | export function sendEvent(userId: string, eventName: EventName, eventProperties: object = {}) {
23 | if (process.env.LOGTAIL_ENV !== "production") {
24 | return;
25 | }
26 |
27 | try {
28 | client.capture({
29 | distinctId: userId,
30 | event: eventName,
31 | properties: eventProperties,
32 | });
33 | } catch (e) {
34 | console.error(e);
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/main/src/server/utils/session.ts:
--------------------------------------------------------------------------------
1 | import {authOptions} from "../auth";
2 | import {getServerSession} from "next-auth";
3 | import {redirect} from "next/navigation";
4 |
5 | /**
6 | * Retrieves session and redirects to login if not found.
7 | */
8 | export async function checkSession() {
9 | const session = await getServerSession(authOptions);
10 | if (!session) {
11 | redirect("/");
12 | }
13 |
14 | return session;
15 | }
16 |
17 | /**
18 | * Retrieves session and throws an error if not found.
19 | */
20 | export async function checkSessionAction() {
21 | const session = await getServerSession(authOptions);
22 | if (!session) {
23 | throw new Error("Unauthorized");
24 | }
25 | return session;
26 | }
27 |
--------------------------------------------------------------------------------
/main/src/styles/globals.css:
--------------------------------------------------------------------------------
1 | @tailwind base;
2 | @tailwind components;
3 | @tailwind utilities;
4 |
5 | @layer base {
6 | :root {
7 | --background: 0 0% 100%;
8 | --foreground: 0 0% 3.9%;
9 | --card: 0 0% 100%;
10 | --card-foreground: 0 0% 3.9%;
11 | --popover: 0 0% 100%;
12 | --popover-foreground: 0 0% 3.9%;
13 | --primary: 0 0% 9%;
14 | --primary-foreground: 0 0% 98%;
15 | --secondary: 0 0% 96.1%;
16 | --secondary-foreground: 0 0% 9%;
17 | --muted: 0 0% 96.1%;
18 | --muted-foreground: 0 0% 45.1%;
19 | --accent: 0 0% 96.1%;
20 | --accent-foreground: 0 0% 9%;
21 | --destructive: 0 84.2% 60.2%;
22 | --destructive-foreground: 0 0% 98%;
23 | --border: 0 0% 89.8%;
24 | --input: 0 0% 89.8%;
25 | --ring: 0 0% 3.9%;
26 | --radius: 0.5rem;
27 | }
28 |
29 | .dark {
30 | --background: 0 0% 3.9%;
31 | --foreground: 0 0% 98%;
32 | --card: 0 0% 3.9%;
33 | --card-foreground: 0 0% 98%;
34 | --popover: 0 0% 3.9%;
35 | --popover-foreground: 0 0% 98%;
36 | --primary: 0 0% 98%;
37 | --primary-foreground: 0 0% 9%;
38 | --secondary: 0 0% 14.9%;
39 | --secondary-foreground: 0 0% 98%;
40 | --muted: 0 0% 14.9%;
41 | --muted-foreground: 0 0% 63.9%;
42 | --accent: 0 0% 14.9%;
43 | --accent-foreground: 0 0% 98%;
44 | --destructive: 0 62.8% 30.6%;
45 | --destructive-foreground: 0 0% 98%;
46 | --border: 0 0% 14.9%;
47 | --input: 0 0% 14.9%;
48 | --ring: 0 0% 83.1%;
49 | }
50 | }
51 |
52 | @layer base {
53 | * {
54 | @apply border-border;
55 | }
56 | body {
57 | @apply bg-background text-foreground;
58 | }
59 | }
--------------------------------------------------------------------------------
/main/tailwind.config.ts:
--------------------------------------------------------------------------------
1 | export default {
2 | darkMode: "class",
3 | content: ["./index.html", "./src/**/*.{js,ts,jsx,tsx}"],
4 | theme: {
5 | transparent: "transparent",
6 | current: "currentColor",
7 | extend: {
8 | colors: {
9 | gray: {
10 | 50: "#FAFAFA",
11 | 100: "#F3F3F3",
12 | 200: "#E5E5E5",
13 | 300: "#C7C7C7",
14 | 400: "#ACACAC",
15 | 500: "#808080",
16 | 600: "#5F5F5F",
17 | 700: "#4E4E4E",
18 | 800: "#373737",
19 | 900: "#2A2A2A",
20 | 950: "#151515",
21 | },
22 | // SHADCN related config starts here
23 | border: "hsl(var(--border))",
24 | input: "hsl(var(--input))",
25 | ring: "hsl(var(--ring))",
26 | background: "hsl(var(--background))",
27 | foreground: "hsl(var(--foreground))",
28 | primary: {
29 | DEFAULT: "hsl(var(--primary))",
30 | foreground: "hsl(var(--primary-foreground))",
31 | },
32 | secondary: {
33 | DEFAULT: "hsl(var(--secondary))",
34 | foreground: "hsl(var(--secondary-foreground))",
35 | },
36 | destructive: {
37 | DEFAULT: "hsl(var(--destructive))",
38 | foreground: "hsl(var(--destructive-foreground))",
39 | },
40 | muted: {
41 | DEFAULT: "hsl(var(--muted))",
42 | foreground: "hsl(var(--muted-foreground))",
43 | },
44 | accent: {
45 | DEFAULT: "hsl(var(--accent))",
46 | foreground: "hsl(var(--accent-foreground))",
47 | },
48 | popover: {
49 | DEFAULT: "hsl(var(--popover))",
50 | foreground: "hsl(var(--popover-foreground))",
51 | },
52 | card: {
53 | DEFAULT: "hsl(var(--card))",
54 | foreground: "hsl(var(--card-foreground))",
55 | },
56 | borderRadius: {
57 | lg: "var(--radius)",
58 | md: "calc(var(--radius) - 2px)",
59 | sm: "calc(var(--radius) - 4px)",
60 | },
61 | keyframes: {
62 | "accordion-down": {
63 | from: {height: 0},
64 | to: {height: "var(--radix-accordion-content-height)"},
65 | },
66 | "accordion-up": {
67 | from: {height: "var(--radix-accordion-content-height)"},
68 | to: {height: 0},
69 | },
70 | },
71 | animation: {
72 | "accordion-down": "accordion-down 0.2s ease-out",
73 | "accordion-up": "accordion-up 0.2s ease-out",
74 | },
75 | },
76 | },
77 | container: {
78 | center: true,
79 | padding: "2rem",
80 | screens: {
81 | "2xl": "1400px",
82 | },
83 | },
84 | // SHADCN related config ends here
85 | },
86 | safelist: [
87 | {
88 | pattern:
89 | /^(bg-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/,
90 | variants: ["hover", "ui-selected"],
91 | },
92 | {
93 | pattern:
94 | /^(text-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/,
95 | variants: ["hover", "ui-selected"],
96 | },
97 | {
98 | pattern:
99 | /^(border-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/,
100 | variants: ["hover", "ui-selected"],
101 | },
102 | {
103 | pattern:
104 | /^(ring-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/,
105 | },
106 | {
107 | pattern:
108 | /^(stroke-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/,
109 | },
110 | {
111 | pattern:
112 | /^(fill-(?:slate|gray|zinc|neutral|stone|red|orange|amber|yellow|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|purple|fuchsia|pink|rose)-(?:50|100|200|300|400|500|600|700|800|900|950))$/,
113 | },
114 | ],
115 | plugins: [require("@headlessui/tailwindcss"), require("@tailwindcss/forms"), require("tailwindcss-animate")], // tailwindcss-animate is from SHADCN
116 | };
117 |
--------------------------------------------------------------------------------
/main/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "es2017",
4 | "lib": ["dom", "dom.iterable", "esnext"],
5 | "allowJs": true,
6 | "checkJs": true,
7 | "skipLibCheck": true,
8 | "strict": true,
9 | "forceConsistentCasingInFileNames": true,
10 | "noEmit": true,
11 | "esModuleInterop": true,
12 | "module": "esnext",
13 | "moduleResolution": "node",
14 | "resolveJsonModule": true,
15 | "isolatedModules": true,
16 | "jsx": "preserve",
17 | "incremental": true,
18 | "noUncheckedIndexedAccess": true,
19 | "baseUrl": ".",
20 | "paths": {
21 | "~/*": ["./src/*"]
22 | },
23 | "plugins": [{ "name": "next" }]
24 | },
25 | "include": [
26 | ".eslintrc.cjs",
27 | "next-env.d.ts",
28 | "**/*.ts",
29 | "**/*.tsx",
30 | "**/*.cjs",
31 | "**/*.mjs",
32 | ".next/types/**/*.ts"
33 | ],
34 | "exclude": ["node_modules"]
35 | }
36 |
--------------------------------------------------------------------------------
/modal-proxy/gunicorn_config.py:
--------------------------------------------------------------------------------
1 | bind = '0.0.0.0:8000'
2 | workers = 4
--------------------------------------------------------------------------------
/modal-proxy/requirements.txt:
--------------------------------------------------------------------------------
1 | modal==0.55.4091
2 | flask==3.0.0
3 | gunicorn==21.2.0
--------------------------------------------------------------------------------
/modal-proxy/run.py:
--------------------------------------------------------------------------------
1 | from modal.cli.volume import put, get
2 |
3 | import os
4 | from flask import Flask, request
5 |
6 | app = Flask(__name__)
7 |
8 | @app.route('/upload', methods=['POST'])
9 | def upload():
10 | json = request.get_json()
11 |
12 | file: str = json['file']
13 | filename = json['filename']
14 |
15 | with open(filename, 'w') as f:
16 | f.write(file)
17 |
18 | put("datasets", filename, "/", env=None)
19 |
20 | os.remove(filename)
21 |
22 | return filename
23 |
24 | @app.route('/download', methods=['POST'])
25 | def download():
26 | json = request.get_json()
27 |
28 | filename = json['fileName']
29 | print("filename: "+filename)
30 |
31 | get("datasets", filename, ".", env=None)
32 |
33 | with open(filename, 'r') as f:
34 | file = f.read()
35 |
36 | os.remove(filename)
37 |
38 | return file
39 |
40 | if __name__ == '__main__':
41 | # MODAL_CONFIG_PATH=./.modal.toml python run.py
42 | app.run(debug=False)
--------------------------------------------------------------------------------
/prettier.config.js:
--------------------------------------------------------------------------------
1 | module.exports = {
2 | useTabs: true,
3 | semi: true,
4 | singleQuote: false,
5 | quoteProps: "as-needed",
6 | trailingComma: "all",
7 | bracketSpacing: false,
8 | arrowParantheses: "always",
9 | printWidth: 120,
10 | };
11 |
--------------------------------------------------------------------------------