├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── examples ├── hello │ └── cmd.go ├── trading │ └── cmd.go └── translate │ └── cmd.go ├── go.mod ├── go.sum ├── pkg ├── ask.go ├── ask_test.go ├── auth │ ├── chatgpt_auth.go │ └── chatgpt_auth_test.go ├── client.go ├── client_test.go ├── httpx │ ├── cookies.go │ ├── cookies_test.go │ ├── http_session.go │ └── http_session_test.go ├── log.go ├── log_test.go ├── option.go ├── option_test.go └── utils │ ├── regexp.go │ └── regexp_test.go └── test ├── ask_test.go ├── auth_test.go ├── auth_test.py └── client_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | /bin 18 | 19 | .env.local 20 | 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Owen Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: deps clean build *test run 2 | 3 | NAME="chatgpt" 4 | 5 | deps: 6 | go install github.com/joho/godotenv/cmd/godotenv@latest 7 | 8 | unit-test: 9 | go test ./pkg/... 10 | 11 | int-test: deps 12 | godotenv -f .env.local go test ./test/... 13 | 14 | test: unit-test int-test 15 | 16 | run-%: 17 | go run ./examples/$*/cmd.go 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chatgpt-go 2 | chatGPT golang client translated from https://github.com/acheong08/ChatGPT 3 | 4 | ## Installation 5 | 6 | ```shell script 7 | go get github.com/yubing744/chatgpt-go 8 | ``` 9 | 10 | ## Usage 11 | 12 | Config .env.local file 13 | ``` 14 | CHATGPT_EMAIL="your chat gpt account" 15 | CHATGPT_PASSWORD="your chat gpt password" 16 | ``` 17 | 18 | ``` go 19 | package main 20 | 21 | import ( 22 | "context" 23 | "fmt" 24 | "log" 25 | "os" 26 | 27 | "github.com/Valgard/godotenv" 28 | "github.com/yubing744/chatgpt-go/pkg" 29 | ) 30 | 31 | func main() { 32 | dotenv := godotenv.New() 33 | if err := dotenv.Load(".env.local"); err != nil { 34 | panic(err) 35 | } 36 | 37 | email := os.Getenv("CHATGPT_EMAIL") 38 | password := os.Getenv("CHATGPT_PASSWORD") 39 | if email == "" || password == "" { 40 | log.Panic("CHATGPT_EMAIL or CHATGPT_PASSWORD not set in .env.local") 41 | } 42 | 43 | client := pkg.NewChatgptClient(email, password) 44 | 45 | fmt.Print("Starting ...\n") 46 | err := client.Start(context.Background()) 47 | defer client.Stop() 48 | 49 | if err != nil { 50 | log.Fatalf("Start fail: %s\n", err.Error()) 51 | return 52 | } 53 | 54 | fmt.Print("Start success\n") 55 | 56 | prompt := "Hello" 57 | fmt.Printf("You: %s", prompt) 58 | result, err := client.Ask(context.Background(), prompt, nil, nil) 59 | if err != nil { 60 | fmt.Printf("Ask fail: %s\n", err.Error()) 61 | return 62 | } 63 | 64 | if result.Code == 0 { 65 | fmt.Printf("AI: %s\n", result.Data.Text) 66 | } 67 | 68 | fmt.Print("Done\n") 69 | } 70 | ``` -------------------------------------------------------------------------------- /examples/hello/cmd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "os" 8 | 9 | "github.com/Valgard/godotenv" 10 | "github.com/yubing744/chatgpt-go/pkg" 11 | ) 12 | 13 | func main() { 14 | dotenv := godotenv.New() 15 | if err := dotenv.Load(".env.local"); err != nil { 16 | panic(err) 17 | } 18 | 19 | email := os.Getenv("CHATGPT_EMAIL") 20 | password := os.Getenv("CHATGPT_PASSWORD") 21 | if email == "" || password == "" { 22 | log.Panic("CHATGPT_EMAIL or CHATGPT_PASSWORD not set in .env.local") 23 | } 24 | 25 | client := pkg.NewChatgptClient(email, password) 26 | 27 | fmt.Print("Starting ...\n") 28 | err := client.Start(context.Background()) 29 | defer client.Stop() 30 | 31 | if err != nil { 32 | log.Fatalf("Start fail: %s\n", err.Error()) 33 | return 34 | } 35 | 36 | fmt.Print("Start success\n") 37 | 38 | prompt := "Hello" 39 | fmt.Printf("You: %s", prompt) 40 | result, err := client.Ask(context.Background(), prompt, nil, nil) 41 | if err != nil { 42 | fmt.Printf("Ask fail: %s\n", err.Error()) 43 | return 44 | } 45 | 46 | if result.Code == 0 { 47 | fmt.Printf("AI: %s\n", result.Data.Text) 48 | } 49 | 50 | fmt.Print("Done\n") 51 | } 52 | -------------------------------------------------------------------------------- /examples/trading/cmd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "os" 8 | 9 | "github.com/Valgard/godotenv" 10 | "github.com/yubing744/chatgpt-go/pkg" 11 | ) 12 | 13 | func main() { 14 | dotenv := godotenv.New() 15 | if err := dotenv.Load(".env.local"); err != nil { 16 | panic(err) 17 | } 18 | 19 | email := os.Getenv("CHATGPT_EMAIL") 20 | password := os.Getenv("CHATGPT_PASSWORD") 21 | if email == "" || password == "" { 22 | log.Panic("CHATGPT_EMAIL or CHATGPT_PASSWORD not set in .env.local") 23 | } 24 | 25 | client := pkg.NewChatgptClient(email, password) 26 | 27 | fmt.Print("Start ...\n") 28 | err := client.Start(context.Background()) 29 | defer client.Stop() 30 | if err != nil { 31 | log.Fatalf("Start fail: %s\n", err.Error()) 32 | return 33 | } 34 | 35 | fmt.Print("Start success\n") 36 | 37 | prompt := `You:BOLL data changed: UpBand:[2.653 2.645 2.640 2.634 2.622 2.611 2.614 2.615 2.618 2.618 2.619 2.619 2.622 2.624 2.624 2.624 2.624 2.624 2.624 2.627], SMA:[2.605 2.603 2.601 2.599 2.596 2.594 2.595 2.596 2.598 2.599 2.600 2.599 2.599 2.598 2.598 2.598 2.598 2.598 2.598 2.599], DownBand:[2.557 2.561 2.562 2.564 2.570 2.577 2.575 2.577 2.579 2.579 2.581 2.580 2.575 2.572 2.572 2.572 2.572 2.572 2.571 2.571] 38 | You:RSI data changed: [55.703 78.253 44.869 33.871 26.280 30.286 81.857 78.360 85.344 38.224 40.336 12.013 8.355 24.564 64.706 72.386 64.481 44.202 75.244 83.419] 39 | You:There are currently no open positions 40 | You:Analyze the data and generate only one trading command: /open_long_position, /open_short_position, /close_position or /no_action, the entity will execute the command and give you feedback. 41 | AI:` 42 | fmt.Printf("%s", prompt) 43 | result, err := client.Ask(context.Background(), prompt, nil, nil) 44 | if err != nil { 45 | fmt.Printf("Ask fail: %s\n", err.Error()) 46 | return 47 | } 48 | 49 | if result.Code == 0 { 50 | fmt.Printf("AI: %s\n", result.Data.Text) 51 | } 52 | 53 | fmt.Print("Done\n") 54 | } 55 | -------------------------------------------------------------------------------- /examples/translate/cmd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "os" 8 | 9 | "github.com/Valgard/godotenv" 10 | "github.com/yubing744/chatgpt-go/pkg" 11 | ) 12 | 13 | func main() { 14 | dotenv := godotenv.New() 15 | if err := dotenv.Load(".env.local"); err != nil { 16 | panic(err) 17 | } 18 | 19 | email := os.Getenv("CHATGPT_EMAIL") 20 | password := os.Getenv("CHATGPT_PASSWORD") 21 | if email == "" || password == "" { 22 | log.Panic("CHATGPT_EMAIL or CHATGPT_PASSWORD not set in .env.local") 23 | } 24 | 25 | client := pkg.NewChatgptClient(email, password) 26 | 27 | fmt.Print("Starting ...\n") 28 | err := client.Start(context.Background()) 29 | defer client.Stop() 30 | if err != nil { 31 | log.Fatalf("Start fail: %s\n", err.Error()) 32 | return 33 | } 34 | 35 | fmt.Print("Start success\n") 36 | 37 | prompt := "翻译成英文:你还需要哪些指标帮助决策交易命令?" 38 | fmt.Printf("You: %s", prompt) 39 | result, err := client.Ask(context.Background(), prompt, nil, nil) 40 | if err != nil { 41 | fmt.Printf("Ask fail: %s\n", err.Error()) 42 | return 43 | } 44 | 45 | if result.Code == 0 { 46 | fmt.Printf("AI: %s\n", result.Data.Text) 47 | } 48 | 49 | fmt.Print("Done\n") 50 | } 51 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/yubing744/chatgpt-go 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/Valgard/godotenv v0.0.0-20200511222744-8873b92a09c5 7 | github.com/google/uuid v1.3.0 8 | github.com/pkg/errors v0.9.1 9 | github.com/stretchr/testify v1.8.1 10 | golang.org/x/net v0.7.0 11 | ) 12 | 13 | require ( 14 | github.com/Valgard/go-pcre v0.0.0-20200510215507-235e400e25e9 // indirect 15 | github.com/davecgh/go-spew v1.1.1 // indirect 16 | github.com/joho/godotenv v1.5.1 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | gopkg.in/yaml.v3 v3.0.1 // indirect 19 | ) 20 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/Valgard/go-pcre v0.0.0-20200510215507-235e400e25e9 h1:0lG/MypDQyrv+g3GoApJwgAX8W7eUskpZaDcobxoHA4= 2 | github.com/Valgard/go-pcre v0.0.0-20200510215507-235e400e25e9/go.mod h1:dPmwfLc83w8+8tnYkKrPzon1X6EDhKQzcgeJw9cBm9g= 3 | github.com/Valgard/godotenv v0.0.0-20200511222744-8873b92a09c5 h1:wOsS79keYJAZ/JFhbeAukKOskg0YLy/K3T+d52ybI+M= 4 | github.com/Valgard/godotenv v0.0.0-20200511222744-8873b92a09c5/go.mod h1:8CKlwPUYW/JXcoFHqWYb44VlOjnrbdbUie/SOlwGaZE= 5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= 9 | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 10 | github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= 11 | github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= 12 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 13 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 14 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 15 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 16 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 17 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 18 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 19 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 20 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 21 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 22 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 23 | golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= 24 | golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= 25 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 26 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 27 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 28 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 29 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 30 | -------------------------------------------------------------------------------- /pkg/ask.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io/ioutil" 9 | "net/http" 10 | "strings" 11 | 12 | "github.com/google/uuid" 13 | "github.com/pkg/errors" 14 | ) 15 | 16 | type Message struct { 17 | Text string 18 | ConversationID string 19 | ParentID string 20 | } 21 | 22 | type AskResult struct { 23 | Code int 24 | Detail string `json:"detail"` 25 | Data *Message 26 | } 27 | 28 | func (client *ChatgptClient) Ask(ctx context.Context, prompt string, conversationId *string, parentId *string) (*AskResult, error) { 29 | url := fmt.Sprintf("%s/%s", client.baseURL, "api/conversation") 30 | headers := http.Header{ 31 | "Accept": {"application/json; charset=utf-8"}, 32 | } 33 | 34 | data := map[string]interface{}{ 35 | "action": "next", 36 | "messages": []map[string]interface{}{ 37 | { 38 | "id": uuid.New().String(), 39 | "role": "user", 40 | "content": map[string]interface{}{ 41 | "content_type": "text", 42 | "parts": []string{prompt}, 43 | }, 44 | }, 45 | }, 46 | "model": "text-davinci-002-render-sha", 47 | } 48 | 49 | if conversationId != nil { 50 | data["conversation_id"] = *conversationId 51 | } 52 | 53 | if parentId != nil { 54 | data["parent_message_id"] = *parentId 55 | } else { 56 | data["parent_message_id"] = uuid.New().String() 57 | } 58 | 59 | payload, _ := json.Marshal(data) 60 | resp, err := client.session.Post(url, headers, payload, true) 61 | if err != nil { 62 | return nil, errors.Wrapf(err, "error in get %s", url) 63 | } 64 | 65 | defer resp.Body.Close() 66 | 67 | if resp.StatusCode == http.StatusOK { 68 | result := &AskResult{ 69 | Code: 0, 70 | Detail: "", 71 | } 72 | 73 | msgs, err := client.parseResponse(resp) 74 | if err != nil { 75 | return nil, err 76 | } 77 | 78 | if len(msgs) > 0 { 79 | result.Data = msgs[len(msgs)-1] 80 | } 81 | 82 | return result, nil 83 | } 84 | 85 | body, _ := ioutil.ReadAll(resp.Body) 86 | return nil, errors.Errorf("Error in ask: %s", string(body)) 87 | } 88 | 89 | func (client *ChatgptClient) parseResponse(response *http.Response) ([]*Message, error) { 90 | log := client.logger 91 | 92 | messages := make([]*Message, 0) 93 | 94 | log.Printf("\n") 95 | log.Printf("Parse response ") 96 | 97 | scanner := bufio.NewScanner(response.Body) 98 | for scanner.Scan() { 99 | line := scanner.Text() 100 | 101 | log.Printf(".") 102 | 103 | if client.debug { 104 | log.Printf("new line: %s\n", line) 105 | } 106 | 107 | if line == "" { 108 | continue 109 | } 110 | 111 | if strings.HasPrefix(line, "event: ") { 112 | continue 113 | } 114 | 115 | if !strings.HasPrefix(line, "data: ") { 116 | log.Printf("line: %s\n", line) 117 | 118 | line = strings.ReplaceAll(line, `\"`, `"`) 119 | line = strings.ReplaceAll(line, `\'`, `'`) 120 | line = strings.ReplaceAll(line, `\\`, `\`) 121 | 122 | var data struct { 123 | Detail string `json:"detail"` 124 | } 125 | err := json.Unmarshal([]byte(line), &data) 126 | if err != nil { 127 | return nil, errors.New(line) 128 | } 129 | 130 | return nil, errors.New(data.Detail) 131 | } 132 | 133 | line = strings.TrimPrefix(line, "data: ") 134 | if line == "[DONE]" { 135 | break 136 | } 137 | 138 | var parsedLine map[string]interface{} 139 | err := json.Unmarshal([]byte(line), &parsedLine) 140 | if err != nil { 141 | log.Printf("Error in Unmarshal: %s\n", line) 142 | continue 143 | } 144 | 145 | if !checkFields(parsedLine) { 146 | log.Printf("Field missing\n") 147 | log.Printf("%v", parsedLine) 148 | continue 149 | } 150 | 151 | messageContextType := parsedLine["message"].(map[string]interface{})["content"].(map[string]interface{})["content_type"].(string) 152 | if messageContextType == "text" { 153 | message := parsedLine["message"].(map[string]interface{})["content"].(map[string]interface{})["parts"].([]interface{})[0] 154 | conversationID := parsedLine["conversation_id"].(string) 155 | parentID := parsedLine["message"].(map[string]interface{})["id"].(string) 156 | messages = append(messages, &Message{ 157 | ConversationID: conversationID, 158 | ParentID: parentID, 159 | Text: fmt.Sprintf("%v", message), 160 | }) 161 | } else { 162 | log.Printf("not support message type: %s\n", messageContextType) 163 | } 164 | } 165 | 166 | log.Printf("\n") 167 | 168 | return messages, nil 169 | } 170 | 171 | func checkFields(parsedLine map[string]interface{}) bool { 172 | _, messageExists := parsedLine["message"] 173 | _, conversationIDExists := parsedLine["conversation_id"] 174 | _, messageContentExists := parsedLine["message"].(map[string]interface{})["content"] 175 | _, messageContentTypeExists := parsedLine["message"].(map[string]interface{})["content"].(map[string]interface{})["content_type"] 176 | _, messagePartsExists := parsedLine["message"].(map[string]interface{})["content"].(map[string]interface{})["parts"] 177 | if messageExists && conversationIDExists && messageContentExists && messageContentTypeExists && messagePartsExists { 178 | return true 179 | } 180 | 181 | return false 182 | } 183 | -------------------------------------------------------------------------------- /pkg/ask_test.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "bytes" 5 | "io/ioutil" 6 | "net/http" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestParseResponse(t *testing.T) { 13 | client := NewChatgptClient("test", "test") 14 | 15 | body := `` 16 | resp := &http.Response{ 17 | Body: ioutil.NopCloser(bytes.NewReader([]byte(body))), 18 | } 19 | 20 | msgs, err := client.parseResponse(resp) 21 | assert.NoError(t, err) 22 | assert.NotNil(t, msgs) 23 | } 24 | 25 | func TestParseResponseForUnmarshalError(t *testing.T) { 26 | client := NewChatgptClient("test", "test") 27 | 28 | body := `data: 2023-02-21 07:00:21.653311` 29 | resp := &http.Response{ 30 | Body: ioutil.NopCloser(bytes.NewReader([]byte(body))), 31 | } 32 | 33 | msgs, err := client.parseResponse(resp) 34 | assert.NoError(t, err) 35 | assert.NotNil(t, msgs) 36 | } 37 | 38 | func TestParseResponseForDetail(t *testing.T) { 39 | client := NewChatgptClient("test", "test") 40 | 41 | body := `{"detail":"Too many requests in 1 hour. Try again later."}` 42 | resp := &http.Response{ 43 | Body: ioutil.NopCloser(bytes.NewReader([]byte(body))), 44 | } 45 | 46 | _, err := client.parseResponse(resp) 47 | assert.Error(t, err) 48 | assert.Equal(t, "Too many requests in 1 hour. Try again later.", err.Error()) 49 | } 50 | 51 | func TestParseResponseForServerError(t *testing.T) { 52 | client := NewChatgptClient("test", "test") 53 | 54 | body := `{"detail":{"message":"The server had an error while processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 985e0eeb2c44145e93637d2d79d416cf in your message.)","type":"server_error","param":null,"code":null}}` 55 | resp := &http.Response{ 56 | Body: ioutil.NopCloser(bytes.NewReader([]byte(body))), 57 | } 58 | 59 | _, err := client.parseResponse(resp) 60 | assert.Error(t, err) 61 | assert.Equal(t, `{"detail":{"message":"The server had an error while processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if the error persists. (Please include the request ID 985e0eeb2c44145e93637d2d79d416cf in your message.)","type":"server_error","param":null,"code":null}}`, err.Error()) 62 | } 63 | 64 | func TestParseResponseForInternalServerError(t *testing.T) { 65 | client := NewChatgptClient("test", "test") 66 | 67 | body := `Internal Server Error` 68 | resp := &http.Response{ 69 | Body: ioutil.NopCloser(bytes.NewReader([]byte(body))), 70 | } 71 | 72 | _, err := client.parseResponse(resp) 73 | assert.Error(t, err) 74 | assert.Equal(t, "Internal Server Error", err.Error()) 75 | } 76 | 77 | func TestParseResponseForLongMessagesShouldBeOk(t *testing.T) { 78 | client := NewChatgptClient("test", "test") 79 | 80 | body := `data: {"message": {"id": "d4cb6686-6ef7-46da-87d7-df16abbec928", "role": "assistant", "user": null, "create_time": null, "update_time": null, "content": {"content_type": "text", "parts": ["The term \"911\" is commonly associated with the emergency telephone number in the United States and Canada. When someone dials 911, it connects them with emergency services such as police, fire, or medical"]}, "end_turn": null, "weight": 1.0, "metadata": {"message_type": "next", "model_slug": "text-davinci-002-render-sha"}, "recipient": "all"}, "conversation_id": "45e1a523-c85a-4d11-96c5-b91e38f0ee83", "error": null}` + "\n" 81 | body = body + "data: [DONE]" 82 | 83 | resp := &http.Response{ 84 | Body: ioutil.NopCloser(bytes.NewReader([]byte(body))), 85 | } 86 | 87 | msgs, err := client.parseResponse(resp) 88 | assert.NoError(t, err) 89 | assert.Len(t, msgs, 1) 90 | assert.Equal(t, `The term "911" is commonly associated with the emergency telephone number in the United States and Canada. When someone dials 911, it connects them with emergency services such as police, fire, or medical`, msgs[0].Text) 91 | } 92 | -------------------------------------------------------------------------------- /pkg/auth/chatgpt_auth.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "net/http" 9 | "net/url" 10 | "strings" 11 | "time" 12 | 13 | "github.com/pkg/errors" 14 | "github.com/yubing744/chatgpt-go/pkg/httpx" 15 | "github.com/yubing744/chatgpt-go/pkg/utils" 16 | ) 17 | 18 | // Error represents the base error class 19 | type Error struct { 20 | location string 21 | statusCode int 22 | details string 23 | } 24 | 25 | func (e *Error) Error() string { 26 | return e.details 27 | } 28 | 29 | // Authenticator represents the OpenAI Authentication Reverse Engineered 30 | type Authenticator struct { 31 | sessionToken string 32 | emailAddress string 33 | password string 34 | proxy string 35 | session *httpx.HttpSession 36 | accessToken string 37 | userAgent string 38 | } 39 | 40 | // NewAuthenticator creates a new instance of Authenticator 41 | func NewAuthenticator(emailAddress, password, proxy string) *Authenticator { 42 | auth := &Authenticator{ 43 | emailAddress: emailAddress, 44 | password: password, 45 | proxy: proxy, 46 | userAgent: "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36", 47 | } 48 | 49 | session, err := httpx.NewHttpSession(time.Second * 60) 50 | if err != nil { 51 | log.Fatal("init http session fail") 52 | } 53 | 54 | auth.session = session 55 | 56 | return auth 57 | } 58 | 59 | // urlEncode encodes the string to URL format 60 | func urlEncode(str string) string { 61 | return url.QueryEscape(str) 62 | } 63 | 64 | // begin starts the authentication process 65 | func (a *Authenticator) Begin() error { 66 | url := "https://explorer.api.openai.com/api/auth/csrf" 67 | headers := http.Header{ 68 | "Host": {"explorer.api.openai.com"}, 69 | "Accept": {"*/*"}, 70 | "Connection": {"keep-alive"}, 71 | "User-Agent": {a.userAgent}, 72 | "Accept-Language": {"en-GB,en-US;q=0.9,en;q=0.8"}, 73 | "Referer": {"https://explorer.api.openai.com/auth/login"}, 74 | "Accept-Encoding": {"gzip, deflate, br"}, 75 | } 76 | 77 | resp, err := a.session.Get(url, headers, true) 78 | if err != nil { 79 | return errors.Wrapf(err, "error in get %s", url) 80 | } 81 | 82 | defer resp.Body.Close() 83 | 84 | if resp.StatusCode == http.StatusOK && 85 | strings.Contains(resp.Header.Get("Content-Type"), "application/json") { 86 | var data struct { 87 | CsrfToken string `json:"csrfToken"` 88 | } 89 | if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { 90 | return err 91 | } 92 | 93 | err := a.partOne(data.CsrfToken) 94 | if err != nil { 95 | return err 96 | } 97 | } else { 98 | body, _ := ioutil.ReadAll(resp.Body) 99 | return &Error{ 100 | location: "Begin", 101 | statusCode: resp.StatusCode, 102 | details: fmt.Sprintf("response error, detail: %s", body), 103 | } 104 | } 105 | 106 | return nil 107 | } 108 | 109 | func (a *Authenticator) partOne(token string) error { 110 | url := "https://explorer.api.openai.com/api/auth/signin/auth0?prompt=login" 111 | payload := fmt.Sprintf("callbackUrl=%s&csrfToken=%s&json=true", "%2F", token) 112 | 113 | headers := http.Header{} 114 | headers.Set("Content-Type", "application/x-www-form-urlencoded") 115 | headers.Set("User-Agent", a.userAgent) 116 | headers.Set("Host", "explorer.api.openai.com") 117 | headers.Set("Accept", "*/*") 118 | headers.Set("Accept-Language", "en-US,en;q=0.8") 119 | headers.Set("Origin", "https://explorer.api.openai.com") 120 | headers.Set("Referer", "https://explorer.api.openai.com/auth/login") 121 | headers.Set("Accept-Encoding", "gzip, deflate") 122 | 123 | resp, err := a.session.Post(url, headers, []byte(payload), true) 124 | if err != nil { 125 | return errors.Wrapf(err, "error in get %s", url) 126 | } 127 | 128 | defer resp.Body.Close() 129 | 130 | if resp.StatusCode == http.StatusOK && 131 | strings.Contains(resp.Header.Get("Content-Type"), "application/json") { 132 | var data struct { 133 | URL string `json:"url"` 134 | } 135 | if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { 136 | return err 137 | } 138 | 139 | if data.URL == "https://explorer.api.openai.com/api/auth/error?error=OAuthSignin" || strings.Contains(data.URL, "error") { 140 | return &Error{ 141 | location: "partOne", 142 | statusCode: resp.StatusCode, 143 | details: "You have been rate limited. Please try again later.", 144 | } 145 | } 146 | 147 | err := a.partTwo(data.URL) 148 | if err != nil { 149 | return err 150 | } 151 | } else { 152 | body, _ := ioutil.ReadAll(resp.Body) 153 | return &Error{ 154 | location: "partOne", 155 | statusCode: resp.StatusCode, 156 | details: fmt.Sprintf("response error, detail: %s", body), 157 | } 158 | } 159 | 160 | return nil 161 | } 162 | 163 | func (a *Authenticator) partTwo(url string) error { 164 | headers := http.Header{ 165 | "Host": {"auth0.openai.com"}, 166 | "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, 167 | "Connection": {"keep-alive"}, 168 | "User-Agent": {a.userAgent}, 169 | "Accept-Language": {"en-US,en;q=0.9"}, 170 | "Referer": {"https://explorer.api.openai.com/"}, 171 | } 172 | 173 | resp, err := a.session.Get(url, headers, true) 174 | if err != nil { 175 | return errors.Wrapf(err, "error in get %s", url) 176 | } 177 | 178 | defer resp.Body.Close() 179 | 180 | if resp.StatusCode == http.StatusFound || resp.StatusCode == http.StatusOK { 181 | body, err := ioutil.ReadAll(resp.Body) 182 | if err != nil { 183 | return errors.Wrap(err, "error in read body in part three") 184 | } 185 | 186 | bodyString := string(body) 187 | state, ok := utils.RegexpExtra(bodyString, `state=([a-zA-Z0-9-_]*)`, 1) 188 | if !ok { 189 | return errors.New("not found state in respone body") 190 | } 191 | 192 | err = a.partThree(state) 193 | if err != nil { 194 | return err 195 | } 196 | } else { 197 | body, _ := ioutil.ReadAll(resp.Body) 198 | return &Error{ 199 | location: "partTwo", 200 | statusCode: resp.StatusCode, 201 | details: fmt.Sprintf("response error, detail: %s", body), 202 | } 203 | } 204 | 205 | return nil 206 | } 207 | 208 | func (auth *Authenticator) partThree(state string) error { 209 | url := fmt.Sprintf("https://auth0.openai.com/u/login/identifier?state=%s", state) 210 | headers := http.Header{ 211 | "Host": []string{"auth0.openai.com"}, 212 | "Accept": []string{"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, 213 | "Connection": []string{"keep-alive"}, 214 | "User-Agent": []string{auth.userAgent}, 215 | "Accept-Language": []string{"en-US,en;q=0.9"}, 216 | "Referer": []string{"https://explorer.api.openai.com/"}, 217 | } 218 | 219 | resp, err := auth.session.Get(url, headers, true) 220 | if err != nil { 221 | return errors.Wrapf(err, "error in get %s", url) 222 | } 223 | 224 | defer resp.Body.Close() 225 | 226 | if resp.StatusCode == http.StatusOK { 227 | err := auth.partFour(state) 228 | if err != nil { 229 | return err 230 | } 231 | } else { 232 | body, _ := ioutil.ReadAll(resp.Body) 233 | return &Error{ 234 | location: "partThree", 235 | statusCode: resp.StatusCode, 236 | details: fmt.Sprintf("response error, detail: %s", body), 237 | } 238 | } 239 | 240 | return nil 241 | } 242 | 243 | func (a *Authenticator) partFour(state string) error { 244 | url := fmt.Sprintf("https://auth0.openai.com/u/login/identifier?state=%s", state) 245 | emailURLEncoded := urlEncode(a.emailAddress) 246 | 247 | headers := http.Header{} 248 | headers.Add("Host", "auth0.openai.com") 249 | headers.Add("Origin", "https://auth0.openai.com") 250 | headers.Add("Connection", "keep-alive") 251 | headers.Add("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8") 252 | headers.Add("User-Agent", a.userAgent) 253 | headers.Add("Referer", fmt.Sprintf("https://auth0.openai.com/u/login/identifier?state=%s", state)) 254 | headers.Add("Accept-Language", "en-US,en;q=0.9") 255 | headers.Add("Content-Type", "application/x-www-form-urlencoded") 256 | 257 | payload := fmt.Sprintf("state=%s&username=%s&js-available=false&webauthn-available=true&is-brave=false&webauthn-platform-available=true&action=default", state, emailURLEncoded) 258 | 259 | resp, err := a.session.Post(url, headers, []byte(payload), true) 260 | if err != nil { 261 | return errors.Wrapf(err, "error in get %s", url) 262 | } 263 | 264 | defer resp.Body.Close() 265 | 266 | if resp.StatusCode == 302 || resp.StatusCode == 200 { 267 | err = a.partFive(state) 268 | if err != nil { 269 | return err 270 | } 271 | } else { 272 | body, _ := ioutil.ReadAll(resp.Body) 273 | return &Error{ 274 | location: "partFour", 275 | statusCode: resp.StatusCode, 276 | details: fmt.Sprintf("response error, detail: %s", body), 277 | } 278 | } 279 | 280 | return nil 281 | } 282 | 283 | func (a *Authenticator) partFive(state string) error { 284 | url := fmt.Sprintf("https://auth0.openai.com/u/login/password?state=%s", state) 285 | headers := http.Header{ 286 | "Host": {"auth0.openai.com"}, 287 | "Origin": {"https://auth0.openai.com"}, 288 | "Connection": {"keep-alive"}, 289 | "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, 290 | "User-Agent": {a.userAgent}, 291 | "Referer": {fmt.Sprintf("https://auth0.openai.com/u/login/password?state=%s", state)}, 292 | "Accept-Language": {"en-US,en;q=0.9"}, 293 | "Content-Type": {"application/x-www-form-urlencoded"}, 294 | } 295 | 296 | emailURLEncoded := urlEncode(a.emailAddress) 297 | passwordURLEncoded := urlEncode(a.password) 298 | payload := fmt.Sprintf("state=%s&username=%s&password=%s&action=default", state, emailURLEncoded, passwordURLEncoded) 299 | 300 | resp, err := a.session.Post(url, headers, []byte(payload), false) 301 | if err != nil { 302 | return errors.Wrapf(err, "error in get %s", url) 303 | } 304 | 305 | defer resp.Body.Close() 306 | 307 | if resp.StatusCode == 302 || resp.StatusCode == 200 { 308 | body, err := ioutil.ReadAll(resp.Body) 309 | if err != nil { 310 | return errors.Wrap(err, "error in read body in part five") 311 | } 312 | 313 | bodyString := string(body) 314 | newState, ok := utils.RegexpExtra(bodyString, `state=([a-zA-Z0-9-_]*)`, 1) 315 | if !ok { 316 | fmt.Print(bodyString) 317 | return errors.New("not found state in respone body of part five") 318 | } 319 | 320 | err = a.partSix(state, newState) 321 | if err != nil { 322 | return err 323 | } 324 | } else { 325 | body, _ := ioutil.ReadAll(resp.Body) 326 | return &Error{ 327 | location: "partFive", 328 | statusCode: resp.StatusCode, 329 | details: fmt.Sprintf("response error, detail: %s", body), 330 | } 331 | } 332 | 333 | return nil 334 | } 335 | 336 | func (a *Authenticator) partSix(oldState, newState string) error { 337 | url := fmt.Sprintf("https://auth0.openai.com/authorize/resume?state=%s", newState) 338 | 339 | headers := http.Header{ 340 | "Host": {"auth0.openai.com"}, 341 | "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, 342 | "Connection": {"keep-alive"}, 343 | "User-Agent": {a.userAgent}, 344 | "Accept-Language": {"en-GB,en-US;q=0.9,en;q=0.8"}, 345 | "Referer": {fmt.Sprintf("https://auth0.openai.com/u/login/password?state=%s", oldState)}, 346 | } 347 | 348 | resp, err := a.session.Get(url, headers, false) 349 | if err != nil { 350 | return errors.Wrapf(err, "error in get %s", url) 351 | } 352 | 353 | defer resp.Body.Close() 354 | 355 | if resp.StatusCode == 302 { 356 | // Print redirect url 357 | redirectURL := resp.Header.Get("Location") 358 | if err = a.partSeven(redirectURL, url); err != nil { 359 | return err 360 | } 361 | } else { 362 | body, _ := ioutil.ReadAll(resp.Body) 363 | return &Error{ 364 | location: "partSix", 365 | statusCode: resp.StatusCode, 366 | details: fmt.Sprintf("response error, detail: %s", body), 367 | } 368 | } 369 | 370 | return nil 371 | } 372 | 373 | func (a *Authenticator) partSeven(redirectURL string, previousURL string) error { 374 | url := redirectURL 375 | headers := http.Header{ 376 | "Host": {"explorer.api.openai.com"}, 377 | "Accept": {"application/json"}, 378 | "Connection": {"keep-alive"}, 379 | "User-Agent": {a.userAgent}, 380 | "Accept-Language": {"en-GB,en-US;q=0.9,en;q=0.8"}, 381 | "Referer": {previousURL}, 382 | } 383 | 384 | resp, err := a.session.Get(url, headers, false) 385 | if err != nil { 386 | return errors.Wrapf(err, "error in get %s", url) 387 | } 388 | 389 | defer resp.Body.Close() 390 | 391 | if resp.StatusCode == 302 { 392 | cookies := httpx.Coookies(resp.Cookies()) 393 | sessionToken, ok := cookies.Get("__Secure-next-auth.session-token") 394 | 395 | if ok { 396 | a.sessionToken = sessionToken 397 | _, err = a.GetAccessToken() 398 | if err != nil { 399 | return err 400 | } 401 | } 402 | } else { 403 | body, _ := ioutil.ReadAll(resp.Body) 404 | return &Error{ 405 | location: "partSeven", 406 | statusCode: resp.StatusCode, 407 | details: fmt.Sprintf("response error, detail: %s", body), 408 | } 409 | } 410 | 411 | return nil 412 | } 413 | 414 | func (a *Authenticator) GetSessionToken() string { 415 | return a.sessionToken 416 | } 417 | 418 | func (a *Authenticator) GetAccessToken() (string, error) { 419 | a.session.Cookies("openai.com").Set( 420 | "__Secure-next-auth.session-token", 421 | a.sessionToken, 422 | ) 423 | 424 | resp, err := a.session.Get("https://explorer.api.openai.com/api/auth/session", nil, true) 425 | if err != nil { 426 | return "", err 427 | } 428 | if resp.StatusCode == 200 && 429 | strings.Contains(resp.Header.Get("Content-Type"), "application/json") { 430 | var data struct { 431 | AccessToken string `json:"accessToken"` 432 | } 433 | if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { 434 | return "", err 435 | } 436 | a.accessToken = data.AccessToken 437 | return a.accessToken, nil 438 | } else { 439 | body, _ := ioutil.ReadAll(resp.Body) 440 | return "", &Error{ 441 | location: "GetAccessToken", 442 | statusCode: resp.StatusCode, 443 | details: fmt.Sprintf("response error, detail: %s", body), 444 | } 445 | } 446 | } 447 | -------------------------------------------------------------------------------- /pkg/auth/chatgpt_auth_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | -------------------------------------------------------------------------------- /pkg/client.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "log" 7 | "net/http" 8 | "time" 9 | 10 | "github.com/pkg/errors" 11 | "github.com/yubing744/chatgpt-go/pkg/auth" 12 | "github.com/yubing744/chatgpt-go/pkg/httpx" 13 | ) 14 | 15 | // Logger is used for logging formatted messages. 16 | type Logger interface { 17 | // Printf must have the same semantics as log.Printf. 18 | Printf(format string, args ...interface{}) 19 | } 20 | 21 | type ChatgptClient struct { 22 | session *httpx.HttpSession 23 | auth *auth.Authenticator 24 | logger Logger 25 | cancel context.CancelFunc 26 | baseURL string 27 | debug bool 28 | } 29 | 30 | func NewChatgptClient(email string, password string, opts ...Option) *ChatgptClient { 31 | cfg := &Options{ 32 | baseURL: "https://chatgpt.duti.tech", 33 | timeout: time.Second * 300, 34 | proxy: "", 35 | debug: false, 36 | logger: &Log{}, 37 | } 38 | 39 | for _, opt := range opts { 40 | opt(cfg) 41 | } 42 | 43 | client := &ChatgptClient{ 44 | baseURL: cfg.baseURL, 45 | debug: cfg.debug, 46 | logger: cfg.logger, 47 | } 48 | 49 | session, err := httpx.NewHttpSession(cfg.timeout) 50 | if err != nil { 51 | log.Fatal("init http session fatal") 52 | } 53 | 54 | client.session = session 55 | client.auth = auth.NewAuthenticator(email, password, cfg.proxy) 56 | 57 | return client 58 | } 59 | 60 | func (client *ChatgptClient) Start(ctx context.Context) error { 61 | ctx, client.cancel = context.WithCancel(ctx) 62 | 63 | err := client.auth.Begin() 64 | if err != nil { 65 | return errors.Wrap(err, "Error in auth") 66 | } 67 | 68 | err = client.refreshToken() 69 | if err != nil { 70 | return err 71 | } 72 | 73 | ticker := time.NewTicker(10 * time.Minute) // 每 10 分钟刷新一次 token 74 | 75 | go func() { 76 | for { 77 | select { 78 | case <-ctx.Done(): 79 | client.logger.Printf("stop ticker ...\n") 80 | ticker.Stop() 81 | return 82 | case <-ticker.C: 83 | // 执行刷新 token 的逻辑 84 | err := client.refreshToken() 85 | if err != nil { 86 | client.logger.Printf("fresh token error: %s\n", err.Error()) 87 | continue 88 | } 89 | } 90 | } 91 | }() 92 | 93 | return nil 94 | } 95 | 96 | func (client *ChatgptClient) Stop() { 97 | if client.cancel != nil { 98 | client.cancel() 99 | } 100 | } 101 | 102 | func (client *ChatgptClient) refreshToken() error { 103 | client.logger.Printf("fresh token ...\n") 104 | 105 | accessToken, err := client.auth.GetAccessToken() 106 | if err != nil { 107 | return errors.Wrap(err, "Error in get access token") 108 | } 109 | 110 | client.session.SetHeaders(http.Header{ 111 | "Accept": {"text/event-stream"}, 112 | "Authorization": {fmt.Sprintf("Bearer %s", accessToken)}, 113 | "Content-Type": {"application/json"}, 114 | "X-Openai-Assistant-App-Id": {""}, 115 | "Connection": {"close"}, 116 | "Accept-Language": {"en-US,en;q=0.9"}, 117 | "Referer": {"https://chat.openai.com/chat"}, 118 | }) 119 | 120 | client.logger.Printf("fresh token ok!\n") 121 | 122 | return nil 123 | } 124 | -------------------------------------------------------------------------------- /pkg/client_test.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestNewChatgptClient(t *testing.T) { 10 | client := NewChatgptClient("test", "test") 11 | assert.NotNil(t, client) 12 | } 13 | -------------------------------------------------------------------------------- /pkg/httpx/cookies.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import "net/http" 4 | 5 | type Coookies []*http.Cookie 6 | 7 | func (c Coookies) Get(name string) (string, bool) { 8 | for _, item := range c { 9 | if item.Name == name { 10 | return item.Value, true 11 | } 12 | } 13 | 14 | return "", false 15 | } 16 | 17 | func (c Coookies) Set(name string, val string) bool { 18 | for _, item := range c { 19 | if item.Name == name { 20 | item.Value = val 21 | return true 22 | } 23 | } 24 | 25 | return false 26 | } 27 | -------------------------------------------------------------------------------- /pkg/httpx/cookies_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestCookiesGet(t *testing.T) { 11 | client, err := NewHttpSession(time.Second * 5) 12 | assert.NoError(t, err) 13 | assert.NotNil(t, client) 14 | 15 | resp, err := client.Get("https://www.bing.com/", nil, true) 16 | if resp != nil { 17 | defer resp.Body.Close() 18 | } 19 | 20 | assert.NoError(t, err) 21 | assert.NotEmpty(t, resp) 22 | 23 | cookies := client.Cookies("bing.com") 24 | assert.NotNil(t, cookies) 25 | 26 | val, ok := cookies.Get("SUID") 27 | assert.True(t, ok) 28 | assert.NotEmpty(t, val) 29 | } 30 | 31 | func TestCookiesSet(t *testing.T) { 32 | client, err := NewHttpSession(time.Second * 5) 33 | assert.NoError(t, err) 34 | assert.NotNil(t, client) 35 | 36 | resp, err := client.Get("https://www.bing.com/", nil, true) 37 | if resp != nil { 38 | defer resp.Body.Close() 39 | } 40 | 41 | assert.NoError(t, err) 42 | assert.NotEmpty(t, resp) 43 | 44 | cookies := client.Cookies("bing.com") 45 | assert.NotNil(t, cookies) 46 | 47 | ok := cookies.Set("SUID", "xxxx") 48 | assert.True(t, ok) 49 | 50 | val, ok := cookies.Get("SUID") 51 | assert.True(t, ok) 52 | assert.Equal(t, "xxxx", val) 53 | } 54 | -------------------------------------------------------------------------------- /pkg/httpx/http_session.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/cookiejar" 7 | "net/url" 8 | "time" 9 | 10 | "golang.org/x/net/publicsuffix" 11 | ) 12 | 13 | // HttpSession 封装了 http.Client,实现了会话保持和 headers 的传递 14 | type HttpSession struct { 15 | client *http.Client 16 | headers http.Header 17 | } 18 | 19 | // NewHttpSessionClient 返回一个新的 HttpSessionClient 实例 20 | func NewHttpSession(timeout time.Duration) (*HttpSession, error) { 21 | opts := &cookiejar.Options{ 22 | PublicSuffixList: publicsuffix.List, 23 | } 24 | 25 | cookieJar, err := cookiejar.New(opts) 26 | if err != nil { 27 | return nil, err 28 | } 29 | httpClient := &http.Client{ 30 | Timeout: timeout, 31 | Jar: cookieJar, 32 | } 33 | return &HttpSession{ 34 | client: httpClient, 35 | }, nil 36 | } 37 | 38 | // Get 发送 GET 请求 39 | func (httpx *HttpSession) Get(url string, headers http.Header, allowRedirects bool) (*http.Response, error) { 40 | req, err := http.NewRequest("GET", url, nil) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | if httpx.headers != nil { 46 | for key, value := range httpx.headers { 47 | req.Header.Set(key, value[0]) 48 | } 49 | } 50 | 51 | for key, value := range headers { 52 | req.Header.Set(key, value[0]) 53 | } 54 | 55 | if !allowRedirects { 56 | httpx.client.CheckRedirect = func(req *http.Request, via []*http.Request) error { 57 | return http.ErrUseLastResponse 58 | } 59 | 60 | defer func() { 61 | httpx.client.CheckRedirect = nil 62 | }() 63 | } 64 | 65 | resp, err := httpx.client.Do(req) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | return resp, nil 71 | } 72 | 73 | // Post 发送 POST 请求 74 | func (httpx *HttpSession) Post(url string, headers http.Header, data []byte, allowRedirects bool) (*http.Response, error) { 75 | req, err := http.NewRequest("POST", url, bytes.NewReader(data)) 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | if httpx.headers != nil { 81 | for key, value := range httpx.headers { 82 | req.Header.Set(key, value[0]) 83 | } 84 | } 85 | 86 | for key, value := range headers { 87 | req.Header.Set(key, value[0]) 88 | } 89 | 90 | if !allowRedirects { 91 | httpx.client.CheckRedirect = func(req *http.Request, via []*http.Request) error { 92 | return http.ErrUseLastResponse 93 | } 94 | 95 | defer func() { 96 | httpx.client.CheckRedirect = nil 97 | }() 98 | } 99 | 100 | resp, err := httpx.client.Do(req) 101 | if err != nil { 102 | return nil, err 103 | } 104 | 105 | return resp, nil 106 | } 107 | 108 | // Cookies returns the value of a cookie 109 | func (httpx *HttpSession) Cookies(host string) Coookies { 110 | domain := &url.URL{ 111 | Scheme: "https", 112 | Host: host, 113 | Path: "/", 114 | } 115 | 116 | rawCookies := httpx.client.Jar.Cookies(domain) 117 | if rawCookies != nil { 118 | return Coookies(rawCookies) 119 | } 120 | 121 | return nil 122 | } 123 | 124 | func (httpx *HttpSession) SetHeaders(headers http.Header) { 125 | httpx.headers = headers 126 | } 127 | -------------------------------------------------------------------------------- /pkg/httpx/http_session_test.go: -------------------------------------------------------------------------------- 1 | package httpx 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestHttpSession(t *testing.T) { 11 | client, err := NewHttpSession(time.Second * 5) 12 | assert.NoError(t, err) 13 | assert.NotNil(t, client) 14 | } 15 | 16 | func TestHTTPXGet(t *testing.T) { 17 | client, err := NewHttpSession(time.Second * 5) 18 | assert.NoError(t, err) 19 | assert.NotNil(t, client) 20 | 21 | resp, err := client.Get("https://www.bing.com/", nil, true) 22 | if resp != nil { 23 | defer resp.Body.Close() 24 | } 25 | 26 | assert.NoError(t, err) 27 | assert.NotEmpty(t, resp) 28 | } 29 | 30 | func TestHTTPXGetCookies(t *testing.T) { 31 | client, err := NewHttpSession(time.Second * 5) 32 | assert.NoError(t, err) 33 | assert.NotNil(t, client) 34 | 35 | resp, err := client.Get("https://www.bing.com/", nil, true) 36 | if resp != nil { 37 | defer resp.Body.Close() 38 | } 39 | 40 | assert.NoError(t, err) 41 | assert.NotEmpty(t, resp) 42 | 43 | cookies := client.Cookies("bing.com") 44 | assert.NotNil(t, cookies) 45 | } 46 | -------------------------------------------------------------------------------- /pkg/log.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import "fmt" 4 | 5 | type Log struct{} 6 | 7 | func (log *Log) Printf(format string, args ...interface{}) { 8 | fmt.Printf(format, args...) 9 | } 10 | -------------------------------------------------------------------------------- /pkg/log_test.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import "testing" 4 | 5 | func TestLog(t *testing.T) { 6 | log := &Log{} 7 | log.Printf("hello") 8 | } 9 | -------------------------------------------------------------------------------- /pkg/option.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import "time" 4 | 5 | type Options struct { 6 | baseURL string 7 | proxy string 8 | timeout time.Duration 9 | debug bool 10 | logger Logger 11 | } 12 | 13 | type Option func(opts *Options) 14 | 15 | func WithOptions(options Options) Option { 16 | return func(opts *Options) { 17 | *opts = options 18 | } 19 | } 20 | 21 | func WithBaseURL(baseURL string) Option { 22 | return func(opts *Options) { 23 | opts.baseURL = baseURL 24 | } 25 | } 26 | 27 | func WithProxy(proxy string) Option { 28 | return func(opts *Options) { 29 | opts.proxy = proxy 30 | } 31 | } 32 | 33 | func WithTimeout(timeout time.Duration) Option { 34 | return func(opts *Options) { 35 | opts.timeout = timeout 36 | } 37 | } 38 | 39 | func WithDebug(debug bool) Option { 40 | return func(opts *Options) { 41 | opts.debug = debug 42 | } 43 | } 44 | 45 | func WithLogger(logger Logger) Option { 46 | return func(opts *Options) { 47 | opts.logger = logger 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /pkg/option_test.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestWithOptions(t *testing.T) { 11 | cfg := &Options{ 12 | baseURL: "https://chatgpt.duti.tech", 13 | } 14 | 15 | opt := WithOptions(Options{ 16 | baseURL: "https://chatgpt.duti.tech2", 17 | }) 18 | 19 | opt(cfg) 20 | 21 | assert.Equal(t, "https://chatgpt.duti.tech2", cfg.baseURL) 22 | } 23 | 24 | func TestWithBaseURL(t *testing.T) { 25 | cfg := &Options{ 26 | baseURL: "https://chatgpt.duti.tech", 27 | } 28 | 29 | opt := WithBaseURL("https://chatgpt.duti.tech2") 30 | 31 | opt(cfg) 32 | 33 | assert.Equal(t, "https://chatgpt.duti.tech2", cfg.baseURL) 34 | } 35 | 36 | func TestWithProxy(t *testing.T) { 37 | cfg := &Options{ 38 | proxy: "", 39 | } 40 | 41 | opt := WithProxy("127.0.0.1:8081") 42 | 43 | opt(cfg) 44 | 45 | assert.Equal(t, "127.0.0.1:8081", cfg.proxy) 46 | } 47 | 48 | func TestWithTimeout(t *testing.T) { 49 | cfg := &Options{ 50 | timeout: time.Second * 5, 51 | } 52 | 53 | opt := WithTimeout(time.Second * 10) 54 | 55 | opt(cfg) 56 | 57 | assert.Equal(t, float64(10), cfg.timeout.Seconds()) 58 | } 59 | 60 | func TestWithDebug(t *testing.T) { 61 | cfg := &Options{ 62 | debug: false, 63 | } 64 | 65 | opt := WithDebug(true) 66 | 67 | opt(cfg) 68 | 69 | assert.True(t, cfg.debug) 70 | } 71 | 72 | func TestWithLogger(t *testing.T) { 73 | cfg := &Options{ 74 | logger: nil, 75 | } 76 | 77 | opt := WithLogger(&Log{}) 78 | 79 | opt(cfg) 80 | 81 | assert.NotNil(t, cfg.logger) 82 | } 83 | -------------------------------------------------------------------------------- /pkg/utils/regexp.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import "regexp" 4 | 5 | func RegexpExtra(text string, pattern string, groupIndex int) (string, bool) { 6 | var re = regexp.MustCompile(pattern) 7 | 8 | matchs := re.FindAllStringSubmatch(text, -1) 9 | if len(matchs) > 0 { 10 | match := matchs[0] 11 | if groupIndex >= 0 && groupIndex < len(match) { 12 | return match[groupIndex], true 13 | } 14 | } 15 | 16 | return "", false 17 | } 18 | -------------------------------------------------------------------------------- /pkg/utils/regexp_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestRegexpExtra(t *testing.T) { 10 | text := "state=hKFo2SA5eEZPZTRjVjJESVhNOUYtZ1pUZEdVVWRIeW1UekNRV6Fur3VuaXZlcnNhbC1sb2dpbqN0aWTZIGVXaDJ1Vm1RRFRDTUJMbDZsMjhwREFTR0J3eWVMRXNZo2NpZNkgVGRKSWNiZTE2V29USHROOTVueXl3aDVFNHlPbzZJdEc\" aria-label=\"\">Sign up