├── comparer ├── src │ ├── react-app-env.d.ts │ ├── App.tsx │ ├── setupTests.ts │ ├── App.test.tsx │ ├── FileComparer.css │ ├── Comparer.css │ ├── index.css │ ├── reportWebVitals.ts │ ├── index.tsx │ ├── components │ │ ├── Prompt.css │ │ └── Prompt.tsx │ ├── FileLoaderTextArea.css │ ├── Comparer.tsx │ ├── useHooks.ts │ ├── logo.svg │ ├── FileComparer.tsx │ └── FileLoaderTextArea.tsx ├── public │ ├── robots.txt │ ├── favicon.ico │ ├── EG_fevicon.png │ ├── logo_expgrad.png │ ├── manifest.json │ └── index.html ├── README.md ├── .gitignore ├── tsconfig.json └── package.json ├── funtuner ├── custom_datasets │ ├── __init__.py │ ├── utils.py │ └── sftdataset.py ├── config │ ├── datasets.yaml │ ├── trainer │ │ └── default.yaml │ ├── config.yaml │ ├── zero2.json │ ├── zero3.json │ └── templates.json ├── sampling.py ├── utils.py ├── inference.py └── trainer.py ├── evals ├── config │ └── generation.yaml ├── sampler.py └── results │ └── results-open-llama-3B-orcastyle.json ├── run_gpu.sh ├── infer.sh ├── hf_upload.py ├── README.md ├── .gitignore └── LICENSE /comparer/src/react-app-env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | -------------------------------------------------------------------------------- /comparer/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /comparer/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vibrantlabsai/Funtuner/HEAD/comparer/public/favicon.ico -------------------------------------------------------------------------------- /comparer/public/EG_fevicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vibrantlabsai/Funtuner/HEAD/comparer/public/EG_fevicon.png -------------------------------------------------------------------------------- /comparer/public/logo_expgrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vibrantlabsai/Funtuner/HEAD/comparer/public/logo_expgrad.png -------------------------------------------------------------------------------- /funtuner/custom_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from funtuner.custom_datasets.sftdataset import FunDataset, FunDataCollator 2 | from funtuner.custom_datasets.utils import get_datasets 3 | -------------------------------------------------------------------------------- /evals/config/generation.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | temperature: 0.0 3 | max_new_tokens: 128 4 | do_sample: true 5 | 6 | k50: 7 | top_k: 30 8 | top_p: 0.95 9 | repetition_penalty: 1.2 10 | do_sample: false -------------------------------------------------------------------------------- /comparer/src/App.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { Comparer } from './Comparer'; 3 | 4 | function App() { 5 | return ( 6 |
7 | 8 |
9 | ); 10 | } 11 | 12 | export default App; 13 | -------------------------------------------------------------------------------- /comparer/src/setupTests.ts: -------------------------------------------------------------------------------- 1 | // jest-dom adds custom jest matchers for asserting on DOM nodes. 2 | // allows you to do things like: 3 | // expect(element).toHaveTextContent(/react/i) 4 | // learn more: https://github.com/testing-library/jest-dom 5 | import '@testing-library/jest-dom'; 6 | -------------------------------------------------------------------------------- /funtuner/config/datasets.yaml: -------------------------------------------------------------------------------- 1 | databricks/databricks-dolly-15k: 2 | prompt: instruction 3 | context: context 4 | response: response 5 | 6 | Dahoas/cot_gsm8k: 7 | prompt: question 8 | response: answer 9 | 10 | psmathur/WizardLM_Orca: 11 | prompt: instruction 12 | response: output -------------------------------------------------------------------------------- /comparer/src/App.test.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import { render, screen } from '@testing-library/react'; 3 | import App from './App'; 4 | 5 | test('renders learn react link', () => { 6 | render(); 7 | const linkElement = screen.getByText(/learn react/i); 8 | expect(linkElement).toBeInTheDocument(); 9 | }); 10 | -------------------------------------------------------------------------------- /comparer/src/FileComparer.css: -------------------------------------------------------------------------------- 1 | .config_options { 2 | display: flex; 3 | gap: 10px; 4 | width: 100%; 5 | background: #eeeeee; 6 | padding: 10px; 7 | box-sizing: border-box; 8 | background: #eeeecc; 9 | } 10 | 11 | .configContainer { 12 | display: flex; 13 | gap: 3px; 14 | } 15 | 16 | .configContainer label { 17 | cursor: pointer; 18 | } -------------------------------------------------------------------------------- /comparer/README.md: -------------------------------------------------------------------------------- 1 | # Simple model generation comparator 2 | 3 | Example report files: 4 | 5 | https://github.com/LAION-AI/Open-Assistant/tree/main/model/model_eval/manual/sampling_reports 6 | 7 | Building 8 | ======== 9 | 10 | ```sh 11 | npm install 12 | npm start 13 | ``` 14 | 15 | ## Credits 16 | * [Open-assistant](https://github.com/Open-Assistant/oasst-model-eval/tree/main/model_comparer/public) 17 | -------------------------------------------------------------------------------- /comparer/src/Comparer.css: -------------------------------------------------------------------------------- 1 | .comparer { 2 | width: 100%; 3 | height: 100%; 4 | padding: 10px 20px; 5 | box-sizing: border-box; 6 | } 7 | .comparer input { 8 | width: 100%; 9 | border: 0; 10 | outline: 0; 11 | color: darkgray; 12 | font-weight: bold; 13 | font-family: 'Roboto', sans-serif; 14 | } 15 | 16 | h1 { 17 | font-family: 'Roboto', sans-serif; 18 | color: black; 19 | font-size: 22px; 20 | } 21 | -------------------------------------------------------------------------------- /comparer/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | -------------------------------------------------------------------------------- /comparer/src/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 4 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', 5 | sans-serif; 6 | -webkit-font-smoothing: antialiased; 7 | -moz-osx-font-smoothing: grayscale; 8 | } 9 | 10 | code { 11 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', 12 | monospace; 13 | } 14 | 15 | h2 { 16 | font-size: 14px; 17 | } 18 | -------------------------------------------------------------------------------- /comparer/src/reportWebVitals.ts: -------------------------------------------------------------------------------- 1 | import { ReportHandler } from 'web-vitals'; 2 | 3 | const reportWebVitals = (onPerfEntry?: ReportHandler) => { 4 | if (onPerfEntry && onPerfEntry instanceof Function) { 5 | import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => { 6 | getCLS(onPerfEntry); 7 | getFID(onPerfEntry); 8 | getFCP(onPerfEntry); 9 | getLCP(onPerfEntry); 10 | getTTFB(onPerfEntry); 11 | }); 12 | } 13 | }; 14 | 15 | export default reportWebVitals; 16 | -------------------------------------------------------------------------------- /comparer/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "Comparer", 3 | "name": "Comparer", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "logo_expgrad.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "logo_expgrad.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } 26 | -------------------------------------------------------------------------------- /funtuner/config/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: transformers.TrainingArguments 2 | output_dir: "." 3 | learning_rate: 1e-4 4 | gradient_checkpointing: true 5 | gradient_accumulation_steps: 16 6 | per_device_train_batch_size: 2 7 | per_device_eval_batch_size: 2 8 | adam_beta1: 0.9 9 | adam_beta2: 0.95 10 | adam_epsilon: 1e-12 11 | weight_decay: 0.001 12 | eval_steps: 100 13 | save_steps: 100 14 | num_train_epochs: 1 15 | logging_steps: 10 16 | max_grad_norm: 1.0 17 | save_total_limit: 4 18 | fp16: true 19 | bf16: false 20 | lr_scheduler_type: cosine 21 | warmup_ratio: 0.15 22 | evaluation_strategy: steps 23 | use_legacy_prediction_loop: false 24 | -------------------------------------------------------------------------------- /comparer/src/index.tsx: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom/client'; 3 | import './index.css'; 4 | import App from './App'; 5 | import reportWebVitals from './reportWebVitals'; 6 | 7 | const root = ReactDOM.createRoot( 8 | document.getElementById('root') as HTMLElement 9 | ); 10 | root.render( 11 | 12 | 13 | 14 | ); 15 | 16 | // If you want to start measuring performance in your app, pass a function 17 | // to log results (for example: reportWebVitals(console.log)) 18 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals 19 | reportWebVitals(); 20 | -------------------------------------------------------------------------------- /comparer/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "es5", 4 | "lib": [ 5 | "dom", 6 | "dom.iterable", 7 | "esnext" 8 | ], 9 | "allowJs": true, 10 | "skipLibCheck": true, 11 | "esModuleInterop": true, 12 | "allowSyntheticDefaultImports": true, 13 | "strict": true, 14 | "forceConsistentCasingInFileNames": true, 15 | "noFallthroughCasesInSwitch": true, 16 | "module": "esnext", 17 | "moduleResolution": "node", 18 | "resolveJsonModule": true, 19 | "isolatedModules": true, 20 | "noEmit": true, 21 | "jsx": "react-jsx" 22 | }, 23 | "include": [ 24 | "src" 25 | ] 26 | } 27 | -------------------------------------------------------------------------------- /run_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH -J RM # job name 3 | #SBATCH -o o.%x.%j # output file 4 | #SBATCH -e e.%x.%j # error file 5 | #SBATCH -p gpu_v100 # partition 6 | #SBATCH --gres=gpu:2 7 | #SBATCH -n 8 # number of tasks (1 CPU per task by default) 8 | #SBATCH --time=06:00:00 # time 9 | #SBATCH --account=scw2050 # project account number 10 | 11 | git pull origin dev-train 12 | module purge 13 | module load deepspeed 14 | module list 15 | export PYTHONPATH="${PYTHONPATH}:/home/c.scmse/Funtuner" 16 | exec singularity exec --nv $DEEPSPEED_IMAGE /nfshome/store03/users/c.scmse/venv/bin/python3 funtuner/trainer.py 17 | 18 | -------------------------------------------------------------------------------- /infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash --login 2 | #SBATCH -J RM # job name 3 | #SBATCH -o o.%x.%j # output file 4 | #SBATCH -e e.%x.%j # error file 5 | #SBATCH -p gpu_v100 # partition 6 | #SBATCH --gres=gpu:1 7 | #SBATCH -n 8 # number of tasks (1 CPU per task by default) 8 | #SBATCH --time=01:00:00 # time 9 | #SBATCH --account=scw2050 # project account number 10 | 11 | git pull origin dev-train 12 | module purge 13 | module load deepspeed 14 | module list 15 | export PYTHONPATH="${PYTHONPATH}:/home/c.scmse/Funtuner" 16 | exec singularity exec --nv $DEEPSPEED_IMAGE /nfshome/store03/users/c.scmse/venv/bin/python funtuner/sampling.py --model_url shahules786/Redpajama-3B-CoT --dataset Dahoas/cot_gsm8k 17 | exec singularity exec --nv $DEEPSPEED_IMAGE /nfshome/store03/users/c.scmse/venv/bin/python evals/sampler.py --model_url shahules786/Redpajama-3B-CoT 18 | 19 | -------------------------------------------------------------------------------- /funtuner/config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - trainer: default 3 | model: openlm-research/open_llama_7b 4 | log_dir: "/scratch/c.scmse/Funtuner-logs" 5 | log_wandb: true 6 | run_name: "" 7 | wandb_entity: "shahules786" 8 | max_length: 2048 9 | per_digit_tokens: False 10 | special_tokens: 11 | eos_token: "" 12 | sep_token: "" 13 | pad_token: "" 14 | datasets: 15 | 16 | - Dahoas/cot_gsm8k: 17 | split: ["train","val"] 18 | - psmathur/WizardLM_Orca: 19 | split: ["train"] 20 | 21 | validation_size: 0.02 22 | deepspeed: true 23 | deepspeed_config: "./funtuner/config/zero2.json" 24 | LoRa: true 25 | LoraConfig: 26 | r: 8 27 | target_modules: all 28 | lora_alpha: 16 29 | bias: none 30 | lora_dropout: 0.05 31 | task_type: CAUSAL_LM 32 | inference_mode: false 33 | qlora: true 34 | qlora_config: 35 | double_quant: true 36 | quant_type: nf4 37 | load_in_4_bit: true 38 | load_in_8_bit: false 39 | template: alpaca-lora 40 | -------------------------------------------------------------------------------- /funtuner/config/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "offload_optimizer": { 16 | "device": "cpu" 17 | }, 18 | "allgather_partitions": true, 19 | "allgather_bucket_size": 5e8, 20 | "overlap_comm": true, 21 | "reduce_scatter": true, 22 | "reduce_bucket_size": "auto", 23 | "contiguous_gradients": true 24 | }, 25 | "gradient_accumulation_steps": "auto", 26 | "gradient_clipping": "auto", 27 | "steps_per_print": 2000, 28 | "train_batch_size": "auto", 29 | "train_micro_batch_size_per_gpu": "auto", 30 | "wall_clock_breakdown": false 31 | } -------------------------------------------------------------------------------- /comparer/src/components/Prompt.css: -------------------------------------------------------------------------------- 1 | .promptBubble { 2 | background: rgb(0, 120, 255); 3 | width: fit-content; 4 | padding: 1px 10px; 5 | border-radius: 10px 10px 10px 0; 6 | margin: 10px; 7 | cursor: pointer; 8 | } 9 | 10 | .replyBubble { 11 | background: lightblue; 12 | padding: 1px 10px; 13 | border-radius: 10px 10px 0 10px; 14 | margin: 2px 0 10px 0; 15 | line-height: 1.4em; 16 | } 17 | 18 | .noMarkdown { 19 | white-space: pre-wrap; 20 | padding: 16px 0; 21 | } 22 | 23 | 24 | .model_name { 25 | font-family: monospace; 26 | } 27 | 28 | .sampling_config { 29 | display: flex; 30 | gap: 10px; 31 | color: gray; 32 | font-size: 12px; 33 | margin: 0 0 0 10px; 34 | } 35 | 36 | .param-value, .sampling_config b { 37 | color: #222222; 38 | } 39 | 40 | .replyTableWithAltRows { 41 | border-collapse: collapse; 42 | } 43 | 44 | /* color rows with alternating colors */ 45 | .replyTableWithAltRows tr:nth-child(even) { 46 | background-color: #f2f2f2; 47 | } 48 | 49 | .replyTableWithAltRows tr { 50 | border-bottom: 1px solid #ddd; 51 | } 52 | 53 | .replyTableWithAltRows td { 54 | padding-bottom: 10px; 55 | } -------------------------------------------------------------------------------- /comparer/src/FileLoaderTextArea.css: -------------------------------------------------------------------------------- 1 | 2 | .filelist { 3 | font-family: 'Roboto', sans-serif; 4 | color: gray; 5 | } 6 | 7 | textarea.filenames { 8 | width: 100%; 9 | height: 200px; 10 | box-sizing: border-box; 11 | resize: vertical; 12 | word-break: break-all; 13 | border: 1px solid lightgray; 14 | background: #eeeeee; 15 | padding: 7px; 16 | } 17 | .loading_wait textarea { 18 | color: black; 19 | } 20 | .loading_success textarea { 21 | color: darkgreen; 22 | } 23 | .loading_errors, .loading_errors textarea { 24 | color: red; 25 | } 26 | .local_files { 27 | font-family: monospace; 28 | margin: 7px; 29 | color: darkgreen; 30 | } 31 | 32 | .error { 33 | color: darkred; 34 | } 35 | 36 | .hint { 37 | color: #555555; 38 | font-size: 12px; 39 | text-align: center; 40 | padding-top: 10px; 41 | padding-bottom: 10px; 42 | } 43 | 44 | .dropzone { 45 | cursor: pointer; 46 | min-height: 100px; 47 | background: repeating-linear-gradient(-45deg, #eeeeee, #eeeeee 1px, white 5px, white 10px); 48 | display: flex; 49 | flex-direction: column; 50 | margin-bottom: 10px; 51 | } 52 | 53 | .dropzone_messages { 54 | flex-grow: 1; 55 | } -------------------------------------------------------------------------------- /funtuner/config/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 3, 15 | "offload_optimizer": { 16 | "device": "cpu", 17 | "pin_memory": true 18 | }, 19 | "offload_param": { 20 | "device": "cpu", 21 | "pin_memory": true 22 | }, 23 | "overlap_comm": true, 24 | "contiguous_gradients": true, 25 | "sub_group_size": 1e9, 26 | "reduce_bucket_size": "auto", 27 | "stage3_prefetch_bucket_size": "auto", 28 | "stage3_param_persistence_threshold": "auto", 29 | "stage3_max_live_parameters": 1e9, 30 | "stage3_max_reuse_distance": 1e9, 31 | "stage3_gather_16bit_weights_on_model_save": true 32 | }, 33 | "gradient_accumulation_steps": "auto", 34 | "gradient_clipping": "auto", 35 | "steps_per_print": 2000, 36 | "train_batch_size": "auto", 37 | "train_micro_batch_size_per_gpu": "auto", 38 | "wall_clock_breakdown": false 39 | } -------------------------------------------------------------------------------- /hf_upload.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import HfApi 2 | import os 3 | import argparse 4 | 5 | api = HfApi() 6 | 7 | tokenizer_files = ["tokenizer.json", "special_tokens_map.json", "tokenizer_config.json", "tokenizer.model"] 8 | model_files = ["adapter_config.json", "adapter_model.bin"] 9 | 10 | if __name__ == "__main__": 11 | 12 | parser = argparse.ArgumentParser(description="") 13 | parser.add_argument("--model_url", type=str, help="model url") 14 | parser.add_argument("--root_dir", type=str, help="model url") 15 | parser.add_argument("--checkpoint", type=str, help="checkpoint id") 16 | 17 | args = parser.parse_args().__dict__ 18 | root = args.get("root_dir") 19 | 20 | files = [os.path.join(args.get("root_dir"), file) for file in tokenizer_files] + \ 21 | [os.path.join(args.get("root_dir") ,args.get("checkpoint"), file) for file in model_files] 22 | 23 | for file in files: 24 | try: 25 | api.upload_file( 26 | path_or_fileobj=file, 27 | repo_id=args.get("model_url"), 28 | repo_type="model", 29 | path_in_repo=file.split('/')[-1] 30 | ) 31 | except Exception as e: 32 | print(e) -------------------------------------------------------------------------------- /comparer/src/Comparer.tsx: -------------------------------------------------------------------------------- 1 | import React, { useMemo, useState } from 'react'; 2 | import './Comparer.css'; 3 | import { FileComparer } from './FileComparer'; 4 | import { FileLoaderTextArea } from './FileLoaderTextArea'; 5 | 6 | export interface JsonFilePrompt { 7 | outputs: string[]; 8 | sampling_config: string; 9 | sampling_params: {[key: string]: string | number | boolean}; 10 | } 11 | 12 | export interface JsonFile { 13 | model_name: string; 14 | date: string; 15 | args: {[key: string]: any}; 16 | prompts: {prompt: string, results: JsonFilePrompt[]}[]; 17 | } 18 | 19 | export const Comparer = () => { 20 | const [files, setFiles] = useState<(JsonFile | undefined)[]>([]); // Can be undefined when loading happens out of order 21 | const [localFiles, setLocalFiles] = useState([]); 22 | 23 | function localFileAdded(file: JsonFile) { 24 | setLocalFiles(localFiles => [...localFiles, file]); 25 | } 26 | 27 | const filesMerged = useMemo(() => ([...files, ...localFiles].filter( f => f) as JsonFile[]), [files, localFiles]); 28 | 29 | return ( 30 |
31 |

