├── .gitignore ├── LICENSE ├── README.md ├── demo ├── filters │ └── filter.go ├── static │ └── forest.png └── user.go ├── examples ├── first_lesson │ ├── afterclass │ │ ├── fibonacci.go │ │ ├── fmt.go │ │ └── slice.go │ ├── array_slice │ │ ├── array.go │ │ └── slice.go │ ├── fmt │ │ └── fmt.go │ ├── for │ │ └── for.go │ ├── func_dec │ │ └── funcs.go │ ├── if_else │ │ └── ifelse.go │ ├── package_dec │ │ ├── multi_same │ │ │ ├── a.go │ │ │ └── b.go │ │ └── not_same │ │ │ └── not_same.go │ ├── switch │ │ └── switch.go │ ├── types │ │ ├── rune.go │ │ └── string.go │ └── var_and_const │ │ ├── assignment.go │ │ ├── const.go │ │ ├── var.go │ │ └── var_wrong.go ├── forth_lesson │ ├── atomic │ │ └── atomic.go │ ├── channel │ │ └── channel.go │ ├── context │ │ └── context.go │ ├── init │ │ ├── init_order.go │ │ └── multi_init.go │ ├── select │ │ └── select.go │ └── static_resource │ │ └── file_server.go ├── second_lesson │ ├── afterclass │ │ ├── set.go │ │ └── tree.go │ ├── composition │ │ ├── composition.go │ │ └── no_over_write.go │ ├── http │ │ └── request_body.go │ ├── map │ │ └── map.go │ ├── server_context │ │ └── signup.go │ └── struct │ │ ├── intf.go │ │ ├── pointer.go │ │ ├── receiver.go │ │ ├── self_ref.go │ │ ├── struct.go │ │ ├── type_a_b.go │ │ └── type_a_et_b.go └── third_lesson │ ├── closure │ └── closure.go │ ├── defer │ └── defer.go │ ├── errors │ ├── error.go │ └── panic.go │ ├── goroutine │ └── goroutine.go │ └── sync │ ├── map.go │ ├── mutex.go │ ├── once.go │ ├── pool.go │ └── wait_group.go ├── go.mod ├── go.sum ├── main.go ├── onclass └── main.go └── pkg ├── context.go ├── filter.go ├── graceful_shutdown.go ├── graceful_shutdown_signal_darwin.go ├── graceful_shutdown_signal_linux.go ├── graceful_shutdown_signal_windows.go ├── handler.go ├── hook.go ├── hook_test.go ├── map_router.go ├── server.go ├── static_resource.go ├── tree_node.go ├── tree_router.go ├── tree_router_test.go ├── v1 ├── context.go ├── filter.go ├── handler.go ├── map_router.go ├── server.go ├── tree_router.go └── tree_router_test.go ├── v2 ├── context.go ├── filter.go ├── handler.go ├── map_router.go ├── server.go ├── tree_router.go └── tree_router_test.go └── v3 ├── context.go ├── filter.go ├── handler.go ├── map_router.go ├── server.go ├── tree_node.go ├── tree_router.go └── tree_router_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | .idea 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Ming Deng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # toy-web 2 | 用于极客时间go基础课程 3 | -------------------------------------------------------------------------------- /demo/filters/filter.go: -------------------------------------------------------------------------------- 1 | package filters 2 | 3 | import ( 4 | "fmt" 5 | web "geektime/toy-web/pkg" 6 | ) 7 | 8 | func init() { 9 | web.RegisterFilter("my-custom", myFilterBuilder) 10 | } 11 | 12 | func myFilterBuilder(next web.Filter) web.Filter { 13 | return func(c *web.Context) { 14 | fmt.Println("假装这是我自定义的 filter") 15 | next(c) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /demo/static/forest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flycash/toy-web/06bd53c25b602fa27f25e52564e9cd7a94283315/demo/static/forest.png -------------------------------------------------------------------------------- /demo/user.go: -------------------------------------------------------------------------------- 1 | package demo 2 | 3 | import ( 4 | "fmt" 5 | web "geektime/toy-web/pkg" 6 | "time" 7 | ) 8 | 9 | func SignUp(c *web.Context) { 10 | req := &signUpReq{} 11 | err := c.ReadJson(req) 12 | if err != nil { 13 | _ = c.BadRequestJson(&commonResponse{ 14 | BizCode: 4, // 假如说我们这个代表输入参数错误 15 | // 注意这里是demo,实际中你应该避免暴露 error 16 | Msg: fmt.Sprintf("invalid request: %v", err), 17 | }) 18 | return 19 | } 20 | _ = c.OkJson(&commonResponse{ 21 | // 假设这个是新用户的 ID 22 | Data: 123, 23 | }) 24 | } 25 | 26 | func SlowService(c *web.Context) { 27 | time.Sleep(time.Second * 10) 28 | _ = c.OkJson(&commonResponse{ 29 | Msg: "Hi, this is msg from slow service", 30 | }) 31 | } 32 | 33 | type signUpReq struct { 34 | Email string `json:"email"` 35 | Password string `json:"password"` 36 | ConfirmedPassword string `json:"confirmed_password"` 37 | } 38 | 39 | type commonResponse struct { 40 | BizCode int `json:"biz_code"` 41 | Msg string `json:"msg"` 42 | Data interface{} `json:"data"` 43 | } -------------------------------------------------------------------------------- /examples/first_lesson/afterclass/fibonacci.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func main() { 4 | 5 | } 6 | 7 | func fibonacci(n int) int { 8 | // TODO 9 | return 0 10 | } 11 | -------------------------------------------------------------------------------- /examples/first_lesson/afterclass/fmt.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func main() { 4 | 5 | } 6 | 7 | // 输出两位小数 8 | func printNumWith2(float642 float64) string { 9 | return "" 10 | } 11 | 12 | func printBytes(data []byte) string { 13 | return "" 14 | } 15 | -------------------------------------------------------------------------------- /examples/first_lesson/afterclass/slice.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func main() { 4 | s := []int{1, 2, 4, 7} 5 | // 结果应该是 5, 1, 2, 4, 7 6 | s = Add(s, 0, 5) 7 | 8 | // 结果应该是5, 9, 1, 2, 4, 7 9 | s = Add(s, 1, 9) 10 | 11 | // 结果应该是5, 9, 1, 2, 4, 7, 13 12 | s = Add(s, 6, 13) 13 | 14 | // 结果应该是5, 9, 2, 4, 7, 13 15 | s = Delete(s, 2) 16 | 17 | // 结果应该是9, 2, 4, 7, 13 18 | s = Delete(s, 0) 19 | 20 | // 结果应该是9, 2, 4, 7 21 | s = Delete(s, 4) 22 | 23 | } 24 | 25 | func Add(s []int, index int, value int) []int { 26 | //TODO 27 | return s 28 | } 29 | 30 | func Delete(s []int, index int) []int { 31 | // TODO 32 | return s 33 | } 34 | -------------------------------------------------------------------------------- /examples/first_lesson/array_slice/array.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | // 直接初始化一个三个元素的数组。大括号里面多一个或者少一个都编译不通过 7 | a1 := [3]int{9, 8, 7} 8 | fmt.Printf("a1: %v, len: %d, cap: %d", a1, len(a1), cap(a1)) 9 | 10 | // 初始化一个三个元素的数组,所有元素都是0 11 | var a2 [3]int 12 | fmt.Printf("a2: %v, len: %d, cap: %d", a2, len(a2), cap(a2)) 13 | 14 | //a1 = append(a1, 12) 数组不支持 append 操作 15 | 16 | // 按下标索引 17 | fmt.Printf("a1[1]: %d", a1[1]) 18 | // 超出下标范围,直接崩溃,编译不通过 19 | //fmt.Printf("a1[99]: %d", a1[99]) 20 | } -------------------------------------------------------------------------------- /examples/first_lesson/array_slice/slice.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | s1 := []int{1, 2, 3, 4} // 直接初始化了 4 个元素的切片 7 | fmt.Printf("s1: %v, len %d, cap: %d \n", s1, len(s1), cap(s1)) 8 | 9 | s2 := make([]int, 3, 4) // 创建了一个包含三个元素,容量为4的切片 10 | fmt.Printf("s2: %v, len %d, cap: %d \n", s2, len(s2), cap(s2)) 11 | 12 | // s2 目前 [0, 0, 0], append(追加)一个元素,变成什么? 13 | s2 = append(s2, 7) // 后边添加一个元素,没有超出容量限制,不会发生扩容 14 | fmt.Printf("s2: %v, len %d, cap: %d \n", s2, len(s2), cap(s2)) 15 | 16 | s2 = append(s2, 8) // 后边添加了一个元素,触发扩容 17 | fmt.Printf("s2: %v, len %d, cap: %d \n", s2, len(s2), cap(s2)) 18 | 19 | s3 := make([]int, 4) // 只传入一个参数,表示创建一个含有四个元素,容量也为四个元素的 20 | // 等价于 s3 := make([]int, 4, 4) 21 | fmt.Printf("s3: %v, len %d, cap: %d \n", s3, len(s3), cap(s3)) 22 | 23 | // 按下标索引 24 | fmt.Printf("s3[2]: %d", s3[2]) 25 | // 超出下标范围,直接崩溃 26 | // runtime error: index out of range [99] with length 4 27 | // fmt.Printf("s3[99]: %d", s3[99]) 28 | 29 | // SubSlice() 30 | 31 | //shareArr() 32 | } 33 | 34 | func SubSlice() { 35 | s1 := []int{2, 4, 6, 8, 10} 36 | s2 := s1[1:3] 37 | fmt.Printf("s2: %v, len %d, cap: %d \n", s2, len(s2), cap(s2)) 38 | 39 | s3 := s1[2:] 40 | fmt.Printf("s3: %v, len %d, cap: %d \n", s3, len(s3), cap(s3)) 41 | 42 | s4 := s1[:3] 43 | fmt.Printf("s4: %v, len %d, cap: %d \n", s4, len(s4), cap(s4)) 44 | } 45 | 46 | func ShareSlice() { 47 | 48 | s1 := []int{1, 2, 3, 4} 49 | s2 := s1[2:] 50 | fmt.Printf("s1: %v, len %d, cap: %d \n", s1, len(s1), cap(s1)) 51 | fmt.Printf("s2: %v, len %d, cap: %d \n", s2, len(s2), cap(s2)) 52 | 53 | s2[0] = 99 54 | fmt.Printf("s1: %v, len %d, cap: %d \n", s1, len(s1), cap(s1)) 55 | fmt.Printf("s2: %v, len %d, cap: %d \n", s2, len(s2), cap(s2)) 56 | 57 | s2 = append(s2, 199) 58 | fmt.Printf("s1: %v, len %d, cap: %d \n", s1, len(s1), cap(s1)) 59 | fmt.Printf("s2: %v, len %d, cap: %d \n", s2, len(s2), cap(s2)) 60 | 61 | s2[1] = 1999 62 | fmt.Printf("s1: %v, len %d, cap: %d \n", s1, len(s1), cap(s1)) 63 | fmt.Printf("s2: %v, len %d, cap: %d \n", s2, len(s2), cap(s2)) 64 | } 65 | -------------------------------------------------------------------------------- /examples/first_lesson/fmt/fmt.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | name:="Tom" 7 | age := 17 8 | // 这个 API 是返回字符串的,所以大多数时候我们都是用这个 9 | str := fmt.Sprintf("hello, I am %s, I am %d years old \n", name, age) 10 | println(str) 11 | 12 | // 这个是直接输出,一般简单程序 DEBUG 会用它输出到一些信息到控制台 13 | fmt.Printf("hello, I am %s, I am %d years old \n", name, age) 14 | 15 | replaceHolder() 16 | } 17 | 18 | func replaceHolder() { 19 | u := &user{ 20 | Name: "Tom", 21 | Age: 17, 22 | } 23 | fmt.Printf("v => %v \n", u) 24 | fmt.Printf("+v => %+v \n", u) 25 | fmt.Printf("#v => %#v \n", u) 26 | fmt.Printf("T => %T \n", u) 27 | } 28 | 29 | type user struct { 30 | Name string 31 | Age int 32 | } 33 | -------------------------------------------------------------------------------- /examples/first_lesson/for/for.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | ForLoop() 7 | ForI() 8 | ForR() 9 | } 10 | 11 | func ForLoop() { 12 | arr := []int {9, 8, 7, 6} 13 | index := 0 14 | for { 15 | if index == 3{ 16 | // break 跳出循环 17 | break 18 | } 19 | fmt.Printf("%d => %d\n", index, arr[index]) 20 | index ++ 21 | } 22 | fmt.Println(" for loop end \n ") 23 | } 24 | 25 | func ForI() { 26 | arr := []int {9, 8, 7, 6} 27 | for i := 0; i < len(arr); i++ { 28 | fmt.Printf("%d => %d \n", i, arr[i]) 29 | } 30 | fmt.Println("for i loop end \n ") 31 | } 32 | 33 | func ForR() { 34 | arr := []int {9, 8, 7, 6} 35 | 36 | for index, value := range arr { 37 | fmt.Printf("%d => %d\n", index, value) 38 | } 39 | 40 | // 如果只是需要 value, 可以用 _ 代替 index 41 | for _, value := range arr { 42 | fmt.Printf("only value: %d \n", value) 43 | } 44 | 45 | // 如果只需要 index 也可以去掉 写成 for index := range arr 46 | for index := range arr { 47 | fmt.Printf("only index: %d \n", index) 48 | } 49 | 50 | fmt.Println("for r loop end \n ") 51 | } 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /examples/first_lesson/func_dec/funcs.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | a := Fun0("Tom") 7 | println(a) 8 | 9 | b, c := Fun1("a", 17) 10 | println(b) 11 | println(c) 12 | 13 | _, d := Fun2("a", "b") 14 | println(d) 15 | 16 | // 不定参数后面可以传递任意多个值 17 | Fun4("hello", 19, "CUICUI", "DaMing") 18 | s := []string{"CUICUI", "DaMing"} 19 | Fun4("hello", 19, s...) 20 | } 21 | 22 | // Fun0 只有一个返回值,不需要括号括起来 23 | func Fun0(name string) string { 24 | return "Hello, " + name 25 | } 26 | 27 | // Fun1 多个参数,多个返回值。参数有名字,但是返回值没有 28 | func Fun1(a string, b int) (int, string) { 29 | return 0, "你好" 30 | } 31 | 32 | // Fun2 的返回值具有名字,可以在内部直接复制,然后返回 33 | // 也可以忽略age, name,直接返回别的。 34 | func Fun2(a string, b string) (age int, name string) { 35 | age = 19 36 | name = "Tom" 37 | return 38 | //return 19, "Tom" // 这样返回也可以 39 | } 40 | 41 | // Fun3 多个参数具有相同类型放在一起,可以只写一次类型 42 | func Fun3(a, b, c string, abc, bcd int, p string) (d, e int, g string) { 43 | d = 15 44 | e = 16 45 | g = "你好" 46 | return 47 | //return 0, 0, "你好" // 这样也可以 48 | } 49 | 50 | // Fun4 不定参数。不定参数要放在最后面 51 | func Fun4(a string, b int, names...string) { 52 | // 我们使用的时候可以直接把 names 看做切片 53 | for _, name := range names { 54 | fmt.Printf("不定参数:%s \n", name) 55 | } 56 | } -------------------------------------------------------------------------------- /examples/first_lesson/if_else/ifelse.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | Young(9) 7 | Young(100) 8 | 9 | IfUsingNewVariable(10, 200) 10 | IfUsingNewVariable(100, 30) 11 | } 12 | 13 | func Young(age int) { 14 | if age < 18{ 15 | fmt.Println("I am a child!") 16 | } else { 17 | // else 分支也可以没有 18 | fmt.Println("I not a child") 19 | } 20 | } 21 | 22 | func IfUsingNewVariable(start int, end int) { 23 | if distance := end - start; distance > 100 { 24 | fmt.Printf("距离太远,不来了: %d\n", distance) 25 | } else { 26 | // else 分支也可以没有 27 | fmt.Printf("距离并不远,来一趟: %d\n", distance) 28 | } 29 | 30 | // 这里不能访问 distance 31 | //fmt.Printf("距离是: %d\n", distance) 32 | } -------------------------------------------------------------------------------- /examples/first_lesson/package_dec/multi_same/a.go: -------------------------------------------------------------------------------- 1 | package multi_same 2 | -------------------------------------------------------------------------------- /examples/first_lesson/package_dec/multi_same/b.go: -------------------------------------------------------------------------------- 1 | package multi_same 2 | -------------------------------------------------------------------------------- /examples/first_lesson/package_dec/not_same/not_same.go: -------------------------------------------------------------------------------- 1 | package not_same_aaaa 2 | -------------------------------------------------------------------------------- /examples/first_lesson/switch/switch.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | ChooseFruit("蓝莓") 7 | ChooseFruit("苹果") 8 | ChooseFruit("西瓜") 9 | } 10 | 11 | func ChooseFruit(fruit string) { 12 | switch fruit { 13 | case "苹果": 14 | fmt.Println("这是一个苹果") 15 | case "草莓", "蓝莓": 16 | fmt.Println("这是霉霉") 17 | default: 18 | fmt.Printf("不知道是啥:%s \n", fruit) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /examples/first_lesson/types/rune.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func main() { 4 | var a byte = 13 5 | } 6 | -------------------------------------------------------------------------------- /examples/first_lesson/types/string.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "unicode/utf8" 4 | 5 | func main() { 6 | // 一般推荐用于短的,不用换行的,不含双引号的 7 | println("He said:\" Hello Go \" ") 8 | // 长的,复杂的。比如说放个 json 串 9 | println(`He said: "hello, Go" 10 | 我还可以换个行 11 | `) 12 | 13 | 14 | println(len("你好")) // 输出6 15 | println(utf8.RuneCountInString("你好")) // 输出 2 16 | println(utf8.RuneCountInString("你好ab")) // 输出 4 17 | 18 | // 反正遇到计算字符个数,比如说用户名字多长,博客多长这种字符个数 19 | // 记得用 utf8.RuneCountInString 20 | 21 | // 字符串拼接。只能发生在 string 之间 22 | println("Hello, " + "Go!") 23 | 24 | } 25 | -------------------------------------------------------------------------------- /examples/first_lesson/var_and_const/assignment.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func main() { 4 | a := 13 5 | println(a) 6 | b := "你好" 7 | println(b) 8 | } 9 | 10 | -------------------------------------------------------------------------------- /examples/first_lesson/var_and_const/const.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | const internal = "包内可访问" 4 | const External = "包外可访问" 5 | 6 | func main() { 7 | const a = "你好" 8 | println(a) 9 | } 10 | -------------------------------------------------------------------------------- /examples/first_lesson/var_and_const/var.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // Global 首字母大写,全局可以访问 4 | var Global = "全局变量" 5 | 6 | // 首字母小写,只能在这个包里面使用 7 | // 其子包也不能用 8 | var local = "包变量" 9 | 10 | var ( 11 | First string = "abc" 12 | second int32 = 16 13 | ) 14 | 15 | func main() { 16 | // int 是灰色的,是因为 golang 自己可以做类型推断,它觉得你可以省略 17 | var a int = 13 18 | println(a) 19 | 20 | // 这里我们省略了类型 21 | var b = 14 22 | println(b) 23 | 24 | // 这里 uint 不可省略,因为生路之后,因为不加 uint 类型,15会被解释为 int 类型 25 | var c uint = 15 26 | println(c) 27 | 28 | // 这一句无法通过编译,因为 golang 是强类型语言,并且不会帮你做任何的转换 29 | // println(a == c) 30 | 31 | // 只声明不赋值,d 是默认值 0,类型不可以省略 32 | var d int 33 | println(d) 34 | } 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /examples/first_lesson/var_and_const/var_wrong.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | var aa = "hello" 4 | // var aa = "bbb" 这个包已经有一个 a 了,所以再次声明会导致编译 5 | func main() { 6 | aa := 13 // 虽然包外面已经有一个 aa 了,但是这里从包变成了局部变量 7 | println(aa) 8 | 9 | var bb = 15 10 | //var bb = 16 // 重复声明,也会导致编译不通过 11 | println(bb) 12 | 13 | bb = 17 // OK,没有重复声明,只是赋值了新的值 14 | // bb := 18 // 不行,因为 := 就是声明并且赋值的简写,相当于重复声明了 bb 15 | } 16 | -------------------------------------------------------------------------------- /examples/forth_lesson/atomic/atomic.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "sync/atomic" 4 | 5 | var value int32 = 0 6 | func main() { 7 | // 要传入 value 的指针 8 | // 把 value + 10 9 | atomic.AddInt32(&value, 10) 10 | nv := atomic.LoadInt32(&value) 11 | // 输出10 12 | println(nv) 13 | // 如果之前的值是10,那么就设置为新的值 20 14 | swapped := atomic.CompareAndSwapInt32(&value, 10, 20) 15 | // 输出 true 16 | println(swapped) 17 | 18 | // 如果之前的值是19,那么就设置为新的值 50 19 | // 显然现在 value 是 20 20 | swapped = atomic.CompareAndSwapInt32(&value, 19, 50) 21 | // 输出 false 22 | println(swapped) 23 | 24 | old := atomic.SwapInt32(&value, 40) 25 | // 应该是20,即原本的值 26 | println(old) 27 | // 输出新的值,也就是交换后的值,40 28 | println(value) 29 | } -------------------------------------------------------------------------------- /examples/forth_lesson/channel/channel.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | func main() { 9 | channelWithoutCache() 10 | channelWithCache() 11 | } 12 | 13 | func channelWithCache() { 14 | ch := make(chan string, 1) 15 | go func() { 16 | 17 | ch <- "Hello, first msg from channel" 18 | time.Sleep(time.Second) 19 | ch <- "Hello, second msg from channel" 20 | }() 21 | 22 | time.Sleep(2 * time.Second) 23 | msg := <- ch 24 | fmt.Println(time.Now().String() + msg) 25 | msg = <- ch 26 | fmt.Println(time.Now().String() + msg) 27 | // 因为前面我们先睡了2秒,所以其实会有一个已经在缓冲了 28 | // 当我们尝试输出的时候,这个输出间隔就会明显小于1秒 29 | // 我电脑上的几次实验,差距都在1ms以内 30 | } 31 | 32 | func channelWithoutCache() { 33 | // 不带缓冲 34 | ch := make(chan string) 35 | go func() { 36 | time.Sleep(time.Second) 37 | ch <- "Hello, msg from channel" 38 | }() 39 | 40 | // 这里比较容易写成 msg <- ch,编译会报错 41 | msg := <- ch 42 | fmt.Println(msg) 43 | } 44 | -------------------------------------------------------------------------------- /examples/forth_lesson/context/context.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | func main() { 10 | WithTimeout() 11 | WithCancel() 12 | WithDeadline() 13 | WithValue() 14 | } 15 | 16 | func WithTimeout() { 17 | ctx, cancel := context.WithTimeout(context.Background(), time.Second * 2) 18 | defer cancel() 19 | 20 | start := time.Now().Unix() 21 | <- ctx.Done() 22 | end := time.Now().Unix() 23 | // 输出2,说明在 ctx.Done()这里阻塞了两秒 24 | fmt.Println(end-start) 25 | } 26 | 27 | func WithCancel() { 28 | ctx, cancel := context.WithCancel(context.Background()) 29 | go func() { 30 | <- ctx.Done() 31 | fmt.Println("context was canceled") 32 | }() 33 | // 确保我们的 goroutine进去执行了 34 | time.Sleep(time.Second) 35 | cancel() 36 | // 确保后面那句打印出来了 37 | time.Sleep(time.Second) 38 | } 39 | 40 | func WithDeadline() { 41 | // 设置两秒后超时 42 | ctx, cancel := context.WithDeadline(context.Background(), 43 | time.Now().Add(2 * time.Second)) 44 | defer cancel() 45 | 46 | start := time.Now().Unix() 47 | <- ctx.Done() 48 | end := time.Now().Unix() 49 | // 输出2,说明在 ctx.Done()这里阻塞了两秒 50 | fmt.Println(end-start) 51 | } 52 | 53 | func WithValue() { 54 | parentKey := "parent" 55 | parent := context.WithValue(context.Background(), parentKey, "this is parent") 56 | 57 | sonKey := "son" 58 | son := context.WithValue(parent, sonKey, "this is son") 59 | 60 | // 尝试从 parent 里面拿出来 key = son的,会拿不到 61 | if parent.Value(parentKey) == nil { 62 | fmt.Printf("parent can not get son's key-value pair") 63 | } 64 | 65 | if val := son.Value(parentKey); val != nil { 66 | fmt.Printf("parent can not get son's key-value pair") 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /examples/forth_lesson/init/init_order.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func init() { 4 | // 因为我们不能确定 init 方法的执行顺序, 5 | // 只能曲线救国 6 | initBeforeSomething() 7 | initSomething() 8 | initAfterSomething() 9 | } 10 | 11 | func initBeforeSomething() { 12 | 13 | } 14 | 15 | func initSomething() { 16 | 17 | } 18 | 19 | func initAfterSomething() { 20 | 21 | } -------------------------------------------------------------------------------- /examples/forth_lesson/init/multi_init.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func init() { 4 | // 第一个 5 | } 6 | 7 | func init() { 8 | // 第二个 9 | } 10 | -------------------------------------------------------------------------------- /examples/forth_lesson/select/select.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | func main() { 9 | // 这个不能在 main 函数运行,是因为运行起来, 10 | // 所有的goroutine都被我们搞sleep了,直接就崩了 11 | //Select() 12 | } 13 | 14 | func Select() { 15 | ch1 := make(chan string) 16 | ch2 := make(chan string) 17 | 18 | go func() { 19 | time.Sleep(time.Second) 20 | ch1 <- "msg from channel1" 21 | }() 22 | 23 | go func() { 24 | time.Sleep(time.Second) 25 | ch2 <- "msg from channel2" 26 | }() 27 | 28 | for { 29 | select { 30 | case msg := <- ch1: 31 | fmt.Println(msg) 32 | case msg := <- ch2: 33 | fmt.Println(msg) 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /examples/forth_lesson/static_resource/file_server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "net/http" 4 | 5 | func main() { 6 | serve := http.FileServer(http.Dir(".")) 7 | //http.Handle("/", serve) 8 | http.ListenAndServe(":8080", serve) 9 | } 10 | -------------------------------------------------------------------------------- /examples/second_lesson/afterclass/set.go: -------------------------------------------------------------------------------- 1 | package afterclass 2 | 3 | type Set interface { 4 | Put(key string) 5 | Keys() []string 6 | Contains(key string) bool 7 | Remove(key string) 8 | // 如果之前已经有了,就返回旧的值,absent =false 9 | // 如果之前没有,就塞下去,返回 absent = true 10 | PutIfAbsent(key string) (old string, absent bool) 11 | } 12 | -------------------------------------------------------------------------------- /examples/second_lesson/afterclass/tree.go: -------------------------------------------------------------------------------- 1 | package afterclass 2 | 3 | type Tree interface { 4 | 5 | } 6 | 7 | // 二叉树 8 | type binaryTree struct { 9 | 10 | } 11 | 12 | // 多叉树 13 | type mutliWayTree struct { 14 | 15 | } 16 | -------------------------------------------------------------------------------- /examples/second_lesson/composition/composition.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | 7 | } 8 | 9 | // Swimming 会游泳的 10 | type Swimming interface { 11 | Swim() 12 | } 13 | 14 | type Duck interface { 15 | // 鸭子是会游泳的,所以这里组合了它 16 | Swimming 17 | } 18 | 19 | 20 | type Base struct { 21 | Name string 22 | } 23 | 24 | type Concrete1 struct { 25 | Base 26 | } 27 | 28 | type Concrete2 struct { 29 | *Base 30 | } 31 | 32 | func (c Concrete1) SayHello() { 33 | // c.Name 直接访问了Base的Name字段 34 | fmt.Printf("I am base and my name is: %s \n", c.Name) 35 | // 这样也是可以的 36 | fmt.Printf("I am base and my name is: %s \n", c.Base.Name) 37 | 38 | // 调用了被组合的 39 | c.Base.SayHello() 40 | } 41 | 42 | func (b *Base) SayHello() { 43 | fmt.Printf("I am base and my name is: %s \n", b.Name) 44 | } -------------------------------------------------------------------------------- /examples/second_lesson/composition/no_over_write.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | son := Son{ 7 | Parent{}, 8 | } 9 | 10 | son.SayHello() 11 | } 12 | 13 | type Parent struct { 14 | 15 | } 16 | 17 | func (p Parent) SayHello() { 18 | fmt.Println("I am " + p.Name()) 19 | } 20 | 21 | func (p Parent) Name() string { 22 | return "Parent" 23 | } 24 | 25 | type Son struct { 26 | Parent 27 | } 28 | 29 | // 定义了自己的 Name() 方法 30 | func (s Son) Name() string { 31 | return "Son" 32 | } 33 | 34 | -------------------------------------------------------------------------------- /examples/second_lesson/http/request_body.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | ) 9 | 10 | func home(w http.ResponseWriter, r *http.Request) { 11 | fmt.Fprint(w, "Hi, this is home page") 12 | } 13 | 14 | func readBodyOnce(w http.ResponseWriter, r *http.Request) { 15 | body, err := io.ReadAll(r.Body) 16 | if err != nil { 17 | fmt.Fprintf(w, "read body failed: %v", err) 18 | // 记住要返回,不然就还会执行后面的代码 19 | return 20 | } 21 | // 类型转换,将 []byte 转换为 string 22 | fmt.Fprintf(w, "read the data: %s \n", string(body)) 23 | 24 | // 尝试再次读取,啥也读不到,但是也不会报错 25 | body, err = io.ReadAll(r.Body) 26 | if err != nil { 27 | // 不会进来这里 28 | fmt.Fprintf(w, "read the data one more time got error: %v", err) 29 | return 30 | } 31 | fmt.Fprintf(w, "read the data one more time: [%s] and read data length %d \n", string(body), len(body)) 32 | } 33 | 34 | 35 | func getBodyIsNil(w http.ResponseWriter, r *http.Request) { 36 | if r.GetBody == nil { 37 | fmt.Fprint(w, "GetBody is nil \n") 38 | } else { 39 | fmt.Fprintf(w, "GetBody not nil \n") 40 | } 41 | } 42 | 43 | func queryParams(w http.ResponseWriter, r *http.Request) { 44 | values := r.URL.Query() 45 | fmt.Fprintf(w, "query is %v\n", values) 46 | } 47 | 48 | func wholeUrl(w http.ResponseWriter, r *http.Request) { 49 | data, _ := json.Marshal(r.URL) 50 | fmt.Fprintf(w, string(data)) 51 | } 52 | 53 | func header(w http.ResponseWriter, r *http.Request) { 54 | fmt.Fprintf(w, "header is %v\n", r.Header) 55 | } 56 | 57 | func form(w http.ResponseWriter, r *http.Request) { 58 | fmt.Fprintf(w, "before parse form %v\n", r.Form) 59 | err := r.ParseForm() 60 | if err != nil { 61 | fmt.Fprintf(w, "parse form error %v\n", r.Form) 62 | } 63 | fmt.Fprintf(w, "before parse form %v\n", r.Form) 64 | } 65 | 66 | func main() { 67 | http.HandleFunc("/", home) 68 | http.HandleFunc("/body/once", readBodyOnce) 69 | http.HandleFunc("/body/multi", getBodyIsNil) 70 | http.HandleFunc("/url/query", queryParams) 71 | http.HandleFunc("/header", header) 72 | http.HandleFunc("/wholeUrl", wholeUrl) 73 | http.HandleFunc("/form", form) 74 | if err := http.ListenAndServe(":8080", nil); err != nil { 75 | panic(err) 76 | } 77 | } -------------------------------------------------------------------------------- /examples/second_lesson/map/map.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | // 创建了一个预估容量是2的 map 7 | m := make(map[string]string, 2) 8 | // 没有指定预估容量 9 | m1 := make(map[string]string) 10 | // 直接初始化 11 | m2 := map[string]string{ 12 | "Tom": "Jerry", 13 | } 14 | 15 | // 赋值 16 | m["hello"] = "world" 17 | m1["hello"] = "world" 18 | // 赋值 19 | m2["hello"] = "world" 20 | // 取值 21 | val := m["hello"] 22 | println(val) 23 | 24 | // 再次取值,使用两个返回值,后面的ok会告诉你map有没有这个key 25 | val, ok := m["invalid_key"] 26 | if !ok { 27 | println("key not found") 28 | } 29 | 30 | for key, val := range m { 31 | fmt.Printf("%s => %s \n", key, val) 32 | } 33 | } -------------------------------------------------------------------------------- /examples/second_lesson/server_context/signup.go: -------------------------------------------------------------------------------- 1 | package server_context 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "geektime/toy-web/pkg/v2" 7 | "io" 8 | "net/http" 9 | ) 10 | 11 | // 在没有 context 抽象的情况下,是长这样的 12 | func SignUpWithoutContext(w http.ResponseWriter, r *http.Request) { 13 | req := &signUpReq{} 14 | body, err := io.ReadAll(r.Body) 15 | if err != nil { 16 | fmt.Fprintf(w, "read body failed: %v", err) 17 | // 要返回掉,不然就会继续执行后面的代码 18 | return 19 | } 20 | err = json.Unmarshal(body, req) 21 | if err != nil { 22 | fmt.Fprintf(w, "deserialized failed: %v", err) 23 | // 要返回掉,不然就会继续执行后面的代码 24 | return 25 | } 26 | 27 | // 返回一个虚拟的 user id 表示注册成功了 28 | fmt.Fprintf(w, "%d", err) 29 | } 30 | 31 | func SignUpWithoutWrite(w http.ResponseWriter, r *http.Request) { 32 | c := webv2.NewContext(w, r) 33 | req := &signUpReq{} 34 | err := c.ReadJson(req) 35 | if err != nil { 36 | resp := &commonResponse{ 37 | BizCode: 4, // 假如说我们这个代表输入参数错误 38 | Msg: fmt.Sprintf("invalid request: %v", err), 39 | } 40 | respBytes, _ := json.Marshal(resp) 41 | fmt.Fprint(w, string(respBytes)) 42 | return 43 | } 44 | // 这里又得来一遍 resp 转json 45 | fmt.Fprintf(w, "invalid request: %v", err) 46 | } 47 | 48 | type signUpReq struct { 49 | Email string `json:"email"` 50 | Password string `json:"password"` 51 | ConfirmedPassword string `json:"confirmed_password"` 52 | } 53 | 54 | type commonResponse struct { 55 | BizCode int `json:"biz_code"` 56 | Msg string `json:"msg"` 57 | Data interface{} `json:"data"` 58 | } 59 | -------------------------------------------------------------------------------- /examples/second_lesson/struct/intf.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // 首字母小写,所以是一个包私有的接口 4 | type animal interface { 5 | // 这里可以有任意多个方法,不过我们一般建议是小接口, 6 | // 即接口里面不会有很多方法 7 | // 方法声明不需要 func 关键字 8 | 9 | Eat() 10 | } 11 | 12 | // 首字母大写,所以是一个包外可访问的接口 13 | type Duck interface { 14 | Swim() 15 | } 16 | -------------------------------------------------------------------------------- /examples/second_lesson/struct/pointer.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | // 指针用 * 表示 7 | var p *ToyDuck = &ToyDuck{} 8 | // 解引用,得到结构体 9 | var duck ToyDuck = *p 10 | duck.Swim() 11 | 12 | // 只是声明了,但是没有使用 13 | var nilDuck *ToyDuck 14 | if nilDuck == nil { 15 | fmt.Println("nilDuck is nil") 16 | } 17 | } -------------------------------------------------------------------------------- /examples/second_lesson/struct/receiver.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | 7 | // 因为 u 是结构体,所以方法调用的时候它数据是不会变的 8 | u := User{ 9 | Name: "Tom", 10 | Age: 10, 11 | } 12 | u.ChangeName("Tom Changed!") 13 | u.ChangeAge(100) 14 | fmt.Printf("%v \n", u) 15 | 16 | // 因为 up 指针,所以内部的数据是可以被改变的 17 | up := &User{ 18 | Name: "Jerry", 19 | Age: 12, 20 | } 21 | 22 | // 因为 ChangeName 的接收器是结构体 23 | // 所以 up 的数据还是不会变 24 | up.ChangeName("Jerry Changed!") 25 | up.ChangeAge(120) 26 | 27 | fmt.Printf("%v \n", up) 28 | } 29 | 30 | type User struct { 31 | Name string 32 | Age int 33 | } 34 | 35 | // 结构体接收器 36 | func (u User) ChangeName(newName string) { 37 | u.Name = newName 38 | } 39 | 40 | // 指针接收器 41 | func (u *User) ChangeAge(newAge int) { 42 | u.Age = newAge 43 | } 44 | -------------------------------------------------------------------------------- /examples/second_lesson/struct/self_ref.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func main() { 4 | 5 | } 6 | 7 | type Node struct { 8 | //自引用只能使用指针 9 | //left Node 10 | //right Node 11 | 12 | left *Node 13 | right *Node 14 | 15 | // 这个也会报错 16 | // nn NodeNode 17 | } 18 | 19 | 20 | type NodeNode struct { 21 | node Node 22 | } -------------------------------------------------------------------------------- /examples/second_lesson/struct/struct.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | // duck1 是 *ToyDuck 7 | duck1 := &ToyDuck{} 8 | duck1.Swim() 9 | 10 | duck2 := ToyDuck{} 11 | duck2.Swim() 12 | 13 | // duck3 是 *ToyDuck 14 | duck3 := new(ToyDuck) 15 | duck3.Swim() 16 | 17 | // 当你声明这样的时候,Go 就帮你分配好内存 18 | // 不用担心空指针的问题,以为它压根就不是指针 19 | var duck4 ToyDuck 20 | duck4.Swim() 21 | 22 | // duck5 就是一个指针了 23 | var duck5 *ToyDuck 24 | // 这边会直接panic 掉 25 | duck5.Swim() 26 | 27 | // 赋值,初始化按字段名字赋值 28 | duck6 := ToyDuck{ 29 | Color: "黄色", 30 | Price: 100, 31 | } 32 | duck6.Swim() 33 | 34 | // 初始化按字段顺序赋值,不建议使用 35 | duck7 := ToyDuck{"蓝色", 1024} 36 | duck7.Swim() 37 | 38 | // 后面再单独赋值 39 | duck8 := ToyDuck{} 40 | duck8.Color = "橘色" 41 | 42 | } 43 | 44 | // ToyDuck 玩具鸭 45 | type ToyDuck struct { 46 | Color string 47 | Price uint64 48 | } 49 | 50 | func (t *ToyDuck) Swim() { 51 | fmt.Printf("门前一条河,游过一群鸭,我是%s,%d一只\n", t.Color, t.Price) 52 | } 53 | 54 | 55 | -------------------------------------------------------------------------------- /examples/second_lesson/struct/type_a_b.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | fake := FakeFish{} 7 | // fake 无法调用原来 Fish 的方法 8 | // 这一句会编译错误 9 | //fake.Swim() 10 | fake.FakeSwim() 11 | 12 | // 转换为Fish 13 | td := Fish(fake) 14 | // 真的变成了鱼 15 | td.Swim() 16 | 17 | sFake := StrongFakeFish{} 18 | // 这里就是调用了自己的方法 19 | sFake.Swim() 20 | 21 | td = Fish(sFake) 22 | // 真的变成了鱼 23 | td.Swim() 24 | } 25 | 26 | // 定义了一个新类型,注意是新类型 27 | type FakeFish Fish 28 | 29 | func (f FakeFish) FakeSwim() { 30 | fmt.Printf("我是山寨鱼,嘎嘎嘎\n") 31 | } 32 | 33 | // 定义了一个新类型 34 | type StrongFakeFish Fish 35 | 36 | func (f StrongFakeFish) Swim() { 37 | fmt.Printf("我是华强北山寨鱼,嘎嘎嘎\n") 38 | } 39 | 40 | type Fish struct { 41 | } 42 | 43 | func (f Fish) Swim() { 44 | fmt.Printf("我是鱼,假装自己是一直鸭子\n") 45 | } 46 | -------------------------------------------------------------------------------- /examples/second_lesson/struct/type_a_et_b.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | var n News = fakeNews{ 7 | Name: "hello", 8 | } 9 | n.Report() 10 | } 11 | 12 | type News struct { 13 | Name string 14 | } 15 | 16 | func (d News) Report() { 17 | fmt.Println("I am news: " + d.Name) 18 | } 19 | 20 | type fakeNews = News -------------------------------------------------------------------------------- /examples/third_lesson/closure/closure.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | func main() { 9 | 10 | i := 13 11 | a := func() { 12 | fmt.Printf("i is %d \n", i) 13 | } 14 | a() 15 | 16 | fmt.Println(ReturnClosure("Tom")()) 17 | 18 | Delay() 19 | time.Sleep(time.Second) 20 | } 21 | 22 | func ReturnClosure(name string) func() string { 23 | return func() string { 24 | return "Hello, " + name 25 | } 26 | } 27 | 28 | func Delay() { 29 | fns := make([]func(), 0, 10) 30 | for i := 0; i < 10; i++ { 31 | fns = append(fns, func() { 32 | fmt.Printf("hello, this is : %d \n", i) 33 | }) 34 | } 35 | 36 | for _, fn := range fns { 37 | fn() 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /examples/third_lesson/defer/defer.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | defer func() { 7 | fmt.Println("aaa") 8 | }() 9 | 10 | defer func() { 11 | fmt.Println("bbb") 12 | }() 13 | 14 | defer func() { 15 | fmt.Println("ccc") 16 | }() 17 | } 18 | -------------------------------------------------------------------------------- /examples/third_lesson/errors/error.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | func main() { 9 | var err error = &MyError{} 10 | println(err.Error()) 11 | 12 | ErrorsPkg() 13 | } 14 | 15 | type MyError struct { 16 | } 17 | 18 | func (m *MyError) Error() string { 19 | return "Hello, it's my error" 20 | } 21 | 22 | func ErrorsPkg() { 23 | err := &MyError{} 24 | // 使用 %w 占位符,返回的是一个新错误 25 | // wrappedErr 是一个新类型,fmt.wrapError 26 | wrappedErr := fmt.Errorf("this is an wrapped error %w", err) 27 | 28 | // 再解出来 29 | if err == errors.Unwrap(wrappedErr) { 30 | fmt.Println("unwrapped") 31 | } 32 | 33 | if errors.Is(wrappedErr, err) { 34 | // 虽然被包了一下,但是 Is 会逐层解除包装,判断是不是该错误 35 | fmt.Println("wrapped is err") 36 | } 37 | 38 | copyErr := &MyError{} 39 | // 这里尝试将 wrappedErr转换为 MyError 40 | // 注意我们使用了两次的取地址符号 41 | if errors.As(wrappedErr, ©Err) { 42 | fmt.Println("convert error") 43 | } 44 | } 45 | 46 | -------------------------------------------------------------------------------- /examples/third_lesson/errors/panic.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func main() { 6 | defer func() { 7 | if data := recover(); data != nil { 8 | fmt.Printf("hello, panic: %v\n", data) 9 | } 10 | fmt.Println("恢复之后从这里继续执行") 11 | }() 12 | 13 | panic("Boom") 14 | fmt.Println("这里将不会执行下来") 15 | } 16 | -------------------------------------------------------------------------------- /examples/third_lesson/goroutine/goroutine.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | func main() { 9 | GoRoutine() 10 | } 11 | 12 | func GoRoutine() { 13 | go func() { 14 | time.Sleep(10 * time.Second) 15 | }() 16 | // 这里直接输出,不会等待十秒 17 | fmt.Println("I am here") 18 | } -------------------------------------------------------------------------------- /examples/third_lesson/sync/map.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | ) 7 | 8 | func main() { 9 | m := sync.Map{} 10 | m.Store("cat", "Tom") 11 | m.Store("mouse", "Jerry") 12 | 13 | // 这里重新读取出来的,就是 14 | val, ok := m.Load("cat") 15 | if ok { 16 | fmt.Println(len(val.(string))) 17 | } 18 | } -------------------------------------------------------------------------------- /examples/third_lesson/sync/mutex.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | var mutex sync.Mutex 8 | var rwMutex sync.RWMutex 9 | func Mutex() { 10 | mutex.Lock() 11 | defer mutex.Unlock() 12 | // 你的代码 13 | } 14 | 15 | func RwMutex() { 16 | // 加读锁 17 | rwMutex.RLock() 18 | defer rwMutex.RUnlock() 19 | 20 | // 也可以加写锁 21 | rwMutex.Lock() 22 | defer rwMutex.Unlock() 23 | } 24 | 25 | // 不可重入例子 26 | func Failed1() { 27 | mutex.Lock() 28 | defer mutex.Unlock() 29 | 30 | // 这一句会死锁 31 | // 但是如果你只有一个goroutine,那么这一个会导致程序崩溃 32 | mutex.Lock() 33 | defer mutex.Unlock() 34 | } 35 | 36 | // 不可升级 37 | func Failed2() { 38 | rwMutex.RLock() 39 | defer rwMutex.RUnlock() 40 | 41 | // 这一句会死锁 42 | // 但是如果你只有一个goroutine,那么这一个会导致程序崩溃 43 | mutex.Lock() 44 | defer mutex.Unlock() 45 | } 46 | -------------------------------------------------------------------------------- /examples/third_lesson/sync/once.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | ) 7 | 8 | func main() { 9 | PrintOnce() 10 | PrintOnce() 11 | PrintOnce() 12 | } 13 | 14 | var once sync.Once 15 | 16 | // 这个方法,不管调用几次,只会输出一次 17 | func PrintOnce() { 18 | once.Do(func() { 19 | fmt.Println("只输出一次") 20 | }) 21 | } 22 | -------------------------------------------------------------------------------- /examples/third_lesson/sync/pool.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "sync" 4 | 5 | func main() { 6 | pool := sync.Pool{ 7 | New: func() interface{}{ 8 | return &user{} 9 | }} 10 | 11 | // Get 返回的是 interface{},所以需要类型断言 12 | u := pool.Get().(*user) 13 | // defer 还回去 14 | defer pool.Put(u) 15 | 16 | // 紧接着重置 u 这个对象 17 | u.Reset("Tom", "my_email@qq.com") 18 | 19 | // 下边就是使用 u 来完成你的业务逻辑 20 | } 21 | 22 | type user struct { 23 | Name string 24 | Email string 25 | } 26 | 27 | // 一般来说,复用对象都要求我们取出来之后, 28 | // 重置里面的字段 29 | func (u *user) Reset(name string, email string) { 30 | u.Email = email 31 | u.Name = name 32 | } -------------------------------------------------------------------------------- /examples/third_lesson/sync/wait_group.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | ) 7 | 8 | func main() { 9 | res := 0 10 | wg := sync.WaitGroup{} 11 | wg.Add(10) 12 | for i := 0; i < 10; i++ { 13 | go func(val int) { 14 | res += val 15 | wg.Done() 16 | }(i) 17 | } 18 | // 把这个注释掉你会发现,什么结果你都可能拿到 19 | wg.Wait() 20 | fmt.Println(res) 21 | } -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module geektime/toy-web 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/hashicorp/golang-lru v0.5.4 // indirect 7 | github.com/stretchr/testify v1.7.0 // indirect 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= 4 | github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= 5 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 6 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 7 | github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= 8 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 9 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 10 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 12 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 13 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 14 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "geektime/toy-web/demo" 7 | _ "geektime/toy-web/demo/filters" 8 | "geektime/toy-web/pkg" 9 | "net/http" 10 | "time" 11 | ) 12 | 13 | func home(w http.ResponseWriter, r *http.Request) { 14 | fmt.Fprintf(w, "这是主页") 15 | } 16 | 17 | func user(w http.ResponseWriter, r *http.Request) { 18 | fmt.Fprintf(w, "这是用户") 19 | } 20 | 21 | func createUser(w http.ResponseWriter, r *http.Request) { 22 | fmt.Fprintf(w, "这是创建用户") 23 | } 24 | 25 | func order(w http.ResponseWriter, r *http.Request) { 26 | fmt.Fprintf(w, "这是订单") 27 | } 28 | 29 | func main() { 30 | shutdown := web.NewGracefulShutdown() 31 | server := web.NewSdkHttpServer("my-test-server", 32 | web.MetricFilterBuilder, shutdown.ShutdownFilterBuilder) 33 | adminServer := web.NewSdkHttpServer("admin-test-server", 34 | // 注意,如果你真实环境里面,使用的是多个 server监听不同端口, 35 | // 那么这个 shutdown最好也是多个。互相之间就不会有竞争 36 | // MetricFilterBuilder 是无状态的,所以不存在这种问题 37 | web.MetricFilterBuilder, shutdown.ShutdownFilterBuilder) 38 | 39 | // 注册路由 40 | _ = server.Route("POST", "/user/create/*", demo.SignUp) 41 | _ = server.Route("POST", "/slowService", demo.SlowService) 42 | 43 | // 准备静态路由 44 | 45 | staticHandler := web.NewStaticResourceHandler( 46 | "demo/static", "/static", 47 | web.WithMoreExtension(map[string]string{ 48 | "mp3": "audio/mp3", 49 | }), web.WithFileCache(1 << 20, 100)) 50 | // 访问 Get http://localhost:8080/static/forest.png 51 | server.Route("GET", "/static/*", staticHandler.ServeStaticResource) 52 | 53 | go func() { 54 | if err := adminServer.Start(":8081"); err != nil { 55 | panic(err) 56 | } 57 | }() 58 | 59 | go func() { 60 | if err := server.Start(":8080"); err != nil { 61 | // 快速失败,因为服务器都没启动成功,啥也做不了 62 | panic(err) 63 | } 64 | // 假设我们后面还有很多动作 65 | }() 66 | 67 | // 先执行 RejectNewRequestAndWaiting,等待所有的请求 68 | // 然后我们关闭 server,如果是多个 server,可以多个 goroutine 一起关闭 69 | // 70 | web.WaitForShutdown( 71 | func(ctx context.Context) error { 72 | // 假设我们这里有一个 hook 73 | // 可以通知网关我们要下线了 74 | fmt.Println("mock notify gateway") 75 | time.Sleep(time.Second * 2) 76 | return nil 77 | }, 78 | shutdown.RejectNewRequestAndWaiting, 79 | // 全部请求处理完了我们就可以关闭 server了 80 | web.BuildCloseServerHook(server, adminServer), 81 | func(ctx context.Context) error { 82 | // 假设这里我要清理一些执行过程中生成的临时资源 83 | fmt.Println("mock release resources") 84 | time.Sleep(time.Second * 2) 85 | return nil 86 | }) 87 | 88 | // filterNames := ReadFromConfig 89 | // 匿名引入之后,就可以在这里按名索引 filter 90 | //web.NewSdkHttpServerWithFilterNames("my-server", filterNames...) 91 | 92 | } 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /onclass/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | ) 7 | 8 | func home(w http.ResponseWriter, r *http.Request) { 9 | fmt.Fprintf(w, "这是主页") 10 | } 11 | 12 | func user(w http.ResponseWriter, r *http.Request) { 13 | fmt.Fprintf(w, "这是用户") 14 | } 15 | 16 | func createUser(w http.ResponseWriter, r *http.Request) { 17 | fmt.Fprintf(w, "这是创建用户") 18 | } 19 | 20 | func order(w http.ResponseWriter, r *http.Request) { 21 | fmt.Fprintf(w, "这是订单") 22 | } 23 | 24 | 25 | func main() { 26 | http.HandleFunc("/", home) 27 | http.HandleFunc("/user", user) 28 | http.HandleFunc("/user/create", createUser) 29 | http.HandleFunc("/order", order) 30 | http.ListenAndServe(":8080", nil) 31 | } 32 | 33 | type Server interface { 34 | Route(pattern string, handlerFunc http.HandlerFunc) 35 | Start(address string) error 36 | } 37 | 38 | type sdkHttpServer struct { 39 | Name string 40 | } 41 | -------------------------------------------------------------------------------- /pkg/context.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io" 7 | "net/http" 8 | ) 9 | 10 | type Context struct { 11 | W http.ResponseWriter 12 | R *http.Request 13 | PathParams map[string]string 14 | } 15 | 16 | func (c *Context) ReadJson(data interface{}) error { 17 | body, err := io.ReadAll(c.R.Body) 18 | if err != nil { 19 | return err 20 | } 21 | return json.Unmarshal(body, data) 22 | } 23 | func (c *Context) OkJson(data interface{}) error { 24 | // http 库里面提前定义好了各种响应码 25 | return c.WriteJson(http.StatusOK, data) 26 | } 27 | 28 | func (c *Context) SystemErrJson(data interface{}) error { 29 | // http 库里面提前定义好了各种响应码 30 | return c.WriteJson(http.StatusInternalServerError, data) 31 | } 32 | 33 | func (c *Context) BadRequestJson(data interface{}) error { 34 | // http 库里面提前定义好了各种响应码 35 | return c.WriteJson(http.StatusBadRequest, data) 36 | } 37 | 38 | func (c *Context) WriteJson(status int, data interface{}) error { 39 | c.W.WriteHeader(status) 40 | bs, err := json.Marshal(data) 41 | if err != nil { 42 | return err 43 | } 44 | _, err = c.W.Write(bs) 45 | if err != nil { 46 | return err 47 | } 48 | return nil 49 | } 50 | 51 | func NewContext(w http.ResponseWriter, r *http.Request) *Context { 52 | return &Context{ 53 | W: w, 54 | R: r, 55 | // 一般路径参数都是一个,所以容量1就可以了 56 | PathParams: make(map[string]string, 1), 57 | } 58 | } 59 | 60 | func newContext() *Context { 61 | fmt.Println("create new context") 62 | return &Context{ 63 | } 64 | } 65 | 66 | func (c *Context) Reset(w http.ResponseWriter, r *http.Request) { 67 | c.W = w 68 | c.R = r 69 | c.PathParams = make(map[string]string, 1) 70 | } -------------------------------------------------------------------------------- /pkg/filter.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | type FilterBuilder func(next Filter) Filter 9 | 10 | type Filter func(c *Context) 11 | 12 | func MetricFilterBuilder(next Filter) Filter { 13 | return func(c *Context) { 14 | // 执行前的时间 15 | startTime := time.Now().UnixNano() 16 | next(c) 17 | // 执行后的时间 18 | endTime := time.Now().UnixNano() 19 | fmt.Printf("run time: %d \n", endTime-startTime) 20 | } 21 | } 22 | 23 | var builderMap = make(map[string]FilterBuilder, 4) 24 | func RegisterFilter(name string, builder FilterBuilder) { 25 | // 情况1 有些时候你可能不允许重复注册,那么你要先检测是否已经注册过了 26 | // 情况2 你会在并发的环境下调用这个方法,那么你应该 27 | builderMap[name] = builder 28 | } 29 | 30 | func GetFilterBuilder(name string) FilterBuilder { 31 | // 如果你觉得名字必须是正确的,那么你同样需要检测 32 | return builderMap[name] 33 | } -------------------------------------------------------------------------------- /pkg/graceful_shutdown.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "os" 9 | "os/signal" 10 | "sync/atomic" 11 | "time" 12 | ) 13 | 14 | var ErrorHookTimeout = errors.New("the hook timeout") 15 | 16 | type GracefulShutdown struct { 17 | // 还在处理中的请求数 18 | reqCnt int64 19 | // 大于 1 就说明要关闭了 20 | closing int32 21 | 22 | // 用 channel 来通知已经处理完了所有请求 23 | zeroReqCnt chan struct{} 24 | } 25 | 26 | func NewGracefulShutdown() *GracefulShutdown { 27 | return &GracefulShutdown{ 28 | zeroReqCnt: make(chan struct{}), 29 | } 30 | } 31 | 32 | // ShutdownFilterBuilder 这个东西怎么保持线程安全呢? 33 | // 它的逻辑有点绕,核心就在于当我们准备关闭的时候,这个动作是单向的,就是说,我的closing一旦加1 34 | // 就再也不会-1 35 | // 所以我们不需要用一个锁把整个方法锁住 36 | // 而实际上,基于这个理由,我们也不需要把 closing 声明为 int32 37 | // 只需要声明 bool,然后在关闭的时候设置为 true。在这里直接检测 true or false就可以。 38 | // 这种做法有一个很重要的点是,在设置值的时候,即便 bool 被高速缓存缓存了, 39 | // 即便了 bool 在平台上,处理器并不能一条指令 设置好值, 40 | // 但是也没什么关系。因为我们可以确认,最终 bool 会变为 true 41 | // 这个做法更加难以理解,所以采用了使用 closing int32 的做法 42 | func (g *GracefulShutdown) ShutdownFilterBuilder(next Filter) Filter { 43 | return func(c *Context) { 44 | // 开始拒绝所有的请求 45 | cl := atomic.LoadInt32(&g.closing) 46 | if cl > 0 { 47 | c.W.WriteHeader(http.StatusServiceUnavailable) 48 | return 49 | } 50 | atomic.AddInt64(&g.reqCnt, 1) 51 | next(c) 52 | n := atomic.AddInt64(&g.reqCnt, -1) 53 | // 已经开始关闭了,而且请求数为0, 54 | if cl > 0 && n == 0 { 55 | g.zeroReqCnt <- struct{}{} 56 | } 57 | } 58 | } 59 | 60 | // RejectNewRequestAndWaiting 将会拒绝新的请求,并且等待处理中的请求 61 | func (g *GracefulShutdown) RejectNewRequestAndWaiting(ctx context.Context) error { 62 | 63 | atomic.AddInt32(&g.closing, 1) 64 | 65 | // 特殊 case 关闭之前其实就已经处理完了请求。 66 | if atomic.LoadInt64(&g.reqCnt) == 0 { 67 | return nil 68 | } 69 | 70 | done := ctx.Done() 71 | // 因为是单向的,所以我们这里不用 for 在外面包 72 | // 所谓单向就是,我一触发就回不到原来正常处理请求的状态了 73 | // 这个 select 可以理解为,要么超时了 74 | // 要么我这里所有的请求都执行完了 75 | select { 76 | case <- done: 77 | fmt.Println("超时了,还没等到所有请求执行完毕") 78 | return ErrorHookTimeout 79 | case <- g.zeroReqCnt: 80 | fmt.Println("全部请求处理完了") 81 | } 82 | return nil 83 | } 84 | 85 | func WaitForShutdown(hooks...Hook) { 86 | signals := make(chan os.Signal, 1) 87 | signal.Notify(signals, ShutdownSignals...) 88 | select { 89 | case sig := <-signals: 90 | fmt.Printf("get signal %s, application will shutdown \n", sig) 91 | // 十分钟都还不行,就直接强退了 92 | time.AfterFunc(time.Minute * 10, func() { 93 | fmt.Printf("Shutdown gracefully timeout, application will shutdown immediately. ") 94 | os.Exit(1) 95 | }) 96 | for _, h := range hooks { 97 | ctx, cancel := context.WithTimeout(context.Background(), time.Second * 30) 98 | err := h(ctx) 99 | if err != nil { 100 | fmt.Printf("failed to run hook, err: %v \n", err) 101 | } 102 | cancel() 103 | } 104 | os.Exit(0) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /pkg/graceful_shutdown_signal_darwin.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package web 19 | 20 | import ( 21 | "os" 22 | "syscall" 23 | ) 24 | 25 | var ( 26 | // ShutdownSignals receives shutdown signals to process 27 | ShutdownSignals = []os.Signal{ 28 | os.Interrupt, os.Kill, syscall.SIGKILL, syscall.SIGSTOP, 29 | syscall.SIGHUP, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGILL, syscall.SIGTRAP, 30 | syscall.SIGABRT, syscall.SIGSYS, syscall.SIGTERM, 31 | } 32 | ) 33 | -------------------------------------------------------------------------------- /pkg/graceful_shutdown_signal_linux.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package web 19 | 20 | import ( 21 | "os" 22 | "syscall" 23 | ) 24 | 25 | var ( 26 | // ShutdownSignals receives shutdown signals to process 27 | ShutdownSignals = []os.Signal{ 28 | os.Interrupt, os.Kill, syscall.SIGKILL, syscall.SIGSTOP, 29 | syscall.SIGHUP, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGILL, syscall.SIGTRAP, 30 | syscall.SIGABRT, syscall.SIGSYS, syscall.SIGTERM, 31 | } 32 | 33 | // DumpHeapShutdownSignals receives shutdown signals to process 34 | DumpHeapShutdownSignals = []os.Signal{ 35 | syscall.SIGQUIT, syscall.SIGILL, 36 | syscall.SIGTRAP, syscall.SIGABRT, syscall.SIGSYS, 37 | } 38 | ) 39 | -------------------------------------------------------------------------------- /pkg/graceful_shutdown_signal_windows.go: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package web 19 | 20 | import ( 21 | "os" 22 | "syscall" 23 | ) 24 | 25 | var ( 26 | // ShutdownSignals receives shutdown signals to process 27 | ShutdownSignals = []os.Signal{ 28 | os.Interrupt, os.Kill, syscall.SIGKILL, 29 | syscall.SIGHUP, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGILL, syscall.SIGTRAP, 30 | syscall.SIGABRT, syscall.SIGTERM, 31 | } 32 | ) -------------------------------------------------------------------------------- /pkg/handler.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | type Handler interface { 4 | ServeHTTP(c *Context) 5 | Routable 6 | } 7 | 8 | type handlerFunc func(c *Context) -------------------------------------------------------------------------------- /pkg/hook.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | // Hook 是一个钩子函数。注意, 11 | // ctx 是一个有超时机制的 context.Context 12 | // 所以你必须处理超时的问题 13 | type Hook func(ctx context.Context) error 14 | 15 | // BuildCloseServerHook 这里其实可以考虑使用 errgroup, 16 | // 但是我们这里不用是希望每个 server 单独关闭 17 | // 互相之间不影响 18 | func BuildCloseServerHook(servers ...Server) Hook { 19 | return func(ctx context.Context) error { 20 | wg := sync.WaitGroup{} 21 | doneCh := make(chan struct{}) 22 | wg.Add(len(servers)) 23 | 24 | for _, s := range servers { 25 | go func(svr Server) { 26 | err := svr.Shutdown(ctx) 27 | if err != nil { 28 | fmt.Printf("server shutdown error: %v \n", err) 29 | } 30 | time.Sleep(time.Second) 31 | wg.Done() 32 | }(s) 33 | } 34 | go func() { 35 | wg.Wait() 36 | doneCh <- struct{}{} 37 | }() 38 | select { 39 | case <- ctx.Done(): 40 | fmt.Printf("closing servers timeout \n") 41 | return ErrorHookTimeout 42 | case <- doneCh: 43 | fmt.Printf("close all servers \n") 44 | return nil 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /pkg/hook_test.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "context" 5 | "github.com/stretchr/testify/assert" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestBuildCloseServerHook(t *testing.T) { 11 | svr := NewSdkHttpServer("test-sever") 12 | h := BuildCloseServerHook(svr, svr, svr, svr, svr) 13 | ctx, cancel := context.WithTimeout(context.Background(), time.Second * 10) 14 | defer cancel() 15 | err := h(ctx) 16 | assert.Nil(t, err) 17 | 18 | ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond * 10) 19 | defer cancel() 20 | err = h(ctx) 21 | assert.Equal(t, ErrorHookTimeout, err) 22 | } 23 | -------------------------------------------------------------------------------- /pkg/map_router.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "sync" 7 | ) 8 | 9 | // 一种常用的GO设计模式, 10 | // 用于确保HandlerBasedOnMap肯定实现了这个接口 11 | var _ Handler = &HandlerBasedOnMap{} 12 | 13 | 14 | type HandlerBasedOnMap struct { 15 | handlers sync.Map 16 | } 17 | 18 | func (h *HandlerBasedOnMap) ServeHTTP(c *Context) { 19 | request := c.R 20 | key := h.key(request.Method, request.URL.Path) 21 | handler, ok := h.handlers.Load(key) 22 | if !ok { 23 | c.W.WriteHeader(http.StatusNotFound) 24 | _, _ = c.W.Write([]byte("not any router match")) 25 | return 26 | } 27 | 28 | handler.(handlerFunc)(c) 29 | } 30 | 31 | func (h *HandlerBasedOnMap) Route(method string, pattern string, 32 | handlerFunc handlerFunc) error { 33 | key := h.key(method, pattern) 34 | h.handlers.Store(key, handlerFunc) 35 | return nil 36 | } 37 | 38 | func (h *HandlerBasedOnMap) key(method string, 39 | path string) string { 40 | return fmt.Sprintf("%s#%s", method, path) 41 | } 42 | 43 | func NewHandlerBasedOnMap() *HandlerBasedOnMap { 44 | return &HandlerBasedOnMap{} 45 | } 46 | -------------------------------------------------------------------------------- /pkg/server.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "net/http" 7 | "sync" 8 | "time" 9 | ) 10 | 11 | // Routable 可路由的 12 | type Routable interface { 13 | // Route 设定一个路由,命中该路由的会执行handlerFunc的代码 14 | Route(method string, pattern string, handlerFunc handlerFunc) error 15 | } 16 | 17 | // Server 是http server 的顶级抽象 18 | type Server interface { 19 | Routable 20 | // Start 启动我们的服务器 21 | Start(address string) error 22 | 23 | Shutdown(ctx context.Context) error 24 | } 25 | 26 | // sdkHttpServer 这个是基于 net/http 这个包实现的 http server 27 | type sdkHttpServer struct { 28 | // Name server 的名字,给个标记,日志输出的时候用得上 29 | Name string 30 | handler Handler 31 | root Filter 32 | ctxPool sync.Pool 33 | } 34 | 35 | func (s *sdkHttpServer) Route(method string, pattern string, 36 | handlerFunc handlerFunc) error { 37 | return s.handler.Route(method, pattern, handlerFunc) 38 | } 39 | 40 | func (s *sdkHttpServer) Start(address string) error { 41 | return http.ListenAndServe(address, s) 42 | } 43 | 44 | func (s *sdkHttpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 45 | c := s.ctxPool.Get().(*Context) 46 | defer func() { 47 | s.ctxPool.Put(c) 48 | }() 49 | c.Reset(writer, request) 50 | s.root(c) 51 | } 52 | 53 | func (s *sdkHttpServer) Shutdown(ctx context.Context) error { 54 | // 因为我们这个简单的框架,没有什么要清理的, 55 | // 所以我们 sleep 一下来模拟这个过程 56 | fmt.Printf("%s shutdown...\n", s.Name) 57 | time.Sleep(time.Second) 58 | fmt.Printf("%s shutdown!!!\n", s.Name) 59 | return nil 60 | } 61 | 62 | func NewSdkHttpServer(name string, builders ...FilterBuilder) Server { 63 | 64 | // 改用我们的树 65 | handler := NewHandlerBasedOnTree() 66 | //handler := NewHandlerBasedOnMap() 67 | // 因为我们是一个链,所以我们把最后的业务逻辑处理,也作为一环 68 | var root Filter = handler.ServeHTTP 69 | // 从后往前把filter串起来 70 | for i := len(builders) - 1; i >= 0; i-- { 71 | b := builders[i] 72 | root = b(root) 73 | } 74 | res := &sdkHttpServer{ 75 | Name: name, 76 | handler: handler, 77 | root: root, 78 | ctxPool: sync.Pool{New: func() interface {}{ 79 | return newContext() 80 | }}, 81 | } 82 | return res 83 | } 84 | 85 | func NewSdkHttpServerWithFilterNames(name string, 86 | filterNames...string) Server { 87 | // 这里取出来 88 | builders := make([]FilterBuilder, 0, len(filterNames)) 89 | for _, n := range filterNames { 90 | b := GetFilterBuilder(n) 91 | builders = append(builders, b) 92 | } 93 | 94 | return NewSdkHttpServer(name, builders...) 95 | } 96 | 97 | -------------------------------------------------------------------------------- /pkg/static_resource.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "fmt" 5 | lru "github.com/hashicorp/golang-lru" 6 | "io/ioutil" 7 | "net/http" 8 | "os" 9 | "path/filepath" 10 | "strings" 11 | ) 12 | 13 | type StaticResourceHandlerOption func(h *StaticResourceHandler) 14 | 15 | type StaticResourceHandler struct { 16 | dir string 17 | pathPrefix string 18 | extensionContentTypeMap map[string]string 19 | 20 | // 缓存静态资源的限制 21 | cache *lru.Cache 22 | maxFileSize int 23 | } 24 | 25 | type fileCacheItem struct { 26 | fileName string 27 | fileSize int 28 | contentType string 29 | data []byte 30 | } 31 | 32 | func NewStaticResourceHandler(dir string, pathPrefix string, 33 | options...StaticResourceHandlerOption) *StaticResourceHandler { 34 | res := &StaticResourceHandler{ 35 | dir: dir, 36 | pathPrefix: pathPrefix, 37 | extensionContentTypeMap: map[string]string{ 38 | // 这里根据自己的需要不断添加 39 | "jpeg": "image/jpeg", 40 | "jpe": "image/jpeg", 41 | "jpg": "image/jpeg", 42 | "png": "image/png", 43 | "pdf": "image/pdf", 44 | }, 45 | } 46 | 47 | for _, o := range options { 48 | o(res) 49 | } 50 | return res 51 | } 52 | // WithFileCache 静态文件将会被缓存 53 | // maxFileSizeThreshold 超过这个大小的文件,就被认为是大文件,我们将不会缓存 54 | // maxCacheFileCnt 最多缓存多少个文件 55 | // 所以我们最多缓存 maxFileSizeThreshold * maxCacheFileCnt 56 | func WithFileCache(maxFileSizeThreshold int, maxCacheFileCnt int) StaticResourceHandlerOption { 57 | return func(h *StaticResourceHandler) { 58 | c, err := lru.New(maxCacheFileCnt) 59 | if err != nil { 60 | fmt.Printf("could not create LRU, we won't cache static file") 61 | } 62 | h.maxFileSize = maxFileSizeThreshold 63 | h.cache = c 64 | } 65 | } 66 | 67 | func WithMoreExtension(extMap map[string]string) StaticResourceHandlerOption { 68 | return func(h *StaticResourceHandler) { 69 | for ext, contentType := range extMap { 70 | h.extensionContentTypeMap[ext] = contentType 71 | } 72 | } 73 | } 74 | 75 | func (h *StaticResourceHandler) ServeStaticResource(c *Context) { 76 | req := strings.TrimPrefix(c.R.URL.Path, h.pathPrefix) 77 | if item, ok := h.readFileFromData(req); ok { 78 | fmt.Printf("read data from cache...") 79 | h.writeItemAsResponse(item, c.W) 80 | return 81 | } 82 | path := filepath.Join(h.dir, req) 83 | f, err := os.Open(path) 84 | if err != nil { 85 | c.W.WriteHeader(http.StatusInternalServerError) 86 | return 87 | } 88 | ext := getFileExt(f.Name()) 89 | t, ok := h.extensionContentTypeMap[ext] 90 | if !ok { 91 | c.W.WriteHeader(http.StatusBadRequest) 92 | return 93 | } 94 | 95 | data, err := ioutil.ReadAll(f) 96 | if err != nil { 97 | c.W.WriteHeader(http.StatusInternalServerError) 98 | return 99 | } 100 | item := &fileCacheItem{ 101 | fileSize: len(data), 102 | data: data, 103 | contentType: t, 104 | fileName: req, 105 | } 106 | 107 | h.cacheFile(item) 108 | h.writeItemAsResponse(item, c.W) 109 | 110 | } 111 | 112 | func (h *StaticResourceHandler) cacheFile(item *fileCacheItem) { 113 | if h.cache != nil && item.fileSize < h.maxFileSize { 114 | h.cache.Add(item.fileName, item) 115 | } 116 | } 117 | 118 | func (h *StaticResourceHandler) writeItemAsResponse(item *fileCacheItem, writer http.ResponseWriter) { 119 | writer.WriteHeader(http.StatusOK) 120 | writer.Header().Set("Content-Type", item.contentType) 121 | writer.Header().Set("Content-Length", fmt.Sprintf("%d", item.fileSize)) 122 | _, _ = writer.Write(item.data) 123 | 124 | } 125 | 126 | func (h *StaticResourceHandler) readFileFromData(fileName string) (*fileCacheItem, bool) { 127 | if h.cache != nil { 128 | if item, ok := h.cache.Get(fileName); ok { 129 | return item.(*fileCacheItem), true 130 | } 131 | } 132 | return nil, false 133 | } 134 | 135 | func getFileExt(name string) string { 136 | index := strings.LastIndex(name, ".") 137 | if index == len(name) - 1{ 138 | return "" 139 | } 140 | return name[index+1:] 141 | } 142 | -------------------------------------------------------------------------------- /pkg/tree_node.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | const ( 8 | 9 | // 根节点,只有根用这个 10 | nodeTypeRoot = iota 11 | 12 | // * 13 | nodeTypeAny 14 | 15 | // 路径参数 16 | nodeTypeParam 17 | 18 | // 正则 19 | nodeTypeReg 20 | 21 | // 静态,即完全匹配 22 | nodeTypeStatic 23 | ) 24 | 25 | const any = "*" 26 | 27 | // matchFunc 承担两个职责,一个是判断是否匹配,一个是在匹配之后 28 | // 将必要的数据写入到 Context 29 | // 所谓必要的数据,这里基本上是指路径参数 30 | type matchFunc func(path string, c *Context) bool 31 | 32 | type node struct { 33 | children []*node 34 | 35 | // 如果这是叶子节点, 36 | // 那么匹配上之后就可以调用该方法 37 | handler handlerFunc 38 | matchFunc matchFunc 39 | 40 | // 原始的 pattern。注意,它不是完整的pattern, 41 | // 而是匹配到这个节点的pattern 42 | pattern string 43 | nodeType int 44 | } 45 | 46 | // 静态节点 47 | func newStaticNode(path string) *node { 48 | return &node{ 49 | children: make([]*node, 0, 2), 50 | matchFunc: func(p string, c *Context) bool { 51 | return path == p && p != "*" 52 | }, 53 | nodeType: nodeTypeStatic, 54 | pattern: path, 55 | } 56 | } 57 | 58 | 59 | func newRootNode(method string) *node { 60 | return &node{ 61 | children: make([]*node, 0, 2), 62 | matchFunc: func( p string, c *Context) bool { 63 | panic("never call me") 64 | }, 65 | nodeType: nodeTypeRoot, 66 | pattern: method, 67 | } 68 | } 69 | 70 | func newNode(path string) *node { 71 | if path == "*"{ 72 | return newAnyNode() 73 | } 74 | if strings.HasPrefix(path, ":") { 75 | return newParamNode(path) 76 | } 77 | return newStaticNode(path) 78 | } 79 | 80 | // 通配符 * 节点 81 | func newAnyNode() *node { 82 | return &node{ 83 | // 因为我们不允许 * 后面还有节点,所以这里可以不用初始化 84 | //children: make([]*node, 0, 2), 85 | matchFunc: func(p string, c *Context) bool { 86 | return true 87 | }, 88 | nodeType: nodeTypeAny, 89 | pattern: any, 90 | } 91 | } 92 | 93 | // 路径参数节点 94 | func newParamNode(path string) *node { 95 | paramName := path[1:] 96 | return &node{ 97 | children: make([]*node, 0, 2), 98 | matchFunc: func(p string, c *Context) bool { 99 | if c != nil { 100 | c.PathParams[paramName] = p 101 | } 102 | // 如果自身是一个参数路由, 103 | // 然后又来一个通配符,我们认为是不匹配的 104 | return p != any 105 | }, 106 | nodeType: nodeTypeParam, 107 | pattern: path, 108 | } 109 | } 110 | 111 | // 正则节点 112 | //func newRegNode(path string) *node { 113 | // // 依据你的规则拿到正则表达式 114 | // return &node{ 115 | // children: make([]*node, 0, 2), 116 | // matchFunc: func(p string, c *Context) bool { 117 | // // 怎么写? 118 | // }, 119 | // nodeType: nodeTypeParam, 120 | // pattern: path, 121 | // } 122 | //} 123 | 124 | -------------------------------------------------------------------------------- /pkg/tree_router.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "sort" 7 | "strings" 8 | ) 9 | 10 | var ErrorInvalidRouterPattern = errors.New("invalid router pattern") 11 | var ErrorInvalidMethod = errors.New("invalid method") 12 | 13 | type HandlerBasedOnTree struct { 14 | forest map[string]*node 15 | } 16 | 17 | var supportMethods = [4]string { 18 | http.MethodGet, 19 | http.MethodPost, 20 | http.MethodPut, 21 | http.MethodDelete, 22 | } 23 | 24 | func NewHandlerBasedOnTree() Handler { 25 | forest := make(map[string]*node, len(supportMethods)) 26 | for _, m :=range supportMethods { 27 | forest[m] = newRootNode(m) 28 | } 29 | return &HandlerBasedOnTree{ 30 | forest: forest, 31 | } 32 | } 33 | 34 | // ServeHTTP 就是从树里面找节点 35 | // 找到了就执行 36 | func (h *HandlerBasedOnTree) ServeHTTP(c *Context) { 37 | handler, found := h.findRouter(c.R.Method, c.R.URL.Path, c) 38 | if !found { 39 | c.W.WriteHeader(http.StatusNotFound) 40 | _, _ = c.W.Write([]byte("Not Found")) 41 | return 42 | } 43 | handler(c) 44 | } 45 | 46 | func (h *HandlerBasedOnTree) findRouter(method, path string, c *Context) (handlerFunc, bool) { 47 | // 去除头尾可能有的/,然后按照/切割成段 48 | paths := strings.Split(strings.Trim(path, "/"), "/") 49 | cur, ok := h.forest[method] 50 | if !ok { 51 | return nil, false 52 | } 53 | for _, p := range paths { 54 | // 从子节点里边找一个匹配到了当前 p 的节点 55 | matchChild, found := h.findMatchChild(cur, p, c) 56 | if !found { 57 | return nil, false 58 | } 59 | cur = matchChild 60 | } 61 | // 到这里,应该是找完了 62 | if cur.handler == nil { 63 | // 到达这里是因为这种场景 64 | // 比如说你注册了 /user/profile 65 | // 然后你访问 /user 66 | return nil, false 67 | } 68 | return cur.handler, true 69 | } 70 | 71 | // Route 就相当于往树里面插入节点 72 | func (h *HandlerBasedOnTree) Route(method string, pattern string, 73 | handlerFunc handlerFunc) error { 74 | 75 | err := h.validatePattern(pattern) 76 | if err != nil { 77 | return err 78 | } 79 | 80 | // 将pattern按照URL的分隔符切割 81 | // 例如,/user/friends 将变成 [user, friends] 82 | // 将前后的/去掉,统一格式 83 | pattern = strings.Trim(pattern, "/") 84 | paths := strings.Split(pattern, "/") 85 | 86 | // 当前指向根节点 87 | cur, ok := h.forest[method] 88 | if !ok { 89 | return ErrorInvalidMethod 90 | } 91 | for index, path := range paths { 92 | 93 | // 从子节点里边找一个匹配到了当前 path 的节点 94 | matchChild, found := h.findMatchChild(cur, path, nil) 95 | // != nodeTypeAny 是考虑到 /order/* 和 /order/:id 这种注册顺序 96 | if found && matchChild.nodeType != nodeTypeAny { 97 | cur = matchChild 98 | } else { 99 | // 为当前节点根据 100 | h.createSubTree(cur, paths[index:], handlerFunc) 101 | return nil 102 | } 103 | } 104 | // 离开了循环,说明我们加入的是短路径, 105 | // 比如说我们先加入了 /order/detail 106 | // 再加入/order,那么会走到这里 107 | cur.handler = handlerFunc 108 | return nil 109 | } 110 | 111 | func (h *HandlerBasedOnTree) validatePattern(pattern string) error { 112 | // 校验 *,如果存在,必须在最后一个,并且它前面必须是/ 113 | // 即我们只接受 /* 的存在,abc*这种是非法 114 | 115 | pos := strings.Index(pattern, "*") 116 | // 找到了 * 117 | if pos > 0 { 118 | // 必须是最后一个 119 | if pos != len(pattern) - 1 { 120 | return ErrorInvalidRouterPattern 121 | } 122 | if pattern[pos-1] != '/' { 123 | return ErrorInvalidRouterPattern 124 | } 125 | } 126 | return nil 127 | } 128 | 129 | func (h *HandlerBasedOnTree) findMatchChild(root *node, 130 | path string, c *Context) (*node, bool) { 131 | candidates := make([]*node, 0, 2) 132 | for _, child := range root.children { 133 | if child.matchFunc(path, c) { 134 | candidates = append(candidates, child) 135 | } 136 | } 137 | 138 | if len(candidates) == 0 { 139 | return nil, false 140 | } 141 | 142 | // type 也决定了它们的优先级 143 | sort.Slice(candidates, func(i, j int) bool { 144 | return candidates[i].nodeType < candidates[j].nodeType 145 | }) 146 | return candidates[len(candidates) - 1], true 147 | } 148 | 149 | func (h *HandlerBasedOnTree) createSubTree(root *node, paths []string, handlerFn handlerFunc) { 150 | cur := root 151 | for _, path := range paths { 152 | nn := newNode(path) 153 | cur.children = append(cur.children, nn) 154 | cur = nn 155 | } 156 | cur.handler = handlerFn 157 | } 158 | 159 | -------------------------------------------------------------------------------- /pkg/tree_router_test.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "net/http" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestHandlerBasedOnTree_Route(t *testing.T) { 11 | handler := NewHandlerBasedOnTree().(*HandlerBasedOnTree) 12 | // 要确认已经为支持的方法创建了节点 13 | assert.Equal(t, len(supportMethods), len(handler.forest)) 14 | 15 | postNode := handler.forest[http.MethodPost] 16 | 17 | err := handler.Route(http.MethodPost, "/user", func(c *Context) {}) 18 | assert.Nil(t, err) 19 | assert.Equal(t, 1, len(postNode.children)) 20 | 21 | n := postNode.children[0] 22 | assert.NotNil(t, n) 23 | assert.Equal(t, "user", n.pattern) 24 | assert.NotNil(t, n.handler) 25 | assert.Empty(t, n.children) 26 | 27 | // 我们只有 28 | // user -> profile 29 | err = handler.Route(http.MethodPost, "/user/profile", func(c *Context) {}) 30 | assert.Nil(t, err) 31 | assert.Equal(t, 1, len(n.children)) 32 | profileNode := n.children[0] 33 | assert.NotNil(t, profileNode) 34 | assert.Equal(t, "profile", profileNode.pattern) 35 | assert.NotNil(t, profileNode.handler) 36 | assert.Empty(t, profileNode.children) 37 | 38 | // 试试重复 39 | err = handler.Route(http.MethodPost, "/user", func(c *Context) {}) 40 | assert.Nil(t, err) 41 | n = postNode.children[0] 42 | assert.NotNil(t, n) 43 | assert.Equal(t, "user", n.pattern) 44 | assert.NotNil(t, n.handler) 45 | // 有profile节点 46 | assert.Equal(t, 1, len(n.children)) 47 | 48 | // 给 user 再加一个节点 49 | err = handler.Route(http.MethodPost, "/user/home", func(c *Context) {}) 50 | assert.Nil(t, err) 51 | assert.Equal(t, 2, len(n.children)) 52 | homeNode := n.children[1] 53 | assert.NotNil(t, homeNode) 54 | assert.Equal(t, "home", homeNode.pattern) 55 | assert.NotNil(t, homeNode.handler) 56 | assert.Empty(t, homeNode.children) 57 | 58 | // 添加 /order/detail 59 | err = handler.Route(http.MethodPost, "/order/detail", func(c *Context) {}) 60 | assert.Equal(t, 2, len(postNode.children)) 61 | orderNode := postNode.children[1] 62 | assert.NotNil(t, orderNode) 63 | assert.Equal(t, "order", orderNode.pattern) 64 | // 此刻我们只有/order/detail,但是没有/order 65 | assert.Nil(t, orderNode.handler) 66 | assert.Equal(t, 1, len(orderNode.children)) 67 | 68 | orderDetailNode := orderNode.children[0] 69 | assert.NotNil(t, orderDetailNode) 70 | assert.Empty(t, orderDetailNode.children) 71 | assert.Equal(t, "detail", orderDetailNode.pattern) 72 | assert.NotNil(t, orderDetailNode.handler) 73 | 74 | // 加一个 /order 75 | err = handler.Route(http.MethodPost, "/order", func(c *Context) {}) 76 | assert.Nil(t, err) 77 | assert.Equal(t, 2, len(postNode.children)) 78 | orderNode = postNode.children[1] 79 | assert.Equal(t, "order", orderNode.pattern) 80 | // 此时我们有了 /order 81 | assert.NotNil(t, orderNode.handler) 82 | 83 | err = handler.Route(http.MethodPost, "/order/*", func(c *Context) {}) 84 | assert.Nil(t, err) 85 | assert.Equal(t, 2, len(orderNode.children)) 86 | orderWildcard := orderNode.children[1] 87 | assert.NotNil(t, orderWildcard) 88 | assert.NotNil(t, orderWildcard.handler) 89 | assert.Equal(t, "*", orderWildcard.pattern) 90 | 91 | err = handler.Route(http.MethodPost, "/order/*/checkout", func(c *Context) {}) 92 | assert.Equal(t, ErrorInvalidRouterPattern, err) 93 | 94 | err = handler.Route(http.MethodConnect, "/order/checkout", func(c *Context) {}) 95 | assert.Equal(t, ErrorInvalidMethod, err) 96 | 97 | err = handler.Route(http.MethodPost, "/order/:id", func(c *Context){}) 98 | assert.Nil(t, err) 99 | // 这时候我们有/order/* 和 /order/:id 100 | // 因为我们并没有认为它们不兼容,而是/order/:id优先 101 | assert.Equal(t, 3, len(orderNode.children)) 102 | orderParamNode := orderNode.children[2] 103 | assert.Equal(t, ":id", orderParamNode.pattern) 104 | 105 | } 106 | 107 | func TestHandlerBasedOnTree_findRouter(t *testing.T) { 108 | handler := NewHandlerBasedOnTree().(*HandlerBasedOnTree) 109 | _ = handler.Route(http.MethodPost, "/user", func(c *Context) {}) 110 | ctx := NewContext(nil, nil) 111 | fn, found := handler.findRouter(http.MethodPost, "/user", ctx) 112 | assert.True(t, found) 113 | assert.NotNil(t, fn) 114 | _, found = handler.findRouter(http.MethodPost,"/user/profile", ctx) 115 | assert.False(t, found) 116 | 117 | _ = handler.Route(http.MethodPost, "/user/profile", func(c *Context) {}) 118 | _, found = handler.findRouter(http.MethodPost, "/user/profile", ctx) 119 | assert.True(t, found) 120 | 121 | _, found = handler.findRouter(http.MethodPost, "/user", ctx) 122 | assert.True(t, found) 123 | 124 | var detailHandler handlerFunc = func(c *Context) {} 125 | _ = handler.Route(http.MethodPost, "/order/detail", detailHandler) 126 | _, found = handler.findRouter(http.MethodPost,"/order", ctx) 127 | assert.False(t, found) 128 | 129 | fn, found = handler.findRouter(http.MethodPost,"/order/detail", ctx) 130 | assert.True(t, found) 131 | assert.True(t, handlerFuncEquals(detailHandler, fn)) 132 | 133 | var wildcardHandler handlerFunc = func(c *Context) {} 134 | _ = handler.Route(http.MethodPost, "/order/*", wildcardHandler) 135 | _, found = handler.findRouter(http.MethodPost,"/order", ctx) 136 | assert.False(t, found) 137 | 138 | fn, found = handler.findRouter(http.MethodPost,"/order/detail", ctx) 139 | assert.True(t, found) 140 | assert.True(t, handlerFuncEquals(detailHandler, fn)) 141 | 142 | fn, found = handler.findRouter(http.MethodPost,"/order/checkout", ctx) 143 | assert.True(t, found) 144 | assert.True(t, handlerFuncEquals(wildcardHandler, fn)) 145 | 146 | _, found = handler.findRouter(http.MethodGet,"/order/checkout", ctx) 147 | assert.False(t, found) 148 | 149 | // 参数路由 150 | handler.Route(http.MethodPost, "/order/*", wildcardHandler) 151 | } 152 | 153 | func handlerFuncEquals(hf1 handlerFunc, hf2 handlerFunc) bool { 154 | return reflect.ValueOf(hf1).Pointer() == reflect.ValueOf(hf2).Pointer() 155 | } -------------------------------------------------------------------------------- /pkg/v1/context.go: -------------------------------------------------------------------------------- 1 | package webv1 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "net/http" 7 | ) 8 | 9 | type Context struct { 10 | W http.ResponseWriter 11 | R *http.Request 12 | } 13 | 14 | func (c *Context) ReadJson(data interface{}) error { 15 | body, err := io.ReadAll(c.R.Body) 16 | if err != nil { 17 | return err 18 | } 19 | return json.Unmarshal(body, data) 20 | } 21 | func (c *Context) OkJson(data interface{}) error { 22 | // http 库里面提前定义好了各种响应码 23 | return c.WriteJson(http.StatusOK, data) 24 | } 25 | 26 | func (c *Context) SystemErrJson(data interface{}) error { 27 | // http 库里面提前定义好了各种响应码 28 | return c.WriteJson(http.StatusInternalServerError, data) 29 | } 30 | 31 | func (c *Context) BadRequestJson(data interface{}) error { 32 | // http 库里面提前定义好了各种响应码 33 | return c.WriteJson(http.StatusBadRequest, data) 34 | } 35 | 36 | func (c *Context) WriteJson(status int, data interface{}) error { 37 | bs, err := json.Marshal(data) 38 | if err != nil { 39 | return err 40 | } 41 | _, err = c.W.Write(bs) 42 | if err != nil { 43 | return err 44 | } 45 | c.W.WriteHeader(status) 46 | return nil 47 | } 48 | 49 | func NewContext(w http.ResponseWriter, r *http.Request) *Context { 50 | return &Context{ 51 | W: w, 52 | R: r, 53 | } 54 | } -------------------------------------------------------------------------------- /pkg/v1/filter.go: -------------------------------------------------------------------------------- 1 | package webv1 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | type FilterBuilder func(next Filter) Filter 9 | 10 | type Filter func(c *Context) 11 | 12 | func MetricFilterBuilder(next Filter) Filter { 13 | return func(c *Context) { 14 | // 执行前的时间 15 | startTime := time.Now().UnixNano() 16 | next(c) 17 | // 执行后的时间 18 | endTime := time.Now().UnixNano() 19 | fmt.Printf("run time: %d \n", endTime-startTime) 20 | } 21 | } -------------------------------------------------------------------------------- /pkg/v1/handler.go: -------------------------------------------------------------------------------- 1 | package webv1 2 | 3 | type Handler interface { 4 | ServeHTTP(c *Context) 5 | Routable 6 | } 7 | 8 | type handlerFunc func(c *Context) -------------------------------------------------------------------------------- /pkg/v1/map_router.go: -------------------------------------------------------------------------------- 1 | package webv1 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "sync" 7 | ) 8 | 9 | // 一种常用的GO设计模式, 10 | // 用于确保HandlerBasedOnMap肯定实现了这个接口 11 | var _ Handler = &HandlerBasedOnMap{} 12 | 13 | 14 | type HandlerBasedOnMap struct { 15 | handlers sync.Map 16 | } 17 | 18 | func (h *HandlerBasedOnMap) ServeHTTP(c *Context) { 19 | request := c.R 20 | key := h.key(request.Method, request.URL.Path) 21 | handler, ok := h.handlers.Load(key) 22 | if !ok { 23 | c.W.WriteHeader(http.StatusNotFound) 24 | _, _ = c.W.Write([]byte("not any router match")) 25 | return 26 | } 27 | 28 | handler.(handlerFunc)(c) 29 | } 30 | 31 | func (h *HandlerBasedOnMap) Route(method string, pattern string, 32 | handlerFunc handlerFunc) { 33 | key := h.key(method, pattern) 34 | h.handlers.Store(key, handlerFunc) 35 | } 36 | 37 | func (h *HandlerBasedOnMap) key(method string, 38 | path string) string { 39 | return fmt.Sprintf("%s#%s", method, path) 40 | } 41 | 42 | func NewHandlerBasedOnMap() *HandlerBasedOnMap { 43 | return &HandlerBasedOnMap{} 44 | } 45 | -------------------------------------------------------------------------------- /pkg/v1/server.go: -------------------------------------------------------------------------------- 1 | package webv1 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | // Routable 可路由的 8 | type Routable interface { 9 | // Route 设定一个路由,命中该路由的会执行handlerFunc的代码 10 | Route(method string, pattern string, handlerFunc handlerFunc) 11 | } 12 | 13 | // Server 是http server 的顶级抽象 14 | type Server interface { 15 | Routable 16 | // Start 启动我们的服务器 17 | Start(address string) error 18 | } 19 | 20 | // sdkHttpServer 这个是基于 net/http 这个包实现的 http server 21 | type sdkHttpServer struct { 22 | // Name server 的名字,给个标记,日志输出的时候用得上 23 | Name string 24 | handler Handler 25 | root Filter 26 | } 27 | 28 | func (s *sdkHttpServer) Route(method string, pattern string, 29 | handlerFunc handlerFunc) { 30 | s.handler.Route(method, pattern, handlerFunc) 31 | } 32 | 33 | func (s *sdkHttpServer) Start(address string) error { 34 | http.HandleFunc("/", func(writer http.ResponseWriter, 35 | request *http.Request) { 36 | c := NewContext(writer, request) 37 | s.root(c) 38 | }) 39 | return http.ListenAndServe(address, nil) 40 | } 41 | 42 | func NewSdkHttpServer(name string, builders ...FilterBuilder) Server { 43 | 44 | // 改用我们的树 45 | handler := NewHandlerBasedOnTree() 46 | //handler := NewHandlerBasedOnMap() 47 | // 因为我们是一个链,所以我们把最后的业务逻辑处理,也作为一环 48 | var root Filter = handler.ServeHTTP 49 | // 从后往前把filter串起来 50 | for i := len(builders) - 1; i >= 0; i-- { 51 | b := builders[i] 52 | root = b(root) 53 | } 54 | res := &sdkHttpServer{ 55 | Name: name, 56 | handler: handler, 57 | root: root, 58 | } 59 | return res 60 | } 61 | 62 | -------------------------------------------------------------------------------- /pkg/v1/tree_router.go: -------------------------------------------------------------------------------- 1 | package webv1 2 | 3 | import ( 4 | "net/http" 5 | "strings" 6 | ) 7 | 8 | type HandlerBasedOnTree struct { 9 | root *node 10 | 11 | } 12 | 13 | func NewHandlerBasedOnTree() Handler { 14 | return &HandlerBasedOnTree{ 15 | root: &node{}, 16 | } 17 | } 18 | 19 | // ServeHTTP 就是从树里面找节点 20 | // 找到了就执行 21 | func (h *HandlerBasedOnTree) ServeHTTP(c *Context) { 22 | handler, found := h.findRouter(c.R.URL.Path) 23 | if !found { 24 | c.W.WriteHeader(http.StatusNotFound) 25 | _, _ = c.W.Write([]byte("Not Found")) 26 | return 27 | } 28 | handler(c) 29 | } 30 | 31 | // 这个是不好测试的版本,可以尝试为这个写单元测试 32 | // 会发现很难构造 request,也很难对 ResponseWriter做断言 33 | //func (h *HandlerBasedOnTree) ServeHTTP(c *Context) { 34 | // url := strings.Trim(c.R.URL.Path, "/") 35 | // paths := strings.Split(url, "/") 36 | // cur := h.root 37 | // for _, path := range paths { 38 | // // 从子节点里边找一个匹配到了当前 path 的节点 39 | // matchChild, found := h.findMatchChild(cur, path) 40 | // if !found { 41 | // // 找不到匹配的路径,直接返回 42 | // c.W.WriteHeader(404) 43 | // _, _ = c.W.Write([]byte("Not Found")) 44 | // return 45 | // } 46 | // cur = matchChild 47 | // } 48 | // // 到这里,应该是找完了 49 | // if cur.handler == nil { 50 | // // 到达这里是因为这种场景 51 | // // 比如说你注册了 /user/profile 52 | // // 然后你访问 /user 53 | // c.W.WriteHeader(404) 54 | // _, _ = c.W.Write([]byte("Not Found")) 55 | // return 56 | // } 57 | // cur.handler(c) 58 | //} 59 | 60 | func (h *HandlerBasedOnTree) findRouter(path string) (handlerFunc, bool) { 61 | // 去除头尾可能有的/,然后按照/切割成段 62 | paths := strings.Split(strings.Trim(path, "/"), "/") 63 | cur := h.root 64 | for _, p := range paths { 65 | // 从子节点里边找一个匹配到了当前 path 的节点 66 | matchChild, found := h.findMatchChild(cur, p) 67 | if !found { 68 | return nil, false 69 | } 70 | cur = matchChild 71 | } 72 | // 到这里,应该是找完了 73 | if cur.handler == nil { 74 | // 到达这里是因为这种场景 75 | // 比如说你注册了 /user/profile 76 | // 然后你访问 /user 77 | return nil, false 78 | } 79 | return cur.handler, true 80 | } 81 | 82 | // Route 就相当于往树里面插入节点 83 | func (h *HandlerBasedOnTree) Route(method string, pattern string, 84 | handlerFunc handlerFunc) { 85 | // 将pattern按照URL的分隔符切割 86 | // 例如,/user/friends 将变成 [user, friends] 87 | // 将前后的/去掉,统一格式 88 | pattern = strings.Trim(pattern, "/") 89 | paths := strings.Split(pattern, "/") 90 | // 当前指向根节点 91 | cur := h.root 92 | for index, path := range paths { 93 | // 从子节点里边找一个匹配到了当前 path 的节点 94 | matchChild, found := h.findMatchChild(cur, path) 95 | if found { 96 | cur = matchChild 97 | } else { 98 | // 为当前节点根据 99 | h.createSubTree(cur, paths[index:], handlerFunc) 100 | return 101 | } 102 | } 103 | // 离开了循环,说明我们加入的是短路径, 104 | // 比如说我们先加入了 /order/detail 105 | // 再加入/order,那么会走到这里 106 | cur.handler = handlerFunc 107 | } 108 | 109 | func (h *HandlerBasedOnTree) findMatchChild(root *node, path string) (*node, bool) { 110 | for _, child := range root.children { 111 | if child.path == path { 112 | return child, true 113 | } 114 | } 115 | return nil, false 116 | } 117 | 118 | func (h *HandlerBasedOnTree) createSubTree(root *node, paths []string, handlerFn handlerFunc) { 119 | cur := root 120 | for _, path := range paths { 121 | nn := newNode(path) 122 | cur.children = append(cur.children, nn) 123 | cur = nn 124 | } 125 | cur.handler = handlerFn 126 | } 127 | 128 | type node struct { 129 | path string 130 | children []*node 131 | 132 | // 如果这是叶子节点, 133 | // 那么匹配上之后就可以调用该方法 134 | handler handlerFunc 135 | } 136 | 137 | func newNode(path string) *node { 138 | return &node{ 139 | path: path, 140 | children: make([]*node, 0, 2), 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /pkg/v1/tree_router_test.go: -------------------------------------------------------------------------------- 1 | package webv1 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "net/http" 6 | "testing" 7 | ) 8 | 9 | func TestHandlerBasedOnTree_Route(t *testing.T) { 10 | handler := NewHandlerBasedOnTree().(*HandlerBasedOnTree) 11 | assert.NotNil(t, handler.root) 12 | 13 | handler.Route(http.MethodPost, "/user", func(c *Context) {}) 14 | 15 | // 开始做断言,这个时候我们应该确认,在根节点之下只有一个user节点 16 | assert.Equal(t, 1, len(handler.root.children)) 17 | 18 | n := handler.root.children[0] 19 | assert.NotNil(t, n) 20 | assert.Equal(t, "user", n.path) 21 | assert.NotNil(t, n.handler) 22 | assert.Empty(t, n.children) 23 | 24 | // 我们只有 25 | // user -> profile 26 | handler.Route(http.MethodPost, "/user/profile", func(c *Context) {}) 27 | assert.Equal(t, 1, len(n.children)) 28 | profileNode := n.children[0] 29 | assert.NotNil(t, profileNode) 30 | assert.Equal(t, "profile", profileNode.path) 31 | assert.NotNil(t, profileNode.handler) 32 | assert.Empty(t, profileNode.children) 33 | 34 | // 试试重复 35 | handler.Route(http.MethodPost, "/user", func(c *Context) {}) 36 | n = handler.root.children[0] 37 | assert.NotNil(t, n) 38 | assert.Equal(t, "user", n.path) 39 | assert.NotNil(t, n.handler) 40 | // 有profile节点 41 | assert.Equal(t, 1, len(n.children)) 42 | 43 | // 给 user 再加一个节点 44 | handler.Route(http.MethodPost, "/user/home", func(c *Context) {}) 45 | assert.Equal(t, 2, len(n.children)) 46 | homeNode := n.children[1] 47 | assert.NotNil(t, homeNode) 48 | assert.Equal(t, "home", homeNode.path) 49 | assert.NotNil(t, homeNode.handler) 50 | assert.Empty(t, homeNode.children) 51 | 52 | // 添加 /order/detail 53 | handler.Route(http.MethodPost, "/order/detail", func(c *Context) {}) 54 | assert.Equal(t, 2, len(handler.root.children)) 55 | orderNode := handler.root.children[1] 56 | assert.NotNil(t, orderNode) 57 | assert.Equal(t, "order", orderNode.path) 58 | // 此刻我们只有/order/detail,但是没有/order 59 | assert.Nil(t, orderNode.handler) 60 | assert.Equal(t, 1, len(orderNode.children)) 61 | 62 | orderDetailNode := orderNode.children[0] 63 | assert.NotNil(t, orderDetailNode) 64 | assert.Empty(t, orderDetailNode.children) 65 | assert.Equal(t, "detail", orderDetailNode.path) 66 | assert.NotNil(t, orderDetailNode.handler) 67 | 68 | // 加一个 /order 69 | handler.Route(http.MethodPost, "/order", func(c *Context) {}) 70 | assert.Equal(t, 2, len(handler.root.children)) 71 | orderNode = handler.root.children[1] 72 | assert.Equal(t, "order", orderNode.path) 73 | // 此时我们有了 /order 74 | assert.NotNil(t, orderNode.handler) 75 | 76 | } 77 | 78 | func TestHandlerBasedOnTree_findRouter(t *testing.T) { 79 | handler := NewHandlerBasedOnTree().(*HandlerBasedOnTree) 80 | handler.Route(http.MethodPost, "/user", func(c *Context) {}) 81 | _, found := handler.findRouter("/user") 82 | assert.True(t, found) 83 | _, found = handler.findRouter("/user/profile") 84 | assert.False(t, found) 85 | 86 | handler.Route(http.MethodPost, "/user/profile", func(c *Context) {}) 87 | _, found = handler.findRouter("/user/profile") 88 | assert.True(t, found) 89 | 90 | _, found = handler.findRouter("/user") 91 | assert.True(t, found) 92 | 93 | handler.Route(http.MethodPost, "/order/detail", func(c *Context) {}) 94 | _, found = handler.findRouter("/order") 95 | assert.False(t, found) 96 | 97 | _, found = handler.findRouter("/order/detail") 98 | assert.True(t, found) 99 | 100 | handler.Route(http.MethodPost, "/order", func(c *Context) {}) 101 | _, found = handler.findRouter("/order") 102 | assert.True(t, found) 103 | } -------------------------------------------------------------------------------- /pkg/v2/context.go: -------------------------------------------------------------------------------- 1 | package webv2 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "net/http" 7 | ) 8 | 9 | type Context struct { 10 | W http.ResponseWriter 11 | R *http.Request 12 | } 13 | 14 | func (c *Context) ReadJson(data interface{}) error { 15 | body, err := io.ReadAll(c.R.Body) 16 | if err != nil { 17 | return err 18 | } 19 | return json.Unmarshal(body, data) 20 | } 21 | func (c *Context) OkJson(data interface{}) error { 22 | // http 库里面提前定义好了各种响应码 23 | return c.WriteJson(http.StatusOK, data) 24 | } 25 | 26 | func (c *Context) SystemErrJson(data interface{}) error { 27 | // http 库里面提前定义好了各种响应码 28 | return c.WriteJson(http.StatusInternalServerError, data) 29 | } 30 | 31 | func (c *Context) BadRequestJson(data interface{}) error { 32 | // http 库里面提前定义好了各种响应码 33 | return c.WriteJson(http.StatusBadRequest, data) 34 | } 35 | 36 | func (c *Context) WriteJson(status int, data interface{}) error { 37 | bs, err := json.Marshal(data) 38 | if err != nil { 39 | return err 40 | } 41 | _, err = c.W.Write(bs) 42 | if err != nil { 43 | return err 44 | } 45 | c.W.WriteHeader(status) 46 | return nil 47 | } 48 | 49 | func NewContext(w http.ResponseWriter, r *http.Request) *Context { 50 | return &Context{ 51 | W: w, 52 | R: r, 53 | } 54 | } -------------------------------------------------------------------------------- /pkg/v2/filter.go: -------------------------------------------------------------------------------- 1 | package webv2 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | type FilterBuilder func(next Filter) Filter 9 | 10 | type Filter func(c *Context) 11 | 12 | func MetricFilterBuilder(next Filter) Filter { 13 | return func(c *Context) { 14 | // 执行前的时间 15 | startTime := time.Now().UnixNano() 16 | next(c) 17 | // 执行后的时间 18 | endTime := time.Now().UnixNano() 19 | fmt.Printf("run time: %d \n", endTime-startTime) 20 | } 21 | } -------------------------------------------------------------------------------- /pkg/v2/handler.go: -------------------------------------------------------------------------------- 1 | package webv2 2 | 3 | type Handler interface { 4 | ServeHTTP(c *Context) 5 | Routable 6 | } 7 | 8 | type handlerFunc func(c *Context) -------------------------------------------------------------------------------- /pkg/v2/map_router.go: -------------------------------------------------------------------------------- 1 | package webv2 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "sync" 7 | ) 8 | 9 | // 一种常用的GO设计模式, 10 | // 用于确保HandlerBasedOnMap肯定实现了这个接口 11 | var _ Handler = &HandlerBasedOnMap{} 12 | 13 | 14 | type HandlerBasedOnMap struct { 15 | handlers sync.Map 16 | } 17 | 18 | func (h *HandlerBasedOnMap) ServeHTTP(c *Context) { 19 | request := c.R 20 | key := h.key(request.Method, request.URL.Path) 21 | handler, ok := h.handlers.Load(key) 22 | if !ok { 23 | c.W.WriteHeader(http.StatusNotFound) 24 | _, _ = c.W.Write([]byte("not any router match")) 25 | return 26 | } 27 | 28 | handler.(handlerFunc)(c) 29 | } 30 | 31 | func (h *HandlerBasedOnMap) Route(method string, pattern string, 32 | handlerFunc handlerFunc) error { 33 | key := h.key(method, pattern) 34 | h.handlers.Store(key, handlerFunc) 35 | return nil 36 | } 37 | 38 | func (h *HandlerBasedOnMap) key(method string, 39 | path string) string { 40 | return fmt.Sprintf("%s#%s", method, path) 41 | } 42 | 43 | func NewHandlerBasedOnMap() *HandlerBasedOnMap { 44 | return &HandlerBasedOnMap{} 45 | } 46 | -------------------------------------------------------------------------------- /pkg/v2/server.go: -------------------------------------------------------------------------------- 1 | package webv2 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | // Routable 可路由的 8 | type Routable interface { 9 | // Route 设定一个路由,命中该路由的会执行handlerFunc的代码 10 | Route(method string, pattern string, handlerFunc handlerFunc) error 11 | } 12 | 13 | // Server 是http server 的顶级抽象 14 | type Server interface { 15 | Routable 16 | // Start 启动我们的服务器 17 | Start(address string) error 18 | } 19 | 20 | // sdkHttpServer 这个是基于 net/http 这个包实现的 http server 21 | type sdkHttpServer struct { 22 | // Name server 的名字,给个标记,日志输出的时候用得上 23 | Name string 24 | handler Handler 25 | root Filter 26 | } 27 | 28 | func (s *sdkHttpServer) Route(method string, pattern string, 29 | handlerFunc handlerFunc) error { 30 | return s.handler.Route(method, pattern, handlerFunc) 31 | } 32 | 33 | func (s *sdkHttpServer) Start(address string) error { 34 | return http.ListenAndServe(address, s) 35 | } 36 | 37 | func (s *sdkHttpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 38 | c := NewContext(writer, request) 39 | s.root(c) 40 | } 41 | 42 | func NewSdkHttpServer(name string, builders ...FilterBuilder) Server { 43 | 44 | // 改用我们的树 45 | handler := NewHandlerBasedOnTree() 46 | //handler := NewHandlerBasedOnMap() 47 | // 因为我们是一个链,所以我们把最后的业务逻辑处理,也作为一环 48 | var root Filter = handler.ServeHTTP 49 | // 从后往前把filter串起来 50 | for i := len(builders) - 1; i >= 0; i-- { 51 | b := builders[i] 52 | root = b(root) 53 | } 54 | res := &sdkHttpServer{ 55 | Name: name, 56 | handler: handler, 57 | root: root, 58 | } 59 | return res 60 | } 61 | 62 | -------------------------------------------------------------------------------- /pkg/v2/tree_router.go: -------------------------------------------------------------------------------- 1 | package webv2 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "strings" 7 | ) 8 | 9 | var ErrorInvalidRouterPattern = errors.New("invalid router pattern") 10 | 11 | type HandlerBasedOnTree struct { 12 | root *node 13 | 14 | } 15 | var supportMethods = [4]string {http.MethodPost, http.MethodGet, 16 | http.MethodDelete, http.MethodPut} 17 | func NewHandlerBasedOnTree() Handler { 18 | root := &node{ 19 | } 20 | return &HandlerBasedOnTree{ 21 | root: root, 22 | } 23 | } 24 | 25 | // ServeHTTP 就是从树里面找节点 26 | // 找到了就执行 27 | func (h *HandlerBasedOnTree) ServeHTTP(c *Context) { 28 | handler, found := h.findRouter(c.R.URL.Path) 29 | if !found { 30 | c.W.WriteHeader(http.StatusNotFound) 31 | _, _ = c.W.Write([]byte("Not Found")) 32 | return 33 | } 34 | handler(c) 35 | } 36 | 37 | // 这个是不好测试的版本,可以尝试为这个写单元测试 38 | // 会发现很难构造 request,也很难对 ResponseWriter做断言 39 | //func (h *HandlerBasedOnTree) ServeHTTP(c *Context) { 40 | // url := strings.Trim(c.R.URL.Path, "/") 41 | // paths := strings.Split(url, "/") 42 | // cur := h.root 43 | // for _, path := range paths { 44 | // // 从子节点里边找一个匹配到了当前 path 的节点 45 | // matchChild, found := h.findMatchChild(cur, path) 46 | // if !found { 47 | // // 找不到匹配的路径,直接返回 48 | // c.W.WriteHeader(404) 49 | // _, _ = c.W.Write([]byte("Not Found")) 50 | // return 51 | // } 52 | // cur = matchChild 53 | // } 54 | // // 到这里,应该是找完了 55 | // if cur.handler == nil { 56 | // // 到达这里是因为这种场景 57 | // // 比如说你注册了 /user/profile 58 | // // 然后你访问 /user 59 | // c.W.WriteHeader(404) 60 | // _, _ = c.W.Write([]byte("Not Found")) 61 | // return 62 | // } 63 | // cur.handler(c) 64 | //} 65 | 66 | func (h *HandlerBasedOnTree) findRouter(path string) (handlerFunc, bool) { 67 | // 去除头尾可能有的/,然后按照/切割成段 68 | paths := strings.Split(strings.Trim(path, "/"), "/") 69 | cur := h.root 70 | for _, p := range paths { 71 | // 从子节点里边找一个匹配到了当前 path 的节点 72 | matchChild, found := h.findMatchChild(cur, p) 73 | if !found { 74 | return nil, false 75 | } 76 | cur = matchChild 77 | } 78 | // 到这里,应该是找完了 79 | if cur.handler == nil { 80 | // 到达这里是因为这种场景 81 | // 比如说你注册了 /user/profile 82 | // 然后你访问 /user 83 | return nil, false 84 | } 85 | return cur.handler, true 86 | } 87 | 88 | // Route 就相当于往树里面插入节点 89 | func (h *HandlerBasedOnTree) Route(method string, pattern string, 90 | handlerFunc handlerFunc) error { 91 | 92 | err := h.validatePattern(pattern) 93 | if err != nil { 94 | return err 95 | } 96 | 97 | // 将pattern按照URL的分隔符切割 98 | // 例如,/user/friends 将变成 [user, friends] 99 | // 将前后的/去掉,统一格式 100 | pattern = strings.Trim(pattern, "/") 101 | paths := strings.Split(pattern, "/") 102 | // 当前指向根节点 103 | cur := h.root 104 | for index, path := range paths { 105 | 106 | // 从子节点里边找一个匹配到了当前 path 的节点 107 | matchChild, found := h.findMatchChild(cur, path) 108 | if found { 109 | cur = matchChild 110 | } else { 111 | // 为当前节点根据 112 | h.createSubTree(cur, paths[index:], handlerFunc) 113 | return nil 114 | } 115 | } 116 | // 离开了循环,说明我们加入的是短路径, 117 | // 比如说我们先加入了 /order/detail 118 | // 再加入/order,那么会走到这里 119 | cur.handler = handlerFunc 120 | return nil 121 | } 122 | 123 | func (h *HandlerBasedOnTree) validatePattern(pattern string) error { 124 | // 校验 *,如果存在,必须在最后一个,并且它前面必须是/ 125 | // 即我们只接受 /* 的存在,abc*这种是非法 126 | 127 | pos := strings.Index(pattern, "*") 128 | // 找到了 * 129 | if pos > 0 { 130 | // 必须是最后一个 131 | if pos != len(pattern) - 1 { 132 | return ErrorInvalidRouterPattern 133 | } 134 | if pattern[pos-1] != '/' { 135 | return ErrorInvalidRouterPattern 136 | } 137 | } 138 | return nil 139 | } 140 | 141 | func (h *HandlerBasedOnTree) findMatchChild(root *node, path string) (*node, bool) { 142 | var wildcardNode *node 143 | for _, child := range root.children { 144 | // 并不是 * 的节点命中了,直接返回 145 | // != * 是为了防止用户乱输入 146 | if child.path == path && 147 | child.path != "*"{ 148 | return child, true 149 | } 150 | // 命中了通配符的,我们看看后面还有没有更加详细的 151 | if child.path == "*" { 152 | wildcardNode = child 153 | } 154 | } 155 | return wildcardNode, wildcardNode != nil 156 | } 157 | 158 | func (h *HandlerBasedOnTree) createSubTree(root *node, paths []string, handlerFn handlerFunc) { 159 | cur := root 160 | for _, path := range paths { 161 | nn := newNode(path) 162 | cur.children = append(cur.children, nn) 163 | cur = nn 164 | } 165 | cur.handler = handlerFn 166 | } 167 | 168 | type node struct { 169 | path string 170 | children []*node 171 | 172 | // 如果这是叶子节点, 173 | // 那么匹配上之后就可以调用该方法 174 | handler handlerFunc 175 | } 176 | 177 | func newNode(path string) *node { 178 | return &node{ 179 | path: path, 180 | children: make([]*node, 0, 2), 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /pkg/v2/tree_router_test.go: -------------------------------------------------------------------------------- 1 | package webv2 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "net/http" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestHandlerBasedOnTree_Route(t *testing.T) { 11 | handler := NewHandlerBasedOnTree().(*HandlerBasedOnTree) 12 | assert.NotNil(t, handler.root) 13 | 14 | err := handler.Route(http.MethodPost, "/user", func(c *Context) {}) 15 | assert.Nil(t, err) 16 | // 开始做断言,这个时候我们应该确认,在根节点之下只有一个user节点 17 | assert.Equal(t, 1, len(handler.root.children)) 18 | 19 | n := handler.root.children[0] 20 | assert.NotNil(t, n) 21 | assert.Equal(t, "user", n.path) 22 | assert.NotNil(t, n.handler) 23 | assert.Empty(t, n.children) 24 | 25 | // 我们只有 26 | // user -> profile 27 | err = handler.Route(http.MethodPost, "/user/profile", func(c *Context) {}) 28 | assert.Nil(t, err) 29 | assert.Equal(t, 1, len(n.children)) 30 | profileNode := n.children[0] 31 | assert.NotNil(t, profileNode) 32 | assert.Equal(t, "profile", profileNode.path) 33 | assert.NotNil(t, profileNode.handler) 34 | assert.Empty(t, profileNode.children) 35 | 36 | // 试试重复 37 | err = handler.Route(http.MethodPost, "/user", func(c *Context) {}) 38 | assert.Nil(t, err) 39 | n = handler.root.children[0] 40 | assert.NotNil(t, n) 41 | assert.Equal(t, "user", n.path) 42 | assert.NotNil(t, n.handler) 43 | // 有profile节点 44 | assert.Equal(t, 1, len(n.children)) 45 | 46 | // 给 user 再加一个节点 47 | err = handler.Route(http.MethodPost, "/user/home", func(c *Context) {}) 48 | assert.Nil(t, err) 49 | assert.Equal(t, 2, len(n.children)) 50 | homeNode := n.children[1] 51 | assert.NotNil(t, homeNode) 52 | assert.Equal(t, "home", homeNode.path) 53 | assert.NotNil(t, homeNode.handler) 54 | assert.Empty(t, homeNode.children) 55 | 56 | // 添加 /order/detail 57 | err = handler.Route(http.MethodPost, "/order/detail", func(c *Context) {}) 58 | assert.Equal(t, 2, len(handler.root.children)) 59 | orderNode := handler.root.children[1] 60 | assert.NotNil(t, orderNode) 61 | assert.Equal(t, "order", orderNode.path) 62 | // 此刻我们只有/order/detail,但是没有/order 63 | assert.Nil(t, orderNode.handler) 64 | assert.Equal(t, 1, len(orderNode.children)) 65 | 66 | orderDetailNode := orderNode.children[0] 67 | assert.NotNil(t, orderDetailNode) 68 | assert.Empty(t, orderDetailNode.children) 69 | assert.Equal(t, "detail", orderDetailNode.path) 70 | assert.NotNil(t, orderDetailNode.handler) 71 | 72 | // 加一个 /order 73 | err = handler.Route(http.MethodPost, "/order", func(c *Context) {}) 74 | assert.Nil(t, err) 75 | assert.Equal(t, 2, len(handler.root.children)) 76 | orderNode = handler.root.children[1] 77 | assert.Equal(t, "order", orderNode.path) 78 | // 此时我们有了 /order 79 | assert.NotNil(t, orderNode.handler) 80 | 81 | err = handler.Route(http.MethodPost, "/order/*", func(c *Context) {}) 82 | assert.Nil(t, err) 83 | assert.Equal(t, 2, len(orderNode.children)) 84 | orderWildcard := orderNode.children[1] 85 | assert.NotNil(t, orderWildcard) 86 | assert.NotNil(t, orderWildcard.handler) 87 | assert.Equal(t, "*", orderWildcard.path) 88 | 89 | err = handler.Route(http.MethodPost, "/order/*/checkout", func(c *Context) {}) 90 | assert.NotNil(t, err) 91 | } 92 | 93 | func TestHandlerBasedOnTree_findRouter(t *testing.T) { 94 | handler := NewHandlerBasedOnTree().(*HandlerBasedOnTree) 95 | _ = handler.Route(http.MethodPost, "/user", func(c *Context) {}) 96 | fn, found := handler.findRouter("/user") 97 | assert.True(t, found) 98 | assert.NotNil(t, fn) 99 | _, found = handler.findRouter("/user/profile") 100 | assert.False(t, found) 101 | 102 | _ = handler.Route(http.MethodPost, "/user/profile", func(c *Context) {}) 103 | _, found = handler.findRouter("/user/profile") 104 | assert.True(t, found) 105 | 106 | _, found = handler.findRouter("/user") 107 | assert.True(t, found) 108 | 109 | var detailHandler handlerFunc = func(c *Context) {} 110 | _ = handler.Route(http.MethodPost, "/order/detail", detailHandler) 111 | _, found = handler.findRouter("/order") 112 | assert.False(t, found) 113 | 114 | fn, found = handler.findRouter("/order/detail") 115 | assert.True(t, found) 116 | assert.True(t, handlerFuncEquals(detailHandler, fn)) 117 | 118 | var wildcardHandler handlerFunc = func(c *Context) {} 119 | _ = handler.Route(http.MethodPost, "/order/*", wildcardHandler) 120 | _, found = handler.findRouter("/order") 121 | assert.False(t, found) 122 | 123 | fn, found = handler.findRouter("/order/detail") 124 | assert.True(t, found) 125 | assert.True(t, handlerFuncEquals(detailHandler, fn)) 126 | 127 | fn, found = handler.findRouter("/order/checkout") 128 | assert.True(t, found) 129 | assert.True(t, handlerFuncEquals(wildcardHandler, fn)) 130 | } 131 | 132 | func handlerFuncEquals(hf1 handlerFunc, hf2 handlerFunc) bool { 133 | return reflect.ValueOf(hf1).Pointer() == reflect.ValueOf(hf2).Pointer() 134 | } -------------------------------------------------------------------------------- /pkg/v3/context.go: -------------------------------------------------------------------------------- 1 | package webv3 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "net/http" 7 | ) 8 | 9 | type Context struct { 10 | W http.ResponseWriter 11 | R *http.Request 12 | PathParams map[string]string 13 | } 14 | 15 | func (c *Context) ReadJson(data interface{}) error { 16 | body, err := io.ReadAll(c.R.Body) 17 | if err != nil { 18 | return err 19 | } 20 | return json.Unmarshal(body, data) 21 | } 22 | func (c *Context) OkJson(data interface{}) error { 23 | // http 库里面提前定义好了各种响应码 24 | return c.WriteJson(http.StatusOK, data) 25 | } 26 | 27 | func (c *Context) SystemErrJson(data interface{}) error { 28 | // http 库里面提前定义好了各种响应码 29 | return c.WriteJson(http.StatusInternalServerError, data) 30 | } 31 | 32 | func (c *Context) BadRequestJson(data interface{}) error { 33 | // http 库里面提前定义好了各种响应码 34 | return c.WriteJson(http.StatusBadRequest, data) 35 | } 36 | 37 | func (c *Context) WriteJson(status int, data interface{}) error { 38 | c.W.WriteHeader(status) 39 | bs, err := json.Marshal(data) 40 | if err != nil { 41 | return err 42 | } 43 | _, err = c.W.Write(bs) 44 | if err != nil { 45 | return err 46 | } 47 | return nil 48 | } 49 | 50 | func NewContext(w http.ResponseWriter, r *http.Request) *Context { 51 | return &Context{ 52 | W: w, 53 | R: r, 54 | // 一般路径参数都是一个,所以容量1就可以了 55 | PathParams: make(map[string]string, 1), 56 | } 57 | } -------------------------------------------------------------------------------- /pkg/v3/filter.go: -------------------------------------------------------------------------------- 1 | package webv3 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | type FilterBuilder func(next Filter) Filter 9 | 10 | type Filter func(c *Context) 11 | 12 | func MetricFilterBuilder(next Filter) Filter { 13 | return func(c *Context) { 14 | // 执行前的时间 15 | startTime := time.Now().UnixNano() 16 | next(c) 17 | // 执行后的时间 18 | endTime := time.Now().UnixNano() 19 | fmt.Printf("run time: %d \n", endTime-startTime) 20 | } 21 | } -------------------------------------------------------------------------------- /pkg/v3/handler.go: -------------------------------------------------------------------------------- 1 | package webv3 2 | 3 | type Handler interface { 4 | ServeHTTP(c *Context) 5 | Routable 6 | } 7 | 8 | type handlerFunc func(c *Context) -------------------------------------------------------------------------------- /pkg/v3/map_router.go: -------------------------------------------------------------------------------- 1 | package webv3 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "sync" 7 | ) 8 | 9 | // 一种常用的GO设计模式, 10 | // 用于确保HandlerBasedOnMap肯定实现了这个接口 11 | var _ Handler = &HandlerBasedOnMap{} 12 | 13 | 14 | type HandlerBasedOnMap struct { 15 | handlers sync.Map 16 | } 17 | 18 | func (h *HandlerBasedOnMap) ServeHTTP(c *Context) { 19 | request := c.R 20 | key := h.key(request.Method, request.URL.Path) 21 | handler, ok := h.handlers.Load(key) 22 | if !ok { 23 | c.W.WriteHeader(http.StatusNotFound) 24 | _, _ = c.W.Write([]byte("not any router match")) 25 | return 26 | } 27 | 28 | handler.(handlerFunc)(c) 29 | } 30 | 31 | func (h *HandlerBasedOnMap) Route(method string, pattern string, 32 | handlerFunc handlerFunc) error { 33 | key := h.key(method, pattern) 34 | h.handlers.Store(key, handlerFunc) 35 | return nil 36 | } 37 | 38 | func (h *HandlerBasedOnMap) key(method string, 39 | path string) string { 40 | return fmt.Sprintf("%s#%s", method, path) 41 | } 42 | 43 | func NewHandlerBasedOnMap() *HandlerBasedOnMap { 44 | return &HandlerBasedOnMap{} 45 | } 46 | -------------------------------------------------------------------------------- /pkg/v3/server.go: -------------------------------------------------------------------------------- 1 | package webv3 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | // Routable 可路由的 8 | type Routable interface { 9 | // Route 设定一个路由,命中该路由的会执行handlerFunc的代码 10 | Route(method string, pattern string, handlerFunc handlerFunc) error 11 | } 12 | 13 | // Server 是http server 的顶级抽象 14 | type Server interface { 15 | Routable 16 | // Start 启动我们的服务器 17 | Start(address string) error 18 | } 19 | 20 | // sdkHttpServer 这个是基于 net/http 这个包实现的 http server 21 | type sdkHttpServer struct { 22 | // Name server 的名字,给个标记,日志输出的时候用得上 23 | Name string 24 | handler Handler 25 | root Filter 26 | } 27 | 28 | func (s *sdkHttpServer) Route(method string, pattern string, 29 | handlerFunc handlerFunc) error { 30 | return s.handler.Route(method, pattern, handlerFunc) 31 | } 32 | 33 | func (s *sdkHttpServer) Start(address string) error { 34 | return http.ListenAndServe(address, s) 35 | } 36 | 37 | func (s *sdkHttpServer) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 38 | c := NewContext(writer, request) 39 | s.root(c) 40 | } 41 | 42 | func NewSdkHttpServer(name string, builders ...FilterBuilder) Server { 43 | 44 | // 改用我们的树 45 | handler := NewHandlerBasedOnTree() 46 | //handler := NewHandlerBasedOnMap() 47 | // 因为我们是一个链,所以我们把最后的业务逻辑处理,也作为一环 48 | var root Filter = handler.ServeHTTP 49 | // 从后往前把filter串起来 50 | for i := len(builders) - 1; i >= 0; i-- { 51 | b := builders[i] 52 | root = b(root) 53 | } 54 | res := &sdkHttpServer{ 55 | Name: name, 56 | handler: handler, 57 | root: root, 58 | } 59 | return res 60 | } 61 | 62 | -------------------------------------------------------------------------------- /pkg/v3/tree_node.go: -------------------------------------------------------------------------------- 1 | package webv3 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | const ( 8 | 9 | // 根节点,只有根用这个 10 | nodeTypeRoot = iota 11 | 12 | // * 13 | nodeTypeAny 14 | 15 | // 路径参数 16 | nodeTypeParam 17 | 18 | // 正则 19 | nodeTypeReg 20 | 21 | // 静态,即完全匹配 22 | nodeTypeStatic 23 | ) 24 | 25 | const any = "*" 26 | 27 | // matchFunc 承担两个职责,一个是判断是否匹配,一个是在匹配之后 28 | // 将必要的数据写入到 Context 29 | // 所谓必要的数据,这里基本上是指路径参数 30 | type matchFunc func(path string, c *Context) bool 31 | 32 | type node struct { 33 | children []*node 34 | 35 | // 如果这是叶子节点, 36 | // 那么匹配上之后就可以调用该方法 37 | handler handlerFunc 38 | matchFunc matchFunc 39 | 40 | // 原始的 pattern。注意,它不是完整的pattern, 41 | // 而是匹配到这个节点的pattern 42 | pattern string 43 | nodeType int 44 | } 45 | 46 | // 静态节点 47 | func newStaticNode(path string) *node { 48 | return &node{ 49 | children: make([]*node, 0, 2), 50 | matchFunc: func(p string, c *Context) bool { 51 | return path == p && p != "*" 52 | }, 53 | nodeType: nodeTypeStatic, 54 | pattern: path, 55 | } 56 | } 57 | 58 | 59 | func newRootNode(method string) *node { 60 | return &node{ 61 | children: make([]*node, 0, 2), 62 | matchFunc: func( p string, c *Context) bool { 63 | panic("never call me") 64 | }, 65 | nodeType: nodeTypeRoot, 66 | pattern: method, 67 | } 68 | } 69 | 70 | func newNode(path string) *node { 71 | if path == "*"{ 72 | return newAnyNode() 73 | } 74 | if strings.HasPrefix(path, ":") { 75 | return newParamNode(path) 76 | } 77 | return newStaticNode(path) 78 | } 79 | 80 | // 通配符 * 节点 81 | func newAnyNode() *node { 82 | return &node{ 83 | // 因为我们不允许 * 后面还有节点,所以这里可以不用初始化 84 | //children: make([]*node, 0, 2), 85 | matchFunc: func(p string, c *Context) bool { 86 | return true 87 | }, 88 | nodeType: nodeTypeAny, 89 | pattern: any, 90 | } 91 | } 92 | 93 | // 路径参数节点 94 | func newParamNode(path string) *node { 95 | paramName := path[1:] 96 | return &node{ 97 | children: make([]*node, 0, 2), 98 | matchFunc: func(p string, c *Context) bool { 99 | if c != nil { 100 | c.PathParams[paramName] = p 101 | } 102 | // 如果自身是一个参数路由, 103 | // 然后又来一个通配符,我们认为是不匹配的 104 | return p != any 105 | }, 106 | nodeType: nodeTypeParam, 107 | pattern: path, 108 | } 109 | } 110 | 111 | // 正则节点 112 | //func newRegNode(path string) *node { 113 | // // 依据你的规则拿到正则表达式 114 | // return &node{ 115 | // children: make([]*node, 0, 2), 116 | // matchFunc: func(p string, c *Context) bool { 117 | // // 怎么写? 118 | // }, 119 | // nodeType: nodeTypeParam, 120 | // pattern: path, 121 | // } 122 | //} 123 | 124 | -------------------------------------------------------------------------------- /pkg/v3/tree_router.go: -------------------------------------------------------------------------------- 1 | package webv3 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "sort" 7 | "strings" 8 | ) 9 | 10 | var ErrorInvalidRouterPattern = errors.New("invalid router pattern") 11 | var ErrorInvalidMethod = errors.New("invalid method") 12 | 13 | type HandlerBasedOnTree struct { 14 | forest map[string]*node 15 | } 16 | 17 | var supportMethods = [4]string { 18 | http.MethodGet, 19 | http.MethodPost, 20 | http.MethodPut, 21 | http.MethodDelete, 22 | } 23 | 24 | func NewHandlerBasedOnTree() Handler { 25 | forest := make(map[string]*node, len(supportMethods)) 26 | for _, m :=range supportMethods { 27 | forest[m] = newRootNode(m) 28 | } 29 | return &HandlerBasedOnTree{ 30 | forest: forest, 31 | } 32 | } 33 | 34 | // ServeHTTP 就是从树里面找节点 35 | // 找到了就执行 36 | func (h *HandlerBasedOnTree) ServeHTTP(c *Context) { 37 | handler, found := h.findRouter(c.R.Method, c.R.URL.Path, c) 38 | if !found { 39 | c.W.WriteHeader(http.StatusNotFound) 40 | _, _ = c.W.Write([]byte("Not Found")) 41 | return 42 | } 43 | handler(c) 44 | } 45 | 46 | func (h *HandlerBasedOnTree) findRouter(method, path string, c *Context) (handlerFunc, bool) { 47 | // 去除头尾可能有的/,然后按照/切割成段 48 | paths := strings.Split(strings.Trim(path, "/"), "/") 49 | cur, ok := h.forest[method] 50 | if !ok { 51 | return nil, false 52 | } 53 | for _, p := range paths { 54 | // 从子节点里边找一个匹配到了当前 p 的节点 55 | matchChild, found := h.findMatchChild(cur, p, c) 56 | if !found { 57 | return nil, false 58 | } 59 | cur = matchChild 60 | } 61 | // 到这里,应该是找完了 62 | if cur.handler == nil { 63 | // 到达这里是因为这种场景 64 | // 比如说你注册了 /user/profile 65 | // 然后你访问 /user 66 | return nil, false 67 | } 68 | return cur.handler, true 69 | } 70 | 71 | // Route 就相当于往树里面插入节点 72 | func (h *HandlerBasedOnTree) Route(method string, pattern string, 73 | handlerFunc handlerFunc) error { 74 | 75 | err := h.validatePattern(pattern) 76 | if err != nil { 77 | return err 78 | } 79 | 80 | // 将pattern按照URL的分隔符切割 81 | // 例如,/user/friends 将变成 [user, friends] 82 | // 将前后的/去掉,统一格式 83 | pattern = strings.Trim(pattern, "/") 84 | paths := strings.Split(pattern, "/") 85 | 86 | // 当前指向根节点 87 | cur, ok := h.forest[method] 88 | if !ok { 89 | return ErrorInvalidMethod 90 | } 91 | for index, path := range paths { 92 | 93 | // 从子节点里边找一个匹配到了当前 path 的节点 94 | matchChild, found := h.findMatchChild(cur, path, nil) 95 | // != nodeTypeAny 是考虑到 /order/* 和 /order/:id 这种注册顺序 96 | if found && matchChild.nodeType != nodeTypeAny { 97 | cur = matchChild 98 | } else { 99 | // 为当前节点根据 100 | h.createSubTree(cur, paths[index:], handlerFunc) 101 | return nil 102 | } 103 | } 104 | // 离开了循环,说明我们加入的是短路径, 105 | // 比如说我们先加入了 /order/detail 106 | // 再加入/order,那么会走到这里 107 | cur.handler = handlerFunc 108 | return nil 109 | } 110 | 111 | func (h *HandlerBasedOnTree) validatePattern(pattern string) error { 112 | // 校验 *,如果存在,必须在最后一个,并且它前面必须是/ 113 | // 即我们只接受 /* 的存在,abc*这种是非法 114 | 115 | pos := strings.Index(pattern, "*") 116 | // 找到了 * 117 | if pos > 0 { 118 | // 必须是最后一个 119 | if pos != len(pattern) - 1 { 120 | return ErrorInvalidRouterPattern 121 | } 122 | if pattern[pos-1] != '/' { 123 | return ErrorInvalidRouterPattern 124 | } 125 | } 126 | return nil 127 | } 128 | 129 | func (h *HandlerBasedOnTree) findMatchChild(root *node, 130 | path string, c *Context) (*node, bool) { 131 | candidates := make([]*node, 0, 2) 132 | for _, child := range root.children { 133 | if child.matchFunc(path, c) { 134 | candidates = append(candidates, child) 135 | } 136 | } 137 | 138 | if len(candidates) == 0 { 139 | return nil, false 140 | } 141 | 142 | // type 也决定了它们的优先级 143 | sort.Slice(candidates, func(i, j int) bool { 144 | return candidates[i].nodeType < candidates[j].nodeType 145 | }) 146 | return candidates[len(candidates) - 1], true 147 | } 148 | 149 | func (h *HandlerBasedOnTree) createSubTree(root *node, paths []string, handlerFn handlerFunc) { 150 | cur := root 151 | for _, path := range paths { 152 | nn := newNode(path) 153 | cur.children = append(cur.children, nn) 154 | cur = nn 155 | } 156 | cur.handler = handlerFn 157 | } 158 | 159 | -------------------------------------------------------------------------------- /pkg/v3/tree_router_test.go: -------------------------------------------------------------------------------- 1 | package webv3 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "net/http" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestHandlerBasedOnTree_Route(t *testing.T) { 11 | handler := NewHandlerBasedOnTree().(*HandlerBasedOnTree) 12 | // 要确认已经为支持的方法创建了节点 13 | assert.Equal(t, len(supportMethods), len(handler.forest)) 14 | 15 | postNode := handler.forest[http.MethodPost] 16 | 17 | err := handler.Route(http.MethodPost, "/user", func(c *Context) {}) 18 | assert.Nil(t, err) 19 | assert.Equal(t, 1, len(postNode.children)) 20 | 21 | n := postNode.children[0] 22 | assert.NotNil(t, n) 23 | assert.Equal(t, "user", n.pattern) 24 | assert.NotNil(t, n.handler) 25 | assert.Empty(t, n.children) 26 | 27 | // 我们只有 28 | // user -> profile 29 | err = handler.Route(http.MethodPost, "/user/profile", func(c *Context) {}) 30 | assert.Nil(t, err) 31 | assert.Equal(t, 1, len(n.children)) 32 | profileNode := n.children[0] 33 | assert.NotNil(t, profileNode) 34 | assert.Equal(t, "profile", profileNode.pattern) 35 | assert.NotNil(t, profileNode.handler) 36 | assert.Empty(t, profileNode.children) 37 | 38 | // 试试重复 39 | err = handler.Route(http.MethodPost, "/user", func(c *Context) {}) 40 | assert.Nil(t, err) 41 | n = postNode.children[0] 42 | assert.NotNil(t, n) 43 | assert.Equal(t, "user", n.pattern) 44 | assert.NotNil(t, n.handler) 45 | // 有profile节点 46 | assert.Equal(t, 1, len(n.children)) 47 | 48 | // 给 user 再加一个节点 49 | err = handler.Route(http.MethodPost, "/user/home", func(c *Context) {}) 50 | assert.Nil(t, err) 51 | assert.Equal(t, 2, len(n.children)) 52 | homeNode := n.children[1] 53 | assert.NotNil(t, homeNode) 54 | assert.Equal(t, "home", homeNode.pattern) 55 | assert.NotNil(t, homeNode.handler) 56 | assert.Empty(t, homeNode.children) 57 | 58 | // 添加 /order/detail 59 | err = handler.Route(http.MethodPost, "/order/detail", func(c *Context) {}) 60 | assert.Equal(t, 2, len(postNode.children)) 61 | orderNode := postNode.children[1] 62 | assert.NotNil(t, orderNode) 63 | assert.Equal(t, "order", orderNode.pattern) 64 | // 此刻我们只有/order/detail,但是没有/order 65 | assert.Nil(t, orderNode.handler) 66 | assert.Equal(t, 1, len(orderNode.children)) 67 | 68 | orderDetailNode := orderNode.children[0] 69 | assert.NotNil(t, orderDetailNode) 70 | assert.Empty(t, orderDetailNode.children) 71 | assert.Equal(t, "detail", orderDetailNode.pattern) 72 | assert.NotNil(t, orderDetailNode.handler) 73 | 74 | // 加一个 /order 75 | err = handler.Route(http.MethodPost, "/order", func(c *Context) {}) 76 | assert.Nil(t, err) 77 | assert.Equal(t, 2, len(postNode.children)) 78 | orderNode = postNode.children[1] 79 | assert.Equal(t, "order", orderNode.pattern) 80 | // 此时我们有了 /order 81 | assert.NotNil(t, orderNode.handler) 82 | 83 | err = handler.Route(http.MethodPost, "/order/*", func(c *Context) {}) 84 | assert.Nil(t, err) 85 | assert.Equal(t, 2, len(orderNode.children)) 86 | orderWildcard := orderNode.children[1] 87 | assert.NotNil(t, orderWildcard) 88 | assert.NotNil(t, orderWildcard.handler) 89 | assert.Equal(t, "*", orderWildcard.pattern) 90 | 91 | err = handler.Route(http.MethodPost, "/order/*/checkout", func(c *Context) {}) 92 | assert.Equal(t, ErrorInvalidRouterPattern, err) 93 | 94 | err = handler.Route(http.MethodConnect, "/order/checkout", func(c *Context) {}) 95 | assert.Equal(t, ErrorInvalidMethod, err) 96 | 97 | err = handler.Route(http.MethodPost, "/order/:id", func(c *Context){}) 98 | assert.Nil(t, err) 99 | // 这时候我们有/order/* 和 /order/:id 100 | // 因为我们并没有认为它们不兼容,而是/order/:id优先 101 | assert.Equal(t, 3, len(orderNode.children)) 102 | orderParamNode := orderNode.children[2] 103 | assert.Equal(t, ":id", orderParamNode.pattern) 104 | 105 | } 106 | 107 | func TestHandlerBasedOnTree_findRouter(t *testing.T) { 108 | handler := NewHandlerBasedOnTree().(*HandlerBasedOnTree) 109 | _ = handler.Route(http.MethodPost, "/user", func(c *Context) {}) 110 | ctx := NewContext(nil, nil) 111 | fn, found := handler.findRouter(http.MethodPost, "/user", ctx) 112 | assert.True(t, found) 113 | assert.NotNil(t, fn) 114 | _, found = handler.findRouter(http.MethodPost,"/user/profile", ctx) 115 | assert.False(t, found) 116 | 117 | _ = handler.Route(http.MethodPost, "/user/profile", func(c *Context) {}) 118 | _, found = handler.findRouter(http.MethodPost, "/user/profile", ctx) 119 | assert.True(t, found) 120 | 121 | _, found = handler.findRouter(http.MethodPost, "/user", ctx) 122 | assert.True(t, found) 123 | 124 | var detailHandler handlerFunc = func(c *Context) {} 125 | _ = handler.Route(http.MethodPost, "/order/detail", detailHandler) 126 | _, found = handler.findRouter(http.MethodPost,"/order", ctx) 127 | assert.False(t, found) 128 | 129 | fn, found = handler.findRouter(http.MethodPost,"/order/detail", ctx) 130 | assert.True(t, found) 131 | assert.True(t, handlerFuncEquals(detailHandler, fn)) 132 | 133 | var wildcardHandler handlerFunc = func(c *Context) {} 134 | _ = handler.Route(http.MethodPost, "/order/*", wildcardHandler) 135 | _, found = handler.findRouter(http.MethodPost,"/order", ctx) 136 | assert.False(t, found) 137 | 138 | fn, found = handler.findRouter(http.MethodPost,"/order/detail", ctx) 139 | assert.True(t, found) 140 | assert.True(t, handlerFuncEquals(detailHandler, fn)) 141 | 142 | fn, found = handler.findRouter(http.MethodPost,"/order/checkout", ctx) 143 | assert.True(t, found) 144 | assert.True(t, handlerFuncEquals(wildcardHandler, fn)) 145 | 146 | _, found = handler.findRouter(http.MethodGet,"/order/checkout", ctx) 147 | assert.False(t, found) 148 | 149 | // 参数路由 150 | handler.Route(http.MethodPost, "/order/*", wildcardHandler) 151 | } 152 | 153 | func handlerFuncEquals(hf1 handlerFunc, hf2 handlerFunc) bool { 154 | return reflect.ValueOf(hf1).Pointer() == reflect.ValueOf(hf2).Pointer() 155 | } --------------------------------------------------------------------------------