├── .gitignore ├── prompted_responder ├── requirements.txt └── responder.py └── prompted-web-server ├── README.md ├── static ├── favicon.ico └── index.html ├── go.mod ├── go.sum ├── broadcaster.go ├── server.go └── mux.go /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | -------------------------------------------------------------------------------- /prompted_responder/requirements.txt: -------------------------------------------------------------------------------- 1 | llama-cpp-python==0.2.6 2 | requests==2.31.0 3 | -------------------------------------------------------------------------------- /prompted-web-server/README.md: -------------------------------------------------------------------------------- 1 | :snake: Serve static files + GET/POST API + websocket for status updates 2 | -------------------------------------------------------------------------------- /prompted-web-server/static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whatever/prompted/main/prompted-web-server/static/favicon.ico -------------------------------------------------------------------------------- /prompted-web-server/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/whatever/echo 2 | 3 | go 1.19 4 | 5 | require github.com/gorilla/websocket v1.5.0 // indirect 6 | -------------------------------------------------------------------------------- /prompted-web-server/go.sum: -------------------------------------------------------------------------------- 1 | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= 2 | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= 3 | -------------------------------------------------------------------------------- /prompted-web-server/broadcaster.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/gorilla/websocket" 7 | ) 8 | 9 | type Event struct { 10 | Conn *websocket.Conn 11 | Data []byte 12 | } 13 | 14 | // BrainBleachSocialConns collects clients and broadcasts messages to clients. 15 | type Broadcaster struct { 16 | Conns map[*websocket.Conn]bool 17 | mutex sync.Mutex 18 | events chan Event 19 | } 20 | 21 | // NewBrainBleachSocialConns returns a new, empty BrainBleachSocialConns obj. 22 | func NewBroadcaster() *Broadcaster { 23 | return &Broadcaster{ 24 | Conns: make(map[*websocket.Conn]bool), 25 | mutex: sync.Mutex{}, 26 | events: make(chan Event), 27 | } 28 | } 29 | 30 | // Add includes a new connection and sets up a go-routine to wait for the 31 | // connection to close. 32 | func (conns *Broadcaster) Add(conn *websocket.Conn) { 33 | 34 | conns.mutex.Lock() 35 | conns.Conns[conn] = true 36 | conns.mutex.Unlock() 37 | 38 | go func() { 39 | defer func() { 40 | conns.Remove(conn) 41 | }() 42 | 43 | for { 44 | if _, m, err := conn.ReadMessage(); err == nil { 45 | conns.events <- Event{conn, m} 46 | } else { 47 | break 48 | } 49 | } 50 | }() 51 | } 52 | 53 | // Remove closes and removes a connection from the Hub. 54 | func (conns *Broadcaster) Remove(conn *websocket.Conn) { 55 | conns.mutex.Lock() 56 | defer conns.mutex.Unlock() 57 | conn.Close() 58 | delete(conns.Conns, conn) 59 | } 60 | 61 | // Broadcast sends messages to every client. 62 | func (conns *Broadcaster) Broadcast(message []byte) error { 63 | conns.mutex.Lock() 64 | defer conns.mutex.Unlock() 65 | for conn, _ := range conns.Conns { 66 | conn.WriteMessage(websocket.TextMessage, message) 67 | } 68 | return nil 69 | } 70 | 71 | // Broadcast sends messages to every client. 72 | func (conns *Broadcaster) Send(conn *websocket.Conn, message []byte) error { 73 | return conn.WriteMessage(websocket.TextMessage, message) 74 | } 75 | 76 | // Events returns a read-only channel of valid messages sent from connections. 77 | func (conns *Broadcaster) Events() <-chan Event { 78 | return conns.events 79 | } 80 | -------------------------------------------------------------------------------- /prompted-web-server/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "embed" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | type TrackerState string 14 | 15 | const ( 16 | TrackerStateWaiting TrackerState = "waiting" 17 | TrackerStateWorking TrackerState = "working" 18 | TrackerStateReady TrackerState = "ready" 19 | ) 20 | 21 | //go:embed static 22 | var static embed.FS 23 | 24 | // PromptResponse is a prompt -> response pair with some things to coordinate multiple participants. 25 | type PromptResponseTracker struct { 26 | Prompt string `json:"prompt"` 27 | Response string `json:"response"` 28 | Secret string `json:"-"` 29 | State string `json:"state"` 30 | LastTouched time.Time `json:"-"` 31 | LastHeartbeat time.Time `json:"-"` 32 | mutex sync.Mutex `json:"-"` 33 | } 34 | 35 | // PromptResponseMessage is a response type which is probably redundant with the above 36 | type PromptResponseMessage struct { 37 | Prompt string `json:"prompt"` 38 | Response string `json:"response"` 39 | State string `json:"state"` 40 | Error string `json:"error,omitempty"` 41 | } 42 | 43 | // NewPromptResponseTracker returns an object to track a single prompt and its response 44 | func NewPromptResponseTracker() *PromptResponseTracker { 45 | return &PromptResponseTracker{ 46 | Prompt: "", 47 | Response: "", 48 | State: "ready", 49 | Secret: "8181", // fmt.Sprintf("%d", rand.Intn(81818181)), 50 | mutex: sync.Mutex{}, 51 | } 52 | } 53 | 54 | // StatusMessage returns a message with the current state/status of the tracker 55 | func (tracker *PromptResponseTracker) StatusMessage() PromptResponseMessage { 56 | return PromptResponseMessage{ 57 | Prompt: tracker.Prompt, 58 | Response: tracker.Response, 59 | State: tracker.State, 60 | Error: "", 61 | } 62 | } 63 | 64 | var port = flag.Int("port", 8182, "set port to listen on") 65 | 66 | func main() { 67 | mux, err := NewMux() 68 | 69 | if err != nil { 70 | log.Fatal(err) 71 | } 72 | 73 | flag.Parse() 74 | 75 | log.Printf("Listening on port %d", *port) 76 | 77 | log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), mux)) 78 | } 79 | -------------------------------------------------------------------------------- /prompted-web-server/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | machine gun funk 5 | 26 | 27 | 94 | 95 | 96 |
97 | state: 98 |
99 | 100 |
101 | 102 |
103 | 104 |
105 | 106 | 107 | -------------------------------------------------------------------------------- /prompted_responder/responder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import argparse 5 | import logging 6 | import requests 7 | import signal 8 | import sys 9 | import time 10 | 11 | 12 | try: 13 | from llama_cpp import Llama 14 | DEBUG = False 15 | except ImportError: 16 | from unittest import mock 17 | logging.warn("Could not import Llama from llama_cpp") 18 | DEBUG = True 19 | Llama = mock.MagicMock() 20 | 21 | 22 | def signal_handler(sig, frame): 23 | """Exit program successfully""" 24 | logging.info("exiting") 25 | sys.exit(0) 26 | 27 | 28 | def predict(prompt): 29 | """Return a LLaMa response given a prompt""" 30 | 31 | if DEBUG: 32 | time.sleep(2) 33 | return f"[DEBUG] Q: {prompt} A: This is a response" 34 | 35 | if not prompt.startswith("Q:"): 36 | return "prompt must start with Q:" 37 | 38 | elif not prompt.endswith("A:"): 39 | return "prompt must end with A:" 40 | 41 | try: 42 | output = llm( 43 | resp["prompt"], 44 | max_tokens=32, 45 | stop=["Q:", "\n"], 46 | echo=True, 47 | ) 48 | response = output["choices"][0]["text"] 49 | except Exception as e: 50 | logging.error("Received error!", e) 51 | response = "some other error happened" 52 | 53 | return response 54 | 55 | 56 | def heartbeat(url, secret, state): 57 | """Send a heartbeat to the server""" 58 | 59 | d = { 60 | "secret": secret, 61 | "state": state, 62 | } 63 | 64 | return requests.post(url, data=d) 65 | 66 | 67 | def respond(url, secret, prompt, response): 68 | """Send a response to the server""" 69 | 70 | d = { 71 | "secret": secret, 72 | "prompt": prompt, 73 | "response": response, 74 | "state": "ready", 75 | } 76 | 77 | return requests.post(url, data=d) 78 | 79 | 80 | if __name__ == "__main__": 81 | 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument("--host", default="whatever.rip") 84 | parser.add_argument("--secret", default="8181") 85 | args = parser.parse_args() 86 | 87 | llm = Llama( 88 | model_path="/home/matt/Models/Llama-2-7b-chat-hf/ggml-model-q4_0.gguf", 89 | n_gpu_layers=32, 90 | ) 91 | 92 | URLS = {} 93 | URLS["status"] = f"http://{args.host}/status" 94 | URLS["respond"] = f"http://{args.host}/respond" 95 | URLS["heartbeat"] = f"http://{args.host}/heartbeat" 96 | 97 | signal.signal(signal.SIGINT, signal_handler) 98 | 99 | while True: 100 | 101 | time.sleep(1) 102 | 103 | # Fetch status and determine whether we should ignore the response 104 | 105 | resp = requests.get(URLS["status"]).json() 106 | 107 | if resp.get("state") != "waiting": 108 | logging.info("server is not waiting, so skipping") 109 | continue 110 | 111 | if "error" in resp: 112 | logging.warn("There was an error:", resp["error"]) 113 | continue 114 | 115 | if resp.get("response"): 116 | logging.warn("Response was non-empty, so no need to compute a response") 117 | continue 118 | 119 | if not resp.get("prompt"): 120 | logging.warn("prompt was empty skipping") 121 | continue 122 | 123 | heartbeat( 124 | URLS["heartbeat"], 125 | args.secret, 126 | "working", 127 | ) 128 | 129 | response = predict(resp["prompt"].strip()) 130 | 131 | respond( 132 | URLS["respond"], 133 | args.secret, 134 | resp["prompt"], 135 | response, 136 | ).json() 137 | -------------------------------------------------------------------------------- /prompted-web-server/mux.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "io/fs" 6 | "log" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/gorilla/websocket" 11 | ) 12 | 13 | // SendStatusUpdate sends a status update to all connected websockets 14 | func SendStatusUpdate(caster *Broadcaster, tracker *PromptResponseTracker) { 15 | encoded, err := json.Marshal(tracker.StatusMessage()) 16 | if err != nil { 17 | log.Println(err) 18 | return 19 | } 20 | caster.Broadcast(encoded) 21 | } 22 | 23 | // NewMux returns a new set of routes 24 | func NewMux() (*http.ServeMux, error) { 25 | 26 | fs, err := fs.Sub(static, "static") 27 | 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | tracker := NewPromptResponseTracker() 33 | 34 | log.Printf("Initialized with secret %v", tracker.Secret) 35 | 36 | caster := NewBroadcaster() 37 | 38 | mux := http.NewServeMux() 39 | 40 | mux.HandleFunc("/prompt", func(w http.ResponseWriter, req *http.Request) { 41 | 42 | tracker.mutex.Lock() 43 | defer tracker.mutex.Unlock() 44 | 45 | now := time.Now() 46 | 47 | var resp PromptResponseMessage 48 | 49 | req.ParseForm() 50 | 51 | switch { 52 | case now.Add(-3 * time.Second).Before(tracker.LastTouched): 53 | resp = PromptResponseMessage{ 54 | Prompt: "", 55 | Response: "", 56 | State: tracker.State, 57 | Error: "Request happened too soon", 58 | } 59 | 60 | case tracker.State == "working": 61 | resp = PromptResponseMessage{ 62 | Prompt: "", 63 | Response: "", 64 | State: tracker.State, 65 | Error: "Request occurred while another is being computed", 66 | } 67 | 68 | case !req.Form.Has("prompt"): 69 | resp = PromptResponseMessage{ 70 | Prompt: "", 71 | Response: "", 72 | State: tracker.State, 73 | Error: "Request is missing prompt in post form data field", 74 | } 75 | 76 | default: 77 | resp = PromptResponseMessage{ 78 | Prompt: req.Form.Get("prompt"), 79 | Response: "", 80 | State: tracker.State, 81 | Error: "", 82 | } 83 | tracker.LastTouched = now 84 | tracker.State = "waiting" 85 | tracker.Prompt = req.Form.Get("prompt") 86 | tracker.Response = "" 87 | } 88 | 89 | encoded, err := json.Marshal(resp) 90 | 91 | if err != nil { 92 | return 93 | } 94 | 95 | w.Write(encoded) 96 | SendStatusUpdate(caster, tracker) 97 | }) 98 | 99 | mux.HandleFunc("/heartbeat", func(w http.ResponseWriter, req *http.Request) { 100 | tracker.mutex.Lock() 101 | defer tracker.mutex.Unlock() 102 | 103 | req.ParseForm() 104 | 105 | errmsg := "" 106 | now := time.Now() 107 | 108 | switch { 109 | case req.Form.Get("secret") != tracker.Secret: 110 | errmsg = "provided secret is incorrect" 111 | case !req.Form.Has("state"): 112 | errmsg = "request is missing state parameter" 113 | default: 114 | tracker.State = req.Form.Get("state") 115 | tracker.LastHeartbeat = now 116 | } 117 | 118 | if errmsg != "" { 119 | log.Printf("error: %v", errmsg) 120 | } 121 | 122 | json.NewEncoder(w).Encode(PromptResponseMessage{ 123 | Prompt: tracker.Prompt, 124 | Response: tracker.Response, 125 | State: tracker.State, 126 | Error: errmsg, 127 | }) 128 | 129 | }) 130 | 131 | mux.HandleFunc("/respond", func(w http.ResponseWriter, req *http.Request) { 132 | 133 | tracker.mutex.Lock() 134 | defer tracker.mutex.Unlock() 135 | 136 | now := time.Now() 137 | 138 | req.ParseForm() 139 | 140 | errmsg := "" 141 | 142 | state := tracker.State 143 | 144 | if req.Form.Has("state") { 145 | state = req.Form.Get("state") 146 | } 147 | 148 | switch { 149 | case !req.Form.Has("prompt"): 150 | errmsg = "request is missing prompt parameter" 151 | 152 | case !req.Form.Has("response"): 153 | errmsg = "request is missing response parameter" 154 | 155 | case !req.Form.Has("secret"): 156 | errmsg = "request is missing secret parameter" 157 | 158 | case req.Form.Get("prompt") != tracker.Prompt: 159 | log.Printf("%v != %v", req.Form.Get("prompt"), tracker.Prompt) 160 | errmsg = "request is responding to the wrong prompt" 161 | 162 | case req.Form.Get("secret") != tracker.Secret: 163 | log.Printf("secret is incorrect") 164 | errmsg = "request is not sending the correct secret" 165 | 166 | default: 167 | tracker.State = state 168 | tracker.Response = req.Form.Get("response") 169 | tracker.LastTouched = now 170 | } 171 | 172 | resp := PromptResponseMessage{ 173 | Prompt: tracker.Prompt, 174 | Response: tracker.Response, 175 | State: tracker.State, 176 | Error: errmsg, 177 | } 178 | 179 | json.NewEncoder(w).Encode(resp) 180 | }) 181 | 182 | mux.HandleFunc("/status", func(w http.ResponseWriter, req *http.Request) { 183 | tracker.mutex.Lock() 184 | defer tracker.mutex.Unlock() 185 | 186 | json.NewEncoder(w).Encode(PromptResponseMessage{ 187 | Prompt: tracker.Prompt, 188 | Response: tracker.Response, 189 | State: tracker.State, 190 | Error: "", 191 | }) 192 | 193 | SendStatusUpdate(caster, tracker) 194 | }) 195 | 196 | upgrader := websocket.Upgrader{ 197 | ReadBufferSize: 1024, 198 | WriteBufferSize: 1024, 199 | } 200 | 201 | mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) { 202 | conn, err := upgrader.Upgrade(w, req, nil) 203 | 204 | if err != nil { 205 | log.Println(err) 206 | return 207 | } 208 | 209 | message, err := json.Marshal(tracker.StatusMessage()) 210 | 211 | if err != nil { 212 | log.Println(err) 213 | return 214 | } 215 | 216 | conn.WriteMessage(websocket.TextMessage, message) 217 | 218 | caster.Add(conn) 219 | }) 220 | 221 | mux.Handle("/", http.FileServer(http.FS(fs))) 222 | 223 | return mux, nil 224 | } 225 | --------------------------------------------------------------------------------