├── LICENSE
├── README.md
├── go.mod
├── go.sum
├── images
└── postgres.png
├── main.go
├── perf
├── go.mod
├── go.sum
├── measure.go
├── plot.py
└── postgresql.conf
├── postgres.c
├── postgres.h
├── test
├── client.go
├── go.mod
├── go.sum
└── postgresql.conf
└── util.go
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2024, Teodor Janez Podobnik
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PostgresQL eBPF
2 |
3 | This is a demo code, for showcasing Observability of the PostgreSQL protocol using eBPF. This code is inspired by Alaz, Kubernetes eBPF agent, developed by Anteon.
4 |
5 |
6 |
7 | In order to try it out locally:
8 |
9 | - Run eBPF program using
10 | ```
11 | go generate
12 | go build
13 | sudo ./postgres
14 | ```
15 | - Run the PostgresQL Container using
16 | ```
17 | docker run --name postgres-container -e POSTGRES_PASSWORD=mysecretpassword -d -p 5432:5432 postgres
18 | ```
19 | - Run client inside `/test` using
20 | ```
21 | go run client.go
22 | ```
23 | - In another shell, inspect eBPF program logs using
24 | ```
25 | sudo cat /sys/kernel/debug/tracing/trace_pipe
26 | ```
27 | - To run performance evaluation, inside `/perf` directory run:
28 | ```
29 | go run measure.go
30 | ```
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module postgres
2 |
3 | go 1.22.4
4 |
5 | require github.com/cilium/ebpf v0.15.0
6 |
7 | require (
8 | golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 // indirect
9 | golang.org/x/sys v0.15.0 // indirect
10 | )
11 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
2 | github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
3 | github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
4 | github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
5 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
6 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
7 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
8 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
9 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
10 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
11 | github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
12 | github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
13 | golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 h1:Jvc7gsqn21cJHCmAWx0LiimpP18LZmUxkT5Mp7EZ1mI=
14 | golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
15 | golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
16 | golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
17 |
--------------------------------------------------------------------------------
/images/postgres.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dorkamotorka/postgres-ebpf/43805aaed7d4d274667ecd34816e81e842cdceaa/images/postgres.png
--------------------------------------------------------------------------------
/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "os"
5 | "log"
6 | "unsafe"
7 | "regexp"
8 | "strings"
9 | "github.com/cilium/ebpf/rlimit"
10 | "github.com/cilium/ebpf/link"
11 | "github.com/cilium/ebpf/perf"
12 | )
13 |
14 | //go:generate go run github.com/cilium/ebpf/cmd/bpf2go postgres postgres.c
15 |
16 | var re *regexp.Regexp
17 | var keywords = []string{"SELECT", "INSERT INTO", "UPDATE", "DELETE FROM", "CREATE TABLE", "ALTER TABLE", "DROP TABLE", "TRUNCATE TABLE", "BEGIN", "COMMIT", "ROLLBACK", "SAVEPOINT", "CREATE INDEX", "DROP INDEX", "CREATE VIEW", "DROP VIEW", "GRANT", "REVOKE", "EXECUTE"}
18 | var pgObjs postgresObjects
19 |
20 | func main() {
21 | // Allow the current process to lock memory for eBPF resources.
22 | if err := rlimit.RemoveMemlock(); err != nil {
23 | log.Fatal(err)
24 | }
25 |
26 | // Load pre-compiled programs and maps into the kernel.
27 | pgObjs = postgresObjects{}
28 | if err := loadPostgresObjects(&pgObjs, nil); err != nil {
29 | log.Fatal(err)
30 | }
31 |
32 | w, err := link.Tracepoint("syscalls", "sys_enter_write", pgObjs.HandleWrite, nil)
33 | if err != nil {
34 | log.Fatal("link sys_enter_write tracepoint")
35 | }
36 | defer w.Close()
37 |
38 | r, err := link.Tracepoint("syscalls", "sys_enter_read", pgObjs.HandleRead, nil)
39 | if err != nil {
40 | log.Fatal("link sys_enter_read tracepoint")
41 | }
42 | defer r.Close()
43 |
44 | rexit, err := link.Tracepoint("syscalls", "sys_exit_read", pgObjs.HandleReadExit, nil)
45 | if err != nil {
46 | log.Fatal("link sys_exit_read tracepoint")
47 | }
48 | defer rexit.Close()
49 |
50 | L7EventsReader, err := perf.NewReader(pgObjs.L7Events, int(4096)*os.Getpagesize())
51 | if err != nil {
52 | log.Fatal("error creating perf event array reader")
53 | }
54 |
55 | // Case-insensitive matching
56 | re = regexp.MustCompile(strings.Join(keywords, "|"))
57 | pgStatements := make(map[string]string)
58 |
59 | for {
60 | var record perf.Record
61 | err := L7EventsReader.ReadInto(&record)
62 | if err != nil {
63 | log.Print("error reading from perf array")
64 | }
65 |
66 | if record.LostSamples != 0 {
67 | log.Printf("lost samples l7-event %d", record.LostSamples)
68 | }
69 |
70 | if record.RawSample == nil || len(record.RawSample) == 0 {
71 | log.Print("read sample l7-event nil or empty")
72 | return
73 | }
74 |
75 | l7Event := (*bpfL7Event)(unsafe.Pointer(&record.RawSample[0]))
76 |
77 | protocol := L7ProtocolConversion(l7Event.Protocol).String()
78 |
79 | // copy payload slice
80 | payload := [1024]uint8{}
81 | copy(payload[:], l7Event.Payload[:])
82 |
83 | if (protocol == "POSTGRES") {
84 | out, err := parseSqlCommand(l7Event, &pgStatements)
85 | if err != nil {
86 | log.Printf("Error parsing sql command: %s", err)
87 | } else {
88 | log.Printf("%s", out)
89 | }
90 | }
91 | }
92 | }
--------------------------------------------------------------------------------
/perf/go.mod:
--------------------------------------------------------------------------------
1 | module perf-postgres
2 |
3 | go 1.22.4
4 |
5 | require github.com/lib/pq v1.10.9
6 |
--------------------------------------------------------------------------------
/perf/go.sum:
--------------------------------------------------------------------------------
1 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
2 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
3 |
--------------------------------------------------------------------------------
/perf/measure.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "time"
6 | "database/sql"
7 | _ "github.com/lib/pq"
8 | )
9 |
10 | const (
11 | host = "localhost"
12 | port = 5432
13 | user = "postgres"
14 | password = "mysecretpassword"
15 | dbname = "mydb"
16 | )
17 |
18 | func main() {
19 | // Connection string
20 | psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
21 | host, port, user, password, dbname)
22 |
23 | // Connect to the PostgreSQL database
24 | db, err := sql.Open("postgres", psqlInfo)
25 | CheckError(err)
26 | defer db.Close()
27 |
28 | // Create table
29 | createTable := `CREATE TABLE IF NOT EXISTS "Students" ("id" serial primary key, "Name" TEXT, "Roll" INTEGER)`
30 | _, err = db.Exec(createTable)
31 | CheckError(err)
32 |
33 | const repeat = 1000
34 | var totalInsertTime, totalUpdateTime, totalDeleteTime, totalQueryTime time.Duration
35 |
36 | for i := 0; i < repeat; i++ {
37 | start := time.Now()
38 | // Insert data into the table
39 | insertStmt := `INSERT INTO "Students"("Name", "Roll") VALUES('John', 1)`
40 | _, err = db.Exec(insertStmt)
41 | elapsed := time.Since(start)
42 | totalInsertTime += elapsed
43 | CheckError(err)
44 |
45 | start = time.Now()
46 | // Insert data into the table using dynamic SQL
47 | insertDynStmt := `INSERT INTO "Students"("Name", "Roll") VALUES($1, $2)`
48 | _, err = db.Exec(insertDynStmt, "Jane", 2)
49 | elapsed = time.Since(start)
50 | totalInsertTime += elapsed
51 | CheckError(err)
52 |
53 | start = time.Now()
54 | // Update data in the table
55 | updateStmt := `UPDATE "Students" SET "Name"=$1, "Roll"=$2 WHERE "id"=$3`
56 | _, err = db.Exec(updateStmt, "Mary", 3, 2)
57 | elapsed = time.Since(start)
58 | totalUpdateTime += elapsed
59 | CheckError(err)
60 |
61 | start = time.Now()
62 | // Delete data from the table
63 | deleteStmt := `DELETE FROM "Students" WHERE id=$1`
64 | _, err = db.Exec(deleteStmt, 1)
65 | elapsed = time.Since(start)
66 | totalDeleteTime += elapsed
67 | CheckError(err)
68 |
69 | start = time.Now()
70 | rows, err := db.Query(`SELECT "Name", "Roll" FROM "Students"`)
71 | elapsed = time.Since(start)
72 | totalQueryTime += elapsed
73 | CheckError(err)
74 |
75 | defer rows.Close()
76 | for rows.Next() {
77 | var name string
78 | var roll int
79 | err = rows.Scan(&name, &roll)
80 | CheckError(err)
81 | }
82 | CheckError(rows.Err())
83 |
84 | // Clean up table for next iteration
85 | _, err = db.Exec(`TRUNCATE "Students" RESTART IDENTITY`)
86 | CheckError(err)
87 | }
88 |
89 | fmt.Printf("Average INSERT latency: %v\n", totalInsertTime/(2*time.Duration(repeat)))
90 | fmt.Printf("Average UPDATE latency: %v\n", totalUpdateTime/time.Duration(repeat))
91 | fmt.Printf("Average DELETE latency: %v\n", totalDeleteTime/time.Duration(repeat))
92 | fmt.Printf("Average QUERY latency: %v\n", totalQueryTime/time.Duration(repeat))
93 | fmt.Println("Perfomance test finished successfully.")
94 | }
95 |
96 | func CheckError(err error) {
97 | if err != nil {
98 | panic(err)
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/perf/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 | # Latency values in milliseconds
4 | operations = ['INSERT', 'UPDATE', 'DELETE', 'QUERY']
5 | latency_with_ebpf = [1.5487, 1.3383, 1.2542, 0.3433610]
6 | latency_without_ebpf = [1.3566, 1.1601, 1.0735, 0.2975380]
7 |
8 | x = range(len(operations))
9 |
10 | # Create the plot
11 | fig, ax = plt.subplots()
12 |
13 | # Plotting the values
14 | ax.bar(x, latency_with_ebpf, width=0.4, label='With eBPF', align='center')
15 | ax.bar(x, latency_without_ebpf, width=0.4, label='Without eBPF', align='edge')
16 |
17 | # Adding labels and title
18 | ax.set_xlabel('Operation')
19 | ax.set_ylabel('Latency (ms)')
20 | ax.set_title('PostgreSQL Latency Comparison with and without eBPF')
21 | ax.set_xticks(x)
22 | ax.set_xticklabels(operations)
23 | ax.legend()
24 |
25 | # Display the plot
26 | plt.tight_layout()
27 | plt.show()
--------------------------------------------------------------------------------
/perf/postgresql.conf:
--------------------------------------------------------------------------------
1 | # PostgreSQL configuration file - postgresql.conf
2 |
3 | shared_buffers = 1GB # recommended: 25% of total memory
4 | effective_cache_size = 3GB # recommended: 75% of total memory
5 | work_mem = 64MB # recommended: 2-4MB per CPU core
6 | maintenance_work_mem = 512MB # recommended: 10% of total memory
7 |
--------------------------------------------------------------------------------
/postgres.c:
--------------------------------------------------------------------------------
1 | //go:build ignore
2 |
3 | #include "postgres.h"
4 |
5 | char LICENSE[] SEC("license") = "Dual BSD/GPL";
6 |
7 | // Instead of allocating on bpf stack, we allocate on a per-CPU array map due to BPF stack limit of 512 bytes
8 | struct {
9 | __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
10 | __type(key, __u32);
11 | __type(value, struct l7_request);
12 | __uint(max_entries, 1);
13 | } l7_request_heap SEC(".maps");
14 |
15 | // Instead of allocating on bpf stack, we allocate on a per-CPU array map due to BPF stack limit of 512 bytes
16 | struct {
17 | __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
18 | __type(key, __u32);
19 | __type(value, struct l7_event);
20 | __uint(max_entries, 1);
21 | } l7_event_heap SEC(".maps");
22 |
23 | // To transfer read parameters from enter to exit
24 | struct {
25 | __uint(type, BPF_MAP_TYPE_HASH);
26 | __type(key, __u64); // pid_tgid
27 | __uint(value_size, sizeof(struct read_args));
28 | __uint(max_entries, 10240);
29 | } active_reads SEC(".maps");
30 |
31 | struct {
32 | __uint(type, BPF_MAP_TYPE_LRU_HASH);
33 | __uint(max_entries, 32768);
34 | __type(key, struct socket_key);
35 | __type(value, struct l7_request);
36 | } active_l7_requests SEC(".maps");
37 |
38 | // Map to share l7 events with the userspace application
39 | struct {
40 | __uint(type, BPF_MAP_TYPE_PERF_EVENT_ARRAY);
41 | __uint(key_size, sizeof(int));
42 | __uint(value_size, sizeof(int));
43 | } l7_events SEC(".maps");
44 |
45 | // Processing enter of write syscall triggered on the client side
46 | static __always_inline
47 | int process_enter_of_syscalls_write(void* ctx, __u64 fd, char* buf, __u64 payload_size){
48 |
49 | // Retrieve the l7_request struct from the eBPF map (check above the map definition, why we use per-CPU array map for this purpose)
50 | int zero = 0;
51 | struct l7_request *req = bpf_map_lookup_elem(&l7_request_heap, &zero);
52 | if (!req) {
53 | return 0;
54 | }
55 |
56 | // Check if the L7 protocol is Postgres otherwise set to unknown
57 | req->protocol = PROTOCOL_UNKNOWN;
58 | req->method = METHOD_UNKNOWN;
59 | req->request_type = 0;
60 | if (buf) {
61 | if (parse_client_postgres_data(buf, payload_size, &req->request_type)) {
62 | bpf_printk("Client request type: %c\n", req->request_type);
63 | if (req->request_type == POSTGRES_MESSAGE_TERMINATE){
64 | req->protocol = PROTOCOL_POSTGRES;
65 | req->method = METHOD_STATEMENT_CLOSE_OR_CONN_TERMINATE;
66 | }
67 | req->protocol = PROTOCOL_POSTGRES;
68 | }
69 | }
70 |
71 | // Copy the payload from the packet and check whether it fit below the MAX_PAYLOAD_SIZE
72 | bpf_probe_read(&req->payload, sizeof(req->payload), (const void *)buf);
73 | if (payload_size > MAX_PAYLOAD_SIZE) {
74 | // We werent able to copy all of it (setting payload_read_complete to 0)
75 | req->payload_size = MAX_PAYLOAD_SIZE;
76 | req->payload_read_complete = 0;
77 | } else {
78 | req->payload_size = payload_size;
79 | req->payload_read_complete = 1;
80 | }
81 |
82 | // Store active L7 request struct for later usage
83 | struct socket_key k = {};
84 | __u64 id = bpf_get_current_pid_tgid();
85 | k.pid = id >> 32;
86 | k.fd = fd;
87 | long res = bpf_map_update_elem(&active_l7_requests, &k, req, BPF_ANY);
88 | if (res < 0) {
89 | bpf_printk("Failed to store struct to active_l7_requests eBPF map");
90 | }
91 |
92 | return 0;
93 | }
94 |
95 | // Processing enter of read syscall triggered on the server side
96 | static __always_inline
97 | int process_enter_of_syscalls_read(struct trace_event_raw_sys_enter_read *ctx) {
98 | __u64 id = bpf_get_current_pid_tgid();
99 |
100 | // Store an active read struct for later usage
101 | struct read_args args = {};
102 | args.fd = ctx->fd;
103 | args.buf = ctx->buf;
104 | args.size = ctx->count;
105 | long res = bpf_map_update_elem(&active_reads, &id, &args, BPF_ANY);
106 | if (res < 0) {
107 | bpf_printk("write to active_reads failed");
108 | }
109 |
110 | return 0;
111 | }
112 |
113 | static __always_inline
114 | int process_exit_of_syscalls_read(void* ctx, __s64 ret) {
115 | __u64 id = bpf_get_current_pid_tgid();
116 | __u32 pid = id >> 32;
117 |
118 | // Retrieve the active read struct from the enter of read syscall
119 | struct read_args *read_info = bpf_map_lookup_elem(&active_reads, &id);
120 | if (!read_info) {
121 | return 0;
122 | }
123 |
124 | // Retrieve the active L7 request struct from the write syscall
125 | struct socket_key k = {};
126 | k.pid = pid;
127 | k.fd = read_info->fd;
128 | struct l7_request *active_req = bpf_map_lookup_elem(&active_l7_requests, &k);
129 | if (!active_req) {
130 | return 0;
131 | }
132 |
133 | // Retrieve the active L7 event struct from the eBPF map (check above the map definition, why we use per-CPU array map for this purpose)
134 | // This event struct is then forwarded to the userspace application
135 | int zero = 0;
136 | struct l7_event *e = bpf_map_lookup_elem(&l7_event_heap, &zero);
137 | if (!e) {
138 | bpf_map_delete_elem(&active_l7_requests, &k);
139 | bpf_map_delete_elem(&active_reads, &id);
140 | return 0;
141 | }
142 | e->fd = k.fd;
143 | e->pid = k.pid;
144 | e->method = active_req->method;
145 | e->protocol = active_req->protocol;
146 | e->payload_size = active_req->payload_size;
147 | e->payload_read_complete = active_req->payload_read_complete;
148 | bpf_probe_read(e->payload, MAX_PAYLOAD_SIZE, active_req->payload);
149 |
150 | if (read_info->buf) {
151 | if (e->protocol == PROTOCOL_POSTGRES) {
152 | e->status = parse_postgres_server_resp(read_info->buf, ret);
153 | if (active_req->request_type == POSTGRES_MESSAGE_SIMPLE_QUERY) {
154 | e->method = METHOD_SIMPLE_QUERY;
155 | bpf_printk("Simple Query read on the Server\n");
156 | } else if (active_req->request_type == POSTGRES_MESSAGE_PARSE || active_req->request_type == POSTGRES_MESSAGE_BIND) {
157 | e->method = METHOD_EXTENDED_QUERY;
158 | bpf_printk("Extended Query read on the Server\n");
159 | }
160 | }
161 | } else {
162 | bpf_map_delete_elem(&active_reads, &id);
163 | return 0;
164 | }
165 |
166 | // All data is now stored in the L7 Event and we can clean up the structs in the eBPF maps
167 | bpf_map_delete_elem(&active_reads, &id);
168 | bpf_map_delete_elem(&active_l7_requests, &k);
169 |
170 | // Forward L7 event to userspace application
171 | long r = bpf_perf_event_output(ctx, &l7_events, BPF_F_CURRENT_CPU, e, sizeof(*e));
172 | if (r < 0) {
173 | bpf_printk("failed write to l7_events");
174 | }
175 |
176 | return 0;
177 | }
178 |
179 |
180 | // /sys/kernel/debug/tracing/events/syscalls/sys_enter_write/format
181 | SEC("tracepoint/syscalls/sys_enter_write")
182 | int handle_write(struct trace_event_raw_sys_enter_write* ctx) {
183 | return process_enter_of_syscalls_write(ctx, ctx->fd, ctx->buf, ctx->count);
184 | }
185 |
186 | // /sys/kernel/debug/tracing/events/syscalls/sys_enter_read/format
187 | SEC("tracepoint/syscalls/sys_enter_read")
188 | int handle_read(struct trace_event_raw_sys_enter_read* ctx) {
189 | return process_enter_of_syscalls_read(ctx);
190 | }
191 |
192 | // /sys/kernel/debug/tracing/events/syscalls/sys_exit_read/format
193 | SEC("tracepoint/syscalls/sys_exit_read")
194 | int handle_read_exit(struct trace_event_raw_sys_exit_read* ctx) {
195 | return process_exit_of_syscalls_read(ctx, ctx->ret);
196 | }
--------------------------------------------------------------------------------
/postgres.h:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #define MAX_PAYLOAD_SIZE 1024
7 |
8 | #define PROTOCOL_UNKNOWN 0
9 | #define PROTOCOL_POSTGRES 1
10 |
11 | #define METHOD_UNKNOWN 0
12 | #define METHOD_STATEMENT_CLOSE_OR_CONN_TERMINATE 1
13 | #define METHOD_SIMPLE_QUERY 2
14 | #define METHOD_EXTENDED_QUERY 3
15 |
16 | #define COMMAND_COMPLETE 1
17 | #define ERROR_RESPONSE 2
18 |
19 | // Q(1 byte), length(4 bytes), query(length-4 bytes)
20 | #define POSTGRES_MESSAGE_SIMPLE_QUERY 'Q' // 'Q' + 4 bytes of length + query
21 |
22 | // C(1 byte), length(4 bytes), Byte1('S' to close a prepared statement; or 'P' to close a portal), name of the prepared statement or portal(length-5 bytes)
23 | #define POSTGRES_MESSAGE_CLOSE 'C'
24 |
25 | // X(1 byte), length(4 bytes)
26 | #define POSTGRES_MESSAGE_TERMINATE 'X'
27 |
28 | // C(1 byte), length(4 bytes), tag(length-4 bytes)
29 | #define POSTGRES_MESSAGE_COMMAND_COMPLETION 'C'
30 |
31 | // prepared statement
32 | #define POSTGRES_MESSAGE_PARSE 'P' // 'P' + 4 bytes of length + query
33 | #define POSTGRES_MESSAGE_BIND 'B' // 'P' + 4 bytes of length + query
34 |
35 | struct trace_entry {
36 | short unsigned int type;
37 | unsigned char flags;
38 | unsigned char preempt_count;
39 | int pid;
40 | };
41 |
42 | struct socket_key {
43 | __u64 fd;
44 | __u32 pid;
45 | __u8 is_tls;
46 | };
47 |
48 | struct read_args {
49 | __u64 fd;
50 | char* buf;
51 | __u64 size;
52 | __u64 read_start_ns;
53 | };
54 |
55 | struct trace_event_raw_sys_enter_write {
56 | struct trace_entry ent;
57 | __s32 __syscall_nr;
58 | __u64 fd;
59 | char * buf;
60 | __u64 count;
61 | };
62 |
63 | struct trace_event_raw_sys_enter_read{
64 | struct trace_entry ent;
65 | int __syscall_nr;
66 | unsigned long int fd;
67 | char * buf;
68 | __u64 count;
69 | };
70 |
71 | struct trace_event_raw_sys_exit_read {
72 | __u64 unused;
73 | __s32 id;
74 | __s64 ret;
75 | };
76 |
77 | struct l7_request {
78 | __u64 write_time_ns;
79 | __u8 protocol;
80 | __u8 method;
81 | unsigned char payload[MAX_PAYLOAD_SIZE];
82 | __u32 payload_size;
83 | __u8 payload_read_complete;
84 | __u8 request_type;
85 | __u32 seq;
86 | __u32 tid;
87 | };
88 |
89 | struct l7_event {
90 | __u64 fd;
91 | __u64 write_time_ns;
92 | __u32 pid;
93 | __u32 status;
94 | __u64 duration;
95 | __u8 protocol;
96 | __u8 method;
97 | __u16 padding;
98 | unsigned char payload[MAX_PAYLOAD_SIZE];
99 | __u32 payload_size;
100 | __u8 payload_read_complete;
101 | __u8 failed;
102 | __u8 is_tls;
103 | __u32 seq;
104 | __u32 tid;
105 | };
106 |
107 | // Used on the client side
108 | // Checks if the message is a postgresql Q, C, X message
109 | static __always_inline
110 | int parse_client_postgres_data(char *buf, int buf_size, __u8 *request_type) {
111 | // Return immeadiately if buffer is empty
112 | if (buf_size < 1) {
113 | return 0;
114 | }
115 |
116 | // Parse the first byte of the buffer
117 | // This is the identifier of the PostgresQL message
118 | char identifier;
119 | if (bpf_probe_read(&identifier, sizeof(identifier), (void *)((char *)buf)) < 0) {
120 | return 0;
121 | }
122 |
123 | // the next four bytes specify the length of the rest of the message
124 | __u32 len;
125 | if (bpf_probe_read(&len, sizeof(len), (void *)((char *)buf + 1)) < 0) {
126 | return 0;
127 | }
128 |
129 | // Connection termination has the Terminate identifier ("X") and the length is 4 bytes
130 | if (identifier == POSTGRES_MESSAGE_TERMINATE && bpf_htonl(len) == 4) {
131 | bpf_printk("Client will send Terminate packet\n");
132 | *request_type = identifier;
133 | return 1;
134 | }
135 |
136 | // Simple Query Protocol
137 | if (identifier == POSTGRES_MESSAGE_SIMPLE_QUERY) {
138 | *request_type = identifier;
139 | bpf_printk("Client will send a Simple Query\n");
140 | return 1;
141 | }
142 |
143 | // Extended Query Protocol (Prepared Statement)
144 | // > P/D/S (Parse/Describe/Sync) creating a prepared statement
145 | // > B/E/S (Bind/Execute/Sync) executing a prepared statement
146 | if (identifier == POSTGRES_MESSAGE_PARSE || identifier == POSTGRES_MESSAGE_BIND) {
147 | // Read last 5 bytes of the buffer (Sync message)
148 | char sync[5];
149 | if (bpf_probe_read(&sync, sizeof(sync), (void *)((char *)buf + (buf_size - 5))) < 0) {
150 | return 0;
151 | }
152 |
153 | // Extended query protocol messages often end with a Sync (S) message.
154 | // Sync message is a 5 byte message with the first byte being 'S' and the rest indicating the length of the message, including self (4 bytes in this case - so no message body)
155 | if (sync[0] == 'S' && sync[1] == 0 && sync[2] == 0 && sync[3] == 0 && sync[4] == 4) {
156 | bpf_printk("Client will send an Extended Query\n");
157 | *request_type = identifier;
158 | return 1;
159 | }
160 | }
161 |
162 | return 0;
163 | }
164 |
165 | static __always_inline
166 | __u32 parse_postgres_server_resp(char *buf, int buf_size) {
167 | // Return immeadiately if buffer is empty
168 | if (buf_size < 1) {
169 | return 0;
170 | }
171 |
172 | // Parse the first byte of the buffer
173 | // This is the identifier of the PostgresQL message
174 | char identifier;
175 | if (bpf_probe_read(&identifier, sizeof(identifier), (void *)((char *)buf)) < 0) {
176 | return 0;
177 | }
178 |
179 | // Identifies the message as an error.
180 | if (identifier == 'E') {
181 | return ERROR_RESPONSE;
182 | }
183 |
184 | // TODO: multiple pg messages can be in one packet, need to parse all of them and check if any of them is a command complete
185 | // assume C came if you see a T or D
186 | // when parsed C, it will have sql command in it (tag field, e.g. SELECT, INSERT, UPDATE, DELETE, CREATE, DROP, etc.)
187 | if (identifier == 't' || identifier == 'T' || identifier == 'D' || identifier == 'C') {
188 | return COMMAND_COMPLETE;
189 | }
190 |
191 | return 0;
192 | }
--------------------------------------------------------------------------------
/test/client.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "database/sql"
6 | _ "github.com/lib/pq"
7 | )
8 |
9 | const (
10 | host = "localhost"
11 | port = 5432
12 | user = "postgres"
13 | password = "mysecretpassword"
14 | dbname = "mydb"
15 | )
16 |
17 | func main() {
18 | // Connection string
19 | psqlInfo := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
20 | host, port, user, password, dbname)
21 |
22 | // Connect to the PostgreSQL database
23 | db, err := sql.Open("postgres", psqlInfo)
24 | CheckError(err)
25 | defer db.Close()
26 |
27 | // Create table
28 | createTable := `CREATE TABLE "Students" ("id" serial primary key, "Name" TEXT, "Roll" INTEGER)`
29 | _, err = db.Exec(createTable)
30 | if err != nil {
31 | //fmt.Println(err)
32 | } else {
33 | fmt.Println("Students table created successfully")
34 | }
35 |
36 | // Insert data into the table
37 | insertStmt := `INSERT INTO "Students"("Name", "Roll") VALUES('John', 1)`
38 | _, err = db.Exec(insertStmt)
39 | CheckError(err)
40 |
41 | // Insert data into the table using dynamic SQL
42 | insertDynStmt := `INSERT INTO "Students"("Name", "Roll") VALUES($1, $2)`
43 | _, err = db.Exec(insertDynStmt, "Jane", 2)
44 | CheckError(err)
45 |
46 | // Update data in the table
47 | updateStmt := `UPDATE "Students" SET "Name"=$1, "Roll"=$2 WHERE "id"=$3`
48 | _, err = db.Exec(updateStmt, "Mary", 3, 2)
49 | CheckError(err)
50 |
51 | // Delete data from the table
52 | deleteStmt := `DELETE FROM "Students" WHERE id=$1`
53 | _, err = db.Exec(deleteStmt, 1)
54 | CheckError(err)
55 |
56 | rows, err := db.Query(`SELECT "Name", "Roll" FROM "Students"`)
57 | CheckError(err)
58 |
59 | defer rows.Close()
60 | for rows.Next() {
61 | var name string
62 | var roll int
63 |
64 | err = rows.Scan(&name, &roll)
65 | CheckError(err)
66 | }
67 |
68 | CheckError(err)
69 | fmt.Println("Client finished successfully.")
70 | }
71 |
72 | func CheckError(err error) {
73 | if err != nil {
74 | panic(err)
75 | }
76 | }
--------------------------------------------------------------------------------
/test/go.mod:
--------------------------------------------------------------------------------
1 | module pg-client
2 |
3 | go 1.22.4
4 |
5 | require github.com/lib/pq v1.10.9
6 |
--------------------------------------------------------------------------------
/test/go.sum:
--------------------------------------------------------------------------------
1 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
2 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
3 |
--------------------------------------------------------------------------------
/test/postgresql.conf:
--------------------------------------------------------------------------------
1 | # PostgreSQL configuration file
2 |
3 | # Memory settings
4 | shared_buffers = 1GB # recommended: 25% of total memory
5 | effective_cache_size = 3GB # recommended: 75% of total memory
6 | work_mem = 64MB # recommended: 2-4MB per CPU core
7 | maintenance_work_mem = 512MB # recommended: 10% of total memory
8 |
--------------------------------------------------------------------------------
/util.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "bytes"
6 | "strings"
7 | )
8 |
9 | // Order is important
10 | const (
11 | BPF_L7_PROTOCOL_UNKNOWN = iota
12 | BPF_L7_PROTOCOL_POSTGRES
13 | )
14 |
15 | const (
16 | L7_PROTOCOL_POSTGRES = "POSTGRES"
17 | L7_PROTOCOL_UNKNOWN = "UNKNOWN"
18 | )
19 |
20 | // Order is important
21 | const (
22 | BPF_POSTGRES_METHOD_UNKNOWN = iota
23 | BPF_POSTGRES_METHOD_STATEMENT_CLOSE_OR_CONN_TERMINATE
24 | BPF_POSTGRES_METHOD_SIMPLE_QUERY
25 | BPF_POSTGRES_METHOD_EXTENDED_QUERY
26 | )
27 |
28 | // for postgres, user space
29 | const (
30 | CLOSE_OR_TERMINATE = "CLOSE_OR_TERMINATE"
31 | SIMPLE_QUERY = "SIMPLE_QUERY"
32 | EXTENDED_QUERY = "EXTENDED_QUERY"
33 | )
34 |
35 | type L7Event struct {
36 | Fd uint64
37 | Pid uint32
38 | Status uint32
39 | Duration uint64
40 | Protocol string // L7_PROTOCOL_HTTP
41 | Tls bool // Whether request was encrypted
42 | Method string
43 | Payload [1024]uint8
44 | PayloadSize uint32 // How much of the payload was copied
45 | PayloadReadComplete bool // Whether the payload was copied completely
46 | Failed bool // Request failed
47 | WriteTimeNs uint64 // start time of write syscall
48 | Tid uint32
49 | Seq uint32 // tcp seq num
50 | EventReadTime int64
51 | }
52 |
53 | type bpfL7Event struct {
54 | Fd uint64
55 | WriteTimeNs uint64
56 | Pid uint32
57 | Status uint32
58 | Duration uint64
59 | Protocol uint8
60 | Method uint8
61 | Padding uint16
62 | Payload [1024]uint8
63 | PayloadSize uint32
64 | PayloadReadComplete uint8
65 | Failed uint8
66 | IsTls uint8
67 | _ [1]byte
68 | Seq uint32
69 | Tid uint32
70 | _ [4]byte
71 | }
72 |
73 | // Custom types for the enumeration
74 | type L7ProtocolConversion uint32
75 | type PostgresMethodConversion uint32
76 |
77 | // String representation of the enumeration values
78 | func (e L7ProtocolConversion) String() string {
79 | switch e {
80 | case BPF_L7_PROTOCOL_POSTGRES:
81 | return L7_PROTOCOL_POSTGRES
82 | case BPF_L7_PROTOCOL_UNKNOWN:
83 | return L7_PROTOCOL_UNKNOWN
84 | default:
85 | return "Unknown"
86 | }
87 | }
88 |
89 | // String representation of the enumeration values
90 | func (e PostgresMethodConversion) String() string {
91 | switch e {
92 | case BPF_POSTGRES_METHOD_STATEMENT_CLOSE_OR_CONN_TERMINATE:
93 | return CLOSE_OR_TERMINATE
94 | case BPF_POSTGRES_METHOD_SIMPLE_QUERY:
95 | return SIMPLE_QUERY
96 | case BPF_POSTGRES_METHOD_EXTENDED_QUERY:
97 | return EXTENDED_QUERY
98 | default:
99 | return "Unknown"
100 | }
101 | }
102 |
103 | func getKey(pid uint32, fd uint64, stmtName string) string {
104 | return fmt.Sprintf("%d-%d-%s", pid, fd, stmtName)
105 | }
106 |
107 | // Check if a string contains SQL keywords
108 | func containsSQLKeywords(input string) bool {
109 | return re.MatchString(strings.ToUpper(input))
110 | }
111 |
112 | func parseSqlCommand(d *bpfL7Event, pgStatements *map[string]string) (string, error) {
113 | r := d.Payload[:d.PayloadSize]
114 | var sqlCommand string
115 | if PostgresMethodConversion(d.Method).String() == SIMPLE_QUERY {
116 | // SIMPLE_QUERY -> Q, 4 bytes of length, SQL command
117 | // Skip Q, (simple query)
118 | r = r[1:]
119 |
120 | // Skip 4 bytes of length
121 | r = r[4:]
122 |
123 | // Get sql command
124 | sqlCommand = string(r)
125 |
126 | // Garbage data can come for Postgres, we need to filter out
127 | // Search statement inside SQL keywords
128 | if !containsSQLKeywords(sqlCommand) {
129 | return "", fmt.Errorf("no sql command found")
130 | }
131 | } else if PostgresMethodConversion(d.Method).String() == EXTENDED_QUERY {
132 | id := r[0]
133 | switch id {
134 | case 'P':
135 | // EXTENDED_QUERY -> P, 4 bytes len, prepared statement name(str) (null terminated), query(str) (null terminated), parameters
136 | var stmtName string
137 | var query string
138 | vars := bytes.Split(r[5:], []byte{0})
139 | if len(vars) >= 3 {
140 | stmtName = string(vars[0])
141 | query = string(vars[1])
142 | } else if len(vars) == 2 { // query too long for our buffer
143 | stmtName = string(vars[0])
144 | query = string(vars[1]) + "..."
145 | } else {
146 | return "", fmt.Errorf("could not parse 'parse' frame for postgres")
147 | }
148 |
149 | (*pgStatements)[getKey(d.Pid, d.Fd, stmtName)] = query
150 | return fmt.Sprintf("PREPARE %s AS %s", stmtName, query), nil
151 | case 'B':
152 | // EXTENDED_QUERY -> B, 4 bytes len, portal str (null terminated), prepared statement name str (null terminated)
153 | var stmtName string
154 | vars := bytes.Split(r[5:], []byte{0})
155 | if len(vars) >= 2 {
156 | stmtName = string(vars[1])
157 | } else {
158 | return "", fmt.Errorf("could not parse bind frame for postgres")
159 | }
160 |
161 | query, ok := (*pgStatements)[getKey(d.Pid, d.Fd, stmtName)]
162 | if !ok || query == "" { // we don't have the query for the prepared statement
163 | // Execute (name of prepared statement) [(parameter)]
164 | return fmt.Sprintf("EXECUTE %s *values*", stmtName), nil
165 | }
166 | return query, nil
167 | default:
168 | return "", fmt.Errorf("could not parse extended query for postgres")
169 | }
170 | } else if PostgresMethodConversion(d.Method).String() == CLOSE_OR_TERMINATE {
171 | sqlCommand = string(r)
172 | }
173 |
174 | return sqlCommand, nil
175 | }
--------------------------------------------------------------------------------