├── LICENSE ├── README.md ├── api.go ├── binding ├── codegen │ ├── binding.go │ ├── codegen.go │ ├── decode.go │ ├── decode_anything.go │ ├── decode_binary.go │ ├── decode_enum.go │ ├── decode_map.go │ ├── decode_pointer.go │ ├── decode_simple_value.go │ ├── decode_slice.go │ ├── decode_struct.go │ ├── encode.go │ ├── encode_anything.go │ ├── encode_binary.go │ ├── encode_enum.go │ ├── encode_map.go │ ├── encode_pointer.go │ ├── encode_simple_value.go │ ├── encode_slice.go │ └── encode_struct.go └── reflection │ ├── decode.go │ ├── decode_map.go │ ├── decode_pointer.go │ ├── decode_simple_value.go │ ├── decode_slice.go │ ├── decode_struct.go │ ├── encode.go │ ├── encode_map.go │ ├── encode_pointer.go │ ├── encode_simple_value.go │ ├── encode_slice.go │ ├── encode_struct.go │ └── unsafe.go ├── cmd └── thrifter │ └── main.go ├── config.go ├── decoder.go ├── encoder.go ├── general ├── decode.go ├── decode_list.go ├── decode_map.go ├── decode_message.go ├── decode_struct.go ├── encode.go ├── encode_list.go ├── encode_map.go ├── encode_message.go ├── encode_struct.go ├── general_extension.go └── general_object.go ├── protocol ├── binary │ ├── discard.go │ ├── iterator.go │ ├── skip.go │ └── stream.go ├── compact │ ├── discard.go │ ├── iterator.go │ ├── skip.go │ ├── stream.go │ └── type.go └── protocol.go ├── raw ├── decode_list.go ├── decode_map.go ├── decode_struct.go ├── encode_list.go ├── encode_map.go ├── encode_struct.go ├── raw_extension.go └── raw_object.go ├── spi ├── discard.go └── spi.go └── test ├── api ├── api_test.go ├── binding_test.go ├── binding_test │ └── model.go ├── generated.go ├── init.go ├── raw_message_test.go └── shortcut_test.go ├── combinations.go ├── level_0 ├── bool_test.go ├── enum_test.go ├── enum_test │ └── Player.go ├── float64_test.go ├── int16_test.go ├── int32_test.go ├── int64_test.go ├── int8_test.go ├── int_test.go ├── level_0_test.go ├── uint16_test.go ├── uint32_test.go ├── uint64_test.go ├── uint8_test.go └── uint_test.go ├── level_1 ├── binary_test.go ├── level_1_test.go ├── list_test.go ├── map_test.go ├── pointer_test.go ├── string_test.go ├── struct_test.go └── struct_test │ └── TestObject.go └── level_2 ├── level_2_test.go ├── list_of_list_test.go ├── list_of_map_test.go ├── list_of_string_test.go ├── list_of_struct_test.go ├── list_of_struct_test └── TestObject.go ├── map_of_list_test.go ├── map_of_map_test.go ├── map_of_string_test.go ├── map_of_struct_test.go ├── map_of_struct_test └── TestObject.go ├── message_test.go ├── struct_complex_test.go ├── struct_complex_test └── TestObject.go ├── struct_of_list_test.go ├── struct_of_list_test └── TestObject.go ├── struct_of_map_test.go ├── struct_of_map_test └── TestObject.go ├── struct_of_pointer_test.go ├── struct_of_pointer_test ├── StructOf1Ptr.go └── StructOf2Ptr.go ├── struct_of_string_test.go ├── struct_of_string_test └── TestObject.go ├── struct_of_struct_test.go └── struct_of_struct_test └── TestObject.go /README.md: -------------------------------------------------------------------------------- 1 | # thrifter 2 | 3 | decode/encode thrift message without IDL 4 | 5 | Why? 6 | 7 | * because IDL generated model is ugly and inflexible, it is seldom used in application directly. instead we define another model, which leads to bad performance. 8 | * bytes need to be copied twice 9 | * more objects to gc 10 | * thrift proxy can not know all possible IDL in advance, in scenarios like api gateway, we need to decode/encode in a generic way to modify embedded header. 11 | * official thrift library for go is slow, verified in several benchmarks. It is even slower than [json-iterator](https://github.com/json-iterator/go) 12 | 13 | # works like encoding/json 14 | 15 | `encoding/json` has a super simple api to encode/decode json. 16 | thrifter mimic the same api. 17 | 18 | ```go 19 | import "github.com/thrift-iterator/go" 20 | // marshal to thrift 21 | thriftEncodedBytes, err := thrifter.Marshal([]int{1, 2, 3}) 22 | // unmarshal back 23 | var val []int 24 | err = thrifter.Unmarshal(thriftEncodedBytes, &val) 25 | ``` 26 | 27 | even struct data binding is supported 28 | 29 | ```go 30 | import "github.com/thrift-iterator/go" 31 | 32 | type NewOrderRequest struct { 33 | Lines []NewOrderLine `thrift:",1"` 34 | } 35 | 36 | type NewOrderLine struct { 37 | ProductId string `thrift:",1"` 38 | Quantity int `thrift:",2"` 39 | } 40 | 41 | // marshal to thrift 42 | thriftEncodedBytes, err := thrifter.Marshal(NewOrderRequest{ 43 | Lines: []NewOrderLine{ 44 | {"apple", 1}, 45 | {"orange", 2}, 46 | } 47 | }) 48 | // unmarshal back 49 | var val NewOrderRequest 50 | err = thrifter.Unmarshal(thriftEncodedBytes, &val) 51 | ``` 52 | 53 | # without IDL 54 | 55 | you do not need to define IDL. you do not need to use static code generation. 56 | you do not event need to define struct. 57 | 58 | ```go 59 | import "github.com/thrift-iterator/go" 60 | import "github.com/thrift-iterator/go/general" 61 | 62 | var msg general.Message 63 | err := thrifter.Unmarshal(thriftEncodedBytes, &msg) 64 | // the RPC call method name, type is string 65 | fmt.Println(msg.MessageName) 66 | // the RPC call arguments, type is general.Struct 67 | fmt.Println(msg.Arguments) 68 | ``` 69 | 70 | what is `general.Struct`, it is defined as a map 71 | 72 | ```go 73 | type FieldId int16 74 | type Struct map[FieldId]interface{} 75 | ``` 76 | 77 | we can extract out specific argument from deeply nested arguments using one line 78 | 79 | ```go 80 | productId := msg.MessageArgs.Get( 81 | protocol.FieldId(1), // lines of request 82 | 0, // the first line 83 | protocol.FieldId(1), // product id 84 | ).(string) 85 | ``` 86 | 87 | You can unmarshal any thrift bytes into general objects. And you can marshal them back. 88 | 89 | # Partial decoding 90 | 91 | fully decoding into a go struct consumes substantial resources. 92 | thrifter provide option to do partial decoding. You can modify part of the 93 | message, with untouched parts in `[]byte` form. 94 | 95 | ```go 96 | import "github.com/thrift-iterator/go" 97 | import "github.com/thrift-iterator/go/protocol" 98 | import "github.com/thrift-iterator/go/raw" 99 | 100 | // partial decoding 101 | decoder := thrifter.NewDecoder(reader) 102 | var msgHeader protocol.MessageHeader 103 | decoder.Decode(&msgHeader) 104 | var msgArgs raw.Struct 105 | decoder.Decode(&msgArgs) 106 | 107 | // modify... 108 | 109 | // encode back 110 | encoder := thrifter.NewEncoder(writer) 111 | encoder.Encode(msgHeader) 112 | encoder.Encode(msgArgs) 113 | ``` 114 | 115 | the definition of `raw.Struct` is 116 | 117 | ```go 118 | type StructField struct { 119 | Buffer []byte 120 | Type protocol.TType 121 | } 122 | 123 | type Struct map[protocol.FieldId]StructField 124 | ``` 125 | 126 | # Performance 127 | 128 | thrifter does not compromise performance. 129 | 130 | gogoprotobuf 131 | 132 | ``` 133 | 5000000 366 ns/op 144 B/op 12 allocs/op 134 | ``` 135 | 136 | thrift 137 | 138 | ``` 139 | 1000000 1549 ns/op 528 B/op 9 allocs/op 140 | ``` 141 | 142 | thrifter by static codegen 143 | 144 | ``` 145 | 5000000 389 ns/op 192 B/op 6 allocs/op 146 | ``` 147 | 148 | thrifter by reflection 149 | 150 | ``` 151 | 2000000 585 ns/op 192 B/op 6 allocs/op 152 | ``` 153 | 154 | You can see the reflection implementation is not bad, much faster than the 155 | static code generated by thrift original implementation. 156 | 157 | To have best performance, you can choose to use static code generation. The api 158 | is unchanged, just need to add extra static codegen in your build steps, and include 159 | the generated code in your package. The runtime will automatically use the 160 | generated encoder/decoder instead of reflection. 161 | 162 | For example of static codegen, checkout [https://github.com/thrift-iterator/go/blob/master/test/api/init.go](https://github.com/thrift-iterator/go/blob/master/test/api/init.go) 163 | 164 | # Sync IDL and Go Struct 165 | 166 | Keep IDL and your object model is challenging. We do not always like the code 167 | generated from thrift IDL. But manually keeping the IDL and model in sync is 168 | tedious and error prone. 169 | 170 | A separate toolchain to manipulate thrift IDL file, and keeping them bidirectionally in sync 171 | will be provided in another project. 172 | 173 | -------------------------------------------------------------------------------- /api.go: -------------------------------------------------------------------------------- 1 | package thrifter 2 | 3 | import ( 4 | "io" 5 | "github.com/thrift-iterator/go/spi" 6 | "github.com/thrift-iterator/go/general" 7 | ) 8 | 9 | type Protocol int 10 | 11 | var ProtocolBinary Protocol = 1 12 | var ProtocolCompact Protocol = 2 13 | 14 | type Config struct { 15 | Protocol Protocol 16 | StaticCodegen bool 17 | Extensions spi.Extensions 18 | } 19 | 20 | type API interface { 21 | // NewStream is low level streaming api 22 | NewStream(writer io.Writer, buf []byte) spi.Stream 23 | // NewIterator is low level streaming api 24 | NewIterator(reader io.Reader, buf []byte) spi.Iterator 25 | // Unmarshal from []byte 26 | Unmarshal(buf []byte, obj interface{}) error 27 | // UnmarshalMessage from []byte 28 | UnmarshalMessage(buf []byte) (general.Message, error) 29 | // Marshal to []byte 30 | Marshal(obj interface{}) ([]byte, error) 31 | // ToJSON convert thrift message to JSON string 32 | ToJSON(buf []byte) (string, error) 33 | // MarshalMessage to []byte 34 | MarshalMessage(msg general.Message) ([]byte, error) 35 | // NewDecoder to unmarshal from []byte or io.Reader 36 | NewDecoder(reader io.Reader, buf []byte) *Decoder 37 | // NewEncoder to marshal to io.Writer 38 | NewEncoder(writer io.Writer) *Encoder 39 | // WillDecodeFromBuffer should only be used in generic.Declare 40 | WillDecodeFromBuffer(sample ...interface{}) 41 | // WillDecodeFromReader should only be used in generic.Declare 42 | WillDecodeFromReader(sample ...interface{}) 43 | // WillEncode should only be used in generic.Declare 44 | WillEncode(sample ...interface{}) 45 | } 46 | 47 | var DefaultConfig = Config{Protocol: ProtocolBinary, StaticCodegen: false}.Froze() 48 | 49 | func NewStream(writer io.Writer, buf []byte) spi.Stream { 50 | return DefaultConfig.NewStream(writer, buf) 51 | } 52 | 53 | func NewIterator(reader io.Reader, buf []byte) spi.Iterator { 54 | return DefaultConfig.NewIterator(reader, buf) 55 | } 56 | 57 | func Unmarshal(buf []byte, obj interface{}) error { 58 | return DefaultConfig.Unmarshal(buf, obj) 59 | } 60 | 61 | // UnmarshalMessage demonstrate how to decode thrift binary without IDL into a general message struct 62 | func UnmarshalMessage(buf []byte) (general.Message, error) { 63 | return DefaultConfig.UnmarshalMessage(buf) 64 | } 65 | 66 | // ToJSON convert the thrift message to JSON string 67 | func ToJSON(buf []byte) (string, error) { 68 | return DefaultConfig.ToJSON(buf) 69 | } 70 | 71 | func Marshal(obj interface{}) ([]byte, error) { 72 | return DefaultConfig.Marshal(obj) 73 | } 74 | 75 | // MarshalMessage is just a shortcut to demonstrate message decoded by UnmarshalMessage can be encoded back 76 | func MarshalMessage(msg general.Message) ([]byte, error) { 77 | return DefaultConfig.MarshalMessage(msg) 78 | } 79 | 80 | func NewDecoder(reader io.Reader, buf []byte) *Decoder { 81 | return DefaultConfig.NewDecoder(reader, buf) 82 | } 83 | 84 | func NewEncoder(writer io.Writer) *Encoder { 85 | return DefaultConfig.NewEncoder(writer) 86 | } 87 | -------------------------------------------------------------------------------- /binding/codegen/binding.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "reflect" 5 | "github.com/thrift-iterator/go/protocol" 6 | "strings" 7 | "strconv" 8 | ) 9 | 10 | var byteArrayType = reflect.TypeOf(([]byte)(nil)) 11 | 12 | var simpleValueMap = map[reflect.Kind]string{ 13 | reflect.Int: "Int", 14 | reflect.Int8: "Int8", 15 | reflect.Int16: "Int16", 16 | reflect.Int32: "Int32", 17 | reflect.Int64: "Int64", 18 | reflect.Uint: "Uint", 19 | reflect.Uint8: "Uint8", 20 | reflect.Uint16: "Uint16", 21 | reflect.Uint32: "Uint32", 22 | reflect.Uint64: "Uint64", 23 | reflect.Float32: "Float32", 24 | reflect.Float64: "Float64", 25 | reflect.String: "String", 26 | reflect.Bool: "Bool", 27 | } 28 | 29 | var thriftTypeMap = map[reflect.Kind]protocol.TType{ 30 | reflect.Int: protocol.TypeI64, 31 | reflect.Int8: protocol.TypeI08, 32 | reflect.Int16: protocol.TypeI16, 33 | reflect.Int32: protocol.TypeI32, 34 | reflect.Int64: protocol.TypeI64, 35 | reflect.Uint: protocol.TypeI64, 36 | reflect.Uint8: protocol.TypeI08, 37 | reflect.Uint16: protocol.TypeI16, 38 | reflect.Uint32: protocol.TypeI32, 39 | reflect.Uint64: protocol.TypeI64, 40 | reflect.Float32: protocol.TypeDouble, 41 | reflect.Float64: protocol.TypeDouble, 42 | reflect.String: protocol.TypeString, 43 | reflect.Bool: protocol.TypeBool, 44 | } 45 | 46 | func isEnumType(valType reflect.Type) bool { 47 | if valType.Kind() != reflect.Int64 { 48 | return false 49 | } 50 | _, hasStringMethod := valType.MethodByName("String") 51 | return hasStringMethod 52 | } 53 | 54 | func calcBindings(valType reflect.Type) interface{} { 55 | bindings := []interface{}{} 56 | for i := 0; i < valType.NumField(); i++ { 57 | field := valType.Field(i) 58 | fieldId := protocol.FieldId(0) 59 | thriftTag := field.Tag.Get("thrift") 60 | if thriftTag != "" { 61 | parts := strings.Split(thriftTag, ",") 62 | if len(parts) >= 2 { 63 | n, err := strconv.Atoi(parts[1]) 64 | if err != nil { 65 | panic("thrift tag must be integer") 66 | } 67 | fieldId = protocol.FieldId(n) 68 | } 69 | } 70 | if fieldId == 0 { 71 | continue 72 | } 73 | bindings = append(bindings, map[string]interface{}{ 74 | "fieldId": fieldId, 75 | "fieldName": field.Name, 76 | "fieldType": reflect.PtrTo(field.Type), 77 | }) 78 | } 79 | return bindings 80 | } -------------------------------------------------------------------------------- /binding/codegen/codegen.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "reflect" 6 | ) 7 | 8 | type Extension struct { 9 | spi.Extension 10 | ExtTypes []reflect.Type 11 | } 12 | 13 | func (ext *Extension) MangledName() string { 14 | // TODO: hash extension to represent different config 15 | return "default" 16 | } -------------------------------------------------------------------------------- /binding/codegen/decode.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | ) 6 | 7 | var Decode = generic.DefineFunc("Decode(dst interface{}, src interface{})"). 8 | Param("EXT", "user provided extension"). 9 | Param("DT", "the dst type to copy into"). 10 | Param("ST", "the src type to copy from"). 11 | ImportPackage("reflect"). 12 | Declare("var typeOf = reflect.TypeOf"). 13 | ImportFunc(decodeAnything). 14 | Source(` 15 | {{ $decode := expand "DecodeAnything" "EXT" .EXT "DT" .DT "ST" .ST }} 16 | iter := src.({{.ST|name}}) 17 | {{ range $extType := .EXT.ExtTypes }} 18 | if iter.GetDecoder("{{$extType|name}}") == nil { 19 | iter.PrepareDecoder(reflect.TypeOf((*{{$extType|name}})(nil)).Elem()) 20 | } 21 | {{ end }} 22 | {{$decode}}(dst.({{.DT|name}}), iter) 23 | `) 24 | -------------------------------------------------------------------------------- /binding/codegen/decode_anything.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | "reflect" 6 | ) 7 | 8 | func dispatchDecode(extension *Extension, dstType reflect.Type) string { 9 | if extension.DecoderOf(dstType) != nil { 10 | extension.ExtTypes = append(extension.ExtTypes, dstType) 11 | return "DecodeByExtension" 12 | } 13 | if dstType.Kind() != reflect.Ptr { 14 | panic("can only decode into pointer") 15 | } 16 | dstType = dstType.Elem() 17 | if dstType == byteArrayType { 18 | return "DecodeBinary" 19 | } 20 | if isEnumType(dstType) { 21 | return "DecodeEnum" 22 | } 23 | switch dstType.Kind() { 24 | case reflect.Slice: 25 | return "DecodeSlice" 26 | case reflect.Map: 27 | return "DecodeMap" 28 | case reflect.Struct: 29 | return "DecodeStruct" 30 | case reflect.Ptr: 31 | return "DecodePointer" 32 | } 33 | if _, isSimpleValue := simpleValueMap[dstType.Kind()]; isSimpleValue { 34 | return "DecodeSimpleValue" 35 | } 36 | panic("unsupported type") 37 | } 38 | 39 | var decodeAnything = generic.DefineFunc("DecodeAnything(dst DT, src ST)"). 40 | Param("EXT", "user provided extension"). 41 | Param("DT", "the dst type to copy into"). 42 | Param("ST", "the src type to copy from"). 43 | Generators("dispatchDecode", dispatchDecode). 44 | Source(` 45 | {{ $tmpl := dispatchDecode .EXT .DT }} 46 | {{ if eq $tmpl "DecodeByExtension" }} 47 | src.GetDecoder("{{ .DT|name }}").Decode(dst, src) 48 | {{ else }} 49 | {{ $decode := expand $tmpl "EXT" .EXT "DT" .DT "ST" .ST }} 50 | {{$decode}}(dst, src) 51 | {{ end }} 52 | `) 53 | -------------------------------------------------------------------------------- /binding/codegen/decode_binary.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | ) 6 | 7 | func init() { 8 | decodeAnything.ImportFunc(decodingBinary) 9 | } 10 | 11 | var decodingBinary = generic.DefineFunc( 12 | "DecodeBinary(dst DT, src ST)"). 13 | Param("EXT", "user provided extension"). 14 | Param("DT", "the dst type to copy into"). 15 | Param("ST", "the src type to copy from"). 16 | Source(` 17 | *dst = src.ReadBinary() 18 | `) -------------------------------------------------------------------------------- /binding/codegen/decode_enum.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | ) 6 | 7 | func init() { 8 | decodeAnything.ImportFunc(decodeEnum) 9 | } 10 | 11 | var decodeEnum = generic.DefineFunc( 12 | "DecodeEnum(dst DT, src ST)"). 13 | Param("EXT", "user provided extension"). 14 | Param("DT", "the dst type to copy into"). 15 | Param("ST", "the src type to copy from"). 16 | Source(` 17 | *dst = {{.DT|elem|name}}(src.ReadInt32()) 18 | `) 19 | -------------------------------------------------------------------------------- /binding/codegen/decode_map.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "reflect" 5 | "github.com/v2pro/wombat/generic" 6 | ) 7 | 8 | func init() { 9 | decodeAnything.ImportFunc(decodeMap) 10 | } 11 | 12 | var decodeMap = generic.DefineFunc( 13 | "DecodeMap(dst DT, src ST)"). 14 | Param("EXT", "user provided extension"). 15 | Param("DT", "the dst type to copy into"). 16 | Param("ST", "the src type to copy from"). 17 | ImportFunc(decodeAnything). 18 | Generators( 19 | "ptrMapElem", func(typ reflect.Type) reflect.Type { 20 | return reflect.PtrTo(typ.Elem().Elem()) 21 | }, "ptrMapKey", func(typ reflect.Type) reflect.Type { 22 | return reflect.PtrTo(typ.Elem().Key()) 23 | }). 24 | Source(` 25 | {{ $decodeKey := expand "DecodeAnything" "EXT" .EXT "DT" (.DT|ptrMapKey) "ST" .ST }} 26 | {{ $decodeElem := expand "DecodeAnything" "EXT" .EXT "DT" (.DT|ptrMapElem) "ST" .ST }} 27 | if *dst == nil { 28 | *dst = {{.DT|elem|name}}{} 29 | } 30 | _, _, length := src.ReadMapHeader() 31 | for i := 0; i < length; i++ { 32 | newKey := new({{.DT|elem|key|name}}) 33 | {{$decodeKey}}(newKey, src) 34 | newElem := new({{.DT|elem|elem|name}}) 35 | {{$decodeElem}}(newElem, src) 36 | (*dst)[*newKey] = *newElem 37 | }`) -------------------------------------------------------------------------------- /binding/codegen/decode_pointer.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | ) 6 | 7 | func init() { 8 | decodeAnything.ImportFunc(decodePointer) 9 | } 10 | 11 | var decodePointer = generic.DefineFunc( 12 | "DecodePointer(dst DT, src ST)"). 13 | Param("EXT", "user provided extension"). 14 | Param("DT", "the dst type to copy into"). 15 | Param("ST", "the src type to copy from"). 16 | ImportFunc(decodeAnything). 17 | Source(` 18 | {{ $decode := expand "DecodeAnything" "EXT" .EXT "DT" (.DT|elem) "ST" .ST }} 19 | defDst := new({{ .DT|elem|elem|name }}) 20 | {{$decode}}(defDst, src) 21 | *dst = defDst 22 | `) -------------------------------------------------------------------------------- /binding/codegen/decode_simple_value.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | "reflect" 6 | ) 7 | 8 | func init() { 9 | decodeAnything.ImportFunc(decodeSimpleValue) 10 | } 11 | 12 | var decodeSimpleValue = generic.DefineFunc( 13 | "DecodeSimpleValue(dst DT, src ST)"). 14 | Param("EXT", "user provided extension"). 15 | Param("DT", "the dst type to copy into"). 16 | Param("ST", "the src type to copy from"). 17 | Generators( 18 | "opFuncName", func(typ reflect.Type) string { 19 | funName := simpleValueMap[typ.Kind()] 20 | if funName == "" { 21 | panic(typ.String() + " is not simple value") 22 | } 23 | return funName 24 | }). 25 | Source(` 26 | *dst = {{.DT|elem|name}}(src.Read{{.DT|elem|opFuncName}}()) 27 | `) -------------------------------------------------------------------------------- /binding/codegen/decode_slice.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "reflect" 5 | "github.com/v2pro/wombat/generic" 6 | ) 7 | 8 | func init() { 9 | decodeAnything.ImportFunc(decodeSlice) 10 | } 11 | 12 | var decodeSlice = generic.DefineFunc( 13 | "DecodeSlice(dst DT, src ST)"). 14 | Param("EXT", "user provided extension"). 15 | Param("DT", "the dst type to copy into"). 16 | Param("ST", "the src type to copy from"). 17 | ImportFunc(decodeAnything). 18 | Generators( 19 | "ptrSliceElem", func(typ reflect.Type) reflect.Type { 20 | return reflect.PtrTo(typ.Elem().Elem()) 21 | }). 22 | Source(` 23 | {{ $decodeElem := expand "DecodeAnything" "EXT" .EXT "DT" (.DT|ptrSliceElem) "ST" .ST }} 24 | _, length := src.ReadListHeader() 25 | for i := 0; i < length; i++ { 26 | elem := new({{.DT|elem|elem|name}}) 27 | {{$decodeElem}}(elem, src) 28 | *dst = append(*dst, *elem) 29 | }`) -------------------------------------------------------------------------------- /binding/codegen/decode_struct.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | ) 6 | 7 | func init() { 8 | decodeAnything.ImportFunc(decodeStruct) 9 | } 10 | 11 | var decodeStruct = generic.DefineFunc( 12 | "DecodeStruct(dst DT, src ST)"). 13 | Param("EXT", "user provided extension"). 14 | Param("DT", "the dst type to copy into"). 15 | Param("ST", "the src type to copy from"). 16 | ImportFunc(decodeAnything). 17 | Generators( 18 | "calcBindings", calcBindings, 19 | "assignDecode", func(binding map[string]interface{}, decodeFuncName string) string { 20 | binding["decode"] = decodeFuncName 21 | return "" 22 | }). 23 | Source(` 24 | {{ $bindings := calcBindings (.DT|elem) }} 25 | {{ range $_, $binding := $bindings}} 26 | {{ $decode := expand "DecodeAnything" "EXT" $.EXT "DT" $binding.fieldType "ST" $.ST }} 27 | {{ assignDecode $binding $decode }} 28 | {{ end }} 29 | src.ReadStructHeader() 30 | for { 31 | fieldType, fieldId := src.ReadStructField() 32 | if fieldType == 0 { 33 | return 34 | } 35 | switch fieldId { 36 | {{ range $_, $binding := $bindings }} 37 | case {{ $binding.fieldId }}: 38 | {{$binding.decode}}(&dst.{{$binding.fieldName}}, src) 39 | {{ end }} 40 | default: 41 | src.Discard(fieldType) 42 | } 43 | }`) -------------------------------------------------------------------------------- /binding/codegen/encode.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import "github.com/v2pro/wombat/generic" 4 | 5 | var Encode = generic.DefineFunc("Encode(dst interface{}, src interface{})"). 6 | Param("EXT", "user provided extension"). 7 | Param("DT", "the dst type to copy into"). 8 | Param("ST", "the src type to copy from"). 9 | ImportFunc(encodeAnything). 10 | ImportPackage("reflect"). 11 | Declare("var typeOf = reflect.TypeOf"). 12 | Source(` 13 | {{ $decode := expand "EncodeAnything" "EXT" .EXT "DT" .DT "ST" .ST }} 14 | stream := dst.({{.DT|name}}) 15 | {{ range $extType := .EXT.ExtTypes }} 16 | if stream.GetEncoder("{{$extType|name}}") == nil { 17 | stream.PrepareEncoder(reflect.TypeOf((*{{$extType|name}})(nil)).Elem()) 18 | } 19 | {{ end }} 20 | {{$decode}}(stream, src.({{.ST|name}})) 21 | `) 22 | -------------------------------------------------------------------------------- /binding/codegen/encode_anything.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "reflect" 5 | "github.com/v2pro/wombat/generic" 6 | "github.com/thrift-iterator/go/protocol" 7 | ) 8 | 9 | func dispatchEncode(extension *Extension, srcType reflect.Type) (string, protocol.TType) { 10 | extEncoder := extension.EncoderOf(srcType) 11 | if extEncoder != nil { 12 | extension.ExtTypes = append(extension.ExtTypes, srcType) 13 | return "EncodeByExtension", extEncoder.ThriftType() 14 | } 15 | if srcType == byteArrayType { 16 | return "EncodeBinary", protocol.TypeString 17 | } 18 | if isEnumType(srcType) { 19 | return "EncodeEnum", protocol.TypeI32 20 | } 21 | switch srcType.Kind() { 22 | case reflect.Slice: 23 | return "EncodeSlice", protocol.TypeList 24 | case reflect.Map: 25 | return "EncodeMap", protocol.TypeMap 26 | case reflect.Struct: 27 | return "EncodeStruct", protocol.TypeStruct 28 | case reflect.Ptr: 29 | _, ttype := dispatchEncode(extension, srcType.Elem()) 30 | return "EncodePointer", ttype 31 | } 32 | return "EncodeSimpleValue", thriftTypeMap[srcType.Kind()] 33 | } 34 | 35 | func dispatchThriftType(extension *Extension, srcType reflect.Type) int { 36 | _, ttype := dispatchEncode(extension, srcType) 37 | return int(ttype) 38 | } 39 | 40 | var encodeAnything = generic.DefineFunc("EncodeAnything(dst DT, src ST)"). 41 | Param("EXT", "user provided extension"). 42 | Param("DT", "the dst type to copy into"). 43 | Param("ST", "the src type to copy from"). 44 | Generators( 45 | "dispatchEncode", func(extension *Extension, srcType reflect.Type) string { 46 | encode, _ := dispatchEncode(extension, srcType) 47 | return encode 48 | }). 49 | Source(` 50 | {{ $tmpl := dispatchEncode .EXT .ST }} 51 | {{ if eq $tmpl "EncodeByExtension" }} 52 | dst.GetEncoder("{{ .ST|name }}").Encode(src, dst) 53 | {{ else }} 54 | {{ $encode := expand $tmpl "EXT" .EXT "DT" .DT "ST" .ST }} 55 | {{$encode}}(dst, src) 56 | {{ end }} 57 | `) 58 | -------------------------------------------------------------------------------- /binding/codegen/encode_binary.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | ) 6 | 7 | func init() { 8 | encodeAnything.ImportFunc(encodeBinary) 9 | } 10 | 11 | var encodeBinary = generic.DefineFunc( 12 | "EncodeBinary(dst DT, src ST)"). 13 | Param("EXT", "user provided extension"). 14 | Param("DT", "the dst type to copy into"). 15 | Param("ST", "the src type to copy from"). 16 | Source(` 17 | dst.WriteBinary(src) 18 | `) -------------------------------------------------------------------------------- /binding/codegen/encode_enum.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | ) 6 | 7 | func init() { 8 | encodeAnything.ImportFunc(encodeEnum) 9 | } 10 | 11 | var encodeEnum = generic.DefineFunc( 12 | "EncodeEnum(dst DT, src ST)"). 13 | Param("EXT", "user provided extension"). 14 | Param("DT", "the dst type to copy into"). 15 | Param("ST", "the src type to copy from"). 16 | Source(` 17 | dst.WriteInt32(int32(src)) 18 | `) 19 | -------------------------------------------------------------------------------- /binding/codegen/encode_map.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | ) 6 | 7 | func init() { 8 | encodeAnything.ImportFunc(encodeMap) 9 | } 10 | 11 | var encodeMap = generic.DefineFunc( 12 | "EncodeMap(dst DT, src ST)"). 13 | Param("EXT", "user provided extension"). 14 | Param("DT", "the dst type to copy into"). 15 | Param("ST", "the src type to copy from"). 16 | ImportFunc(encodeAnything). 17 | Generators( 18 | "thriftType", dispatchThriftType). 19 | Source(` 20 | {{ $encodeKey := expand "EncodeAnything" "EXT" .EXT "DT" .DT "ST" (.ST|key) }} 21 | {{ $encodeElem := expand "EncodeAnything" "EXT" .EXT "DT" .DT "ST" (.ST|elem) }} 22 | dst.WriteMapHeader({{.ST|key|thriftType .EXT}}, {{.ST|elem|thriftType .EXT}}, len(src)) 23 | for key, elem := range src { 24 | {{$encodeKey}}(dst, key) 25 | {{$encodeElem}}(dst, elem) 26 | }`) -------------------------------------------------------------------------------- /binding/codegen/encode_pointer.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | ) 6 | 7 | func init() { 8 | encodeAnything.ImportFunc(encodePointer) 9 | } 10 | 11 | var encodePointer = generic.DefineFunc( 12 | "EncodePointer(dst DT, src ST)"). 13 | Param("EXT", "user provided extension"). 14 | Param("DT", "the dst type to copy into"). 15 | Param("ST", "the src type to copy from"). 16 | ImportFunc(encodeAnything). 17 | Source(` 18 | {{ $encode := expand "EncodeAnything" "EXT" .EXT "DT" .DT "ST" (.ST|elem) }} 19 | {{$encode}}(dst, *src) 20 | `) -------------------------------------------------------------------------------- /binding/codegen/encode_simple_value.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | "reflect" 6 | ) 7 | 8 | func init() { 9 | encodeAnything.ImportFunc(encodeSimpleValue) 10 | } 11 | 12 | var encodeSimpleValue = generic.DefineFunc( 13 | "EncodeSimpleValue(dst DT, src ST)"). 14 | Param("EXT", "user provided extension"). 15 | Param("DT", "the dst type to copy into"). 16 | Param("ST", "the src type to copy from"). 17 | Generators( 18 | "opFuncName", func(typ reflect.Type) string { 19 | funName := simpleValueMap[typ.Kind()] 20 | if funName == "" { 21 | panic(typ.String() + " is not simple value") 22 | } 23 | return funName 24 | }). 25 | Source(` 26 | dst.Write{{.ST|opFuncName}}(src) 27 | `) -------------------------------------------------------------------------------- /binding/codegen/encode_slice.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | ) 6 | 7 | func init() { 8 | encodeAnything.ImportFunc(encodeSlice) 9 | } 10 | 11 | var encodeSlice = generic.DefineFunc( 12 | "EncodeSlice(dst DT, src ST)"). 13 | Param("EXT", "user provided extension"). 14 | Param("DT", "the dst type to copy into"). 15 | Param("ST", "the src type to copy from"). 16 | ImportFunc(encodeAnything). 17 | Generators( 18 | "thriftType", dispatchThriftType). 19 | Source(` 20 | {{ $encodeElem := expand "EncodeAnything" "EXT" .EXT "DT" .DT "ST" (.ST|elem) }} 21 | dst.WriteListHeader({{.ST|elem|thriftType .EXT }}, len(src)) 22 | for _, elem := range src { 23 | {{$encodeElem}}(dst, elem) 24 | } 25 | `) 26 | -------------------------------------------------------------------------------- /binding/codegen/encode_struct.go: -------------------------------------------------------------------------------- 1 | package codegen 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | ) 6 | 7 | func init() { 8 | encodeAnything.ImportFunc(encodeStruct) 9 | } 10 | 11 | var encodeStruct = generic.DefineFunc( 12 | "EncodeStruct(dst DT, src ST)"). 13 | Param("EXT", "user provided extension"). 14 | Param("DT", "the dst type to copy into"). 15 | Param("ST", "the src type to copy from"). 16 | ImportFunc(encodeAnything). 17 | Generators( 18 | "calcBindings", calcBindings, 19 | "assignEncode", func(binding map[string]interface{}, encodeFuncName string) string { 20 | binding["encode"] = encodeFuncName 21 | return "" 22 | }, 23 | "thriftType", dispatchThriftType). 24 | Source(` 25 | {{ $bindings := calcBindings .ST }} 26 | dst.WriteStructHeader() 27 | {{ range $_, $binding := $bindings}} 28 | {{ $encode := expand "EncodeAnything" "EXT" $.EXT "DT" $.DT "ST" $binding.fieldType }} 29 | dst.WriteStructField({{$binding.fieldType|thriftType .EXT}}, {{$binding.fieldId}}) 30 | {{$encode}}(dst, &src.{{$binding.fieldName}}) 31 | {{ end }} 32 | dst.WriteStructFieldStop() 33 | `) 34 | -------------------------------------------------------------------------------- /binding/reflection/decode.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "reflect" 5 | "github.com/thrift-iterator/go/spi" 6 | "unsafe" 7 | "github.com/thrift-iterator/go/protocol" 8 | "strings" 9 | "unicode" 10 | "strconv" 11 | ) 12 | 13 | var byteSliceType = reflect.TypeOf(([]byte)(nil)) 14 | 15 | func DecoderOf(extension spi.Extension, valType reflect.Type) spi.ValDecoder { 16 | if valType.Kind() != reflect.Ptr { 17 | return &valDecoderAdapter{&unknownDecoder{ 18 | prefix: "unmarshal into non-pointer type", valType: valType}} 19 | } 20 | return &valDecoderAdapter{decoderOf(extension, "", valType.Elem())} 21 | } 22 | 23 | func decoderOf(extension spi.Extension, prefix string, valType reflect.Type) internalDecoder { 24 | extDecoder := extension.DecoderOf(reflect.PtrTo(valType)) 25 | if extDecoder != nil { 26 | valObj := reflect.New(valType).Interface() 27 | valEmptyInterface := *(*emptyInterface)(unsafe.Pointer(&valObj)) 28 | return &internalDecoderAdapter{valEmptyInterface: valEmptyInterface, decoder: extDecoder} 29 | } 30 | if byteSliceType == valType { 31 | return &binaryDecoder{} 32 | } 33 | if isEnumType(valType) { 34 | return &int32Decoder{} 35 | } 36 | switch valType.Kind() { 37 | case reflect.Bool: 38 | return &boolDecoder{} 39 | case reflect.Float64: 40 | return &float64Decoder{} 41 | case reflect.Int: 42 | return &intDecoder{} 43 | case reflect.Uint: 44 | return &uintDecoder{} 45 | case reflect.Int8: 46 | return &int8Decoder{} 47 | case reflect.Uint8: 48 | return &uint8Decoder{} 49 | case reflect.Int16: 50 | return &int16Decoder{} 51 | case reflect.Uint16: 52 | return &uint16Decoder{} 53 | case reflect.Int32: 54 | return &int32Decoder{} 55 | case reflect.Uint32: 56 | return &uint32Decoder{} 57 | case reflect.Int64: 58 | return &int64Decoder{} 59 | case reflect.Uint64: 60 | return &uint64Decoder{} 61 | case reflect.String: 62 | return &stringDecoder{} 63 | case reflect.Ptr: 64 | return &pointerDecoder{ 65 | valType: valType.Elem(), 66 | valDecoder: decoderOf(extension, prefix+" [ptrElem]", valType.Elem()), 67 | } 68 | case reflect.Slice: 69 | return &sliceDecoder{ 70 | elemType: valType.Elem(), 71 | sliceType: valType, 72 | elemDecoder: decoderOf(extension, prefix+" [sliceElem]", valType.Elem()), 73 | } 74 | case reflect.Map: 75 | sampleObj := reflect.New(valType).Interface() 76 | return &mapDecoder{ 77 | keyType: valType.Key(), 78 | keyDecoder: decoderOf(extension, prefix+" [mapKey]", valType.Key()), 79 | elemType: valType.Elem(), 80 | elemDecoder: decoderOf(extension, prefix+" [mapElem]", valType.Elem()), 81 | mapType: valType, 82 | mapInterface: *(*emptyInterface)(unsafe.Pointer(&sampleObj)), 83 | } 84 | case reflect.Struct: 85 | decoderFields := make([]structDecoderField, 0, valType.NumField()) 86 | decoderFieldMap := map[protocol.FieldId]structDecoderField{} 87 | for i := 0; i < valType.NumField(); i++ { 88 | refField := valType.Field(i) 89 | fieldId := parseFieldId(refField) 90 | if fieldId == -1 { 91 | continue 92 | } 93 | decoderField := structDecoderField{ 94 | offset: refField.Offset, 95 | fieldId: fieldId, 96 | decoder: decoderOf(extension, prefix + " " + refField.Name, refField.Type), 97 | } 98 | decoderFields = append(decoderFields, decoderField) 99 | decoderFieldMap[fieldId] = decoderField 100 | } 101 | return &structDecoder{ 102 | fields: decoderFields, 103 | fieldMap: decoderFieldMap, 104 | } 105 | } 106 | return &unknownDecoder{prefix, valType} 107 | } 108 | 109 | func isEnumType(valType reflect.Type) bool { 110 | if valType.Kind() != reflect.Int64 { 111 | return false 112 | } 113 | _, hasStringMethod := valType.MethodByName("String") 114 | return hasStringMethod 115 | } 116 | 117 | func parseFieldId(refField reflect.StructField) protocol.FieldId { 118 | if !unicode.IsUpper(rune(refField.Name[0])) { 119 | return -1 120 | } 121 | thriftTag := refField.Tag.Get("thrift") 122 | if thriftTag == "" { 123 | return -1 124 | } 125 | parts := strings.Split(thriftTag, ",") 126 | if len(parts) < 2 { 127 | return -1 128 | } 129 | fieldId, err := strconv.Atoi(parts[1]) 130 | if err != nil { 131 | return -1 132 | } 133 | return protocol.FieldId(fieldId) 134 | } 135 | 136 | type unknownDecoder struct { 137 | prefix string 138 | valType reflect.Type 139 | } 140 | 141 | func (decoder *unknownDecoder) decode(ptr unsafe.Pointer, iterator spi.Iterator) { 142 | iterator.ReportError("decode "+decoder.prefix, "do not know how to decode "+decoder.valType.String()) 143 | } 144 | -------------------------------------------------------------------------------- /binding/reflection/decode_map.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "reflect" 5 | "unsafe" 6 | "github.com/thrift-iterator/go/spi" 7 | ) 8 | 9 | type mapDecoder struct { 10 | mapType reflect.Type 11 | mapInterface emptyInterface 12 | keyType reflect.Type 13 | keyDecoder internalDecoder 14 | elemType reflect.Type 15 | elemDecoder internalDecoder 16 | } 17 | 18 | func (decoder *mapDecoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 19 | mapInterface := decoder.mapInterface 20 | mapInterface.word = ptr 21 | realInterface := (*interface{})(unsafe.Pointer(&mapInterface)) 22 | mapVal := reflect.ValueOf(*realInterface).Elem() 23 | if mapVal.IsNil() { 24 | mapVal.Set(reflect.MakeMap(decoder.mapType)) 25 | } 26 | _, _, length := iter.ReadMapHeader() 27 | for i := 0; i < length; i++ { 28 | keyVal := reflect.New(decoder.keyType) 29 | decoder.keyDecoder.decode(unsafe.Pointer(keyVal.Pointer()), iter) 30 | elemVal := reflect.New(decoder.elemType) 31 | decoder.elemDecoder.decode(unsafe.Pointer(elemVal.Pointer()), iter) 32 | mapVal.SetMapIndex(keyVal.Elem(), elemVal.Elem()) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /binding/reflection/decode_pointer.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "unsafe" 5 | "github.com/thrift-iterator/go/spi" 6 | "reflect" 7 | ) 8 | 9 | type pointerDecoder struct { 10 | valType reflect.Type 11 | valDecoder internalDecoder 12 | } 13 | 14 | func (decoder *pointerDecoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 15 | value := reflect.New(decoder.valType).Interface() 16 | newPtr := (*emptyInterface)(unsafe.Pointer(&value)).word 17 | decoder.valDecoder.decode(newPtr, iter) 18 | *(*unsafe.Pointer)(ptr) = newPtr 19 | } 20 | -------------------------------------------------------------------------------- /binding/reflection/decode_simple_value.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "unsafe" 6 | ) 7 | 8 | type binaryDecoder struct { 9 | } 10 | 11 | func (decoder *binaryDecoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 12 | *(*[]byte)(ptr) = iter.ReadBinary() 13 | } 14 | 15 | type boolDecoder struct { 16 | } 17 | 18 | func (decoder *boolDecoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 19 | *(*bool)(ptr) = iter.ReadBool() 20 | } 21 | 22 | type float64Decoder struct { 23 | } 24 | 25 | func (decoder *float64Decoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 26 | *(*float64)(ptr) = iter.ReadFloat64() 27 | } 28 | 29 | type int8Decoder struct { 30 | } 31 | 32 | func (decoder *int8Decoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 33 | *(*int8)(ptr) = iter.ReadInt8() 34 | } 35 | 36 | type uint8Decoder struct { 37 | } 38 | 39 | func (decoder *uint8Decoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 40 | *(*uint8)(ptr) = iter.ReadUint8() 41 | } 42 | 43 | type int16Decoder struct { 44 | } 45 | 46 | func (decoder *int16Decoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 47 | *(*int16)(ptr) = iter.ReadInt16() 48 | } 49 | 50 | type uint16Decoder struct { 51 | } 52 | 53 | func (decoder *uint16Decoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 54 | *(*uint16)(ptr) = iter.ReadUint16() 55 | } 56 | 57 | type int32Decoder struct { 58 | } 59 | 60 | func (decoder *int32Decoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 61 | *(*int32)(ptr) = iter.ReadInt32() 62 | } 63 | 64 | type uint32Decoder struct { 65 | } 66 | 67 | func (decoder *uint32Decoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 68 | *(*uint32)(ptr) = iter.ReadUint32() 69 | } 70 | 71 | type int64Decoder struct { 72 | } 73 | 74 | func (decoder *int64Decoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 75 | *(*int64)(ptr) = iter.ReadInt64() 76 | } 77 | 78 | type uint64Decoder struct { 79 | } 80 | 81 | func (decoder *uint64Decoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 82 | *(*uint64)(ptr) = iter.ReadUint64() 83 | } 84 | 85 | type intDecoder struct { 86 | } 87 | 88 | func (decoder *intDecoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 89 | *(*int)(ptr) = iter.ReadInt() 90 | } 91 | 92 | type uintDecoder struct { 93 | } 94 | 95 | func (decoder *uintDecoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 96 | *(*uint)(ptr) = iter.ReadUint() 97 | } 98 | 99 | type stringDecoder struct { 100 | } 101 | 102 | func (decoder *stringDecoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 103 | *(*string)(ptr) = iter.ReadString() 104 | } 105 | 106 | -------------------------------------------------------------------------------- /binding/reflection/decode_slice.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "unsafe" 5 | "github.com/thrift-iterator/go/spi" 6 | "reflect" 7 | ) 8 | 9 | type sliceDecoder struct { 10 | sliceType reflect.Type 11 | elemType reflect.Type 12 | elemDecoder internalDecoder 13 | } 14 | 15 | func (decoder *sliceDecoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 16 | slice := (*sliceHeader)(ptr) 17 | slice.Len = 0 18 | offset := uintptr(0) 19 | _, length := iter.ReadListHeader() 20 | 21 | if slice.Cap < length { 22 | newVal := reflect.MakeSlice(decoder.sliceType, 0, length) 23 | slice.Data = unsafe.Pointer(newVal.Pointer()) 24 | slice.Cap = length 25 | } 26 | 27 | for i := 0; i < length; i++ { 28 | decoder.elemDecoder.decode(unsafe.Pointer(uintptr(slice.Data)+offset), iter) 29 | offset += decoder.elemType.Size() 30 | slice.Len += 1 31 | } 32 | } 33 | 34 | // grow grows the slice s so that it can hold extra more values, allocating 35 | // more capacity if needed. It also returns the old and new slice lengths. 36 | func growOne(slice *sliceHeader, sliceType reflect.Type, elementType reflect.Type) { 37 | newLen := slice.Len + 1 38 | if newLen <= slice.Cap { 39 | slice.Len = newLen 40 | return 41 | } 42 | newCap := slice.Cap 43 | if newCap == 0 { 44 | newCap = 1 45 | } else { 46 | for newCap < newLen { 47 | if slice.Len < 1024 { 48 | newCap += newCap 49 | } else { 50 | newCap += newCap / 4 51 | } 52 | } 53 | } 54 | newVal := reflect.MakeSlice(sliceType, newLen, newCap) 55 | dst := unsafe.Pointer(newVal.Pointer()) 56 | // copy old array into new array 57 | originalBytesCount := slice.Len * int(elementType.Size()) 58 | srcSliceHeader := (unsafe.Pointer)(&sliceHeader{slice.Data, originalBytesCount, originalBytesCount}) 59 | dstSliceHeader := (unsafe.Pointer)(&sliceHeader{dst, originalBytesCount, originalBytesCount}) 60 | copy(*(*[]byte)(dstSliceHeader), *(*[]byte)(srcSliceHeader)) 61 | slice.Data = dst 62 | slice.Len = newLen 63 | slice.Cap = newCap 64 | } -------------------------------------------------------------------------------- /binding/reflection/decode_struct.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/protocol" 5 | "unsafe" 6 | "github.com/thrift-iterator/go/spi" 7 | ) 8 | 9 | type structDecoder struct { 10 | fields []structDecoderField 11 | fieldMap map[protocol.FieldId]structDecoderField 12 | } 13 | 14 | type structDecoderField struct { 15 | offset uintptr 16 | fieldId protocol.FieldId 17 | decoder internalDecoder 18 | } 19 | 20 | func (decoder *structDecoder) decode(ptr unsafe.Pointer, iter spi.Iterator) { 21 | iter.ReadStructHeader() 22 | for _, field := range decoder.fields { 23 | fieldType, fieldId := iter.ReadStructField() 24 | if field.fieldId == fieldId { 25 | field.decoder.decode(unsafe.Pointer(uintptr(ptr) + field.offset), iter) 26 | } else { 27 | decoder.decodeByMap(ptr, iter, fieldType, fieldId) 28 | return 29 | } 30 | } 31 | fieldType, fieldId := iter.ReadStructField() 32 | decoder.decodeByMap(ptr, iter, fieldType, fieldId) 33 | } 34 | 35 | func (decoder *structDecoder) decodeByMap(ptr unsafe.Pointer, iter spi.Iterator, 36 | fieldType protocol.TType, fieldId protocol.FieldId) { 37 | for { 38 | if protocol.TypeStop == fieldType { 39 | return 40 | } 41 | field, isFound := decoder.fieldMap[fieldId] 42 | if isFound { 43 | field.decoder.decode(unsafe.Pointer(uintptr(ptr) + field.offset), iter) 44 | } else { 45 | iter.Discard(fieldType) 46 | } 47 | fieldType, fieldId = iter.ReadStructField() 48 | } 49 | } -------------------------------------------------------------------------------- /binding/reflection/encode.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/protocol" 5 | "github.com/thrift-iterator/go/spi" 6 | "reflect" 7 | "unsafe" 8 | ) 9 | 10 | func EncoderOf(extension spi.Extension, valType reflect.Type) spi.ValEncoder { 11 | isPtr := valType.Kind() == reflect.Ptr 12 | isOnePtrArray := valType.Kind() == reflect.Array && valType.Len() == 1 && 13 | valType.Elem().Kind() == reflect.Ptr 14 | isOnePtrStruct := valType.Kind() == reflect.Struct && valType.NumField() == 1 && 15 | valType.Field(0).Type.Kind() == reflect.Ptr 16 | isOneMapStruct := valType.Kind() == reflect.Struct && valType.NumField() == 1 && 17 | valType.Field(0).Type.Kind() == reflect.Map 18 | if isPtr || isOnePtrArray || isOnePtrStruct || isOneMapStruct { 19 | return &ptrEncoderAdapter{encoderOf(extension, "", valType)} 20 | } 21 | return &valEncoderAdapter{encoderOf(extension, "", valType)} 22 | } 23 | 24 | func encoderOf(extension spi.Extension, prefix string, valType reflect.Type) internalEncoder { 25 | extEncoder := extension.EncoderOf(valType) 26 | if extEncoder != nil { 27 | valObj := reflect.New(valType).Elem().Interface() 28 | valEmptyInterface := *(*emptyInterface)(unsafe.Pointer(&valObj)) 29 | return &internalEncoderAdapter{valEmptyInterface: valEmptyInterface, encoder: extEncoder} 30 | } 31 | if byteSliceType == valType { 32 | return &binaryEncoder{} 33 | } 34 | if isEnumType(valType) { 35 | return &int32Encoder{} 36 | } 37 | switch valType.Kind() { 38 | case reflect.String: 39 | return &stringEncoder{} 40 | case reflect.Bool: 41 | return &boolEncoder{} 42 | case reflect.Int8: 43 | return &int8Encoder{} 44 | case reflect.Uint8: 45 | return &uint8Encoder{} 46 | case reflect.Int16: 47 | return &int16Encoder{} 48 | case reflect.Uint16: 49 | return &uint16Encoder{} 50 | case reflect.Int32: 51 | return &int32Encoder{} 52 | case reflect.Uint32: 53 | return &uint32Encoder{} 54 | case reflect.Int64: 55 | return &int64Encoder{} 56 | case reflect.Uint64: 57 | return &uint64Encoder{} 58 | case reflect.Int: 59 | return &intEncoder{} 60 | case reflect.Uint: 61 | return &uintEncoder{} 62 | case reflect.Float32: 63 | return &float32Encoder{} 64 | case reflect.Float64: 65 | return &float64Encoder{} 66 | case reflect.Slice: 67 | return &sliceEncoder{ 68 | sliceType: valType, 69 | elemType: valType.Elem(), 70 | elemEncoder: encoderOf(extension, prefix+" [sliceElem]", valType.Elem()), 71 | } 72 | case reflect.Map: 73 | sampleObj := reflect.New(valType).Elem().Interface() 74 | elemType := valType.Elem() 75 | if elemType.Kind() == reflect.Ptr { 76 | elemType = elemType.Elem() 77 | } 78 | return &mapEncoder{ 79 | keyEncoder: encoderOf(extension, prefix+" [mapKey]", valType.Key()), 80 | elemEncoder: encoderOf(extension, prefix+" [mapElem]", elemType), 81 | mapInterface: *(*emptyInterface)(unsafe.Pointer(&sampleObj)), 82 | } 83 | case reflect.Struct: 84 | encoderFields := make([]structEncoderField, 0, valType.NumField()) 85 | for i := 0; i < valType.NumField(); i++ { 86 | refField := valType.Field(i) 87 | fieldId := parseFieldId(refField) 88 | if fieldId == -1 { 89 | continue 90 | } 91 | encoderField := structEncoderField{ 92 | offset: refField.Offset, 93 | fieldId: fieldId, 94 | encoder: encoderOf(extension, prefix+" "+refField.Name, refField.Type), 95 | } 96 | encoderFields = append(encoderFields, encoderField) 97 | } 98 | return &structEncoder{ 99 | fields: encoderFields, 100 | } 101 | case reflect.Ptr: 102 | return &pointerEncoder{ 103 | valType: valType.Elem(), 104 | valEncoder: encoderOf(extension, prefix+" [ptrElem]", valType.Elem()), 105 | } 106 | } 107 | return &unknownEncoder{prefix, valType} 108 | } 109 | 110 | type unknownEncoder struct { 111 | prefix string 112 | valType reflect.Type 113 | } 114 | 115 | func (encoder *unknownEncoder) encode(ptr unsafe.Pointer, stream spi.Stream) { 116 | stream.ReportError("decode "+encoder.prefix, "do not know how to encode "+encoder.valType.String()) 117 | } 118 | 119 | func (encoder *unknownEncoder) thriftType() protocol.TType { 120 | return protocol.TypeStop 121 | } 122 | -------------------------------------------------------------------------------- /binding/reflection/encode_map.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/protocol" 5 | "github.com/thrift-iterator/go/spi" 6 | "reflect" 7 | "unsafe" 8 | ) 9 | 10 | type mapEncoder struct { 11 | mapInterface emptyInterface 12 | keyEncoder internalEncoder 13 | elemEncoder internalEncoder 14 | } 15 | 16 | func (encoder *mapEncoder) encode(ptr unsafe.Pointer, stream spi.Stream) { 17 | mapInterface := encoder.mapInterface 18 | mapInterface.word = ptr 19 | realInterface := (*interface{})(unsafe.Pointer(&mapInterface)) 20 | mapVal := reflect.ValueOf(*realInterface) 21 | keys := mapVal.MapKeys() 22 | stream.WriteMapHeader(encoder.keyEncoder.thriftType(), encoder.elemEncoder.thriftType(), len(keys)) 23 | for _, key := range keys { 24 | keyObj := key.Interface() 25 | keyInf := (*emptyInterface)(unsafe.Pointer(&keyObj)) 26 | encoder.keyEncoder.encode(keyInf.word, stream) 27 | elem := mapVal.MapIndex(key) 28 | elemObj := elem.Interface() 29 | elemInf := (*emptyInterface)(unsafe.Pointer(&elemObj)) 30 | encoder.elemEncoder.encode(elemInf.word, stream) 31 | } 32 | } 33 | 34 | func (encoder *mapEncoder) thriftType() protocol.TType { 35 | return protocol.TypeMap 36 | } 37 | -------------------------------------------------------------------------------- /binding/reflection/encode_pointer.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/protocol" 5 | "github.com/thrift-iterator/go/spi" 6 | "reflect" 7 | "unsafe" 8 | ) 9 | 10 | type pointerEncoder struct { 11 | valType reflect.Type 12 | valEncoder internalEncoder 13 | } 14 | 15 | func (encoder *pointerEncoder) encode(ptr unsafe.Pointer, stream spi.Stream) { 16 | valPtr := *(*unsafe.Pointer)(ptr) 17 | if encoder.valType.Kind() == reflect.Map { 18 | valPtr = *(*unsafe.Pointer)(valPtr) 19 | } 20 | encoder.valEncoder.encode(valPtr, stream) 21 | } 22 | 23 | func (encoder *pointerEncoder) thriftType() protocol.TType { 24 | return encoder.valEncoder.thriftType() 25 | } 26 | -------------------------------------------------------------------------------- /binding/reflection/encode_simple_value.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "unsafe" 5 | "github.com/thrift-iterator/go/spi" 6 | "github.com/thrift-iterator/go/protocol" 7 | ) 8 | 9 | type binaryEncoder struct { 10 | } 11 | 12 | func (encoder *binaryEncoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 13 | iter.WriteBinary(*(*[]byte)(ptr)) 14 | } 15 | 16 | func (encoder *binaryEncoder) thriftType() protocol.TType { 17 | return protocol.TypeString 18 | } 19 | 20 | type stringEncoder struct { 21 | } 22 | 23 | func (encoder *stringEncoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 24 | iter.WriteString(*(*string)(ptr)) 25 | } 26 | 27 | func (encoder *stringEncoder) thriftType() protocol.TType { 28 | return protocol.TypeString 29 | } 30 | 31 | type boolEncoder struct { 32 | } 33 | 34 | func (encoder *boolEncoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 35 | iter.WriteBool(*(*bool)(ptr)) 36 | } 37 | 38 | func (encoder *boolEncoder) thriftType() protocol.TType { 39 | return protocol.TypeBool 40 | } 41 | 42 | type int8Encoder struct { 43 | } 44 | 45 | func (encoder *int8Encoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 46 | iter.WriteInt8(*(*int8)(ptr)) 47 | } 48 | 49 | func (encoder *int8Encoder) thriftType() protocol.TType { 50 | return protocol.TypeI08 51 | } 52 | 53 | type uint8Encoder struct { 54 | } 55 | 56 | func (encoder *uint8Encoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 57 | iter.WriteUint8(*(*uint8)(ptr)) 58 | } 59 | 60 | func (encoder *uint8Encoder) thriftType() protocol.TType { 61 | return protocol.TypeI08 62 | } 63 | 64 | type int16Encoder struct { 65 | } 66 | 67 | func (encoder *int16Encoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 68 | iter.WriteInt16(*(*int16)(ptr)) 69 | } 70 | 71 | func (encoder *int16Encoder) thriftType() protocol.TType { 72 | return protocol.TypeI16 73 | } 74 | 75 | type uint16Encoder struct { 76 | } 77 | 78 | func (encoder *uint16Encoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 79 | iter.WriteUint16(*(*uint16)(ptr)) 80 | } 81 | 82 | func (encoder *uint16Encoder) thriftType() protocol.TType { 83 | return protocol.TypeI16 84 | } 85 | 86 | type int32Encoder struct { 87 | } 88 | 89 | func (encoder *int32Encoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 90 | iter.WriteInt32(*(*int32)(ptr)) 91 | } 92 | 93 | func (encoder *int32Encoder) thriftType() protocol.TType { 94 | return protocol.TypeI32 95 | } 96 | 97 | type uint32Encoder struct { 98 | } 99 | 100 | func (encoder *uint32Encoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 101 | iter.WriteUint32(*(*uint32)(ptr)) 102 | } 103 | 104 | func (encoder *uint32Encoder) thriftType() protocol.TType { 105 | return protocol.TypeI32 106 | } 107 | 108 | type int64Encoder struct { 109 | } 110 | 111 | func (encoder *int64Encoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 112 | iter.WriteInt64(*(*int64)(ptr)) 113 | } 114 | 115 | func (encoder *int64Encoder) thriftType() protocol.TType { 116 | return protocol.TypeI64 117 | } 118 | 119 | type uint64Encoder struct { 120 | } 121 | 122 | func (encoder *uint64Encoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 123 | iter.WriteUint64(*(*uint64)(ptr)) 124 | } 125 | 126 | func (encoder *uint64Encoder) thriftType() protocol.TType { 127 | return protocol.TypeI64 128 | } 129 | 130 | type intEncoder struct { 131 | } 132 | 133 | func (encoder *intEncoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 134 | iter.WriteInt(*(*int)(ptr)) 135 | } 136 | 137 | func (encoder *intEncoder) thriftType() protocol.TType { 138 | return protocol.TypeI64 139 | } 140 | 141 | type uintEncoder struct { 142 | } 143 | 144 | func (encoder *uintEncoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 145 | iter.WriteUint(*(*uint)(ptr)) 146 | } 147 | 148 | func (encoder *uintEncoder) thriftType() protocol.TType { 149 | return protocol.TypeI64 150 | } 151 | 152 | type float64Encoder struct { 153 | } 154 | 155 | func (encoder *float64Encoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 156 | iter.WriteFloat64(*(*float64)(ptr)) 157 | } 158 | 159 | func (encoder *float64Encoder) thriftType() protocol.TType { 160 | return protocol.TypeDouble 161 | } 162 | 163 | type float32Encoder struct { 164 | } 165 | 166 | func (encoder *float32Encoder) encode(ptr unsafe.Pointer, iter spi.Stream) { 167 | iter.WriteFloat64(float64(*(*float32)(ptr))) 168 | } 169 | 170 | func (encoder *float32Encoder) thriftType() protocol.TType { 171 | return protocol.TypeDouble 172 | } 173 | -------------------------------------------------------------------------------- /binding/reflection/encode_slice.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "unsafe" 5 | "github.com/thrift-iterator/go/spi" 6 | "reflect" 7 | "github.com/thrift-iterator/go/protocol" 8 | ) 9 | 10 | type sliceEncoder struct { 11 | sliceType reflect.Type 12 | elemType reflect.Type 13 | elemEncoder internalEncoder 14 | } 15 | 16 | func (encoder *sliceEncoder) encode(ptr unsafe.Pointer, stream spi.Stream) { 17 | slice := (*sliceHeader)(ptr) 18 | stream.WriteListHeader(encoder.elemEncoder.thriftType(), slice.Len) 19 | offset := uintptr(slice.Data) 20 | var addr unsafe.Pointer 21 | for i := 0; i < slice.Len; i++ { 22 | addr = unsafe.Pointer(offset) 23 | if encoder.elemType.Kind() == reflect.Map { 24 | addr = unsafe.Pointer((uintptr)(*(*uint64)(addr))) 25 | } 26 | encoder.elemEncoder.encode(addr, stream) 27 | offset += encoder.elemType.Size() 28 | } 29 | } 30 | 31 | func (encoder *sliceEncoder) thriftType() protocol.TType { 32 | return protocol.TypeList 33 | } 34 | -------------------------------------------------------------------------------- /binding/reflection/encode_struct.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/protocol" 5 | "github.com/thrift-iterator/go/spi" 6 | "unsafe" 7 | ) 8 | 9 | type structEncoder struct { 10 | fields []structEncoderField 11 | } 12 | 13 | type structEncoderField struct { 14 | offset uintptr 15 | fieldId protocol.FieldId 16 | encoder internalEncoder 17 | } 18 | 19 | func (encoder *structEncoder) encode(ptr unsafe.Pointer, stream spi.Stream) { 20 | stream.WriteStructHeader() 21 | for _, field := range encoder.fields { 22 | fieldPtr := unsafe.Pointer(uintptr(ptr) + field.offset) 23 | switch field.encoder.(type) { 24 | case *pointerEncoder, *sliceEncoder: 25 | if *(*unsafe.Pointer)(fieldPtr) == nil { 26 | continue 27 | } 28 | case *mapEncoder: 29 | if *(*unsafe.Pointer)(fieldPtr) == nil { 30 | continue 31 | } 32 | fieldPtr = *(*unsafe.Pointer)(fieldPtr) 33 | } 34 | stream.WriteStructField(field.encoder.thriftType(), field.fieldId) 35 | field.encoder.encode(fieldPtr, stream) 36 | } 37 | stream.WriteStructFieldStop() 38 | } 39 | 40 | func (encoder *structEncoder) thriftType() protocol.TType { 41 | return protocol.TypeStruct 42 | } -------------------------------------------------------------------------------- /binding/reflection/unsafe.go: -------------------------------------------------------------------------------- 1 | package reflection 2 | 3 | import ( 4 | "unsafe" 5 | "github.com/thrift-iterator/go/spi" 6 | "github.com/thrift-iterator/go/protocol" 7 | ) 8 | 9 | type internalDecoder interface { 10 | decode(ptr unsafe.Pointer, iter spi.Iterator) 11 | } 12 | 13 | type valDecoderAdapter struct { 14 | decoder internalDecoder 15 | } 16 | 17 | func (decoder *valDecoderAdapter) Decode(val interface{}, iter spi.Iterator) { 18 | ptr := (*emptyInterface)(unsafe.Pointer(&val)).word 19 | decoder.decoder.decode(ptr, iter) 20 | } 21 | 22 | type internalDecoderAdapter struct { 23 | decoder spi.ValDecoder 24 | valEmptyInterface emptyInterface 25 | } 26 | 27 | func (decoder *internalDecoderAdapter) decode(ptr unsafe.Pointer, iter spi.Iterator) { 28 | valEmptyInterface := decoder.valEmptyInterface 29 | valEmptyInterface.word = ptr 30 | valObj := *(*interface{})((unsafe.Pointer(&valEmptyInterface))) 31 | decoder.decoder.Decode(valObj, iter) 32 | } 33 | 34 | type internalEncoder interface { 35 | encode(ptr unsafe.Pointer, stream spi.Stream) 36 | thriftType() protocol.TType 37 | } 38 | 39 | type valEncoderAdapter struct { 40 | encoder internalEncoder 41 | } 42 | 43 | func (encoder *valEncoderAdapter) Encode(val interface{}, stream spi.Stream) { 44 | ptr := (*emptyInterface)(unsafe.Pointer(&val)).word 45 | encoder.encoder.encode(ptr, stream) 46 | } 47 | 48 | func (encoder *valEncoderAdapter) ThriftType() protocol.TType { 49 | return encoder.encoder.thriftType() 50 | } 51 | 52 | type ptrEncoderAdapter struct { 53 | encoder internalEncoder 54 | } 55 | 56 | func (encoder *ptrEncoderAdapter) Encode(val interface{}, stream spi.Stream) { 57 | ptr := (*emptyInterface)(unsafe.Pointer(&val)).word 58 | encoder.encoder.encode(unsafe.Pointer(&ptr), stream) 59 | } 60 | 61 | func (encoder *ptrEncoderAdapter) ThriftType() protocol.TType { 62 | return encoder.encoder.thriftType() 63 | } 64 | 65 | type internalEncoderAdapter struct { 66 | encoder spi.ValEncoder 67 | valEmptyInterface emptyInterface 68 | } 69 | 70 | func (encoder *internalEncoderAdapter) encode(ptr unsafe.Pointer, stream spi.Stream) { 71 | valEmptyInterface := encoder.valEmptyInterface 72 | valEmptyInterface.word = ptr 73 | valObj := *(*interface{})((unsafe.Pointer(&valEmptyInterface))) 74 | encoder.encoder.Encode(valObj, stream) 75 | } 76 | 77 | func (encoder *internalEncoderAdapter) thriftType() protocol.TType { 78 | return encoder.encoder.ThriftType() 79 | } 80 | 81 | // emptyInterface is the header for an interface{} value. 82 | type emptyInterface struct { 83 | typ unsafe.Pointer 84 | word unsafe.Pointer 85 | } 86 | 87 | // sliceHeader is a safe version of SliceHeader used within this package. 88 | type sliceHeader struct { 89 | Data unsafe.Pointer 90 | Len int 91 | Cap int 92 | } 93 | -------------------------------------------------------------------------------- /cmd/thrifter/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "github.com/v2pro/wombat" 6 | "os" 7 | ) 8 | 9 | func main() { 10 | pkgPath := flag.String("pkg", "", "the package to generate generic code for") 11 | flag.Parse() 12 | if *pkgPath == "" { 13 | flag.Usage() 14 | os.Exit(1) 15 | } 16 | wombat.Codegen(*pkgPath) 17 | } -------------------------------------------------------------------------------- /decoder.go: -------------------------------------------------------------------------------- 1 | package thrifter 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/general" 5 | "github.com/thrift-iterator/go/protocol" 6 | "github.com/thrift-iterator/go/spi" 7 | "io" 8 | "reflect" 9 | ) 10 | 11 | type Decoder struct { 12 | cfg *frozenConfig 13 | iter spi.Iterator 14 | } 15 | 16 | func (decoder *Decoder) Decode(val interface{}) error { 17 | cfg := decoder.cfg 18 | valType := reflect.TypeOf(val) 19 | valDecoder := cfg.getGenDecoder(valType) 20 | if valDecoder == nil { 21 | valDecoder = cfg.decoderOf(valType) 22 | cfg.addGenDecoder(valType, valDecoder) 23 | } 24 | valDecoder.Decode(val, decoder.iter) 25 | if decoder.iter.Error() != nil { 26 | return decoder.iter.Error() 27 | } 28 | return nil 29 | } 30 | 31 | func (decoder *Decoder) DecodeMessage() (general.Message, error) { 32 | var msg general.Message 33 | err := decoder.Decode(&msg) 34 | return msg, err 35 | } 36 | 37 | func (decoder *Decoder) DecodeMessageHeader() (protocol.MessageHeader, error) { 38 | var msgHeader protocol.MessageHeader 39 | err := decoder.Decode(&msgHeader) 40 | return msgHeader, err 41 | } 42 | 43 | func (decoder *Decoder) DecodeMessageArguments() (general.Struct, error) { 44 | var msgArgs general.Struct 45 | err := decoder.Decode(&msgArgs) 46 | return msgArgs, err 47 | } 48 | 49 | func (decoder *Decoder) Reset(reader io.Reader, buf []byte) { 50 | decoder.iter.Reset(reader, buf) 51 | } 52 | -------------------------------------------------------------------------------- /encoder.go: -------------------------------------------------------------------------------- 1 | package thrifter 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/general" 5 | "github.com/thrift-iterator/go/protocol" 6 | "github.com/thrift-iterator/go/spi" 7 | "io" 8 | "reflect" 9 | ) 10 | 11 | type Encoder struct { 12 | cfg *frozenConfig 13 | stream spi.Stream 14 | } 15 | 16 | func (encoder *Encoder) Encode(val interface{}) error { 17 | cfg := encoder.cfg 18 | valType := reflect.TypeOf(val) 19 | valEncoder := cfg.getGenEncoder(valType) 20 | if valEncoder == nil { 21 | valEncoder = cfg.encoderOf(valType) 22 | cfg.addGenEncoder(valType, valEncoder) 23 | } 24 | valEncoder.Encode(val, encoder.stream) 25 | encoder.stream.Flush() 26 | if encoder.stream.Error() != nil { 27 | return encoder.stream.Error() 28 | } 29 | return nil 30 | } 31 | 32 | func (encoder *Encoder) EncodeMessage(msg general.Message) error { 33 | return encoder.Encode(msg) 34 | } 35 | 36 | func (encoder *Encoder) EncodeMessageHeader(msgHeader protocol.MessageHeader) error { 37 | return encoder.Encode(msgHeader) 38 | } 39 | 40 | func (encoder *Encoder) EncodeMessageArguments(msgArgs general.Struct) error { 41 | return encoder.Encode(msgArgs) 42 | } 43 | 44 | func (encoder *Encoder) Reset(writer io.Writer) { 45 | encoder.stream.Reset(writer) 46 | } 47 | 48 | func (encoder *Encoder) Buffer() []byte { 49 | return encoder.stream.Buffer() 50 | } 51 | -------------------------------------------------------------------------------- /general/decode.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | func generalReaderOf(ttype protocol.TType) func(iter spi.Iterator) interface{} { 9 | switch ttype { 10 | case protocol.TypeBool: 11 | return readBool 12 | case protocol.TypeI08: 13 | return readInt8 14 | case protocol.TypeI16: 15 | return readInt16 16 | case protocol.TypeI32: 17 | return readInt32 18 | case protocol.TypeI64: 19 | return readInt64 20 | case protocol.TypeString: 21 | return readString 22 | case protocol.TypeDouble: 23 | return readFloat64 24 | case protocol.TypeList: 25 | return readList 26 | case protocol.TypeMap: 27 | return readMap 28 | case protocol.TypeStruct: 29 | return readStruct 30 | default: 31 | panic("unsupported type") 32 | } 33 | } 34 | 35 | func readFloat64(iter spi.Iterator) interface{} { 36 | return iter.ReadFloat64() 37 | } 38 | 39 | func readBool(iter spi.Iterator) interface{} { 40 | return iter.ReadBool() 41 | } 42 | 43 | func readInt8(iter spi.Iterator) interface{} { 44 | return iter.ReadInt8() 45 | } 46 | 47 | func readInt16(iter spi.Iterator) interface{} { 48 | return iter.ReadInt16() 49 | } 50 | 51 | func readInt32(iter spi.Iterator) interface{} { 52 | return iter.ReadInt32() 53 | } 54 | 55 | func readInt64(iter spi.Iterator) interface{} { 56 | return iter.ReadInt64() 57 | } 58 | 59 | func readString(iter spi.Iterator) interface{} { 60 | return iter.ReadString() 61 | } 62 | -------------------------------------------------------------------------------- /general/decode_list.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import "github.com/thrift-iterator/go/spi" 4 | 5 | type generalListDecoder struct { 6 | } 7 | 8 | func (decoder *generalListDecoder) Decode(val interface{}, iter spi.Iterator) { 9 | *val.(*List) = readList(iter).(List) 10 | } 11 | 12 | func readList(iter spi.Iterator) interface{} { 13 | elemType, length := iter.ReadListHeader() 14 | generalReader := generalReaderOf(elemType) 15 | var generalList List 16 | for i := 0; i < length; i++ { 17 | generalList = append(generalList, generalReader(iter)) 18 | } 19 | return generalList 20 | } 21 | -------------------------------------------------------------------------------- /general/decode_map.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import "github.com/thrift-iterator/go/spi" 4 | 5 | type generalMapDecoder struct { 6 | } 7 | 8 | func (decoder *generalMapDecoder) Decode(val interface{}, iter spi.Iterator) { 9 | *val.(*Map) = readMap(iter).(Map) 10 | } 11 | 12 | func readMap(iter spi.Iterator) interface{} { 13 | keyType, elemType, length := iter.ReadMapHeader() 14 | generalMap := Map{} 15 | if length == 0 { 16 | return generalMap 17 | } 18 | keyReader := generalReaderOf(keyType) 19 | elemReader := generalReaderOf(elemType) 20 | for i := 0; i < length; i++ { 21 | key := keyReader(iter) 22 | elem := elemReader(iter) 23 | generalMap[key] = elem 24 | } 25 | return generalMap 26 | } -------------------------------------------------------------------------------- /general/decode_message.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | type messageDecoder struct { 9 | } 10 | 11 | func (decoder *messageDecoder) Decode(val interface{}, iter spi.Iterator) { 12 | *val.(*Message) = Message{ 13 | MessageHeader: iter.ReadMessageHeader(), 14 | Arguments: readStruct(iter).(Struct), 15 | } 16 | } 17 | 18 | type messageHeaderDecoder struct { 19 | } 20 | 21 | func (decoder *messageHeaderDecoder) Decode(val interface{}, iter spi.Iterator) { 22 | *val.(*protocol.MessageHeader) = iter.ReadMessageHeader() 23 | } -------------------------------------------------------------------------------- /general/decode_struct.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | type generalStructDecoder struct { 9 | } 10 | 11 | func (decoder *generalStructDecoder) Decode(val interface{}, iter spi.Iterator) { 12 | *val.(*Struct) = readStruct(iter).(Struct) 13 | } 14 | 15 | func readStruct(iter spi.Iterator) interface{} { 16 | generalStruct := Struct{} 17 | iter.ReadStructHeader() 18 | for { 19 | fieldType, fieldId := iter.ReadStructField() 20 | if fieldType == protocol.TypeStop { 21 | return generalStruct 22 | } 23 | generalReader := generalReaderOf(fieldType) 24 | generalStruct[fieldId] = generalReader(iter) 25 | } 26 | } -------------------------------------------------------------------------------- /general/encode.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | "reflect" 7 | ) 8 | 9 | func generalWriterOf(sample interface{}) (protocol.TType, func(val interface{}, stream spi.Stream)) { 10 | switch sample.(type) { 11 | case bool: 12 | return protocol.TypeBool, writeBool 13 | case int8: 14 | return protocol.TypeI08, writeInt8 15 | case uint8: 16 | return protocol.TypeI08, writeUint8 17 | case int16: 18 | return protocol.TypeI16, writeInt16 19 | case uint16: 20 | return protocol.TypeI16, writeUint16 21 | case int32: 22 | return protocol.TypeI32, writeInt32 23 | case uint32: 24 | return protocol.TypeI32, writeUint32 25 | case int64: 26 | return protocol.TypeI64, writeInt64 27 | case uint64: 28 | return protocol.TypeI64, writeUint64 29 | case float64: 30 | return protocol.TypeDouble, writeFloat64 31 | case string: 32 | return protocol.TypeString, writeString 33 | case []byte: 34 | return protocol.TypeString, writeBinary 35 | case List: 36 | return protocol.TypeList, writeList 37 | case Map: 38 | return protocol.TypeMap, writeMap 39 | case Struct: 40 | return protocol.TypeStruct, writeStruct 41 | default: 42 | panic("unsupported type: " + reflect.TypeOf(sample).String()) 43 | } 44 | } 45 | 46 | func writeBool(val interface{}, stream spi.Stream) { 47 | stream.WriteBool(val.(bool)) 48 | } 49 | 50 | func writeInt8(val interface{}, stream spi.Stream) { 51 | stream.WriteInt8(val.(int8)) 52 | } 53 | 54 | func writeUint8(val interface{}, stream spi.Stream) { 55 | stream.WriteUint8(val.(uint8)) 56 | } 57 | 58 | func writeInt16(val interface{}, stream spi.Stream) { 59 | stream.WriteInt16(val.(int16)) 60 | } 61 | 62 | func writeUint16(val interface{}, stream spi.Stream) { 63 | stream.WriteUint16(val.(uint16)) 64 | } 65 | 66 | func writeInt32(val interface{}, stream spi.Stream) { 67 | stream.WriteInt32(val.(int32)) 68 | } 69 | 70 | func writeUint32(val interface{}, stream spi.Stream) { 71 | stream.WriteUint32(val.(uint32)) 72 | } 73 | 74 | func writeInt64(val interface{}, stream spi.Stream) { 75 | stream.WriteInt64(val.(int64)) 76 | } 77 | 78 | func writeUint64(val interface{}, stream spi.Stream) { 79 | stream.WriteUint64(val.(uint64)) 80 | } 81 | 82 | func writeFloat64(val interface{}, stream spi.Stream) { 83 | stream.WriteFloat64(val.(float64)) 84 | } 85 | 86 | func writeString(val interface{}, stream spi.Stream) { 87 | stream.WriteString(val.(string)) 88 | } 89 | 90 | func writeBinary(val interface{}, stream spi.Stream) { 91 | stream.WriteBinary(val.([]byte)) 92 | } -------------------------------------------------------------------------------- /general/encode_list.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | type generalListEncoder struct { 9 | } 10 | 11 | func (encoder *generalListEncoder) Encode(val interface{}, stream spi.Stream) { 12 | writeList(val, stream) 13 | } 14 | 15 | func (encoder *generalListEncoder) ThriftType() protocol.TType { 16 | return protocol.TypeList 17 | } 18 | 19 | func writeList(val interface{}, stream spi.Stream) { 20 | obj := val.(List) 21 | length := len(obj) 22 | if length == 0 { 23 | stream.WriteListHeader(protocol.TypeI64, 0) 24 | return 25 | } 26 | elemType, generalWriter := generalWriterOf(obj[0]) 27 | stream.WriteListHeader(elemType, length) 28 | for _, elem := range obj { 29 | generalWriter(elem, stream) 30 | } 31 | } -------------------------------------------------------------------------------- /general/encode_map.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | type generalMapEncoder struct { 9 | } 10 | 11 | func (encoder *generalMapEncoder) Encode(val interface{}, stream spi.Stream) { 12 | writeMap(val, stream) 13 | } 14 | 15 | func (encoder *generalMapEncoder) ThriftType() protocol.TType { 16 | return protocol.TypeMap 17 | } 18 | 19 | func takeSampleFromMap(sample Map) (interface{}, interface{}){ 20 | for key, elem := range sample { 21 | return key, elem 22 | } 23 | panic("should not reach here") 24 | } 25 | 26 | func writeMap(val interface{}, stream spi.Stream) { 27 | obj := val.(Map) 28 | length := len(obj) 29 | if length == 0 { 30 | stream.WriteMapHeader(protocol.TypeI64, protocol.TypeI64, 0) 31 | return 32 | } 33 | keySample, elemSample := takeSampleFromMap(obj) 34 | keyType, generalKeyWriter := generalWriterOf(keySample) 35 | elemType, generalElemWriter := generalWriterOf(elemSample) 36 | stream.WriteMapHeader(keyType, elemType, length) 37 | for key, elem := range obj { 38 | generalKeyWriter(key, stream) 39 | generalElemWriter(elem, stream) 40 | } 41 | } -------------------------------------------------------------------------------- /general/encode_message.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | type messageEncoder struct { 9 | } 10 | 11 | func (encoder *messageEncoder) Encode(val interface{}, stream spi.Stream) { 12 | msg := val.(Message) 13 | stream.WriteMessageHeader(msg.MessageHeader) 14 | writeStruct(msg.Arguments, stream) 15 | } 16 | 17 | func (encoder *messageEncoder) ThriftType() protocol.TType { 18 | return protocol.TypeStruct 19 | } 20 | 21 | type messageHeaderEncoder struct { 22 | } 23 | 24 | func (encoder *messageHeaderEncoder) Encode(val interface{}, stream spi.Stream) { 25 | msgHeader := val.(protocol.MessageHeader) 26 | stream.WriteMessageHeader(msgHeader) 27 | } 28 | 29 | func (encoder *messageHeaderEncoder) ThriftType() protocol.TType { 30 | return protocol.TypeStruct 31 | } -------------------------------------------------------------------------------- /general/encode_struct.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | type generalStructEncoder struct { 9 | } 10 | 11 | func (encoder *generalStructEncoder) Encode(val interface{}, stream spi.Stream) { 12 | writeStruct(val, stream) 13 | } 14 | 15 | func (encoder *generalStructEncoder) ThriftType() protocol.TType { 16 | return protocol.TypeStruct 17 | } 18 | 19 | func writeStruct(val interface{}, stream spi.Stream) { 20 | obj := val.(Struct) 21 | stream.WriteStructHeader() 22 | for fieldId, elem := range obj { 23 | fieldType, generalWriter := generalWriterOf(elem) 24 | stream.WriteStructField(fieldType, fieldId) 25 | generalWriter(elem, stream) 26 | } 27 | stream.WriteStructFieldStop() 28 | } 29 | -------------------------------------------------------------------------------- /general/general_extension.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import ( 4 | "reflect" 5 | "github.com/thrift-iterator/go/spi" 6 | "github.com/thrift-iterator/go/protocol" 7 | ) 8 | 9 | type Extension struct { 10 | } 11 | 12 | func (ext *Extension) EncoderOf(valType reflect.Type) spi.ValEncoder { 13 | switch valType { 14 | case reflect.TypeOf(List(nil)): 15 | return &generalListEncoder{} 16 | case reflect.TypeOf(Map(nil)): 17 | return &generalMapEncoder{} 18 | case reflect.TypeOf(Struct(nil)): 19 | return &generalStructEncoder{} 20 | case reflect.TypeOf((*Message)(nil)).Elem(): 21 | return &messageEncoder{} 22 | case reflect.TypeOf((*protocol.MessageHeader)(nil)).Elem(): 23 | return &messageHeaderEncoder{} 24 | } 25 | return nil 26 | } 27 | 28 | func (ext *Extension) DecoderOf(valType reflect.Type) spi.ValDecoder { 29 | switch valType { 30 | case reflect.TypeOf((*List)(nil)): 31 | return &generalListDecoder{} 32 | case reflect.TypeOf((*Map)(nil)): 33 | return &generalMapDecoder{} 34 | case reflect.TypeOf((*Struct)(nil)): 35 | return &generalStructDecoder{} 36 | case reflect.TypeOf((*Message)(nil)): 37 | return &messageDecoder{} 38 | case reflect.TypeOf((*protocol.MessageHeader)(nil)): 39 | return &messageHeaderDecoder{} 40 | } 41 | return nil 42 | } 43 | -------------------------------------------------------------------------------- /general/general_object.go: -------------------------------------------------------------------------------- 1 | package general 2 | 3 | import "github.com/thrift-iterator/go/protocol" 4 | 5 | type Object interface { 6 | Get(path ...interface{}) interface{} 7 | } 8 | 9 | type List []interface{} 10 | 11 | func (obj List) Get(path ...interface{}) interface{} { 12 | if len(path) == 0 { 13 | return obj 14 | } 15 | elem := obj[path[0].(int)] 16 | if len(path) == 1 { 17 | return elem 18 | } 19 | return elem.(Object).Get(path[1:]...) 20 | } 21 | 22 | type Map map[interface{}]interface{} 23 | 24 | func (obj Map) Get(path ...interface{}) interface{} { 25 | if len(path) == 0 { 26 | return obj 27 | } 28 | elem := obj[path[0]] 29 | if len(path) == 1 { 30 | return elem 31 | } 32 | return elem.(Object).Get(path[1:]...) 33 | } 34 | 35 | type Struct map[protocol.FieldId]interface{} 36 | 37 | func (obj Struct) Get(path ...interface{}) interface{} { 38 | if len(path) == 0 { 39 | return obj 40 | } 41 | elem := obj[path[0].(protocol.FieldId)] 42 | if len(path) == 1 { 43 | return elem 44 | } 45 | return elem.(Object).Get(path[1:]...) 46 | } 47 | 48 | type Message struct { 49 | protocol.MessageHeader 50 | Arguments Struct 51 | } 52 | -------------------------------------------------------------------------------- /protocol/binary/discard.go: -------------------------------------------------------------------------------- 1 | package binary 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/protocol" 5 | "github.com/thrift-iterator/go/spi" 6 | ) 7 | 8 | func (iter *Iterator) Discard(ttype protocol.TType) { 9 | switch ttype { 10 | case protocol.TypeBool, protocol.TypeI08: 11 | iter.readByte() 12 | case protocol.TypeI16: 13 | iter.readSmall(2) 14 | case protocol.TypeI32: 15 | iter.readSmall(4) 16 | case protocol.TypeI64, protocol.TypeDouble: 17 | iter.readSmall(8) 18 | case protocol.TypeString: 19 | iter.SkipBinary(nil) 20 | case protocol.TypeList: 21 | spi.DiscardList(iter) 22 | case protocol.TypeStruct: 23 | spi.DiscardStruct(iter) 24 | case protocol.TypeMap: 25 | spi.DiscardMap(iter) 26 | default: 27 | panic("unsupported type") 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /protocol/binary/iterator.go: -------------------------------------------------------------------------------- 1 | package binary 2 | 3 | import ( 4 | "fmt" 5 | "github.com/thrift-iterator/go/protocol" 6 | "github.com/thrift-iterator/go/spi" 7 | "io" 8 | "math" 9 | ) 10 | 11 | type Iterator struct { 12 | spi.ValDecoderProvider 13 | reader io.Reader 14 | tmp []byte 15 | preread []byte 16 | skipped []byte 17 | err error 18 | } 19 | 20 | func NewIterator(provider spi.ValDecoderProvider, reader io.Reader, buf []byte) *Iterator { 21 | return &Iterator{ 22 | ValDecoderProvider: provider, 23 | reader: reader, 24 | tmp: make([]byte, 8), 25 | preread: buf, 26 | } 27 | } 28 | 29 | func (iter *Iterator) readByte() byte { 30 | tmp := iter.tmp[:1] 31 | if len(iter.preread) > 0 { 32 | tmp[0] = iter.preread[0] 33 | iter.preread = iter.preread[1:] 34 | } else { 35 | _, err := iter.reader.Read(tmp) 36 | if err != nil { 37 | iter.ReportError("read", err.Error()) 38 | return 0 39 | } 40 | } 41 | if iter.skipped != nil { 42 | iter.skipped = append(iter.skipped, tmp[0]) 43 | } 44 | return tmp[0] 45 | } 46 | 47 | func (iter *Iterator) readSmall(nBytes int) []byte { 48 | tmp := iter.tmp[:nBytes] 49 | wantBytes := nBytes 50 | if len(iter.preread) > 0 { 51 | if len(iter.preread) > nBytes { 52 | copy(tmp, iter.preread[:nBytes]) 53 | iter.preread = iter.preread[nBytes:] 54 | wantBytes = 0 55 | } else { 56 | prelength := len(iter.preread) 57 | copy(tmp[:prelength], iter.preread) 58 | wantBytes -= prelength 59 | iter.preread = nil 60 | } 61 | } 62 | if wantBytes > 0 { 63 | _, err := io.ReadFull(iter.reader, tmp[nBytes-wantBytes:nBytes]) 64 | if err != nil { 65 | for i := 0; i < len(tmp); i++ { 66 | tmp[i] = 0 67 | } 68 | iter.ReportError("read", err.Error()) 69 | return tmp 70 | } 71 | } 72 | if iter.skipped != nil { 73 | iter.skipped = append(iter.skipped, tmp...) 74 | } 75 | return tmp 76 | } 77 | 78 | func (iter *Iterator) readLarge(nBytes int) []byte { 79 | // allocate new buffer if not enough 80 | if len(iter.tmp) < nBytes { 81 | iter.tmp = make([]byte, nBytes) 82 | } 83 | return iter.readSmall(nBytes) 84 | } 85 | 86 | func (iter *Iterator) Spawn() spi.Iterator { 87 | return NewIterator(iter.ValDecoderProvider, nil, nil) 88 | } 89 | 90 | func (iter *Iterator) Error() error { 91 | return iter.err 92 | } 93 | 94 | func (iter *Iterator) ReportError(operation string, err string) { 95 | if iter.err == nil { 96 | iter.err = fmt.Errorf("%s: %s", operation, err) 97 | } 98 | } 99 | 100 | func (iter *Iterator) Reset(reader io.Reader, buf []byte) { 101 | iter.reader = reader 102 | iter.preread = buf 103 | iter.err = nil 104 | } 105 | 106 | func (iter *Iterator) ReadMessageHeader() protocol.MessageHeader { 107 | versionAndMessageType := iter.ReadInt32() 108 | messageType := protocol.TMessageType(versionAndMessageType & 0x0ff) 109 | version := int64(int64(versionAndMessageType) & 0xffff0000) 110 | if version != protocol.BINARY_VERSION_1 { 111 | iter.ReportError("ReadMessageHeader", "unexpected version") 112 | return protocol.MessageHeader{} 113 | } 114 | messageName := iter.ReadString() 115 | seqId := protocol.SeqId(iter.ReadInt32()) 116 | return protocol.MessageHeader{ 117 | MessageName: messageName, 118 | MessageType: messageType, 119 | SeqId: seqId, 120 | } 121 | } 122 | 123 | func (iter *Iterator) ReadStructHeader() { 124 | // noop 125 | } 126 | 127 | func (iter *Iterator) ReadStructField() (fieldType protocol.TType, fieldId protocol.FieldId) { 128 | firstByte := iter.readByte() 129 | fieldType = protocol.TType(firstByte) 130 | if fieldType == protocol.TypeStop { 131 | return protocol.TypeStop, 0 132 | } 133 | fieldId = protocol.FieldId(iter.ReadUint16()) 134 | return fieldType, fieldId 135 | } 136 | 137 | func (iter *Iterator) ReadListHeader() (elemType protocol.TType, size int) { 138 | b := iter.readSmall(5) 139 | elemType = protocol.TType(b[0]) 140 | size = int(uint32(b[4]) | uint32(b[3])<<8 | uint32(b[2])<<16 | uint32(b[1])<<24) 141 | return elemType, size 142 | } 143 | 144 | func (iter *Iterator) ReadMapHeader() (keyType protocol.TType, elemType protocol.TType, size int) { 145 | b := iter.readSmall(6) 146 | keyType = protocol.TType(b[0]) 147 | elemType = protocol.TType(b[1]) 148 | size = int(uint32(b[5]) | uint32(b[4])<<8 | uint32(b[3])<<16 | uint32(b[2])<<24) 149 | return keyType, elemType, size 150 | } 151 | 152 | func (iter *Iterator) ReadBool() bool { 153 | return iter.ReadUint8() == 1 154 | } 155 | 156 | func (iter *Iterator) ReadInt() int { 157 | return int(iter.ReadInt64()) 158 | } 159 | 160 | func (iter *Iterator) ReadUint() uint { 161 | return uint(iter.ReadUint64()) 162 | } 163 | 164 | func (iter *Iterator) ReadInt8() int8 { 165 | return int8(iter.ReadUint8()) 166 | } 167 | 168 | func (iter *Iterator) ReadUint8() uint8 { 169 | return iter.readByte() 170 | } 171 | 172 | func (iter *Iterator) ReadInt16() int16 { 173 | return int16(iter.ReadUint16()) 174 | } 175 | 176 | func (iter *Iterator) ReadUint16() uint16 { 177 | b := iter.readSmall(2) 178 | return uint16(b[1]) | uint16(b[0])<<8 179 | } 180 | 181 | func (iter *Iterator) ReadInt32() int32 { 182 | return int32(iter.ReadUint32()) 183 | } 184 | 185 | func (iter *Iterator) ReadUint32() uint32 { 186 | b := iter.readSmall(4) 187 | return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 188 | } 189 | 190 | func (iter *Iterator) ReadInt64() int64 { 191 | return int64(iter.ReadUint64()) 192 | } 193 | 194 | func (iter *Iterator) ReadUint64() uint64 { 195 | b := iter.readSmall(8) 196 | return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 | 197 | uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56 198 | } 199 | 200 | func (iter *Iterator) ReadFloat64() float64 { 201 | return math.Float64frombits(iter.ReadUint64()) 202 | } 203 | 204 | func (iter *Iterator) ReadString() string { 205 | length := iter.ReadUint32() 206 | return string(iter.readLarge(int(length))) 207 | } 208 | 209 | func (iter *Iterator) ReadBinary() []byte { 210 | length := iter.ReadUint32() 211 | tmp := make([]byte, length) 212 | copy(tmp, iter.readLarge(int(length))) 213 | return tmp 214 | } 215 | -------------------------------------------------------------------------------- /protocol/binary/skip.go: -------------------------------------------------------------------------------- 1 | package binary 2 | 3 | import "github.com/thrift-iterator/go/protocol" 4 | 5 | func (iter *Iterator) skip(skipper func(), space []byte) []byte { 6 | var tmp []byte 7 | iter.skipped = make([]byte, 0, 8) 8 | skipper() 9 | tmp, iter.skipped = iter.skipped, nil 10 | if iter.Error() != nil { 11 | return nil 12 | } 13 | if len(space) > 0 { 14 | return append(space, tmp...) 15 | } 16 | return tmp 17 | } 18 | 19 | func (iter *Iterator) Skip(ttype protocol.TType, space []byte) []byte { 20 | return iter.skip(func() { iter.Discard(ttype) }, space) 21 | } 22 | 23 | func (iter *Iterator) SkipMessageHeader(space []byte) []byte { 24 | return iter.skip(func() { iter.ReadMessageHeader() }, space) 25 | } 26 | 27 | func (iter *Iterator) SkipStruct(space []byte) []byte { 28 | return iter.skip(func() { iter.Discard(protocol.TypeStruct) }, space) 29 | } 30 | 31 | func (iter *Iterator) SkipList(space []byte) []byte { 32 | return iter.skip(func() { iter.Discard(protocol.TypeList) }, space) 33 | } 34 | 35 | func (iter *Iterator) SkipMap(space []byte) []byte { 36 | return iter.skip(func() { iter.Discard(protocol.TypeMap) }, space) 37 | } 38 | 39 | func (iter *Iterator) SkipBinary(space []byte) []byte { 40 | tmp := iter.ReadBinary() 41 | if iter.Error() != nil { 42 | return nil 43 | } 44 | if len(space) > 0 { 45 | return append(space, tmp...) 46 | } 47 | return tmp 48 | } 49 | -------------------------------------------------------------------------------- /protocol/binary/stream.go: -------------------------------------------------------------------------------- 1 | package binary 2 | 3 | import ( 4 | "fmt" 5 | "github.com/thrift-iterator/go/protocol" 6 | "github.com/thrift-iterator/go/spi" 7 | "io" 8 | "math" 9 | ) 10 | 11 | type Stream struct { 12 | spi.ValEncoderProvider 13 | writer io.Writer 14 | buf []byte 15 | err error 16 | } 17 | 18 | func NewStream(provider spi.ValEncoderProvider, writer io.Writer, buf []byte) *Stream { 19 | return &Stream{ 20 | ValEncoderProvider: provider, 21 | writer: writer, 22 | buf: buf, 23 | } 24 | } 25 | 26 | func (stream *Stream) Spawn() spi.Stream { 27 | return &Stream{ 28 | ValEncoderProvider: stream.ValEncoderProvider, 29 | } 30 | } 31 | 32 | func (stream *Stream) Error() error { 33 | return stream.err 34 | } 35 | 36 | func (stream *Stream) ReportError(operation string, err string) { 37 | if stream.err == nil { 38 | stream.err = fmt.Errorf("%s: %s", operation, err) 39 | } 40 | } 41 | 42 | func (stream *Stream) Buffer() []byte { 43 | return stream.buf 44 | } 45 | 46 | func (stream *Stream) Reset(writer io.Writer) { 47 | stream.writer = writer 48 | stream.err = nil 49 | stream.buf = stream.buf[:0] 50 | } 51 | 52 | func (stream *Stream) Flush() { 53 | if stream.writer == nil { 54 | return 55 | } 56 | _, err := stream.writer.Write(stream.buf) 57 | if err != nil { 58 | stream.ReportError("Flush", err.Error()) 59 | return 60 | } 61 | if f, ok := stream.writer.(protocol.Flusher); ok { 62 | if err = f.Flush(); err != nil { 63 | stream.ReportError("Flush", err.Error()) 64 | } 65 | } 66 | stream.buf = stream.buf[:0] 67 | } 68 | 69 | func (stream *Stream) Write(buf []byte) error { 70 | stream.buf = append(stream.buf, buf...) 71 | stream.Flush() 72 | return stream.Error() 73 | } 74 | 75 | func (stream *Stream) WriteMessageHeader(header protocol.MessageHeader) { 76 | versionAndMessageType := uint32(protocol.BINARY_VERSION_1) | uint32(header.MessageType) 77 | stream.WriteUint32(versionAndMessageType) 78 | stream.WriteString(header.MessageName) 79 | stream.WriteInt32(int32(header.SeqId)) 80 | } 81 | 82 | func (stream *Stream) WriteListHeader(elemType protocol.TType, length int) { 83 | stream.buf = append(stream.buf, byte(elemType), 84 | byte(length>>24), byte(length>>16), byte(length>>8), byte(length)) 85 | } 86 | 87 | func (stream *Stream) WriteStructHeader() { 88 | } 89 | 90 | func (stream *Stream) WriteStructField(fieldType protocol.TType, fieldId protocol.FieldId) { 91 | stream.buf = append(stream.buf, byte(fieldType), byte(fieldId>>8), byte(fieldId)) 92 | } 93 | 94 | func (stream *Stream) WriteStructFieldStop() { 95 | stream.buf = append(stream.buf, byte(protocol.TypeStop)) 96 | } 97 | 98 | func (stream *Stream) WriteMapHeader(keyType protocol.TType, elemType protocol.TType, length int) { 99 | stream.buf = append(stream.buf, byte(keyType), byte(elemType), 100 | byte(length>>24), byte(length>>16), byte(length>>8), byte(length)) 101 | } 102 | 103 | func (stream *Stream) WriteBool(val bool) { 104 | if val { 105 | stream.WriteUint8(1) 106 | } else { 107 | stream.WriteUint8(0) 108 | } 109 | } 110 | 111 | func (stream *Stream) WriteInt8(val int8) { 112 | stream.WriteUint8(uint8(val)) 113 | } 114 | 115 | func (stream *Stream) WriteUint8(val uint8) { 116 | stream.buf = append(stream.buf, byte(val)) 117 | } 118 | 119 | func (stream *Stream) WriteInt16(val int16) { 120 | stream.WriteUint16(uint16(val)) 121 | } 122 | 123 | func (stream *Stream) WriteUint16(val uint16) { 124 | stream.buf = append(stream.buf, byte(val>>8), byte(val)) 125 | } 126 | 127 | func (stream *Stream) WriteInt32(val int32) { 128 | stream.WriteUint32(uint32(val)) 129 | } 130 | 131 | func (stream *Stream) WriteUint32(val uint32) { 132 | stream.buf = append(stream.buf, byte(val>>24), byte(val>>16), byte(val>>8), byte(val)) 133 | } 134 | 135 | func (stream *Stream) WriteInt64(val int64) { 136 | stream.WriteUint64(uint64(val)) 137 | } 138 | 139 | func (stream *Stream) WriteUint64(val uint64) { 140 | stream.buf = append(stream.buf, 141 | byte(val>>56), byte(val>>48), byte(val>>40), byte(val>>32), 142 | byte(val>>24), byte(val>>16), byte(val>>8), byte(val)) 143 | } 144 | 145 | func (stream *Stream) WriteInt(val int) { 146 | stream.WriteInt64(int64(val)) 147 | } 148 | 149 | func (stream *Stream) WriteUint(val uint) { 150 | stream.WriteUint64(uint64(val)) 151 | } 152 | 153 | func (stream *Stream) WriteFloat64(val float64) { 154 | stream.WriteUint64(math.Float64bits(val)) 155 | } 156 | 157 | func (stream *Stream) WriteBinary(val []byte) { 158 | stream.WriteUint32(uint32(len(val))) 159 | stream.buf = append(stream.buf, val...) 160 | } 161 | 162 | func (stream *Stream) WriteString(val string) { 163 | stream.WriteUint32(uint32(len(val))) 164 | stream.buf = append(stream.buf, val...) 165 | } 166 | -------------------------------------------------------------------------------- /protocol/compact/discard.go: -------------------------------------------------------------------------------- 1 | package compact 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/protocol" 5 | "github.com/thrift-iterator/go/spi" 6 | ) 7 | 8 | func (iter *Iterator) Discard(ttype protocol.TType) { 9 | switch ttype { 10 | case protocol.TypeBool, protocol.TypeI08: 11 | iter.ReadInt8() 12 | case protocol.TypeI16: 13 | iter.ReadInt16() 14 | case protocol.TypeI32: 15 | iter.ReadInt32() 16 | case protocol.TypeI64: 17 | iter.ReadInt64() 18 | case protocol.TypeDouble: 19 | iter.ReadFloat64() 20 | case protocol.TypeString: 21 | iter.SkipBinary(nil) 22 | case protocol.TypeList: 23 | spi.DiscardList(iter) 24 | case protocol.TypeStruct: 25 | spi.DiscardStruct(iter) 26 | case protocol.TypeMap: 27 | spi.DiscardMap(iter) 28 | default: 29 | panic("unsupported type") 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /protocol/compact/skip.go: -------------------------------------------------------------------------------- 1 | package compact 2 | 3 | import "github.com/thrift-iterator/go/protocol" 4 | 5 | func (iter *Iterator) skip(skipper func(), space []byte) []byte { 6 | var tmp []byte 7 | iter.skipped = make([]byte, 0, 8) 8 | skipper() 9 | tmp, iter.skipped = iter.skipped, nil 10 | if iter.Error() != nil { 11 | return nil 12 | } 13 | if len(space) > 0 { 14 | return append(space, tmp...) 15 | } 16 | return tmp 17 | } 18 | 19 | func (iter *Iterator) Skip(ttype protocol.TType, space []byte) []byte { 20 | return iter.skip(func() { iter.Discard(ttype) }, space) 21 | } 22 | 23 | func (iter *Iterator) SkipMessageHeader(space []byte) []byte { 24 | return iter.skip(func() { iter.ReadMessageHeader() }, space) 25 | } 26 | 27 | func (iter *Iterator) SkipStruct(space []byte) []byte { 28 | return iter.skip(func() { iter.Discard(protocol.TypeStruct) }, space) 29 | } 30 | 31 | func (iter *Iterator) SkipList(space []byte) []byte { 32 | return iter.skip(func() { iter.Discard(protocol.TypeList) }, space) 33 | } 34 | 35 | func (iter *Iterator) SkipMap(space []byte) []byte { 36 | return iter.skip(func() { iter.Discard(protocol.TypeMap) }, space) 37 | } 38 | 39 | func (iter *Iterator) SkipBinary(space []byte) []byte { 40 | tmp := iter.ReadBinary() 41 | if iter.Error() != nil { 42 | return nil 43 | } 44 | if len(space) > 0 { 45 | return append(space, tmp...) 46 | } 47 | return tmp 48 | } 49 | -------------------------------------------------------------------------------- /protocol/compact/stream.go: -------------------------------------------------------------------------------- 1 | package compact 2 | 3 | import ( 4 | "fmt" 5 | "github.com/thrift-iterator/go/protocol" 6 | "github.com/thrift-iterator/go/spi" 7 | "io" 8 | "math" 9 | ) 10 | 11 | type Stream struct { 12 | spi.ValEncoderProvider 13 | writer io.Writer 14 | buf []byte 15 | err error 16 | fieldIdStack []protocol.FieldId 17 | lastFieldId protocol.FieldId 18 | pendingBoolField protocol.FieldId 19 | } 20 | 21 | func NewStream(provider spi.ValEncoderProvider, writer io.Writer, buf []byte) *Stream { 22 | return &Stream{ 23 | ValEncoderProvider: provider, 24 | writer: writer, 25 | buf: buf, 26 | pendingBoolField: -1, 27 | } 28 | } 29 | 30 | func (stream *Stream) Spawn() spi.Stream { 31 | return &Stream{ 32 | ValEncoderProvider: stream.ValEncoderProvider, 33 | } 34 | } 35 | 36 | func (stream *Stream) Error() error { 37 | return stream.err 38 | } 39 | 40 | func (stream *Stream) ReportError(operation string, err string) { 41 | if stream.err == nil { 42 | stream.err = fmt.Errorf("%s: %s", operation, err) 43 | } 44 | } 45 | 46 | func (stream *Stream) Buffer() []byte { 47 | return stream.buf 48 | } 49 | 50 | func (stream *Stream) Reset(writer io.Writer) { 51 | stream.writer = writer 52 | stream.err = nil 53 | stream.buf = stream.buf[:0] 54 | } 55 | 56 | func (stream *Stream) Flush() { 57 | if stream.writer == nil { 58 | return 59 | } 60 | _, err := stream.writer.Write(stream.buf) 61 | if err != nil { 62 | stream.ReportError("Flush", err.Error()) 63 | return 64 | } 65 | if f, ok := stream.writer.(protocol.Flusher); ok { 66 | if err = f.Flush(); err != nil { 67 | stream.ReportError("Flush", err.Error()) 68 | } 69 | } 70 | stream.buf = stream.buf[:0] 71 | } 72 | 73 | func (stream *Stream) Write(buf []byte) error { 74 | stream.buf = append(stream.buf, buf...) 75 | stream.Flush() 76 | return stream.Error() 77 | } 78 | 79 | func (stream *Stream) WriteMessageHeader(header protocol.MessageHeader) { 80 | stream.buf = append(stream.buf, protocol.COMPACT_PROTOCOL_ID) 81 | stream.buf = append(stream.buf, (protocol.COMPACT_VERSION&protocol.COMPACT_VERSION_MASK)|((byte(header.MessageType)<<5)&0x0E0)) 82 | stream.writeVarInt32(int32(header.SeqId)) 83 | stream.WriteString(header.MessageName) 84 | } 85 | 86 | func (stream *Stream) WriteListHeader(elemType protocol.TType, length int) { 87 | if length <= 14 { 88 | stream.WriteUint8(uint8(int32(length<<4) | int32(compactTypes[elemType]))) 89 | return 90 | } 91 | stream.WriteUint8(0xf0 | uint8(compactTypes[elemType])) 92 | stream.writeVarInt32(int32(length)) 93 | } 94 | 95 | func (stream *Stream) WriteStructHeader() { 96 | stream.fieldIdStack = append(stream.fieldIdStack, stream.lastFieldId) 97 | stream.lastFieldId = 0 98 | } 99 | 100 | func (stream *Stream) WriteStructField(fieldType protocol.TType, fieldId protocol.FieldId) { 101 | if fieldType == protocol.TypeBool { 102 | stream.pendingBoolField = fieldId 103 | return 104 | } 105 | compactType := uint8(compactTypes[fieldType]) 106 | // check if we can use delta encoding for the field id 107 | if fieldId > stream.lastFieldId && fieldId-stream.lastFieldId <= 15 { 108 | stream.WriteUint8(uint8((fieldId-stream.lastFieldId)<<4) | compactType) 109 | } else { 110 | stream.WriteUint8(compactType) 111 | stream.WriteInt16(int16(fieldId)) 112 | } 113 | stream.lastFieldId = fieldId 114 | } 115 | 116 | func (stream *Stream) WriteStructFieldStop() { 117 | stream.buf = append(stream.buf, byte(TypeStop)) 118 | stream.lastFieldId = stream.fieldIdStack[len(stream.fieldIdStack)-1] 119 | stream.fieldIdStack = stream.fieldIdStack[:len(stream.fieldIdStack)-1] 120 | stream.pendingBoolField = -1 121 | } 122 | 123 | func (stream *Stream) WriteMapHeader(keyType protocol.TType, elemType protocol.TType, length int) { 124 | if length == 0 { 125 | stream.WriteUint8(0) 126 | return 127 | } 128 | stream.writeVarInt32(int32(length)) 129 | stream.WriteUint8(uint8(compactTypes[keyType]<<4 | TCompactType(compactTypes[elemType]))) 130 | } 131 | 132 | func (stream *Stream) WriteBool(val bool) { 133 | if stream.pendingBoolField == -1 { 134 | if val { 135 | stream.WriteUint8(1) 136 | } else { 137 | stream.WriteUint8(0) 138 | } 139 | return 140 | } 141 | var compactType TCompactType 142 | if val { 143 | compactType = TypeBooleanTrue 144 | } else { 145 | compactType = TypeBooleanFalse 146 | } 147 | fieldId := stream.pendingBoolField 148 | // check if we can use delta encoding for the field id 149 | if fieldId > stream.lastFieldId && fieldId-stream.lastFieldId <= 15 { 150 | stream.WriteUint8(uint8((fieldId-stream.lastFieldId)<<4) | uint8(compactType)) 151 | } else { 152 | stream.WriteUint8(uint8(compactType)) 153 | stream.WriteInt16(int16(fieldId)) 154 | } 155 | stream.lastFieldId = fieldId 156 | stream.pendingBoolField = -1 157 | } 158 | 159 | func (stream *Stream) WriteInt8(val int8) { 160 | stream.WriteUint8(uint8(val)) 161 | } 162 | 163 | func (stream *Stream) WriteUint8(val uint8) { 164 | stream.buf = append(stream.buf, byte(val)) 165 | } 166 | 167 | func (stream *Stream) WriteInt16(val int16) { 168 | stream.WriteInt32(int32(val)) 169 | } 170 | 171 | func (stream *Stream) WriteUint16(val uint16) { 172 | stream.WriteInt32(int32(val)) 173 | } 174 | 175 | func (stream *Stream) WriteInt32(val int32) { 176 | stream.writeVarInt32((val << 1) ^ (val >> 31)) 177 | } 178 | 179 | func (stream *Stream) WriteUint32(val uint32) { 180 | stream.WriteInt32(int32(val)) 181 | } 182 | 183 | // Write an i32 as a varint. Results in 1-5 bytes on the wire. 184 | func (stream *Stream) writeVarInt32(n int32) { 185 | for { 186 | if (n & ^0x7F) == 0 { 187 | stream.buf = append(stream.buf, byte(n)) 188 | break 189 | } else { 190 | stream.buf = append(stream.buf, byte((n&0x7F)|0x80)) 191 | u := uint64(n) 192 | n = int32(u >> 7) 193 | } 194 | } 195 | } 196 | 197 | func (stream *Stream) WriteInt64(val int64) { 198 | stream.writeVarInt64((val << 1) ^ (val >> 63)) 199 | } 200 | 201 | // Write an i64 as a varint. Results in 1-10 bytes on the wire. 202 | func (stream *Stream) writeVarInt64(n int64) { 203 | for { 204 | if (n & ^0x7F) == 0 { 205 | stream.buf = append(stream.buf, byte(n)) 206 | break 207 | } else { 208 | stream.buf = append(stream.buf, byte((n&0x7F)|0x80)) 209 | u := uint64(n) 210 | n = int64(u >> 7) 211 | } 212 | } 213 | } 214 | 215 | func (stream *Stream) WriteUint64(val uint64) { 216 | stream.WriteInt64(int64(val)) 217 | } 218 | 219 | func (stream *Stream) WriteInt(val int) { 220 | stream.WriteInt64(int64(val)) 221 | } 222 | 223 | func (stream *Stream) WriteUint(val uint) { 224 | stream.WriteUint64(uint64(val)) 225 | } 226 | 227 | func (stream *Stream) WriteFloat64(val float64) { 228 | bits := math.Float64bits(val) 229 | stream.buf = append(stream.buf, 230 | byte(bits), 231 | byte(bits>>8), 232 | byte(bits>>16), 233 | byte(bits>>24), 234 | byte(bits>>32), 235 | byte(bits>>40), 236 | byte(bits>>48), 237 | byte(bits>>56), 238 | ) 239 | } 240 | 241 | func (stream *Stream) WriteBinary(val []byte) { 242 | stream.writeVarInt32(int32(len(val))) 243 | stream.buf = append(stream.buf, val...) 244 | } 245 | 246 | func (stream *Stream) WriteString(val string) { 247 | stream.writeVarInt32(int32(len(val))) 248 | stream.buf = append(stream.buf, val...) 249 | } 250 | -------------------------------------------------------------------------------- /protocol/compact/type.go: -------------------------------------------------------------------------------- 1 | package compact 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/protocol" 5 | ) 6 | 7 | type TCompactType byte 8 | 9 | const ( 10 | TypeStop TCompactType = 0x00 11 | TypeBooleanTrue TCompactType = 0x01 12 | TypeBooleanFalse TCompactType = 0x02 13 | TypeByte TCompactType = 0x03 14 | TypeI16 TCompactType = 0x04 15 | TypeI32 TCompactType = 0x05 16 | TypeI64 TCompactType = 0x06 17 | TypeDouble TCompactType = 0x07 18 | TypeBinary TCompactType = 0x08 19 | TypeList TCompactType = 0x09 20 | TypeSet TCompactType = 0x0A 21 | TypeMap TCompactType = 0x0B 22 | TypeStruct TCompactType = 0x0C 23 | ) 24 | 25 | var compactTypes = map[protocol.TType]TCompactType{ 26 | protocol.TypeStop: TypeStop, 27 | protocol.TypeBool: TypeBooleanTrue, 28 | protocol.TypeByte: TypeByte, 29 | protocol.TypeI16: TypeI16, 30 | protocol.TypeI32: TypeI32, 31 | protocol.TypeI64: TypeI64, 32 | protocol.TypeDouble: TypeDouble, 33 | protocol.TypeString: TypeBinary, 34 | protocol.TypeList: TypeList, 35 | protocol.TypeSet: TypeSet, 36 | protocol.TypeMap: TypeMap, 37 | protocol.TypeStruct: TypeStruct, 38 | } 39 | 40 | // TType value. 41 | func (t TCompactType) ToTType() protocol.TType { 42 | switch TCompactType(byte(t) & 0x0f) { 43 | case TypeBooleanFalse, TypeBooleanTrue: 44 | return protocol.TypeBool 45 | case TypeByte: 46 | return protocol.TypeByte 47 | case TypeI16: 48 | return protocol.TypeI16 49 | case TypeI32: 50 | return protocol.TypeI32 51 | case TypeI64: 52 | return protocol.TypeI64 53 | case TypeDouble: 54 | return protocol.TypeDouble 55 | case TypeBinary: 56 | return protocol.TypeString 57 | case TypeList: 58 | return protocol.TypeList 59 | case TypeSet: 60 | return protocol.TypeSet 61 | case TypeMap: 62 | return protocol.TypeMap 63 | case TypeStruct: 64 | return protocol.TypeStruct 65 | } 66 | return protocol.TypeStop 67 | } 68 | -------------------------------------------------------------------------------- /protocol/protocol.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | // Type constants in the Thrift protocol 4 | type TType byte 5 | type TMessageType int32 6 | type SeqId int32 7 | type FieldId int16 8 | 9 | const ( 10 | BINARY_VERSION_MASK = 0xffff0000 11 | BINARY_VERSION_1 = 0x80010000 12 | 13 | COMPACT_PROTOCOL_ID = 0x082 14 | COMPACT_VERSION = 0x01 15 | COMPACT_VERSON_BE = 0x02 16 | COMPACT_VERSION_MASK = 0x1f 17 | COMPACT_TYPE_BITS = 0x07 18 | COMPACT_TYPE_SHIFT_AMOUT = 5 19 | ) 20 | 21 | const ( 22 | MessgeTypeInvalid TMessageType = 0 23 | MessageTypeCall TMessageType = 1 24 | MessageTypeReply TMessageType = 2 25 | MessageTypeException TMessageType = 3 26 | MessageTypeOneWay TMessageType = 4 27 | ) 28 | 29 | const ( 30 | TypeStop TType = 0 31 | TypeVoid TType = 1 32 | TypeBool TType = 2 33 | TypeByte TType = 3 34 | TypeI08 TType = 3 35 | TypeDouble TType = 4 36 | TypeI16 TType = 6 37 | TypeI32 TType = 8 38 | TypeI64 TType = 10 39 | TypeString TType = 11 40 | TypeUTF7 TType = 11 41 | TypeStruct TType = 12 42 | TypeMap TType = 13 43 | TypeSet TType = 14 44 | TypeList TType = 15 45 | TypeUTF8 TType = 16 46 | TypeUTF16 TType = 17 47 | ) 48 | 49 | var typeNames = map[TType]string{ 50 | TypeStop: "Stop", 51 | TypeVoid: "Void", 52 | TypeBool: "Bool", 53 | TypeByte: "Byte", 54 | TypeDouble: "Double", 55 | TypeI16: "I16", 56 | TypeI32: "I32", 57 | TypeI64: "I64", 58 | TypeString: "String", 59 | TypeStruct: "Struct", 60 | TypeMap: "Map", 61 | TypeSet: "Set", 62 | TypeList: "List", 63 | TypeUTF8: "UTF8", 64 | TypeUTF16: "UTF16", 65 | } 66 | 67 | func (p TType) String() string { 68 | if s, ok := typeNames[p]; ok { 69 | return s 70 | } 71 | return "Unknown" 72 | } 73 | 74 | type MessageHeader struct { 75 | MessageName string 76 | MessageType TMessageType 77 | SeqId SeqId 78 | } 79 | 80 | type Flusher interface { 81 | Flush() error 82 | } 83 | -------------------------------------------------------------------------------- /raw/decode_list.go: -------------------------------------------------------------------------------- 1 | package raw 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | ) 6 | 7 | type rawListDecoder struct { 8 | } 9 | 10 | func (decoder *rawListDecoder) Decode(val interface{}, iter spi.Iterator) { 11 | elemType, length := iter.ReadListHeader() 12 | elements := make([][]byte, length) 13 | for i := 0; i < length; i++ { 14 | elements[i] = iter.Skip(elemType, nil) 15 | } 16 | obj := val.(*List) 17 | obj.ElementType = elemType 18 | obj.Elements = elements 19 | } -------------------------------------------------------------------------------- /raw/decode_map.go: -------------------------------------------------------------------------------- 1 | package raw 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | type rawMapDecoder struct { 9 | } 10 | 11 | func (decoder *rawMapDecoder) Decode(val interface{}, iter spi.Iterator) { 12 | keyType, elemType, length := iter.ReadMapHeader() 13 | entries := make(map[interface{}]MapEntry, length) 14 | generalKeyReader := readerOf(keyType) 15 | keyIter := iter.Spawn() 16 | for i := 0; i < length; i++ { 17 | keyBuf := iter.Skip(keyType, nil) 18 | key := generalKeyReader(keyBuf, keyIter) 19 | elemBuf := iter.Skip(elemType, nil) 20 | entries[key] = MapEntry{ 21 | Key: keyBuf, 22 | Element: elemBuf, 23 | } 24 | } 25 | obj := val.(*Map) 26 | obj.KeyType = keyType 27 | obj.ElementType = elemType 28 | obj.Entries = entries 29 | } 30 | 31 | func readerOf(valType protocol.TType) func([]byte, spi.Iterator) interface{} { 32 | switch valType { 33 | case protocol.TypeBool: 34 | return readBool 35 | case protocol.TypeI08: 36 | return readInt8 37 | case protocol.TypeI16: 38 | return readInt16 39 | case protocol.TypeI32: 40 | return readInt32 41 | case protocol.TypeI64: 42 | return readInt64 43 | case protocol.TypeDouble: 44 | return readFloat64 45 | case protocol.TypeString: 46 | return readString 47 | default: 48 | panic("unsupported type") 49 | } 50 | } 51 | 52 | func readBool(buf []byte, iter spi.Iterator) interface{} { 53 | iter.Reset(nil, buf) 54 | return iter.ReadBool() 55 | } 56 | 57 | func readInt8(buf []byte, iter spi.Iterator) interface{} { 58 | iter.Reset(nil, buf) 59 | return iter.ReadInt8() 60 | } 61 | 62 | func readInt16(buf []byte, iter spi.Iterator) interface{} { 63 | iter.Reset(nil, buf) 64 | return iter.ReadInt16() 65 | } 66 | 67 | func readInt32(buf []byte, iter spi.Iterator) interface{} { 68 | iter.Reset(nil,buf) 69 | return iter.ReadInt32() 70 | } 71 | 72 | func readInt64(buf []byte, iter spi.Iterator) interface{} { 73 | iter.Reset(nil, buf) 74 | return iter.ReadInt64() 75 | } 76 | 77 | func readFloat64(buf []byte, iter spi.Iterator) interface{} { 78 | iter.Reset(nil,buf) 79 | return iter.ReadFloat64() 80 | } 81 | 82 | func readString(buf []byte, iter spi.Iterator) interface{} { 83 | iter.Reset(nil,buf) 84 | return iter.ReadString() 85 | } -------------------------------------------------------------------------------- /raw/decode_struct.go: -------------------------------------------------------------------------------- 1 | package raw 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | type rawStructDecoder struct { 9 | } 10 | 11 | func (decoder *rawStructDecoder) Decode(val interface{}, iter spi.Iterator) { 12 | fields := Struct{} 13 | iter.ReadStructHeader() 14 | for { 15 | fieldType, fieldId := iter.ReadStructField() 16 | if fieldType == protocol.TypeStop { 17 | *val.(*Struct) = fields 18 | return 19 | } 20 | fields[fieldId] = StructField{ 21 | Type: fieldType, 22 | Buffer: iter.Skip(fieldType, nil), 23 | } 24 | } 25 | } -------------------------------------------------------------------------------- /raw/encode_list.go: -------------------------------------------------------------------------------- 1 | package raw 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | type rawListEncoder struct { 9 | } 10 | 11 | func (encoder *rawListEncoder) Encode(val interface{}, stream spi.Stream) { 12 | obj := val.(List) 13 | stream.WriteListHeader(obj.ElementType, len(obj.Elements)) 14 | for _, elem := range obj.Elements { 15 | stream.Write(elem) 16 | } 17 | } 18 | 19 | func (encoder *rawListEncoder) ThriftType() protocol.TType { 20 | return protocol.TypeList 21 | } -------------------------------------------------------------------------------- /raw/encode_map.go: -------------------------------------------------------------------------------- 1 | package raw 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | type rawMapEncoder struct { 9 | } 10 | 11 | func (encoder *rawMapEncoder) Encode(val interface{}, stream spi.Stream) { 12 | obj := val.(Map) 13 | length := len(obj.Entries) 14 | stream.WriteMapHeader(obj.KeyType, obj.ElementType, length) 15 | for _, entry := range obj.Entries { 16 | stream.Write(entry.Key) 17 | stream.Write(entry.Element) 18 | } 19 | } 20 | 21 | func (encoder *rawMapEncoder) ThriftType() protocol.TType { 22 | return protocol.TypeMap 23 | } -------------------------------------------------------------------------------- /raw/encode_struct.go: -------------------------------------------------------------------------------- 1 | package raw 2 | 3 | import ( 4 | "github.com/thrift-iterator/go/spi" 5 | "github.com/thrift-iterator/go/protocol" 6 | ) 7 | 8 | type rawStructEncoder struct { 9 | } 10 | 11 | func (encoder *rawStructEncoder) Encode(val interface{}, stream spi.Stream) { 12 | obj := val.(Struct) 13 | stream.WriteStructHeader() 14 | for fieldId, field := range obj { 15 | stream.WriteStructField(field.Type, fieldId) 16 | stream.Write(field.Buffer) 17 | } 18 | stream.WriteStructFieldStop() 19 | } 20 | 21 | func (encoder *rawStructEncoder) ThriftType() protocol.TType { 22 | return protocol.TypeStruct 23 | } -------------------------------------------------------------------------------- /raw/raw_extension.go: -------------------------------------------------------------------------------- 1 | package raw 2 | 3 | import ( 4 | "reflect" 5 | "github.com/thrift-iterator/go/spi" 6 | ) 7 | 8 | type Extension struct { 9 | } 10 | 11 | func (extension *Extension) DecoderOf(valType reflect.Type) spi.ValDecoder { 12 | switch valType { 13 | case reflect.TypeOf((*List)(nil)): 14 | return &rawListDecoder{} 15 | case reflect.TypeOf((*Map)(nil)): 16 | return &rawMapDecoder{} 17 | case reflect.TypeOf((*Struct)(nil)): 18 | return &rawStructDecoder{} 19 | } 20 | return nil 21 | } 22 | 23 | func (extension *Extension) EncoderOf(valType reflect.Type) spi.ValEncoder { 24 | switch valType { 25 | case reflect.TypeOf((*List)(nil)).Elem(): 26 | return &rawListEncoder{} 27 | case reflect.TypeOf((*Map)(nil)).Elem(): 28 | return &rawMapEncoder{} 29 | case reflect.TypeOf((*Struct)(nil)).Elem(): 30 | return &rawStructEncoder{} 31 | } 32 | return nil 33 | } -------------------------------------------------------------------------------- /raw/raw_object.go: -------------------------------------------------------------------------------- 1 | package raw 2 | 3 | import "github.com/thrift-iterator/go/protocol" 4 | 5 | type StructField struct { 6 | Buffer []byte 7 | Type protocol.TType 8 | } 9 | 10 | type Struct map[protocol.FieldId]StructField 11 | 12 | type List struct { 13 | ElementType protocol.TType 14 | Elements [][]byte 15 | } 16 | 17 | type MapEntry struct { 18 | Key []byte 19 | Element []byte 20 | } 21 | 22 | type Map struct { 23 | KeyType protocol.TType 24 | ElementType protocol.TType 25 | Entries map[interface{}]MapEntry 26 | } -------------------------------------------------------------------------------- /spi/discard.go: -------------------------------------------------------------------------------- 1 | package spi 2 | 3 | func DiscardList(iter Iterator) { 4 | elemType, size := iter.ReadListHeader() 5 | for i := 0; i < size; i++ { 6 | iter.Discard(elemType) 7 | } 8 | } 9 | 10 | func DiscardStruct(iter Iterator) { 11 | iter.ReadStructHeader() 12 | for { 13 | fieldType, _ := iter.ReadStructField() 14 | if fieldType == 0 { 15 | return 16 | } 17 | iter.Discard(fieldType) 18 | } 19 | } 20 | 21 | func DiscardMap(iter Iterator) { 22 | keyType, elemType, size := iter.ReadMapHeader() 23 | for i := 0; i < size; i++ { 24 | iter.Discard(keyType) 25 | iter.Discard(elemType) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /spi/spi.go: -------------------------------------------------------------------------------- 1 | package spi 2 | 3 | import ( 4 | "io" 5 | "github.com/thrift-iterator/go/protocol" 6 | "reflect" 7 | ) 8 | 9 | type Iterator interface { 10 | ValDecoderProvider 11 | Spawn() Iterator 12 | Error() error 13 | Reset(reader io.Reader, buf []byte) 14 | ReportError(operation string, err string) 15 | ReadMessageHeader() protocol.MessageHeader 16 | SkipMessageHeader(space []byte) []byte 17 | ReadStructHeader() 18 | ReadStructField() (fieldType protocol.TType, fieldId protocol.FieldId) 19 | SkipStruct(space []byte) []byte 20 | ReadListHeader() (elemType protocol.TType, size int) 21 | SkipList(space []byte) []byte 22 | ReadMapHeader() (keyType protocol.TType, elemType protocol.TType, size int) 23 | SkipMap(space []byte) []byte 24 | ReadBool() bool 25 | ReadInt() int 26 | ReadUint() uint 27 | ReadInt8() int8 28 | ReadUint8() uint8 29 | ReadInt16() int16 30 | ReadUint16() uint16 31 | ReadInt32() int32 32 | ReadUint32() uint32 33 | ReadInt64() int64 34 | ReadUint64() uint64 35 | ReadFloat64() float64 36 | ReadString() string 37 | ReadBinary() []byte 38 | SkipBinary(space []byte) []byte 39 | Skip(ttype protocol.TType, space []byte) []byte 40 | Discard(ttype protocol.TType) 41 | } 42 | 43 | type Stream interface { 44 | ValEncoderProvider 45 | Spawn() Stream 46 | Error() error 47 | ReportError(operation string, err string) 48 | Reset(writer io.Writer) 49 | Flush() 50 | Buffer() []byte 51 | Write(buf []byte) error 52 | WriteMessageHeader(header protocol.MessageHeader) 53 | WriteListHeader(elemType protocol.TType, length int) 54 | WriteStructHeader() 55 | WriteStructField(fieldType protocol.TType, fieldId protocol.FieldId) 56 | WriteStructFieldStop() 57 | WriteMapHeader(keyType protocol.TType, elemType protocol.TType, length int) 58 | WriteBool(val bool) 59 | WriteInt(val int) 60 | WriteUint(val uint) 61 | WriteInt8(val int8) 62 | WriteUint8(val uint8) 63 | WriteInt16(val int16) 64 | WriteUint16(val uint16) 65 | WriteInt32(val int32) 66 | WriteUint32(val uint32) 67 | WriteInt64(val int64) 68 | WriteUint64(val uint64) 69 | WriteFloat64(val float64) 70 | WriteBinary(val []byte) 71 | WriteString(val string) 72 | } 73 | 74 | type ValEncoder interface { 75 | Encode(val interface{}, stream Stream) 76 | ThriftType() protocol.TType 77 | } 78 | 79 | type ValDecoder interface { 80 | Decode(val interface{}, iter Iterator) 81 | } 82 | 83 | type ValDecoderProvider interface { 84 | PrepareDecoder(valType reflect.Type) 85 | GetDecoder(decoderName string) ValDecoder 86 | } 87 | 88 | type ValEncoderProvider interface { 89 | PrepareEncoder(valType reflect.Type) 90 | GetEncoder(encoderName string) ValEncoder 91 | } 92 | 93 | type Extension interface { 94 | DecoderOf(valType reflect.Type) ValDecoder 95 | EncoderOf(valType reflect.Type) ValEncoder 96 | } 97 | 98 | type DummyExtension struct { 99 | } 100 | 101 | func (extension *DummyExtension) DecoderOf(valType reflect.Type) ValDecoder { 102 | return nil 103 | } 104 | 105 | func (extension *DummyExtension) EncoderOf(valType reflect.Type) ValEncoder { 106 | return nil 107 | } 108 | 109 | type Extensions []Extension 110 | 111 | func (extensions Extensions) DecoderOf(valType reflect.Type) ValDecoder { 112 | for _, extension := range extensions { 113 | decoder := extension.DecoderOf(valType) 114 | if decoder != nil { 115 | return decoder 116 | } 117 | } 118 | return nil 119 | } 120 | 121 | func (extensions Extensions) EncoderOf(valType reflect.Type) ValEncoder { 122 | for _, extension := range extensions { 123 | encoder := extension.EncoderOf(valType) 124 | if encoder != nil { 125 | return encoder 126 | } 127 | } 128 | return nil 129 | } -------------------------------------------------------------------------------- /test/api/api_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "encoding/hex" 6 | "github.com/stretchr/testify/require" 7 | "github.com/thrift-iterator/go" 8 | "fmt" 9 | "github.com/thrift-iterator/go/protocol" 10 | "bytes" 11 | "github.com/thrift-iterator/go/general" 12 | ) 13 | 14 | type combination struct { 15 | encoded string 16 | api thrifter.API 17 | } 18 | 19 | var combinations = []combination{ 20 | { 21 | encoded: "800100010000000568656c6c6f0000000c0b00010000000a73657373696f6e2d69640c00020c00010a000100000000000000010a000200000000000000000b00030000000f43616c6c46726f6d496e626f756e64000c00020b0001000000093132372e302e302e310a000200000000000004d2000b00030000000568656c6c6f000c00030c00010a000100000000000000020a000200000000000000000b00030000000d52657475726e496e626f756e64000b000200000005776f726c64000f00040c000000010c00020c00010a000100000000000000020a000200000000000000000b00030000000d52657475726e496e626f756e64000b000200000005776f726c64000000", 22 | api: thrifter.Config{Protocol: thrifter.ProtocolBinary}.Froze(), 23 | }, 24 | { 25 | encoded: "82210c0568656c6c6f180a73657373696f6e2d69641c1c16021600180f43616c6c46726f6d496e626f756e64001c18093132372e302e302e3116a41300180568656c6c6f001c1c16041600180d52657475726e496e626f756e64001805776f726c6400191c2c1c16041600180d52657475726e496e626f756e64001805776f726c64000000", 26 | api: thrifter.Config{Protocol: thrifter.ProtocolCompact}.Froze(), 27 | }, 28 | } 29 | 30 | func Test_unmarshal_message(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range combinations { 33 | input, err := hex.DecodeString(c.encoded) 34 | should.NoError(err) 35 | var msg general.Message 36 | err = c.api.Unmarshal(input, &msg) 37 | should.NoError(err) 38 | fmt.Println(msg.MessageType) 39 | fmt.Println(msg.MessageName) 40 | for fieldId, fieldValue := range msg.Arguments { 41 | fmt.Println("!!!", fieldId, fieldValue) 42 | } 43 | } 44 | } 45 | 46 | func Test_marshal_message(t *testing.T) { 47 | should := require.New(t) 48 | msg := general.Message{ 49 | MessageHeader: protocol.MessageHeader{ 50 | MessageType: protocol.MessageTypeCall, 51 | MessageName: "hello", 52 | SeqId: protocol.SeqId(17), 53 | }, 54 | Arguments: general.Struct{ 55 | protocol.FieldId(1): int64(1), 56 | protocol.FieldId(2): int64(2), 57 | }, 58 | } 59 | output, err := thrifter.Marshal(msg) 60 | should.Nil(err) 61 | var msgRead general.Message 62 | err = thrifter.Unmarshal(output, &msgRead) 63 | should.NoError(err) 64 | fmt.Println(msgRead.MessageType) 65 | fmt.Println(msgRead.MessageName) 66 | for fieldId, fieldValue := range msgRead.Arguments { 67 | fmt.Println(fieldId, fieldValue) 68 | } 69 | } 70 | 71 | func Test_decode_message(t *testing.T) { 72 | should := require.New(t) 73 | input, err := hex.DecodeString("800100010000000568656c6c6f0000000c0b00010000000a73657373696f6e2d69640c00020c00010a000100000000000000010a000200000000000000000b00030000000f43616c6c46726f6d496e626f756e64000c00020b0001000000093132372e302e302e310a000200000000000004d2000b00030000000568656c6c6f000c00030c00010a000100000000000000020a000200000000000000000b00030000000d52657475726e496e626f756e64000b000200000005776f726c64000f00040c000000010c00020c00010a000100000000000000020a000200000000000000000b00030000000d52657475726e496e626f756e64000b000200000005776f726c64000000") 74 | should.NoError(err) 75 | reader := bytes.NewBuffer(input) 76 | cfg := thrifter.Config{Protocol: thrifter.ProtocolBinary}.Froze() 77 | decoder := cfg.NewDecoder(reader, nil) 78 | var msg general.Message 79 | should.NoError(decoder.Decode(&msg)) 80 | fmt.Println(msg.MessageType) 81 | fmt.Println(msg.MessageName) 82 | for fieldId, fieldValue := range msg.Arguments { 83 | fmt.Println(fieldId, fieldValue) 84 | } 85 | } 86 | 87 | func Test_encode_message(t *testing.T) { 88 | should := require.New(t) 89 | msg := general.Message{ 90 | MessageHeader: protocol.MessageHeader{ 91 | MessageType: protocol.MessageTypeCall, 92 | MessageName: "hello", 93 | SeqId: protocol.SeqId(17), 94 | }, 95 | Arguments: general.Struct{ 96 | protocol.FieldId(1): int64(1), 97 | protocol.FieldId(2): int64(2), 98 | }, 99 | } 100 | var msgRead general.Message 101 | buf := bytes.NewBuffer(nil) 102 | cfg := thrifter.Config{Protocol: thrifter.ProtocolBinary}.Froze() 103 | encoder := cfg.NewEncoder(buf) 104 | should.NoError(encoder.Encode(msg)) 105 | err := cfg.Unmarshal(buf.Bytes(), &msgRead) 106 | should.NoError(err) 107 | fmt.Println(msgRead.MessageType) 108 | fmt.Println(msgRead.MessageName) 109 | for fieldId, fieldValue := range msgRead.Arguments { 110 | fmt.Println(fieldId, fieldValue) 111 | } 112 | } 113 | 114 | type Foo struct { 115 | Sa string `thrift:"Sa,1" json:"Sa"` 116 | Ib int32 `thrift:"Ib,2" json:"Ib"` 117 | Lc []string `thrift:"Lc,3" json:"Lc"` 118 | } 119 | 120 | type Example struct { 121 | Name string `thrift:"Name,1" json:"Name"` 122 | Ia int64 `thrift:"Ia,2" json:"Ia"` 123 | Lb []string `thrift:"Lb,3" json:"Lb"` 124 | Mc map[string]*Foo `thrift:"Mc,4" json:"Mc"` 125 | } 126 | 127 | func TestPanic(t *testing.T) { 128 | var example = Example{ 129 | Name: "xxxxxxxxxxxxxxxx", 130 | Ia: 12345678, 131 | Lb: []string{"a", "b", "c", "d", "1", "2", "3", "4", "5"}, 132 | Mc: map[string]*Foo{ 133 | "t1": &Foo{Sa: "sss", Ib: 987654321, Lc: []string{"1", "2", "3"}}, 134 | }, 135 | } 136 | _, err := thrifter.Marshal(example) 137 | if err != nil { 138 | t.Fail() 139 | } 140 | } -------------------------------------------------------------------------------- /test/api/binding_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/test/api/binding_test" 8 | ) 9 | 10 | func Test_binding(t *testing.T) { 11 | should := require.New(t) 12 | buf := thrift.NewTMemoryBuffer() 13 | transport := thrift.NewTFramedTransport(buf) 14 | proto := thrift.NewTBinaryProtocol(transport, true, true) 15 | proto.WriteStructBegin("hello") 16 | proto.WriteFieldBegin("field1", thrift.I64, 1) 17 | proto.WriteI64(1024) 18 | proto.WriteFieldEnd() 19 | proto.WriteFieldStop() 20 | proto.WriteStructEnd() 21 | transport.Flush() 22 | var val binding_test.TestObject 23 | should.NoError(api.Unmarshal(buf.Bytes()[4:], &val)) 24 | should.Equal(int64(1024), val.Field1) 25 | } 26 | -------------------------------------------------------------------------------- /test/api/binding_test/model.go: -------------------------------------------------------------------------------- 1 | package binding_test 2 | 3 | type TestObject struct { 4 | Field1 int64 `thrift:"field1,1"` 5 | } 6 | -------------------------------------------------------------------------------- /test/api/generated.go: -------------------------------------------------------------------------------- 1 | 2 | package test 3 | import "github.com/v2pro/wombat/generic" 4 | import "reflect" 5 | import "github.com/thrift-iterator/go/test/api/binding_test" 6 | import "github.com/thrift-iterator/go/protocol/binary" 7 | func init() { 8 | generic.RegisterExpandedFunc("Decode_DT_ptr_binding_test__TestObject_EXT_default_ST_ptr_binary__Iterator",Decode_DT_ptr_binding_test__TestObject_EXT_default_ST_ptr_binary__Iterator)} 9 | var typeOf = reflect.TypeOf 10 | func DecodeSimpleValue_DT_ptr_int64_EXT_default_ST_ptr_binary__Iterator(dst *int64,src *binary.Iterator){ 11 | *dst = int64(src.ReadInt64()) 12 | 13 | } 14 | func DecodeAnything_DT_ptr_int64_EXT_default_ST_ptr_binary__Iterator(dst *int64,src *binary.Iterator){ 15 | 16 | 17 | 18 | DecodeSimpleValue_DT_ptr_int64_EXT_default_ST_ptr_binary__Iterator(dst, src) 19 | 20 | 21 | } 22 | func DecodeStruct_DT_ptr_binding_test__TestObject_EXT_default_ST_ptr_binary__Iterator(dst *binding_test.TestObject,src *binary.Iterator){ 23 | 24 | 25 | 26 | 27 | 28 | src.ReadStructHeader() 29 | for { 30 | fieldType, fieldId := src.ReadStructField() 31 | if fieldType == 0 { 32 | return 33 | } 34 | switch fieldId { 35 | 36 | case 1: 37 | DecodeAnything_DT_ptr_int64_EXT_default_ST_ptr_binary__Iterator(&dst.Field1, src) 38 | 39 | default: 40 | src.Discard(fieldType) 41 | } 42 | } 43 | } 44 | func DecodeAnything_DT_ptr_binding_test__TestObject_EXT_default_ST_ptr_binary__Iterator(dst *binding_test.TestObject,src *binary.Iterator){ 45 | 46 | 47 | 48 | DecodeStruct_DT_ptr_binding_test__TestObject_EXT_default_ST_ptr_binary__Iterator(dst, src) 49 | 50 | 51 | } 52 | func Decode_DT_ptr_binding_test__TestObject_EXT_default_ST_ptr_binary__Iterator(dst interface{},src interface{}){ 53 | 54 | iter := src.(*binary.Iterator) 55 | 56 | DecodeAnything_DT_ptr_binding_test__TestObject_EXT_default_ST_ptr_binary__Iterator(dst.(*binding_test.TestObject), iter) 57 | 58 | } -------------------------------------------------------------------------------- /test/api/init.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "github.com/v2pro/wombat/generic" 5 | "github.com/thrift-iterator/go" 6 | "github.com/thrift-iterator/go/test/api/binding_test" 7 | ) 8 | 9 | var api = thrifter.Config{ 10 | Protocol: thrifter.ProtocolBinary, 11 | }.Froze() 12 | 13 | //go:generate go install github.com/thrift-iterator/go/cmd/thrifter 14 | //go:generate $GOPATH/bin/thrifter -pkg github.com/thrift-iterator/go/test/api 15 | func init() { 16 | generic.Declare(func() { 17 | api.WillDecodeFromBuffer( 18 | (*binding_test.TestObject)(nil), 19 | ) 20 | }) 21 | } 22 | -------------------------------------------------------------------------------- /test/api/raw_message_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go" 7 | "github.com/thrift-iterator/go/general" 8 | "fmt" 9 | "github.com/thrift-iterator/go/raw" 10 | "github.com/thrift-iterator/go/protocol" 11 | ) 12 | 13 | func Test_decode_struct_of_raw_message(t *testing.T) { 14 | should := require.New(t) 15 | api := thrifter.Config{Protocol: thrifter.ProtocolBinary, StaticCodegen: false}.Froze() 16 | output, err := api.Marshal(general.Struct{ 17 | 0: general.Map{ 18 | "key1": "value1", 19 | }, 20 | 1: "hello", 21 | }) 22 | should.Nil(err) 23 | rawStruct := raw.Struct{} 24 | should.NoError(api.Unmarshal(output, &rawStruct)) 25 | // parse arg1 26 | var arg1 string 27 | should.NoError(api.Unmarshal(rawStruct[protocol.FieldId(1)].Buffer, &arg1)) 28 | should.Equal("hello", arg1) 29 | // parse arg0 30 | var arg0 map[string]string 31 | should.NoError(api.Unmarshal(rawStruct[protocol.FieldId(0)].Buffer, &arg0)) 32 | should.Equal(map[string]string{"key1": "value1"}, arg0) 33 | // modify arg0 34 | arg0["key2"] = "value2" 35 | encodedArg0, err := api.Marshal(arg0) 36 | should.NoError(err) 37 | // set arg0 back 38 | rawStruct[protocol.FieldId(0)] = raw.StructField{ 39 | Buffer: encodedArg0, 40 | Type: protocol.TypeMap, 41 | } 42 | encodedArgs, err := api.Marshal(rawStruct) 43 | should.NoError(err) 44 | // verify it is changed 45 | var val general.Struct 46 | should.NoError(api.Unmarshal(encodedArgs, &val)) 47 | fmt.Println(val) 48 | } -------------------------------------------------------------------------------- /test/api/shortcut_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "encoding/hex" 6 | "github.com/stretchr/testify/require" 7 | "github.com/thrift-iterator/go" 8 | "fmt" 9 | ) 10 | 11 | func Test_to_json(t *testing.T) { 12 | should := require.New(t) 13 | input, err := hex.DecodeString("800100010000000568656c6c6f0000000c0b00010000000a73657373696f6e2d69640c00020c00010a000100000000000000010a000200000000000000000b00030000000f43616c6c46726f6d496e626f756e64000c00020b0001000000093132372e302e302e310a000200000000000004d2000b00030000000568656c6c6f000c00030c00010a000100000000000000020a000200000000000000000b00030000000d52657475726e496e626f756e64000b000200000005776f726c64000f00040c000000010c00020c00010a000100000000000000020a000200000000000000000b00030000000d52657475726e496e626f756e64000b000200000005776f726c64000000") 14 | should.NoError(err) 15 | json, err := thrifter.ToJSON(input) 16 | should.NoError(err) 17 | fmt.Println(json) 18 | } 19 | -------------------------------------------------------------------------------- /test/combinations.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "github.com/thrift-iterator/go" 5 | "bytes" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/spi" 8 | ) 9 | 10 | type Combination struct { 11 | CreateProtocol func() (*thrift.TMemoryBuffer, thrift.TProtocol) 12 | CreateStream func() spi.Stream 13 | CreateIterator func(buf []byte) spi.Iterator 14 | Unmarshal func(buf []byte, val interface{}) error 15 | Marshal func(val interface{}) ([]byte, error) 16 | } 17 | 18 | var binaryCfg = thrifter.Config{Protocol: thrifter.ProtocolBinary} 19 | var binary = Combination{ 20 | CreateProtocol: func() (*thrift.TMemoryBuffer, thrift.TProtocol) { 21 | buf := thrift.NewTMemoryBuffer() 22 | proto := thrift.NewTBinaryProtocol(buf, true, true) 23 | return buf, proto 24 | }, 25 | CreateStream: func() spi.Stream { 26 | return binaryCfg.Froze().NewStream(nil, nil) 27 | }, 28 | CreateIterator: func(buf []byte) spi.Iterator { 29 | return binaryCfg.Froze().NewIterator(nil, buf) 30 | }, 31 | Unmarshal: func(buf []byte, val interface{}) error { 32 | return binaryCfg.Froze().Unmarshal(buf, val) 33 | }, 34 | Marshal: func(val interface{}) ([]byte, error) { 35 | return binaryCfg.Froze().Marshal(val) 36 | }, 37 | } 38 | 39 | var binaryEncoderDecoder = Combination{ 40 | CreateProtocol: func() (*thrift.TMemoryBuffer, thrift.TProtocol) { 41 | buf := thrift.NewTMemoryBuffer() 42 | proto := thrift.NewTBinaryProtocol(buf, true, true) 43 | return buf, proto 44 | }, 45 | CreateStream: func() spi.Stream { 46 | return binaryCfg.Froze().NewStream(nil, nil) 47 | }, 48 | CreateIterator: func(buf []byte) spi.Iterator { 49 | return binaryCfg.Froze().NewIterator(bytes.NewBuffer(buf), nil) 50 | }, 51 | Unmarshal: func(buf []byte, val interface{}) error { 52 | decoder := binaryCfg.Froze().NewDecoder(bytes.NewBuffer(buf), nil) 53 | return decoder.Decode(val) 54 | }, 55 | Marshal: func(val interface{}) ([]byte, error) { 56 | encoder := binaryCfg.Froze().NewEncoder(nil) 57 | err := encoder.Encode(val) 58 | if err != nil { 59 | return nil, err 60 | } 61 | return encoder.Buffer(), nil 62 | }, 63 | } 64 | 65 | var compactCfg = thrifter.Config{Protocol: thrifter.ProtocolCompact} 66 | var compact = Combination{ 67 | CreateProtocol: func() (*thrift.TMemoryBuffer, thrift.TProtocol) { 68 | buf := thrift.NewTMemoryBuffer() 69 | proto := thrift.NewTCompactProtocol(buf) 70 | return buf, proto 71 | }, 72 | CreateStream: func() spi.Stream { 73 | return compactCfg.Froze().NewStream(nil, nil) 74 | }, 75 | CreateIterator: func(buf []byte) spi.Iterator { 76 | return compactCfg.Froze().NewIterator(nil, buf) 77 | }, 78 | Unmarshal: func(buf []byte, val interface{}) error { 79 | return compactCfg.Froze().Unmarshal(buf, val) 80 | }, 81 | Marshal: func(val interface{}) ([]byte, error) { 82 | return compactCfg.Froze().Marshal(val) 83 | }, 84 | } 85 | 86 | var compactEncoderDecoder = Combination{ 87 | CreateProtocol: func() (*thrift.TMemoryBuffer, thrift.TProtocol) { 88 | buf := thrift.NewTMemoryBuffer() 89 | proto := thrift.NewTCompactProtocol(buf) 90 | return buf, proto 91 | }, 92 | CreateStream: func() spi.Stream { 93 | return compactCfg.Froze().NewStream(nil, nil) 94 | }, 95 | CreateIterator: func(buf []byte) spi.Iterator { 96 | return compactCfg.Froze().NewIterator(bytes.NewBuffer(buf), nil) 97 | }, 98 | Unmarshal: func(buf []byte, val interface{}) error { 99 | decoder := compactCfg.Froze().NewDecoder(bytes.NewBuffer(buf), nil) 100 | return decoder.Decode(val) 101 | }, 102 | Marshal: func(val interface{}) ([]byte, error) { 103 | encoder := compactCfg.Froze().NewEncoder(nil) 104 | err := encoder.Encode(val) 105 | if err != nil { 106 | return nil, err 107 | } 108 | return encoder.Buffer(), nil 109 | }, 110 | } 111 | 112 | var binaryDynamicCfg = thrifter.Config{Protocol: thrifter.ProtocolBinary, StaticCodegen: false} 113 | var binaryDynamic = Combination{ 114 | CreateProtocol: func() (*thrift.TMemoryBuffer, thrift.TProtocol) { 115 | buf := thrift.NewTMemoryBuffer() 116 | proto := thrift.NewTBinaryProtocol(buf, true, true) 117 | return buf, proto 118 | }, 119 | CreateIterator: func(buf []byte) spi.Iterator { 120 | return binaryDynamicCfg.Froze().NewIterator(nil, buf) 121 | }, 122 | Unmarshal: func(buf []byte, val interface{}) error { 123 | return binaryDynamicCfg.Froze().Unmarshal(buf, val) 124 | }, 125 | Marshal: func(val interface{}) ([]byte, error) { 126 | return binaryDynamicCfg.Froze().Marshal(val) 127 | }, 128 | } 129 | var compactDynamicCfg = thrifter.Config{Protocol: thrifter.ProtocolCompact, StaticCodegen: false} 130 | var compactDynamic = Combination{ 131 | CreateProtocol: func() (*thrift.TMemoryBuffer, thrift.TProtocol) { 132 | buf := thrift.NewTMemoryBuffer() 133 | proto := thrift.NewTCompactProtocol(buf) 134 | return buf, proto 135 | }, 136 | CreateIterator: func(buf []byte) spi.Iterator { 137 | return compactDynamicCfg.Froze().NewIterator(nil, buf) 138 | }, 139 | Unmarshal: func(buf []byte, val interface{}) error { 140 | return compactDynamicCfg.Froze().Unmarshal(buf, val) 141 | }, 142 | Marshal: func(val interface{}) ([]byte, error) { 143 | return compactDynamicCfg.Froze().Marshal(val) 144 | }, 145 | } 146 | 147 | var Combinations = []Combination{ 148 | binary, binaryEncoderDecoder, compact, compactEncoderDecoder, 149 | } 150 | 151 | var UnmarshalCombinations = append(Combinations, 152 | binaryDynamic, compactDynamic) 153 | var MarshalCombinations = UnmarshalCombinations 154 | -------------------------------------------------------------------------------- /test/level_0/bool_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "github.com/stretchr/testify/require" 5 | "github.com/thrift-iterator/go/test" 6 | "testing" 7 | ) 8 | 9 | func Test_decode_bool(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteBool(true) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(true, iter.ReadBool()) 16 | 17 | buf, proto = c.CreateProtocol() 18 | proto.WriteBool(false) 19 | iter = c.CreateIterator(buf.Bytes()) 20 | should.Equal(false, iter.ReadBool()) 21 | } 22 | } 23 | 24 | func Test_unmarshal_bool(t *testing.T) { 25 | should := require.New(t) 26 | for _, c := range test.UnmarshalCombinations { 27 | buf, proto := c.CreateProtocol() 28 | var val1 bool 29 | proto.WriteBool(true) 30 | should.NoError(c.Unmarshal(buf.Bytes(), &val1)) 31 | should.Equal(true, val1) 32 | 33 | buf, proto = c.CreateProtocol() 34 | var val2 bool = true 35 | proto.WriteBool(false) 36 | should.NoError(c.Unmarshal(buf.Bytes(), &val2)) 37 | should.Equal(false, val2) 38 | } 39 | } 40 | 41 | func Test_encode_bool(t *testing.T) { 42 | should := require.New(t) 43 | for _, c := range test.Combinations { 44 | stream := c.CreateStream() 45 | stream.WriteBool(true) 46 | iter := c.CreateIterator(stream.Buffer()) 47 | should.Equal(true, iter.ReadBool()) 48 | 49 | stream = c.CreateStream() 50 | stream.WriteBool(false) 51 | iter = c.CreateIterator(stream.Buffer()) 52 | should.Equal(false, iter.ReadBool()) 53 | } 54 | } 55 | 56 | func Test_marshal_bool(t *testing.T) { 57 | should := require.New(t) 58 | for _, c := range test.MarshalCombinations { 59 | output, err := c.Marshal(true) 60 | should.NoError(err) 61 | iter := c.CreateIterator(output) 62 | should.Equal(true, iter.ReadBool()) 63 | 64 | output, err = c.Marshal(false) 65 | should.NoError(err) 66 | iter = c.CreateIterator(output) 67 | should.Equal(false, iter.ReadBool()) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /test/level_0/enum_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | "github.com/thrift-iterator/go/test/level_0/enum_test" 8 | ) 9 | 10 | func Test_unmarshal_enum(t *testing.T) { 11 | should := require.New(t) 12 | for _, c := range test.UnmarshalCombinations { 13 | buf, proto := c.CreateProtocol() 14 | proto.WriteI32(1) 15 | var val enum_test.Player 16 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 17 | should.Equal(enum_test.Player_FLASH, val) 18 | } 19 | } 20 | 21 | func Test_marshal_enum(t *testing.T) { 22 | should := require.New(t) 23 | for _, c := range test.MarshalCombinations { 24 | output, err := c.Marshal(enum_test.Player_FLASH) 25 | should.NoError(err) 26 | iter := c.CreateIterator(output) 27 | should.Equal(int32(1), iter.ReadInt32()) 28 | } 29 | } -------------------------------------------------------------------------------- /test/level_0/enum_test/Player.go: -------------------------------------------------------------------------------- 1 | package enum_test 2 | 3 | type Player int64 4 | const ( 5 | Player_JAVA Player = 0 6 | Player_FLASH Player = 1 7 | ) 8 | 9 | func (p Player) String() string { 10 | switch p { 11 | case Player_JAVA: return "JAVA" 12 | case Player_FLASH: return "FLASH" 13 | } 14 | return "" 15 | } -------------------------------------------------------------------------------- /test/level_0/float64_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_float64(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteDouble(10.24) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(10.24, iter.ReadFloat64()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_float64(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteDouble(10.24) 24 | var val float64 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal(10.24, val) 27 | } 28 | } 29 | 30 | func Test_encode_float64(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteFloat64(10.24) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal(10.24, iter.ReadFloat64()) 37 | } 38 | } 39 | 40 | func Test_marshal_float64(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal(10.24) 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal(10.24, iter.ReadFloat64()) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /test/level_0/int16_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_int16(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteI16(-1) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(int16(-1), iter.ReadInt16()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_int16(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteI16(-1) 24 | var val int16 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal(int16(-1), val) 27 | } 28 | } 29 | 30 | func Test_encode_int16(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteInt16(-1) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal(int16(-1), iter.ReadInt16()) 37 | } 38 | } 39 | 40 | func Test_marshal_int16(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal(int16(-1)) 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal(int16(-1), iter.ReadInt16()) 47 | } 48 | } -------------------------------------------------------------------------------- /test/level_0/int32_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_int32(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteI32(-1) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(int32(-1), iter.ReadInt32()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_int32(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteI32(-1) 24 | var val int32 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal(int32(-1), val) 27 | } 28 | } 29 | 30 | func Test_encode_int32(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteInt32(-1) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal(int32(-1), iter.ReadInt32()) 37 | } 38 | } 39 | 40 | func Test_marshal_int32(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal(int32(-1)) 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal(int32(-1), iter.ReadInt32()) 47 | } 48 | } -------------------------------------------------------------------------------- /test/level_0/int64_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_int64(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteI64(-1) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(int64(-1), iter.ReadInt64()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_int64(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteI64(-1) 24 | var val int64 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal(int64(-1), val) 27 | } 28 | } 29 | 30 | func Test_encode_int64(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteInt64(-1) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal(int64(-1), iter.ReadInt64()) 37 | } 38 | } 39 | 40 | func Test_marshal_int64(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal(int64(-1)) 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal(int64(-1), iter.ReadInt64()) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /test/level_0/int8_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_int8(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteByte(-1) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(int8(-1), iter.ReadInt8()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_int8(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteByte(-1) 24 | var val int8 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal(int8(-1), val) 27 | } 28 | } 29 | 30 | func Test_encode_int8(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteInt8(-1) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal(int8(-1), iter.ReadInt8()) 37 | } 38 | } 39 | 40 | func Test_marshal_int8(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal(int8(-1)) 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal(int8(-1), iter.ReadInt8()) 47 | } 48 | } -------------------------------------------------------------------------------- /test/level_0/int_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_int(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteI64(-1) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(int(-1), iter.ReadInt()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_int(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteI64(-1) 24 | var val int 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal(int(-1), val) 27 | } 28 | } 29 | 30 | func Test_encode_int(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteInt(-1) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal(int(-1), iter.ReadInt()) 37 | } 38 | } 39 | 40 | func Test_marshal_int(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal(int(-1)) 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal(int(-1), iter.ReadInt()) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /test/level_0/level_0_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import "github.com/v2pro/wombat/generic" 4 | 5 | func init() { 6 | generic.DynamicCompilationEnabled = true 7 | } -------------------------------------------------------------------------------- /test/level_0/uint16_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_uint16(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteI16(1024) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(uint16(1024), iter.ReadUint16()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_uint16(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteI16(1024) 24 | var val uint16 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal(uint16(1024), val) 27 | } 28 | } 29 | 30 | func Test_encode_uint16(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteUint16(1024) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal(uint16(1024), iter.ReadUint16()) 37 | } 38 | } 39 | 40 | func Test_marshal_uint16(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal(uint16(1024)) 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal(uint16(1024), iter.ReadUint16()) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /test/level_0/uint32_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_uint32(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteI32(1024) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(uint32(1024), iter.ReadUint32()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_uint32(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteI32(1024) 24 | var val uint32 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal(uint32(1024), val) 27 | } 28 | } 29 | 30 | func Test_encode_uint32(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteUint32(1024) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal(uint32(1024), iter.ReadUint32()) 37 | } 38 | } 39 | 40 | func Test_marshal_uint32(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal(uint32(1024)) 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal(uint32(1024), iter.ReadUint32()) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /test/level_0/uint64_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_uint64(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteI64(1024) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(uint64(1024), iter.ReadUint64()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_uint64(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteI64(1024) 24 | var val uint64 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal(uint64(1024), val) 27 | } 28 | } 29 | 30 | func Test_encode_uint64(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteUint64(1024) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal(uint64(1024), iter.ReadUint64()) 37 | } 38 | } 39 | 40 | func Test_marshal_uint64(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal(uint64(1024)) 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal(uint64(1024), iter.ReadUint64()) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /test/level_0/uint8_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_uint8(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteByte(100) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(uint8(100), iter.ReadUint8()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_uint8(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteByte(100) 24 | var val uint8 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal(uint8(100), val) 27 | } 28 | } 29 | 30 | func Test_encode_uint8(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteUint8(100) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal(uint8(100), iter.ReadUint8()) 37 | } 38 | } 39 | 40 | func Test_marshal_uint8(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal(uint8(100)) 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal(uint8(100), iter.ReadUint8()) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /test/level_0/uint_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_uint(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteI64(1024) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal(uint(1024), iter.ReadUint()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_uint(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteI64(1024) 24 | var val uint 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal(uint(1024), val) 27 | } 28 | } 29 | 30 | func Test_encode_uint(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteUint(1024) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal(uint(1024), iter.ReadUint()) 37 | } 38 | } 39 | 40 | func Test_marshal_uint(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal(uint(1024)) 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal(uint(1024), iter.ReadUint()) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /test/level_1/binary_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_binary(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteBinary([]byte("hello")) 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal("hello", string(iter.ReadBinary())) 16 | } 17 | } 18 | 19 | func Test_unmarshal_binary(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteBinary([]byte("hello")) 24 | var val []byte 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal("hello", string(val)) 27 | } 28 | } 29 | 30 | func Test_encode_binary(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteBinary([]byte(`hello world!`)) 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal([]byte(`hello world!`), iter.ReadBinary()) 37 | } 38 | } 39 | 40 | func Test_marshal_binary(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | val := []byte("hello") 44 | output, err := c.Marshal(val) 45 | should.NoError(err) 46 | iter := c.CreateIterator(output) 47 | should.Equal("hello", string(iter.ReadBinary())) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /test/level_1/level_1_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import "github.com/v2pro/wombat/generic" 4 | 5 | func init() { 6 | generic.DynamicCompilationEnabled = true 7 | } -------------------------------------------------------------------------------- /test/level_1/list_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/protocol" 8 | "github.com/thrift-iterator/go/test" 9 | "github.com/thrift-iterator/go/general" 10 | "github.com/thrift-iterator/go/raw" 11 | ) 12 | 13 | func Test_decode_list_by_iterator(t *testing.T) { 14 | should := require.New(t) 15 | for _, c := range test.Combinations { 16 | buf, proto := c.CreateProtocol() 17 | proto.WriteListBegin(thrift.I64, 3) 18 | proto.WriteI64(1) 19 | proto.WriteI64(2) 20 | proto.WriteI64(3) 21 | proto.WriteListEnd() 22 | iter := c.CreateIterator(buf.Bytes()) 23 | elemType, length := iter.ReadListHeader() 24 | should.Equal(protocol.TypeI64, elemType) 25 | should.Equal(3, length) 26 | should.Equal(uint64(1), iter.ReadUint64()) 27 | should.Equal(uint64(2), iter.ReadUint64()) 28 | should.Equal(uint64(3), iter.ReadUint64()) 29 | } 30 | } 31 | 32 | func Test_encode_list_by_stream(t *testing.T) { 33 | should := require.New(t) 34 | for _, c := range test.Combinations { 35 | stream := c.CreateStream() 36 | stream.WriteListHeader(protocol.TypeI64, 3) 37 | stream.WriteUint64(1) 38 | stream.WriteUint64(2) 39 | stream.WriteUint64(3) 40 | iter := c.CreateIterator(stream.Buffer()) 41 | elemType, length := iter.ReadListHeader() 42 | should.Equal(protocol.TypeI64, elemType) 43 | should.Equal(3, length) 44 | should.Equal(uint64(1), iter.ReadUint64()) 45 | should.Equal(uint64(2), iter.ReadUint64()) 46 | should.Equal(uint64(3), iter.ReadUint64()) 47 | } 48 | } 49 | 50 | func Test_skip_list(t *testing.T) { 51 | should := require.New(t) 52 | for _, c := range test.Combinations { 53 | buf, proto := c.CreateProtocol() 54 | proto.WriteListBegin(thrift.I64, 3) 55 | proto.WriteI64(1) 56 | proto.WriteI64(2) 57 | proto.WriteI64(3) 58 | proto.WriteListEnd() 59 | iter := c.CreateIterator(buf.Bytes()) 60 | should.Equal(buf.Bytes(), iter.SkipList(nil)) 61 | } 62 | } 63 | 64 | func Test_unmarshal_general_list(t *testing.T) { 65 | should := require.New(t) 66 | for _, c := range test.Combinations { 67 | buf, proto := c.CreateProtocol() 68 | proto.WriteListBegin(thrift.I64, 3) 69 | proto.WriteI64(1) 70 | proto.WriteI64(2) 71 | proto.WriteI64(3) 72 | proto.WriteListEnd() 73 | var val general.List 74 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 75 | should.Equal(general.List{int64(1), int64(2), int64(3)}, val) 76 | } 77 | } 78 | 79 | func Test_unmarshal_raw_list(t *testing.T) { 80 | should := require.New(t) 81 | for _, c := range test.Combinations { 82 | buf, proto := c.CreateProtocol() 83 | proto.WriteListBegin(thrift.I64, 3) 84 | proto.WriteI64(1) 85 | proto.WriteI64(2) 86 | proto.WriteI64(3) 87 | proto.WriteListEnd() 88 | var val raw.List 89 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 90 | should.Equal(3, len(val.Elements)) 91 | should.Equal(protocol.TypeI64, val.ElementType) 92 | iter := c.CreateIterator(val.Elements[0]) 93 | should.Equal(int64(1), iter.ReadInt64()) 94 | } 95 | } 96 | 97 | func Test_unmarshal_list(t *testing.T) { 98 | should := require.New(t) 99 | for _, c := range test.UnmarshalCombinations { 100 | buf, proto := c.CreateProtocol() 101 | proto.WriteListBegin(thrift.I64, 3) 102 | proto.WriteI64(1) 103 | proto.WriteI64(2) 104 | proto.WriteI64(3) 105 | proto.WriteListEnd() 106 | var val []int64 107 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 108 | should.Equal([]int64{int64(1), int64(2), int64(3)}, val) 109 | } 110 | } 111 | 112 | func Test_marshal_general_list(t *testing.T) { 113 | should := require.New(t) 114 | for _, c := range test.Combinations { 115 | output, err := c.Marshal(general.List{ 116 | int64(1), int64(2), int64(3), 117 | }) 118 | should.NoError(err) 119 | iter := c.CreateIterator(output) 120 | elemType, length := iter.ReadListHeader() 121 | should.Equal(protocol.TypeI64, elemType) 122 | should.Equal(3, length) 123 | should.Equal(uint64(1), iter.ReadUint64()) 124 | should.Equal(uint64(2), iter.ReadUint64()) 125 | should.Equal(uint64(3), iter.ReadUint64()) 126 | } 127 | } 128 | 129 | 130 | func Test_marshal_raw_list(t *testing.T) { 131 | should := require.New(t) 132 | for _, c := range test.Combinations { 133 | buf, proto := c.CreateProtocol() 134 | proto.WriteListBegin(thrift.I64, 3) 135 | proto.WriteI64(1) 136 | proto.WriteI64(2) 137 | proto.WriteI64(3) 138 | proto.WriteListEnd() 139 | var val raw.List 140 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 141 | output, err := c.Marshal(val) 142 | should.NoError(err) 143 | var generalVal general.List 144 | should.NoError(c.Unmarshal(output, &generalVal)) 145 | should.Equal(general.List{int64(1), int64(2), int64(3)}, generalVal) 146 | } 147 | } 148 | 149 | func Test_marshal_list(t *testing.T) { 150 | should := require.New(t) 151 | for _, c := range test.MarshalCombinations { 152 | output, err := c.Marshal([]int64{1, 2, 3}) 153 | should.NoError(err) 154 | iter := c.CreateIterator(output) 155 | elemType, length := iter.ReadListHeader() 156 | should.Equal(protocol.TypeI64, elemType) 157 | should.Equal(3, length) 158 | should.Equal(uint64(1), iter.ReadUint64()) 159 | should.Equal(uint64(2), iter.ReadUint64()) 160 | should.Equal(uint64(3), iter.ReadUint64()) 161 | } 162 | } 163 | 164 | func Test_marshal_empty_list(t *testing.T) { 165 | should := require.New(t) 166 | for _, c := range test.MarshalCombinations { 167 | output, err := c.Marshal([]int64{}) 168 | should.NoError(err) 169 | iter := c.CreateIterator(output) 170 | elemType, length := iter.ReadListHeader() 171 | should.Equal(protocol.TypeI64, elemType) 172 | should.Equal(0, length) 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /test/level_1/map_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/protocol" 8 | "github.com/thrift-iterator/go/test" 9 | "github.com/thrift-iterator/go/general" 10 | "github.com/thrift-iterator/go/raw" 11 | ) 12 | 13 | func Test_decode_map_by_iterator(t *testing.T) { 14 | should := require.New(t) 15 | for _, c := range test.Combinations { 16 | buf, proto := c.CreateProtocol() 17 | proto.WriteMapBegin(thrift.STRING, thrift.I64, 3) 18 | proto.WriteString("k1") 19 | proto.WriteI64(1) 20 | proto.WriteString("k2") 21 | proto.WriteI64(2) 22 | proto.WriteString("k3") 23 | proto.WriteI64(3) 24 | proto.WriteMapEnd() 25 | iter := c.CreateIterator(buf.Bytes()) 26 | keyType, elemType, length := iter.ReadMapHeader() 27 | should.Equal(protocol.TypeString, keyType) 28 | should.Equal(protocol.TypeI64, elemType) 29 | should.Equal(3, length) 30 | should.Equal("k1", iter.ReadString()) 31 | should.Equal(uint64(1), iter.ReadUint64()) 32 | should.Equal("k2", iter.ReadString()) 33 | should.Equal(uint64(2), iter.ReadUint64()) 34 | should.Equal("k3", iter.ReadString()) 35 | should.Equal(uint64(3), iter.ReadUint64()) 36 | } 37 | } 38 | 39 | func Test_encode_map_by_stream(t *testing.T) { 40 | should := require.New(t) 41 | for _, c := range test.Combinations { 42 | stream := c.CreateStream() 43 | stream.WriteMapHeader(protocol.TypeString, protocol.TypeI64, 3) 44 | stream.WriteString("k1") 45 | stream.WriteUint64(1) 46 | stream.WriteString("k2") 47 | stream.WriteUint64(2) 48 | stream.WriteString("k3") 49 | stream.WriteUint64(3) 50 | iter := c.CreateIterator(stream.Buffer()) 51 | keyType, elemType, length := iter.ReadMapHeader() 52 | should.Equal(protocol.TypeString, keyType) 53 | should.Equal(protocol.TypeI64, elemType) 54 | should.Equal(3, length) 55 | should.Equal("k1", iter.ReadString()) 56 | should.Equal(uint64(1), iter.ReadUint64()) 57 | should.Equal("k2", iter.ReadString()) 58 | should.Equal(uint64(2), iter.ReadUint64()) 59 | should.Equal("k3", iter.ReadString()) 60 | should.Equal(uint64(3), iter.ReadUint64()) 61 | } 62 | } 63 | 64 | func Test_skip_map(t *testing.T) { 65 | should := require.New(t) 66 | for _, c := range test.Combinations { 67 | buf, proto := c.CreateProtocol() 68 | proto.WriteMapBegin(thrift.I32, thrift.I64, 3) 69 | proto.WriteI32(1) 70 | proto.WriteI64(1) 71 | proto.WriteI32(2) 72 | proto.WriteI64(2) 73 | proto.WriteI32(3) 74 | proto.WriteI64(3) 75 | proto.WriteMapEnd() 76 | iter := c.CreateIterator(buf.Bytes()) 77 | should.Equal(buf.Bytes(), iter.SkipMap(nil)) 78 | } 79 | } 80 | 81 | func Test_unmarshal_general_map(t *testing.T) { 82 | should := require.New(t) 83 | for _, c := range test.Combinations { 84 | buf, proto := c.CreateProtocol() 85 | proto.WriteMapBegin(thrift.I32, thrift.I64, 3) 86 | proto.WriteI32(1) 87 | proto.WriteI64(1) 88 | proto.WriteI32(2) 89 | proto.WriteI64(2) 90 | proto.WriteI32(3) 91 | proto.WriteI64(3) 92 | proto.WriteMapEnd() 93 | var val general.Map 94 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 95 | should.Equal(general.Map{ 96 | int32(1): int64(1), 97 | int32(2): int64(2), 98 | int32(3): int64(3), 99 | }, val) 100 | } 101 | } 102 | 103 | 104 | func Test_unmarshal_raw_map(t *testing.T) { 105 | should := require.New(t) 106 | for _, c := range test.Combinations { 107 | buf, proto := c.CreateProtocol() 108 | proto.WriteMapBegin(thrift.I32, thrift.I64, 3) 109 | proto.WriteI32(1) 110 | proto.WriteI64(1) 111 | proto.WriteI32(2) 112 | proto.WriteI64(2) 113 | proto.WriteI32(3) 114 | proto.WriteI64(3) 115 | proto.WriteMapEnd() 116 | var val raw.Map 117 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 118 | should.Equal(3, len(val.Entries)) 119 | should.Equal(protocol.TypeI32, val.KeyType) 120 | should.Equal(protocol.TypeI64, val.ElementType) 121 | iter := c.CreateIterator(val.Entries[int32(1)].Element) 122 | should.Equal(int64(1), iter.ReadInt64()) 123 | } 124 | } 125 | 126 | func Test_unmarshal_map(t *testing.T) { 127 | should := require.New(t) 128 | for _, c := range test.UnmarshalCombinations { 129 | buf, proto := c.CreateProtocol() 130 | proto.WriteMapBegin(thrift.I32, thrift.I64, 3) 131 | proto.WriteI32(1) 132 | proto.WriteI64(1) 133 | proto.WriteI32(2) 134 | proto.WriteI64(2) 135 | proto.WriteI32(3) 136 | proto.WriteI64(3) 137 | proto.WriteMapEnd() 138 | val := map[int32]int64{} 139 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 140 | should.Equal(map[int32]int64{ 141 | int32(1): int64(1), 142 | int32(2): int64(2), 143 | int32(3): int64(3), 144 | }, val) 145 | } 146 | } 147 | 148 | func Test_marshal_general_map(t *testing.T) { 149 | should := require.New(t) 150 | for _, c := range test.Combinations { 151 | m := general.Map{ 152 | int32(1): int64(1), 153 | int32(2): int64(2), 154 | int32(3): int64(3), 155 | } 156 | 157 | output, err := c.Marshal(m) 158 | should.NoError(err) 159 | output1, err := c.Marshal(&m) 160 | should.NoError(err) 161 | var val, val1 general.Map 162 | should.NoError(c.Unmarshal(output, &val)) 163 | should.NoError(c.Unmarshal(output1, &val1)) 164 | should.Equal(val, val1) 165 | should.Equal(general.Map{ 166 | int32(1): int64(1), 167 | int32(2): int64(2), 168 | int32(3): int64(3), 169 | }, val) 170 | } 171 | } 172 | 173 | 174 | func Test_marshal_raw_map(t *testing.T) { 175 | should := require.New(t) 176 | for _, c := range test.Combinations { 177 | buf, proto := c.CreateProtocol() 178 | proto.WriteMapBegin(thrift.I32, thrift.I64, 3) 179 | proto.WriteI32(1) 180 | proto.WriteI64(1) 181 | proto.WriteI32(2) 182 | proto.WriteI64(2) 183 | proto.WriteI32(3) 184 | proto.WriteI64(3) 185 | proto.WriteMapEnd() 186 | var val raw.Map 187 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 188 | 189 | output, err := c.Marshal(val) 190 | should.NoError(err) 191 | output1, err := c.Marshal(&val) 192 | should.NoError(err) 193 | var generalVal, generalVal1 general.Map 194 | should.NoError(c.Unmarshal(output, &generalVal)) 195 | should.NoError(c.Unmarshal(output1, &generalVal1)) 196 | should.Equal(generalVal, generalVal1) 197 | should.Equal(general.Map{ 198 | int32(1): int64(1), 199 | int32(2): int64(2), 200 | int32(3): int64(3), 201 | }, generalVal) 202 | } 203 | } 204 | 205 | func Test_marshal_map(t *testing.T) { 206 | should := require.New(t) 207 | for _, c := range test.MarshalCombinations { 208 | m := map[string]int64{ 209 | "k1": int64(1), 210 | "k2": int64(2), 211 | "k3": int64(3), 212 | } 213 | 214 | output, err := c.Marshal(m) 215 | should.NoError(err) 216 | output1, err := c.Marshal(&m) 217 | should.NoError(err) 218 | var val, val1 general.Map 219 | should.NoError(c.Unmarshal(output, &val)) 220 | should.NoError(c.Unmarshal(output1, &val1)) 221 | should.Equal(val, val1) 222 | should.Equal(general.Map{ 223 | "k1": int64(1), 224 | "k2": int64(2), 225 | "k3": int64(3), 226 | }, val) 227 | } 228 | } 229 | 230 | func Test_marshal_empty_map(t *testing.T) { 231 | should := require.New(t) 232 | for _, c := range test.MarshalCombinations { 233 | output, err := c.Marshal(map[string]int64{}) 234 | should.NoError(err) 235 | var val general.Map 236 | should.NoError(c.Unmarshal(output, &val)) 237 | should.Equal(general.Map{}, val) 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /test/level_1/pointer_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_unmarshal_ptr_int64(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.UnmarshalCombinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteI64(2) 14 | proto.WriteListEnd() 15 | var val *int64 16 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 17 | should.Equal(int64(2), *val) 18 | } 19 | } 20 | func Test_marshal_ptr_int64(t *testing.T) { 21 | should := require.New(t) 22 | for _, c := range test.MarshalCombinations { 23 | val := int64(2) 24 | output, err := c.Marshal(&val) 25 | should.NoError(err) 26 | iter := c.CreateIterator(output) 27 | should.Equal(int64(2), iter.ReadInt64()) 28 | } 29 | } -------------------------------------------------------------------------------- /test/level_1/string_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | ) 8 | 9 | func Test_decode_string(t *testing.T) { 10 | should := require.New(t) 11 | for _, c := range test.Combinations { 12 | buf, proto := c.CreateProtocol() 13 | proto.WriteString("hello") 14 | iter := c.CreateIterator(buf.Bytes()) 15 | should.Equal("hello", iter.ReadString()) 16 | } 17 | } 18 | 19 | func Test_unmarshal_string(t *testing.T) { 20 | should := require.New(t) 21 | for _, c := range test.UnmarshalCombinations { 22 | buf, proto := c.CreateProtocol() 23 | proto.WriteString("hello") 24 | var val string 25 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 26 | should.Equal("hello", val) 27 | } 28 | } 29 | 30 | func Test_encode_string(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | stream := c.CreateStream() 34 | stream.WriteString("hello") 35 | iter := c.CreateIterator(stream.Buffer()) 36 | should.Equal("hello", iter.ReadString()) 37 | } 38 | } 39 | 40 | func Test_marshal_string(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.MarshalCombinations { 43 | output, err := c.Marshal("hello") 44 | should.NoError(err) 45 | iter := c.CreateIterator(output) 46 | should.Equal("hello", iter.ReadString()) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /test/level_1/struct_test/TestObject.go: -------------------------------------------------------------------------------- 1 | package struct_test 2 | 3 | type TestObject struct { 4 | Field1 int64 `thrift:",1"` 5 | } -------------------------------------------------------------------------------- /test/level_2/level_2_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import "github.com/v2pro/wombat/generic" 4 | 5 | func init() { 6 | generic.DynamicCompilationEnabled = true 7 | } -------------------------------------------------------------------------------- /test/level_2/list_of_list_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/test" 8 | "github.com/thrift-iterator/go/general" 9 | ) 10 | 11 | func Test_skip_list_of_list(t *testing.T) { 12 | should := require.New(t) 13 | for _, c := range test.Combinations { 14 | buf, proto := c.CreateProtocol() 15 | proto.WriteListBegin(thrift.LIST, 2) 16 | proto.WriteListBegin(thrift.I64, 1) 17 | proto.WriteI64(1) 18 | proto.WriteListEnd() 19 | proto.WriteListBegin(thrift.I64, 1) 20 | proto.WriteI64(2) 21 | proto.WriteListEnd() 22 | proto.WriteListEnd() 23 | iter := c.CreateIterator(buf.Bytes()) 24 | should.Equal(buf.Bytes(), iter.SkipList(nil)) 25 | } 26 | } 27 | 28 | func Test_unmarshal_general_list_of_list(t *testing.T) { 29 | should := require.New(t) 30 | for _, c := range test.Combinations { 31 | buf, proto := c.CreateProtocol() 32 | proto.WriteListBegin(thrift.LIST, 2) 33 | proto.WriteListBegin(thrift.I64, 1) 34 | proto.WriteI64(1) 35 | proto.WriteListEnd() 36 | proto.WriteListBegin(thrift.I64, 1) 37 | proto.WriteI64(2) 38 | proto.WriteListEnd() 39 | proto.WriteListEnd() 40 | var val general.List 41 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 42 | should.Equal(general.List{int64(1)}, val[0]) 43 | } 44 | } 45 | 46 | func Test_unmarshal_list_of_general_list(t *testing.T) { 47 | should := require.New(t) 48 | for _, c := range test.UnmarshalCombinations { 49 | buf, proto := c.CreateProtocol() 50 | proto.WriteListBegin(thrift.LIST, 2) 51 | proto.WriteListBegin(thrift.I64, 1) 52 | proto.WriteI64(1) 53 | proto.WriteListEnd() 54 | proto.WriteListBegin(thrift.I64, 1) 55 | proto.WriteI64(2) 56 | proto.WriteListEnd() 57 | proto.WriteListEnd() 58 | var val []general.List 59 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 60 | should.Equal(general.List{int64(1)}, val[0]) 61 | } 62 | } 63 | 64 | func Test_unmarshal_list_of_list(t *testing.T) { 65 | should := require.New(t) 66 | for _, c := range test.UnmarshalCombinations { 67 | buf, proto := c.CreateProtocol() 68 | proto.WriteListBegin(thrift.LIST, 2) 69 | proto.WriteListBegin(thrift.I64, 1) 70 | proto.WriteI64(1) 71 | proto.WriteListEnd() 72 | proto.WriteListBegin(thrift.I64, 1) 73 | proto.WriteI64(2) 74 | proto.WriteListEnd() 75 | proto.WriteListEnd() 76 | var val [][]int64 77 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 78 | should.Equal([][]int64{ 79 | {1}, {2}, 80 | }, val) 81 | } 82 | } 83 | 84 | func Test_marshal_general_list_of_list(t *testing.T) { 85 | should := require.New(t) 86 | for _, c := range test.Combinations { 87 | lst := general.List{ 88 | general.List{ 89 | int64(1), 90 | }, 91 | general.List { 92 | int64(2), 93 | }, 94 | } 95 | 96 | output, err := c.Marshal(lst) 97 | should.NoError(err) 98 | output1, err := c.Marshal(&lst) 99 | should.NoError(err) 100 | should.Equal(output, output1) 101 | var val general.List 102 | should.NoError(c.Unmarshal(output, &val)) 103 | should.Equal(general.List{int64(1)}, val[0]) 104 | } 105 | } 106 | 107 | func Test_marshal_list_of_general_list(t *testing.T) { 108 | should := require.New(t) 109 | for _, c := range test.MarshalCombinations { 110 | lst := []general.List{ 111 | { 112 | int64(1), 113 | }, 114 | { 115 | int64(2), 116 | }, 117 | } 118 | 119 | output, err := c.Marshal(lst) 120 | should.NoError(err) 121 | output1, err := c.Marshal(&lst) 122 | should.NoError(err) 123 | should.Equal(output, output1) 124 | var val general.List 125 | should.NoError(c.Unmarshal(output, &val)) 126 | should.Equal(general.List{int64(1)}, val[0]) 127 | } 128 | } 129 | 130 | func Test_marshal_list_of_list(t *testing.T) { 131 | should := require.New(t) 132 | for _, c := range test.MarshalCombinations { 133 | lst := [][]int64{ 134 | {1}, {2}, 135 | } 136 | 137 | output, err := c.Marshal(lst) 138 | should.NoError(err) 139 | output1, err := c.Marshal(&lst) 140 | should.NoError(err) 141 | should.Equal(output, output1) 142 | var val general.List 143 | should.NoError(c.Unmarshal(output, &val)) 144 | should.Equal(general.List{int64(1)}, val[0]) 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /test/level_2/list_of_map_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/test" 8 | "github.com/thrift-iterator/go/general" 9 | ) 10 | 11 | func Test_skip_list_of_map(t *testing.T) { 12 | should := require.New(t) 13 | for _, c := range test.Combinations { 14 | buf, proto := c.CreateProtocol() 15 | proto.WriteListBegin(thrift.MAP, 2) 16 | proto.WriteMapBegin(thrift.I32, thrift.I64, 1) 17 | proto.WriteI32(1) 18 | proto.WriteI64(1) 19 | proto.WriteMapEnd() 20 | proto.WriteMapBegin(thrift.I32, thrift.I64, 1) 21 | proto.WriteI32(2) 22 | proto.WriteI64(2) 23 | proto.WriteMapEnd() 24 | proto.WriteListEnd() 25 | iter := c.CreateIterator(buf.Bytes()) 26 | should.Equal(buf.Bytes(), iter.SkipList(nil)) 27 | } 28 | } 29 | 30 | func Test_unmarshal_general_list_of_map(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | buf, proto := c.CreateProtocol() 34 | proto.WriteListBegin(thrift.MAP, 2) 35 | proto.WriteMapBegin(thrift.I32, thrift.I64, 1) 36 | proto.WriteI32(1) 37 | proto.WriteI64(1) 38 | proto.WriteMapEnd() 39 | proto.WriteMapBegin(thrift.I32, thrift.I64, 1) 40 | proto.WriteI32(2) 41 | proto.WriteI64(2) 42 | proto.WriteMapEnd() 43 | proto.WriteListEnd() 44 | var val general.List 45 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 46 | should.Equal(general.Map{ 47 | int32(1): int64(1), 48 | }, val[0]) 49 | should.Equal(int64(1), val.Get(0, int32(1))) 50 | } 51 | } 52 | 53 | func Test_unmarshal_list_of_map(t *testing.T) { 54 | should := require.New(t) 55 | for _, c := range test.UnmarshalCombinations { 56 | buf, proto := c.CreateProtocol() 57 | proto.WriteListBegin(thrift.MAP, 2) 58 | proto.WriteMapBegin(thrift.I32, thrift.I64, 1) 59 | proto.WriteI32(1) 60 | proto.WriteI64(1) 61 | proto.WriteMapEnd() 62 | proto.WriteMapBegin(thrift.I32, thrift.I64, 1) 63 | proto.WriteI32(2) 64 | proto.WriteI64(2) 65 | proto.WriteMapEnd() 66 | proto.WriteListEnd() 67 | var val []map[int32]int64 68 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 69 | should.Equal([]map[int32]int64{ 70 | {1: 1}, {2: 2}, 71 | }, val) 72 | } 73 | } 74 | 75 | func Test_marshal_general_list_of_map(t *testing.T) { 76 | should := require.New(t) 77 | for _, c := range test.Combinations { 78 | lst := general.List{ 79 | general.Map{ 80 | int32(1): int64(1), 81 | }, 82 | general.Map{ 83 | int32(2): int64(2), 84 | }, 85 | } 86 | 87 | output, err := c.Marshal(lst) 88 | should.NoError(err) 89 | output1, err := c.Marshal(&lst) 90 | should.NoError(err) 91 | should.Equal(output, output1) 92 | var val []map[int32]int64 93 | should.NoError(c.Unmarshal(output, &val)) 94 | should.Equal([]map[int32]int64{ 95 | {1: 1}, {2: 2}, 96 | }, val) 97 | } 98 | } 99 | 100 | func Test_marshal_list_of_map(t *testing.T) { 101 | should := require.New(t) 102 | for _, c := range test.MarshalCombinations { 103 | lst := []map[int32]int64{ 104 | {1: 1}, {2: 2}, 105 | } 106 | 107 | output, err := c.Marshal(lst) 108 | should.NoError(err) 109 | output1, err := c.Marshal(&lst) 110 | should.Equal(output, output1) 111 | should.NoError(err) 112 | var val []map[int32]int64 113 | should.NoError(c.Unmarshal(output, &val)) 114 | should.Equal([]map[int32]int64{ 115 | {1: 1}, {2: 2}, 116 | }, val) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /test/level_2/list_of_string_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/test" 8 | "github.com/thrift-iterator/go/general" 9 | ) 10 | 11 | func Test_skip_list_of_string(t *testing.T) { 12 | should := require.New(t) 13 | for _, c := range test.Combinations { 14 | buf, proto := c.CreateProtocol() 15 | proto.WriteListBegin(thrift.STRING, 3) 16 | proto.WriteString("a") 17 | proto.WriteString("b") 18 | proto.WriteString("c") 19 | proto.WriteListEnd() 20 | iter := c.CreateIterator(buf.Bytes()) 21 | should.Equal(buf.Bytes(), iter.SkipList(nil)) 22 | } 23 | } 24 | 25 | func Test_unmarshal_general_list_of_string(t *testing.T) { 26 | should := require.New(t) 27 | for _, c := range test.Combinations { 28 | buf, proto := c.CreateProtocol() 29 | proto.WriteListBegin(thrift.STRING, 3) 30 | proto.WriteString("a") 31 | proto.WriteString("b") 32 | proto.WriteString("c") 33 | proto.WriteListEnd() 34 | var val general.List 35 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 36 | should.Equal(general.List{"a", "b", "c"}, val) 37 | } 38 | } 39 | 40 | func Test_unmarshal_list_of_string(t *testing.T) { 41 | should := require.New(t) 42 | for _, c := range test.UnmarshalCombinations { 43 | buf, proto := c.CreateProtocol() 44 | proto.WriteListBegin(thrift.STRING, 3) 45 | proto.WriteString("a") 46 | proto.WriteString("b") 47 | proto.WriteString("c") 48 | proto.WriteListEnd() 49 | var val []string 50 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 51 | should.Equal([]string{ 52 | "a", "b", "c", 53 | }, val) 54 | } 55 | } 56 | 57 | func Test_marshal_general_list_of_string(t *testing.T) { 58 | should := require.New(t) 59 | for _, c := range test.Combinations { 60 | lst := general.List{ 61 | "a", "b", "c", 62 | } 63 | 64 | output, err := c.Marshal(lst) 65 | should.NoError(err) 66 | output1, err := c.Marshal(&lst) 67 | should.NoError(err) 68 | should.Equal(output, output1) 69 | var val []string 70 | should.NoError(c.Unmarshal(output, &val)) 71 | should.Equal([]string{ 72 | "a", "b", "c", 73 | }, val) 74 | } 75 | } 76 | 77 | func Test_marshal_list_of_string(t *testing.T) { 78 | should := require.New(t) 79 | for _, c := range test.MarshalCombinations { 80 | lst := []string{"a", "b", "c"} 81 | 82 | output, err := c.Marshal(lst) 83 | should.NoError(err) 84 | output1, err := c.Marshal(&lst) 85 | should.Equal(output, output1) 86 | should.NoError(err) 87 | var val []string 88 | should.NoError(c.Unmarshal(output, &val)) 89 | should.Equal([]string{ 90 | "a", "b", "c", 91 | }, val) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /test/level_2/list_of_struct_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/protocol" 8 | "github.com/thrift-iterator/go/test" 9 | "github.com/thrift-iterator/go/test/level_2/list_of_struct_test" 10 | "github.com/thrift-iterator/go/general" 11 | ) 12 | 13 | func Test_skip_list_of_struct(t *testing.T) { 14 | should := require.New(t) 15 | for _, c := range test.Combinations { 16 | buf, proto := c.CreateProtocol() 17 | proto.WriteListBegin(thrift.STRUCT, 2) 18 | proto.WriteStructBegin("hello") 19 | proto.WriteFieldBegin("field1", thrift.I64, 1) 20 | proto.WriteI64(1024) 21 | proto.WriteFieldEnd() 22 | proto.WriteFieldStop() 23 | proto.WriteStructEnd() 24 | proto.WriteStructBegin("hello") 25 | proto.WriteFieldBegin("field1", thrift.I64, 1) 26 | proto.WriteI64(1024) 27 | proto.WriteFieldEnd() 28 | proto.WriteFieldStop() 29 | proto.WriteStructEnd() 30 | proto.WriteListEnd() 31 | iter := c.CreateIterator(buf.Bytes()) 32 | should.Equal(buf.Bytes(), iter.SkipList(nil)) 33 | } 34 | } 35 | 36 | func Test_unmarshal_general_list_of_struct(t *testing.T) { 37 | should := require.New(t) 38 | for _, c := range test.Combinations { 39 | buf, proto := c.CreateProtocol() 40 | proto.WriteListBegin(thrift.STRUCT, 2) 41 | proto.WriteStructBegin("hello") 42 | proto.WriteFieldBegin("field1", thrift.I64, 1) 43 | proto.WriteI64(1024) 44 | proto.WriteFieldEnd() 45 | proto.WriteFieldStop() 46 | proto.WriteStructEnd() 47 | proto.WriteStructBegin("hello") 48 | proto.WriteFieldBegin("field1", thrift.I64, 1) 49 | proto.WriteI64(1024) 50 | proto.WriteFieldEnd() 51 | proto.WriteFieldStop() 52 | proto.WriteStructEnd() 53 | proto.WriteListEnd() 54 | var val general.List 55 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 56 | should.Equal(general.Struct{ 57 | protocol.FieldId(1): int64(1024), 58 | }, val[0]) 59 | } 60 | } 61 | 62 | func Test_unmarshal_list_of_struct(t *testing.T) { 63 | should := require.New(t) 64 | for _, c := range test.UnmarshalCombinations { 65 | buf, proto := c.CreateProtocol() 66 | proto.WriteListBegin(thrift.STRUCT, 2) 67 | proto.WriteStructBegin("hello") 68 | proto.WriteFieldBegin("field1", thrift.I64, 1) 69 | proto.WriteI64(1024) 70 | proto.WriteFieldEnd() 71 | proto.WriteFieldStop() 72 | proto.WriteStructEnd() 73 | proto.WriteStructBegin("hello") 74 | proto.WriteFieldBegin("field1", thrift.I64, 1) 75 | proto.WriteI64(1024) 76 | proto.WriteFieldEnd() 77 | proto.WriteFieldStop() 78 | proto.WriteStructEnd() 79 | proto.WriteListEnd() 80 | var val []list_of_struct_test.TestObject 81 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 82 | should.Equal([]list_of_struct_test.TestObject{ 83 | {1024}, {1024}, 84 | }, val) 85 | } 86 | } 87 | 88 | func Test_marshal_general_list_of_struct(t *testing.T) { 89 | should := require.New(t) 90 | for _, c := range test.Combinations { 91 | lst := general.List{ 92 | general.Struct{ 93 | protocol.FieldId(1): int64(1024), 94 | }, 95 | general.Struct{ 96 | protocol.FieldId(1): int64(1024), 97 | }, 98 | } 99 | 100 | output, err := c.Marshal(lst) 101 | should.NoError(err) 102 | output1, err := c.Marshal(&lst) 103 | should.NoError(err) 104 | should.Equal(output, output1) 105 | var val general.List 106 | should.NoError(c.Unmarshal(output, &val)) 107 | should.Equal(general.Struct{ 108 | protocol.FieldId(1): int64(1024), 109 | }, val[0]) 110 | should.Equal(general.Struct{ 111 | protocol.FieldId(1): int64(1024), 112 | }, val[1]) 113 | } 114 | } 115 | 116 | func Test_marshal_list_of_struct(t *testing.T) { 117 | should := require.New(t) 118 | for _, c := range test.MarshalCombinations { 119 | lst := []list_of_struct_test.TestObject{ 120 | {1024}, {1024}, 121 | } 122 | 123 | output, err := c.Marshal(lst) 124 | should.NoError(err) 125 | output1, err := c.Marshal(&lst) 126 | should.Equal(output, output1) 127 | should.NoError(err) 128 | var val general.List 129 | should.NoError(c.Unmarshal(output, &val)) 130 | should.Equal(general.Struct{ 131 | protocol.FieldId(1): int64(1024), 132 | }, val[0]) 133 | should.Equal(general.Struct{ 134 | protocol.FieldId(1): int64(1024), 135 | }, val[1]) 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /test/level_2/list_of_struct_test/TestObject.go: -------------------------------------------------------------------------------- 1 | package list_of_struct_test 2 | 3 | type TestObject struct { 4 | Field1 int64 `thrift:",1"` 5 | } -------------------------------------------------------------------------------- /test/level_2/map_of_list_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/test" 8 | "github.com/thrift-iterator/go/general" 9 | ) 10 | 11 | func Test_skip_map_of_list(t *testing.T) { 12 | should := require.New(t) 13 | for _, c := range test.Combinations { 14 | buf, proto := c.CreateProtocol() 15 | proto.WriteMapBegin(thrift.I64, thrift.LIST, 1) 16 | proto.WriteI64(1) 17 | proto.WriteListBegin(thrift.I64, 1) 18 | proto.WriteI64(1) 19 | proto.WriteListEnd() 20 | proto.WriteMapEnd() 21 | iter := c.CreateIterator(buf.Bytes()) 22 | should.Equal(buf.Bytes(), iter.SkipMap(nil)) 23 | } 24 | } 25 | 26 | func Test_unmarshal_general_map_of_list(t *testing.T) { 27 | should := require.New(t) 28 | for _, c := range test.Combinations { 29 | buf, proto := c.CreateProtocol() 30 | proto.WriteMapBegin(thrift.I64, thrift.LIST, 1) 31 | proto.WriteI64(1) 32 | proto.WriteListBegin(thrift.I64, 1) 33 | proto.WriteI64(1) 34 | proto.WriteListEnd() 35 | proto.WriteMapEnd() 36 | var val general.Map 37 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 38 | should.Equal(general.List{ 39 | int64(1), 40 | }, val[int64(1)]) 41 | } 42 | } 43 | 44 | func Test_unmarshal_map_of_list(t *testing.T) { 45 | should := require.New(t) 46 | for _, c := range test.UnmarshalCombinations { 47 | buf, proto := c.CreateProtocol() 48 | proto.WriteMapBegin(thrift.I64, thrift.LIST, 1) 49 | proto.WriteI64(1) 50 | proto.WriteListBegin(thrift.I64, 1) 51 | proto.WriteI64(1) 52 | proto.WriteListEnd() 53 | proto.WriteMapEnd() 54 | var val map[int64][]int64 55 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 56 | should.Equal(map[int64][]int64{ 57 | 1: {1}, 58 | }, val) 59 | } 60 | } 61 | 62 | func Test_marshal_general_map_of_list(t *testing.T) { 63 | should := require.New(t) 64 | for _, c := range test.Combinations { 65 | m := general.Map{ 66 | int64(1): general.List{int64(1)}, 67 | } 68 | 69 | output, err := c.Marshal(m) 70 | should.NoError(err) 71 | output1, err := c.Marshal(&m) 72 | should.NoError(err) 73 | should.Equal(output, output1) 74 | var val general.Map 75 | should.NoError(c.Unmarshal(output, &val)) 76 | should.Equal(general.List{ 77 | int64(1), 78 | }, val[int64(1)]) 79 | } 80 | } 81 | 82 | func Test_marshal_map_of_list(t *testing.T) { 83 | should := require.New(t) 84 | for _, c := range test.MarshalCombinations { 85 | m := map[int64][]int64{ 86 | 1: {1}, 87 | } 88 | 89 | output, err := c.Marshal(m) 90 | should.NoError(err) 91 | output1, err := c.Marshal(&m) 92 | should.Equal(output, output1) 93 | should.NoError(err) 94 | var val general.Map 95 | should.NoError(c.Unmarshal(output, &val)) 96 | should.Equal(general.List{ 97 | int64(1), 98 | }, val[int64(1)]) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /test/level_2/map_of_map_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/test" 8 | "github.com/thrift-iterator/go/general" 9 | ) 10 | 11 | func Test_skip_map_of_map(t *testing.T) { 12 | should := require.New(t) 13 | for _, c := range test.Combinations { 14 | buf, proto := c.CreateProtocol() 15 | proto.WriteMapBegin(thrift.I64, thrift.MAP, 1) 16 | proto.WriteI64(1) 17 | 18 | proto.WriteMapBegin(thrift.STRING, thrift.I64, 1) 19 | proto.WriteString("k1") 20 | proto.WriteI64(1) 21 | proto.WriteMapEnd() 22 | 23 | proto.WriteMapEnd() 24 | iter := c.CreateIterator(buf.Bytes()) 25 | should.Equal(buf.Bytes(), iter.SkipMap(nil)) 26 | } 27 | } 28 | 29 | func Test_unmarshal_general_map_of_map(t *testing.T) { 30 | should := require.New(t) 31 | for _, c := range test.Combinations { 32 | buf, proto := c.CreateProtocol() 33 | proto.WriteMapBegin(thrift.I64, thrift.MAP, 1) 34 | proto.WriteI64(1) 35 | 36 | proto.WriteMapBegin(thrift.STRING, thrift.I64, 1) 37 | proto.WriteString("k1") 38 | proto.WriteI64(1) 39 | proto.WriteMapEnd() 40 | 41 | proto.WriteMapEnd() 42 | var val general.Map 43 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 44 | should.Equal(general.Map{ 45 | "k1": int64(1), 46 | }, val[int64(1)]) 47 | } 48 | } 49 | 50 | func Test_marshal_general_map_of_map(t *testing.T) { 51 | should := require.New(t) 52 | for _, c := range test.Combinations { 53 | m := general.Map{ 54 | int64(1): general.Map{ 55 | "k1": int64(1), 56 | }, 57 | } 58 | 59 | output, err := c.Marshal(m) 60 | should.NoError(err) 61 | output1, err := c.Marshal(&m) 62 | should.NoError(err) 63 | should.Equal(output, output1) 64 | var val general.Map 65 | should.NoError(c.Unmarshal(output, &val)) 66 | should.Equal(general.Map{ 67 | "k1": int64(1), 68 | }, val[int64(1)]) 69 | } 70 | } 71 | 72 | func Test_marshal_map_of_map(t *testing.T) { 73 | should := require.New(t) 74 | for _, c := range test.MarshalCombinations { 75 | m := map[int64]map[string]int64{ 76 | 1: {"k1": 1}, 77 | } 78 | 79 | output, err := c.Marshal(m) 80 | should.NoError(err) 81 | output1, err := c.Marshal(&m) 82 | should.NoError(err) 83 | should.Equal(output, output1) 84 | var val general.Map 85 | should.NoError(c.Unmarshal(output, &val)) 86 | should.Equal(general.Map{ 87 | "k1": int64(1), 88 | }, val[int64(1)]) 89 | } 90 | } -------------------------------------------------------------------------------- /test/level_2/map_of_string_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/test" 8 | "github.com/thrift-iterator/go/general" 9 | ) 10 | 11 | func Test_skip_map_of_string_key(t *testing.T) { 12 | should := require.New(t) 13 | for _, c := range test.Combinations { 14 | buf, proto := c.CreateProtocol() 15 | proto.WriteMapBegin(thrift.STRING, thrift.I64, 1) 16 | proto.WriteString("1") 17 | proto.WriteI64(1) 18 | proto.WriteMapEnd() 19 | iter := c.CreateIterator(buf.Bytes()) 20 | should.Equal(buf.Bytes(), iter.SkipMap(nil)) 21 | } 22 | } 23 | 24 | func Test_skip_map_of_string_elem(t *testing.T) { 25 | should := require.New(t) 26 | for _, c := range test.Combinations { 27 | buf, proto := c.CreateProtocol() 28 | proto.WriteMapBegin(thrift.I64, thrift.STRING, 1) 29 | proto.WriteI64(1) 30 | proto.WriteString("1") 31 | proto.WriteMapEnd() 32 | iter := c.CreateIterator(buf.Bytes()) 33 | should.Equal(buf.Bytes(), iter.SkipMap(nil)) 34 | } 35 | } 36 | 37 | func Test_unmarshal_general_map_of_string_key(t *testing.T) { 38 | should := require.New(t) 39 | for _, c := range test.Combinations { 40 | buf, proto := c.CreateProtocol() 41 | proto.WriteMapBegin(thrift.STRING, thrift.I64, 1) 42 | proto.WriteString("1") 43 | proto.WriteI64(1) 44 | proto.WriteMapEnd() 45 | var val general.Map 46 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 47 | should.Equal(general.Map{ 48 | "1": int64(1), 49 | }, val) 50 | } 51 | } 52 | 53 | func Test_unmarshal_map_of_string_key(t *testing.T) { 54 | should := require.New(t) 55 | for _, c := range test.UnmarshalCombinations { 56 | buf, proto := c.CreateProtocol() 57 | proto.WriteMapBegin(thrift.STRING, thrift.I64, 1) 58 | proto.WriteString("1") 59 | proto.WriteI64(1) 60 | proto.WriteMapEnd() 61 | var val map[string]int64 62 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 63 | should.Equal(map[string]int64{ 64 | "1": 1, 65 | }, val) 66 | } 67 | } 68 | 69 | func Test_marshal_general_map_of_string_key(t *testing.T) { 70 | should := require.New(t) 71 | for _, c := range test.Combinations { 72 | m := general.Map{ 73 | "1": int64(1), 74 | } 75 | 76 | output, err := c.Marshal(m) 77 | should.NoError(err) 78 | output1, err := c.Marshal(&m) 79 | should.NoError(err) 80 | should.Equal(output, output1) 81 | var val general.Map 82 | should.NoError(c.Unmarshal(output, &val)) 83 | should.Equal(general.Map{ 84 | "1": int64(1), 85 | }, val) 86 | } 87 | } 88 | 89 | func Test_marshal_map_of_string_key(t *testing.T) { 90 | should := require.New(t) 91 | for _, c := range test.MarshalCombinations { 92 | m := map[string]int64{ 93 | "1": 1, 94 | } 95 | 96 | output, err := c.Marshal(m) 97 | should.NoError(err) 98 | output1, err := c.Marshal(&m) 99 | should.NoError(err) 100 | should.Equal(output, output1) 101 | var val general.Map 102 | should.NoError(c.Unmarshal(output, &val)) 103 | should.Equal(general.Map{ 104 | "1": int64(1), 105 | }, val) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /test/level_2/map_of_struct_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/protocol" 8 | "github.com/thrift-iterator/go/test" 9 | "github.com/thrift-iterator/go/test/level_2/map_of_struct_test" 10 | "github.com/thrift-iterator/go/general" 11 | ) 12 | 13 | func Test_skip_map_of_struct(t *testing.T) { 14 | should := require.New(t) 15 | for _, c := range test.Combinations { 16 | buf, proto := c.CreateProtocol() 17 | proto.WriteMapBegin(thrift.I64, thrift.STRUCT, 1) 18 | proto.WriteI64(1) 19 | 20 | proto.WriteStructBegin("hello") 21 | proto.WriteFieldBegin("field1", thrift.I64, 1) 22 | proto.WriteI64(1024) 23 | proto.WriteFieldEnd() 24 | proto.WriteFieldStop() 25 | proto.WriteStructEnd() 26 | 27 | proto.WriteMapEnd() 28 | iter := c.CreateIterator(buf.Bytes()) 29 | should.Equal(buf.Bytes(), iter.SkipMap(nil)) 30 | } 31 | } 32 | 33 | func Test_unmarshal_general_map_of_struct(t *testing.T) { 34 | should := require.New(t) 35 | for _, c := range test.Combinations { 36 | buf, proto := c.CreateProtocol() 37 | proto.WriteMapBegin(thrift.I64, thrift.STRUCT, 1) 38 | proto.WriteI64(1) 39 | 40 | proto.WriteStructBegin("hello") 41 | proto.WriteFieldBegin("field1", thrift.I64, 1) 42 | proto.WriteI64(1024) 43 | proto.WriteFieldEnd() 44 | proto.WriteFieldStop() 45 | proto.WriteStructEnd() 46 | 47 | proto.WriteMapEnd() 48 | var val general.Map 49 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 50 | should.Equal(general.Struct{ 51 | protocol.FieldId(1): int64(1024), 52 | }, val[int64(1)]) 53 | } 54 | } 55 | 56 | func Test_unmarshal_map_of_struct(t *testing.T) { 57 | should := require.New(t) 58 | for _, c := range test.UnmarshalCombinations { 59 | buf, proto := c.CreateProtocol() 60 | proto.WriteMapBegin(thrift.I64, thrift.STRUCT, 1) 61 | proto.WriteI64(1) 62 | 63 | proto.WriteStructBegin("hello") 64 | proto.WriteFieldBegin("field1", thrift.I64, 1) 65 | proto.WriteI64(1024) 66 | proto.WriteFieldEnd() 67 | proto.WriteFieldStop() 68 | proto.WriteStructEnd() 69 | 70 | proto.WriteMapEnd() 71 | var val map[int64]map_of_struct_test.TestObject 72 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 73 | should.Equal(map[int64]map_of_struct_test.TestObject{ 74 | 1: {1024}, 75 | }, val) 76 | } 77 | } 78 | 79 | func Test_marshal_general_map_of_struct(t *testing.T) { 80 | should := require.New(t) 81 | for _, c := range test.Combinations { 82 | m := general.Map{ 83 | int64(1): general.Struct { 84 | protocol.FieldId(1): int64(1024), 85 | }, 86 | } 87 | 88 | output, err := c.Marshal(m) 89 | should.NoError(err) 90 | output1, err := c.Marshal(&m) 91 | should.NoError(err) 92 | should.Equal(output, output1) 93 | var val general.Map 94 | should.NoError(c.Unmarshal(output, &val)) 95 | should.Equal(general.Struct{ 96 | protocol.FieldId(1): int64(1024), 97 | }, val[int64(1)]) 98 | } 99 | } 100 | 101 | func Test_marshal_map_of_struct(t *testing.T) { 102 | should := require.New(t) 103 | for _, c := range test.MarshalCombinations { 104 | m := map[int64]map_of_struct_test.TestObject{ 105 | 1: {1024}, 106 | } 107 | 108 | output, err := c.Marshal(m) 109 | should.NoError(err) 110 | output1, err := c.Marshal(&m) 111 | should.NoError(err) 112 | should.Equal(output, output1) 113 | var val general.Map 114 | should.NoError(c.Unmarshal(output, &val)) 115 | should.Equal(general.Struct{ 116 | protocol.FieldId(1): int64(1024), 117 | }, val[int64(1)]) 118 | } 119 | } -------------------------------------------------------------------------------- /test/level_2/map_of_struct_test/TestObject.go: -------------------------------------------------------------------------------- 1 | package map_of_struct_test 2 | 3 | type TestObject struct { 4 | Field1 int64 `thrift:",1"` 5 | } -------------------------------------------------------------------------------- /test/level_2/message_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/protocol" 8 | "github.com/thrift-iterator/go/test" 9 | "github.com/thrift-iterator/go/general" 10 | ) 11 | 12 | func Test_skip_message(t *testing.T) { 13 | should := require.New(t) 14 | for _, c := range test.Combinations { 15 | buf, proto := c.CreateProtocol() 16 | proto.WriteMessageBegin("hello", thrift.CALL, 17) 17 | proto.WriteStructBegin("args") 18 | proto.WriteFieldBegin("field1", thrift.I64, 1) 19 | proto.WriteI64(1) 20 | proto.WriteFieldBegin("field2", thrift.I64, 2) 21 | proto.WriteI64(2) 22 | proto.WriteFieldEnd() 23 | proto.WriteFieldStop() 24 | proto.WriteStructEnd() 25 | proto.WriteMessageEnd() 26 | iter := c.CreateIterator(buf.Bytes()) 27 | should.Equal(buf.Bytes(), iter.SkipStruct(iter.SkipMessageHeader(nil))) 28 | } 29 | } 30 | 31 | func Test_unmarshal_message(t *testing.T) { 32 | should := require.New(t) 33 | for _, c := range test.Combinations { 34 | buf, proto := c.CreateProtocol() 35 | proto.WriteMessageBegin("hello", thrift.CALL, 17) 36 | proto.WriteStructBegin("args") 37 | proto.WriteFieldBegin("field1", thrift.I64, 1) 38 | proto.WriteI64(1) 39 | proto.WriteFieldBegin("field2", thrift.I64, 2) 40 | proto.WriteI64(2) 41 | proto.WriteFieldEnd() 42 | proto.WriteFieldStop() 43 | proto.WriteStructEnd() 44 | proto.WriteMessageEnd() 45 | var msg general.Message 46 | should.NoError(c.Unmarshal(buf.Bytes(), &msg)) 47 | should.Equal("hello", msg.MessageName) 48 | should.Equal(protocol.MessageTypeCall, msg.MessageType) 49 | should.Equal(protocol.SeqId(17), msg.SeqId) 50 | should.Equal(int64(1), msg.Arguments[protocol.FieldId(1)]) 51 | should.Equal(int64(2), msg.Arguments[protocol.FieldId(2)]) 52 | } 53 | } 54 | 55 | func Test_marshal_message(t *testing.T) { 56 | should := require.New(t) 57 | for _, c := range test.Combinations { 58 | output, err := c.Marshal(general.Message{ 59 | MessageHeader: protocol.MessageHeader{ 60 | MessageType: protocol.MessageTypeCall, 61 | MessageName: "hello", 62 | SeqId: protocol.SeqId(17), 63 | }, 64 | Arguments: general.Struct{ 65 | protocol.FieldId(1): int64(1), 66 | protocol.FieldId(2): int64(2), 67 | }, 68 | }) 69 | should.NoError(err) 70 | var msg general.Message 71 | should.NoError(c.Unmarshal(output, &msg)) 72 | should.Equal("hello", msg.MessageName) 73 | should.Equal(protocol.MessageTypeCall, msg.MessageType) 74 | should.Equal(protocol.SeqId(17), msg.SeqId) 75 | should.Equal(int64(1), msg.Arguments[protocol.FieldId(1)]) 76 | should.Equal(int64(2), msg.Arguments[protocol.FieldId(2)]) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /test/level_2/struct_complex_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "github.com/stretchr/testify/require" 5 | "github.com/thrift-iterator/go/test" 6 | "github.com/thrift-iterator/go/test/level_2/struct_complex_test" 7 | "testing" 8 | ) 9 | 10 | func Test_marshal_struct_complex(t *testing.T) { 11 | should := require.New(t) 12 | for _, c := range test.MarshalCombinations[:] { 13 | var obj struct_complex_test.TestObject 14 | obj.Av = false 15 | obj.Ap = &obj.Av 16 | obj.Bv = 1 17 | obj.Bp = &obj.Bv 18 | obj.Cv = 2 19 | obj.Cp = &obj.Cv 20 | obj.Dv = 3 21 | obj.Dp = &obj.Dv 22 | obj.Ev = 4 23 | obj.Ep = &obj.Ev 24 | obj.Fv = 5 25 | obj.Fp = &obj.Fv 26 | obj.Gv = 3.1415926 27 | obj.Gp = &obj.Gv 28 | obj.Hv = "6" // 15 29 | obj.Hp = &obj.Hv // 16 30 | obj.Iv = []byte{7} // 17 31 | obj.Ip = &obj.Iv // 18 32 | obj.Jv = []string{"8"} // 19 33 | obj.Jp = &obj.Jv // 20 34 | obj.Kv = map[string]bool{"9": true} // 21 35 | obj.Kp = &obj.Kv 36 | obj.Lv = map[int32]struct_complex_test.SubType{10: {A: 10}} 37 | obj.Lp = &obj.Lv 38 | obj.Mv = map[int32]map[int32]string{ 39 | 101: {102: "103"}, 40 | } 41 | obj.Mp = &obj.Mv 42 | obj.Nv = [][]string{ 43 | {"201", "202"}, 44 | } 45 | obj.Np = &obj.Nv 46 | obj.Ov = 11 47 | obj.Op = &obj.Ov 48 | obj.Pv = struct_complex_test.Enum_B 49 | obj.Pp = &obj.Pv 50 | obj.Qv = map[int32][]string{ 51 | 12: {"1201", "1201"}, 52 | } 53 | obj.Qp = &obj.Qv 54 | obj.Rv = []map[string][]map[string]int32{ 55 | {"foo": []map[string]int32{ 56 | {"foo1": 1801}, 57 | {"foo2": 1802}, 58 | }}, 59 | {"bar": []map[string]int32{ 60 | {"bar1": 1803}, 61 | {"bar2": 1804}, 62 | }}, 63 | } 64 | obj.Rp = &obj.Rv 65 | 66 | output, err := c.Marshal(obj) 67 | should.NoError(err) 68 | output1, err := c.Marshal(&obj) 69 | should.NoError(err) 70 | should.Equal(output, output1) 71 | 72 | var val *struct_complex_test.TestObject 73 | should.NoError(c.Unmarshal(output, &val)) 74 | 75 | should.Equal(obj.Av, val.Av) 76 | should.Equal(*obj.Ap, *val.Ap) 77 | should.Equal(obj.Bv, val.Bv) 78 | should.Equal(obj.Bv, *val.Bp) 79 | should.Equal(obj.Cv, val.Cv) 80 | should.Equal(obj.Cv, *val.Cp) 81 | should.Equal(obj.Dv, val.Dv) 82 | should.Equal(obj.Dv, *val.Dp) 83 | should.Equal(obj.Ev, val.Ev) 84 | should.Equal(obj.Ev, *val.Ep) 85 | should.Equal(obj.Fv, val.Fv) 86 | should.Equal(obj.Fv, *val.Fp) 87 | should.Equal(obj.Gv, val.Gv) 88 | should.Equal(obj.Gv, *val.Gp) 89 | should.Equal(obj.Hv, val.Hv) 90 | should.Equal(obj.Hv, *val.Hp) 91 | should.Equal(obj.Iv, val.Iv) 92 | should.Equal(obj.Iv, *val.Ip) 93 | should.Equal(obj.Jv, val.Jv) 94 | should.Equal(obj.Jv, *val.Jp) 95 | should.Equal(obj.Kv, val.Kv) 96 | should.Equal(obj.Kv, *val.Kp) 97 | should.Equal(obj.Lv, val.Lv) 98 | should.Equal(obj.Lv, *val.Lp) 99 | should.Equal(obj.Mv, val.Mv) 100 | should.Equal(obj.Mv, *val.Mp) 101 | should.Equal(obj.Nv, val.Nv) 102 | should.Equal(obj.Nv, *val.Np) 103 | should.Equal(obj.Ov, val.Ov) 104 | should.Equal(obj.Ov, *val.Op) 105 | should.Equal(obj.Pv, val.Pv) 106 | should.Equal(obj.Pv, *val.Pp) 107 | should.Equal(obj.Qv, val.Qv) 108 | should.Equal(obj.Qv, *val.Qp) 109 | should.Equal(obj.Rv, val.Rv) 110 | should.Equal(obj.Rv, *val.Rp) 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /test/level_2/struct_complex_test/TestObject.go: -------------------------------------------------------------------------------- 1 | package struct_complex_test 2 | 3 | type SubType struct { 4 | A int32 `thrift:"a,1"` 5 | } 6 | 7 | type Enum int32 8 | 9 | const ( 10 | Enum_A Enum = 1 11 | 12 | Enum_B Enum = 2 13 | ) 14 | 15 | type Int int32 16 | 17 | type TestObject struct { 18 | Av bool `thrift:"av,0"` 19 | Ap *bool `thrift:"ap,2,optional"` 20 | Bv int8 `thrift:"bv,3"` 21 | Bp *int8 `thrift:"bp,4,optional"` 22 | Cv int8 `thrift:"cv,5"` 23 | Cp *int8 `thrift:"cp,6,optional"` 24 | Dv int16 `thrift:"dv,7"` 25 | Dp *int16 `thrift:"dp,8,optional"` 26 | Ev int32 `thrift:"ev,9"` 27 | Ep *int32 `thrift:"ep,10,optional"` 28 | Fv int64 `thrift:"fv,11"` 29 | Fp *int64 `thrift:"fp,12,optional"` 30 | Gv float64 `thrift:"gv,13"` 31 | Gp *float64 `thrift:"gp,14,optional"` 32 | Hv string `thrift:"hv,15"` 33 | Hp *string `thrift:"hp,16,optional"` 34 | Iv []byte `thrift:"iv,17,optional"` 35 | Ip *[]byte `thrift:"ip,18,optional"` 36 | Jv []string `thrift:"jv,19,optional"` 37 | Jp *[]string `thrift:"jp,20,optional"` 38 | Kv map[string]bool `thrift:"kv,21,optional"` 39 | Kp *map[string]bool `thrift:"kp,22,optional"` 40 | Lv map[int32]SubType `thrift:"lv,23,optional"` 41 | Lp *map[int32]SubType `thrift:"lp,24,optional"` 42 | Mv map[int32]map[int32]string `thrift:"mv,25,optional"` 43 | Mp *map[int32]map[int32]string `thrift:"mp,26,optional"` 44 | Nv [][]string `thrift:"nv,27,optional"` 45 | Np *[][]string `thrift:"np,28,optional"` 46 | Ov Int `thrift:"ov,29"` 47 | Op *Int `thrift:"op,30,optional"` 48 | Pv Enum `thrift:"pv,31"` 49 | Pp *Enum `thrift:"pp,32,optional"` 50 | Qv map[int32][]string `thrift:"qv,33,optional"` 51 | Qp *map[int32][]string `thrift:"qp,34,optional"` 52 | Rv []map[string][]map[string]int32 `thrift:"rv,35,optional"` 53 | Rp *[]map[string][]map[string]int32 `thrift:"rp,36,optional"` 54 | } 55 | -------------------------------------------------------------------------------- /test/level_2/struct_of_list_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/protocol" 8 | "github.com/thrift-iterator/go/test" 9 | "github.com/thrift-iterator/go/test/level_2/struct_of_list_test" 10 | "github.com/thrift-iterator/go/general" 11 | ) 12 | 13 | func Test_skip_struct_of_list(t *testing.T) { 14 | should := require.New(t) 15 | for _, c := range test.Combinations { 16 | buf, proto := c.CreateProtocol() 17 | proto.WriteStructBegin("hello") 18 | proto.WriteFieldBegin("field1", thrift.LIST, 1) 19 | proto.WriteListBegin(thrift.I64, 1) 20 | proto.WriteI64(1) 21 | proto.WriteListEnd() 22 | proto.WriteFieldEnd() 23 | proto.WriteFieldStop() 24 | proto.WriteStructEnd() 25 | iter := c.CreateIterator(buf.Bytes()) 26 | should.Equal(buf.Bytes(), iter.SkipStruct(nil)) 27 | } 28 | } 29 | 30 | func Test_unmarshal_general_struct_of_list(t *testing.T) { 31 | should := require.New(t) 32 | for _, c := range test.Combinations { 33 | buf, proto := c.CreateProtocol() 34 | proto.WriteStructBegin("hello") 35 | proto.WriteFieldBegin("field1", thrift.LIST, 1) 36 | proto.WriteListBegin(thrift.I64, 1) 37 | proto.WriteI64(1) 38 | proto.WriteListEnd() 39 | proto.WriteFieldEnd() 40 | proto.WriteFieldStop() 41 | proto.WriteStructEnd() 42 | var val general.Struct 43 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 44 | should.Equal(general.List{int64(1)}, val[protocol.FieldId(1)]) 45 | } 46 | } 47 | 48 | func Test_unmarshal_struct_of_list(t *testing.T) { 49 | should := require.New(t) 50 | for _, c := range test.UnmarshalCombinations { 51 | buf, proto := c.CreateProtocol() 52 | proto.WriteStructBegin("hello") 53 | proto.WriteFieldBegin("field1", thrift.LIST, 1) 54 | proto.WriteListBegin(thrift.I64, 1) 55 | proto.WriteI64(1) 56 | proto.WriteListEnd() 57 | proto.WriteFieldEnd() 58 | proto.WriteFieldStop() 59 | proto.WriteStructEnd() 60 | var val struct_of_list_test.TestObject 61 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 62 | should.Equal(struct_of_list_test.TestObject{ 63 | []int64{1}, 64 | }, val) 65 | } 66 | } 67 | 68 | func Test_marshal_general_struct_of_list(t *testing.T) { 69 | should := require.New(t) 70 | for _, c := range test.Combinations { 71 | obj := general.Struct { 72 | protocol.FieldId(1): general.List { 73 | int64(1), 74 | }, 75 | } 76 | 77 | output, err := c.Marshal(obj) 78 | should.NoError(err) 79 | output1, err := c.Marshal(&obj) 80 | should.NoError(err) 81 | should.Equal(output, output1) 82 | var val general.Struct 83 | should.NoError(c.Unmarshal(output, &val)) 84 | should.Equal(general.List{int64(1)}, val[protocol.FieldId(1)]) 85 | } 86 | } 87 | 88 | func Test_marshal_struct_of_list(t *testing.T) { 89 | should := require.New(t) 90 | for _, c := range test.MarshalCombinations { 91 | obj := struct_of_list_test.TestObject{ 92 | []int64{1}, 93 | } 94 | 95 | output, err := c.Marshal(obj) 96 | should.NoError(err) 97 | output1, err := c.Marshal(&obj) 98 | should.NoError(err) 99 | should.Equal(output, output1) 100 | var val general.Struct 101 | should.NoError(c.Unmarshal(output, &val)) 102 | should.Equal(general.List{int64(1)}, val[protocol.FieldId(1)]) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /test/level_2/struct_of_list_test/TestObject.go: -------------------------------------------------------------------------------- 1 | package struct_of_list_test 2 | 3 | type TestObject struct { 4 | Field1 []int64 `thrift:",1"` 5 | } -------------------------------------------------------------------------------- /test/level_2/struct_of_map_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/protocol" 8 | "github.com/thrift-iterator/go/test" 9 | "github.com/thrift-iterator/go/test/level_2/struct_of_map_test" 10 | "github.com/thrift-iterator/go/general" 11 | ) 12 | 13 | func Test_skip_struct_of_map(t *testing.T) { 14 | should := require.New(t) 15 | for _, c := range test.Combinations { 16 | buf, proto := c.CreateProtocol() 17 | proto.WriteStructBegin("hello") 18 | proto.WriteFieldBegin("field1", thrift.MAP, 1) 19 | proto.WriteMapBegin(thrift.I32, thrift.I64, 1) 20 | proto.WriteI32(2) 21 | proto.WriteI64(2) 22 | proto.WriteMapEnd() 23 | proto.WriteFieldEnd() 24 | proto.WriteFieldStop() 25 | proto.WriteStructEnd() 26 | iter := c.CreateIterator(buf.Bytes()) 27 | should.Equal(buf.Bytes(), iter.SkipStruct(nil)) 28 | } 29 | } 30 | 31 | func Test_unmarshal_general_struct_of_map(t *testing.T) { 32 | should := require.New(t) 33 | for _, c := range test.Combinations { 34 | buf, proto := c.CreateProtocol() 35 | proto.WriteStructBegin("hello") 36 | proto.WriteFieldBegin("field1", thrift.MAP, 1) 37 | proto.WriteMapBegin(thrift.I32, thrift.I64, 1) 38 | proto.WriteI32(2) 39 | proto.WriteI64(2) 40 | proto.WriteMapEnd() 41 | proto.WriteFieldEnd() 42 | proto.WriteFieldStop() 43 | proto.WriteStructEnd() 44 | var val general.Struct 45 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 46 | should.Equal(general.Map{ 47 | int32(2): int64(2), 48 | }, val[protocol.FieldId(1)]) 49 | } 50 | } 51 | 52 | func Test_unmarshal_struct_of_map(t *testing.T) { 53 | should := require.New(t) 54 | for _, c := range test.UnmarshalCombinations { 55 | buf, proto := c.CreateProtocol() 56 | proto.WriteStructBegin("hello") 57 | proto.WriteFieldBegin("field1", thrift.MAP, 1) 58 | proto.WriteMapBegin(thrift.I32, thrift.I64, 1) 59 | proto.WriteI32(2) 60 | proto.WriteI64(2) 61 | proto.WriteMapEnd() 62 | proto.WriteFieldEnd() 63 | proto.WriteFieldStop() 64 | proto.WriteStructEnd() 65 | var val struct_of_map_test.TestObject 66 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 67 | should.Equal(struct_of_map_test.TestObject{ 68 | map[int32]int64{2: 2}, 69 | }, val) 70 | } 71 | } 72 | 73 | func Test_marshal_general_struct_of_map(t *testing.T) { 74 | should := require.New(t) 75 | for _, c := range test.Combinations { 76 | m := general.Struct{ 77 | protocol.FieldId(1): general.Map{ 78 | int32(2): int64(2), 79 | }, 80 | } 81 | 82 | output, err := c.Marshal(m) 83 | should.NoError(err) 84 | output1, err := c.Marshal(&m) 85 | should.NoError(err) 86 | should.Equal(output, output1) 87 | var val general.Struct 88 | should.NoError(c.Unmarshal(output, &val)) 89 | should.Equal(general.Map{ 90 | int32(2): int64(2), 91 | }, val[protocol.FieldId(1)]) 92 | } 93 | } 94 | 95 | func Test_marshal_struct_of_map(t *testing.T) { 96 | should := require.New(t) 97 | for _, c := range test.MarshalCombinations { 98 | m := struct_of_map_test.TestObject{ 99 | map[int32]int64{2: 2}, 100 | } 101 | 102 | output, err := c.Marshal(m) 103 | should.NoError(err) 104 | output1, err := c.Marshal(&m) 105 | should.NoError(err) 106 | should.Equal(output, output1) 107 | var val general.Struct 108 | should.NoError(c.Unmarshal(output, &val)) 109 | should.Equal(general.Map{ 110 | int32(2): int64(2), 111 | }, val[protocol.FieldId(1)]) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /test/level_2/struct_of_map_test/TestObject.go: -------------------------------------------------------------------------------- 1 | package struct_of_map_test 2 | 3 | type TestObject struct { 4 | Field1 map[int32]int64 `thrift:",1"` 5 | } -------------------------------------------------------------------------------- /test/level_2/struct_of_pointer_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "github.com/thrift-iterator/go/test" 7 | "git.apache.org/thrift.git/lib/go/thrift" 8 | "github.com/thrift-iterator/go/test/level_2/struct_of_pointer_test" 9 | ) 10 | 11 | func Test_unmarshal_struct_of_1_ptr(t *testing.T) { 12 | should := require.New(t) 13 | for _, c := range test.UnmarshalCombinations { 14 | buf, proto := c.CreateProtocol() 15 | proto.WriteStructBegin("hello") 16 | proto.WriteFieldBegin("field1", thrift.I64, 1) 17 | proto.WriteI64(1) 18 | proto.WriteFieldEnd() 19 | proto.WriteFieldStop() 20 | proto.WriteStructEnd() 21 | var val *struct_of_pointer_test.StructOf1Ptr 22 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 23 | should.Equal(1, *val.Field1) 24 | } 25 | } 26 | 27 | func Test_unmarshal_struct_of_2_ptr(t *testing.T) { 28 | should := require.New(t) 29 | for _, c := range test.UnmarshalCombinations { 30 | buf, proto := c.CreateProtocol() 31 | proto.WriteStructBegin("hello") 32 | proto.WriteFieldBegin("field1", thrift.I64, 1) 33 | proto.WriteI64(1) 34 | proto.WriteFieldEnd() 35 | proto.WriteFieldBegin("field2", thrift.I64, 2) 36 | proto.WriteI64(2) 37 | proto.WriteFieldEnd() 38 | proto.WriteFieldStop() 39 | proto.WriteStructEnd() 40 | var val *struct_of_pointer_test.StructOf2Ptr 41 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 42 | should.Equal(1, *val.Field1) 43 | should.Equal(2, *val.Field2) 44 | } 45 | } 46 | 47 | func Test_marshal_struct_of_1_ptr(t *testing.T) { 48 | should := require.New(t) 49 | for _, c := range test.MarshalCombinations { 50 | one := 1 51 | obj := struct_of_pointer_test.StructOf1Ptr{ 52 | &one, 53 | } 54 | 55 | output, err := c.Marshal(obj) 56 | should.NoError(err) 57 | output1, err := c.Marshal(&obj) 58 | should.NoError(err) 59 | should.Equal(output, output1) 60 | var val *struct_of_pointer_test.StructOf1Ptr 61 | should.NoError(c.Unmarshal(output, &val)) 62 | should.Equal(1, *val.Field1) 63 | } 64 | } 65 | 66 | func Test_marshal_struct_of_2_ptr(t *testing.T) { 67 | should := require.New(t) 68 | for _, c := range test.MarshalCombinations { 69 | one := 1 70 | two := 2 71 | obj := struct_of_pointer_test.StructOf2Ptr{ 72 | &one, &two, 73 | } 74 | 75 | output, err := c.Marshal(obj) 76 | should.NoError(err) 77 | output1, err := c.Marshal(&obj) 78 | should.NoError(err) 79 | should.Equal(output, output1) 80 | var val *struct_of_pointer_test.StructOf2Ptr 81 | should.NoError(c.Unmarshal(output, &val)) 82 | should.Equal(1, *val.Field1) 83 | should.Equal(2, *val.Field2) 84 | } 85 | } -------------------------------------------------------------------------------- /test/level_2/struct_of_pointer_test/StructOf1Ptr.go: -------------------------------------------------------------------------------- 1 | package struct_of_pointer_test 2 | 3 | type StructOf1Ptr struct { 4 | Field1 *int `thrift:",1"` 5 | } -------------------------------------------------------------------------------- /test/level_2/struct_of_pointer_test/StructOf2Ptr.go: -------------------------------------------------------------------------------- 1 | package struct_of_pointer_test 2 | 3 | type StructOf2Ptr struct { 4 | Field1 *int `thrift:",1"` 5 | Field2 *int `thrift:",2"` 6 | } -------------------------------------------------------------------------------- /test/level_2/struct_of_string_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/protocol" 8 | "github.com/thrift-iterator/go/test" 9 | "github.com/thrift-iterator/go/test/level_2/struct_of_string_test" 10 | "github.com/thrift-iterator/go/general" 11 | ) 12 | 13 | func Test_skip_struct_of_string(t *testing.T) { 14 | should := require.New(t) 15 | for _, c := range test.Combinations { 16 | buf, proto := c.CreateProtocol() 17 | proto.WriteStructBegin("hello") 18 | proto.WriteFieldBegin("field1", thrift.STRING, 1) 19 | proto.WriteString("abc") 20 | proto.WriteFieldEnd() 21 | proto.WriteFieldStop() 22 | proto.WriteStructEnd() 23 | iter := c.CreateIterator(buf.Bytes()) 24 | should.Equal(buf.Bytes(), iter.SkipStruct(nil)) 25 | } 26 | } 27 | 28 | func Test_unmarshal_general_struct_of_string(t *testing.T) { 29 | should := require.New(t) 30 | for _, c := range test.Combinations { 31 | buf, proto := c.CreateProtocol() 32 | proto.WriteStructBegin("hello") 33 | proto.WriteFieldBegin("field1", thrift.STRING, 1) 34 | proto.WriteString("abc") 35 | proto.WriteFieldEnd() 36 | proto.WriteFieldStop() 37 | proto.WriteStructEnd() 38 | var val general.Struct 39 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 40 | should.Equal("abc", val[protocol.FieldId(1)]) 41 | } 42 | } 43 | 44 | func Test_unmarshal_struct_of_string(t *testing.T) { 45 | should := require.New(t) 46 | for _, c := range test.UnmarshalCombinations { 47 | buf, proto := c.CreateProtocol() 48 | proto.WriteStructBegin("hello") 49 | proto.WriteFieldBegin("field1", thrift.STRING, 1) 50 | proto.WriteString("abc") 51 | proto.WriteFieldEnd() 52 | proto.WriteFieldStop() 53 | proto.WriteStructEnd() 54 | var val struct_of_string_test.TestObject 55 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 56 | should.Equal(struct_of_string_test.TestObject{ 57 | "abc", 58 | }, val) 59 | } 60 | } 61 | 62 | func Test_marshal_general_struct_of_string(t *testing.T) { 63 | should := require.New(t) 64 | for _, c := range test.Combinations { 65 | obj := general.Struct{ 66 | protocol.FieldId(1): "abc", 67 | } 68 | 69 | output, err := c.Marshal(obj) 70 | should.NoError(err) 71 | output1, err := c.Marshal(&obj) 72 | should.NoError(err) 73 | should.Equal(output, output1) 74 | var val general.Struct 75 | should.NoError(c.Unmarshal(output, &val)) 76 | should.Equal("abc", val[protocol.FieldId(1)]) 77 | } 78 | } 79 | 80 | func Test_marshal_struct_of_string(t *testing.T) { 81 | should := require.New(t) 82 | for _, c := range test.MarshalCombinations { 83 | obj := struct_of_string_test.TestObject{ 84 | "abc", 85 | } 86 | 87 | output, err := c.Marshal(obj) 88 | should.NoError(err) 89 | output1, err := c.Marshal(&obj) 90 | should.NoError(err) 91 | should.Equal(output, output1) 92 | var val general.Struct 93 | should.NoError(c.Unmarshal(output, &val)) 94 | should.Equal("abc", val[protocol.FieldId(1)]) 95 | } 96 | } -------------------------------------------------------------------------------- /test/level_2/struct_of_string_test/TestObject.go: -------------------------------------------------------------------------------- 1 | package struct_of_string_test 2 | 3 | type TestObject struct { 4 | Field1 string `thrift:",1"` 5 | } -------------------------------------------------------------------------------- /test/level_2/struct_of_struct_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "github.com/stretchr/testify/require" 6 | "git.apache.org/thrift.git/lib/go/thrift" 7 | "github.com/thrift-iterator/go/protocol" 8 | "github.com/thrift-iterator/go/test" 9 | "github.com/thrift-iterator/go/test/level_2/struct_of_struct_test" 10 | "github.com/thrift-iterator/go/general" 11 | ) 12 | 13 | func Test_skip_struct_of_struct(t *testing.T) { 14 | should := require.New(t) 15 | for _, c := range test.Combinations { 16 | buf, proto := c.CreateProtocol() 17 | proto.WriteStructBegin("hello") 18 | proto.WriteFieldBegin("field1", thrift.STRUCT, 1) 19 | 20 | proto.WriteStructBegin("hello") 21 | proto.WriteFieldBegin("field1", thrift.STRING, 1) 22 | proto.WriteString("abc") 23 | proto.WriteFieldEnd() 24 | proto.WriteFieldStop() 25 | proto.WriteStructEnd() 26 | 27 | proto.WriteFieldEnd() 28 | proto.WriteFieldStop() 29 | proto.WriteStructEnd() 30 | iter := c.CreateIterator(buf.Bytes()) 31 | should.Equal(buf.Bytes(), iter.SkipStruct(nil)) 32 | } 33 | } 34 | 35 | func Test_unmarshal_general_struct_of_struct(t *testing.T) { 36 | should := require.New(t) 37 | for _, c := range test.Combinations { 38 | buf, proto := c.CreateProtocol() 39 | proto.WriteStructBegin("hello") 40 | proto.WriteFieldBegin("field1", thrift.STRUCT, 1) 41 | 42 | proto.WriteStructBegin("hello") 43 | proto.WriteFieldBegin("field1", thrift.STRING, 1) 44 | proto.WriteString("abc") 45 | proto.WriteFieldEnd() 46 | proto.WriteFieldStop() 47 | proto.WriteStructEnd() 48 | 49 | proto.WriteFieldEnd() 50 | proto.WriteFieldStop() 51 | proto.WriteStructEnd() 52 | var val general.Struct 53 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 54 | should.Equal(general.Struct{ 55 | protocol.FieldId(1): "abc", 56 | }, val[protocol.FieldId(1)]) 57 | } 58 | } 59 | 60 | func Test_unmarshal_struct_of_struct(t *testing.T) { 61 | should := require.New(t) 62 | for _, c := range test.UnmarshalCombinations { 63 | buf, proto := c.CreateProtocol() 64 | proto.WriteStructBegin("hello") 65 | proto.WriteFieldBegin("field1", thrift.STRUCT, 1) 66 | 67 | proto.WriteStructBegin("hello") 68 | proto.WriteFieldBegin("field1", thrift.STRING, 1) 69 | proto.WriteString("abc") 70 | proto.WriteFieldEnd() 71 | proto.WriteFieldStop() 72 | proto.WriteStructEnd() 73 | 74 | proto.WriteFieldEnd() 75 | proto.WriteFieldStop() 76 | proto.WriteStructEnd() 77 | var val struct_of_struct_test.TestObject 78 | should.NoError(c.Unmarshal(buf.Bytes(), &val)) 79 | should.Equal(struct_of_struct_test.TestObject{ 80 | struct_of_struct_test.EmbeddedObject{"abc"}, 81 | }, val) 82 | } 83 | } 84 | 85 | func Test_marshal_general_struct_of_struct(t *testing.T) { 86 | should := require.New(t) 87 | for _, c := range test.Combinations { 88 | obj := general.Struct{ 89 | protocol.FieldId(1): general.Struct{ 90 | protocol.FieldId(1): "abc", 91 | }, 92 | } 93 | 94 | output, err := c.Marshal(obj) 95 | should.NoError(err) 96 | output1, err := c.Marshal(&obj) 97 | should.NoError(err) 98 | should.Equal(output, output1) 99 | var val general.Struct 100 | should.NoError(c.Unmarshal(output, &val)) 101 | should.Equal(general.Struct{ 102 | protocol.FieldId(1): "abc", 103 | }, val[protocol.FieldId(1)]) 104 | } 105 | } 106 | 107 | func Test_marshal_struct_of_struct(t *testing.T) { 108 | should := require.New(t) 109 | for _, c := range test.MarshalCombinations { 110 | obj := struct_of_struct_test.TestObject{ 111 | struct_of_struct_test.EmbeddedObject{"abc"}, 112 | } 113 | 114 | output, err := c.Marshal(obj) 115 | should.NoError(err) 116 | output1, err := c.Marshal(&obj) 117 | should.NoError(err) 118 | should.Equal(output, output1) 119 | var val general.Struct 120 | should.NoError(c.Unmarshal(output, &val)) 121 | should.Equal(general.Struct{ 122 | protocol.FieldId(1): "abc", 123 | }, val[protocol.FieldId(1)]) 124 | } 125 | } -------------------------------------------------------------------------------- /test/level_2/struct_of_struct_test/TestObject.go: -------------------------------------------------------------------------------- 1 | package struct_of_struct_test 2 | 3 | type TestObject struct { 4 | Field1 EmbeddedObject `thrift:",1"` 5 | } 6 | 7 | type EmbeddedObject struct { 8 | Field1 string `thrift:",1"` 9 | } --------------------------------------------------------------------------------