├── .gitignore
├── prompted_responder
├── requirements.txt
└── responder.py
└── prompted-web-server
├── README.md
├── static
├── favicon.ico
└── index.html
├── go.mod
├── go.sum
├── broadcaster.go
├── server.go
└── mux.go
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 |
--------------------------------------------------------------------------------
/prompted_responder/requirements.txt:
--------------------------------------------------------------------------------
1 | llama-cpp-python==0.2.6
2 | requests==2.31.0
3 |
--------------------------------------------------------------------------------
/prompted-web-server/README.md:
--------------------------------------------------------------------------------
1 | :snake: Serve static files + GET/POST API + websocket for status updates
2 |
--------------------------------------------------------------------------------
/prompted-web-server/static/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whatever/prompted/main/prompted-web-server/static/favicon.ico
--------------------------------------------------------------------------------
/prompted-web-server/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/whatever/echo
2 |
3 | go 1.19
4 |
5 | require github.com/gorilla/websocket v1.5.0 // indirect
6 |
--------------------------------------------------------------------------------
/prompted-web-server/go.sum:
--------------------------------------------------------------------------------
1 | github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
2 | github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
3 |
--------------------------------------------------------------------------------
/prompted-web-server/broadcaster.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "sync"
5 |
6 | "github.com/gorilla/websocket"
7 | )
8 |
9 | type Event struct {
10 | Conn *websocket.Conn
11 | Data []byte
12 | }
13 |
14 | // BrainBleachSocialConns collects clients and broadcasts messages to clients.
15 | type Broadcaster struct {
16 | Conns map[*websocket.Conn]bool
17 | mutex sync.Mutex
18 | events chan Event
19 | }
20 |
21 | // NewBrainBleachSocialConns returns a new, empty BrainBleachSocialConns obj.
22 | func NewBroadcaster() *Broadcaster {
23 | return &Broadcaster{
24 | Conns: make(map[*websocket.Conn]bool),
25 | mutex: sync.Mutex{},
26 | events: make(chan Event),
27 | }
28 | }
29 |
30 | // Add includes a new connection and sets up a go-routine to wait for the
31 | // connection to close.
32 | func (conns *Broadcaster) Add(conn *websocket.Conn) {
33 |
34 | conns.mutex.Lock()
35 | conns.Conns[conn] = true
36 | conns.mutex.Unlock()
37 |
38 | go func() {
39 | defer func() {
40 | conns.Remove(conn)
41 | }()
42 |
43 | for {
44 | if _, m, err := conn.ReadMessage(); err == nil {
45 | conns.events <- Event{conn, m}
46 | } else {
47 | break
48 | }
49 | }
50 | }()
51 | }
52 |
53 | // Remove closes and removes a connection from the Hub.
54 | func (conns *Broadcaster) Remove(conn *websocket.Conn) {
55 | conns.mutex.Lock()
56 | defer conns.mutex.Unlock()
57 | conn.Close()
58 | delete(conns.Conns, conn)
59 | }
60 |
61 | // Broadcast sends messages to every client.
62 | func (conns *Broadcaster) Broadcast(message []byte) error {
63 | conns.mutex.Lock()
64 | defer conns.mutex.Unlock()
65 | for conn, _ := range conns.Conns {
66 | conn.WriteMessage(websocket.TextMessage, message)
67 | }
68 | return nil
69 | }
70 |
71 | // Broadcast sends messages to every client.
72 | func (conns *Broadcaster) Send(conn *websocket.Conn, message []byte) error {
73 | return conn.WriteMessage(websocket.TextMessage, message)
74 | }
75 |
76 | // Events returns a read-only channel of valid messages sent from connections.
77 | func (conns *Broadcaster) Events() <-chan Event {
78 | return conns.events
79 | }
80 |
--------------------------------------------------------------------------------
/prompted-web-server/server.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "embed"
5 | "flag"
6 | "fmt"
7 | "log"
8 | "net/http"
9 | "sync"
10 | "time"
11 | )
12 |
13 | type TrackerState string
14 |
15 | const (
16 | TrackerStateWaiting TrackerState = "waiting"
17 | TrackerStateWorking TrackerState = "working"
18 | TrackerStateReady TrackerState = "ready"
19 | )
20 |
21 | //go:embed static
22 | var static embed.FS
23 |
24 | // PromptResponse is a prompt -> response pair with some things to coordinate multiple participants.
25 | type PromptResponseTracker struct {
26 | Prompt string `json:"prompt"`
27 | Response string `json:"response"`
28 | Secret string `json:"-"`
29 | State string `json:"state"`
30 | LastTouched time.Time `json:"-"`
31 | LastHeartbeat time.Time `json:"-"`
32 | mutex sync.Mutex `json:"-"`
33 | }
34 |
35 | // PromptResponseMessage is a response type which is probably redundant with the above
36 | type PromptResponseMessage struct {
37 | Prompt string `json:"prompt"`
38 | Response string `json:"response"`
39 | State string `json:"state"`
40 | Error string `json:"error,omitempty"`
41 | }
42 |
43 | // NewPromptResponseTracker returns an object to track a single prompt and its response
44 | func NewPromptResponseTracker() *PromptResponseTracker {
45 | return &PromptResponseTracker{
46 | Prompt: "",
47 | Response: "",
48 | State: "ready",
49 | Secret: "8181", // fmt.Sprintf("%d", rand.Intn(81818181)),
50 | mutex: sync.Mutex{},
51 | }
52 | }
53 |
54 | // StatusMessage returns a message with the current state/status of the tracker
55 | func (tracker *PromptResponseTracker) StatusMessage() PromptResponseMessage {
56 | return PromptResponseMessage{
57 | Prompt: tracker.Prompt,
58 | Response: tracker.Response,
59 | State: tracker.State,
60 | Error: "",
61 | }
62 | }
63 |
64 | var port = flag.Int("port", 8182, "set port to listen on")
65 |
66 | func main() {
67 | mux, err := NewMux()
68 |
69 | if err != nil {
70 | log.Fatal(err)
71 | }
72 |
73 | flag.Parse()
74 |
75 | log.Printf("Listening on port %d", *port)
76 |
77 | log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", *port), mux))
78 | }
79 |
--------------------------------------------------------------------------------
/prompted-web-server/static/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | machine gun funk
5 |
26 |
27 |
94 |
95 |
96 |
97 | state:
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
--------------------------------------------------------------------------------
/prompted_responder/responder.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 |
4 | import argparse
5 | import logging
6 | import requests
7 | import signal
8 | import sys
9 | import time
10 |
11 |
12 | try:
13 | from llama_cpp import Llama
14 | DEBUG = False
15 | except ImportError:
16 | from unittest import mock
17 | logging.warn("Could not import Llama from llama_cpp")
18 | DEBUG = True
19 | Llama = mock.MagicMock()
20 |
21 |
22 | def signal_handler(sig, frame):
23 | """Exit program successfully"""
24 | logging.info("exiting")
25 | sys.exit(0)
26 |
27 |
28 | def predict(prompt):
29 | """Return a LLaMa response given a prompt"""
30 |
31 | if DEBUG:
32 | time.sleep(2)
33 | return f"[DEBUG] Q: {prompt} A: This is a response"
34 |
35 | if not prompt.startswith("Q:"):
36 | return "prompt must start with Q:"
37 |
38 | elif not prompt.endswith("A:"):
39 | return "prompt must end with A:"
40 |
41 | try:
42 | output = llm(
43 | resp["prompt"],
44 | max_tokens=32,
45 | stop=["Q:", "\n"],
46 | echo=True,
47 | )
48 | response = output["choices"][0]["text"]
49 | except Exception as e:
50 | logging.error("Received error!", e)
51 | response = "some other error happened"
52 |
53 | return response
54 |
55 |
56 | def heartbeat(url, secret, state):
57 | """Send a heartbeat to the server"""
58 |
59 | d = {
60 | "secret": secret,
61 | "state": state,
62 | }
63 |
64 | return requests.post(url, data=d)
65 |
66 |
67 | def respond(url, secret, prompt, response):
68 | """Send a response to the server"""
69 |
70 | d = {
71 | "secret": secret,
72 | "prompt": prompt,
73 | "response": response,
74 | "state": "ready",
75 | }
76 |
77 | return requests.post(url, data=d)
78 |
79 |
80 | if __name__ == "__main__":
81 |
82 | parser = argparse.ArgumentParser()
83 | parser.add_argument("--host", default="whatever.rip")
84 | parser.add_argument("--secret", default="8181")
85 | args = parser.parse_args()
86 |
87 | llm = Llama(
88 | model_path="/home/matt/Models/Llama-2-7b-chat-hf/ggml-model-q4_0.gguf",
89 | n_gpu_layers=32,
90 | )
91 |
92 | URLS = {}
93 | URLS["status"] = f"http://{args.host}/status"
94 | URLS["respond"] = f"http://{args.host}/respond"
95 | URLS["heartbeat"] = f"http://{args.host}/heartbeat"
96 |
97 | signal.signal(signal.SIGINT, signal_handler)
98 |
99 | while True:
100 |
101 | time.sleep(1)
102 |
103 | # Fetch status and determine whether we should ignore the response
104 |
105 | resp = requests.get(URLS["status"]).json()
106 |
107 | if resp.get("state") != "waiting":
108 | logging.info("server is not waiting, so skipping")
109 | continue
110 |
111 | if "error" in resp:
112 | logging.warn("There was an error:", resp["error"])
113 | continue
114 |
115 | if resp.get("response"):
116 | logging.warn("Response was non-empty, so no need to compute a response")
117 | continue
118 |
119 | if not resp.get("prompt"):
120 | logging.warn("prompt was empty skipping")
121 | continue
122 |
123 | heartbeat(
124 | URLS["heartbeat"],
125 | args.secret,
126 | "working",
127 | )
128 |
129 | response = predict(resp["prompt"].strip())
130 |
131 | respond(
132 | URLS["respond"],
133 | args.secret,
134 | resp["prompt"],
135 | response,
136 | ).json()
137 |
--------------------------------------------------------------------------------
/prompted-web-server/mux.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "encoding/json"
5 | "io/fs"
6 | "log"
7 | "net/http"
8 | "time"
9 |
10 | "github.com/gorilla/websocket"
11 | )
12 |
13 | // SendStatusUpdate sends a status update to all connected websockets
14 | func SendStatusUpdate(caster *Broadcaster, tracker *PromptResponseTracker) {
15 | encoded, err := json.Marshal(tracker.StatusMessage())
16 | if err != nil {
17 | log.Println(err)
18 | return
19 | }
20 | caster.Broadcast(encoded)
21 | }
22 |
23 | // NewMux returns a new set of routes
24 | func NewMux() (*http.ServeMux, error) {
25 |
26 | fs, err := fs.Sub(static, "static")
27 |
28 | if err != nil {
29 | return nil, err
30 | }
31 |
32 | tracker := NewPromptResponseTracker()
33 |
34 | log.Printf("Initialized with secret %v", tracker.Secret)
35 |
36 | caster := NewBroadcaster()
37 |
38 | mux := http.NewServeMux()
39 |
40 | mux.HandleFunc("/prompt", func(w http.ResponseWriter, req *http.Request) {
41 |
42 | tracker.mutex.Lock()
43 | defer tracker.mutex.Unlock()
44 |
45 | now := time.Now()
46 |
47 | var resp PromptResponseMessage
48 |
49 | req.ParseForm()
50 |
51 | switch {
52 | case now.Add(-3 * time.Second).Before(tracker.LastTouched):
53 | resp = PromptResponseMessage{
54 | Prompt: "",
55 | Response: "",
56 | State: tracker.State,
57 | Error: "Request happened too soon",
58 | }
59 |
60 | case tracker.State == "working":
61 | resp = PromptResponseMessage{
62 | Prompt: "",
63 | Response: "",
64 | State: tracker.State,
65 | Error: "Request occurred while another is being computed",
66 | }
67 |
68 | case !req.Form.Has("prompt"):
69 | resp = PromptResponseMessage{
70 | Prompt: "",
71 | Response: "",
72 | State: tracker.State,
73 | Error: "Request is missing prompt in post form data field",
74 | }
75 |
76 | default:
77 | resp = PromptResponseMessage{
78 | Prompt: req.Form.Get("prompt"),
79 | Response: "",
80 | State: tracker.State,
81 | Error: "",
82 | }
83 | tracker.LastTouched = now
84 | tracker.State = "waiting"
85 | tracker.Prompt = req.Form.Get("prompt")
86 | tracker.Response = ""
87 | }
88 |
89 | encoded, err := json.Marshal(resp)
90 |
91 | if err != nil {
92 | return
93 | }
94 |
95 | w.Write(encoded)
96 | SendStatusUpdate(caster, tracker)
97 | })
98 |
99 | mux.HandleFunc("/heartbeat", func(w http.ResponseWriter, req *http.Request) {
100 | tracker.mutex.Lock()
101 | defer tracker.mutex.Unlock()
102 |
103 | req.ParseForm()
104 |
105 | errmsg := ""
106 | now := time.Now()
107 |
108 | switch {
109 | case req.Form.Get("secret") != tracker.Secret:
110 | errmsg = "provided secret is incorrect"
111 | case !req.Form.Has("state"):
112 | errmsg = "request is missing state parameter"
113 | default:
114 | tracker.State = req.Form.Get("state")
115 | tracker.LastHeartbeat = now
116 | }
117 |
118 | if errmsg != "" {
119 | log.Printf("error: %v", errmsg)
120 | }
121 |
122 | json.NewEncoder(w).Encode(PromptResponseMessage{
123 | Prompt: tracker.Prompt,
124 | Response: tracker.Response,
125 | State: tracker.State,
126 | Error: errmsg,
127 | })
128 |
129 | })
130 |
131 | mux.HandleFunc("/respond", func(w http.ResponseWriter, req *http.Request) {
132 |
133 | tracker.mutex.Lock()
134 | defer tracker.mutex.Unlock()
135 |
136 | now := time.Now()
137 |
138 | req.ParseForm()
139 |
140 | errmsg := ""
141 |
142 | state := tracker.State
143 |
144 | if req.Form.Has("state") {
145 | state = req.Form.Get("state")
146 | }
147 |
148 | switch {
149 | case !req.Form.Has("prompt"):
150 | errmsg = "request is missing prompt parameter"
151 |
152 | case !req.Form.Has("response"):
153 | errmsg = "request is missing response parameter"
154 |
155 | case !req.Form.Has("secret"):
156 | errmsg = "request is missing secret parameter"
157 |
158 | case req.Form.Get("prompt") != tracker.Prompt:
159 | log.Printf("%v != %v", req.Form.Get("prompt"), tracker.Prompt)
160 | errmsg = "request is responding to the wrong prompt"
161 |
162 | case req.Form.Get("secret") != tracker.Secret:
163 | log.Printf("secret is incorrect")
164 | errmsg = "request is not sending the correct secret"
165 |
166 | default:
167 | tracker.State = state
168 | tracker.Response = req.Form.Get("response")
169 | tracker.LastTouched = now
170 | }
171 |
172 | resp := PromptResponseMessage{
173 | Prompt: tracker.Prompt,
174 | Response: tracker.Response,
175 | State: tracker.State,
176 | Error: errmsg,
177 | }
178 |
179 | json.NewEncoder(w).Encode(resp)
180 | })
181 |
182 | mux.HandleFunc("/status", func(w http.ResponseWriter, req *http.Request) {
183 | tracker.mutex.Lock()
184 | defer tracker.mutex.Unlock()
185 |
186 | json.NewEncoder(w).Encode(PromptResponseMessage{
187 | Prompt: tracker.Prompt,
188 | Response: tracker.Response,
189 | State: tracker.State,
190 | Error: "",
191 | })
192 |
193 | SendStatusUpdate(caster, tracker)
194 | })
195 |
196 | upgrader := websocket.Upgrader{
197 | ReadBufferSize: 1024,
198 | WriteBufferSize: 1024,
199 | }
200 |
201 | mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
202 | conn, err := upgrader.Upgrade(w, req, nil)
203 |
204 | if err != nil {
205 | log.Println(err)
206 | return
207 | }
208 |
209 | message, err := json.Marshal(tracker.StatusMessage())
210 |
211 | if err != nil {
212 | log.Println(err)
213 | return
214 | }
215 |
216 | conn.WriteMessage(websocket.TextMessage, message)
217 |
218 | caster.Add(conn)
219 | })
220 |
221 | mux.Handle("/", http.FileServer(http.FS(fs)))
222 |
223 | return mux, nil
224 | }
225 |
--------------------------------------------------------------------------------