├── LICENSE ├── README.md ├── go.mod ├── go.sum ├── images └── postgres.png ├── main.go ├── perf ├── go.mod ├── go.sum ├── measure.go ├── plot.py └── postgresql.conf ├── postgres.c ├── postgres.h ├── test ├── client.go ├── go.mod ├── go.sum └── postgresql.conf └── util.go /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2024, Teodor Janez Podobnik 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PostgresQL eBPF 2 | 3 | This is a demo code, for showcasing Observability of the PostgreSQL protocol using eBPF. This code is inspired by Alaz, Kubernetes eBPF agent, developed by Anteon. 4 | 5 | postgres 6 | 7 | In order to try it out locally: 8 | 9 | - Run eBPF program using 10 | ``` 11 | go generate 12 | go build 13 | sudo ./postgres 14 | ``` 15 | - Run the PostgresQL Container using 16 | ``` 17 | docker run --name postgres-container -e POSTGRES_PASSWORD=mysecretpassword -d -p 5432:5432 postgres 18 | ``` 19 | - Run client inside `/test` using 20 | ``` 21 | go run client.go 22 | ``` 23 | - In another shell, inspect eBPF program logs using 24 | ``` 25 | sudo cat /sys/kernel/debug/tracing/trace_pipe 26 | ``` 27 | - To run performance evaluation, inside `/perf` directory run: 28 | ``` 29 | go run measure.go 30 | ``` -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module postgres 2 | 3 | go 1.22.4 4 | 5 | require github.com/cilium/ebpf v0.15.0 6 | 7 | require ( 8 | golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 // indirect 9 | golang.org/x/sys v0.15.0 // indirect 10 | ) 11 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= 2 | github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso= 3 | github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= 4 | github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= 5 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 6 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 7 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 8 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 9 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 10 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 11 | github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= 12 | github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= 13 | golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 h1:Jvc7gsqn21cJHCmAWx0LiimpP18LZmUxkT5Mp7EZ1mI= 14 | golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= 15 | golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= 16 | golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 17 | -------------------------------------------------------------------------------- /images/postgres.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dorkamotorka/postgres-ebpf/43805aaed7d4d274667ecd34816e81e842cdceaa/images/postgres.png -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "log" 6 | "unsafe" 7 | "regexp" 8 | "strings" 9 | "github.com/cilium/ebpf/rlimit" 10 | "github.com/cilium/ebpf/link" 11 | "github.com/cilium/ebpf/perf" 12 | ) 13 | 14 | //go:generate go run github.com/cilium/ebpf/cmd/bpf2go postgres postgres.c 15 | 16 | var re *regexp.Regexp 17 | var keywords = []string{"SELECT", "INSERT INTO", "UPDATE", "DELETE FROM", "CREATE TABLE", "ALTER TABLE", "DROP TABLE", "TRUNCATE TABLE", "BEGIN", "COMMIT", "ROLLBACK", "SAVEPOINT", "CREATE INDEX", "DROP INDEX", "CREATE VIEW", "DROP VIEW", "GRANT", "REVOKE", "EXECUTE"} 18 | var pgObjs postgresObjects 19 | 20 | func main() { 21 | // Allow the current process to lock memory for eBPF resources. 22 | if err := rlimit.RemoveMemlock(); err != nil { 23 | log.Fatal(err) 24 | } 25 | 26 | // Load pre-compiled programs and maps into the kernel. 27 | pgObjs = postgresObjects{} 28 | if err := loadPostgresObjects(&pgObjs, nil); err != nil { 29 | log.Fatal(err) 30 | } 31 | 32 | w, err := link.Tracepoint("syscalls", "sys_enter_write", pgObjs.HandleWrite, nil) 33 | if err != nil { 34 | log.Fatal("link sys_enter_write tracepoint") 35 | } 36 | defer w.Close() 37 | 38 | r, err := link.Tracepoint("syscalls", "sys_enter_read", pgObjs.HandleRead, nil) 39 | if err != nil { 40 | log.Fatal("link sys_enter_read tracepoint") 41 | } 42 | defer r.Close() 43 | 44 | rexit, err := link.Tracepoint("syscalls", "sys_exit_read", pgObjs.HandleReadExit, nil) 45 | if err != nil { 46 | log.Fatal("link sys_exit_read tracepoint") 47 | } 48 | defer rexit.Close() 49 | 50 | L7EventsReader, err := perf.NewReader(pgObjs.L7Events, int(4096)*os.Getpagesize()) 51 | if err != nil { 52 | log.Fatal("error creating perf event array reader") 53 | } 54 | 55 | // Case-insensitive matching 56 | re = regexp.MustCompile(strings.Join(keywords, "|")) 57 | pgStatements := make(map[string]string) 58 | 59 | for { 60 | var record perf.Record 61 | err := L7EventsReader.ReadInto(&record) 62 | if err != nil { 63 | log.Print("error reading from perf array") 64 | } 65 | 66 | if record.LostSamples != 0 { 67 | log.Printf("lost samples l7-event %d", record.LostSamples) 68 | } 69 | 70 | if record.RawSample == nil || len(record.RawSample) == 0 { 71 | log.Print("read sample l7-event nil or empty") 72 | return 73 | } 74 | 75 | l7Event := (*bpfL7Event)(unsafe.Pointer(&record.RawSample[0])) 76 | 77 | protocol := L7ProtocolConversion(l7Event.Protocol).String() 78 | 79 | // copy payload slice 80 | payload := [1024]uint8{} 81 | copy(payload[:], l7Event.Payload[:]) 82 | 83 | if (protocol == "POSTGRES") { 84 | out, err := parseSqlCommand(l7Event, &pgStatements) 85 | if err != nil { 86 | log.Printf("Error parsing sql command: %s", err) 87 | } else { 88 | log.Printf("%s", out) 89 | } 90 | } 91 | } 92 | } -------------------------------------------------------------------------------- /perf/go.mod: -------------------------------------------------------------------------------- 1 | module perf-postgres 2 | 3 | go 1.22.4 4 | 5 | require github.com/lib/pq v1.10.9 6 | -------------------------------------------------------------------------------- /perf/go.sum: -------------------------------------------------------------------------------- 1 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 2 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 3 | -------------------------------------------------------------------------------- /perf/measure.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | "database/sql" 7 | _ "github.com/lib/pq" 8 | ) 9 | 10 | const ( 11 | host = "localhost" 12 | port = 5432 13 | user = "postgres" 14 | password = "mysecretpassword" 15 | dbname = "mydb" 16 | ) 17 | 18 | func main() { 19 | // Connection string 20 | psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", 21 | host, port, user, password, dbname) 22 | 23 | // Connect to the PostgreSQL database 24 | db, err := sql.Open("postgres", psqlInfo) 25 | CheckError(err) 26 | defer db.Close() 27 | 28 | // Create table 29 | createTable := `CREATE TABLE IF NOT EXISTS "Students" ("id" serial primary key, "Name" TEXT, "Roll" INTEGER)` 30 | _, err = db.Exec(createTable) 31 | CheckError(err) 32 | 33 | const repeat = 1000 34 | var totalInsertTime, totalUpdateTime, totalDeleteTime, totalQueryTime time.Duration 35 | 36 | for i := 0; i < repeat; i++ { 37 | start := time.Now() 38 | // Insert data into the table 39 | insertStmt := `INSERT INTO "Students"("Name", "Roll") VALUES('John', 1)` 40 | _, err = db.Exec(insertStmt) 41 | elapsed := time.Since(start) 42 | totalInsertTime += elapsed 43 | CheckError(err) 44 | 45 | start = time.Now() 46 | // Insert data into the table using dynamic SQL 47 | insertDynStmt := `INSERT INTO "Students"("Name", "Roll") VALUES($1, $2)` 48 | _, err = db.Exec(insertDynStmt, "Jane", 2) 49 | elapsed = time.Since(start) 50 | totalInsertTime += elapsed 51 | CheckError(err) 52 | 53 | start = time.Now() 54 | // Update data in the table 55 | updateStmt := `UPDATE "Students" SET "Name"=$1, "Roll"=$2 WHERE "id"=$3` 56 | _, err = db.Exec(updateStmt, "Mary", 3, 2) 57 | elapsed = time.Since(start) 58 | totalUpdateTime += elapsed 59 | CheckError(err) 60 | 61 | start = time.Now() 62 | // Delete data from the table 63 | deleteStmt := `DELETE FROM "Students" WHERE id=$1` 64 | _, err = db.Exec(deleteStmt, 1) 65 | elapsed = time.Since(start) 66 | totalDeleteTime += elapsed 67 | CheckError(err) 68 | 69 | start = time.Now() 70 | rows, err := db.Query(`SELECT "Name", "Roll" FROM "Students"`) 71 | elapsed = time.Since(start) 72 | totalQueryTime += elapsed 73 | CheckError(err) 74 | 75 | defer rows.Close() 76 | for rows.Next() { 77 | var name string 78 | var roll int 79 | err = rows.Scan(&name, &roll) 80 | CheckError(err) 81 | } 82 | CheckError(rows.Err()) 83 | 84 | // Clean up table for next iteration 85 | _, err = db.Exec(`TRUNCATE "Students" RESTART IDENTITY`) 86 | CheckError(err) 87 | } 88 | 89 | fmt.Printf("Average INSERT latency: %v\n", totalInsertTime/(2*time.Duration(repeat))) 90 | fmt.Printf("Average UPDATE latency: %v\n", totalUpdateTime/time.Duration(repeat)) 91 | fmt.Printf("Average DELETE latency: %v\n", totalDeleteTime/time.Duration(repeat)) 92 | fmt.Printf("Average QUERY latency: %v\n", totalQueryTime/time.Duration(repeat)) 93 | fmt.Println("Perfomance test finished successfully.") 94 | } 95 | 96 | func CheckError(err error) { 97 | if err != nil { 98 | panic(err) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /perf/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | # Latency values in milliseconds 4 | operations = ['INSERT', 'UPDATE', 'DELETE', 'QUERY'] 5 | latency_with_ebpf = [1.5487, 1.3383, 1.2542, 0.3433610] 6 | latency_without_ebpf = [1.3566, 1.1601, 1.0735, 0.2975380] 7 | 8 | x = range(len(operations)) 9 | 10 | # Create the plot 11 | fig, ax = plt.subplots() 12 | 13 | # Plotting the values 14 | ax.bar(x, latency_with_ebpf, width=0.4, label='With eBPF', align='center') 15 | ax.bar(x, latency_without_ebpf, width=0.4, label='Without eBPF', align='edge') 16 | 17 | # Adding labels and title 18 | ax.set_xlabel('Operation') 19 | ax.set_ylabel('Latency (ms)') 20 | ax.set_title('PostgreSQL Latency Comparison with and without eBPF') 21 | ax.set_xticks(x) 22 | ax.set_xticklabels(operations) 23 | ax.legend() 24 | 25 | # Display the plot 26 | plt.tight_layout() 27 | plt.show() -------------------------------------------------------------------------------- /perf/postgresql.conf: -------------------------------------------------------------------------------- 1 | # PostgreSQL configuration file - postgresql.conf 2 | 3 | shared_buffers = 1GB # recommended: 25% of total memory 4 | effective_cache_size = 3GB # recommended: 75% of total memory 5 | work_mem = 64MB # recommended: 2-4MB per CPU core 6 | maintenance_work_mem = 512MB # recommended: 10% of total memory 7 | -------------------------------------------------------------------------------- /postgres.c: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | #include "postgres.h" 4 | 5 | char LICENSE[] SEC("license") = "Dual BSD/GPL"; 6 | 7 | // Instead of allocating on bpf stack, we allocate on a per-CPU array map due to BPF stack limit of 512 bytes 8 | struct { 9 | __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY); 10 | __type(key, __u32); 11 | __type(value, struct l7_request); 12 | __uint(max_entries, 1); 13 | } l7_request_heap SEC(".maps"); 14 | 15 | // Instead of allocating on bpf stack, we allocate on a per-CPU array map due to BPF stack limit of 512 bytes 16 | struct { 17 | __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY); 18 | __type(key, __u32); 19 | __type(value, struct l7_event); 20 | __uint(max_entries, 1); 21 | } l7_event_heap SEC(".maps"); 22 | 23 | // To transfer read parameters from enter to exit 24 | struct { 25 | __uint(type, BPF_MAP_TYPE_HASH); 26 | __type(key, __u64); // pid_tgid 27 | __uint(value_size, sizeof(struct read_args)); 28 | __uint(max_entries, 10240); 29 | } active_reads SEC(".maps"); 30 | 31 | struct { 32 | __uint(type, BPF_MAP_TYPE_LRU_HASH); 33 | __uint(max_entries, 32768); 34 | __type(key, struct socket_key); 35 | __type(value, struct l7_request); 36 | } active_l7_requests SEC(".maps"); 37 | 38 | // Map to share l7 events with the userspace application 39 | struct { 40 | __uint(type, BPF_MAP_TYPE_PERF_EVENT_ARRAY); 41 | __uint(key_size, sizeof(int)); 42 | __uint(value_size, sizeof(int)); 43 | } l7_events SEC(".maps"); 44 | 45 | // Processing enter of write syscall triggered on the client side 46 | static __always_inline 47 | int process_enter_of_syscalls_write(void* ctx, __u64 fd, char* buf, __u64 payload_size){ 48 | 49 | // Retrieve the l7_request struct from the eBPF map (check above the map definition, why we use per-CPU array map for this purpose) 50 | int zero = 0; 51 | struct l7_request *req = bpf_map_lookup_elem(&l7_request_heap, &zero); 52 | if (!req) { 53 | return 0; 54 | } 55 | 56 | // Check if the L7 protocol is Postgres otherwise set to unknown 57 | req->protocol = PROTOCOL_UNKNOWN; 58 | req->method = METHOD_UNKNOWN; 59 | req->request_type = 0; 60 | if (buf) { 61 | if (parse_client_postgres_data(buf, payload_size, &req->request_type)) { 62 | bpf_printk("Client request type: %c\n", req->request_type); 63 | if (req->request_type == POSTGRES_MESSAGE_TERMINATE){ 64 | req->protocol = PROTOCOL_POSTGRES; 65 | req->method = METHOD_STATEMENT_CLOSE_OR_CONN_TERMINATE; 66 | } 67 | req->protocol = PROTOCOL_POSTGRES; 68 | } 69 | } 70 | 71 | // Copy the payload from the packet and check whether it fit below the MAX_PAYLOAD_SIZE 72 | bpf_probe_read(&req->payload, sizeof(req->payload), (const void *)buf); 73 | if (payload_size > MAX_PAYLOAD_SIZE) { 74 | // We werent able to copy all of it (setting payload_read_complete to 0) 75 | req->payload_size = MAX_PAYLOAD_SIZE; 76 | req->payload_read_complete = 0; 77 | } else { 78 | req->payload_size = payload_size; 79 | req->payload_read_complete = 1; 80 | } 81 | 82 | // Store active L7 request struct for later usage 83 | struct socket_key k = {}; 84 | __u64 id = bpf_get_current_pid_tgid(); 85 | k.pid = id >> 32; 86 | k.fd = fd; 87 | long res = bpf_map_update_elem(&active_l7_requests, &k, req, BPF_ANY); 88 | if (res < 0) { 89 | bpf_printk("Failed to store struct to active_l7_requests eBPF map"); 90 | } 91 | 92 | return 0; 93 | } 94 | 95 | // Processing enter of read syscall triggered on the server side 96 | static __always_inline 97 | int process_enter_of_syscalls_read(struct trace_event_raw_sys_enter_read *ctx) { 98 | __u64 id = bpf_get_current_pid_tgid(); 99 | 100 | // Store an active read struct for later usage 101 | struct read_args args = {}; 102 | args.fd = ctx->fd; 103 | args.buf = ctx->buf; 104 | args.size = ctx->count; 105 | long res = bpf_map_update_elem(&active_reads, &id, &args, BPF_ANY); 106 | if (res < 0) { 107 | bpf_printk("write to active_reads failed"); 108 | } 109 | 110 | return 0; 111 | } 112 | 113 | static __always_inline 114 | int process_exit_of_syscalls_read(void* ctx, __s64 ret) { 115 | __u64 id = bpf_get_current_pid_tgid(); 116 | __u32 pid = id >> 32; 117 | 118 | // Retrieve the active read struct from the enter of read syscall 119 | struct read_args *read_info = bpf_map_lookup_elem(&active_reads, &id); 120 | if (!read_info) { 121 | return 0; 122 | } 123 | 124 | // Retrieve the active L7 request struct from the write syscall 125 | struct socket_key k = {}; 126 | k.pid = pid; 127 | k.fd = read_info->fd; 128 | struct l7_request *active_req = bpf_map_lookup_elem(&active_l7_requests, &k); 129 | if (!active_req) { 130 | return 0; 131 | } 132 | 133 | // Retrieve the active L7 event struct from the eBPF map (check above the map definition, why we use per-CPU array map for this purpose) 134 | // This event struct is then forwarded to the userspace application 135 | int zero = 0; 136 | struct l7_event *e = bpf_map_lookup_elem(&l7_event_heap, &zero); 137 | if (!e) { 138 | bpf_map_delete_elem(&active_l7_requests, &k); 139 | bpf_map_delete_elem(&active_reads, &id); 140 | return 0; 141 | } 142 | e->fd = k.fd; 143 | e->pid = k.pid; 144 | e->method = active_req->method; 145 | e->protocol = active_req->protocol; 146 | e->payload_size = active_req->payload_size; 147 | e->payload_read_complete = active_req->payload_read_complete; 148 | bpf_probe_read(e->payload, MAX_PAYLOAD_SIZE, active_req->payload); 149 | 150 | if (read_info->buf) { 151 | if (e->protocol == PROTOCOL_POSTGRES) { 152 | e->status = parse_postgres_server_resp(read_info->buf, ret); 153 | if (active_req->request_type == POSTGRES_MESSAGE_SIMPLE_QUERY) { 154 | e->method = METHOD_SIMPLE_QUERY; 155 | bpf_printk("Simple Query read on the Server\n"); 156 | } else if (active_req->request_type == POSTGRES_MESSAGE_PARSE || active_req->request_type == POSTGRES_MESSAGE_BIND) { 157 | e->method = METHOD_EXTENDED_QUERY; 158 | bpf_printk("Extended Query read on the Server\n"); 159 | } 160 | } 161 | } else { 162 | bpf_map_delete_elem(&active_reads, &id); 163 | return 0; 164 | } 165 | 166 | // All data is now stored in the L7 Event and we can clean up the structs in the eBPF maps 167 | bpf_map_delete_elem(&active_reads, &id); 168 | bpf_map_delete_elem(&active_l7_requests, &k); 169 | 170 | // Forward L7 event to userspace application 171 | long r = bpf_perf_event_output(ctx, &l7_events, BPF_F_CURRENT_CPU, e, sizeof(*e)); 172 | if (r < 0) { 173 | bpf_printk("failed write to l7_events"); 174 | } 175 | 176 | return 0; 177 | } 178 | 179 | 180 | // /sys/kernel/debug/tracing/events/syscalls/sys_enter_write/format 181 | SEC("tracepoint/syscalls/sys_enter_write") 182 | int handle_write(struct trace_event_raw_sys_enter_write* ctx) { 183 | return process_enter_of_syscalls_write(ctx, ctx->fd, ctx->buf, ctx->count); 184 | } 185 | 186 | // /sys/kernel/debug/tracing/events/syscalls/sys_enter_read/format 187 | SEC("tracepoint/syscalls/sys_enter_read") 188 | int handle_read(struct trace_event_raw_sys_enter_read* ctx) { 189 | return process_enter_of_syscalls_read(ctx); 190 | } 191 | 192 | // /sys/kernel/debug/tracing/events/syscalls/sys_exit_read/format 193 | SEC("tracepoint/syscalls/sys_exit_read") 194 | int handle_read_exit(struct trace_event_raw_sys_exit_read* ctx) { 195 | return process_exit_of_syscalls_read(ctx, ctx->ret); 196 | } -------------------------------------------------------------------------------- /postgres.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #define MAX_PAYLOAD_SIZE 1024 7 | 8 | #define PROTOCOL_UNKNOWN 0 9 | #define PROTOCOL_POSTGRES 1 10 | 11 | #define METHOD_UNKNOWN 0 12 | #define METHOD_STATEMENT_CLOSE_OR_CONN_TERMINATE 1 13 | #define METHOD_SIMPLE_QUERY 2 14 | #define METHOD_EXTENDED_QUERY 3 15 | 16 | #define COMMAND_COMPLETE 1 17 | #define ERROR_RESPONSE 2 18 | 19 | // Q(1 byte), length(4 bytes), query(length-4 bytes) 20 | #define POSTGRES_MESSAGE_SIMPLE_QUERY 'Q' // 'Q' + 4 bytes of length + query 21 | 22 | // C(1 byte), length(4 bytes), Byte1('S' to close a prepared statement; or 'P' to close a portal), name of the prepared statement or portal(length-5 bytes) 23 | #define POSTGRES_MESSAGE_CLOSE 'C' 24 | 25 | // X(1 byte), length(4 bytes) 26 | #define POSTGRES_MESSAGE_TERMINATE 'X' 27 | 28 | // C(1 byte), length(4 bytes), tag(length-4 bytes) 29 | #define POSTGRES_MESSAGE_COMMAND_COMPLETION 'C' 30 | 31 | // prepared statement 32 | #define POSTGRES_MESSAGE_PARSE 'P' // 'P' + 4 bytes of length + query 33 | #define POSTGRES_MESSAGE_BIND 'B' // 'P' + 4 bytes of length + query 34 | 35 | struct trace_entry { 36 | short unsigned int type; 37 | unsigned char flags; 38 | unsigned char preempt_count; 39 | int pid; 40 | }; 41 | 42 | struct socket_key { 43 | __u64 fd; 44 | __u32 pid; 45 | __u8 is_tls; 46 | }; 47 | 48 | struct read_args { 49 | __u64 fd; 50 | char* buf; 51 | __u64 size; 52 | __u64 read_start_ns; 53 | }; 54 | 55 | struct trace_event_raw_sys_enter_write { 56 | struct trace_entry ent; 57 | __s32 __syscall_nr; 58 | __u64 fd; 59 | char * buf; 60 | __u64 count; 61 | }; 62 | 63 | struct trace_event_raw_sys_enter_read{ 64 | struct trace_entry ent; 65 | int __syscall_nr; 66 | unsigned long int fd; 67 | char * buf; 68 | __u64 count; 69 | }; 70 | 71 | struct trace_event_raw_sys_exit_read { 72 | __u64 unused; 73 | __s32 id; 74 | __s64 ret; 75 | }; 76 | 77 | struct l7_request { 78 | __u64 write_time_ns; 79 | __u8 protocol; 80 | __u8 method; 81 | unsigned char payload[MAX_PAYLOAD_SIZE]; 82 | __u32 payload_size; 83 | __u8 payload_read_complete; 84 | __u8 request_type; 85 | __u32 seq; 86 | __u32 tid; 87 | }; 88 | 89 | struct l7_event { 90 | __u64 fd; 91 | __u64 write_time_ns; 92 | __u32 pid; 93 | __u32 status; 94 | __u64 duration; 95 | __u8 protocol; 96 | __u8 method; 97 | __u16 padding; 98 | unsigned char payload[MAX_PAYLOAD_SIZE]; 99 | __u32 payload_size; 100 | __u8 payload_read_complete; 101 | __u8 failed; 102 | __u8 is_tls; 103 | __u32 seq; 104 | __u32 tid; 105 | }; 106 | 107 | // Used on the client side 108 | // Checks if the message is a postgresql Q, C, X message 109 | static __always_inline 110 | int parse_client_postgres_data(char *buf, int buf_size, __u8 *request_type) { 111 | // Return immeadiately if buffer is empty 112 | if (buf_size < 1) { 113 | return 0; 114 | } 115 | 116 | // Parse the first byte of the buffer 117 | // This is the identifier of the PostgresQL message 118 | char identifier; 119 | if (bpf_probe_read(&identifier, sizeof(identifier), (void *)((char *)buf)) < 0) { 120 | return 0; 121 | } 122 | 123 | // the next four bytes specify the length of the rest of the message 124 | __u32 len; 125 | if (bpf_probe_read(&len, sizeof(len), (void *)((char *)buf + 1)) < 0) { 126 | return 0; 127 | } 128 | 129 | // Connection termination has the Terminate identifier ("X") and the length is 4 bytes 130 | if (identifier == POSTGRES_MESSAGE_TERMINATE && bpf_htonl(len) == 4) { 131 | bpf_printk("Client will send Terminate packet\n"); 132 | *request_type = identifier; 133 | return 1; 134 | } 135 | 136 | // Simple Query Protocol 137 | if (identifier == POSTGRES_MESSAGE_SIMPLE_QUERY) { 138 | *request_type = identifier; 139 | bpf_printk("Client will send a Simple Query\n"); 140 | return 1; 141 | } 142 | 143 | // Extended Query Protocol (Prepared Statement) 144 | // > P/D/S (Parse/Describe/Sync) creating a prepared statement 145 | // > B/E/S (Bind/Execute/Sync) executing a prepared statement 146 | if (identifier == POSTGRES_MESSAGE_PARSE || identifier == POSTGRES_MESSAGE_BIND) { 147 | // Read last 5 bytes of the buffer (Sync message) 148 | char sync[5]; 149 | if (bpf_probe_read(&sync, sizeof(sync), (void *)((char *)buf + (buf_size - 5))) < 0) { 150 | return 0; 151 | } 152 | 153 | // Extended query protocol messages often end with a Sync (S) message. 154 | // Sync message is a 5 byte message with the first byte being 'S' and the rest indicating the length of the message, including self (4 bytes in this case - so no message body) 155 | if (sync[0] == 'S' && sync[1] == 0 && sync[2] == 0 && sync[3] == 0 && sync[4] == 4) { 156 | bpf_printk("Client will send an Extended Query\n"); 157 | *request_type = identifier; 158 | return 1; 159 | } 160 | } 161 | 162 | return 0; 163 | } 164 | 165 | static __always_inline 166 | __u32 parse_postgres_server_resp(char *buf, int buf_size) { 167 | // Return immeadiately if buffer is empty 168 | if (buf_size < 1) { 169 | return 0; 170 | } 171 | 172 | // Parse the first byte of the buffer 173 | // This is the identifier of the PostgresQL message 174 | char identifier; 175 | if (bpf_probe_read(&identifier, sizeof(identifier), (void *)((char *)buf)) < 0) { 176 | return 0; 177 | } 178 | 179 | // Identifies the message as an error. 180 | if (identifier == 'E') { 181 | return ERROR_RESPONSE; 182 | } 183 | 184 | // TODO: multiple pg messages can be in one packet, need to parse all of them and check if any of them is a command complete 185 | // assume C came if you see a T or D 186 | // when parsed C, it will have sql command in it (tag field, e.g. SELECT, INSERT, UPDATE, DELETE, CREATE, DROP, etc.) 187 | if (identifier == 't' || identifier == 'T' || identifier == 'D' || identifier == 'C') { 188 | return COMMAND_COMPLETE; 189 | } 190 | 191 | return 0; 192 | } -------------------------------------------------------------------------------- /test/client.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "database/sql" 6 | _ "github.com/lib/pq" 7 | ) 8 | 9 | const ( 10 | host = "localhost" 11 | port = 5432 12 | user = "postgres" 13 | password = "mysecretpassword" 14 | dbname = "mydb" 15 | ) 16 | 17 | func main() { 18 | // Connection string 19 | psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", 20 | host, port, user, password, dbname) 21 | 22 | // Connect to the PostgreSQL database 23 | db, err := sql.Open("postgres", psqlInfo) 24 | CheckError(err) 25 | defer db.Close() 26 | 27 | // Create table 28 | createTable := `CREATE TABLE "Students" ("id" serial primary key, "Name" TEXT, "Roll" INTEGER)` 29 | _, err = db.Exec(createTable) 30 | if err != nil { 31 | //fmt.Println(err) 32 | } else { 33 | fmt.Println("Students table created successfully") 34 | } 35 | 36 | // Insert data into the table 37 | insertStmt := `INSERT INTO "Students"("Name", "Roll") VALUES('John', 1)` 38 | _, err = db.Exec(insertStmt) 39 | CheckError(err) 40 | 41 | // Insert data into the table using dynamic SQL 42 | insertDynStmt := `INSERT INTO "Students"("Name", "Roll") VALUES($1, $2)` 43 | _, err = db.Exec(insertDynStmt, "Jane", 2) 44 | CheckError(err) 45 | 46 | // Update data in the table 47 | updateStmt := `UPDATE "Students" SET "Name"=$1, "Roll"=$2 WHERE "id"=$3` 48 | _, err = db.Exec(updateStmt, "Mary", 3, 2) 49 | CheckError(err) 50 | 51 | // Delete data from the table 52 | deleteStmt := `DELETE FROM "Students" WHERE id=$1` 53 | _, err = db.Exec(deleteStmt, 1) 54 | CheckError(err) 55 | 56 | rows, err := db.Query(`SELECT "Name", "Roll" FROM "Students"`) 57 | CheckError(err) 58 | 59 | defer rows.Close() 60 | for rows.Next() { 61 | var name string 62 | var roll int 63 | 64 | err = rows.Scan(&name, &roll) 65 | CheckError(err) 66 | } 67 | 68 | CheckError(err) 69 | fmt.Println("Client finished successfully.") 70 | } 71 | 72 | func CheckError(err error) { 73 | if err != nil { 74 | panic(err) 75 | } 76 | } -------------------------------------------------------------------------------- /test/go.mod: -------------------------------------------------------------------------------- 1 | module pg-client 2 | 3 | go 1.22.4 4 | 5 | require github.com/lib/pq v1.10.9 6 | -------------------------------------------------------------------------------- /test/go.sum: -------------------------------------------------------------------------------- 1 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 2 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 3 | -------------------------------------------------------------------------------- /test/postgresql.conf: -------------------------------------------------------------------------------- 1 | # PostgreSQL configuration file 2 | 3 | # Memory settings 4 | shared_buffers = 1GB # recommended: 25% of total memory 5 | effective_cache_size = 3GB # recommended: 75% of total memory 6 | work_mem = 64MB # recommended: 2-4MB per CPU core 7 | maintenance_work_mem = 512MB # recommended: 10% of total memory 8 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "bytes" 6 | "strings" 7 | ) 8 | 9 | // Order is important 10 | const ( 11 | BPF_L7_PROTOCOL_UNKNOWN = iota 12 | BPF_L7_PROTOCOL_POSTGRES 13 | ) 14 | 15 | const ( 16 | L7_PROTOCOL_POSTGRES = "POSTGRES" 17 | L7_PROTOCOL_UNKNOWN = "UNKNOWN" 18 | ) 19 | 20 | // Order is important 21 | const ( 22 | BPF_POSTGRES_METHOD_UNKNOWN = iota 23 | BPF_POSTGRES_METHOD_STATEMENT_CLOSE_OR_CONN_TERMINATE 24 | BPF_POSTGRES_METHOD_SIMPLE_QUERY 25 | BPF_POSTGRES_METHOD_EXTENDED_QUERY 26 | ) 27 | 28 | // for postgres, user space 29 | const ( 30 | CLOSE_OR_TERMINATE = "CLOSE_OR_TERMINATE" 31 | SIMPLE_QUERY = "SIMPLE_QUERY" 32 | EXTENDED_QUERY = "EXTENDED_QUERY" 33 | ) 34 | 35 | type L7Event struct { 36 | Fd uint64 37 | Pid uint32 38 | Status uint32 39 | Duration uint64 40 | Protocol string // L7_PROTOCOL_HTTP 41 | Tls bool // Whether request was encrypted 42 | Method string 43 | Payload [1024]uint8 44 | PayloadSize uint32 // How much of the payload was copied 45 | PayloadReadComplete bool // Whether the payload was copied completely 46 | Failed bool // Request failed 47 | WriteTimeNs uint64 // start time of write syscall 48 | Tid uint32 49 | Seq uint32 // tcp seq num 50 | EventReadTime int64 51 | } 52 | 53 | type bpfL7Event struct { 54 | Fd uint64 55 | WriteTimeNs uint64 56 | Pid uint32 57 | Status uint32 58 | Duration uint64 59 | Protocol uint8 60 | Method uint8 61 | Padding uint16 62 | Payload [1024]uint8 63 | PayloadSize uint32 64 | PayloadReadComplete uint8 65 | Failed uint8 66 | IsTls uint8 67 | _ [1]byte 68 | Seq uint32 69 | Tid uint32 70 | _ [4]byte 71 | } 72 | 73 | // Custom types for the enumeration 74 | type L7ProtocolConversion uint32 75 | type PostgresMethodConversion uint32 76 | 77 | // String representation of the enumeration values 78 | func (e L7ProtocolConversion) String() string { 79 | switch e { 80 | case BPF_L7_PROTOCOL_POSTGRES: 81 | return L7_PROTOCOL_POSTGRES 82 | case BPF_L7_PROTOCOL_UNKNOWN: 83 | return L7_PROTOCOL_UNKNOWN 84 | default: 85 | return "Unknown" 86 | } 87 | } 88 | 89 | // String representation of the enumeration values 90 | func (e PostgresMethodConversion) String() string { 91 | switch e { 92 | case BPF_POSTGRES_METHOD_STATEMENT_CLOSE_OR_CONN_TERMINATE: 93 | return CLOSE_OR_TERMINATE 94 | case BPF_POSTGRES_METHOD_SIMPLE_QUERY: 95 | return SIMPLE_QUERY 96 | case BPF_POSTGRES_METHOD_EXTENDED_QUERY: 97 | return EXTENDED_QUERY 98 | default: 99 | return "Unknown" 100 | } 101 | } 102 | 103 | func getKey(pid uint32, fd uint64, stmtName string) string { 104 | return fmt.Sprintf("%d-%d-%s", pid, fd, stmtName) 105 | } 106 | 107 | // Check if a string contains SQL keywords 108 | func containsSQLKeywords(input string) bool { 109 | return re.MatchString(strings.ToUpper(input)) 110 | } 111 | 112 | func parseSqlCommand(d *bpfL7Event, pgStatements *map[string]string) (string, error) { 113 | r := d.Payload[:d.PayloadSize] 114 | var sqlCommand string 115 | if PostgresMethodConversion(d.Method).String() == SIMPLE_QUERY { 116 | // SIMPLE_QUERY -> Q, 4 bytes of length, SQL command 117 | // Skip Q, (simple query) 118 | r = r[1:] 119 | 120 | // Skip 4 bytes of length 121 | r = r[4:] 122 | 123 | // Get sql command 124 | sqlCommand = string(r) 125 | 126 | // Garbage data can come for Postgres, we need to filter out 127 | // Search statement inside SQL keywords 128 | if !containsSQLKeywords(sqlCommand) { 129 | return "", fmt.Errorf("no sql command found") 130 | } 131 | } else if PostgresMethodConversion(d.Method).String() == EXTENDED_QUERY { 132 | id := r[0] 133 | switch id { 134 | case 'P': 135 | // EXTENDED_QUERY -> P, 4 bytes len, prepared statement name(str) (null terminated), query(str) (null terminated), parameters 136 | var stmtName string 137 | var query string 138 | vars := bytes.Split(r[5:], []byte{0}) 139 | if len(vars) >= 3 { 140 | stmtName = string(vars[0]) 141 | query = string(vars[1]) 142 | } else if len(vars) == 2 { // query too long for our buffer 143 | stmtName = string(vars[0]) 144 | query = string(vars[1]) + "..." 145 | } else { 146 | return "", fmt.Errorf("could not parse 'parse' frame for postgres") 147 | } 148 | 149 | (*pgStatements)[getKey(d.Pid, d.Fd, stmtName)] = query 150 | return fmt.Sprintf("PREPARE %s AS %s", stmtName, query), nil 151 | case 'B': 152 | // EXTENDED_QUERY -> B, 4 bytes len, portal str (null terminated), prepared statement name str (null terminated) 153 | var stmtName string 154 | vars := bytes.Split(r[5:], []byte{0}) 155 | if len(vars) >= 2 { 156 | stmtName = string(vars[1]) 157 | } else { 158 | return "", fmt.Errorf("could not parse bind frame for postgres") 159 | } 160 | 161 | query, ok := (*pgStatements)[getKey(d.Pid, d.Fd, stmtName)] 162 | if !ok || query == "" { // we don't have the query for the prepared statement 163 | // Execute (name of prepared statement) [(parameter)] 164 | return fmt.Sprintf("EXECUTE %s *values*", stmtName), nil 165 | } 166 | return query, nil 167 | default: 168 | return "", fmt.Errorf("could not parse extended query for postgres") 169 | } 170 | } else if PostgresMethodConversion(d.Method).String() == CLOSE_OR_TERMINATE { 171 | sqlCommand = string(r) 172 | } 173 | 174 | return sqlCommand, nil 175 | } --------------------------------------------------------------------------------