├── LICENSE.md ├── README.md ├── config.png ├── lloom.py ├── screenshot.png ├── search.py └── viz.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright © 2024 Mikhail Ravkine 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The LLooM 2 | 3 | Leverage raw LLM logits to weave the threads probability a few tokens at a time. 4 | 5 | The problem with straight greedy decoding is that due to the self-recursive nature of LLMs, if there's a high-probability token hidden behind a low-probability one then greedy wont find it. 6 | 7 | Conceptually this idea is similar to beamsearching, tracking multiple candidates at once, but with a human in the loop and unlimited beams. 8 | 9 | # News 10 | 11 | *06/02* Released **v0.3** with [vLLM](https://github.com/vllm-project/vllm) support and a quality of life improvement: if a suggestion beam starts with a stop token, it will be allowed to continue. 12 | 13 | # Screenshot 14 | 15 | ![LLooM Screenshot](screenshot.png "LLooM Screenshot") 16 | 17 | # Using 18 | 19 | Give the LLooM a starting prompt, or change the Story any time by directly editing in the top input area and pressing Ctrl+Enter. 20 | 21 | Click ➡️ beside a suggestion to accept it, or edit the suggestion (press Enter when done) in-line before accepting. 22 | 23 | *Have fun!* 24 | 25 | # Launching 26 | 27 | ## Prepare environment 28 | 29 | `pip3 install requests graphviz streamlit` 30 | 31 | `sudo apt-get install -y graphviz` 32 | 33 | ## Usage with vLLM 34 | 35 | Download an appropriate GPTQ or AWQ quant for your system such as [study-hjt/Meta-Llama-3-70B-Instruct-GPTQ-Int4](https://huggingface.co/study-hjt/Meta-Llama-3-70B-Instruct-GPTQ-Int4). 36 | 37 | Launch a vllm openAI server: 38 | 39 | ``` 40 | python3 -m vllm.entrypoints.openai.api_server --model ~/models/study-hjt-Meta-Llama-3-70B-Instruct-GPTQ-Int4/ --enable-prefix-cache 41 | ``` 42 | 43 | Remember to add `--tensor-parallel-size ` if you have multiple GPUs. 44 | 45 | Then launch the frontend with `VLLM_API_URL` set to the host and port of the server: 46 | 47 | ``` 48 | LLAMA_PIPELINE_REQUESTS=6 VLLM_API_URL=http://127.0.0.1:8000 streamlit run lloom.py 49 | ``` 50 | 51 | Tweak `LLAMA_PIPELINE_REQUESTS` up until it stops being faster (if you have powerful GPUs). 52 | 53 | ### If you 48GB VRAM and want to Fit Llama3-70B-GPTQ model 54 | 55 | Add `--enforce-eager --max-model-len 2048 --gpu_memory_utilization 1.0` to vllm server command line. 56 | 57 | ### If you have P100 58 | 59 | Use [vllm-ci](https://github.com/sasha0552/vllm-ci) 60 | 61 | `export VLLM_ATTENTION_BACKEND=XFORMERS` to force xformers. 62 | 63 | Add `--dtype half` to vllm server command line. 64 | 65 | If you get the "Cannot convert f16 to f16 crash" either remove `--enable-prefix-cache` or build triton from source with the patches from the vllm-ci repo. 66 | 67 | ## Usage with llama.cpp 68 | 69 | Download an appropriate quant for your system from [dolphin-2.9-llama3-70b-GGUF](https://huggingface.co/crusoeai/dolphin-2.9-llama3-70b-GGUF) 70 | 71 | Launch a llama.cpp server with a good Llama3-70B finetune: 72 | 73 | ``` 74 | ./server -m ~/models/dolphin-2.9-llama3-70b.Q4_K_M.gguf -ngl 99 -sm row --host 0.0.0.0 -c 8192 --log-format text 75 | ``` 76 | 77 | | :exclamation: Note that you cannot use -fa as this results in all the logits being `null` and its strongly discouraged to launch with any kind of parallelism because this both reduces available context size and seems to break the KV caching so performance suffers. | 78 | |-----------------------------------------| 79 | 80 | 81 | Then launch the frontend with `LLAMA_API_URL` set to the host and port of the server: 82 | 83 | | :exclamation: LLooM makes a large number of network calls and is latency sensitive, make sure the llama.cpp server is running on the same machine or LAN as the frontend to avoid degraded performance. If you cannot avoid going over a high-latency, connection setting `LLAMA_PIPELINE_REQUESTS=2` should improve performance. | 84 | |-----------------------------------------| 85 | 86 | ``` 87 | LLAMA_API_URL=http://127.0.0.1:8080 streamlit run lloom.py 88 | ``` 89 | 90 | ## Usage with OpenAI 91 | 92 | Launch the frontend with `OPENAI_API_KEY` 93 | 94 | ``` 95 | OPENAI_API_KEY=sk-... streamlit run lloom.py 96 | ``` 97 | 98 | Model is currently hard-coded to `gpt-3.5-turbo`. 99 | 100 | # Configuration 101 | 102 | You can open the Configuration dropdown at the top at any time to adjust parameters. 103 | 104 | ![LLooM Screenshot](config.png "LLooM Screenshot") 105 | 106 | The parameters are grouped into two sections: when to stop, and when to split. 107 | 108 | ## Stop Conditions 109 | 110 | `Auto-Stop` Early-terminate suggestion beams when a "." or "," character in encountered. 111 | 112 | `Max Depth` The maximum number of tokens a suggestion beam can have. Note that if you disable the Auto-Stop condition, then all beams will have exactly this number of tokens. 113 | 114 | `Maximum Suggestions` The maximum number of completed suggestion beams to return (this can be really useful to limit run-time if the model is slow). 115 | 116 | ## Split Conditions 117 | 118 | `Cutoff` The minimum token probability (0.0 - 1.0) to spawn a new thread. 119 | 120 | `Multiplier` cutoff per token slope (1.0: fixed cutoff, <1.0 cutoff decreases with depth, >1.0 cutoff increases with depth) 121 | 122 | `Split Limit` the maximum number of times a suggestion beam can split (analogous to beamsearch with top-k) 123 | -------------------------------------------------------------------------------- /config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-crypt-keeper/LLooM/a77ec96ae4ea0eaa8d86c4d7246ae882a4d003c5/config.png -------------------------------------------------------------------------------- /lloom.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import hashlib 3 | import time 4 | import os 5 | 6 | from viz import visualize_common_prefixes 7 | from search import parallel_lloom_search 8 | 9 | STARTING_STORIES = [ 10 | "Once upon a time,", 11 | "The forest seemed darker then usual, but that did not bother Elis in the least.", 12 | "In the age before man," 13 | ] 14 | 15 | LLAMA_PIPELINE_REQUESTS = int(os.getenv('LLAMA_PIPELINE_REQUESTS', 1)) 16 | print("LLAMA_PIPELINE_REQUESTS", LLAMA_PIPELINE_REQUESTS) 17 | 18 | def computeMD5hash(my_string): 19 | m = hashlib.md5() 20 | m.update(my_string.encode('utf-8')) 21 | return m.hexdigest() 22 | 23 | def main(): 24 | st.set_page_config(layout='wide', page_title='The LLooM') 25 | st.markdown(""" 26 | 37 | """, unsafe_allow_html=True) 38 | 39 | if 'page' not in st.session_state: 40 | st.session_state.page = 0 41 | st.session_state.threads = None 42 | 43 | logo, config = st.columns((1,5)) 44 | logo.markdown("### The LLooM :green[v0.3]") 45 | 46 | with config.expander('Configuration', expanded=False): 47 | config_cols = st.columns((1,1)) 48 | config_cols[0].markdown('_Stop conditions_') 49 | story_depth = config_cols[0].checkbox("Auto-Stop (early terminate if a period or comma is encountered)", value=True) 50 | depth = config_cols[0].number_input("Maximum Depth", min_value=1, max_value=50, value=12, help="Terminate a sugguestion when it gets this long") 51 | maxsuggestions = config_cols[0].number_input("Beam Limit", min_value=5, max_value=100, value=25, help="Stop spawning new beams when the number of suggestions hits this limit") 52 | 53 | config_cols[1].markdown('_Split conditions_\n\nLower the Cutoff to get more variety (at the expense of quality and speed), raise Cutoff for a smaller number of better suggestions.') 54 | cutoff = config_cols[1].number_input("Cutoff", help="Minimum propability of a token to have it split a new suggestion beam", min_value=0.0, max_value=1.0, value=0.2, step=0.01) 55 | multiplier = config_cols[1].number_input("Multiplier", help="The cutoff is scaled by Multiplier each time a new token is generated", min_value=0.0, max_value=2.0, value=1.0, step=0.1) 56 | maxsplits = config_cols[1].number_input("Split Limit", help="The maximum number of splits from a single source token, raise to get more variety.", min_value=0, max_value=10, value=3) 57 | 58 | left, right = st.columns((2,3)) 59 | 60 | if st.session_state.page == 0: 61 | st.write('Open the Configuration panel above to adjust settings, Auto-depth mode is particularly useful at the expense of longer generation speeds. You will be able to change settings at any point and regenerate suggestions.\n\nThe starting prompts below are just suggestions, once inside the playground you can fully edit the prompt.') 62 | start_prompt = st.selectbox("Start Prompt", STARTING_STORIES, index=1) 63 | if st.button("Start"): 64 | st.session_state.story_so_far = start_prompt 65 | st.session_state.page = 1 66 | st.rerun() 67 | else: 68 | story_so_far = st.session_state.story_so_far 69 | new_story_so_far = left.text_area("Story so far", story_so_far, label_visibility='hidden', height=300) 70 | if left.button('Suggest Again'): 71 | story_so_far = new_story_so_far 72 | st.session_state.story_so_far = story_so_far 73 | st.session_state.threads = None 74 | 75 | if st.session_state.threads == None: 76 | please_wait = st.empty() 77 | t0 = time.time() 78 | tokens = 0 79 | 80 | with please_wait.status('Searching for suggestions, please wait..') as status: 81 | threads = [] 82 | for thread in parallel_lloom_search(story_so_far, depth, maxsuggestions, ['.',','] if story_depth else [], cutoff, multiplier, maxsplits, LLAMA_PIPELINE_REQUESTS): 83 | label = thread[1][len(story_so_far):] 84 | status.update(label=label, state="running") 85 | threads.append(thread) 86 | tokens += thread[2] 87 | 88 | delta = time.time() - t0 89 | tps = tokens/delta 90 | status.update(label=f"Search completed, found {len(threads)} suggestion in {delta:.2f}s @ {tps:.2f} tokens/sec", state="complete", expanded=False) 91 | 92 | sorted_threads = sorted(threads, key=lambda x: x[0], reverse=True) 93 | 94 | # remove duplicate threads 95 | dedupe = {} 96 | good_threads = [] 97 | add_space = False 98 | for prob, thread, depth in sorted_threads: 99 | new_tokens = thread[len(story_so_far):] 100 | if new_tokens[0] == ' ': 101 | new_tokens = new_tokens[1:] 102 | thread = story_so_far + " " + thread[len(story_so_far):] 103 | add_space = True 104 | if dedupe.get(new_tokens) is None: 105 | dedupe[new_tokens] = prob 106 | good_threads.append( (prob, new_tokens) ) 107 | 108 | st.session_state.threads = good_threads 109 | st.session_state.add_space = add_space 110 | 111 | # if there is only one option - take it. 112 | if len(good_threads) == 1: 113 | st.session_state.story_so_far += (" " if add_space else "") + good_threads[0][1] 114 | st.session_state.threads = None 115 | st.rerun() 116 | 117 | threads = st.session_state.threads 118 | add_space = st.session_state.add_space 119 | 120 | labels = [ thread for prob, thread in threads ] 121 | viz = visualize_common_prefixes(labels) 122 | with right: 123 | st.graphviz_chart(viz) 124 | st.download_button('Download DOT Graph', viz.source, 'graph.dot', 'text/plain') 125 | st.download_button('Download PNG', viz.pipe(format='png'), 'graph.png', 'image/png') 126 | 127 | controls = st.container() 128 | buttons = st.container() 129 | 130 | with controls: 131 | user_add_space = st.checkbox("Prefix space", value=add_space, key=computeMD5hash('prefix-'+story_so_far)) 132 | 133 | sum_probs = sum([prob for prob, _ in threads]) 134 | with buttons: 135 | for prob, thread in threads: 136 | col1, col2 = st.columns((3,1)) 137 | col2.progress(value=prob/sum_probs) 138 | new_text = col1.text_input(thread, value=thread, key='text-'+computeMD5hash(thread), label_visibility='hidden') 139 | if col2.button(':arrow_right:', key='ok-'+computeMD5hash(thread)): 140 | st.session_state.story_so_far += (" " if user_add_space else "") + new_text 141 | st.session_state.threads = None 142 | st.rerun() 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/the-crypt-keeper/LLooM/a77ec96ae4ea0eaa8d86c4d7246ae882a4d003c5/screenshot.png -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | openai_client = None 5 | def get_logprobs_openai(prompt, model="gpt-3.5-turbo"): 6 | global openai_client 7 | if openai_client is None: 8 | from openai import OpenAI 9 | openai_client = OpenAI() 10 | 11 | messages = [{'role': 'user', 'content': prompt}] 12 | response = openai_client.chat.completions.create( 13 | model=model, 14 | messages=messages, 15 | temperature=0.7, 16 | max_tokens=1, 17 | logprobs=True, 18 | top_logprobs=10, 19 | n=1 20 | ) 21 | 22 | top_logprobs = response.choices[0].logprobs.content[0].top_logprobs 23 | for logprob in top_logprobs: 24 | logprob.probability = np.exp(logprob.logprob) 25 | 26 | return top_logprobs 27 | 28 | class SimpleProbability: 29 | def __init__(self, token, probability): 30 | self.token = token 31 | self.probability = probability 32 | 33 | def get_logprobs_llama(prompt, base_url): 34 | import requests 35 | 36 | url = base_url+'/completion' 37 | payload = { 'prompt': prompt, 38 | 'cache_prompt': True, 39 | 'temperature': 1.0, 40 | 'n_predict': 1, 41 | 'top_k': 10, 42 | 'top_p': 1.0, 43 | 'n_probs': 10 44 | } 45 | 46 | response = requests.post(url, json=payload) 47 | probs = response.json()['completion_probabilities'][0]['probs'] 48 | print(probs) 49 | 50 | return [ SimpleProbability(prob['tok_str'], prob['prob']) for prob in probs] 51 | 52 | vllm_model_name = None 53 | def get_logprobs_vllm(prompt, base_url): 54 | import requests 55 | 56 | global vllm_model_name 57 | if vllm_model_name is None: 58 | models = requests.get(base_url+'/v1/models').json() 59 | vllm_model_name = models['data'][0]['id'] 60 | print('VLLM model name:', vllm_model_name) 61 | 62 | url = base_url+'/v1/completions' 63 | payload = { 64 | "prompt": prompt, 65 | "n": 1, 66 | "temperature": 0.0, 67 | "max_tokens": 1, 68 | "stream": False, 69 | "logprobs": 5, 70 | "model": vllm_model_name 71 | } 72 | 73 | response = requests.post(url, json=payload) 74 | probs = response.json()['choices'][0]['logprobs']['top_logprobs'][0] 75 | return [ SimpleProbability(k,np.exp(v)) for k,v in probs.items()] 76 | 77 | from concurrent.futures import ThreadPoolExecutor, as_completed 78 | 79 | def parallel_get_logprobs(prompt, acc): 80 | # Choose which API to use based on environment variables 81 | if os.getenv('LLAMA_API_URL') is not None: 82 | logprobs = get_logprobs_llama(prompt, os.getenv('LLAMA_API_URL')) 83 | elif os.getenv('VLLM_API_URL') is not None: 84 | logprobs = get_logprobs_vllm(prompt, os.getenv('VLLM_API_URL')) 85 | elif os.getenv('OPENAI_API_KEY') is not None: 86 | logprobs = get_logprobs_openai(prompt) 87 | else: 88 | raise Exception('Please set either OPENAI_API_KEY or LLAMA_API_URL') 89 | 90 | return (prompt, acc, logprobs) 91 | 92 | def parallel_lloom_search(initial_prompt, max_depth, max_beams, stop_tokens, initial_cutoff, multiplier, maxsplits, parallelism=2): 93 | 94 | tasks = [(initial_prompt, 0.0)] 95 | cutoff = initial_cutoff 96 | depth = max_depth 97 | done_beams = 0 98 | 99 | with ThreadPoolExecutor(max_workers=parallelism) as executor: 100 | while tasks: 101 | # spawn futures 102 | futures = [] 103 | for task in tasks: 104 | print("spawning depth:", depth ,"task:", task) 105 | futures.append(executor.submit(parallel_get_logprobs, *task)) 106 | 107 | total_futures = len(tasks) 108 | tasks = [] 109 | done_futures = 0 110 | 111 | # process futures as they come in 112 | for future in as_completed(futures): 113 | res = future.result() 114 | (prompt, acc, logprobs) = res 115 | 116 | count = 0 117 | for logprob_choice in logprobs: 118 | token = logprob_choice.token 119 | probability = logprob_choice.probability 120 | 121 | if count > 0 and probability < cutoff: break 122 | if maxsplits > 0 and count == maxsplits: break 123 | 124 | count += 1 125 | 126 | new_prompt = prompt + token 127 | early_finish = False 128 | 129 | if depth == 0 or ((max_beams > 0) and (done_beams+total_futures-done_futures >= max_beams)): 130 | yield (acc + probability, new_prompt, max_depth - depth) 131 | early_finish = True 132 | else: 133 | new_tokens = new_prompt[len(initial_prompt):] 134 | stop_search_tokens = new_tokens 135 | 136 | for st in stop_tokens: 137 | # starting with a stop token is OK, keep searching until there's some meat 138 | if stop_search_tokens[0:len(st)] == st: 139 | stop_search_tokens = stop_search_tokens[len(st):] 140 | 141 | if (not early_finish) and (st in stop_search_tokens): 142 | trimmed_prompt = initial_prompt + new_tokens[:new_tokens.find(st)+1] 143 | yield (acc + probability, trimmed_prompt, max_depth - depth) 144 | early_finish = True 145 | 146 | if not early_finish: 147 | new_task =(new_prompt, acc + probability) 148 | tasks.append(new_task) 149 | else: 150 | done_beams += 1 151 | 152 | done_futures += 1 153 | 154 | # adjust for next cycle 155 | cutoff = cutoff * multiplier 156 | depth = depth - 1 157 | -------------------------------------------------------------------------------- /viz.py: -------------------------------------------------------------------------------- 1 | from graphviz import Digraph 2 | from collections import defaultdict 3 | 4 | def find_common_prefix(strings): 5 | if not strings: 6 | return "", [] 7 | if len(strings) == 1: 8 | return strings[0], [] 9 | 10 | prefix = strings[0] 11 | for s in strings[1:]: 12 | i = 0 13 | while i < len(prefix) and i < len(s) and prefix[i] == s[i]: 14 | i += 1 15 | prefix = prefix[:i] 16 | if not prefix: 17 | break 18 | 19 | if ' ' in prefix: 20 | # stop at the last space 21 | prefix = prefix.rsplit(' ', 1)[0] 22 | 23 | remaining = [s[len(prefix):].strip() for s in strings if s[len(prefix):].strip()] 24 | return prefix, remaining 25 | 26 | def visualize_common_prefixes(strings): 27 | graph = Digraph() 28 | graph.attr(rankdir='LR') # Set the direction to left-to-right 29 | 30 | def add_nodes_and_edges(strings, parent=None, level=0): 31 | if not strings: 32 | return 33 | 34 | prefix, remainder = find_common_prefix(strings) 35 | 36 | if prefix: 37 | node_label = prefix.strip() 38 | node_id = f"{level}-{node_label}" 39 | graph.node(node_id, node_label, shape='box', style='filled', fillcolor='lightgrey', fontsize='12') 40 | if parent and parent != node_label: # Avoid self-looping edge 41 | graph.edge(parent, node_id) 42 | print(level, prefix, ':', parent, '=>', node_label) 43 | 44 | if remainder: 45 | child_groups = defaultdict(list) 46 | for s in remainder: 47 | child_groups[s.split()[0]].append(s) 48 | 49 | for group in child_groups.values(): 50 | add_nodes_and_edges(group, node_id, level+1) 51 | 52 | # Add "[start]" prefix to each string 53 | prefixed_strings = ["[start] " + s for s in strings] 54 | add_nodes_and_edges(prefixed_strings, None) 55 | return graph 56 | 57 | if __name__ == "__main__": 58 | strings = ["There was once", "There was a", "One sunny"] 59 | graph = visualize_common_prefixes(strings) 60 | graph.render('common_prefixes', format='png', cleanup=True) --------------------------------------------------------------------------------