Battle of LLMs

32 | 33 | 34 |
35 | ); 36 | } 37 | -------------------------------------------------------------------------------- /comparer/src/useHooks.ts: -------------------------------------------------------------------------------- 1 | import { useState } from "react"; 2 | 3 | export function useStateWithLocalStorageInt(initial: number, key: string): [number, (value: number) => void] { 4 | const [value, setValue] = useState(parseInt(localStorage.getItem(key) || initial.toString())); 5 | const setValueAndStore = (newValue: number) => { 6 | localStorage.setItem(key, newValue.toString()); 7 | setValue(newValue); 8 | } 9 | return [value, setValueAndStore]; 10 | } 11 | 12 | export function useStateWithLocalStorageBoolean(initial: boolean, key: string): [boolean, (value: boolean) => void] { 13 | const ls = localStorage.getItem(key); 14 | const [value, setValue] = useState(ls === "true" || ls === "false" ? ls === "true" : initial); 15 | const setValueAndStore = (newValue: boolean) => { 16 | localStorage.setItem(key, newValue.toString()); 17 | setValue(newValue); 18 | } 19 | return [value, setValueAndStore]; 20 | } 21 | 22 | export function useStateWithLocalStorageString(initial: string, key: string): [string, (value: string) => void] { 23 | const [value, setValue] = useState(localStorage.getItem(key) ?? initial); 24 | const setValueAndStore = (newValue: string) => { 25 | localStorage.setItem(key, newValue.toString()); 26 | setValue(newValue); 27 | } 28 | return [value, setValueAndStore]; 29 | } 30 | -------------------------------------------------------------------------------- /funtuner/config/templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "alpaca-lora" : { 3 | "description":"", 4 | "prompt_and_input":"###System\nBelow is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{context}\n\n### Response:\n", 5 | "prompt_only":"###System\nBelow is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 6 | "response_split": "### Response:" 7 | }, 8 | 9 | "orca-style":{ 10 | "description":"", 11 | "prompt_and_input":"###System Instruction:\n You are an AI assistant. User will you give you a task paired with an input that provides further context. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps. \n\n### Instruction:\n{instruction}\n\n### Input:\n{context}\n\n### Response:\n", 12 | "prompt_only":"System Instruction: You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 13 | "response_split": "### Response:" 14 | } 15 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FunTuner 2 | A no nonsense easy to configure model fine-tuning framework for GPT based models that can get the job done in a memory and time efficient manner. 3 | 4 | :radioactive: Work in progress 5 | 6 | ## Components 7 | ✅hydra configuration 8 | 9 | ✅Deepspeed support 10 | 11 | ✅8 bit training 12 | 13 | ✅LoRA using peft 14 | 15 | ✅Sequence bucketing 16 | 17 | ✅Inference 18 | 19 | ✅single 20 | ✅batch 21 | ❎stream 22 | 23 | ✅Supported Models 24 | 25 | ✅GPTNeoX - Redajajama, Pythia, etc 26 | ❎LLama 27 | ❎Falcon 28 | 29 | ❎Flash attention 30 | 31 | 32 | ## Train 33 | 34 | * Using deepspeed 35 | 36 | ```bash 37 | deepspeed funtuner/trainer.py 38 | ``` 39 | 40 | ## Inference 41 | ```python 42 | from funtuner.inference import Inference 43 | model = Inference("shahules786/GPTNeo-125M-lora") 44 | kwargs = {"temperature":0.1, 45 | "top_p":0.75, 46 | "top_k":5, 47 | "num_beams":2, 48 | "max_new_tokens":128,} 49 | 50 | ##single 51 | output =model.generate("Which is a species of fish? Tope or Rope",**kwargs) 52 | 53 | ##batch 54 | inputs = [["There was a tiger in the hidden"],["Which is a species of fish? Tope or Rope"]] 55 | output = model.batch_generate(inputs,**kwargs) 56 | 57 | ``` 58 | 59 | 60 | ## Sampling 61 | 62 | ```bash 63 | python funtuner/sampling.py --model_url shahules786/Redpajama-3B-CoT --dataset Dahoas/cot_gsm8k 64 | ``` -------------------------------------------------------------------------------- /funtuner/custom_datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import random_split, ConcatDataset 3 | from funtuner.custom_datasets.sftdataset import FunDataset 4 | import yaml 5 | from pathlib import Path 6 | 7 | generator = torch.Generator().manual_seed(42) 8 | DATASET_MAPPING = yaml.safe_load(Path("funtuner/config/datasets.yaml").read_text()) 9 | 10 | 11 | def get_single_dataset(name, split, template): 12 | args = DATASET_MAPPING.get(name) 13 | if args is not None: 14 | dataset = FunDataset(name=name, split=split, template=template, **args) 15 | else: 16 | raise ValueError(f"Invalid dataset name {name}. Add dataset to dataset.yaml") 17 | 18 | return dataset 19 | 20 | 21 | def get_datasets(config): 22 | dataset_list = [] 23 | template = config.template 24 | for dataset in config.datasets: 25 | name = list(dataset.keys())[0] 26 | splits = dataset[name].get("split", "train") 27 | dataset_list.append( 28 | get_single_dataset( 29 | name, 30 | splits, 31 | template, 32 | ) 33 | ) 34 | 35 | dataset = ConcatDataset(dataset_list) 36 | train_dataset, valid_dataset = random_split( 37 | dataset, 38 | [1 - config.validation_size, config.validation_size], 39 | generator=generator, 40 | ) 41 | return train_dataset, valid_dataset 42 | -------------------------------------------------------------------------------- /comparer/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Battle of LLMs", 3 | "version": "0.1.0", 4 | "homepage": "https://open-assistant.github.io/FunTuner/", 5 | "private": true, 6 | "dependencies": { 7 | "@testing-library/jest-dom": "^5.16.5", 8 | "@testing-library/react": "^13.4.0", 9 | "@testing-library/user-event": "^13.5.0", 10 | "@types/jest": "^27.5.2", 11 | "@types/node": "^16.18.14", 12 | "@types/react": "^18.0.28", 13 | "@types/react-dom": "^18.0.11", 14 | "react": "^18.2.0", 15 | "react-dom": "^18.2.0", 16 | "react-dropzone": "^14.2.3", 17 | "react-markdown": "^8.0.5", 18 | "react-scripts": "5.0.1", 19 | "react-syntax-highlighter": "^15.5.0", 20 | "rehype-katex": "^6.0.2", 21 | "remark-math": "^5.1.1", 22 | "typescript": "^4.9.5", 23 | "web-vitals": "^2.1.4" 24 | }, 25 | "scripts": { 26 | "predeploy": "npm run build", 27 | "deploy": "gh-pages -d build", 28 | "start": "react-scripts start", 29 | "build": "react-scripts build", 30 | "test": "react-scripts test", 31 | "eject": "react-scripts eject" 32 | }, 33 | "eslintConfig": { 34 | "extends": [ 35 | "react-app", 36 | "react-app/jest" 37 | ] 38 | }, 39 | "browserslist": { 40 | "production": [ 41 | ">0.2%", 42 | "not dead", 43 | "not op_mini all" 44 | ], 45 | "development": [ 46 | "last 1 chrome version", 47 | "last 1 firefox version", 48 | "last 1 safari version" 49 | ] 50 | }, 51 | "devDependencies": { 52 | "@types/react-syntax-highlighter": "^15.5.6", 53 | "gh-pages": "^5.0.0" 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /comparer/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 12 | 13 | 17 | 18 | 27 | Comparer 28 | 29 | 30 | 31 |
32 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /funtuner/sampling.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from funtuner.inference import Inference 3 | from funtuner.custom_datasets.utils import DATASET_MAPPING 4 | from datasets import load_dataset 5 | import argparse 6 | 7 | def sampling(examples, model, dataset, **generation_args): 8 | 9 | dataset = DATASET_MAPPING[dataset] 10 | instruction, input = dataset.get("prompt"), dataset.get("context",None) 11 | instruction = examples[instruction] 12 | input = examples[input] if input is not None else [None] 13 | examples = list(itertools.zip_longest(instruction, input)) 14 | output = model.batch_generate(examples, **generation_args) 15 | return {"completion": output} 16 | 17 | 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | parser = argparse.ArgumentParser(description="") 23 | parser.add_argument("--model_url", type=str, help="model name") 24 | parser.add_argument("--load_8_bits", type=bool, default=False, help="model name") 25 | 26 | parser.add_argument("--dataset", type=str, help="dataset name") 27 | parser.add_argument("--split", type=str, default="test", help="dataset split") 28 | 29 | parser.add_argument("--num_samples", type=int, default=100, help="num of samples to run inference") 30 | parser.add_argument("--save_path", type=str, default="results.json", help="save path") 31 | 32 | parser.add_argument("--batch_size", type=int, default=4, help="") 33 | parser.add_argument("--temperature", type=float, default=0.1, help="") 34 | parser.add_argument("--top_p", type=float, default=0.75, help="") 35 | parser.add_argument("--top_k", type=int, default=40, help="") 36 | parser.add_argument("--num_beams", type=int, default=4, help="") 37 | parser.add_argument("--max_new_tokens", type=int, default=128, help="") 38 | 39 | generation_args = ["temperature", "top_p", "top_k", "num_beams", "max_new_tokens"] 40 | 41 | 42 | args = parser.parse_args().__dict__ 43 | 44 | generation_args = {k: args.get(k) for k in generation_args} 45 | model = Inference(args.get("model_url"), load_in_8bit=args.get("load_8_bits")) 46 | dataset = load_dataset(args.get("dataset"), split=args.get("split")).select(range(0, args.get("num_samples"))) 47 | dataset = dataset.map(lambda batch: sampling(batch, model, args.get("dataset"), **generation_args), batch_size=args.get("batch_size"), batched=True) 48 | dataset.to_json(args.get("save_path"), indent=4) 49 | 50 | -------------------------------------------------------------------------------- /comparer/src/logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evals/sampler.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datasets import load_dataset 3 | import yaml 4 | from pathlib import Path 5 | from datetime import datetime 6 | from funtuner.inference import Inference 7 | from funtuner.utils import save_json 8 | from funtuner.inference import Inference 9 | 10 | DATA = "shahules786/llm-eval" 11 | 12 | def merge_dicts(generation_args, default_args): 13 | 14 | for _, args_dict in generation_args.items(): 15 | args = {k: v for k, v in default_args.items() if k not in args_dict.keys()} 16 | args_dict.update(args) 17 | return generation_args 18 | 19 | 20 | def sampling(examples, model, generation_args): 21 | datadict = {} 22 | instruction = examples["instruction"] 23 | inputs = examples["input"] 24 | examples = list(zip(instruction, inputs)) 25 | for key, args in generation_args.items(): 26 | output = model.batch_generate(examples, **args) 27 | datadict[f"{key}_completion"] = output 28 | 29 | return datadict 30 | 31 | def update_results(dataset, generation_args: dict, results: dict): 32 | 33 | for item in dataset: 34 | sample = [] 35 | for _, key in enumerate(generation_args.keys()): 36 | sample.append({"sampling_config": key, 37 | "sampling_params": generation_args[key], 38 | "outputs": [item[f"{key}_completion"]]}) 39 | sample_dict = {"prompt": item["instruction"], "results": sample} 40 | results["prompts"].append(sample_dict) 41 | 42 | return results 43 | 44 | 45 | if __name__ == "__main__": 46 | 47 | parser = argparse.ArgumentParser(description="") 48 | parser.add_argument("--model_url", type=str, help="model name") 49 | parser.add_argument("--batch_size", type=int, default=2, help="model name") 50 | parser.add_argument("--load_8_bits", type=bool, default=False, help="model name") 51 | args = parser.parse_args().__dict__ 52 | 53 | model_name = args.get("model_url") 54 | load_8_bits = args.get("load_8_bits") 55 | model = Inference(model_name, load_8_bits) 56 | dataset = load_dataset(DATA, split="train").shuffle(seed=42) 57 | generation_args = yaml.safe_load(Path("evals/config/generation.yaml").read_text()) 58 | default_args = generation_args.pop("defaults") 59 | generation_args = merge_dicts(generation_args, default_args) 60 | dataset = dataset.map(lambda batch: sampling(batch, model, generation_args), 61 | batch_size=args.get("batch_size"), batched=True) 62 | 63 | results = { 64 | "model_name": model_name, 65 | "date": datetime.utcnow().isoformat(), 66 | "args":{ 67 | "device":"cuda", 68 | "batch_size":args.get("batch_size"), 69 | "dataset":DATA 70 | }, 71 | "prompts":[] 72 | } 73 | results = update_results(dataset, generation_args, results) 74 | model_name = model_name.split("/")[-1] 75 | save_json(f"results-{model_name}.json", results) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /funtuner/utils.py: -------------------------------------------------------------------------------- 1 | from tokenizers import pre_tokenizers 2 | from transformers import AutoConfig, AutoTokenizer, LlamaTokenizer 3 | from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING 4 | import requests 5 | import random 6 | from pynvml import * 7 | import json 8 | from glob import glob 9 | import os 10 | import torch 11 | import bitsandbytes as bnb 12 | from peft.tuners.lora import LoraLayer 13 | from omegaconf import OmegaConf 14 | 15 | MODEL_MAPPINGS = [MODEL_FOR_CAUSAL_LM_MAPPING] 16 | 17 | 18 | def get_tokenizer(config): 19 | 20 | if "llama" not in config.model: 21 | tokenizer = AutoTokenizer.from_pretrained(config.model) 22 | else: 23 | tokenizer = LlamaTokenizer.from_pretrained(config.model) 24 | 25 | 26 | if hasattr(config, "per_digit_tokens") and config.per_digit_tokens: 27 | tokenizer._tokenizer.pre_processor = pre_tokenizers.Digits(True) 28 | 29 | if config.special_tokens: 30 | special_tokens = { 31 | "pad_token": config.special_tokens.pad_token, 32 | "eos_token": config.special_tokens.eos_token, 33 | "sep_token": config.special_tokens.sep_token, 34 | } 35 | tokenizer.add_special_tokens(special_tokens) 36 | 37 | return tokenizer 38 | 39 | 40 | def get_model(name, **kwargs): 41 | model_config = AutoConfig.from_pretrained(name) 42 | for mapping in MODEL_MAPPINGS: 43 | model = mapping.get(type(model_config), None) 44 | if model is not None: 45 | return model.from_pretrained(name, config=model_config, 46 | **kwargs) 47 | 48 | def get_name(): 49 | word_site = "https://www.mit.edu/~ecprice/wordlist.10000" 50 | response = requests.get(word_site) 51 | WORDS = response.content.splitlines() 52 | return random.choice(WORDS).decode('UTF-8') 53 | 54 | 55 | def print_gpu_utilization(): 56 | nvmlInit() 57 | deviceCount = nvmlDeviceGetCount() 58 | for i in range(deviceCount): 59 | handle = nvmlDeviceGetHandleByIndex(i) 60 | info = nvmlDeviceGetMemoryInfo(handle) 61 | print(f"GPU memory occupied: {info.used//1024**2} MB.") 62 | 63 | 64 | def save_json(filename, data): 65 | with open(filename, "w") as file: 66 | json.dump(data, file, indent=4) 67 | 68 | 69 | def add_additional_config(cfg): 70 | config_files = glob(os.path.join(cfg.log_dir, "**/*.json"), recursive=True) 71 | for file in config_files: 72 | config = json.load(open(file)) 73 | config["template"] = cfg.template 74 | config["train_max_len"] = cfg.max_length 75 | save_json(file, config) 76 | 77 | 78 | def get_lora_modules(model, cfg): 79 | 80 | modules = cfg.LoraConfig.target_modules 81 | cls = bnb.nn.Linear4bit if cfg.load_in_4_bit == 4 else (bnb.nn.Linear8bitLt if cfg.load_in_8_bit == 8 else torch.nn.Linear) 82 | if modules != "all": 83 | return modules 84 | 85 | modules = { 86 | name.split('.')[-1] 87 | for name, module in model.named_modules() 88 | if isinstance(module, cls) 89 | } 90 | if 'lm_head' in modules: 91 | modules.remove('lm_head') 92 | return list(modules) 93 | 94 | 95 | def prepare_model_types(model, cfg): 96 | 97 | for name, module in model.named_modules(): 98 | if isinstance(model, LoraLayer): 99 | if cfg.trainer.bf16: 100 | module = module.to(torch.bfloat16) 101 | if "norm" in name: 102 | module = module.to(torch.float32) 103 | if "lm_head" in name or "embed_tokens" in name: 104 | if hasattr(module, "weight"): 105 | if cfg.trainer.bf16: 106 | module = module.to(torch.bfloat16) 107 | return model 108 | -------------------------------------------------------------------------------- /comparer/src/components/Prompt.tsx: -------------------------------------------------------------------------------- 1 | import { PromptResults } from "../FileComparer"; 2 | import ReactMarkdown from 'react-markdown'; 3 | import {Prism as SyntaxHighlighter} from 'react-syntax-highlighter'; 4 | import {dark} from 'react-syntax-highlighter/dist/esm/styles/prism' 5 | import './Prompt.css'; 6 | import remarkMath from 'remark-math'; 7 | import rehypeKatex from 'rehype-katex'; 8 | import 'katex/dist/katex.min.css' 9 | 10 | interface PromptProps { 11 | prompt: string; 12 | results: PromptResults[]; 13 | showSamplingMethod: boolean; 14 | outputIndex: number; 15 | showSamplingConfig: boolean; 16 | collapsed?: boolean; 17 | onToggleCollapsed?: () => void; 18 | renderMarkdown: boolean; 19 | }; 20 | 21 | export const Prompt = ({collapsed, onToggleCollapsed, prompt, results, outputIndex, showSamplingMethod, showSamplingConfig, renderMarkdown}: PromptProps) => { 22 | return ( 23 |
24 |
{prompt}
25 | { !collapsed && 26 | 27 | 28 | {results.map((result, modelIndex) => 29 | 30 | 42 | )} 43 | 44 |
{result.file.model_name}{result.results.map((result, result_index) => ( 31 |
32 | {(showSamplingMethod || showSamplingConfig) &&
33 | {showSamplingMethod &&
Sampling config: {result.sampling_config}
} 34 | {showSamplingConfig && Object.keys(result.sampling_params).map(param =>
{param}: {result.sampling_params[param].toString()}
)} 35 |
} 36 | {outputIndex === -1 ? 37 | result.outputs.map((output, index) => ) : 38 | 39 | } 40 |
41 | ))}
45 | } 46 |
47 | ); 48 | } 49 | 50 | type ReplyBubbleProps = { modelIndex: number, output: string, saturation: number, renderMarkdown: boolean }; 51 | 52 | const endoftextToken = '<|endoftext|>'; 53 | const ReplyBubble = ({modelIndex, output, saturation, renderMarkdown} : ReplyBubbleProps) => { 54 | // remove <|endoftext|> from the end of the output if it is present 55 | if (output.endsWith(endoftextToken)) { 56 | output = output.slice(0, -endoftextToken.length); 57 | } 58 | const out = {output}; 59 | return
{out}
60 | } 61 | 62 | type RenderMarkdownProps = { 63 | children: string; 64 | renderMarkdown: boolean 65 | } 66 | 67 | const RenderMarkdown = ({children, renderMarkdown} : RenderMarkdownProps ) => { 68 | return renderMarkdown ? 83 | ) : ( 84 | 85 | {children} 86 | 87 | ) 88 | } 89 | }} 90 | /> :
{children}
; 91 | } -------------------------------------------------------------------------------- /funtuner/inference.py: -------------------------------------------------------------------------------- 1 | from peft import PeftModel 2 | from funtuner.utils import get_model 3 | from transformers import AutoTokenizer, LlamaTokenizer 4 | from funtuner.custom_datasets.sftdataset import PromptFormater 5 | from typing import List, Optional 6 | import torch 7 | from huggingface_hub import hf_hub_download 8 | import json 9 | import os 10 | 11 | class Inference: 12 | def __init__( 13 | self, 14 | model_name:str, 15 | load_in_8bit:bool=False, 16 | ): 17 | 18 | self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 19 | config = self.load_config(model_name) 20 | base_model = config["base_model_name_or_path"] 21 | 22 | self.tokenizer = self.load_tokenizer(model_name, base_model) 23 | model = get_model(base_model, load_in_8bit) 24 | 25 | 26 | model.resize_token_embeddings(len(self.tokenizer)) 27 | self.model = PeftModel.from_pretrained(model, model_name).eval() 28 | if not load_in_8bit: 29 | self.model = self.model.half() 30 | self.model.to(self.device) 31 | self.tokenizer.padding_side = "left" 32 | self.template = PromptFormater(config.get("template", "alpaca-lora")) 33 | 34 | def load_config(self, model_name): 35 | 36 | if os.path.exists(model_name): 37 | config = os.path.join(model_name, "adapter_config.json") 38 | else: 39 | config = hf_hub_download(repo_id=model_name, filename="adapter_config.json", local_dir=".") 40 | config = "adapter_config.json" 41 | 42 | config = json.load(open(config)) 43 | return config 44 | 45 | def load_tokenizer(self, model_name, base_model): 46 | 47 | if "llama" not in base_model: 48 | tokenizer = AutoTokenizer.from_pretrained(model_name) 49 | else: 50 | tokenizer = LlamaTokenizer.from_pretrained(model_name) 51 | return tokenizer 52 | def generate(self, 53 | instruction:str, 54 | context:Optional[str]=None, 55 | **kwargs, 56 | ): 57 | 58 | text = self.template.format(instruction, context) 59 | inputs = self.tokenizer(text, return_tensors="pt").to(self.device) 60 | kwargs.update({ 61 | "input_ids": inputs["input_ids"], 62 | "attention_mask": inputs["attention_mask"], 63 | "pad_token_id": self.tokenizer.pad_token_id, 64 | "eos_token_id": self.tokenizer.eos_token_id, 65 | }) 66 | with torch.no_grad(): 67 | output = self.model.generate(**kwargs)[0] 68 | output = self.tokenizer.decode(output, skip_special_tokens=True) 69 | return self.template.response(output) 70 | 71 | def batch_generate( 72 | self, 73 | inputs: List[List[str]], 74 | **kwargs, 75 | ): 76 | # TODO: Add batch_size and iterate if needed 77 | format_inputs = [item if (len(item) == 2 and item[-1] != "") else [item[0],None] for item in inputs ] 78 | format_inputs = [self.template.format(instruction, context) for instruction, context in format_inputs] 79 | format_inputs = self.tokenizer.batch_encode_plus(format_inputs, return_attention_mask=True, 80 | return_tensors="pt", padding="longest").to(self.device) 81 | kwargs.update({ 82 | "input_ids":format_inputs["input_ids"], 83 | "attention_mask":format_inputs["attention_mask"], 84 | "pad_token_id": self.tokenizer.pad_token_id, 85 | "eos_token_id": self.tokenizer.eos_token_id, 86 | }) 87 | 88 | output = self.model.generate(**kwargs) 89 | output = self.tokenizer.batch_decode(output, skip_special_tokens=True) 90 | return [self.template.response(text) for text in output] 91 | 92 | 93 | 94 | 95 | 96 | 97 | if __name__ == "__main__": 98 | 99 | model = Inference("shahules786/GPTNeo-125M-lora") 100 | kwargs = {"temperature":0.1, 101 | "top_p":0.75, 102 | "top_k":5, 103 | "num_beams":2, 104 | "max_new_tokens":128,} 105 | print(model.generate("Which is a species of fish? Tope or Rope", **kwargs)) -------------------------------------------------------------------------------- /comparer/src/FileComparer.tsx: -------------------------------------------------------------------------------- 1 | import { useEffect, useMemo, useState } from "react"; 2 | import { JsonFile, JsonFilePrompt } from "./Comparer"; 3 | import { Prompt } from "./components/Prompt"; 4 | import './FileComparer.css' 5 | import { useStateWithLocalStorageBoolean, useStateWithLocalStorageInt, useStateWithLocalStorageString } from "./useHooks"; 6 | 7 | export interface PromptResults { 8 | file: JsonFile; 9 | results: JsonFilePrompt[]; 10 | } 11 | 12 | export const FileComparer = ({files}: {files:JsonFile[]}) => { 13 | const samplingMethods = useMemo(() => { 14 | const s = new Set(); 15 | files.forEach(file => file?.prompts?.forEach(p => p.results.forEach(result => s.add(result.sampling_config)))); 16 | return Array.from(s.values()); 17 | }, [files]); 18 | 19 | const [samplingMethod, setSamplingMethod] = useStateWithLocalStorageString(samplingMethods[0] || 'beam5', 'samplingMethod'); 20 | const [outputIndex, setOutputIndex] = useStateWithLocalStorageInt(0, 'outputIndex'); // -1 for all 21 | const [showSamplingConfig, setShowSamplingConfig] = useStateWithLocalStorageBoolean(true, 'samplingConfig'); 22 | const [renderMarkdown, setRenderMarkdown] = useStateWithLocalStorageBoolean(true, 'renderMarkdown'); 23 | const [expandedPrompts, setExpandedPrompts] = useState>(new Set()); 24 | 25 | useEffect(() => { 26 | if (samplingMethods && samplingMethods.length > 0 && samplingMethod && !samplingMethods.includes(samplingMethod)) { 27 | setSamplingMethod(samplingMethods[0] || ''); 28 | } 29 | }, [samplingMethods, setSamplingMethod, samplingMethod]); 30 | 31 | const toggleExpandedPrompts = () => { 32 | if(expandedPrompts.size === 0) { 33 | setExpandedPrompts(new Set(Object.keys(prompts))); 34 | } else { 35 | setExpandedPrompts(new Set()); 36 | } 37 | } 38 | const toggleCollapsed = (prompt: string, value: boolean) => { 39 | const newExpandedPrompts = new Set(expandedPrompts); 40 | if (value) { 41 | newExpandedPrompts.add(prompt); 42 | } else { 43 | newExpandedPrompts.delete(prompt); 44 | } 45 | setExpandedPrompts(newExpandedPrompts); 46 | } 47 | 48 | const prompts = useMemo(() => { 49 | const prompts: {[prompt: string]: PromptResults[]} = {}; 50 | files.forEach(file => { 51 | file.prompts?.forEach(p => { 52 | if (!prompts[p.prompt]) { 53 | prompts[p.prompt] = []; 54 | } 55 | prompts[p.prompt] = [...prompts[p.prompt], {file, results: p.results.filter(r => !samplingMethod || (r.sampling_config === samplingMethod))}]; 56 | }); 57 | }); 58 | return prompts; 59 | }, [files, samplingMethod]); 60 | 61 | return ( 62 |
63 |
64 |
65 | 66 | 70 |
71 |
72 | 73 | 78 |
79 |
80 | setShowSamplingConfig(e.target.checked)}/> 81 | 82 |
83 |
84 | setRenderMarkdown(e.target.checked)}/> 85 | 86 |
87 |
88 | 89 |
90 |
91 | {Object.keys(prompts).map((p) => 92 | toggleCollapsed(p, !expandedPrompts.has(p))} 101 | renderMarkdown={renderMarkdown} 102 | /> 103 | )} 104 |
105 | ); 106 | } 107 | -------------------------------------------------------------------------------- /funtuner/custom_datasets/sftdataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from datasets import load_dataset 3 | from typing import Union, Optional 4 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 5 | from dataclasses import dataclass 6 | from datasets import Split, DatasetDict, concatenate_datasets 7 | import json 8 | from omegaconf import OmegaConf 9 | import torch 10 | 11 | 12 | class PromptFormater: 13 | def __init__(self, template): 14 | self.template = json.load(open("funtuner/config/templates.json"))[template] 15 | 16 | def format( 17 | self, 18 | instruction: str, 19 | context: Optional[str] = None, 20 | ): 21 | return ( 22 | self.template["prompt_and_input"].format( 23 | instruction=instruction, context=context 24 | ) 25 | if context 26 | else self.template["prompt_only"].format(instruction=instruction) 27 | ) 28 | def response( 29 | self, 30 | generated_text:str 31 | ): 32 | index = generated_text.find(self.template["response_split"]) 33 | return generated_text[index + len(self.template["response_split"]):] if index != -1 else generated_text 34 | 35 | class FunDataset(Dataset): 36 | def __init__( 37 | self, 38 | name: str = "databricks/databricks-dolly-15k", 39 | split: Optional[Union[str, Split]] = "train", 40 | template: str = "alpaca-lora", 41 | **kwargs, 42 | ): 43 | split = OmegaConf.to_object(split) 44 | if isinstance(split, list) and len(split) == 1: 45 | split = split[0] 46 | self.dataset = load_dataset(name, split=split) 47 | if isinstance(self.dataset, list): 48 | self.dataset = concatenate_datasets(self.dataset) 49 | if isinstance(self.dataset, DatasetDict): 50 | features = self.dataset[list(self.dataset.keys())[0]].features.keys() 51 | else: 52 | features = self.dataset.features.keys() 53 | self.prompt = kwargs.get("prompt", "instruction") 54 | self.context = kwargs.get("context", None) 55 | self.response = kwargs.get("response", "response") 56 | for col in [self.prompt, self.context, self.response]: 57 | if (col is not None) and (col not in features): 58 | raise ValueError(f"feature {col} is not present in {name}") 59 | 60 | self.prompt_formater = PromptFormater(template) 61 | 62 | def __len__(self): 63 | return len(self.dataset) 64 | 65 | def __getitem__(self, index): 66 | item = self.dataset[index] 67 | prompt, response = item[self.prompt], item[self.response] 68 | if self.context is not None: 69 | context = item[self.context] 70 | else: 71 | context = None 72 | 73 | return self.prompt_formater.format(prompt, context), response 74 | 75 | 76 | @dataclass 77 | class FunDataCollator: 78 | tokenizer: PreTrainedTokenizerBase 79 | max_length: int = 512 80 | 81 | def __call__(self, batch): 82 | batch_maxlen = 0 83 | batch_input_ids, batch_label_ids = [], [] 84 | prompts, responses = zip(*batch) 85 | prompt_tokens = self.tokenizer.batch_encode_plus( 86 | prompts, return_attention_mask=False 87 | ).input_ids 88 | response_tokens = self.tokenizer.batch_encode_plus( 89 | responses, return_attention_mask=False 90 | ).input_ids 91 | for prompt, rsp in zip(prompt_tokens, response_tokens): 92 | input_ids = prompt + rsp 93 | input_len = len(input_ids) 94 | if input_len > (self.max_length - 1): 95 | trun_len = (input_len - self.max_length + 1) 96 | input_ids = input_ids[:-trun_len] 97 | 98 | input_ids += [self.tokenizer.eos_token_id] 99 | label_ids = input_ids.copy() 100 | label_ids[: len(prompt)] = [-100] * min(len(input_ids), len(prompt)) 101 | if len(input_ids) > batch_maxlen: 102 | batch_maxlen = len(input_ids) 103 | 104 | batch_input_ids.append(input_ids) 105 | batch_label_ids.append(label_ids) 106 | 107 | batch = self.tokenizer.pad( 108 | {"input_ids": batch_input_ids}, 109 | max_length=batch_maxlen, 110 | return_attention_mask=True, 111 | return_tensors="pt", 112 | ) 113 | batch_label_ids = self.tokenizer.pad( 114 | {"input_ids": batch_label_ids}, 115 | max_length=batch_maxlen, 116 | return_attention_mask=False, 117 | return_tensors="pt", 118 | )["input_ids"] 119 | batch_label_ids = torch.where(batch_label_ids==self.tokenizer.pad_token_id, -100, batch_label_ids) 120 | 121 | batch["labels"] = batch_label_ids 122 | return batch 123 | -------------------------------------------------------------------------------- /funtuner/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | from hydra.utils import instantiate 5 | from omegaconf import DictConfig 6 | from transformers import Trainer, BitsAndBytesConfig 7 | from funtuner.custom_datasets import get_datasets, FunDataCollator 8 | from funtuner.utils import get_model, get_name, get_tokenizer, add_additional_config, get_lora_modules, prepare_model_types 9 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 10 | from omegaconf import OmegaConf 11 | import torch 12 | import warnings 13 | warnings.filterwarnings("ignore") 14 | 15 | JOB_ID = os.environ.get("SLURM_JOB_ID",None) 16 | class FunTrainer(Trainer): 17 | def __init__(self, **kwargs): 18 | super().__init__(**kwargs) 19 | 20 | def compute_loss(self, model, inputs, return_outputs=False): 21 | outputs = model(**inputs) 22 | loss = outputs.get("loss") 23 | return (loss, outputs) if return_outputs else loss 24 | 25 | 26 | @hydra.main(version_base=None, config_path="config", config_name="config") 27 | def train(cfg: DictConfig) -> None: 28 | random_runname = get_name() 29 | if not os.path.exists(cfg.log_dir): 30 | os.mkdir(cfg.log_dir) 31 | if JOB_ID is not None: 32 | cfg.log_dir = os.path.join(cfg.log_dir, JOB_ID) 33 | if not os.path.exists(cfg.log_dir): 34 | os.mkdir(cfg.log_dir) 35 | 36 | 37 | if not cfg.log_wandb: 38 | os.environ["WANDB_MODE"] = "offline" 39 | 40 | if cfg.log_wandb: 41 | import wandb 42 | if cfg.run_name == "": 43 | cfg.run_name = random_runname 44 | name = f"{cfg.model.split('/')[-1]}-{cfg.run_name}" 45 | wandb.init( 46 | project="Funtuner", 47 | entity=cfg.wandb_entity, 48 | name=name, 49 | config=cfg, 50 | ) 51 | print("DEVICES", [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]) 52 | model_args = {"load_in_4bit": cfg.load_in_4_bit, "load_in_8bit": cfg.load_in_8_bit} 53 | if cfg.qlora: 54 | compute_dtype = (torch.float16 if cfg.trainer.fp16 else (torch.bfloat16 if cfg.trainer.bf16 else torch.float32)) 55 | quantization_config = BitsAndBytesConfig( 56 | load_in_8bit=cfg.load_in_8_bit, 57 | load_in_4bit=cfg.load_in_4_bit, 58 | llm_int8_threshold=6.0, 59 | llm_int8_has_fp16_weight=False, 60 | bnb_4bit_compute_dtype=compute_dtype, 61 | bnb_4bit_use_double_quant=cfg.qlora_config.double_quant, 62 | bnb_4bit_quant_type=cfg.qlora_config.quant_type, 63 | ) 64 | model_args.update({"quantization_config":quantization_config, 65 | "torch_dtype":compute_dtype}) 66 | 67 | model = get_model(cfg.model, **model_args) 68 | setattr(model, 'model_parallel', True) 69 | setattr(model, 'is_parallelizable', True) 70 | tokenizer = get_tokenizer(cfg) 71 | model.resize_token_embeddings(len(tokenizer)) 72 | 73 | if hasattr(model, "enable_input_require_grads"): 74 | model.enable_input_require_grads() 75 | else: 76 | def make_inputs_require_grad(module, input, output): 77 | output.requires_grad_(True) 78 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 79 | 80 | if cfg.qlora: 81 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=cfg.trainer.gradient_checkpointing) 82 | 83 | 84 | if cfg.LoRa: 85 | Lora_config = OmegaConf.to_object(cfg.LoraConfig) 86 | Lora_config["target_modules"] = get_lora_modules(model, cfg) 87 | Lora_config = LoraConfig( 88 | **Lora_config 89 | ) 90 | 91 | model = get_peft_model(model, Lora_config) 92 | print("--------LoRA------------") 93 | model.print_trainable_parameters() 94 | 95 | model = prepare_model_types(model, cfg) 96 | 97 | 98 | training_args = instantiate( 99 | cfg.trainer, 100 | deepspeed=cfg.deepspeed_config if cfg.deepspeed else None, 101 | report_to="wandb" if cfg.log_wandb else None, 102 | output_dir=cfg.log_dir, 103 | 104 | ) 105 | train_dataset, validation_dataset = get_datasets(cfg) 106 | 107 | datacollator = FunDataCollator( 108 | tokenizer=tokenizer, 109 | max_length=cfg.max_length, 110 | ) 111 | 112 | # from tqdm import tqdm 113 | # for i,item in tqdm(enumerate(train_dataset)): 114 | # output = datacollator([item]) 115 | # if not (output["input_ids"].size() == output["labels"].size() == output["attention_mask"].size()) : 116 | # print("ERROR",i) 117 | # if output['input_ids'].size(-1) > 512: 118 | # print(output['input_ids'].size(-1)) 119 | # Initialize our Trainer 120 | trainer = FunTrainer( 121 | model=model, 122 | args=training_args, 123 | train_dataset=train_dataset, 124 | eval_dataset=validation_dataset, 125 | data_collator=datacollator, 126 | ) 127 | 128 | # training 129 | trainer.train() 130 | 131 | # save 132 | trainer.save_model(os.path.join(cfg.log_dir, f"{cfg.model.split('/')[-1]}-model")) 133 | tokenizer.save_pretrained(cfg.log_dir) 134 | add_additional_config(cfg.log_dir) 135 | 136 | if __name__ == "__main__": 137 | train() 138 | -------------------------------------------------------------------------------- /comparer/src/FileLoaderTextArea.tsx: -------------------------------------------------------------------------------- 1 | import { useCallback, useEffect, useState } from "react"; 2 | import { JsonFile } from "./Comparer"; 3 | import { useDropzone } from 'react-dropzone'; 4 | import './FileLoaderTextArea.css'; 5 | 6 | const someUrls = `https://raw.githubusercontent.com/LAION-AI/Open-Assistant/main/model/model_eval/manual/sampling_reports/2023-03-01_theblackcat102_pythia-12b-deduped-sft_sampling.json 7 | https://raw.githubusercontent.com/LAION-AI/Open-Assistant/main/model/model_eval/manual/sampling_reports/2023-03-01_theblackcat102_pythia-1b-deduped-sft_sampling.json 8 | https://raw.githubusercontent.com/LAION-AI/Open-Assistant/main/model/model_eval/manual/sampling_reports/2023-03-01_theblackcat102_pythia-3b-deduped-sft_sampling.json`; 9 | 10 | const fileCache: {[key:string]:JsonFile } = {} 11 | 12 | async function catchedFetch(url: string) { 13 | if (fileCache[url]) { 14 | return fileCache[url]; 15 | } 16 | return await fetch(url).then(r => r.json()).then(json => { 17 | fileCache[url] = json; 18 | return json; 19 | }); 20 | } 21 | 22 | interface FileLoaderTextAreaProps { 23 | setFiles: (files: (JsonFile| undefined)[] | ((f: (JsonFile | undefined)[]) => (JsonFile| undefined)[])) => void; 24 | localFileAdded: (file: JsonFile) => void; 25 | localFiles: JsonFile[]; 26 | } 27 | 28 | export const FileLoaderTextArea = ({setFiles, localFiles, localFileAdded} : FileLoaderTextAreaProps) => { 29 | // get window location param for ?f= 30 | const urlParams = new URLSearchParams(window.location.search); 31 | const filename = urlParams.get('f'); 32 | const [filenamesTxt, setFilenamesTxt] = useState(filename ?? localStorage.getItem('filenames') ?? someUrls); 33 | const [fileErrors, setFileErrors] = useState([]); 34 | const [loading, setLoading] = useState(0); 35 | 36 | useEffect(() => { 37 | localStorage.setItem('filenames', filenamesTxt); 38 | // Update window location search param with f= filenamesTxt 39 | const url = new URL(window.location.href); 40 | if (filenamesTxt) { 41 | url.searchParams.set('f', filenamesTxt); 42 | } else { 43 | url.searchParams.delete('f'); 44 | } 45 | window.history.pushState({}, '', url.toString()); 46 | }, [filenamesTxt]); 47 | 48 | useEffect(() => { 49 | const urls = filenamesTxt.split(/[\r\n]/).map(s => s.trim()).filter(s => s); 50 | setLoading(urls.length); 51 | setFileErrors([]); 52 | setFiles([]); 53 | urls.forEach((url, index) => catchedFetch(url).then(json => { 54 | setFiles(fs => { 55 | const fs_ = [...fs]; 56 | fs_[index] = json; 57 | return fs_; 58 | }); 59 | setFileErrors(es => { const es_ = [...es]; es_[index] = ''; return es_}); 60 | }).catch(e => { 61 | setFileErrors(es => { const es_ = [...es]; es_[index] = `Failed to load ${url}: ${e.toString()}`; return es_}); 62 | }).finally(() => setLoading(l => l - 1))); 63 | }, [filenamesTxt, setFiles]); 64 | 65 | return ( 66 | <> 67 | 68 | {loading > 0 &&
Loading... ({loading} more)
} 69 |
{fileErrors.map((e, i) => e &&
{e}
)}
70 | }> 71 |
0 ? "loading_wait" : fileErrors.some(e => e) ? "loading_errors" : "loading_success"}> 72 |